refactor(forwarder): improve connection handling and cleanup
- Extract copyAndClose method for bidirectional data transfe - Add closeWriter helper for graceful connection shutdown - Add handleIncomingConnection helper - Add openForwardedChannel helper
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@@ -62,6 +63,55 @@ type Forwarder interface {
|
|||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
|
type channelResult struct {
|
||||||
|
channel ssh.Channel
|
||||||
|
reqs <-chan *ssh.Request
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultChan := make(chan channelResult, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
||||||
|
select {
|
||||||
|
case resultChan <- channelResult{channel, reqs, err}:
|
||||||
|
default:
|
||||||
|
if channel != nil {
|
||||||
|
err = channel.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to close unused channel: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result := <-resultChan:
|
||||||
|
return result.channel, result.reqs, result.err
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *forwarder) handleIncomingConnection(conn net.Conn) {
|
||||||
|
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||||
|
|
||||||
|
channel, reqs, err := f.openForwardedChannel(payload)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||||
|
err = conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
go f.HandleConnection(conn, channel, conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
func (f *forwarder) AcceptTCPConnections() {
|
func (f *forwarder) AcceptTCPConnections() {
|
||||||
for {
|
for {
|
||||||
conn, err := f.Listener().Accept()
|
conn, err := f.Listener().Accept()
|
||||||
@@ -73,51 +123,33 @@ func (f *forwarder) AcceptTCPConnections() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
go f.handleIncomingConnection(conn)
|
||||||
|
|
||||||
type channelResult struct {
|
|
||||||
channel ssh.Channel
|
|
||||||
reqs <-chan *ssh.Request
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
resultChan := make(chan channelResult, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
|
||||||
select {
|
|
||||||
case resultChan <- channelResult{channel, reqs, err}:
|
|
||||||
default:
|
|
||||||
if channel != nil {
|
|
||||||
err := channel.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to close unused channel: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case result := <-resultChan:
|
|
||||||
if result.err != nil {
|
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
|
|
||||||
if closeErr := conn.Close(); closeErr != nil {
|
|
||||||
log.Printf("Failed to close connection: %v", closeErr)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(result.reqs)
|
|
||||||
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
log.Printf("Timeout opening forwarded-tcpip channel")
|
|
||||||
if closeErr := conn.Close(); closeErr != nil {
|
|
||||||
log.Printf("Failed to close connection: %v", closeErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func closeWriter(w io.Writer) error {
|
||||||
|
if cw, ok := w.(interface{ CloseWrite() error }); ok {
|
||||||
|
return cw.CloseWrite()
|
||||||
|
}
|
||||||
|
if closer, ok := w.(io.Closer); ok {
|
||||||
|
return closer.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
|
||||||
|
var errs []error
|
||||||
|
_, err := copyWithBuffer(dst, src)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||||
|
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
errs = append(errs, fmt.Errorf("close writer error (%s): %w", direction, err))
|
||||||
|
}
|
||||||
|
return errors.Join(errs...)
|
||||||
|
}
|
||||||
|
|
||||||
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||||
defer func() {
|
defer func() {
|
||||||
_, err := io.Copy(io.Discard, src)
|
_, err := io.Copy(io.Discard, src)
|
||||||
@@ -133,31 +165,19 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
_, err := copyWithBuffer(dst, src)
|
err := f.copyAndClose(dst, src, "src to dst")
|
||||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
if err != nil {
|
||||||
log.Printf("Error copying src to dst: %v", err)
|
log.Println("Error during copy: ", err)
|
||||||
}
|
return
|
||||||
if conn, ok := dst.(interface{ CloseWrite() error }); ok {
|
|
||||||
if err = conn.CloseWrite(); err != nil {
|
|
||||||
log.Printf("Error closing write side of dst: %v", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if closer, closerOk := dst.(io.Closer); closerOk {
|
|
||||||
if err = closer.Close(); err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
log.Printf("Error closing dst connection: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
_, err := copyWithBuffer(src, dst)
|
err := f.copyAndClose(src, dst, "dst to src")
|
||||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
if err != nil {
|
||||||
log.Printf("Error copying dst to src: %v", err)
|
log.Println("Error during copy: ", err)
|
||||||
}
|
return
|
||||||
if err = src.CloseWrite(); err != nil {
|
|
||||||
log.Printf("Error closing write side of src: %v", err)
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user