diff --git a/main.exe b/main.exe new file mode 100644 index 0000000..ca198ce Binary files /dev/null and b/main.exe differ diff --git a/session/handler.go b/session/handler.go index 2181daa..b461673 100644 --- a/session/handler.go +++ b/session/handler.go @@ -41,8 +41,7 @@ var ( type Session struct { Connection *ssh.ServerConn - ConnChannels []ssh.Channel - GlobalRequest <-chan *ssh.Request + ConnChannel ssh.Channel Listener net.Listener TunnelType TunnelType ForwardedPort uint16 @@ -94,8 +93,8 @@ func (s *Session) Close() { s.Listener.Close() } - for _, ch := range s.ConnChannels { - ch.Close() + if s.ConnChannel != nil { + s.ConnChannel.Close() } if s.Connection != nil { @@ -109,17 +108,19 @@ func (s *Session) Close() { close(s.Done) } -func (s *Session) handleGlobalRequest() { +func (s *Session) handleGlobalRequest(GlobalRequest <-chan *ssh.Request) { ticker := time.NewTicker(1 * time.Second) for { select { - case req := <-s.GlobalRequest: + case req := <-GlobalRequest: ticker.Stop() if req == nil { return } if req.Type == "tcpip-forward" { s.handleTCPIPForward(req) + } else if req.Type == "shell" || req.Type == "pty-req" || req.Type == "window-change" { + req.Reply(true, nil) } else { req.Reply(false, nil) } @@ -128,6 +129,7 @@ func (s *Session) handleGlobalRequest() { case <-ticker.C: s.sendMessage(fmt.Sprintf("Please specify the forwarding tunnel. For example: 'ssh %s -p %s -R 443:localhost:8080' \r\n\n\n", utils.Getenv("domain"), utils.Getenv("port"))) s.Close() + return } } } @@ -162,7 +164,7 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) { } s.sendMessage("\033[H\033[2J") - showWelcomeMessage(s.ConnChannels[0]) + showWelcomeMessage(s.ConnChannel) s.Status = RUNNING if portToBind == 80 || portToBind == 443 { @@ -299,23 +301,24 @@ func (s *Session) waitForRunningStatus() { } func (s *Session) sendMessage(message string) { - if len(s.ConnChannels) > 0 { - s.ConnChannels[0].Write([]byte(message)) + if s.ConnChannel != nil { + s.ConnChannel.Write([]byte(message)) } } -func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel) { +func (s *Session) HandleSessionChannel(newChannel ssh.NewChannel, initialRequest <-chan *ssh.Request) { connection, requests, err := newChannel.Accept() if err != nil { log.Printf("Could not accept channel: %s", err) return } - s.ConnChannels = append(s.ConnChannels, connection) + s.ConnChannel = connection + s.Status = RUNNING + go s.handleGlobalRequest(initialRequest) + go s.handleGlobalRequest(requests) go s.handleUserInput(connection) - - go s.handleChannelRequests(connection, requests) } func (s *Session) handleUserInput(connection ssh.Channel) { @@ -493,7 +496,7 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd connection.Write([]byte("\r\nAvailable commands: /bye, /help, /clear, /slug")) case "/clear": connection.Write([]byte("\033[H\033[2J")) - showWelcomeMessage(s.ConnChannels[0]) + showWelcomeMessage(s.ConnChannel) domain := utils.Getenv("domain") if s.TunnelType == HTTP { protocol := "http" @@ -522,20 +525,6 @@ func (s *Session) handleCommand(connection ssh.Channel, command string, inSlugEd commandBuffer.Reset() } -func (s *Session) handleChannelRequests(connection ssh.Channel, requests <-chan *ssh.Request) { - go s.handleGlobalRequest() - - for req := range requests { - switch req.Type { - case "shell", "pty-req", "window-change": - req.Reply(true, nil) - default: - log.Println("Unknown request type:", req.Type) - req.Reply(false, nil) - } - } -} - func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.ServerConn) { defer conn.Writer.Close() diff --git a/session/session.go b/session/session.go index 9e05b72..bc82a82 100644 --- a/session/session.go +++ b/session/session.go @@ -15,19 +15,17 @@ const ( func New(conn *ssh.ServerConn, sshChannel <-chan ssh.NewChannel, req <-chan *ssh.Request) *Session { session := &Session{ - Status: SETUP, - Slug: "", - ConnChannels: []ssh.Channel{}, - Connection: conn, - GlobalRequest: req, - TunnelType: UNKNOWN, - SlugChannel: make(chan bool), - Done: make(chan bool), + Status: SETUP, + Slug: "", + ConnChannel: nil, + Connection: conn, + TunnelType: UNKNOWN, + Done: make(chan bool), } go func() { for newChannel := range sshChannel { - go session.HandleSessionChannel(newChannel) + go session.HandleSessionChannel(newChannel, req) } }()