diff --git a/internal/config/config.go b/internal/config/config.go index 21cb4fb..45f1cc5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,19 +1,17 @@ package config import ( - "log" "os" "strconv" "github.com/joho/godotenv" ) -func init() { +func Load() error { if _, err := os.Stat(".env"); err == nil { - if err := godotenv.Load(".env"); err != nil { - log.Printf("Warning: Failed to load .env file: %s", err) - } + return godotenv.Load(".env") } + return nil } func Getenv(key, defaultValue string) string { diff --git a/internal/port/port.go b/internal/port/port.go index bd5073a..01ecf96 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -3,53 +3,30 @@ package port import ( "fmt" "sort" - "strconv" - "strings" "sync" - "tunnel_pls/internal/config" ) -type Manager interface { +type Registry interface { AddPortRange(startPort, endPort uint16) error GetUnassignedPort() (uint16, bool) SetPortStatus(port uint16, assigned bool) error ClaimPort(port uint16) (claimed bool) } -type manager struct { +type registry struct { mu sync.RWMutex ports map[uint16]bool sortedPorts []uint16 } -var Default Manager = &manager{ - ports: make(map[uint16]bool), - sortedPorts: []uint16{}, +func New() Registry { + return ®istry{ + ports: make(map[uint16]bool), + sortedPorts: []uint16{}, + } } -func init() { - rawRange := config.Getenv("ALLOWED_PORTS", "") - if rawRange == "" { - return - } - - splitRange := strings.Split(rawRange, "-") - if len(splitRange) != 2 { - return - } - - start, err := strconv.ParseUint(splitRange[0], 10, 16) - if err != nil { - return - } - end, err := strconv.ParseUint(splitRange[1], 10, 16) - if err != nil { - return - } - _ = Default.AddPortRange(uint16(start), uint16(end)) -} - -func (pm *manager) AddPortRange(startPort, endPort uint16) error { +func (pm *registry) AddPortRange(startPort, endPort uint16) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -68,7 +45,7 @@ func (pm *manager) AddPortRange(startPort, endPort uint16) error { return nil } -func (pm *manager) GetUnassignedPort() (uint16, bool) { +func (pm *registry) GetUnassignedPort() (uint16, bool) { pm.mu.Lock() defer pm.mu.Unlock() @@ -80,7 +57,7 @@ func (pm *manager) GetUnassignedPort() (uint16, bool) { return 0, false } -func (pm *manager) SetPortStatus(port uint16, assigned bool) error { +func (pm *registry) SetPortStatus(port uint16, assigned bool) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -88,7 +65,7 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error { return nil } -func (pm *manager) ClaimPort(port uint16) (claimed bool) { +func (pm *registry) ClaimPort(port uint16) (claimed bool) { pm.mu.Lock() defer pm.mu.Unlock() diff --git a/main.go b/main.go index e8d3884..2303718 100644 --- a/main.go +++ b/main.go @@ -8,12 +8,14 @@ import ( _ "net/http/pprof" "os" "os/signal" + "strconv" "strings" "syscall" "time" "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/key" + "tunnel_pls/internal/port" "tunnel_pls/server" "tunnel_pls/session" "tunnel_pls/version" @@ -32,6 +34,12 @@ func main() { log.Printf("Starting %s", version.GetVersion()) + err := config.Load() + if err != nil { + log.Fatalf("Failed to load configuration: %s", err) + return + } + mode := strings.ToLower(config.Getenv("MODE", "standalone")) isNodeMode := mode == "node" @@ -41,7 +49,7 @@ func main() { go func() { pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) - if err := http.ListenAndServe(pprofAddr, nil); err != nil { + if err = http.ListenAndServe(pprofAddr, nil); err != nil { log.Printf("pprof server error: %v", err) } }() @@ -53,7 +61,7 @@ func main() { } sshKeyPath := "certs/ssh/id_rsa" - if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { + if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { log.Fatalf("Failed to generate SSH key: %s", err) } @@ -107,9 +115,33 @@ func main() { }() } + portManager := port.New() + rawRange := config.Getenv("ALLOWED_PORTS", "") + if rawRange != "" { + splitRange := strings.Split(rawRange, "-") + if len(splitRange) == 2 { + var start, end uint64 + start, err = strconv.ParseUint(splitRange[0], 10, 16) + if err != nil { + log.Fatalf("Failed to parse start port: %s", err) + } + + end, err = strconv.ParseUint(splitRange[1], 10, 16) + if err != nil { + log.Fatalf("Failed to parse end port: %s", err) + } + + if err = portManager.AddPortRange(uint16(start), uint16(end)); err != nil { + log.Fatalf("Failed to add port range: %s", err) + } + log.Printf("PortRegistry range configured: %d-%d", start, end) + } else { + log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange) + } + } var app server.Server go func() { - app, err = server.New(sshConfig, sessionRegistry, grpcClient) + app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager) if err != nil { errChan <- fmt.Errorf("failed to start server: %s", err) return diff --git a/server/server.go b/server/server.go index 3e42c9a..792f47e 100644 --- a/server/server.go +++ b/server/server.go @@ -9,6 +9,7 @@ import ( "time" "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" + "tunnel_pls/internal/port" "tunnel_pls/session" "golang.org/x/crypto/ssh" @@ -21,11 +22,12 @@ type Server interface { type server struct { listener net.Listener config *ssh.ServerConfig - sessionRegistry session.Registry grpcClient client.Client + sessionRegistry session.Registry + portRegistry port.Registry } -func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client) (Server, error) { +func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) if err != nil { log.Fatalf("failed to listen on port 2200: %v", err) @@ -50,8 +52,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie return &server{ listener: listener, config: sshConfig, - sessionRegistry: sessionRegistry, grpcClient: grpcClient, + sessionRegistry: sessionRegistry, + portRegistry: portRegistry, }, nil } @@ -103,7 +106,7 @@ func (s *server) handleConnection(conn net.Conn) { cancel() } log.Println("SSH connection established:", sshConn.User()) - sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user) + sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) err = sshSession.Start() if err != nil { log.Printf("SSH session ended with error: %v", err) diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index c3bfc8a..1a9bb8f 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -30,7 +30,6 @@ type Interaction interface { Mode() types.Mode SetChannel(channel ssh.Channel) SetLifecycle(lifecycle Lifecycle) - SetSessionRegistry(registry SessionRegistry) SetMode(m types.Mode) SetWH(w, h int) Start() @@ -80,24 +79,20 @@ func (i *interaction) SetWH(w, h int) { } } -func New(slug slug.Slug, forwarder Forwarder) Interaction { +func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry) Interaction { ctx, cancel := context.WithCancel(context.Background()) return &interaction{ channel: nil, slug: slug, forwarder: forwarder, lifecycle: nil, - sessionRegistry: nil, + sessionRegistry: sessionRegistry, program: nil, ctx: ctx, cancel: cancel, } } -func (i *interaction) SetSessionRegistry(registry SessionRegistry) { - i.sessionRegistry = registry -} - func (i *interaction) SetLifecycle(lifecycle Lifecycle) { i.lifecycle = lifecycle } diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 704b4a8..0b5ff33 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -28,41 +28,43 @@ type lifecycle struct { conn ssh.Conn channel ssh.Channel forwarder Forwarder - sessionRegistry SessionRegistry slug slug.Slug startedAt time.Time + sessionRegistry SessionRegistry + portRegistry portUtil.Registry user string } -func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle { +func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle { return &lifecycle{ status: types.INITIALIZING, conn: conn, channel: nil, forwarder: forwarder, slug: slugManager, - sessionRegistry: nil, startedAt: time.Now(), + sessionRegistry: sessionRegistry, + portRegistry: port, user: user, } } -func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) { - l.sessionRegistry = registry -} - type Lifecycle interface { Connection() ssh.Conn Channel() ssh.Channel + PortRegistry() portUtil.Registry User() string SetChannel(channel ssh.Channel) - SetSessionRegistry(registry SessionRegistry) SetStatus(status types.Status) IsActive() bool StartedAt() time.Time Close() error } +func (l *lifecycle) PortRegistry() portUtil.Registry { + return l.portRegistry +} + func (l *lifecycle) User() string { return l.user } @@ -116,7 +118,7 @@ func (l *lifecycle) Close() error { l.sessionRegistry.Remove(key) if tunnelType == types.TCP { - if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { + if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { firstErr = err } } diff --git a/session/session.go b/session/session.go index f0fd5be..c4f6feb 100644 --- a/session/session.go +++ b/session/session.go @@ -54,16 +54,14 @@ type session struct { var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session { +func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session { slugManager := slug.New() forwarderManager := forwarder.New(slugManager) - interactionManager := interaction.New(slugManager, forwarderManager) - lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user) + interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry) + lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) interactionManager.SetLifecycle(lifecycleManager) forwarderManager.SetLifecycle(lifecycleManager) - interactionManager.SetSessionRegistry(sessionRegistry) - lifecycleManager.SetSessionRegistry(sessionRegistry) return &session{ initialReq: initialReq, @@ -135,7 +133,7 @@ func (s *session) Start() error { tcpipReq := s.waitForTCPIPForward() if tcpipReq == nil { - err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) + err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) if err != nil { return err } @@ -234,7 +232,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { } func (s *session) HandleTCPIPForward(req *ssh.Request) { - log.Println("Port forwarding request detected") + log.Println("PortRegistry forwarding request detected") fail := func(msg string) { log.Println(msg) @@ -262,13 +260,13 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) { } if rawPortToBind > 65535 { - fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind)) + fail(fmt.Sprintf("PortRegistry %d is larger than allowed port of 65535", rawPortToBind)) return } portToBind := uint16(rawPortToBind) if isBlockedPort(portToBind) { - fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind)) + fail(fmt.Sprintf("PortRegistry %d is blocked or restricted", portToBind)) return } @@ -340,7 +338,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin s.registry.Remove(*key) } if port != 0 { - if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil { + if setErr := s.lifecycle.PortRegistry().SetPortStatus(port, false); setErr != nil { log.Printf("Failed to reset port status: %v", setErr) } } @@ -356,7 +354,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin } if portToBind == 0 { - unassigned, ok := portUtil.Default.GetUnassignedPort() + unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort() if !ok { fail("No available port") return @@ -364,15 +362,15 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin portToBind = unassigned } - if claimed := portUtil.Default.ClaimPort(portToBind); !claimed { - fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind)) + if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed { + fail(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) return } log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { - cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil) + cleanup(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind), portToBind, nil, nil) return }