refactor: inject SessionRegistry interface instead of individual functions

This commit is contained in:
2026-01-05 16:44:03 +07:00
parent 6de0a618ee
commit 767737a719
6 changed files with 83 additions and 75 deletions

2
go.mod
View File

@@ -52,4 +52,4 @@ require (
golang.org/x/text v0.32.0 // indirect golang.org/x/text v0.32.0 // indirect
golang.org/x/tools v0.40.0 // indirect golang.org/x/tools v0.40.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
) )

View File

@@ -217,6 +217,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
} }
switch recv.GetType() { switch recv.GetType() {
case proto.EventType_SLUG_CHANGE: case proto.EventType_SLUG_CHANGE:
user := recv.GetSlugEvent().GetUser()
oldSlug := recv.GetSlugEvent().GetOld() oldSlug := recv.GetSlugEvent().GetOld()
newSlug := recv.GetSlugEvent().GetNew() newSlug := recv.GetSlugEvent().GetNew()
var userSession *session.SSHSession var userSession *session.SSHSession
@@ -242,7 +243,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
} }
continue continue
} }
err = c.sessionRegistry.Update(types.SessionKey{ err = c.sessionRegistry.Update(user, types.SessionKey{
Id: oldSlug, Id: oldSlug,
Type: types.HTTP, Type: types.HTTP,
}, types.SessionKey{ }, types.SessionKey{

View File

@@ -23,15 +23,20 @@ import (
type Lifecycle interface { type Lifecycle interface {
Close() error Close() error
GetUser() string
}
type SessionRegistry interface {
Update(user string, oldKey, newKey types.SessionKey) error
} }
type Controller interface { type Controller interface {
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle) SetLifecycle(lifecycle Lifecycle)
SetSlugModificator(func(oldSlug, newSlug string) error)
Start() Start()
SetWH(w, h int) SetWH(w, h int)
Redraw() Redraw()
SetSessionRegistry(registry SessionRegistry)
} }
type Forwarder interface { type Forwarder interface {
@@ -41,14 +46,14 @@ type Forwarder interface {
} }
type Interaction struct { type Interaction struct {
channel ssh.Channel channel ssh.Channel
slugManager slug.Manager slugManager slug.Manager
forwarder Forwarder forwarder Forwarder
lifecycle Lifecycle lifecycle Lifecycle
updateClientSlug func(oldSlug, newSlug string) error sessionRegistry SessionRegistry
program *tea.Program program *tea.Program
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
} }
func (i *Interaction) SetWH(w, h int) { func (i *Interaction) SetWH(w, h int) {
@@ -102,17 +107,21 @@ type tickMsg time.Time
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction { func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Interaction{ return &Interaction{
channel: nil, channel: nil,
slugManager: slugManager, slugManager: slugManager,
forwarder: forwarder, forwarder: forwarder,
lifecycle: nil, lifecycle: nil,
updateClientSlug: nil, sessionRegistry: nil,
program: nil, program: nil,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
} }
func (i *Interaction) SetSessionRegistry(registry SessionRegistry) {
i.sessionRegistry = registry
}
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle i.lifecycle = lifecycle
} }
@@ -121,10 +130,6 @@ func (i *Interaction) SetChannel(channel ssh.Channel) {
i.channel = channel i.channel = channel
} }
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) error) {
i.updateClientSlug = modificator
}
func (i *Interaction) Stop() { func (i *Interaction) Stop() {
if i.cancel != nil { if i.cancel != nil {
i.cancel() i.cancel()
@@ -218,7 +223,13 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter": case "enter":
inputValue := m.slugInput.Value() inputValue := m.slugInput.Value()
if err := m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue); err != nil { if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.GetUser(), types.SessionKey{
Id: m.interaction.slugManager.Get(),
Type: types.HTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.HTTP,
}); err != nil {
m.slugError = err.Error() m.slugError = err.Error()
return m, nil return m, nil
} }

View File

@@ -19,30 +19,36 @@ type Forwarder interface {
GetForwardedPort() uint16 GetForwardedPort() uint16
} }
type Lifecycle struct { type SessionRegistry interface {
status types.Status Remove(key types.SessionKey)
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
slugManager slug.Manager
unregisterClient func(key types.SessionKey)
startedAt time.Time
} }
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle { type Lifecycle struct {
status types.Status
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
sessionRegistry SessionRegistry
slugManager slug.Manager
startedAt time.Time
user string
}
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager, user string) *Lifecycle {
return &Lifecycle{ return &Lifecycle{
status: types.INITIALIZING, status: types.INITIALIZING,
conn: conn, conn: conn,
channel: nil, channel: nil,
forwarder: forwarder, forwarder: forwarder,
slugManager: slugManager, slugManager: slugManager,
unregisterClient: nil, sessionRegistry: nil,
startedAt: time.Now(), startedAt: time.Now(),
user: user,
} }
} }
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(key types.SessionKey)) { func (l *Lifecycle) SetSessionRegistry(registry SessionRegistry) {
l.unregisterClient = unregisterClient l.sessionRegistry = registry
} }
type SessionLifecycle interface { type SessionLifecycle interface {
@@ -50,12 +56,17 @@ type SessionLifecycle interface {
SetStatus(status types.Status) SetStatus(status types.Status)
GetConnection() ssh.Conn GetConnection() ssh.Conn
GetChannel() ssh.Channel GetChannel() ssh.Channel
GetUser() string
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetUnregisterClient(unregisterClient func(key types.SessionKey)) SetSessionRegistry(registry SessionRegistry)
IsActive() bool IsActive() bool
StartedAt() time.Time StartedAt() time.Time
} }
func (l *Lifecycle) GetUser() string {
return l.user
}
func (l *Lifecycle) GetChannel() ssh.Channel { func (l *Lifecycle) GetChannel() ssh.Channel {
return l.channel return l.channel
} }
@@ -94,13 +105,13 @@ func (l *Lifecycle) Close() error {
} }
clientSlug := l.slugManager.Get() clientSlug := l.slugManager.Get()
if clientSlug != "" && l.unregisterClient != nil { if clientSlug != "" && l.sessionRegistry.Remove != nil {
key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()} key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()}
l.unregisterClient(key) l.sessionRegistry.Remove(key)
} }
if l.forwarder.GetTunnelType() == types.TCP { if l.forwarder.GetTunnelType() == types.TCP {
err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false) err = portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -10,7 +10,7 @@ type Key = types.SessionKey
type Registry interface { type Registry interface {
Get(key Key) (session *SSHSession, err error) Get(key Key) (session *SSHSession, err error)
Update(oldKey, newKey Key) error Update(user string, oldKey, newKey Key) error
Register(key Key, session *SSHSession) (success bool) Register(key Key, session *SSHSession) (success bool)
Remove(key Key) Remove(key Key)
GetAllSessionFromUser(user string) []*SSHSession GetAllSessionFromUser(user string) []*SSHSession
@@ -44,7 +44,7 @@ func (r *registry) Get(key Key) (session *SSHSession, err error) {
return client, nil return client, nil
} }
func (r *registry) Update(oldKey, newKey Key) error { func (r *registry) Update(user string, oldKey, newKey Key) error {
if oldKey.Type != newKey.Type { if oldKey.Type != newKey.Type {
return fmt.Errorf("tunnel type cannot change") return fmt.Errorf("tunnel type cannot change")
} }
@@ -64,30 +64,24 @@ func (r *registry) Update(oldKey, newKey Key) error {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
userID, ok := r.slugIndex[oldKey]
if !ok {
return fmt.Errorf("session not found")
}
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return fmt.Errorf("someone already uses this subdomain") return fmt.Errorf("someone already uses this subdomain")
} }
client, ok := r.byUser[user][oldKey]
client, ok := r.byUser[userID][oldKey]
if !ok { if !ok {
return fmt.Errorf("session not found") return fmt.Errorf("session not found")
} }
delete(r.byUser[userID], oldKey) delete(r.byUser[user], oldKey)
delete(r.slugIndex, oldKey) delete(r.slugIndex, oldKey)
client.slugManager.Set(newKey.Id) client.slugManager.Set(newKey.Id)
r.slugIndex[newKey] = userID r.slugIndex[newKey] = user
if r.byUser[userID] == nil { if r.byUser[user] == nil {
r.byUser[userID] = make(map[Key]*SSHSession) r.byUser[user] = make(map[Key]*SSHSession)
} }
r.byUser[userID][newKey] = client r.byUser[user][newKey] = client
return nil return nil
} }
@@ -99,7 +93,7 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) {
return false return false
} }
userID := session.userID userID := session.lifecycle.GetUser()
if r.byUser[userID] == nil { if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]*SSHSession) r.byUser[userID] = make(map[Key]*SSHSession)
} }

View File

@@ -9,7 +9,6 @@ import (
"tunnel_pls/session/interaction" "tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle" "tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -29,7 +28,6 @@ type SSHSession struct {
forwarder forwarder.ForwardingController forwarder forwarder.ForwardingController
slugManager slug.Manager slugManager slug.Manager
registry Registry registry Registry
userID string
} }
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle { func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
@@ -48,22 +46,16 @@ func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager return s.slugManager
} }
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) *SSHSession {
slugManager := slug.NewManager() slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager) forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager) interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager) lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager) interactionManager.SetLifecycle(lifecycleManager)
interactionManager.SetSlugModificator(func(oldSlug, newSlug string) error {
oldKey := types.SessionKey{Id: oldSlug, Type: forwarderManager.GetTunnelType()}
newKey := types.SessionKey{Id: newSlug, Type: forwarderManager.GetTunnelType()}
return sessionRegistry.Update(oldKey, newKey)
})
forwarderManager.SetLifecycle(lifecycleManager) forwarderManager.SetLifecycle(lifecycleManager)
lifecycleManager.SetUnregisterClient(func(key types.SessionKey) { interactionManager.SetSessionRegistry(sessionRegistry)
sessionRegistry.Remove(key) lifecycleManager.SetSessionRegistry(sessionRegistry)
})
return &SSHSession{ return &SSHSession{
initialReq: forwardingReq, initialReq: forwardingReq,
@@ -73,7 +65,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
forwarder: forwarderManager, forwarder: forwarderManager,
slugManager: slugManager, slugManager: slugManager,
registry: sessionRegistry, registry: sessionRegistry,
userID: userID,
} }
} }
@@ -89,7 +80,7 @@ func (s *SSHSession) Detail() Detail {
return Detail{ return Detail{
ForwardingType: string(s.forwarder.GetTunnelType()), ForwardingType: string(s.forwarder.GetTunnelType()),
Slug: s.slugManager.Get(), Slug: s.slugManager.Get(),
UserID: s.userID, UserID: s.lifecycle.GetUser(),
Active: s.lifecycle.IsActive(), Active: s.lifecycle.IsActive(),
StartedAt: s.lifecycle.StartedAt(), StartedAt: s.lifecycle.StartedAt(),
} }