From 560c98b8690dc3dfde98ee5089b1352b89a1b797 Mon Sep 17 00:00:00 2001 From: bagas Date: Mon, 12 Jan 2026 14:42:42 +0700 Subject: [PATCH] refactor: consolidate error handling with fail() function in session handlers - Replace repetitive error handling code with fail() function in HandleGlobalRequest - Standardize error response pattern across all handler methods - Improve code maintainability and reduce duplication --- internal/port/port.go | 1 - session/handler.go | 225 ++++++++++++++++-------------------------- 2 files changed, 86 insertions(+), 140 deletions(-) diff --git a/internal/port/port.go b/internal/port/port.go index 8eb17b9..50f6878 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -74,7 +74,6 @@ func (pm *manager) GetUnassignedPort() (uint16, bool) { for _, port := range pm.sortedPorts { if !pm.ports[port] { - pm.ports[port] = true return port, true } } diff --git a/session/handler.go b/session/handler.go index 3b8e3c5..26d2394 100644 --- a/session/handler.go +++ b/session/handler.go @@ -59,142 +59,79 @@ func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { log.Println("Port forwarding request detected") + fail := func(msg string) { + log.Println(msg) + if err := req.Reply(false, nil); err != nil { + log.Println("Failed to reply to request:", err) + return + } + if err := s.lifecycle.Close(); err != nil { + log.Printf("failed to close session: %v", err) + } + } + reader := bytes.NewReader(req.Payload) addr, err := readSSHString(reader) if err != nil { - log.Println("Failed to read address from payload:", err) - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - err = s.lifecycle.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } + fail(fmt.Sprintf("Failed to read address from payload: %v", err)) return } var rawPortToBind uint32 - if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { - log.Println("Failed to read port from payload:", err) - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - err = s.lifecycle.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } + if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { + fail(fmt.Sprintf("Failed to read port from payload: %v", err)) return } if rawPortToBind > 65535 { - log.Printf("Port %d is larger than allowed port of 65535", rawPortToBind) - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - err = s.lifecycle.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } + fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind)) return } portToBind := uint16(rawPortToBind) if isBlockedPort(portToBind) { - log.Printf("Port %d is blocked or restricted", portToBind) - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - err = s.lifecycle.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } + fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind)) return } - if portToBind == 80 || portToBind == 443 { + switch portToBind { + case 80, 443: s.HandleHTTPForward(req, portToBind) - return + default: + s.HandleTCPForward(req, addr, portToBind) } - if portToBind == 0 { - unassign, success := portUtil.Default.GetUnassignedPort() - portToBind = unassign - if !success { - log.Println("No available port") - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - err = s.lifecycle.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } - return - } - } else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse { - log.Printf("Port %d is already in use or restricted", portToBind) - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - return - } - err = s.lifecycle.Close() - if err != nil { - log.Printf("failed to close session: %v", err) - } - return - } - err = portUtil.Default.SetPortStatus(portToBind, true) - if err != nil { - log.Println("Failed to set port status:", err) - return - } - - s.HandleTCPForward(req, addr, portToBind) } func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { + fail := func(msg string, key *types.SessionKey) { + log.Println(msg) + if key != nil { + s.registry.Remove(*key) + } + if err := req.Reply(false, nil); err != nil { + log.Println("Failed to reply to request:", err) + } + } + slug := random.GenerateRandomString(20) key := types.SessionKey{Id: slug, Type: types.HTTP} if !s.registry.Register(key, s) { - log.Printf("Failed to register client with slug: %s", slug) - err := req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - } + fail(fmt.Sprintf("Failed to register client with slug: %s", slug), nil) return } buf := new(bytes.Buffer) err := binary.Write(buf, binary.BigEndian, uint32(portToBind)) if err != nil { - log.Println("Failed to write port to buffer:", err) - s.registry.Remove(key) - err = req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - } + fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key) return } log.Printf("HTTP forwarding approved on port: %d", portToBind) err = req.Reply(true, buf.Bytes()) if err != nil { - log.Println("Failed to reply to request:", err) - s.registry.Remove(key) - err = req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - } + fail(fmt.Sprintf("Failed to reply to request: %v", err), &key) return } @@ -205,72 +142,82 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { } func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { - log.Printf("Requested forwarding on %s:%d", addr, portToBind) - listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) - if err != nil { - log.Printf("Port %d is already in use or restricted", portToBind) - if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { - log.Printf("Failed to reset port status: %v", setErr) - } - err = req.Reply(false, nil) - if err != nil { + fail := func(msg string) { + log.Println(msg) + if err := req.Reply(false, nil); err != nil { log.Println("Failed to reply to request:", err) return } - err = s.lifecycle.Close() - if err != nil { + if err := s.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) } + } + + cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) { + log.Println(msg) + if key != nil { + s.registry.Remove(*key) + } + if port != 0 { + if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil { + log.Printf("Failed to reset port status: %v", setErr) + } + } + if listener != nil { + if closeErr := listener.Close(); closeErr != nil { + log.Printf("Failed to close listener: %v", closeErr) + } + } + if err := req.Reply(false, nil); err != nil { + log.Println("Failed to reply to request:", err) + } + _ = s.lifecycle.Close() + } + + if portToBind == 0 { + unassigned, ok := portUtil.Default.GetUnassignedPort() + if !ok { + fail("No available port") + return + } + portToBind = unassigned + } + + if isUsed, exists := portUtil.Default.GetPortStatus(portToBind); exists && isUsed { + fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind)) + return + } + + if err := portUtil.Default.SetPortStatus(portToBind, true); err != nil { + fail(fmt.Sprintf("Failed to set port status: %v", err)) + return + } + + log.Printf("Requested forwarding on %s:%d", addr, portToBind) + listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) + if err != nil { + cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil) return } key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} if !s.registry.Register(key, s) { - log.Printf("Failed to register TCP client with id: %s", key.Id) - if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { - log.Printf("Failed to reset port status: %v", setErr) - } - if closeErr := listener.Close(); closeErr != nil { - log.Printf("Failed to close listener: %s", closeErr) - } - err = req.Reply(false, nil) - if err != nil { - log.Println("Failed to reply to request:", err) - } - _ = s.lifecycle.Close() + cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil) return } buf := new(bytes.Buffer) err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) if err != nil { - log.Println("Failed to write port to buffer:", err) - s.registry.Remove(key) - if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { - log.Printf("Failed to reset port status: %v", setErr) - } - err = listener.Close() - if err != nil { - log.Printf("Failed to close listener: %s", err) - return - } + cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key) return } log.Printf("TCP forwarding approved on port: %d", portToBind) err = req.Reply(true, buf.Bytes()) if err != nil { - log.Println("Failed to reply to request:", err) - s.registry.Remove(key) - if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { - log.Printf("Failed to reset port status: %v", setErr) - } - err = listener.Close() - if err != nil { - log.Printf("Failed to close listener: %s", err) - return - } + cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key) return }