From 4a25627ab5fa4c4d7caa95f95def0090480b4e5f Mon Sep 17 00:00:00 2001 From: bagas Date: Mon, 21 Jul 2025 16:36:05 +0700 Subject: [PATCH] fix: using close channel when attempting to close session --- server/handler.go | 8 ++++-- session/handler.go | 70 +++++++++++----------------------------------- session/session.go | 26 +++++++++++++++-- 3 files changed, 46 insertions(+), 58 deletions(-) diff --git a/server/handler.go b/server/handler.go index d167a55..7908c82 100644 --- a/server/handler.go +++ b/server/handler.go @@ -8,7 +8,7 @@ import ( ) func (s *Server) handleConnection(conn net.Conn) { - sshConn, chans, reqs, err := ssh.NewServerConn(conn, s.Config) + sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config) if err != nil { log.Printf("failed to establish SSH connection: %v", err) conn.Close() @@ -17,5 +17,9 @@ func (s *Server) handleConnection(conn net.Conn) { log.Println("SSH connection established:", sshConn.User()) - session.New(sshConn, chans, reqs) + newSession := session.New(sshConn, forwardingReqs) + for ch := range chans { + newSession.ChannelChan <- ch + } + return } diff --git a/session/handler.go b/session/handler.go index 542c9a4..779936c 100644 --- a/session/handler.go +++ b/session/handler.go @@ -40,17 +40,6 @@ var ( Clients = make(map[string]*Session) ) -type Session struct { - Connection *ssh.ServerConn - ConnChannel ssh.Channel - Listener net.Listener - TunnelType TunnelType - ForwardedPort uint16 - Status SessionStatus - Slug string - Done chan bool -} - func registerClient(slug string, session *Session) bool { clientsMutex.Lock() defer clientsMutex.Unlock() @@ -113,28 +102,17 @@ func (s *Session) Close() { close(s.Done) } -func (s *Session) handleGlobalRequest(GlobalRequest <-chan *ssh.Request) { - ticker := time.NewTicker(1 * time.Second) - for { - select { - case req := <-GlobalRequest: - ticker.Stop() - if req == nil { - return - } - if req.Type == "tcpip-forward" { - s.handleTCPIPForward(req) - } else if req.Type == "shell" || req.Type == "pty-req" || req.Type == "window-change" { - req.Reply(true, nil) - } else { - req.Reply(false, nil) - } - case <-s.Done: - return - case <-ticker.C: - s.sendMessage(fmt.Sprintf("Please specify the forwarding tunnel. For example: 'ssh %s -p %s -R 443:localhost:8080' \r\n\n\n", utils.Getenv("domain"), utils.Getenv("port"))) - s.Close() +func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { + for req := range GlobalRequest { + switch req.Type { + case "tcpip-forward": + s.handleTCPIPForward(req) return + case "shell", "pty-req", "window-change": + req.Reply(true, nil) + default: + log.Println("Unknown request type:", req.Type) + req.Reply(false, nil) } } } @@ -181,6 +159,7 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) { showWelcomeMessage(s.ConnChannel) s.Status = RUNNING + go s.handleUserInput() if portToBind == 80 || portToBind == 443 { s.handleHTTPForward(req, portToBind) @@ -338,29 +317,14 @@ func (s *Session) sendMessage(message string) { } } -func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel, initialRequest <-chan *ssh.Request) { - connection, requests, err := newChannel.Accept() - if err != nil { - log.Printf("Could not accept channel: %s", err) - return - } - - s.ConnChannel = connection - s.Status = RUNNING - - go s.handleGlobalRequest(initialRequest) - go s.handleGlobalRequest(requests) - go s.handleUserInput(connection) -} - -func (s *Session) handleUserInput(connection ssh.Channel) { +func (s *Session) handleUserInput() { var commandBuffer bytes.Buffer buf := make([]byte, 1) inSlugEditMode := false editSlug := s.Slug for { - n, err := connection.Read(buf) + n, err := s.ConnChannel.Read(buf) if err != nil { if err != io.EOF { log.Printf("Error reading from client: %s", err) @@ -372,16 +336,16 @@ func (s *Session) handleUserInput(connection ssh.Channel) { char := buf[0] if inSlugEditMode { - s.handleSlugEditMode(connection, &inSlugEditMode, &editSlug, char, &commandBuffer) + s.handleSlugEditMode(s.ConnChannel, &inSlugEditMode, &editSlug, char, &commandBuffer) continue } - connection.Write(buf[:n]) + s.ConnChannel.Write(buf[:n]) if char == 8 || char == 127 { if commandBuffer.Len() > 0 { commandBuffer.Truncate(commandBuffer.Len() - 1) - connection.Write([]byte("\b \b")) + s.ConnChannel.Write([]byte("\b \b")) } continue } @@ -394,7 +358,7 @@ func (s *Session) handleUserInput(connection ssh.Channel) { if commandBuffer.Len() > 0 { if char == 13 { - s.handleCommand(connection, commandBuffer.String(), &inSlugEditMode, &editSlug, &commandBuffer) + s.handleCommand(s.ConnChannel, commandBuffer.String(), &inSlugEditMode, &editSlug, &commandBuffer) continue } commandBuffer.WriteByte(char) diff --git a/session/session.go b/session/session.go index bc82a82..e61eec4 100644 --- a/session/session.go +++ b/session/session.go @@ -2,6 +2,7 @@ package session import ( "golang.org/x/crypto/ssh" + "net" ) type TunnelType string @@ -13,19 +14,38 @@ const ( UNKNOWN TunnelType = "unknown" ) -func New(conn *ssh.ServerConn, sshChannel <-chan ssh.NewChannel, req <-chan *ssh.Request) *Session { +type Session struct { + Connection *ssh.ServerConn + ConnChannel ssh.Channel + Listener net.Listener + TunnelType TunnelType + ForwardedPort uint16 + Status SessionStatus + Slug string + ChannelChan chan ssh.NewChannel + Done chan bool +} + +func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session { session := &Session{ Status: SETUP, Slug: "", ConnChannel: nil, Connection: conn, TunnelType: UNKNOWN, + ChannelChan: make(chan ssh.NewChannel), Done: make(chan bool), } go func() { - for newChannel := range sshChannel { - go session.HandleSessionChannel(newChannel, req) + for channel := range session.ChannelChan { + ch, reqs, _ := channel.Accept() + if session.ConnChannel == nil { + session.ConnChannel = ch + session.Status = RUNNING + go session.HandleGlobalRequest(forwardingReq) + } + go session.HandleGlobalRequest(reqs) } }()