From 19135ceb428389c33aede22eff43691935fb0cf9 Mon Sep 17 00:00:00 2001 From: bagas Date: Fri, 16 Jan 2026 15:17:33 +0700 Subject: [PATCH] refactor: convert structs to interfaces and rename accessors - Convert struct types to interfaces - Rename getter and setter methods - Add Close method to server interface - Merge handler functionality into session file - Handle lifecycle.Connection().Wait() - fix panic on nil connection in SSH server --- internal/grpc/client/client.go | 135 ++++++------------ main.go | 31 ++-- server/server.go | 30 ++-- session/handler.go | 252 --------------------------------- session/session.go | 250 +++++++++++++++++++++++++++++++- 5 files changed, 325 insertions(+), 373 deletions(-) delete mode 100644 session/handler.go diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 8c701e1..ffcc7d9 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -2,7 +2,6 @@ package client import ( "context" - "crypto/tls" "errors" "fmt" "io" @@ -16,7 +15,6 @@ import ( proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/keepalive" @@ -24,83 +22,34 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -type GrpcConfig struct { - Address string - UseTLS bool - InsecureSkipVerify bool - Timeout time.Duration - KeepAlive bool - MaxRetries int - KeepAliveTime time.Duration - KeepAliveTimeout time.Duration - PermitWithoutStream bool +type Client interface { + SubscribeEvents(ctx context.Context, identity, authToken string) error + ClientConn() *grpc.ClientConn + AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) + Close() error + CheckServerHealth(ctx context.Context) error } - -type Client struct { +type client struct { conn *grpc.ClientConn - config *GrpcConfig + address string sessionRegistry session.Registry eventService proto.EventServiceClient authorizeConnectionService proto.UserServiceClient closing bool } -func DefaultConfig() *GrpcConfig { - return &GrpcConfig{ - Address: "localhost:50051", - UseTLS: false, - InsecureSkipVerify: false, - Timeout: 10 * time.Second, - KeepAlive: true, - MaxRetries: 3, - KeepAliveTime: 2 * time.Minute, - KeepAliveTimeout: 10 * time.Second, - PermitWithoutStream: false, - } -} - -func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error) { - if config == nil { - config = DefaultConfig() - } else { - defaults := DefaultConfig() - if config.Address == "" { - config.Address = defaults.Address - } - if config.Timeout == 0 { - config.Timeout = defaults.Timeout - } - if config.MaxRetries == 0 { - config.MaxRetries = defaults.MaxRetries - } - if config.KeepAliveTime == 0 { - config.KeepAliveTime = defaults.KeepAliveTime - } - if config.KeepAliveTimeout == 0 { - config.KeepAliveTimeout = defaults.KeepAliveTimeout - } - } - +func New(address string, sessionRegistry session.Registry) (Client, error) { var opts []grpc.DialOption - if config.UseTLS { - tlsConfig := &tls.Config{ - InsecureSkipVerify: config.InsecureSkipVerify, - } - creds := credentials.NewTLS(tlsConfig) - opts = append(opts, grpc.WithTransportCredentials(creds)) - } else { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + + kaParams := keepalive.ClientParameters{ + Time: 2 * time.Minute, + Timeout: 10 * time.Second, + PermitWithoutStream: false, } - if config.KeepAlive { - kaParams := keepalive.ClientParameters{ - Time: config.KeepAliveTime, - Timeout: config.KeepAliveTimeout, - PermitWithoutStream: config.PermitWithoutStream, - } - opts = append(opts, grpc.WithKeepaliveParams(kaParams)) - } + opts = append(opts, grpc.WithKeepaliveParams(kaParams)) opts = append(opts, grpc.WithDefaultCallOptions( @@ -109,24 +58,24 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error) ), ) - conn, err := grpc.NewClient(config.Address, opts...) + conn, err := grpc.NewClient(address, opts...) if err != nil { - return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", config.Address, err) + return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err) } eventService := proto.NewEventServiceClient(conn) authorizeConnectionService := proto.NewUserServiceClient(conn) - return &Client{ + return &client{ conn: conn, - config: config, + address: address, sessionRegistry: sessionRegistry, eventService: eventService, authorizeConnectionService: authorizeConnectionService, }, nil } -func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string) error { +func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error { const ( baseBackoff = time.Second maxBackoff = 30 * time.Second @@ -209,7 +158,7 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string } } -func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error { +func (c *client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error { handlers := c.eventHandlers(subscribe) for { @@ -230,7 +179,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod } } -func (c *Client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error { +func (c *client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error { return map[proto.EventType]func(*proto.Events) error{ proto.EventType_SLUG_CHANGE: func(evt *proto.Events) error { return c.handleSlugChange(subscribe, evt) }, proto.EventType_GET_SESSIONS: func(evt *proto.Events) error { return c.handleGetSessions(subscribe, evt) }, @@ -238,7 +187,7 @@ func (c *Client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, pr } } -func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { +func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { slugEvent := evt.GetSlugEvent() user := slugEvent.GetUser() oldSlug := slugEvent.GetOld() @@ -272,7 +221,7 @@ func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, }, "slug change success response") } -func (c *Client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { +func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { sessions := c.sessionRegistry.GetAllSessionFromUser(evt.GetGetSessionsEvent().GetIdentity()) var details []*proto.Detail @@ -296,7 +245,7 @@ func (c *Client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node }, "send get sessions response") } -func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { +func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { terminate := evt.GetTerminateSessionEvent() user := terminate.GetUser() slug := terminate.GetSlug() @@ -338,7 +287,7 @@ func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto }, "terminate session success response") } -func (c *Client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error { +func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error { if err := subscribe.Send(node); err != nil { if c.isConnectionError(err) { return err @@ -348,7 +297,7 @@ func (c *Client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E return nil } -func (c *Client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) { +func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) { switch t { case proto.TunnelType_HTTP: return types.HTTP, nil @@ -359,11 +308,11 @@ func (c *Client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) } } -func (c *Client) GetConnection() *grpc.ClientConn { +func (c *client) ClientConn() *grpc.ClientConn { return c.conn } -func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { +func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token}) if err != nil { return false, "UNAUTHORIZED", err @@ -375,17 +324,8 @@ func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bo return true, check.GetUser(), nil } -func (c *Client) Close() error { - if c.conn != nil { - log.Printf("Closing gRPC connection to %s", c.config.Address) - c.closing = true - return c.conn.Close() - } - return nil -} - -func (c *Client) CheckServerHealth(ctx context.Context) error { - healthClient := grpc_health_v1.NewHealthClient(c.GetConnection()) +func (c *client) CheckServerHealth(ctx context.Context) error { + healthClient := grpc_health_v1.NewHealthClient(c.ClientConn()) resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{ Service: "", }) @@ -398,11 +338,16 @@ func (c *Client) CheckServerHealth(ctx context.Context) error { return nil } -func (c *Client) GetConfig() *GrpcConfig { - return c.config +func (c *client) Close() error { + if c.conn != nil { + log.Printf("Closing gRPC connection to %s", c.address) + c.closing = true + return c.conn.Close() + } + return nil } -func (c *Client) isConnectionError(err error) bool { +func (c *client) isConnectionError(err error) bool { if c.closing { return false } diff --git a/main.go b/main.go index e069f7c..e8d3884 100644 --- a/main.go +++ b/main.go @@ -49,7 +49,7 @@ func main() { sshConfig := &ssh.ServerConfig{ NoClientAuth: true, - ServerVersion: fmt.Sprintf("SSH-2.0-TunnlPls-%s", version.GetShortVersion()), + ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), } sshKeyPath := "certs/ssh/id_rsa" @@ -77,7 +77,7 @@ func main() { shutdownChan := make(chan os.Signal, 1) signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) - var grpcClient *client.Client + var grpcClient client.Client if isNodeMode { grpcHost := config.Getenv("GRPC_ADDRESS", "localhost") grpcPort := config.Getenv("GRPC_PORT", "8080") @@ -87,21 +87,13 @@ func main() { log.Fatalf("NODE_TOKEN is required in node mode") } - c, err := client.New(&client.GrpcConfig{ - Address: grpcAddr, - UseTLS: false, - InsecureSkipVerify: false, - Timeout: 10 * time.Second, - KeepAlive: true, - MaxRetries: 3, - }, sessionRegistry) + grpcClient, err = client.New(grpcAddr, sessionRegistry) if err != nil { log.Fatalf("failed to create grpc client: %v", err) } - grpcClient = c healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second) - if err := grpcClient.CheckServerHealth(healthCtx); err != nil { + if err = grpcClient.CheckServerHealth(healthCtx); err != nil { healthCancel() log.Fatalf("gRPC health check failed: %v", err) } @@ -109,14 +101,15 @@ func main() { go func() { identity := config.Getenv("DOMAIN", "localhost") - if err := grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil { + if err = grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil { errChan <- fmt.Errorf("failed to subscribe to events: %w", err) } }() } + var app server.Server go func() { - app, err := server.NewServer(sshConfig, sessionRegistry, grpcClient) + app, err = server.New(sshConfig, sessionRegistry, grpcClient) if err != nil { errChan <- fmt.Errorf("failed to start server: %s", err) return @@ -125,7 +118,7 @@ func main() { }() select { - case err := <-errChan: + case err = <-errChan: log.Printf("error happen : %s", err) case sig := <-shutdownChan: log.Printf("received signal %s, shutting down", sig) @@ -133,8 +126,14 @@ func main() { cancel() + if app != nil { + if err = app.Close(); err != nil { + log.Printf("failed to close server : %s", err) + } + } + if grpcClient != nil { - if err := grpcClient.Close(); err != nil { + if err = grpcClient.Close(); err != nil { log.Printf("failed to close grpc conn : %s", err) } } diff --git a/server/server.go b/server/server.go index 4bd4804..3e42c9a 100644 --- a/server/server.go +++ b/server/server.go @@ -14,14 +14,18 @@ import ( "golang.org/x/crypto/ssh" ) -type Server struct { - conn *net.Listener +type Server interface { + Start() + Close() error +} +type server struct { + listener net.Listener config *ssh.ServerConfig sessionRegistry session.Registry - grpcClient *client.Client + grpcClient client.Client } -func NewServer(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient *client.Client) (*Server, error) { +func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client) (Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) if err != nil { log.Fatalf("failed to listen on port 2200: %v", err) @@ -43,19 +47,23 @@ func NewServer(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, gr } } - return &Server{ - conn: &listener, + return &server{ + listener: listener, config: sshConfig, sessionRegistry: sessionRegistry, grpcClient: grpcClient, }, nil } -func (s *Server) Start() { +func (s *server) Start() { log.Println("SSH server is starting on port 2200...") for { - conn, err := (*s.conn).Accept() + conn, err := s.listener.Accept() if err != nil { + if errors.Is(err, net.ErrClosed) { + log.Println("listener closed, stopping server") + return + } log.Printf("failed to accept connection: %v", err) continue } @@ -64,7 +72,11 @@ func (s *Server) Start() { } } -func (s *Server) handleConnection(conn net.Conn) { +func (s *server) Close() error { + return s.listener.Close() +} + +func (s *server) handleConnection(conn net.Conn) { sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) if err != nil { log.Printf("failed to establish SSH connection: %v", err) diff --git a/session/handler.go b/session/handler.go deleted file mode 100644 index f80f222..0000000 --- a/session/handler.go +++ /dev/null @@ -1,252 +0,0 @@ -package session - -import ( - "bytes" - "encoding/binary" - "fmt" - "log" - "net" - portUtil "tunnel_pls/internal/port" - "tunnel_pls/internal/random" - "tunnel_pls/types" - - "golang.org/x/crypto/ssh" -) - -var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} - -func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { - for req := range GlobalRequest { - switch req.Type { - case "shell", "pty-req": - err := req.Reply(true, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - case "window-change": - p := req.Payload - if len(p) < 16 { - log.Println("invalid window-change payload") - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - return - } - cols := binary.BigEndian.Uint32(p[0:4]) - rows := binary.BigEndian.Uint32(p[4:8]) - - s.interaction.SetWH(int(cols), int(rows)) - - err := req.Reply(true, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - default: - log.Println("Unknown request type:", req.Type) - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - } - } -} - -func (s *session) HandleTCPIPForward(req *ssh.Request) { - log.Println("Port forwarding request detected") - - fail := func(msg string) { - log.Println(msg) - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - return - } - if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } - } - - reader := bytes.NewReader(req.Payload) - - addr, err := readSSHString(reader) - if err != nil { - fail(fmt.Sprintf("Failed to read address from payload: %v", err)) - return - } - - var rawPortToBind uint32 - if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { - fail(fmt.Sprintf("Failed to read port from payload: %v", err)) - return - } - - if rawPortToBind > 65535 { - fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind)) - return - } - - portToBind := uint16(rawPortToBind) - if isBlockedPort(portToBind) { - fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind)) - return - } - - switch portToBind { - case 80, 443: - s.HandleHTTPForward(req, portToBind) - default: - s.HandleTCPForward(req, addr, portToBind) - } -} - -func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) { - fail := func(msg string, key *types.SessionKey) { - log.Println(msg) - if key != nil { - s.registry.Remove(*key) - } - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - } - } - - slug := random.GenerateRandomString(20) - key := types.SessionKey{Id: slug, Type: types.HTTP} - if !s.registry.Register(key, s) { - fail(fmt.Sprintf("Failed to register client with slug: %s", slug), nil) - return - } - - buf := new(bytes.Buffer) - err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) - if err != nil { - fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key) - return - } - log.Printf("HTTP forwarding approved on port: %d", portToBind) - - err = req.Reply(true, buf.Bytes()) - if err != nil { - fail(fmt.Sprintf("Failed to reply to request: %v", err), &key) - return - } - - s.forwarder.SetType(types.HTTP) - s.forwarder.SetForwardedPort(portToBind) - s.slug.Set(slug) - s.lifecycle.SetStatus(types.RUNNING) -} - -func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - fail := func(msg string) { - log.Println(msg) - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - return - } - if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } - } - - cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) { - log.Println(msg) - if key != nil { - s.registry.Remove(*key) - } - if port != 0 { - if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil { - log.Printf("Failed to reset port status: %v", setErr) - } - } - if listener != nil { - if closeErr := listener.Close(); closeErr != nil { - log.Printf("Failed to close listener: %v", closeErr) - } - } - if err := req.Reply(false, nil); err != nil { - log.Println("Failed to reply to request:", err) - } - _ = s.lifecycle.Close() - } - - if portToBind == 0 { - unassigned, ok := portUtil.Default.GetUnassignedPort() - if !ok { - fail("No available port") - return - } - portToBind = unassigned - } - - if claimed := portUtil.Default.ClaimPort(portToBind); !claimed { - fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind)) - return - } - - log.Printf("Requested forwarding on %s:%d", addr, portToBind) - listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) - if err != nil { - cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil) - return - } - - key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} - - if !s.registry.Register(key, s) { - cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil) - return - } - - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) - if err != nil { - cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key) - return - } - - log.Printf("TCP forwarding approved on port: %d", portToBind) - err = req.Reply(true, buf.Bytes()) - if err != nil { - cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key) - return - } - - s.forwarder.SetType(types.TCP) - s.forwarder.SetListener(listener) - s.forwarder.SetForwardedPort(portToBind) - s.slug.Set(key.Id) - s.lifecycle.SetStatus(types.RUNNING) - go s.forwarder.AcceptTCPConnections() -} - -func readSSHString(reader *bytes.Reader) (string, error) { - var length uint32 - if err := binary.Read(reader, binary.BigEndian, &length); err != nil { - return "", err - } - strBytes := make([]byte, length) - if _, err := reader.Read(strBytes); err != nil { - return "", err - } - return string(strBytes), nil -} - -func isBlockedPort(port uint16) bool { - if port == 80 || port == 443 { - return false - } - if port < 1024 && port != 0 { - return true - } - for _, p := range blockedReservedPorts { - if p == port { - return true - } - } - return false -} diff --git a/session/session.go b/session/session.go index e01355c..82e6916 100644 --- a/session/session.go +++ b/session/session.go @@ -1,10 +1,17 @@ package session import ( + "bytes" + "encoding/binary" + "errors" "fmt" + "io" "log" + "net" "time" "tunnel_pls/internal/config" + portUtil "tunnel_pls/internal/port" + "tunnel_pls/internal/random" "tunnel_pls/session/forwarder" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" @@ -45,6 +52,8 @@ type session struct { registry Registry } +var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} + func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session { slugManager := slug.New() forwarderManager := forwarder.New(slugManager) @@ -151,7 +160,10 @@ func (s *session) Start() error { s.HandleTCPIPForward(tcpipReq) s.interaction.Start() - s.lifecycle.Connection().Wait() + if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + log.Printf("ssh connection closed with error: %v", err) + } + if err := s.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) return err @@ -179,3 +191,239 @@ func (s *session) waitForTCPIPForward() *ssh.Request { return nil } } + +func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { + for req := range GlobalRequest { + switch req.Type { + case "shell", "pty-req": + err := req.Reply(true, nil) + if err != nil { + log.Println("Failed to reply to request:", err) + return + } + case "window-change": + p := req.Payload + if len(p) < 16 { + log.Println("invalid window-change payload") + err := req.Reply(false, nil) + if err != nil { + log.Println("Failed to reply to request:", err) + return + } + return + } + cols := binary.BigEndian.Uint32(p[0:4]) + rows := binary.BigEndian.Uint32(p[4:8]) + + s.interaction.SetWH(int(cols), int(rows)) + + err := req.Reply(true, nil) + if err != nil { + log.Println("Failed to reply to request:", err) + return + } + default: + log.Println("Unknown request type:", req.Type) + err := req.Reply(false, nil) + if err != nil { + log.Println("Failed to reply to request:", err) + return + } + } + } +} + +func (s *session) HandleTCPIPForward(req *ssh.Request) { + log.Println("Port forwarding request detected") + + fail := func(msg string) { + log.Println(msg) + if err := req.Reply(false, nil); err != nil { + log.Println("Failed to reply to request:", err) + return + } + if err := s.lifecycle.Close(); err != nil { + log.Printf("failed to close session: %v", err) + } + } + + reader := bytes.NewReader(req.Payload) + + addr, err := readSSHString(reader) + if err != nil { + fail(fmt.Sprintf("Failed to read address from payload: %v", err)) + return + } + + var rawPortToBind uint32 + if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { + fail(fmt.Sprintf("Failed to read port from payload: %v", err)) + return + } + + if rawPortToBind > 65535 { + fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind)) + return + } + + portToBind := uint16(rawPortToBind) + if isBlockedPort(portToBind) { + fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind)) + return + } + + switch portToBind { + case 80, 443: + s.HandleHTTPForward(req, portToBind) + default: + s.HandleTCPForward(req, addr, portToBind) + } +} + +func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) { + fail := func(msg string, key *types.SessionKey) { + log.Println(msg) + if key != nil { + s.registry.Remove(*key) + } + if err := req.Reply(false, nil); err != nil { + log.Println("Failed to reply to request:", err) + } + } + + randomString := random.GenerateRandomString(20) + key := types.SessionKey{Id: randomString, Type: types.HTTP} + if !s.registry.Register(key, s) { + fail(fmt.Sprintf("Failed to register client with slug: %s", randomString), nil) + return + } + + buf := new(bytes.Buffer) + err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) + if err != nil { + fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key) + return + } + log.Printf("HTTP forwarding approved on port: %d", portToBind) + + err = req.Reply(true, buf.Bytes()) + if err != nil { + fail(fmt.Sprintf("Failed to reply to request: %v", err), &key) + return + } + + s.forwarder.SetType(types.HTTP) + s.forwarder.SetForwardedPort(portToBind) + s.slug.Set(randomString) + s.lifecycle.SetStatus(types.RUNNING) +} + +func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { + fail := func(msg string) { + log.Println(msg) + if err := req.Reply(false, nil); err != nil { + log.Println("Failed to reply to request:", err) + return + } + if err := s.lifecycle.Close(); err != nil { + log.Printf("failed to close session: %v", err) + } + } + + cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) { + log.Println(msg) + if key != nil { + s.registry.Remove(*key) + } + if port != 0 { + if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil { + log.Printf("Failed to reset port status: %v", setErr) + } + } + if listener != nil { + if closeErr := listener.Close(); closeErr != nil { + log.Printf("Failed to close listener: %v", closeErr) + } + } + if err := req.Reply(false, nil); err != nil { + log.Println("Failed to reply to request:", err) + } + _ = s.lifecycle.Close() + } + + if portToBind == 0 { + unassigned, ok := portUtil.Default.GetUnassignedPort() + if !ok { + fail("No available port") + return + } + portToBind = unassigned + } + + if claimed := portUtil.Default.ClaimPort(portToBind); !claimed { + fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind)) + return + } + + log.Printf("Requested forwarding on %s:%d", addr, portToBind) + listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) + if err != nil { + cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil) + return + } + + key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} + + if !s.registry.Register(key, s) { + cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil) + return + } + + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) + if err != nil { + cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key) + return + } + + log.Printf("TCP forwarding approved on port: %d", portToBind) + err = req.Reply(true, buf.Bytes()) + if err != nil { + cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key) + return + } + + s.forwarder.SetType(types.TCP) + s.forwarder.SetListener(listener) + s.forwarder.SetForwardedPort(portToBind) + s.slug.Set(key.Id) + s.lifecycle.SetStatus(types.RUNNING) + go s.forwarder.AcceptTCPConnections() +} + +func readSSHString(reader *bytes.Reader) (string, error) { + var length uint32 + if err := binary.Read(reader, binary.BigEndian, &length); err != nil { + return "", err + } + strBytes := make([]byte, length) + if _, err := reader.Read(strBytes); err != nil { + return "", err + } + return string(strBytes), nil +} + +func isBlockedPort(port uint16) bool { + if port == 80 || port == 443 { + return false + } + if port < 1024 && port != 0 { + return true + } + for _, p := range blockedReservedPorts { + if p == port { + return true + } + } + return false +}