From 4b21541668356c40da9457c46a7f56688cfb5fda Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 22 Jul 2025 12:38:52 +0700 Subject: [PATCH] fix: resolve resource exhaustion with high connection counts --- session/handler.go | 45 +++++++-------------------------------------- 1 file changed, 7 insertions(+), 38 deletions(-) diff --git a/session/handler.go b/session/handler.go index 0290190..19f42ad 100644 --- a/session/handler.go +++ b/session/handler.go @@ -560,37 +560,31 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se ctx, cancel := context.WithCancel(context.Background()) defer cancel() - var wg sync.WaitGroup - errChan := make(chan error, 2) - - wg.Add(1) go func() { - defer wg.Done() defer func() { if r := recover(); r != nil { log.Printf("Panic in reader copy: %v", r) - errChan <- fmt.Errorf("panic: %v", r) } + cancel() }() _, err := io.Copy(channel, conn.Reader) if err != nil && !errors.Is(err, io.EOF) { log.Printf("Error copying from conn.Reader to channel: %v", err) - errChan <- err } cancel() }() reader := bufio.NewReader(channel) - peekDone := make(chan error, 1) + peekChan := make(chan error, 1) go func() { _, err := reader.Peek(1) - peekDone <- err + peekChan <- err }() select { - case err := <-peekDone: + case err := <-peekChan: if err == io.EOF { s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) sendBadGatewayResponse(conn.Writer) @@ -613,34 +607,9 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se s.sendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) - wg.Add(1) - go func() { - defer wg.Done() - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in writer copy: %v", r) - errChan <- fmt.Errorf("panic: %v", r) - } - }() - - _, err := io.Copy(conn.Writer, reader) - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error copying from channel to conn.Writer: %v", err) - errChan <- err - } - cancel() - }() - - go func() { - wg.Wait() - close(errChan) - }() - - for err := range errChan { - if err != nil { - log.Printf("Connection error: %v", err) - break - } + _, err = io.Copy(conn.Writer, reader) + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error copying from channel to conn.Writer: %v", err) } }