Improve concurrency and resource management #2
@ -11,7 +11,11 @@ func (s *Server) handleConnection(conn net.Conn) {
|
|||||||
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config)
|
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to establish SSH connection: %v", err)
|
log.Printf("failed to establish SSH connection: %v", err)
|
||||||
conn.Close()
|
err := conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close SSH connection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,5 +25,12 @@ func (s *Server) handleConnection(conn net.Conn) {
|
|||||||
for ch := range chans {
|
for ch := range chans {
|
||||||
newSession.ChannelChan <- ch
|
newSession.ChannelChan <- ch
|
||||||
}
|
}
|
||||||
|
|
||||||
|
defer func(newSession *session.Session) {
|
||||||
|
err := newSession.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
}(newSession)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -83,17 +83,30 @@ func (s *Session) safeClose() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) Close() {
|
func (s *Session) Close() error {
|
||||||
if s.Listener != nil {
|
if s.Listener != nil {
|
||||||
s.Listener.Close()
|
err := s.Listener.Close()
|
||||||
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
fmt.Println("1")
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.ConnChannel != nil {
|
if s.ConnChannel != nil {
|
||||||
s.ConnChannel.Close()
|
err := s.ConnChannel.Close()
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
fmt.Println("2")
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Connection != nil {
|
if s.Connection != nil {
|
||||||
s.Connection.Close()
|
err := s.Connection.Close()
|
||||||
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
fmt.Println("3")
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.Slug != "" {
|
if s.Slug != "" {
|
||||||
@ -101,10 +114,15 @@ func (s *Session) Close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.TunnelType == TCP {
|
if s.TunnelType == TCP {
|
||||||
portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
|
err := portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("4")
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.safeClose()
|
s.safeClose()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||||
@ -542,7 +560,6 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se
|
|||||||
}
|
}
|
||||||
defer channel.Close()
|
defer channel.Close()
|
||||||
|
|
||||||
// Goroutine 1: Handle SSH channel requests (same as old code)
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@ -570,7 +587,7 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := io.Copy(channel, conn.Reader)
|
_, err := io.Copy(channel, conn.Reader)
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||||
log.Printf("Error copying from conn.Reader to channel: %v", err)
|
log.Printf("Error copying from conn.Reader to channel: %v", err)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
|
|||||||
Reference in New Issue
Block a user