feat: add headless mode support for SSH -N connections
- use s.lifecycle.GetConnection().Wait() to block until SSH connection closes - Prevent premature session closure in headless mode In headless mode (ssh -N), there's no channel interaction to block on, so the session would immediately return and close. Now blocking on conn.Wait() keeps the session alive until the client disconnects.
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
@@ -87,31 +88,56 @@ func (s *SSHSession) Detail() Detail {
|
||||
}
|
||||
|
||||
func (s *SSHSession) Start() error {
|
||||
channel := <-s.sshReqChannel
|
||||
ch, reqs, err := channel.Accept()
|
||||
if err != nil {
|
||||
log.Printf("failed to accept channel: %v", err)
|
||||
return err
|
||||
var channel ssh.NewChannel
|
||||
var ok bool
|
||||
select {
|
||||
case channel, ok = <-s.sshReqChannel:
|
||||
if !ok {
|
||||
log.Println("Forwarding request channel closed")
|
||||
return nil
|
||||
}
|
||||
ch, reqs, err := channel.Accept()
|
||||
if err != nil {
|
||||
log.Printf("failed to accept channel: %v", err)
|
||||
return err
|
||||
}
|
||||
go s.HandleGlobalRequest(reqs)
|
||||
|
||||
s.lifecycle.SetChannel(ch)
|
||||
s.interaction.SetChannel(ch)
|
||||
s.interaction.SetMode(types.INTERACTIVE)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
s.interaction.SetMode(types.HEADLESS)
|
||||
}
|
||||
go s.HandleGlobalRequest(reqs)
|
||||
|
||||
tcpipReq := s.waitForTCPIPForward()
|
||||
if tcpipReq == nil {
|
||||
_, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))))
|
||||
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.lifecycle.Close(); err != nil {
|
||||
if err = s.lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return fmt.Errorf("no forwarding Request")
|
||||
}
|
||||
|
||||
s.lifecycle.SetChannel(ch)
|
||||
s.interaction.SetChannel(ch)
|
||||
if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") || s.lifecycle.GetUser() == "UNAUTHORIZED" {
|
||||
if err := tcpipReq.Reply(false, nil); err != nil {
|
||||
log.Printf("cannot reply to tcpip req: %s\n", err)
|
||||
return err
|
||||
}
|
||||
if err := s.lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
s.HandleTCPIPForward(tcpipReq)
|
||||
s.interaction.Start()
|
||||
|
||||
s.lifecycle.GetConnection().Wait()
|
||||
if err := s.lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user