refactor: inject SessionRegistry interface instead of individual functions

This commit is contained in:
2026-01-05 16:44:03 +07:00
parent 6de0a618ee
commit 767737a719
6 changed files with 83 additions and 75 deletions

2
go.mod
View File

@@ -52,4 +52,4 @@ require (
golang.org/x/text v0.32.0 // indirect
golang.org/x/tools v0.40.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
)
)

View File

@@ -217,6 +217,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
}
switch recv.GetType() {
case proto.EventType_SLUG_CHANGE:
user := recv.GetSlugEvent().GetUser()
oldSlug := recv.GetSlugEvent().GetOld()
newSlug := recv.GetSlugEvent().GetNew()
var userSession *session.SSHSession
@@ -242,7 +243,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
}
continue
}
err = c.sessionRegistry.Update(types.SessionKey{
err = c.sessionRegistry.Update(user, types.SessionKey{
Id: oldSlug,
Type: types.HTTP,
}, types.SessionKey{

View File

@@ -23,15 +23,20 @@ import (
type Lifecycle interface {
Close() error
GetUser() string
}
type SessionRegistry interface {
Update(user string, oldKey, newKey types.SessionKey) error
}
type Controller interface {
SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle)
SetSlugModificator(func(oldSlug, newSlug string) error)
Start()
SetWH(w, h int)
Redraw()
SetSessionRegistry(registry SessionRegistry)
}
type Forwarder interface {
@@ -41,14 +46,14 @@ type Forwarder interface {
}
type Interaction struct {
channel ssh.Channel
slugManager slug.Manager
forwarder Forwarder
lifecycle Lifecycle
updateClientSlug func(oldSlug, newSlug string) error
program *tea.Program
ctx context.Context
cancel context.CancelFunc
channel ssh.Channel
slugManager slug.Manager
forwarder Forwarder
lifecycle Lifecycle
sessionRegistry SessionRegistry
program *tea.Program
ctx context.Context
cancel context.CancelFunc
}
func (i *Interaction) SetWH(w, h int) {
@@ -102,17 +107,21 @@ type tickMsg time.Time
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
ctx, cancel := context.WithCancel(context.Background())
return &Interaction{
channel: nil,
slugManager: slugManager,
forwarder: forwarder,
lifecycle: nil,
updateClientSlug: nil,
program: nil,
ctx: ctx,
cancel: cancel,
channel: nil,
slugManager: slugManager,
forwarder: forwarder,
lifecycle: nil,
sessionRegistry: nil,
program: nil,
ctx: ctx,
cancel: cancel,
}
}
func (i *Interaction) SetSessionRegistry(registry SessionRegistry) {
i.sessionRegistry = registry
}
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle
}
@@ -121,10 +130,6 @@ func (i *Interaction) SetChannel(channel ssh.Channel) {
i.channel = channel
}
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) error) {
i.updateClientSlug = modificator
}
func (i *Interaction) Stop() {
if i.cancel != nil {
i.cancel()
@@ -218,7 +223,13 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter":
inputValue := m.slugInput.Value()
if err := m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue); err != nil {
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.GetUser(), types.SessionKey{
Id: m.interaction.slugManager.Get(),
Type: types.HTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.HTTP,
}); err != nil {
m.slugError = err.Error()
return m, nil
}

View File

@@ -19,30 +19,36 @@ type Forwarder interface {
GetForwardedPort() uint16
}
type Lifecycle struct {
status types.Status
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
slugManager slug.Manager
unregisterClient func(key types.SessionKey)
startedAt time.Time
type SessionRegistry interface {
Remove(key types.SessionKey)
}
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
type Lifecycle struct {
status types.Status
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
sessionRegistry SessionRegistry
slugManager slug.Manager
startedAt time.Time
user string
}
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager, user string) *Lifecycle {
return &Lifecycle{
status: types.INITIALIZING,
conn: conn,
channel: nil,
forwarder: forwarder,
slugManager: slugManager,
unregisterClient: nil,
startedAt: time.Now(),
status: types.INITIALIZING,
conn: conn,
channel: nil,
forwarder: forwarder,
slugManager: slugManager,
sessionRegistry: nil,
startedAt: time.Now(),
user: user,
}
}
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(key types.SessionKey)) {
l.unregisterClient = unregisterClient
func (l *Lifecycle) SetSessionRegistry(registry SessionRegistry) {
l.sessionRegistry = registry
}
type SessionLifecycle interface {
@@ -50,12 +56,17 @@ type SessionLifecycle interface {
SetStatus(status types.Status)
GetConnection() ssh.Conn
GetChannel() ssh.Channel
GetUser() string
SetChannel(channel ssh.Channel)
SetUnregisterClient(unregisterClient func(key types.SessionKey))
SetSessionRegistry(registry SessionRegistry)
IsActive() bool
StartedAt() time.Time
}
func (l *Lifecycle) GetUser() string {
return l.user
}
func (l *Lifecycle) GetChannel() ssh.Channel {
return l.channel
}
@@ -94,13 +105,13 @@ func (l *Lifecycle) Close() error {
}
clientSlug := l.slugManager.Get()
if clientSlug != "" && l.unregisterClient != nil {
if clientSlug != "" && l.sessionRegistry.Remove != nil {
key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()}
l.unregisterClient(key)
l.sessionRegistry.Remove(key)
}
if l.forwarder.GetTunnelType() == types.TCP {
err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
err = portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
if err != nil {
return err
}

View File

@@ -10,7 +10,7 @@ type Key = types.SessionKey
type Registry interface {
Get(key Key) (session *SSHSession, err error)
Update(oldKey, newKey Key) error
Update(user string, oldKey, newKey Key) error
Register(key Key, session *SSHSession) (success bool)
Remove(key Key)
GetAllSessionFromUser(user string) []*SSHSession
@@ -44,7 +44,7 @@ func (r *registry) Get(key Key) (session *SSHSession, err error) {
return client, nil
}
func (r *registry) Update(oldKey, newKey Key) error {
func (r *registry) Update(user string, oldKey, newKey Key) error {
if oldKey.Type != newKey.Type {
return fmt.Errorf("tunnel type cannot change")
}
@@ -64,30 +64,24 @@ func (r *registry) Update(oldKey, newKey Key) error {
r.mu.Lock()
defer r.mu.Unlock()
userID, ok := r.slugIndex[oldKey]
if !ok {
return fmt.Errorf("session not found")
}
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return fmt.Errorf("someone already uses this subdomain")
}
client, ok := r.byUser[userID][oldKey]
client, ok := r.byUser[user][oldKey]
if !ok {
return fmt.Errorf("session not found")
}
delete(r.byUser[userID], oldKey)
delete(r.byUser[user], oldKey)
delete(r.slugIndex, oldKey)
client.slugManager.Set(newKey.Id)
r.slugIndex[newKey] = userID
r.slugIndex[newKey] = user
if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]*SSHSession)
if r.byUser[user] == nil {
r.byUser[user] = make(map[Key]*SSHSession)
}
r.byUser[userID][newKey] = client
r.byUser[user][newKey] = client
return nil
}
@@ -99,7 +93,7 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) {
return false
}
userID := session.userID
userID := session.lifecycle.GetUser()
if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]*SSHSession)
}

View File

@@ -9,7 +9,6 @@ import (
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
@@ -29,7 +28,6 @@ type SSHSession struct {
forwarder forwarder.ForwardingController
slugManager slug.Manager
registry Registry
userID string
}
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
@@ -48,22 +46,16 @@ func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager
}
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession {
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) *SSHSession {
slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager)
interactionManager.SetSlugModificator(func(oldSlug, newSlug string) error {
oldKey := types.SessionKey{Id: oldSlug, Type: forwarderManager.GetTunnelType()}
newKey := types.SessionKey{Id: newSlug, Type: forwarderManager.GetTunnelType()}
return sessionRegistry.Update(oldKey, newKey)
})
forwarderManager.SetLifecycle(lifecycleManager)
lifecycleManager.SetUnregisterClient(func(key types.SessionKey) {
sessionRegistry.Remove(key)
})
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
return &SSHSession{
initialReq: forwardingReq,
@@ -73,7 +65,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
forwarder: forwarderManager,
slugManager: slugManager,
registry: sessionRegistry,
userID: userID,
}
}
@@ -89,7 +80,7 @@ func (s *SSHSession) Detail() Detail {
return Detail{
ForwardingType: string(s.forwarder.GetTunnelType()),
Slug: s.slugManager.Get(),
UserID: s.userID,
UserID: s.lifecycle.GetUser(),
Active: s.lifecycle.IsActive(),
StartedAt: s.lifecycle.StartedAt(),
}