From 9bd2bead9e543fc7469d7ed0955f714d769288f2 Mon Sep 17 00:00:00 2001 From: bagas Date: Sat, 6 Dec 2025 23:14:13 +0700 Subject: [PATCH] refactor: instantiate new session object once forwarding is approved --- session/handler.go | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/session/handler.go b/session/handler.go index db45de3..8e7019e 100644 --- a/session/handler.go +++ b/session/handler.go @@ -105,10 +105,6 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { return } - s.Interaction.SendMessage("\033[H\033[2J") - s.Lifecycle.SetStatus(types.RUNNING) - go s.Interaction.HandleUserInput() - if portToBind == 80 || portToBind == 443 { s.HandleHTTPForward(req, portToBind) return @@ -152,9 +148,6 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { } func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { - s.Forwarder.SetType(types.HTTP) - s.Forwarder.SetForwardedPort(portToBind) - slug := generateUniqueSlug() if slug == "" { err := req.Reply(false, nil) @@ -173,13 +166,15 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { return } - s.SlugManager.Set(slug) - buf := new(bytes.Buffer) err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) if err != nil { log.Println("Failed to write port to buffer:", err) unregisterClient(slug) + err = req.Reply(false, nil) + if err != nil { + log.Println("Failed to reply to request:", err) + } return } log.Printf("HTTP forwarding approved on port: %d", portToBind) @@ -190,24 +185,33 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { protocol = "https" } - s.Interaction.ShowWelcomeMessage() - s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) - err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) unregisterClient(slug) + err = req.Reply(false, nil) + if err != nil { + log.Println("Failed to reply to request:", err) + } return } + + s.Forwarder.SetType(types.HTTP) + s.Forwarder.SetForwardedPort(portToBind) + s.SlugManager.Set(slug) + s.Interaction.SendMessage("\033[H\033[2J") + s.Interaction.ShowWelcomeMessage() + s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) + s.Lifecycle.SetStatus(types.RUNNING) + s.Interaction.HandleUserInput() } func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - s.Forwarder.SetType(types.TCP) log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) - err := req.Reply(false, nil) + err = req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) return @@ -243,11 +247,15 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind return } + s.Forwarder.SetType(types.TCP) s.Forwarder.SetListener(listener) s.Forwarder.SetForwardedPort(portToBind) + s.Interaction.SendMessage("\033[H\033[2J") s.Interaction.ShowWelcomeMessage() s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("domain"), s.Forwarder.GetForwardedPort())) + s.Lifecycle.SetStatus(types.RUNNING) go s.Forwarder.AcceptTCPConnections() + s.Interaction.HandleUserInput() } func generateUniqueSlug() string {