diff --git a/internal/port/port.go b/internal/port/port.go index 6512f40..68e185a 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -9,36 +9,47 @@ import ( "tunnel_pls/utils" ) -type PortManager struct { +type Manager interface { + AddPortRange(startPort, endPort uint16) error + GetUnassignedPort() (uint16, bool) + SetPortStatus(port uint16, assigned bool) error + GetPortStatus(port uint16) (bool, bool) +} + +type manager struct { mu sync.RWMutex ports map[uint16]bool sortedPorts []uint16 } -var Manager = PortManager{ +var Default Manager = &manager{ ports: make(map[uint16]bool), sortedPorts: []uint16{}, } func init() { - rawRange := utils.Getenv("ALLOWED_PORTS", "40000-41000") + rawRange := utils.Getenv("ALLOWED_PORTS", "") + if rawRange == "" { + return + } + splitRange := strings.Split(rawRange, "-") if len(splitRange) != 2 { - Manager.AddPortRange(30000, 31000) - } else { - start, err := strconv.ParseUint(splitRange[0], 10, 16) - if err != nil { - start = 30000 - } - end, err := strconv.ParseUint(splitRange[1], 10, 16) - if err != nil { - end = 31000 - } - Manager.AddPortRange(uint16(start), uint16(end)) + return } + + start, err := strconv.ParseUint(splitRange[0], 10, 16) + if err != nil { + return + } + end, err := strconv.ParseUint(splitRange[1], 10, 16) + if err != nil { + return + } + _ = Default.AddPortRange(uint16(start), uint16(end)) } -func (pm *PortManager) AddPortRange(startPort, endPort uint16) error { +func (pm *manager) AddPortRange(startPort, endPort uint16) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -57,7 +68,7 @@ func (pm *PortManager) AddPortRange(startPort, endPort uint16) error { return nil } -func (pm *PortManager) GetUnassignedPort() (uint16, bool) { +func (pm *manager) GetUnassignedPort() (uint16, bool) { pm.mu.Lock() defer pm.mu.Unlock() @@ -70,7 +81,7 @@ func (pm *PortManager) GetUnassignedPort() (uint16, bool) { return 0, false } -func (pm *PortManager) SetPortStatus(port uint16, assigned bool) error { +func (pm *manager) SetPortStatus(port uint16, assigned bool) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -78,7 +89,7 @@ func (pm *PortManager) SetPortStatus(port uint16, assigned bool) error { return nil } -func (pm *PortManager) GetPortStatus(port uint16) (bool, bool) { +func (pm *manager) GetPortStatus(port uint16) (bool, bool) { pm.mu.RLock() defer pm.mu.RUnlock() diff --git a/session/handler.go b/session/handler.go index eb61cfd..d536b51 100644 --- a/session/handler.go +++ b/session/handler.go @@ -107,7 +107,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { return } else { if portToBind == 0 { - unassign, success := portUtil.Manager.GetUnassignedPort() + unassign, success := portUtil.Default.GetUnassignedPort() portToBind = unassign if !success { s.Interaction.SendMessage("No available port\r\n") @@ -122,7 +122,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { } return } - } else if isUse, isExist := portUtil.Manager.GetPortStatus(portToBind); isExist && isUse { + } else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse { s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { @@ -135,7 +135,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { } return } - err := portUtil.Manager.SetPortStatus(portToBind, true) + err := portUtil.Default.SetPortStatus(portToBind, true) if err != nil { log.Println("Failed to set port status:", err) return @@ -208,7 +208,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) - if setErr := portUtil.Manager.SetPortStatus(portToBind, false); setErr != nil { + if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { log.Printf("Failed to reset port status: %v", setErr) } err = req.Reply(false, nil) @@ -227,7 +227,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind err = binary.Write(buf, binary.BigEndian, uint32(portToBind)) if err != nil { log.Println("Failed to write port to buffer:", err) - if setErr := portUtil.Manager.SetPortStatus(portToBind, false); setErr != nil { + if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { log.Printf("Failed to reset port status: %v", setErr) } err = listener.Close() @@ -242,7 +242,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind err = req.Reply(true, buf.Bytes()) if err != nil { log.Println("Failed to reply to request:", err) - if setErr := portUtil.Manager.SetPortStatus(portToBind, false); setErr != nil { + if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { log.Printf("Failed to reset port status: %v", setErr) } err = listener.Close() diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index ecfc206..11106f8 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -85,7 +85,7 @@ func (l *Lifecycle) Close() error { } if l.Forwarder.GetTunnelType() == types.TCP { - err := portUtil.Manager.SetPortStatus(l.Forwarder.GetForwardedPort(), false) + err := portUtil.Default.SetPortStatus(l.Forwarder.GetForwardedPort(), false) if err != nil { return err }