feat(port): disable TCP forwarding by default and refactor port manager
This commit is contained in:
@@ -9,36 +9,47 @@ import (
|
|||||||
"tunnel_pls/utils"
|
"tunnel_pls/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PortManager struct {
|
type Manager interface {
|
||||||
|
AddPortRange(startPort, endPort uint16) error
|
||||||
|
GetUnassignedPort() (uint16, bool)
|
||||||
|
SetPortStatus(port uint16, assigned bool) error
|
||||||
|
GetPortStatus(port uint16) (bool, bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type manager struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ports map[uint16]bool
|
ports map[uint16]bool
|
||||||
sortedPorts []uint16
|
sortedPorts []uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
var Manager = PortManager{
|
var Default Manager = &manager{
|
||||||
ports: make(map[uint16]bool),
|
ports: make(map[uint16]bool),
|
||||||
sortedPorts: []uint16{},
|
sortedPorts: []uint16{},
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
rawRange := utils.Getenv("ALLOWED_PORTS", "40000-41000")
|
rawRange := utils.Getenv("ALLOWED_PORTS", "")
|
||||||
|
if rawRange == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
splitRange := strings.Split(rawRange, "-")
|
splitRange := strings.Split(rawRange, "-")
|
||||||
if len(splitRange) != 2 {
|
if len(splitRange) != 2 {
|
||||||
Manager.AddPortRange(30000, 31000)
|
return
|
||||||
} else {
|
}
|
||||||
|
|
||||||
start, err := strconv.ParseUint(splitRange[0], 10, 16)
|
start, err := strconv.ParseUint(splitRange[0], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
start = 30000
|
return
|
||||||
}
|
}
|
||||||
end, err := strconv.ParseUint(splitRange[1], 10, 16)
|
end, err := strconv.ParseUint(splitRange[1], 10, 16)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
end = 31000
|
return
|
||||||
}
|
|
||||||
Manager.AddPortRange(uint16(start), uint16(end))
|
|
||||||
}
|
}
|
||||||
|
_ = Default.AddPortRange(uint16(start), uint16(end))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PortManager) AddPortRange(startPort, endPort uint16) error {
|
func (pm *manager) AddPortRange(startPort, endPort uint16) error {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
@@ -57,7 +68,7 @@ func (pm *PortManager) AddPortRange(startPort, endPort uint16) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PortManager) GetUnassignedPort() (uint16, bool) {
|
func (pm *manager) GetUnassignedPort() (uint16, bool) {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
@@ -70,7 +81,7 @@ func (pm *PortManager) GetUnassignedPort() (uint16, bool) {
|
|||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PortManager) SetPortStatus(port uint16, assigned bool) error {
|
func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
@@ -78,7 +89,7 @@ func (pm *PortManager) SetPortStatus(port uint16, assigned bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *PortManager) GetPortStatus(port uint16) (bool, bool) {
|
func (pm *manager) GetPortStatus(port uint16) (bool, bool) {
|
||||||
pm.mu.RLock()
|
pm.mu.RLock()
|
||||||
defer pm.mu.RUnlock()
|
defer pm.mu.RUnlock()
|
||||||
|
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
|
|||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
if portToBind == 0 {
|
if portToBind == 0 {
|
||||||
unassign, success := portUtil.Manager.GetUnassignedPort()
|
unassign, success := portUtil.Default.GetUnassignedPort()
|
||||||
portToBind = unassign
|
portToBind = unassign
|
||||||
if !success {
|
if !success {
|
||||||
s.Interaction.SendMessage("No available port\r\n")
|
s.Interaction.SendMessage("No available port\r\n")
|
||||||
@@ -122,7 +122,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else if isUse, isExist := portUtil.Manager.GetPortStatus(portToBind); isExist && isUse {
|
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
|
||||||
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind))
|
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind))
|
||||||
err := req.Reply(false, nil)
|
err := req.Reply(false, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -135,7 +135,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err := portUtil.Manager.SetPortStatus(portToBind, true)
|
err := portUtil.Default.SetPortStatus(portToBind, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to set port status:", err)
|
log.Println("Failed to set port status:", err)
|
||||||
return
|
return
|
||||||
@@ -208,7 +208,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
|
|||||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
|
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
|
||||||
if setErr := portUtil.Manager.SetPortStatus(portToBind, false); setErr != nil {
|
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
|
||||||
log.Printf("Failed to reset port status: %v", setErr)
|
log.Printf("Failed to reset port status: %v", setErr)
|
||||||
}
|
}
|
||||||
err = req.Reply(false, nil)
|
err = req.Reply(false, nil)
|
||||||
@@ -227,7 +227,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
|
|||||||
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to write port to buffer:", err)
|
log.Println("Failed to write port to buffer:", err)
|
||||||
if setErr := portUtil.Manager.SetPortStatus(portToBind, false); setErr != nil {
|
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
|
||||||
log.Printf("Failed to reset port status: %v", setErr)
|
log.Printf("Failed to reset port status: %v", setErr)
|
||||||
}
|
}
|
||||||
err = listener.Close()
|
err = listener.Close()
|
||||||
@@ -242,7 +242,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
|
|||||||
err = req.Reply(true, buf.Bytes())
|
err = req.Reply(true, buf.Bytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Failed to reply to request:", err)
|
log.Println("Failed to reply to request:", err)
|
||||||
if setErr := portUtil.Manager.SetPortStatus(portToBind, false); setErr != nil {
|
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
|
||||||
log.Printf("Failed to reset port status: %v", setErr)
|
log.Printf("Failed to reset port status: %v", setErr)
|
||||||
}
|
}
|
||||||
err = listener.Close()
|
err = listener.Close()
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func (l *Lifecycle) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if l.Forwarder.GetTunnelType() == types.TCP {
|
if l.Forwarder.GetTunnelType() == types.TCP {
|
||||||
err := portUtil.Manager.SetPortStatus(l.Forwarder.GetForwardedPort(), false)
|
err := portUtil.Default.SetPortStatus(l.Forwarder.GetForwardedPort(), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user