diff --git a/session/handler.go b/session/handler.go index 19f42ad..425a5c8 100644 --- a/session/handler.go +++ b/session/handler.go @@ -3,7 +3,6 @@ package session import ( "bufio" "bytes" - "context" "encoding/binary" "errors" "fmt" @@ -557,60 +556,21 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se conn.Reader = bufio.NewReader(conn.Writer) } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - defer func() { - if r := recover(); r != nil { - log.Printf("Panic in reader copy: %v", r) - } - cancel() - }() - - _, err := io.Copy(channel, conn.Reader) - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error copying from conn.Reader to channel: %v", err) - } - cancel() - }() + go io.Copy(channel, conn.Reader) reader := bufio.NewReader(channel) - - peekChan := make(chan error, 1) - go func() { - _, err := reader.Peek(1) - peekChan <- err - }() - - select { - case err := <-peekChan: - if err == io.EOF { - s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - } - if err != nil { - log.Printf("Error peeking channel data: %v", err) - s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) - sendBadGatewayResponse(conn.Writer) - return - } - case <-time.After(5 * time.Second): - log.Printf("Timeout waiting for channel data from %s", conn.Writer.RemoteAddr()) - s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) + _, err = reader.Peek(1) + if err == io.EOF { + s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, "Could not forward request to the tunnel addr")) sendBadGatewayResponse(conn.Writer) - return - case <-ctx.Done(): + conn.Writer.Close() + channel.Close() return } - s.sendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) + s.sendMessage(fmt.Sprintf("\033[32m %s -> [%s] TUNNEL ADDRESS -- \"%s\" \r\n \033[0m", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) - _, err = io.Copy(conn.Writer, reader) - if err != nil && !errors.Is(err, io.EOF) { - log.Printf("Error copying from channel to conn.Writer: %v", err) - } + io.Copy(conn.Writer, reader) } func sendBadGatewayResponse(writer io.Writer) {