diff --git a/internal/port/port.go b/internal/port/port.go index 50f6878..bd5073a 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -13,7 +13,7 @@ type Manager interface { AddPortRange(startPort, endPort uint16) error GetUnassignedPort() (uint16, bool) SetPortStatus(port uint16, assigned bool) error - GetPortStatus(port uint16) (bool, bool) + ClaimPort(port uint16) (claimed bool) } type manager struct { @@ -88,10 +88,21 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error { return nil } -func (pm *manager) GetPortStatus(port uint16) (bool, bool) { - pm.mu.RLock() - defer pm.mu.RUnlock() +func (pm *manager) ClaimPort(port uint16) (claimed bool) { + pm.mu.Lock() + defer pm.mu.Unlock() status, exists := pm.ports[port] - return status, exists + + if exists && status { + return false + } + + if !exists { + pm.ports[port] = true + return true + } + + pm.ports[port] = true + return true } diff --git a/session/handler.go b/session/handler.go index 26d2394..c5aad63 100644 --- a/session/handler.go +++ b/session/handler.go @@ -183,16 +183,11 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind portToBind = unassigned } - if isUsed, exists := portUtil.Default.GetPortStatus(portToBind); exists && isUsed { + if claimed := portUtil.Default.ClaimPort(portToBind); !claimed { fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind)) return } - if err := portUtil.Default.SetPortStatus(portToBind, true); err != nil { - fail(fmt.Sprintf("Failed to set port status: %v", err)) - return - } - log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { diff --git a/session/session.go b/session/session.go index a7bbb5d..98a6cd4 100644 --- a/session/session.go +++ b/session/session.go @@ -122,7 +122,7 @@ func (s *SSHSession) Start() error { return fmt.Errorf("no forwarding Request") } - if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") || s.lifecycle.GetUser() == "UNAUTHORIZED" { + if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.GetUser() == "UNAUTHORIZED" { if err := tcpipReq.Reply(false, nil); err != nil { log.Printf("cannot reply to tcpip req: %s\n", err) return err