From f6ad5c81e38eafd31743f6f9f7772265d9677ff0 Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 22 Jul 2025 13:19:28 +0700 Subject: [PATCH] fix: resolve resource exhaustion with high connection counts --- session/handler.go | 57 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/session/handler.go b/session/handler.go index 425a5c8..8317c54 100644 --- a/session/handler.go +++ b/session/handler.go @@ -3,6 +3,7 @@ package session import ( "bufio" "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -541,6 +542,7 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se } defer channel.Close() + // Goroutine 1: Handle SSH channel requests (same as old code) go func() { defer func() { if r := recover(); r != nil { @@ -556,21 +558,60 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se conn.Reader = bufio.NewReader(conn.Writer) } - go io.Copy(channel, conn.Reader) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in reader copy: %v", r) + } + cancel() + }() + + _, err := io.Copy(channel, conn.Reader) + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error copying from conn.Reader to channel: %v", err) + } + cancel() + }() reader := bufio.NewReader(channel) - _, err = reader.Peek(1) - if err == io.EOF { - s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, "Could not forward request to the tunnel addr")) + + peekChan := make(chan error, 1) + go func() { + _, err := reader.Peek(1) + peekChan <- err + }() + + select { + case err := <-peekChan: + if err == io.EOF { + s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) + sendBadGatewayResponse(conn.Writer) + return + } + if err != nil { + log.Printf("Error peeking channel data: %v", err) + s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) + sendBadGatewayResponse(conn.Writer) + return + } + case <-time.After(5 * time.Second): + log.Printf("Timeout waiting for channel data from %s", conn.Writer.RemoteAddr()) + s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) sendBadGatewayResponse(conn.Writer) - conn.Writer.Close() - channel.Close() + return + case <-ctx.Done(): return } - s.sendMessage(fmt.Sprintf("\033[32m %s -> [%s] TUNNEL ADDRESS -- \"%s\" \r\n \033[0m", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) + s.sendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) - io.Copy(conn.Writer, reader) + _, err = io.Copy(conn.Writer, reader) + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error copying from channel to conn.Writer: %v", err) + } } func sendBadGatewayResponse(writer io.Writer) {