Improve concurrency and resource management #2

Merged
bagas merged 12 commits from staging into main 2025-07-23 06:51:09 +00:00
7 changed files with 298 additions and 208 deletions
Showing only changes of commit c4dd086fb3 - Show all commits

View File

@ -11,7 +11,11 @@ func (s *Server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config)
if err != nil {
log.Printf("failed to establish SSH connection: %v", err)
conn.Close()
err := conn.Close()
if err != nil {
log.Printf("failed to close SSH connection: %v", err)
return
}
return
}
@ -21,5 +25,12 @@ func (s *Server) handleConnection(conn net.Conn) {
for ch := range chans {
newSession.ChannelChan <- ch
}
defer func(newSession *session.Session) {
err := newSession.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
}(newSession)
return
}

View File

@ -83,17 +83,30 @@ func (s *Session) safeClose() {
})
}
func (s *Session) Close() {
func (s *Session) Close() error {
if s.Listener != nil {
s.Listener.Close()
err := s.Listener.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
fmt.Println("1")
return err
}
}
if s.ConnChannel != nil {
s.ConnChannel.Close()
err := s.ConnChannel.Close()
if err != nil && !errors.Is(err, io.EOF) {
fmt.Println("2")
return err
}
}
if s.Connection != nil {
s.Connection.Close()
err := s.Connection.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
fmt.Println("3")
return err
}
}
if s.Slug != "" {
@ -101,10 +114,15 @@ func (s *Session) Close() {
}
if s.TunnelType == TCP {
portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
if err != nil {
fmt.Println("4")
return err
}
}
s.safeClose()
return nil
}
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
@ -542,7 +560,6 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se
}
defer channel.Close()
// Goroutine 1: Handle SSH channel requests (same as old code)
go func() {
defer func() {
if r := recover(); r != nil {
@ -570,7 +587,7 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se
}()
_, err := io.Copy(channel, conn.Reader)
if err != nil && !errors.Is(err, io.EOF) {
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying from conn.Reader to channel: %v", err)
}
cancel()