refactor(session): add registry to manage SSH sessions
All checks were successful
renovate / renovate (push) Successful in 39s
Docker Build and Push / build-and-push-branches (push) Successful in 4m27s
Docker Build and Push / build-and-push-tags (push) Successful in 4m22s

- Implement thread-safe session registry with sync.RWMutex
- Add Registry interface for session management operations
- Support Get, Register, Update, and Remove session operations
- Enable dynamic slug updates for existing sessions
This commit is contained in:
2025-12-31 17:47:35 +07:00
parent acd02aadd3
commit f8a6f0bafe
10 changed files with 218 additions and 224 deletions

View File

@@ -1,8 +1,8 @@
package session
import (
"fmt"
"log"
"sync"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/session/forwarder"
@@ -13,11 +13,6 @@ import (
"golang.org/x/crypto/ssh"
)
var (
clientsMutex sync.RWMutex
Clients = make(map[string]*SSHSession)
)
type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
@@ -26,10 +21,13 @@ type Session interface {
}
type SSHSession struct {
lifecycle lifecycle.SessionLifecycle
interaction interaction.Controller
forwarder forwarder.ForwardingController
slugManager slug.Manager
initialReq <-chan *ssh.Request
sshReqChannel <-chan ssh.NewChannel
lifecycle lifecycle.SessionLifecycle
interaction interaction.Controller
forwarder forwarder.ForwardingController
slugManager slug.Manager
registry Registry
}
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
@@ -48,55 +46,64 @@ func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager
}
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry) *SSHSession {
slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager)
interactionManager.SetLifecycle(lifecycleManager)
interactionManager.SetSlugModificator(updateClientSlug)
interactionManager.SetSlugModificator(sessionRegistry.Update)
forwarderManager.SetLifecycle(lifecycleManager)
lifecycleManager.SetUnregisterClient(unregisterClient)
lifecycleManager.SetUnregisterClient(sessionRegistry.Remove)
session := &SSHSession{
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slugManager: slugManager,
}
var once sync.Once
for channel := range sshChan {
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
continue
}
once.Do(func() {
session.lifecycle.SetChannel(ch)
session.interaction.SetChannel(ch)
tcpipReq := session.waitForTCPIPForward(forwardingReq)
if tcpipReq == nil {
log.Printf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))
if err := session.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
go session.HandleTCPIPForward(tcpipReq)
})
session.HandleGlobalRequest(reqs)
}
if err := session.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return &SSHSession{
initialReq: forwardingReq,
sshReqChannel: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slugManager: slugManager,
registry: sessionRegistry,
}
}
func (s *SSHSession) waitForTCPIPForward(forwardingReq <-chan *ssh.Request) *ssh.Request {
func (s *SSHSession) Start() error {
channel := <-s.sshReqChannel
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
return err
}
go s.HandleGlobalRequest(reqs)
tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil {
_, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))))
if err != nil {
return err
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("No forwarding Request")
}
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
go s.HandleTCPIPForward(tcpipReq)
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err
}
return nil
}
func (s *SSHSession) waitForTCPIPForward() *ssh.Request {
select {
case req, ok := <-forwardingReq:
case req, ok := <-s.initialReq:
if !ok {
log.Println("Forwarding request channel closed")
return nil
@@ -114,41 +121,3 @@ func (s *SSHSession) waitForTCPIPForward(forwardingReq <-chan *ssh.Request) *ssh
return nil
}
}
func updateClientSlug(oldSlug, newSlug string) bool {
clientsMutex.Lock()
defer clientsMutex.Unlock()
if _, exists := Clients[newSlug]; exists && newSlug != oldSlug {
return false
}
client, ok := Clients[oldSlug]
if !ok {
return false
}
delete(Clients, oldSlug)
client.slugManager.Set(newSlug)
Clients[newSlug] = client
return true
}
func registerClient(slug string, session *SSHSession) bool {
clientsMutex.Lock()
defer clientsMutex.Unlock()
if _, exists := Clients[slug]; exists {
return false
}
Clients[slug] = session
return true
}
func unregisterClient(slug string) {
clientsMutex.Lock()
defer clientsMutex.Unlock()
delete(Clients, slug)
}