refactor: explicit initialization and dependency injection
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 10m2s

- Replace init() with config.Load() function when loading env variables
- Inject portRegistry into session, server, and lifecycle structs
- Inject sessionRegistry directly into interaction and lifecycle
- Remove SetSessionRegistry function and global port variables
This commit is contained in:
2026-01-18 20:42:10 +07:00
parent 2b9bca65d5
commit 317ab2dbe4
7 changed files with 81 additions and 76 deletions
+3 -5
View File
@@ -1,19 +1,17 @@
package config package config
import ( import (
"log"
"os" "os"
"strconv" "strconv"
"github.com/joho/godotenv" "github.com/joho/godotenv"
) )
func init() { func Load() error {
if _, err := os.Stat(".env"); err == nil { if _, err := os.Stat(".env"); err == nil {
if err := godotenv.Load(".env"); err != nil { return godotenv.Load(".env")
log.Printf("Warning: Failed to load .env file: %s", err)
}
} }
return nil
} }
func Getenv(key, defaultValue string) string { func Getenv(key, defaultValue string) string {
+9 -32
View File
@@ -3,53 +3,30 @@ package port
import ( import (
"fmt" "fmt"
"sort" "sort"
"strconv"
"strings"
"sync" "sync"
"tunnel_pls/internal/config"
) )
type Manager interface { type Registry interface {
AddPortRange(startPort, endPort uint16) error AddPortRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool) GetUnassignedPort() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error SetPortStatus(port uint16, assigned bool) error
ClaimPort(port uint16) (claimed bool) ClaimPort(port uint16) (claimed bool)
} }
type manager struct { type registry struct {
mu sync.RWMutex mu sync.RWMutex
ports map[uint16]bool ports map[uint16]bool
sortedPorts []uint16 sortedPorts []uint16
} }
var Default Manager = &manager{ func New() Registry {
return &registry{
ports: make(map[uint16]bool), ports: make(map[uint16]bool),
sortedPorts: []uint16{}, sortedPorts: []uint16{},
}
} }
func init() { func (pm *registry) AddPortRange(startPort, endPort uint16) error {
rawRange := config.Getenv("ALLOWED_PORTS", "")
if rawRange == "" {
return
}
splitRange := strings.Split(rawRange, "-")
if len(splitRange) != 2 {
return
}
start, err := strconv.ParseUint(splitRange[0], 10, 16)
if err != nil {
return
}
end, err := strconv.ParseUint(splitRange[1], 10, 16)
if err != nil {
return
}
_ = Default.AddPortRange(uint16(start), uint16(end))
}
func (pm *manager) AddPortRange(startPort, endPort uint16) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -68,7 +45,7 @@ func (pm *manager) AddPortRange(startPort, endPort uint16) error {
return nil return nil
} }
func (pm *manager) GetUnassignedPort() (uint16, bool) { func (pm *registry) GetUnassignedPort() (uint16, bool) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -80,7 +57,7 @@ func (pm *manager) GetUnassignedPort() (uint16, bool) {
return 0, false return 0, false
} }
func (pm *manager) SetPortStatus(port uint16, assigned bool) error { func (pm *registry) SetPortStatus(port uint16, assigned bool) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -88,7 +65,7 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
return nil return nil
} }
func (pm *manager) ClaimPort(port uint16) (claimed bool) { func (pm *registry) ClaimPort(port uint16) (claimed bool) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
+35 -3
View File
@@ -8,12 +8,14 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
"strconv"
"strings" "strings"
"syscall" "syscall"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/key" "tunnel_pls/internal/key"
"tunnel_pls/internal/port"
"tunnel_pls/server" "tunnel_pls/server"
"tunnel_pls/session" "tunnel_pls/session"
"tunnel_pls/version" "tunnel_pls/version"
@@ -32,6 +34,12 @@ func main() {
log.Printf("Starting %s", version.GetVersion()) log.Printf("Starting %s", version.GetVersion())
err := config.Load()
if err != nil {
log.Fatalf("Failed to load configuration: %s", err)
return
}
mode := strings.ToLower(config.Getenv("MODE", "standalone")) mode := strings.ToLower(config.Getenv("MODE", "standalone"))
isNodeMode := mode == "node" isNodeMode := mode == "node"
@@ -41,7 +49,7 @@ func main() {
go func() { go func() {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil { if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err) log.Printf("pprof server error: %v", err)
} }
}() }()
@@ -53,7 +61,7 @@ func main() {
} }
sshKeyPath := "certs/ssh/id_rsa" sshKeyPath := "certs/ssh/id_rsa"
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
log.Fatalf("Failed to generate SSH key: %s", err) log.Fatalf("Failed to generate SSH key: %s", err)
} }
@@ -107,9 +115,33 @@ func main() {
}() }()
} }
portManager := port.New()
rawRange := config.Getenv("ALLOWED_PORTS", "")
if rawRange != "" {
splitRange := strings.Split(rawRange, "-")
if len(splitRange) == 2 {
var start, end uint64
start, err = strconv.ParseUint(splitRange[0], 10, 16)
if err != nil {
log.Fatalf("Failed to parse start port: %s", err)
}
end, err = strconv.ParseUint(splitRange[1], 10, 16)
if err != nil {
log.Fatalf("Failed to parse end port: %s", err)
}
if err = portManager.AddPortRange(uint16(start), uint16(end)); err != nil {
log.Fatalf("Failed to add port range: %s", err)
}
log.Printf("PortRegistry range configured: %d-%d", start, end)
} else {
log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange)
}
}
var app server.Server var app server.Server
go func() { go func() {
app, err = server.New(sshConfig, sessionRegistry, grpcClient) app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager)
if err != nil { if err != nil {
errChan <- fmt.Errorf("failed to start server: %s", err) errChan <- fmt.Errorf("failed to start server: %s", err)
return return
+7 -4
View File
@@ -9,6 +9,7 @@ import (
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port"
"tunnel_pls/session" "tunnel_pls/session"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@@ -21,11 +22,12 @@ type Server interface {
type server struct { type server struct {
listener net.Listener listener net.Listener
config *ssh.ServerConfig config *ssh.ServerConfig
sessionRegistry session.Registry
grpcClient client.Client grpcClient client.Client
sessionRegistry session.Registry
portRegistry port.Registry
} }
func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client) (Server, error) { func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client, portRegistry port.Registry) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200")))
if err != nil { if err != nil {
log.Fatalf("failed to listen on port 2200: %v", err) log.Fatalf("failed to listen on port 2200: %v", err)
@@ -50,8 +52,9 @@ func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClie
return &server{ return &server{
listener: listener, listener: listener,
config: sshConfig, config: sshConfig,
sessionRegistry: sessionRegistry,
grpcClient: grpcClient, grpcClient: grpcClient,
sessionRegistry: sessionRegistry,
portRegistry: portRegistry,
}, nil }, nil
} }
@@ -103,7 +106,7 @@ func (s *server) handleConnection(conn net.Conn) {
cancel() cancel()
} }
log.Println("SSH connection established:", sshConn.User()) log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user) sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
err = sshSession.Start() err = sshSession.Start()
if err != nil { if err != nil {
log.Printf("SSH session ended with error: %v", err) log.Printf("SSH session ended with error: %v", err)
+2 -7
View File
@@ -30,7 +30,6 @@ type Interaction interface {
Mode() types.Mode Mode() types.Mode
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
SetSessionRegistry(registry SessionRegistry)
SetMode(m types.Mode) SetMode(m types.Mode)
SetWH(w, h int) SetWH(w, h int)
Start() Start()
@@ -80,24 +79,20 @@ func (i *interaction) SetWH(w, h int) {
} }
} }
func New(slug slug.Slug, forwarder Forwarder) Interaction { func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry) Interaction {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &interaction{ return &interaction{
channel: nil, channel: nil,
slug: slug, slug: slug,
forwarder: forwarder, forwarder: forwarder,
lifecycle: nil, lifecycle: nil,
sessionRegistry: nil, sessionRegistry: sessionRegistry,
program: nil, program: nil,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
} }
func (i *interaction) SetSessionRegistry(registry SessionRegistry) {
i.sessionRegistry = registry
}
func (i *interaction) SetLifecycle(lifecycle Lifecycle) { func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle i.lifecycle = lifecycle
} }
+11 -9
View File
@@ -28,41 +28,43 @@ type lifecycle struct {
conn ssh.Conn conn ssh.Conn
channel ssh.Channel channel ssh.Channel
forwarder Forwarder forwarder Forwarder
sessionRegistry SessionRegistry
slug slug.Slug slug slug.Slug
startedAt time.Time startedAt time.Time
sessionRegistry SessionRegistry
portRegistry portUtil.Registry
user string user string
} }
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle { func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle {
return &lifecycle{ return &lifecycle{
status: types.INITIALIZING, status: types.INITIALIZING,
conn: conn, conn: conn,
channel: nil, channel: nil,
forwarder: forwarder, forwarder: forwarder,
slug: slugManager, slug: slugManager,
sessionRegistry: nil,
startedAt: time.Now(), startedAt: time.Now(),
sessionRegistry: sessionRegistry,
portRegistry: port,
user: user, user: user,
} }
} }
func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) {
l.sessionRegistry = registry
}
type Lifecycle interface { type Lifecycle interface {
Connection() ssh.Conn Connection() ssh.Conn
Channel() ssh.Channel Channel() ssh.Channel
PortRegistry() portUtil.Registry
User() string User() string
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetSessionRegistry(registry SessionRegistry)
SetStatus(status types.Status) SetStatus(status types.Status)
IsActive() bool IsActive() bool
StartedAt() time.Time StartedAt() time.Time
Close() error Close() error
} }
func (l *lifecycle) PortRegistry() portUtil.Registry {
return l.portRegistry
}
func (l *lifecycle) User() string { func (l *lifecycle) User() string {
return l.user return l.user
} }
@@ -116,7 +118,7 @@ func (l *lifecycle) Close() error {
l.sessionRegistry.Remove(key) l.sessionRegistry.Remove(key)
if tunnelType == types.TCP { if tunnelType == types.TCP {
if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
firstErr = err firstErr = err
} }
} }
+12 -14
View File
@@ -54,16 +54,14 @@ type session struct {
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session { func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session {
slugManager := slug.New() slugManager := slug.New()
forwarderManager := forwarder.New(slugManager) forwarderManager := forwarder.New(slugManager)
interactionManager := interaction.New(slugManager, forwarderManager) interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
interactionManager.SetLifecycle(lifecycleManager) interactionManager.SetLifecycle(lifecycleManager)
forwarderManager.SetLifecycle(lifecycleManager) forwarderManager.SetLifecycle(lifecycleManager)
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
return &session{ return &session{
initialReq: initialReq, initialReq: initialReq,
@@ -135,7 +133,7 @@ func (s *session) Start() error {
tcpipReq := s.waitForTCPIPForward() tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil { if tcpipReq == nil {
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
if err != nil { if err != nil {
return err return err
} }
@@ -234,7 +232,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
} }
func (s *session) HandleTCPIPForward(req *ssh.Request) { func (s *session) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected") log.Println("PortRegistry forwarding request detected")
fail := func(msg string) { fail := func(msg string) {
log.Println(msg) log.Println(msg)
@@ -262,13 +260,13 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) {
} }
if rawPortToBind > 65535 { if rawPortToBind > 65535 {
fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind)) fail(fmt.Sprintf("PortRegistry %d is larger than allowed port of 65535", rawPortToBind))
return return
} }
portToBind := uint16(rawPortToBind) portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) { if isBlockedPort(portToBind) {
fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind)) fail(fmt.Sprintf("PortRegistry %d is blocked or restricted", portToBind))
return return
} }
@@ -340,7 +338,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
s.registry.Remove(*key) s.registry.Remove(*key)
} }
if port != 0 { if port != 0 {
if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil { if setErr := s.lifecycle.PortRegistry().SetPortStatus(port, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr) log.Printf("Failed to reset port status: %v", setErr)
} }
} }
@@ -356,7 +354,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
} }
if portToBind == 0 { if portToBind == 0 {
unassigned, ok := portUtil.Default.GetUnassignedPort() unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort()
if !ok { if !ok {
fail("No available port") fail("No available port")
return return
@@ -364,15 +362,15 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
portToBind = unassigned portToBind = unassigned
} }
if claimed := portUtil.Default.ClaimPort(portToBind); !claimed { if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed {
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind)) fail(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
return return
} }
log.Printf("Requested forwarding on %s:%d", addr, portToBind) log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil { if err != nil {
cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil) cleanup(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind), portToBind, nil, nil)
return return
} }