fix(port): add atomic ClaimPort() to prevent race condition
All checks were successful
Docker Build and Push / build-and-push-tags (push) Successful in 3m23s
Docker Build and Push / build-and-push-branches (push) Has been skipped

- Replace GetPortStatus/SetPortStatus calls with atomic ClaimPort() operation.
- Fixed a logic error when handling headless tunneling.
This commit is contained in:
2026-01-12 18:17:20 +07:00
parent 560c98b869
commit abd103b5ab
3 changed files with 18 additions and 12 deletions

View File

@@ -13,7 +13,7 @@ type Manager interface {
AddPortRange(startPort, endPort uint16) error AddPortRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool) GetUnassignedPort() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error SetPortStatus(port uint16, assigned bool) error
GetPortStatus(port uint16) (bool, bool) ClaimPort(port uint16) (claimed bool)
} }
type manager struct { type manager struct {
@@ -88,10 +88,21 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
return nil return nil
} }
func (pm *manager) GetPortStatus(port uint16) (bool, bool) { func (pm *manager) ClaimPort(port uint16) (claimed bool) {
pm.mu.RLock() pm.mu.Lock()
defer pm.mu.RUnlock() defer pm.mu.Unlock()
status, exists := pm.ports[port] status, exists := pm.ports[port]
return status, exists
if exists && status {
return false
}
if !exists {
pm.ports[port] = true
return true
}
pm.ports[port] = true
return true
} }

View File

@@ -183,16 +183,11 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
portToBind = unassigned portToBind = unassigned
} }
if isUsed, exists := portUtil.Default.GetPortStatus(portToBind); exists && isUsed { if claimed := portUtil.Default.ClaimPort(portToBind); !claimed {
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind)) fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
return return
} }
if err := portUtil.Default.SetPortStatus(portToBind, true); err != nil {
fail(fmt.Sprintf("Failed to set port status: %v", err))
return
}
log.Printf("Requested forwarding on %s:%d", addr, portToBind) log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil { if err != nil {

View File

@@ -122,7 +122,7 @@ func (s *SSHSession) Start() error {
return fmt.Errorf("no forwarding Request") return fmt.Errorf("no forwarding Request")
} }
if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") || s.lifecycle.GetUser() == "UNAUTHORIZED" { if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.GetUser() == "UNAUTHORIZED" {
if err := tcpipReq.Reply(false, nil); err != nil { if err := tcpipReq.Reply(false, nil); err != nil {
log.Printf("cannot reply to tcpip req: %s\n", err) log.Printf("cannot reply to tcpip req: %s\n", err)
return err return err