fix(port): add atomic ClaimPort() to prevent race condition
- Replace GetPortStatus/SetPortStatus calls with atomic ClaimPort() operation. - Fixed a logic error when handling headless tunneling.
This commit is contained in:
@@ -13,7 +13,7 @@ type Manager interface {
|
|||||||
AddPortRange(startPort, endPort uint16) error
|
AddPortRange(startPort, endPort uint16) error
|
||||||
GetUnassignedPort() (uint16, bool)
|
GetUnassignedPort() (uint16, bool)
|
||||||
SetPortStatus(port uint16, assigned bool) error
|
SetPortStatus(port uint16, assigned bool) error
|
||||||
GetPortStatus(port uint16) (bool, bool)
|
ClaimPort(port uint16) (claimed bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type manager struct {
|
type manager struct {
|
||||||
@@ -88,10 +88,21 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *manager) GetPortStatus(port uint16) (bool, bool) {
|
func (pm *manager) ClaimPort(port uint16) (claimed bool) {
|
||||||
pm.mu.RLock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.RUnlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
status, exists := pm.ports[port]
|
status, exists := pm.ports[port]
|
||||||
return status, exists
|
|
||||||
|
if exists && status {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
pm.ports[port] = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.ports[port] = true
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -183,16 +183,11 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
|
|||||||
portToBind = unassigned
|
portToBind = unassigned
|
||||||
}
|
}
|
||||||
|
|
||||||
if isUsed, exists := portUtil.Default.GetPortStatus(portToBind); exists && isUsed {
|
if claimed := portUtil.Default.ClaimPort(portToBind); !claimed {
|
||||||
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := portUtil.Default.SetPortStatus(portToBind, true); err != nil {
|
|
||||||
fail(fmt.Sprintf("Failed to set port status: %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ func (s *SSHSession) Start() error {
|
|||||||
return fmt.Errorf("no forwarding Request")
|
return fmt.Errorf("no forwarding Request")
|
||||||
}
|
}
|
||||||
|
|
||||||
if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") || s.lifecycle.GetUser() == "UNAUTHORIZED" {
|
if (s.interaction.GetMode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.GetUser() == "UNAUTHORIZED" {
|
||||||
if err := tcpipReq.Reply(false, nil); err != nil {
|
if err := tcpipReq.Reply(false, nil); err != nil {
|
||||||
log.Printf("cannot reply to tcpip req: %s\n", err)
|
log.Printf("cannot reply to tcpip req: %s\n", err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
Reference in New Issue
Block a user