From 8c15da6131a88bf48427b4d7694828266486f057 Mon Sep 17 00:00:00 2001 From: bagas Date: Mon, 21 Jul 2025 13:07:42 +0700 Subject: [PATCH] feat: Support dynamic port allocation for SSH forwarding --- internal/port/port.go | 75 +++++++++++++++++++++++++++++++++++++++++++ session/handler.go | 51 +++++++++++++++++++++++------ 2 files changed, 117 insertions(+), 9 deletions(-) create mode 100644 internal/port/port.go diff --git a/internal/port/port.go b/internal/port/port.go new file mode 100644 index 0000000..a47f155 --- /dev/null +++ b/internal/port/port.go @@ -0,0 +1,75 @@ +package port + +import ( + "fmt" + "sort" + "strconv" + "strings" + "tunnel_pls/utils" +) + +type PortManager struct { + ports map[uint16]bool + sortedPorts []uint16 +} + +var Manager = PortManager{ + ports: make(map[uint16]bool), + sortedPorts: []uint16{}, +} + +func init() { + rawRange := utils.Getenv("ALLOWED_PORTS") + splitRange := strings.Split(rawRange, "-") + if len(splitRange) != 2 { + Manager.AddPortRange(30000, 31000) + } else { + start, err := strconv.ParseUint(splitRange[0], 10, 16) + if err != nil { + start = 30000 + } + end, err := strconv.ParseUint(splitRange[1], 10, 16) + if err != nil { + end = 31000 + } + Manager.AddPortRange(uint16(start), uint16(end)) + } +} + +func (pm *PortManager) AddPortRange(startPort, endPort uint16) error { + if startPort > endPort { + return fmt.Errorf("start port cannot be greater than end port") + } + for port := startPort; port <= endPort; port++ { + if _, exists := pm.ports[port]; !exists { + pm.ports[port] = false + pm.sortedPorts = append(pm.sortedPorts, port) + } + } + sort.Slice(pm.sortedPorts, func(i, j int) bool { + return pm.sortedPorts[i] < pm.sortedPorts[j] + }) + return nil +} + +func (pm *PortManager) GetUnassignedPort() (uint16, bool) { + for _, port := range pm.sortedPorts { + if !pm.ports[port] { + return port, true + } + } + return 0, false +} + +func (pm *PortManager) SetPortStatus(port uint16, assigned bool) error { + if _, exists := pm.ports[port]; !exists { + return fmt.Errorf("port %d is not in the allowed range", port) + } + pm.ports[port] = assigned + return nil +} + +func (pm *PortManager) GetPortStatus(port uint16) (bool, bool) { + status, exists := pm.ports[port] + return status, exists +} diff --git a/session/handler.go b/session/handler.go index b461673..0547299 100644 --- a/session/handler.go +++ b/session/handler.go @@ -13,6 +13,7 @@ import ( "strings" "sync" "time" + portUtil "tunnel_pls/internal/port" "golang.org/x/crypto/ssh" "golang.org/x/net/context" @@ -105,6 +106,10 @@ func (s *Session) Close() { unregisterClient(s.Slug) } + if s.TunnelType == TCP { + portUtil.Manager.SetPortStatus(s.ForwardedPort, false) + } + close(s.Done) } @@ -147,41 +152,69 @@ func (s *Session) handleTCPIPForward(req *ssh.Request) { return } - var portToBind uint32 - if err := binary.Read(reader, binary.BigEndian, &portToBind); err != nil { + var rawPortToBind uint32 + if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { log.Println("Failed to read port from payload:", err) - s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) + s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) req.Reply(false, nil) s.Close() return } + if rawPortToBind > 65535 { + s.sendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) + req.Reply(false, nil) + s.Close() + return + } + + portToBind := uint16(rawPortToBind) + if isBlockedPort(portToBind) { - s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) + s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) + req.Reply(false, nil) + s.Close() + return + } + + if portToBind == 0 { + unassign, success := portUtil.Manager.GetUnassignedPort() + portToBind = unassign + if !success { + s.sendMessage(fmt.Sprintf("No available port\r\n", portToBind)) + req.Reply(false, nil) + s.Close() + return + } + } else if isUse, isExist := portUtil.Manager.GetPortStatus(portToBind); !isExist || isUse { + s.sendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) req.Reply(false, nil) s.Close() return } s.sendMessage("\033[H\033[2J") + showWelcomeMessage(s.ConnChannel) s.Status = RUNNING if portToBind == 80 || portToBind == 443 { s.handleHTTPForward(req, portToBind) return + } else { + portUtil.Manager.SetPortStatus(portToBind, true) } s.handleTCPForward(req, addr, portToBind) } -var blockedReservedPorts = []uint32{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} +var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func isBlockedPort(port uint32) bool { +func isBlockedPort(port uint16) bool { if port == 80 || port == 443 { return false } - if port < 1024 { + if port < 1024 && port != 0 { return true } for _, p := range blockedReservedPorts { @@ -192,7 +225,7 @@ func isBlockedPort(port uint32) bool { return false } -func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint32) { +func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint16) { s.TunnelType = HTTP s.ForwardedPort = uint16(portToBind) @@ -221,7 +254,7 @@ func (s *Session) handleHTTPForward(req *ssh.Request, portToBind uint32) { req.Reply(true, buf.Bytes()) } -func (s *Session) handleTCPForward(req *ssh.Request, addr string, portToBind uint32) { +func (s *Session) handleTCPForward(req *ssh.Request, addr string, portToBind uint16) { s.TunnelType = TCP log.Printf("Requested forwarding on %s:%d", addr, portToBind)