From 76d1202b8ec4bdf144dc05f67d17ec5b687a6f48 Mon Sep 17 00:00:00 2001 From: bagas Date: Fri, 26 Dec 2025 23:17:13 +0700 Subject: [PATCH] fix: correct logic when checking tcpip-forward request --- README.md | 15 +++++++++ server/http.go | 56 ++++++++++++++++++++-------------- server/server.go | 19 +++++------- session/forwarder/forwarder.go | 51 +++++++++++++++++++++++-------- session/handler.go | 3 -- session/lifecycle/lifecycle.go | 31 ------------------- session/session.go | 44 +++++++++++++++++++++----- 7 files changed, 130 insertions(+), 89 deletions(-) diff --git a/README.md b/README.md index d66785f..28f8f4a 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,21 @@ The following environment variables can be configured in the `.env` file: If the SSH private key specified in `SSH_PRIVATE_KEY` doesn't exist, the application will automatically generate a new 4096-bit RSA key pair at the specified location. This makes it easier to get started without manually creating SSH keys. +### Memory Optimization + +The application uses a buffer pool with controlled buffer sizes to prevent excessive memory usage under high concurrent loads. The `BUFFER_SIZE` environment variable controls the size of buffers used for io.Copy operations: + +- **Default:** 32768 bytes (32 KB) - Good balance for most scenarios +- **Minimum:** 4096 bytes (4 KB) - Lower memory usage, more CPU overhead +- **Maximum:** 1048576 bytes (1 MB) - Higher throughput, more memory usage + +**Recommended settings based on load:** +- **Low traffic (<100 concurrent):** `BUFFER_SIZE=32768` (default) +- **High traffic (>100 concurrent):** `BUFFER_SIZE=16384` or `BUFFER_SIZE=8192` +- **Very high traffic (>1000 concurrent):** `BUFFER_SIZE=8192` or `BUFFER_SIZE=4096` + +The buffer pool reuses buffers across connections, preventing memory fragmentation and reducing garbage collection pressure. + ### Profiling with pprof To enable profiling for performance analysis: diff --git a/server/http.go b/server/http.go index 34fad0e..6c716d1 100644 --- a/server/http.go +++ b/server/http.go @@ -10,8 +10,11 @@ import ( "net" "regexp" "strings" + "time" "tunnel_pls/session" "tunnel_pls/utils" + + "golang.org/x/crypto/ssh" ) type Interaction interface { @@ -295,30 +298,38 @@ func Handler(conn net.Conn) { func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr) - channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) - if closer, ok := cw.writer.(io.Closer); ok { - if closeErr := closer.Close(); closeErr != nil { - log.Printf("Failed to close connection: %v", closeErr) - } + + type channelResult struct { + channel ssh.Channel + reqs <-chan *ssh.Request + err error + } + resultChan := make(chan channelResult, 1) + + go func() { + channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) + resultChan <- channelResult{channel, reqs, err} + }() + + var channel ssh.Channel + var reqs <-chan *ssh.Request + + select { + case result := <-resultChan: + if result.err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) + sshSession.Forwarder.WriteBadGatewayResponse(cw.writer) + return } + channel = result.channel + reqs = result.reqs + case <-time.After(5 * time.Second): + log.Printf("Timeout opening forwarded-tcpip channel") + sshSession.Forwarder.WriteBadGatewayResponse(cw.writer) return } - go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in request handler goroutine: %v", r) - } - }() - for req := range reqs { - if err := req.Reply(false, nil); err != nil { - log.Printf("Failed to reply to request: %v", err) - return - } - } - }() + go ssh.DiscardRequests(reqs) fingerprintMiddleware := NewTunnelFingerprint() forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr) @@ -329,14 +340,13 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS cw.reqHeader = initialRequest for _, m := range cw.reqStartMW { - err = m.HandleRequest(cw.reqHeader) - if err != nil { + if err := m.HandleRequest(cw.reqHeader); err != nil { log.Printf("Error handling request: %v", err) return } } - _, err = channel.Write(initialRequest.Finalize()) + _, err := channel.Write(initialRequest.Finalize()) if err != nil { log.Printf("Failed to forward request: %v", err) return diff --git a/server/server.go b/server/server.go index e81ac62..8fb85b0 100644 --- a/server/server.go +++ b/server/server.go @@ -23,20 +23,15 @@ func NewServer(config *ssh.ServerConfig) *Server { return nil } if utils.Getenv("TLS_ENABLED", "false") == "true" { - go func() { - err = NewHTTPSServer() - if err != nil { - log.Fatalf("failed to start https server: %v", err) - } - return - }() - } - go func() { - err = NewHTTPServer() + err = NewHTTPSServer() if err != nil { - log.Fatalf("failed to start http server: %v", err) + log.Fatalf("failed to start https server: %v", err) } - }() + } + err = NewHTTPServer() + if err != nil { + log.Fatalf("failed to start http server: %v", err) + } return &Server{ Conn: &listener, Config: config, diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 3bf41bb..c993183 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -9,6 +9,7 @@ import ( "net" "strconv" "sync" + "time" "tunnel_pls/session/slug" "tunnel_pls/types" "tunnel_pls/utils" @@ -70,26 +71,52 @@ func (f *Forwarder) AcceptTCPConnections() { log.Printf("Error accepting connection: %v", err) continue } - payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) - channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) - if err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", err) + + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + log.Printf("Failed to set connection deadline: %v", err) if closeErr := conn.Close(); closeErr != nil { log.Printf("Failed to close connection: %v", closeErr) } continue } + payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) + + type channelResult struct { + channel ssh.Channel + reqs <-chan *ssh.Request + err error + } + resultChan := make(chan channelResult, 1) + go func() { - for req := range reqs { - err := req.Reply(false, nil) - if err != nil { - log.Printf("Failed to reply to request: %v", err) - return - } - } + channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) + resultChan <- channelResult{channel, reqs, err} }() - go f.HandleConnection(conn, channel, conn.RemoteAddr()) + + select { + case result := <-resultChan: + if result.err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) + if closeErr := conn.Close(); closeErr != nil { + log.Printf("Failed to close connection: %v", closeErr) + } + continue + } + + if err := conn.SetDeadline(time.Time{}); err != nil { + log.Printf("Failed to clear connection deadline: %v", err) + } + + go ssh.DiscardRequests(result.reqs) + go f.HandleConnection(conn, result.channel, conn.RemoteAddr()) + + case <-time.After(5 * time.Second): + log.Printf("Timeout opening forwarded-tcpip channel") + if closeErr := conn.Close(); closeErr != nil { + log.Printf("Failed to close connection: %v", closeErr) + } + } } } diff --git a/session/handler.go b/session/handler.go index 962aa65..eb61cfd 100644 --- a/session/handler.go +++ b/session/handler.go @@ -19,9 +19,6 @@ var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 54 func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { for req := range GlobalRequest { switch req.Type { - case "tcpip-forward": - s.HandleTCPIPForward(req) - return case "shell", "pty-req", "window-change": err := req.Reply(true, nil) if err != nil { diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 29b02ed..ecfc206 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -2,11 +2,8 @@ package lifecycle import ( "errors" - "fmt" "io" - "log" "net" - "time" portUtil "tunnel_pls/internal/port" "tunnel_pls/session/slug" "tunnel_pls/types" @@ -41,7 +38,6 @@ func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) { type SessionLifecycle interface { Close() error - WaitForRunningStatus() SetStatus(status types.Status) GetConnection() ssh.Conn GetChannel() ssh.Channel @@ -62,33 +58,6 @@ func (l *Lifecycle) GetConnection() ssh.Conn { func (l *Lifecycle) SetStatus(status types.Status) { l.Status = status } -func (l *Lifecycle) WaitForRunningStatus() { - timeout := time.After(3 * time.Second) - ticker := time.NewTicker(150 * time.Millisecond) - defer ticker.Stop() - frames := []string{"-", "\\", "|", "/"} - i := 0 - for { - select { - case <-ticker.C: - l.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i])) - i = (i + 1) % len(frames) - if l.Status == types.RUNNING { - l.Interaction.SendMessage("\r\033[K") - return - } - case <-timeout: - l.Interaction.SendMessage("\r\033[K") - l.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n") - err := l.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } - log.Println("Timeout waiting for session to start running") - return - } - } -} func (l *Lifecycle) Close() error { err := l.Forwarder.Close() diff --git a/session/session.go b/session/session.go index f5f9ed5..1d23994 100644 --- a/session/session.go +++ b/session/session.go @@ -2,13 +2,15 @@ package session import ( "bytes" + "fmt" "log" "sync" + "time" "tunnel_pls/session/forwarder" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" "tunnel_pls/session/slug" - "tunnel_pls/types" + "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) @@ -30,8 +32,6 @@ type SSHSession struct { Interaction interaction.Controller Forwarder forwarder.ForwardingController SlugManager slug.Manager - - channelOnce sync.Once } func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { @@ -71,20 +71,27 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan SlugManager: slugManager, } + var once sync.Once for channel := range sshChan { ch, reqs, err := channel.Accept() if err != nil { log.Printf("failed to accept channel: %v", err) continue } - session.channelOnce.Do(func() { + once.Do(func() { session.Lifecycle.SetChannel(ch) session.Interaction.SetChannel(ch) - session.Lifecycle.SetStatus(types.SETUP) - go session.HandleGlobalRequest(forwardingReq) - session.Lifecycle.WaitForRunningStatus() - }) + tcpipReq := session.waitForTCPIPForward(forwardingReq) + if tcpipReq == nil { + session.Interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200"))) + if err := session.Lifecycle.Close(); err != nil { + log.Printf("failed to close session: %v", err) + } + return + } + session.HandleTCPIPForward(tcpipReq) + }) go session.HandleGlobalRequest(reqs) } if err := session.Lifecycle.Close(); err != nil { @@ -92,6 +99,27 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan } } +func (s *SSHSession) waitForTCPIPForward(forwardingReq <-chan *ssh.Request) *ssh.Request { + select { + case req, ok := <-forwardingReq: + if !ok { + log.Println("Forwarding request channel closed") + return nil + } + if req.Type == "tcpip-forward" { + return req + } + if err := req.Reply(false, nil); err != nil { + log.Printf("Failed to reply to request: %v", err) + } + log.Printf("Expected tcpip-forward request, got: %s", req.Type) + return nil + case <-time.After(500 * time.Millisecond): + log.Println("No forwarding request received") + return nil + } +} + func updateClientSlug(oldSlug, newSlug string) bool { clientsMutex.Lock() defer clientsMutex.Unlock()