From 761139d626dd4a213dcccd6d9a14e953a0df5abc Mon Sep 17 00:00:00 2001 From: bagas Date: Sun, 18 Jan 2026 20:42:10 +0700 Subject: [PATCH] refactor: explicit initialization and dependency injection - Replace init() with config.Load() function when loading env variables - Inject portRegistry into session, server, and lifecycle structs - Inject sessionRegistry directly into interaction and lifecycle - Remove SetSessionRegistry function and global port variables - Pass ssh.Conn directly to forwarder constructor instead of lifecycle interface - Pass user and closeFunc callback to interaction constructor instead of lifecycle interface - Eliminate circular dependencies between lifecycle, forwarder, and interaction - Remove setter methods (SetLifecycle) from forwarder and interaction interfaces --- go.sum | 18 ++---------- internal/config/config.go | 8 ++---- internal/port/port.go | 45 ++++++++---------------------- main.go | 38 +++++++++++++++++++++++-- server/server.go | 11 +++++--- session/forwarder/forwarder.go | 21 ++++---------- session/interaction/dashboard.go | 2 +- session/interaction/interaction.go | 40 ++++++++++---------------- session/interaction/slug.go | 2 +- session/lifecycle/lifecycle.go | 25 ++++++++--------- session/session.go | 31 +++++++++----------- 11 files changed, 105 insertions(+), 136 deletions(-) diff --git a/go.sum b/go.sum index c9792ef..11912af 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,3 @@ -git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0 h1:RhcBKUG41/om4jgN+iF/vlY/RojTeX1QhBa4p4428ec= -git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY= -git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0 h1:tpJSKjaSmV+vxxbVx6qnStjxFVXjj2M0rygWXxLb99o= -git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY= git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA= git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= @@ -10,12 +6,8 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= -github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic= -github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA= github.com/caddyserver/certmagic v0.25.1 h1:4sIKKbOt5pg6+sL7tEwymE1x2bj6CHr80da1CRRIPbY= github.com/caddyserver/certmagic v0.25.1/go.mod h1:VhyvndxtVton/Fo/wKhRoC46Rbw1fmjvQ3GjHYSQTEY= -github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA= -github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/caddyserver/zerossl v0.1.4 h1:CVJOE3MZeFisCERZjkxIcsqIH4fnFdlYWnPYeFtBHRw= github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= @@ -118,8 +110,6 @@ go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U= go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= @@ -132,14 +122,10 @@ golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= +golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= diff --git a/internal/config/config.go b/internal/config/config.go index 21cb4fb..45f1cc5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,19 +1,17 @@ package config import ( - "log" "os" "strconv" "github.com/joho/godotenv" ) -func init() { +func Load() error { if _, err := os.Stat(".env"); err == nil { - if err := godotenv.Load(".env"); err != nil { - log.Printf("Warning: Failed to load .env file: %s", err) - } + return godotenv.Load(".env") } + return nil } func Getenv(key, defaultValue string) string { diff --git a/internal/port/port.go b/internal/port/port.go index bd5073a..01ecf96 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -3,53 +3,30 @@ package port import ( "fmt" "sort" - "strconv" - "strings" "sync" - "tunnel_pls/internal/config" ) -type Manager interface { +type Registry interface { AddPortRange(startPort, endPort uint16) error GetUnassignedPort() (uint16, bool) SetPortStatus(port uint16, assigned bool) error ClaimPort(port uint16) (claimed bool) } -type manager struct { +type registry struct { mu sync.RWMutex ports map[uint16]bool sortedPorts []uint16 } -var Default Manager = &manager{ - ports: make(map[uint16]bool), - sortedPorts: []uint16{}, +func New() Registry { + return ®istry{ + ports: make(map[uint16]bool), + sortedPorts: []uint16{}, + } } -func init() { - rawRange := config.Getenv("ALLOWED_PORTS", "") - if rawRange == "" { - return - } - - splitRange := strings.Split(rawRange, "-") - if len(splitRange) != 2 { - return - } - - start, err := strconv.ParseUint(splitRange[0], 10, 16) - if err != nil { - return - } - end, err := strconv.ParseUint(splitRange[1], 10, 16) - if err != nil { - return - } - _ = Default.AddPortRange(uint16(start), uint16(end)) -} - -func (pm *manager) AddPortRange(startPort, endPort uint16) error { +func (pm *registry) AddPortRange(startPort, endPort uint16) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -68,7 +45,7 @@ func (pm *manager) AddPortRange(startPort, endPort uint16) error { return nil } -func (pm *manager) GetUnassignedPort() (uint16, bool) { +func (pm *registry) GetUnassignedPort() (uint16, bool) { pm.mu.Lock() defer pm.mu.Unlock() @@ -80,7 +57,7 @@ func (pm *manager) GetUnassignedPort() (uint16, bool) { return 0, false } -func (pm *manager) SetPortStatus(port uint16, assigned bool) error { +func (pm *registry) SetPortStatus(port uint16, assigned bool) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -88,7 +65,7 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error { return nil } -func (pm *manager) ClaimPort(port uint16) (claimed bool) { +func (pm *registry) ClaimPort(port uint16) (claimed bool) { pm.mu.Lock() defer pm.mu.Unlock() diff --git a/main.go b/main.go index e8d3884..2303718 100644 --- a/main.go +++ b/main.go @@ -8,12 +8,14 @@ import ( _ "net/http/pprof" "os" "os/signal" + "strconv" "strings" "syscall" "time" "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/key" + "tunnel_pls/internal/port" "tunnel_pls/server" "tunnel_pls/session" "tunnel_pls/version" @@ -32,6 +34,12 @@ func main() { log.Printf("Starting %s", version.GetVersion()) + err := config.Load() + if err != nil { + log.Fatalf("Failed to load configuration: %s", err) + return + } + mode := strings.ToLower(config.Getenv("MODE", "standalone")) isNodeMode := mode == "node" @@ -41,7 +49,7 @@ func main() { go func() { pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) - if err := http.ListenAndServe(pprofAddr, nil); err != nil { + if err = http.ListenAndServe(pprofAddr, nil); err != nil { log.Printf("pprof server error: %v", err) } }() @@ -53,7 +61,7 @@ func main() { } sshKeyPath := "certs/ssh/id_rsa" - if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { + if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { log.Fatalf("Failed to generate SSH key: %s", err) } @@ -107,9 +115,33 @@ func main() { }() } + portManager := port.New() + rawRange := config.Getenv("ALLOWED_PORTS", "") + if rawRange != "" { + splitRange := strings.Split(rawRange, "-") + if len(splitRange) == 2 { + var start, end uint64 + start, err = strconv.ParseUint(splitRange[0], 10, 16) + if err != nil { + log.Fatalf("Failed to parse start port: %s", err) + } + + end, err = strconv.ParseUint(splitRange[1], 10, 16) + if err != nil { + log.Fatalf("Failed to parse end port: %s", err) + } + + if err = portManager.AddPortRange(uint16(start), uint16(end)); err != nil { + log.Fatalf("Failed to add port range: %s", err) + } + log.Printf("PortRegistry range configured: %d-%d", start, end) + } else { + log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange) + } + } var app server.Server go func() { - app, err = server.New(sshConfig, sessionRegistry, grpcClient) + app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager) if err != nil { errChan <- fmt.Errorf("failed to start server: %s", err) return diff --git a/server/server.go b/server/server.go index 3e42c9a..792f47e 100644 --- a/server/server.go +++ b/server/server.go @@ -9,6 +9,7 @@ import ( "time" "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" + "tunnel_pls/internal/port" "tunnel_pls/session" "golang.org/x/crypto/ssh" @@ -21,11 +22,12 @@ type Server interface { type server struct { listener net.Listener config *ssh.ServerConfig - sessionRegistry session.Registry grpcClient client.Client + sessionRegistry session.Registry + portRegistry port.Registry } -func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client) (Server, error) { +func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) if err != nil { log.Fatalf("failed to listen on port 2200: %v", err) @@ -50,8 +52,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie return &server{ listener: listener, config: sshConfig, - sessionRegistry: sessionRegistry, grpcClient: grpcClient, + sessionRegistry: sessionRegistry, + portRegistry: portRegistry, }, nil } @@ -103,7 +106,7 @@ func (s *server) handleConnection(conn net.Conn) { cancel() } log.Println("SSH connection established:", sshConn.User()) - sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user) + sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) err = sshSession.Start() if err != nil { log.Printf("SSH session ended with error: %v", err) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index c8478a4..5807ac4 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -35,26 +35,21 @@ type forwarder struct { tunnelType types.TunnelType forwardedPort uint16 slug slug.Slug - lifecycle Lifecycle + conn ssh.Conn } -func New(slug slug.Slug) Forwarder { +func New(slug slug.Slug, conn ssh.Conn) Forwarder { return &forwarder{ listener: nil, tunnelType: types.UNKNOWN, forwardedPort: 0, slug: slug, - lifecycle: nil, + conn: conn, } } -type Lifecycle interface { - Connection() ssh.Conn -} - type Forwarder interface { SetType(tunnelType types.TunnelType) - SetLifecycle(lifecycle Lifecycle) SetForwardedPort(port uint16) SetListener(listener net.Listener) Listener() net.Listener @@ -67,10 +62,6 @@ type Forwarder interface { Close() error } -func (f *forwarder) SetLifecycle(lifecycle Lifecycle) { - f.lifecycle = lifecycle -} - func (f *forwarder) AcceptTCPConnections() { for { conn, err := f.Listener().Accept() @@ -82,7 +73,7 @@ func (f *forwarder) AcceptTCPConnections() { continue } - if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { log.Printf("Failed to set connection deadline: %v", err) if closeErr := conn.Close(); closeErr != nil { log.Printf("Failed to close connection: %v", closeErr) @@ -100,7 +91,7 @@ func (f *forwarder) AcceptTCPConnections() { resultChan := make(chan channelResult, 1) go func() { - channel, reqs, err := f.lifecycle.Connection().OpenChannel("forwarded-tcpip", payload) + channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) resultChan <- channelResult{channel, reqs, err} }() @@ -114,7 +105,7 @@ func (f *forwarder) AcceptTCPConnections() { continue } - if err := conn.SetDeadline(time.Time{}); err != nil { + if err = conn.SetDeadline(time.Time{}); err != nil { log.Printf("Failed to clear connection deadline: %v", err) } diff --git a/session/interaction/dashboard.go b/session/interaction/dashboard.go index eee08db..a24ab7c 100644 --- a/session/interaction/dashboard.go +++ b/session/interaction/dashboard.go @@ -109,7 +109,7 @@ func (m *model) dashboardView() string { MarginBottom(boxMargin). Width(boxMaxWidth) - authenticatedUser := m.interaction.lifecycle.User() + authenticatedUser := m.interaction.user userInfoStyle := lipgloss.NewStyle(). Foreground(lipgloss.Color("#FAFAFA")). diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index c3bfc8a..3c02dae 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -17,20 +17,9 @@ import ( "golang.org/x/crypto/ssh" ) -type Lifecycle interface { - Close() error - User() string -} - -type SessionRegistry interface { - Update(user string, oldKey, newKey types.SessionKey) error -} - type Interaction interface { Mode() types.Mode SetChannel(channel ssh.Channel) - SetLifecycle(lifecycle Lifecycle) - SetSessionRegistry(registry SessionRegistry) SetMode(m types.Mode) SetWH(w, h int) Start() @@ -38,17 +27,23 @@ type Interaction interface { Send(message string) error } +type SessionRegistry interface { + Update(user string, oldKey, newKey types.SessionKey) error +} + type Forwarder interface { Close() error TunnelType() types.TunnelType ForwardedPort() uint16 } +type CloseFunc func() error type interaction struct { channel ssh.Channel slug slug.Slug forwarder Forwarder - lifecycle Lifecycle + closeFunc CloseFunc + user string sessionRegistry SessionRegistry program *tea.Program ctx context.Context @@ -80,28 +75,21 @@ func (i *interaction) SetWH(w, h int) { } } -func New(slug slug.Slug, forwarder Forwarder) Interaction { +func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { ctx, cancel := context.WithCancel(context.Background()) return &interaction{ channel: nil, slug: slug, forwarder: forwarder, - lifecycle: nil, - sessionRegistry: nil, + closeFunc: closeFunc, + user: user, + sessionRegistry: sessionRegistry, program: nil, ctx: ctx, cancel: cancel, } } -func (i *interaction) SetSessionRegistry(registry SessionRegistry) { - i.sessionRegistry = registry -} - -func (i *interaction) SetLifecycle(lifecycle Lifecycle) { - i.lifecycle = lifecycle -} - func (i *interaction) SetChannel(channel ssh.Channel) { i.channel = channel } @@ -262,7 +250,9 @@ func (i *interaction) Start() { } i.program.Kill() i.program = nil - if err := m.interaction.lifecycle.Close(); err != nil { - log.Printf("Cannot close session: %s \n", err) + if i.closeFunc != nil { + if err := i.closeFunc(); err != nil { + log.Printf("Cannot close session: %s \n", err) + } } } diff --git a/session/interaction/slug.go b/session/interaction/slug.go index 7a7bdaa..6c6a97b 100644 --- a/session/interaction/slug.go +++ b/session/interaction/slug.go @@ -28,7 +28,7 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, tea.Batch(tea.ClearScreen, textinput.Blink) case "enter": inputValue := m.slugInput.Value() - if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.User(), types.SessionKey{ + if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{ Id: m.interaction.slug.String(), Type: types.HTTP, }, types.SessionKey{ diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 704b4a8..8d134a2 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -28,47 +28,44 @@ type lifecycle struct { conn ssh.Conn channel ssh.Channel forwarder Forwarder - sessionRegistry SessionRegistry slug slug.Slug startedAt time.Time + sessionRegistry SessionRegistry + portRegistry portUtil.Registry user string } -func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle { +func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle { return &lifecycle{ status: types.INITIALIZING, conn: conn, channel: nil, forwarder: forwarder, slug: slugManager, - sessionRegistry: nil, startedAt: time.Now(), + sessionRegistry: sessionRegistry, + portRegistry: port, user: user, } } -func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) { - l.sessionRegistry = registry -} - type Lifecycle interface { Connection() ssh.Conn - Channel() ssh.Channel + PortRegistry() portUtil.Registry User() string SetChannel(channel ssh.Channel) - SetSessionRegistry(registry SessionRegistry) SetStatus(status types.Status) IsActive() bool StartedAt() time.Time Close() error } -func (l *lifecycle) User() string { - return l.user +func (l *lifecycle) PortRegistry() portUtil.Registry { + return l.portRegistry } -func (l *lifecycle) Channel() ssh.Channel { - return l.channel +func (l *lifecycle) User() string { + return l.user } func (l *lifecycle) SetChannel(channel ssh.Channel) { @@ -116,7 +113,7 @@ func (l *lifecycle) Close() error { l.sessionRegistry.Remove(key) if tunnelType == types.TCP { - if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { + if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { firstErr = err } } diff --git a/session/session.go b/session/session.go index f0fd5be..be9e9ed 100644 --- a/session/session.go +++ b/session/session.go @@ -54,16 +54,11 @@ type session struct { var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session { +func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session { slugManager := slug.New() - forwarderManager := forwarder.New(slugManager) - interactionManager := interaction.New(slugManager, forwarderManager) - lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user) - - interactionManager.SetLifecycle(lifecycleManager) - forwarderManager.SetLifecycle(lifecycleManager) - interactionManager.SetSessionRegistry(sessionRegistry) - lifecycleManager.SetSessionRegistry(sessionRegistry) + forwarderManager := forwarder.New(slugManager, conn) + lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) + interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) return &session{ initialReq: initialReq, @@ -135,7 +130,7 @@ func (s *session) Start() error { tcpipReq := s.waitForTCPIPForward() if tcpipReq == nil { - err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) + err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) if err != nil { return err } @@ -234,7 +229,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { } func (s *session) HandleTCPIPForward(req *ssh.Request) { - log.Println("Port forwarding request detected") + log.Println("PortRegistry forwarding request detected") fail := func(msg string) { log.Println(msg) @@ -262,13 +257,13 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) { } if rawPortToBind > 65535 { - fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind)) + fail(fmt.Sprintf("PortRegistry %d is larger than allowed port of 65535", rawPortToBind)) return } portToBind := uint16(rawPortToBind) if isBlockedPort(portToBind) { - fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind)) + fail(fmt.Sprintf("PortRegistry %d is blocked or restricted", portToBind)) return } @@ -340,7 +335,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin s.registry.Remove(*key) } if port != 0 { - if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil { + if setErr := s.lifecycle.PortRegistry().SetPortStatus(port, false); setErr != nil { log.Printf("Failed to reset port status: %v", setErr) } } @@ -356,7 +351,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin } if portToBind == 0 { - unassigned, ok := portUtil.Default.GetUnassignedPort() + unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort() if !ok { fail("No available port") return @@ -364,15 +359,15 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin portToBind = unassigned } - if claimed := portUtil.Default.ClaimPort(portToBind); !claimed { - fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind)) + if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed { + fail(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) return } log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { - cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil) + cleanup(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind), portToBind, nil, nil) return }