diff --git a/server/http.go b/server/http.go index 46a00d5..36b17e1 100644 --- a/server/http.go +++ b/server/http.go @@ -5,17 +5,14 @@ import ( "bytes" "errors" "fmt" - "golang.org/x/net/context" "log" "net" - "strconv" "strings" - "time" "tunnel_pls/session" "tunnel_pls/utils" ) -var redirectTLS bool = false +var redirectTLS = false func NewHTTPServer() error { listener, err := net.Listen("tcp", ":80") @@ -81,23 +78,10 @@ func Handler(conn net.Conn) { conn.Close() return } - keepalive, timeout := parseConnectionDetails(headers) - var ctx context.Context - var cancel context.CancelFunc - if keepalive { - if timeout >= 300 { - timeout = 300 - } - ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second)) - } else { - ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) - } sshSession.HandleForwardedConnection(session.UserConnection{ - Reader: reader, - Writer: conn, - Context: ctx, - Cancel: cancel, + Reader: reader, + Writer: conn, }, sshSession.Connection) return } @@ -131,42 +115,3 @@ func parseHostFromHeader(data []byte) string { } return "" } - -func parseConnectionDetails(data []byte) (keepAlive bool, timeout int) { - keepAlive = false - timeout = 30 - - lines := strings.Split(string(data), "\r\n") - - for _, line := range lines { - if strings.HasPrefix(strings.ToLower(line), "connection:") { - value := strings.TrimSpace(strings.TrimPrefix(strings.ToLower(line), "connection:")) - keepAlive = (value == "keep-alive") - break - } - } - - if keepAlive { - for _, line := range lines { - if strings.HasPrefix(strings.ToLower(line), "keep-alive:") { - value := strings.TrimSpace(strings.TrimPrefix(line, "Keep-Alive:")) - - if strings.Contains(value, "timeout=") { - parts := strings.Split(value, ",") - for _, part := range parts { - part = strings.TrimSpace(part) - if strings.HasPrefix(part, "timeout=") { - timeoutStr := strings.TrimPrefix(part, "timeout=") - if t, err := strconv.Atoi(timeoutStr); err == nil { - timeout = t - } - } - } - } - break - } - } - } - - return keepAlive, timeout -} diff --git a/server/https.go b/server/https.go index 2d83e5b..0b882b7 100644 --- a/server/https.go +++ b/server/https.go @@ -4,11 +4,9 @@ import ( "bufio" "crypto/tls" "errors" - "golang.org/x/net/context" "log" "net" "strings" - "time" "tunnel_pls/session" "tunnel_pls/utils" ) @@ -70,23 +68,10 @@ func HandlerTLS(conn net.Conn) { conn.Close() return } - keepalive, timeout := parseConnectionDetails(headers) - var ctx context.Context - var cancel context.CancelFunc - if keepalive { - if timeout >= 300 { - timeout = 300 - } - ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second)) - } else { - ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second)) - } sshSession.HandleForwardedConnection(session.UserConnection{ - Reader: reader, - Writer: conn, - Context: ctx, - Cancel: cancel, + Reader: reader, + Writer: conn, }, sshSession.Connection) return } diff --git a/session/handler.go b/session/handler.go index 779936c..0290190 100644 --- a/session/handler.go +++ b/session/handler.go @@ -3,6 +3,7 @@ package session import ( "bufio" "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -16,7 +17,6 @@ import ( portUtil "tunnel_pls/internal/port" "golang.org/x/crypto/ssh" - "golang.org/x/net/context" "tunnel_pls/utils" ) @@ -29,10 +29,8 @@ const ( ) type UserConnection struct { - Reader io.Reader - Writer net.Conn - Context context.Context - Cancel context.CancelFunc + Reader io.Reader + Writer net.Conn } var ( @@ -78,6 +76,13 @@ func updateClientSlug(oldSlug, newSlug string) bool { return true } +func (s *Session) safeClose() { + s.once.Do(func() { + close(s.ChannelChan) + close(s.Done) + }) +} + func (s *Session) Close() { if s.Listener != nil { s.Listener.Close() @@ -99,7 +104,7 @@ func (s *Session) Close() { portUtil.Manager.SetPortStatus(s.ForwardedPort, false) } - close(s.Done) + s.safeClose() } func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { @@ -267,9 +272,8 @@ func (s *Session) acceptTCPConnections() { } go s.HandleForwardedConnection(UserConnection{ - Reader: nil, - Writer: conn, - Context: context.Background(), + Reader: nil, + Writer: conn, }, s.Connection) } } @@ -538,40 +542,105 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se } defer channel.Close() - go handleChannelRequests(reqs, conn, channel) + go func() { + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in request handler: %v", r) + } + }() + for req := range reqs { + req.Reply(false, nil) + } + }() if conn.Reader == nil { conn.Reader = bufio.NewReader(conn.Writer) } - go io.Copy(channel, conn.Reader) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var wg sync.WaitGroup + errChan := make(chan error, 2) + + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in reader copy: %v", r) + errChan <- fmt.Errorf("panic: %v", r) + } + }() + + _, err := io.Copy(channel, conn.Reader) + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error copying from conn.Reader to channel: %v", err) + errChan <- err + } + cancel() + }() reader := bufio.NewReader(channel) - _, err = reader.Peek(1) - if err == io.EOF { - s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, "Could not forward request to the tunnel addr")) + + peekDone := make(chan error, 1) + go func() { + _, err := reader.Peek(1) + peekDone <- err + }() + + select { + case err := <-peekDone: + if err == io.EOF { + s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) + sendBadGatewayResponse(conn.Writer) + return + } + if err != nil { + log.Printf("Error peeking channel data: %v", err) + s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) + sendBadGatewayResponse(conn.Writer) + return + } + case <-time.After(5 * time.Second): + log.Printf("Timeout waiting for channel data from %s", conn.Writer.RemoteAddr()) + s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType)) sendBadGatewayResponse(conn.Writer) - conn.Writer.Close() - channel.Close() + return + case <-ctx.Done(): return } - s.sendMessage(fmt.Sprintf("\033[32m %s -> [%s] TUNNEL ADDRESS -- \"%s\" \r\n \033[0m", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) + s.sendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) - io.Copy(conn.Writer, reader) -} + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + log.Printf("Panic in writer copy: %v", r) + errChan <- fmt.Errorf("panic: %v", r) + } + }() -func handleChannelRequests(reqs <-chan *ssh.Request, conn UserConnection, channel ssh.Channel) { - select { - case <-reqs: - for req := range reqs { - req.Reply(false, nil) + _, err := io.Copy(conn.Writer, reader) + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error copying from channel to conn.Writer: %v", err) + errChan <- err + } + cancel() + }() + + go func() { + wg.Wait() + close(errChan) + }() + + for err := range errChan { + if err != nil { + log.Printf("Connection error: %v", err) + break } - case <-conn.Context.Done(): - conn.Writer.Close() - channel.Close() - log.Println("Connection closed by timeout") - return } } diff --git a/session/session.go b/session/session.go index e61eec4..63177c5 100644 --- a/session/session.go +++ b/session/session.go @@ -3,6 +3,7 @@ package session import ( "golang.org/x/crypto/ssh" "net" + "sync" ) type TunnelType string @@ -24,6 +25,7 @@ type Session struct { Slug string ChannelChan chan ssh.NewChannel Done chan bool + once sync.Once } func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {