refactor: explicit initialization and dependency injection
- Replace init() with config.Load() function when loading env variables - Inject portRegistry into session, server, and lifecycle structs - Inject sessionRegistry directly into interaction and lifecycle - Remove SetSessionRegistry function and global port variables
This commit is contained in:
@@ -30,7 +30,6 @@ type Interaction interface {
|
||||
Mode() types.Mode
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
SetSessionRegistry(registry SessionRegistry)
|
||||
SetMode(m types.Mode)
|
||||
SetWH(w, h int)
|
||||
Start()
|
||||
@@ -80,24 +79,20 @@ func (i *interaction) SetWH(w, h int) {
|
||||
}
|
||||
}
|
||||
|
||||
func New(slug slug.Slug, forwarder Forwarder) Interaction {
|
||||
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry) Interaction {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &interaction{
|
||||
channel: nil,
|
||||
slug: slug,
|
||||
forwarder: forwarder,
|
||||
lifecycle: nil,
|
||||
sessionRegistry: nil,
|
||||
sessionRegistry: sessionRegistry,
|
||||
program: nil,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *interaction) SetSessionRegistry(registry SessionRegistry) {
|
||||
i.sessionRegistry = registry
|
||||
}
|
||||
|
||||
func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
|
||||
i.lifecycle = lifecycle
|
||||
}
|
||||
|
||||
@@ -28,41 +28,43 @@ type lifecycle struct {
|
||||
conn ssh.Conn
|
||||
channel ssh.Channel
|
||||
forwarder Forwarder
|
||||
sessionRegistry SessionRegistry
|
||||
slug slug.Slug
|
||||
startedAt time.Time
|
||||
sessionRegistry SessionRegistry
|
||||
portRegistry portUtil.Registry
|
||||
user string
|
||||
}
|
||||
|
||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle {
|
||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle {
|
||||
return &lifecycle{
|
||||
status: types.INITIALIZING,
|
||||
conn: conn,
|
||||
channel: nil,
|
||||
forwarder: forwarder,
|
||||
slug: slugManager,
|
||||
sessionRegistry: nil,
|
||||
startedAt: time.Now(),
|
||||
sessionRegistry: sessionRegistry,
|
||||
portRegistry: port,
|
||||
user: user,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) {
|
||||
l.sessionRegistry = registry
|
||||
}
|
||||
|
||||
type Lifecycle interface {
|
||||
Connection() ssh.Conn
|
||||
Channel() ssh.Channel
|
||||
PortRegistry() portUtil.Registry
|
||||
User() string
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetSessionRegistry(registry SessionRegistry)
|
||||
SetStatus(status types.Status)
|
||||
IsActive() bool
|
||||
StartedAt() time.Time
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (l *lifecycle) PortRegistry() portUtil.Registry {
|
||||
return l.portRegistry
|
||||
}
|
||||
|
||||
func (l *lifecycle) User() string {
|
||||
return l.user
|
||||
}
|
||||
@@ -116,7 +118,7 @@ func (l *lifecycle) Close() error {
|
||||
l.sessionRegistry.Remove(key)
|
||||
|
||||
if tunnelType == types.TCP {
|
||||
if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
||||
if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
|
||||
+12
-14
@@ -54,16 +54,14 @@ type session struct {
|
||||
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session {
|
||||
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session {
|
||||
slugManager := slug.New()
|
||||
forwarderManager := forwarder.New(slugManager)
|
||||
interactionManager := interaction.New(slugManager, forwarderManager)
|
||||
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user)
|
||||
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry)
|
||||
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
|
||||
|
||||
interactionManager.SetLifecycle(lifecycleManager)
|
||||
forwarderManager.SetLifecycle(lifecycleManager)
|
||||
interactionManager.SetSessionRegistry(sessionRegistry)
|
||||
lifecycleManager.SetSessionRegistry(sessionRegistry)
|
||||
|
||||
return &session{
|
||||
initialReq: initialReq,
|
||||
@@ -135,7 +133,7 @@ func (s *session) Start() error {
|
||||
|
||||
tcpipReq := s.waitForTCPIPForward()
|
||||
if tcpipReq == nil {
|
||||
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
|
||||
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -234,7 +232,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||
}
|
||||
|
||||
func (s *session) HandleTCPIPForward(req *ssh.Request) {
|
||||
log.Println("Port forwarding request detected")
|
||||
log.Println("PortRegistry forwarding request detected")
|
||||
|
||||
fail := func(msg string) {
|
||||
log.Println(msg)
|
||||
@@ -262,13 +260,13 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) {
|
||||
}
|
||||
|
||||
if rawPortToBind > 65535 {
|
||||
fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind))
|
||||
fail(fmt.Sprintf("PortRegistry %d is larger than allowed port of 65535", rawPortToBind))
|
||||
return
|
||||
}
|
||||
|
||||
portToBind := uint16(rawPortToBind)
|
||||
if isBlockedPort(portToBind) {
|
||||
fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind))
|
||||
fail(fmt.Sprintf("PortRegistry %d is blocked or restricted", portToBind))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -340,7 +338,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
s.registry.Remove(*key)
|
||||
}
|
||||
if port != 0 {
|
||||
if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil {
|
||||
if setErr := s.lifecycle.PortRegistry().SetPortStatus(port, false); setErr != nil {
|
||||
log.Printf("Failed to reset port status: %v", setErr)
|
||||
}
|
||||
}
|
||||
@@ -356,7 +354,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
}
|
||||
|
||||
if portToBind == 0 {
|
||||
unassigned, ok := portUtil.Default.GetUnassignedPort()
|
||||
unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort()
|
||||
if !ok {
|
||||
fail("No available port")
|
||||
return
|
||||
@@ -364,15 +362,15 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
portToBind = unassigned
|
||||
}
|
||||
|
||||
if claimed := portUtil.Default.ClaimPort(portToBind); !claimed {
|
||||
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||
if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed {
|
||||
fail(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||
if err != nil {
|
||||
cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil)
|
||||
cleanup(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind), portToBind, nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user