From aa1a46517831f10b84de517362fa85ee1b754cbe Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 20 Jan 2026 19:01:15 +0700 Subject: [PATCH] refactor(forwarder): improve connection handling and cleanup - Extract copyAndClose method for bidirectional data transfe - Add closeWriter helper for graceful connection shutdown - Add handleIncomingConnection helper - Add openForwardedChannel helper --- session/forwarder/forwarder.go | 144 +++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 62 deletions(-) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index fa1dff4..cac6691 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "errors" + "fmt" "io" "log" "net" @@ -62,6 +63,55 @@ type Forwarder interface { Close() error } +func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { + type channelResult struct { + channel ssh.Channel + reqs <-chan *ssh.Request + err error + } + resultChan := make(chan channelResult, 1) + + go func() { + channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) + select { + case resultChan <- channelResult{channel, reqs, err}: + default: + if channel != nil { + err = channel.Close() + if err != nil { + log.Printf("Failed to close unused channel: %v", err) + return + } + go ssh.DiscardRequests(reqs) + } + } + }() + + select { + case result := <-resultChan: + return result.channel, result.reqs, result.err + case <-time.After(5 * time.Second): + return nil, nil, errors.New("timeout opening forwarded-tcpip channel") + } +} + +func (f *forwarder) handleIncomingConnection(conn net.Conn) { + payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) + + channel, reqs, err := f.openForwardedChannel(payload) + if err != nil { + log.Printf("Failed to open forwarded-tcpip channel: %v", err) + err = conn.Close() + if err != nil { + log.Printf("Failed to close connection: %v", err) + } + return + } + + go ssh.DiscardRequests(reqs) + go f.HandleConnection(conn, channel, conn.RemoteAddr()) +} + func (f *forwarder) AcceptTCPConnections() { for { conn, err := f.Listener().Accept() @@ -73,51 +123,33 @@ func (f *forwarder) AcceptTCPConnections() { continue } - payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr()) - - type channelResult struct { - channel ssh.Channel - reqs <-chan *ssh.Request - err error - } - resultChan := make(chan channelResult, 1) - - go func() { - channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) - select { - case resultChan <- channelResult{channel, reqs, err}: - default: - if channel != nil { - err := channel.Close() - if err != nil { - log.Printf("Failed to close unused channel: %v", err) - return - } - go ssh.DiscardRequests(reqs) - } - } - }() - - select { - case result := <-resultChan: - if result.err != nil { - log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) - if closeErr := conn.Close(); closeErr != nil { - log.Printf("Failed to close connection: %v", closeErr) - } - continue - } - go ssh.DiscardRequests(result.reqs) - go f.HandleConnection(conn, result.channel, conn.RemoteAddr()) - case <-time.After(5 * time.Second): - log.Printf("Timeout opening forwarded-tcpip channel") - if closeErr := conn.Close(); closeErr != nil { - log.Printf("Failed to close connection: %v", closeErr) - } - } + go f.handleIncomingConnection(conn) } } +func closeWriter(w io.Writer) error { + if cw, ok := w.(interface{ CloseWrite() error }); ok { + return cw.CloseWrite() + } + if closer, ok := w.(io.Closer); ok { + return closer.Close() + } + return nil +} + +func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error { + var errs []error + _, err := copyWithBuffer(dst, src) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err)) + } + + if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) { + errs = append(errs, fmt.Errorf("close writer error (%s): %w", direction, err)) + } + return errors.Join(errs...) +} + func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { defer func() { _, err := io.Copy(io.Discard, src) @@ -133,31 +165,19 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA go func() { defer wg.Done() - _, err := copyWithBuffer(dst, src) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying src to dst: %v", err) - } - if conn, ok := dst.(interface{ CloseWrite() error }); ok { - if err = conn.CloseWrite(); err != nil { - log.Printf("Error closing write side of dst: %v", err) - } - } else { - if closer, closerOk := dst.(io.Closer); closerOk { - if err = closer.Close(); err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error closing dst connection: %v", err) - } - } + err := f.copyAndClose(dst, src, "src to dst") + if err != nil { + log.Println("Error during copy: ", err) + return } }() go func() { defer wg.Done() - _, err := copyWithBuffer(src, dst) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying dst to src: %v", err) - } - if err = src.CloseWrite(); err != nil { - log.Printf("Error closing write side of src: %v", err) + err := f.copyAndClose(src, dst, "dst to src") + if err != nil { + log.Println("Error during copy: ", err) + return } }()