diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 4558533..3d32a43 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -152,25 +152,26 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA log.Printf("Handling new forwarded connection from %s", remoteAddr) - done := make(chan struct{}, 2) - - go func() { - _, err := copyWithBuffer(src, dst) - if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying from conn.Reader to channel: %v", err) - } - done <- struct{}{} - }() + var wg sync.WaitGroup + wg.Add(2) go func() { + defer wg.Done() _, err := copyWithBuffer(dst, src) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { - log.Printf("Error copying from channel to conn.Writer: %v", err) + log.Printf("Error copying src→dst: %v", err) } - done <- struct{}{} }() - <-done + go func() { + defer wg.Done() + _, err := copyWithBuffer(src, dst) + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { + log.Printf("Error copying dst→src: %v", err) + } + }() + + wg.Wait() } func (f *Forwarder) SetType(tunnelType types.TunnelType) {