From c4dd086fb34ddec10ed131b8144514ef0daecf29 Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 23 Jul 2025 12:40:20 +0700 Subject: [PATCH] fix: ensure SSH connections close on client disconnect --- server/handler.go | 13 ++++++++++++- session/handler.go | 31 ++++++++++++++++++++++++------- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/server/handler.go b/server/handler.go index 7908c82..bf6fd9e 100644 --- a/server/handler.go +++ b/server/handler.go @@ -11,7 +11,11 @@ func (s *Server) handleConnection(conn net.Conn) { sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config) if err != nil { log.Printf("failed to establish SSH connection: %v", err) - conn.Close() + err := conn.Close() + if err != nil { + log.Printf("failed to close SSH connection: %v", err) + return + } return } @@ -21,5 +25,12 @@ func (s *Server) handleConnection(conn net.Conn) { for ch := range chans { newSession.ChannelChan <- ch } + + defer func(newSession *session.Session) { + err := newSession.Close() + if err != nil { + log.Printf("failed to close session: %v", err) + } + }(newSession) return } diff --git a/session/handler.go b/session/handler.go index 8317c54..fb840d7 100644 --- a/session/handler.go +++ b/session/handler.go @@ -83,17 +83,30 @@ func (s *Session) safeClose() { }) } -func (s *Session) Close() { +func (s *Session) Close() error { if s.Listener != nil { - s.Listener.Close() + err := s.Listener.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + fmt.Println("1") + return err + } } if s.ConnChannel != nil { - s.ConnChannel.Close() + err := s.ConnChannel.Close() + if err != nil && !errors.Is(err, io.EOF) { + fmt.Println("2") + return err + } } if s.Connection != nil { - s.Connection.Close() + err := s.Connection.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + fmt.Println("3") + + return err + } } if s.Slug != "" { @@ -101,10 +114,15 @@ func (s *Session) Close() { } if s.TunnelType == TCP { - portUtil.Manager.SetPortStatus(s.ForwardedPort, false) + err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false) + if err != nil { + fmt.Println("4") + return err + } } s.safeClose() + return nil } func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { @@ -542,7 +560,6 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se } defer channel.Close() - // Goroutine 1: Handle SSH channel requests (same as old code) go func() { defer func() { if r := recover(); r != nil { @@ -570,7 +587,7 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se }() _, err := io.Copy(channel, conn.Reader) - if err != nil && !errors.Is(err, io.EOF) { + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { log.Printf("Error copying from conn.Reader to channel: %v", err) } cancel()