refactor: explicit initialization and dependency injection
- Replace init() with config.Load() function when loading env variables - Inject portRegistry into session, server, and lifecycle structs - Inject sessionRegistry directly into interaction and lifecycle - Remove SetSessionRegistry function and global port variables - Pass ssh.Conn directly to forwarder constructor instead of lifecycle interface - Pass user and closeFunc callback to interaction constructor instead of lifecycle interface - Eliminate circular dependencies between lifecycle, forwarder, and interaction - Remove setter methods (SetLifecycle) from forwarder and interaction interfaces
This commit is contained in:
@@ -35,26 +35,21 @@ type forwarder struct {
|
||||
tunnelType types.TunnelType
|
||||
forwardedPort uint16
|
||||
slug slug.Slug
|
||||
lifecycle Lifecycle
|
||||
conn ssh.Conn
|
||||
}
|
||||
|
||||
func New(slug slug.Slug) Forwarder {
|
||||
func New(slug slug.Slug, conn ssh.Conn) Forwarder {
|
||||
return &forwarder{
|
||||
listener: nil,
|
||||
tunnelType: types.UNKNOWN,
|
||||
forwardedPort: 0,
|
||||
slug: slug,
|
||||
lifecycle: nil,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
type Lifecycle interface {
|
||||
Connection() ssh.Conn
|
||||
}
|
||||
|
||||
type Forwarder interface {
|
||||
SetType(tunnelType types.TunnelType)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
SetForwardedPort(port uint16)
|
||||
SetListener(listener net.Listener)
|
||||
Listener() net.Listener
|
||||
@@ -67,10 +62,6 @@ type Forwarder interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (f *forwarder) SetLifecycle(lifecycle Lifecycle) {
|
||||
f.lifecycle = lifecycle
|
||||
}
|
||||
|
||||
func (f *forwarder) AcceptTCPConnections() {
|
||||
for {
|
||||
conn, err := f.Listener().Accept()
|
||||
@@ -82,7 +73,7 @@ func (f *forwarder) AcceptTCPConnections() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
if err = conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
log.Printf("Failed to set connection deadline: %v", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
@@ -100,7 +91,7 @@ func (f *forwarder) AcceptTCPConnections() {
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := f.lifecycle.Connection().OpenChannel("forwarded-tcpip", payload)
|
||||
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
}()
|
||||
|
||||
@@ -114,7 +105,7 @@ func (f *forwarder) AcceptTCPConnections() {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||
if err = conn.SetDeadline(time.Time{}); err != nil {
|
||||
log.Printf("Failed to clear connection deadline: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ func (m *model) dashboardView() string {
|
||||
MarginBottom(boxMargin).
|
||||
Width(boxMaxWidth)
|
||||
|
||||
authenticatedUser := m.interaction.lifecycle.User()
|
||||
authenticatedUser := m.interaction.user
|
||||
|
||||
userInfoStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
|
||||
@@ -17,20 +17,9 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Lifecycle interface {
|
||||
Close() error
|
||||
User() string
|
||||
}
|
||||
|
||||
type SessionRegistry interface {
|
||||
Update(user string, oldKey, newKey types.SessionKey) error
|
||||
}
|
||||
|
||||
type Interaction interface {
|
||||
Mode() types.Mode
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
SetSessionRegistry(registry SessionRegistry)
|
||||
SetMode(m types.Mode)
|
||||
SetWH(w, h int)
|
||||
Start()
|
||||
@@ -38,17 +27,23 @@ type Interaction interface {
|
||||
Send(message string) error
|
||||
}
|
||||
|
||||
type SessionRegistry interface {
|
||||
Update(user string, oldKey, newKey types.SessionKey) error
|
||||
}
|
||||
|
||||
type Forwarder interface {
|
||||
Close() error
|
||||
TunnelType() types.TunnelType
|
||||
ForwardedPort() uint16
|
||||
}
|
||||
|
||||
type CloseFunc func() error
|
||||
type interaction struct {
|
||||
channel ssh.Channel
|
||||
slug slug.Slug
|
||||
forwarder Forwarder
|
||||
lifecycle Lifecycle
|
||||
closeFunc CloseFunc
|
||||
user string
|
||||
sessionRegistry SessionRegistry
|
||||
program *tea.Program
|
||||
ctx context.Context
|
||||
@@ -80,28 +75,21 @@ func (i *interaction) SetWH(w, h int) {
|
||||
}
|
||||
}
|
||||
|
||||
func New(slug slug.Slug, forwarder Forwarder) Interaction {
|
||||
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &interaction{
|
||||
channel: nil,
|
||||
slug: slug,
|
||||
forwarder: forwarder,
|
||||
lifecycle: nil,
|
||||
sessionRegistry: nil,
|
||||
closeFunc: closeFunc,
|
||||
user: user,
|
||||
sessionRegistry: sessionRegistry,
|
||||
program: nil,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *interaction) SetSessionRegistry(registry SessionRegistry) {
|
||||
i.sessionRegistry = registry
|
||||
}
|
||||
|
||||
func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
|
||||
i.lifecycle = lifecycle
|
||||
}
|
||||
|
||||
func (i *interaction) SetChannel(channel ssh.Channel) {
|
||||
i.channel = channel
|
||||
}
|
||||
@@ -262,7 +250,9 @@ func (i *interaction) Start() {
|
||||
}
|
||||
i.program.Kill()
|
||||
i.program = nil
|
||||
if err := m.interaction.lifecycle.Close(); err != nil {
|
||||
log.Printf("Cannot close session: %s \n", err)
|
||||
if i.closeFunc != nil {
|
||||
if err := i.closeFunc(); err != nil {
|
||||
log.Printf("Cannot close session: %s \n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case "enter":
|
||||
inputValue := m.slugInput.Value()
|
||||
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.User(), types.SessionKey{
|
||||
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
|
||||
Id: m.interaction.slug.String(),
|
||||
Type: types.HTTP,
|
||||
}, types.SessionKey{
|
||||
|
||||
@@ -28,41 +28,43 @@ type lifecycle struct {
|
||||
conn ssh.Conn
|
||||
channel ssh.Channel
|
||||
forwarder Forwarder
|
||||
sessionRegistry SessionRegistry
|
||||
slug slug.Slug
|
||||
startedAt time.Time
|
||||
sessionRegistry SessionRegistry
|
||||
portRegistry portUtil.Registry
|
||||
user string
|
||||
}
|
||||
|
||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle {
|
||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Registry, sessionRegistry SessionRegistry, user string) Lifecycle {
|
||||
return &lifecycle{
|
||||
status: types.INITIALIZING,
|
||||
conn: conn,
|
||||
channel: nil,
|
||||
forwarder: forwarder,
|
||||
slug: slugManager,
|
||||
sessionRegistry: nil,
|
||||
startedAt: time.Now(),
|
||||
sessionRegistry: sessionRegistry,
|
||||
portRegistry: port,
|
||||
user: user,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) {
|
||||
l.sessionRegistry = registry
|
||||
}
|
||||
|
||||
type Lifecycle interface {
|
||||
Connection() ssh.Conn
|
||||
Channel() ssh.Channel
|
||||
PortRegistry() portUtil.Registry
|
||||
User() string
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetSessionRegistry(registry SessionRegistry)
|
||||
SetStatus(status types.Status)
|
||||
IsActive() bool
|
||||
StartedAt() time.Time
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (l *lifecycle) PortRegistry() portUtil.Registry {
|
||||
return l.portRegistry
|
||||
}
|
||||
|
||||
func (l *lifecycle) User() string {
|
||||
return l.user
|
||||
}
|
||||
@@ -116,7 +118,7 @@ func (l *lifecycle) Close() error {
|
||||
l.sessionRegistry.Remove(key)
|
||||
|
||||
if tunnelType == types.TCP {
|
||||
if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
||||
if err := l.PortRegistry().SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
|
||||
+13
-18
@@ -54,16 +54,11 @@ type session struct {
|
||||
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session {
|
||||
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, portRegistry portUtil.Registry, user string) Session {
|
||||
slugManager := slug.New()
|
||||
forwarderManager := forwarder.New(slugManager)
|
||||
interactionManager := interaction.New(slugManager, forwarderManager)
|
||||
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user)
|
||||
|
||||
interactionManager.SetLifecycle(lifecycleManager)
|
||||
forwarderManager.SetLifecycle(lifecycleManager)
|
||||
interactionManager.SetSessionRegistry(sessionRegistry)
|
||||
lifecycleManager.SetSessionRegistry(sessionRegistry)
|
||||
forwarderManager := forwarder.New(slugManager, conn)
|
||||
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
|
||||
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
|
||||
|
||||
return &session{
|
||||
initialReq: initialReq,
|
||||
@@ -135,7 +130,7 @@ func (s *session) Start() error {
|
||||
|
||||
tcpipReq := s.waitForTCPIPForward()
|
||||
if tcpipReq == nil {
|
||||
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
|
||||
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -234,7 +229,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||
}
|
||||
|
||||
func (s *session) HandleTCPIPForward(req *ssh.Request) {
|
||||
log.Println("Port forwarding request detected")
|
||||
log.Println("PortRegistry forwarding request detected")
|
||||
|
||||
fail := func(msg string) {
|
||||
log.Println(msg)
|
||||
@@ -262,13 +257,13 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) {
|
||||
}
|
||||
|
||||
if rawPortToBind > 65535 {
|
||||
fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind))
|
||||
fail(fmt.Sprintf("PortRegistry %d is larger than allowed port of 65535", rawPortToBind))
|
||||
return
|
||||
}
|
||||
|
||||
portToBind := uint16(rawPortToBind)
|
||||
if isBlockedPort(portToBind) {
|
||||
fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind))
|
||||
fail(fmt.Sprintf("PortRegistry %d is blocked or restricted", portToBind))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -340,7 +335,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
s.registry.Remove(*key)
|
||||
}
|
||||
if port != 0 {
|
||||
if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil {
|
||||
if setErr := s.lifecycle.PortRegistry().SetPortStatus(port, false); setErr != nil {
|
||||
log.Printf("Failed to reset port status: %v", setErr)
|
||||
}
|
||||
}
|
||||
@@ -356,7 +351,7 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
}
|
||||
|
||||
if portToBind == 0 {
|
||||
unassigned, ok := portUtil.Default.GetUnassignedPort()
|
||||
unassigned, ok := s.lifecycle.PortRegistry().GetUnassignedPort()
|
||||
if !ok {
|
||||
fail("No available port")
|
||||
return
|
||||
@@ -364,15 +359,15 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
portToBind = unassigned
|
||||
}
|
||||
|
||||
if claimed := portUtil.Default.ClaimPort(portToBind); !claimed {
|
||||
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||
if claimed := s.lifecycle.PortRegistry().ClaimPort(portToBind); !claimed {
|
||||
fail(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||
if err != nil {
|
||||
cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil)
|
||||
cleanup(fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind), portToBind, nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user