- Rename customWriter struct to httpWriter for clarity - Add closeWriter field to properly close write side of connections - Update all cw variable references to hw - Merge handlerTLS into handler function to reduce code duplication - Extract handler into smaller, focused methods - Split Read/Write/forwardRequest into composable functions Fixes resource leak where connections weren't properly closed on the write side, matching the forwarder's CloseWrite() pattern.
118 lines
2.8 KiB
Go
118 lines
2.8 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"time"
|
|
"tunnel_pls/internal/config"
|
|
"tunnel_pls/internal/grpc/client"
|
|
"tunnel_pls/internal/port"
|
|
"tunnel_pls/session"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type Server interface {
|
|
Start()
|
|
Close() error
|
|
}
|
|
type server struct {
|
|
listener net.Listener
|
|
config *ssh.ServerConfig
|
|
grpcClient client.Client
|
|
sessionRegistry session.Registry
|
|
portRegistry port.Registry
|
|
}
|
|
|
|
func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) {
|
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200")))
|
|
if err != nil {
|
|
log.Fatalf("failed to listen on port 2200: %v", err)
|
|
return nil, err
|
|
}
|
|
redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true"
|
|
|
|
HttpServer := NewHTTPServer(sessionRegistry, redirectTLS)
|
|
err = HttpServer.ListenAndServe()
|
|
if err != nil {
|
|
log.Fatalf("failed to start http server: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
if config.Getenv("TLS_ENABLED", "false") == "true" {
|
|
err = HttpServer.ListenAndServeTLS()
|
|
if err != nil {
|
|
log.Fatalf("failed to start https server: %v", err)
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &server{
|
|
listener: listener,
|
|
config: sshConfig,
|
|
grpcClient: grpcClient,
|
|
sessionRegistry: sessionRegistry,
|
|
portRegistry: portRegistry,
|
|
}, nil
|
|
}
|
|
|
|
func (s *server) Start() {
|
|
log.Println("SSH server is starting on port 2200...")
|
|
for {
|
|
conn, err := s.listener.Accept()
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
log.Println("listener closed, stopping server")
|
|
return
|
|
}
|
|
log.Printf("failed to accept connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
go s.handleConnection(conn)
|
|
}
|
|
}
|
|
|
|
func (s *server) Close() error {
|
|
return s.listener.Close()
|
|
}
|
|
|
|
func (s *server) handleConnection(conn net.Conn) {
|
|
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
|
|
if err != nil {
|
|
log.Printf("failed to establish SSH connection: %v", err)
|
|
err = conn.Close()
|
|
if err != nil {
|
|
log.Printf("failed to close SSH connection: %v", err)
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
defer func(sshConn *ssh.ServerConn) {
|
|
err = sshConn.Close()
|
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
log.Printf("failed to close SSH server: %v", err)
|
|
}
|
|
}(sshConn)
|
|
|
|
user := "UNAUTHORIZED"
|
|
if s.grpcClient != nil {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
_, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User())
|
|
user = u
|
|
cancel()
|
|
}
|
|
log.Println("SSH connection established:", sshConn.User())
|
|
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
|
|
err = sshSession.Start()
|
|
if err != nil {
|
|
log.Printf("SSH session ended with error: %v", err)
|
|
return
|
|
}
|
|
return
|
|
}
|