diff --git a/server/server.go b/server/server.go index 759bac9..4bd4804 100644 --- a/server/server.go +++ b/server/server.go @@ -90,7 +90,6 @@ func (s *Server) handleConnection(conn net.Conn) { user = u cancel() } - log.Println("SSH connection established:", sshConn.User()) sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user) err = sshSession.Start() diff --git a/session/handler.go b/session/handler.go index ca080d8..3b8e3c5 100644 --- a/session/handler.go +++ b/session/handler.go @@ -165,7 +165,6 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { slug := random.GenerateRandomString(20) key := types.SessionKey{Id: slug, Type: types.HTTP} - if !s.registry.Register(key, s) { log.Printf("Failed to register client with slug: %s", slug) err := req.Reply(false, nil) @@ -203,7 +202,6 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { s.forwarder.SetForwardedPort(portToBind) s.slugManager.Set(slug) s.lifecycle.SetStatus(types.RUNNING) - s.interaction.Start() } func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { @@ -282,7 +280,6 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind s.slugManager.Set(key.Id) s.lifecycle.SetStatus(types.RUNNING) go s.forwarder.AcceptTCPConnections() - s.interaction.Start() } func readSSHString(reader *bytes.Reader) (string, error) { diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 874fbf2..6d37f69 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -37,6 +37,9 @@ type Controller interface { SetWH(w, h int) Redraw() SetSessionRegistry(registry SessionRegistry) + SetMode(m types.Mode) + GetMode() types.Mode + Send(message string) error } type Forwarder interface { @@ -54,8 +57,24 @@ type Interaction struct { program *tea.Program ctx context.Context cancel context.CancelFunc + mode types.Mode } +func (i *Interaction) SetMode(m types.Mode) { + i.mode = m +} + +func (i *Interaction) GetMode() types.Mode { + return i.mode +} + +func (i *Interaction) Send(message string) error { + if i.channel != nil { + _, err := i.channel.Write([]byte(message)) + return err + } + return nil +} func (i *Interaction) SetWH(w, h int) { if i.program != nil { i.program.Send(tea.WindowSizeMsg{ @@ -749,6 +768,9 @@ func (m *model) View() string { } func (i *Interaction) Start() { + if i.mode == types.HEADLESS { + return + } lipgloss.SetColorProfile(termenv.TrueColor) domain := config.Getenv("DOMAIN", "localhost") diff --git a/session/session.go b/session/session.go index b1a9ac5..a7bbb5d 100644 --- a/session/session.go +++ b/session/session.go @@ -9,6 +9,7 @@ import ( "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" "tunnel_pls/session/slug" + "tunnel_pls/types" "golang.org/x/crypto/ssh" ) @@ -87,31 +88,56 @@ func (s *SSHSession) Detail() Detail { } func (s *SSHSession) Start() error { - channel := <-s.sshReqChannel - ch, reqs, err := channel.Accept() - if err != nil { - log.Printf("failed to accept channel: %v", err) - return err + var channel ssh.NewChannel + var ok bool + select { + case channel, ok = <-s.sshReqChannel: + if !ok { + log.Println("Forwarding request channel closed") + return nil + } + ch, reqs, err := channel.Accept() + if err != nil { + log.Printf("failed to accept channel: %v", err) + return err + } + go s.HandleGlobalRequest(reqs) + + s.lifecycle.SetChannel(ch) + s.interaction.SetChannel(ch) + s.interaction.SetMode(types.INTERACTIVE) + case <-time.After(500 * time.Millisecond): + s.interaction.SetMode(types.HEADLESS) } - go s.HandleGlobalRequest(reqs) tcpipReq := s.waitForTCPIPForward() if tcpipReq == nil { - _, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))) + err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) if err != nil { return err } - if err := s.lifecycle.Close(); err != nil { + if err = s.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) } return fmt.Errorf("no forwarding Request") } - s.lifecycle.SetChannel(ch) - s.interaction.SetChannel(ch) + if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") || s.lifecycle.GetUser() == "UNAUTHORIZED" { + if err := tcpipReq.Reply(false, nil); err != nil { + log.Printf("cannot reply to tcpip req: %s\n", err) + return err + } + if err := s.lifecycle.Close(); err != nil { + log.Printf("failed to close session: %v", err) + return err + } + return nil + } s.HandleTCPIPForward(tcpipReq) + s.interaction.Start() + s.lifecycle.GetConnection().Wait() if err := s.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) return err diff --git a/types/types.go b/types/types.go index 7d3f4de..bb8d199 100644 --- a/types/types.go +++ b/types/types.go @@ -8,6 +8,13 @@ const ( SETUP Status = "SETUP" ) +type Mode string + +const ( + INTERACTIVE Mode = "INTERACTIVE" + HEADLESS Mode = "HEADLESS" +) + type TunnelType string const (