revert-54069ad305 #11
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@@ -62,6 +63,55 @@ type Forwarder interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (f *forwarder) openForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
if channel != nil {
|
||||
err = channel.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close unused channel: %v", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
return result.channel, result.reqs, result.err
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) handleIncomingConnection(conn net.Conn) {
|
||||
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||
|
||||
channel, reqs, err := f.openForwardedChannel(payload)
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close connection: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
go f.HandleConnection(conn, channel, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (f *forwarder) AcceptTCPConnections() {
|
||||
for {
|
||||
conn, err := f.Listener().Accept()
|
||||
@@ -73,51 +123,33 @@ func (f *forwarder) AcceptTCPConnections() {
|
||||
continue
|
||||
}
|
||||
|
||||
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
if channel != nil {
|
||||
err := channel.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close unused channel: %v", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
go ssh.DiscardRequests(result.reqs)
|
||||
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("Timeout opening forwarded-tcpip channel")
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
}
|
||||
}
|
||||
go f.handleIncomingConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func closeWriter(w io.Writer) error {
|
||||
if cw, ok := w.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
}
|
||||
if closer, ok := w.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
|
||||
var errs []error
|
||||
_, err := copyWithBuffer(dst, src)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
|
||||
}
|
||||
|
||||
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
|
||||
errs = append(errs, fmt.Errorf("close writer error (%s): %w", direction, err))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||
defer func() {
|
||||
_, err := io.Copy(io.Discard, src)
|
||||
@@ -133,31 +165,19 @@ func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := copyWithBuffer(dst, src)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error copying src to dst: %v", err)
|
||||
}
|
||||
if conn, ok := dst.(interface{ CloseWrite() error }); ok {
|
||||
if err = conn.CloseWrite(); err != nil {
|
||||
log.Printf("Error closing write side of dst: %v", err)
|
||||
}
|
||||
} else {
|
||||
if closer, closerOk := dst.(io.Closer); closerOk {
|
||||
if err = closer.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error closing dst connection: %v", err)
|
||||
}
|
||||
}
|
||||
err := f.copyAndClose(dst, src, "src to dst")
|
||||
if err != nil {
|
||||
log.Println("Error during copy: ", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := copyWithBuffer(src, dst)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error copying dst to src: %v", err)
|
||||
}
|
||||
if err = src.CloseWrite(); err != nil {
|
||||
log.Printf("Error closing write side of src: %v", err)
|
||||
err := f.copyAndClose(src, dst, "dst to src")
|
||||
if err != nil {
|
||||
log.Println("Error during copy: ", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user