fix: resolve copy goroutine deadlock on early connection close
- Add proper CloseWrite handling to signal EOF to other goroutine
- Ensure both copy goroutines terminate when either side closes
- Prevent goroutine leaks for SSH forwarded-tcpip channels:
- Use select with default when sending result to resultChan
- Close unused SSH channels and discard requests if main goroutine has already timed out
This commit is contained in:
+12
-1
@@ -347,7 +347,18 @@ func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSessi
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
if channel != nil {
|
||||
err := channel.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close unused channel: %v", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var channel ssh.Channel
|
||||
|
||||
@@ -73,14 +73,6 @@ func (f *forwarder) AcceptTCPConnections() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
log.Printf("Failed to set connection deadline: %v", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||
|
||||
type channelResult struct {
|
||||
@@ -92,7 +84,18 @@ func (f *forwarder) AcceptTCPConnections() {
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
if channel != nil {
|
||||
err := channel.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close unused channel: %v", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -104,14 +107,8 @@ func (f *forwarder) AcceptTCPConnections() {
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err = conn.SetDeadline(time.Time{}); err != nil {
|
||||
log.Printf("Failed to clear connection deadline: %v", err)
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(result.reqs)
|
||||
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
|
||||
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("Timeout opening forwarded-tcpip channel")
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
@@ -150,7 +147,18 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
|
||||
defer wg.Done()
|
||||
_, err := copyWithBuffer(dst, src)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error copying src→dst: %v", err)
|
||||
log.Printf("Error copying src to dst: %v", err)
|
||||
}
|
||||
if conn, ok := dst.(interface{ CloseWrite() error }); ok {
|
||||
if err = conn.CloseWrite(); err != nil {
|
||||
log.Printf("Error closing write side of dst: %v", err)
|
||||
}
|
||||
} else {
|
||||
if closer, closerOk := dst.(io.Closer); closerOk {
|
||||
if err = closer.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error closing dst connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -158,7 +166,10 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
|
||||
defer wg.Done()
|
||||
_, err := copyWithBuffer(src, dst)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error copying dst→src: %v", err)
|
||||
log.Printf("Error copying dst to src: %v", err)
|
||||
}
|
||||
if err = src.CloseWrite(); err != nil {
|
||||
log.Printf("Error closing write side of src: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user