fix: resolve copy goroutine deadlock on early connection close

- Add proper CloseWrite handling to signal EOF to other goroutine
- Ensure both copy goroutines terminate when either side closes
- Prevent goroutine leaks for SSH forwarded-tcpip channels:
    - Use select with default when sending result to resultChan
    - Close unused SSH channels and discard requests if main goroutine has already timed out
This commit is contained in:
2026-01-19 00:13:09 +07:00
parent 41fdb5639c
commit 8fb19af5a6
2 changed files with 40 additions and 18 deletions
+28 -17
View File
@@ -73,14 +73,6 @@ func (f *forwarder) AcceptTCPConnections() {
continue
}
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Printf("Failed to set connection deadline: %v", err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
type channelResult struct {
@@ -92,7 +84,18 @@ func (f *forwarder) AcceptTCPConnections() {
go func() {
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
if channel != nil {
err := channel.Close()
if err != nil {
log.Printf("Failed to close unused channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
}
}
}()
select {
@@ -104,14 +107,8 @@ func (f *forwarder) AcceptTCPConnections() {
}
continue
}
if err = conn.SetDeadline(time.Time{}); err != nil {
log.Printf("Failed to clear connection deadline: %v", err)
}
go ssh.DiscardRequests(result.reqs)
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
if closeErr := conn.Close(); closeErr != nil {
@@ -150,7 +147,18 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
defer wg.Done()
_, err := copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying srcdst: %v", err)
log.Printf("Error copying src to dst: %v", err)
}
if conn, ok := dst.(interface{ CloseWrite() error }); ok {
if err = conn.CloseWrite(); err != nil {
log.Printf("Error closing write side of dst: %v", err)
}
} else {
if closer, closerOk := dst.(io.Closer); closerOk {
if err = closer.Close(); err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing dst connection: %v", err)
}
}
}
}()
@@ -158,7 +166,10 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
defer wg.Done()
_, err := copyWithBuffer(src, dst)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying dstsrc: %v", err)
log.Printf("Error copying dst to src: %v", err)
}
if err = src.CloseWrite(); err != nil {
log.Printf("Error closing write side of src: %v", err)
}
}()