From 4ffaec9d9a6ce678c4859fbb9d41c56e10c3282f Mon Sep 17 00:00:00 2001 From: bagas Date: Mon, 5 Jan 2026 16:44:03 +0700 Subject: [PATCH] refactor: inject SessionRegistry interface instead of individual functions --- go.mod | 4 +-- go.sum | 2 ++ internal/grpc/client/client.go | 3 +- session/interaction/interaction.go | 55 ++++++++++++++++++------------ session/lifecycle/lifecycle.go | 55 ++++++++++++++++++------------ session/registry.go | 24 +++++-------- session/session.go | 19 +++-------- 7 files changed, 86 insertions(+), 76 deletions(-) diff --git a/go.mod b/go.mod index ae797c0..fe62062 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module tunnel_pls go 1.25.5 require ( - git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0 + git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0 github.com/caddyserver/certmagic v0.25.0 github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbletea v1.3.10 @@ -52,4 +52,4 @@ require ( golang.org/x/text v0.32.0 // indirect golang.org/x/tools v0.40.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect -) +) \ No newline at end of file diff --git a/go.sum b/go.sum index 1ce04dc..d477230 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0 h1:RhcBKUG41/om4jgN+iF/vlY/RojTeX1QhBa4p4428ec= git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY= +git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0 h1:tpJSKjaSmV+vxxbVx6qnStjxFVXjj2M0rygWXxLb99o= +git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 2243951..00eac89 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -217,6 +217,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod } switch recv.GetType() { case proto.EventType_SLUG_CHANGE: + user := recv.GetSlugEvent().GetUser() oldSlug := recv.GetSlugEvent().GetOld() newSlug := recv.GetSlugEvent().GetNew() var userSession *session.SSHSession @@ -242,7 +243,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod } continue } - err = c.sessionRegistry.Update(types.SessionKey{ + err = c.sessionRegistry.Update(user, types.SessionKey{ Id: oldSlug, Type: types.HTTP, }, types.SessionKey{ diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 3a36f4c..99723c0 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -23,15 +23,20 @@ import ( type Lifecycle interface { Close() error + GetUser() string +} + +type SessionRegistry interface { + Update(user string, oldKey, newKey types.SessionKey) error } type Controller interface { SetChannel(channel ssh.Channel) SetLifecycle(lifecycle Lifecycle) - SetSlugModificator(func(oldSlug, newSlug string) error) Start() SetWH(w, h int) Redraw() + SetSessionRegistry(registry SessionRegistry) } type Forwarder interface { @@ -41,14 +46,14 @@ type Forwarder interface { } type Interaction struct { - channel ssh.Channel - slugManager slug.Manager - forwarder Forwarder - lifecycle Lifecycle - updateClientSlug func(oldSlug, newSlug string) error - program *tea.Program - ctx context.Context - cancel context.CancelFunc + channel ssh.Channel + slugManager slug.Manager + forwarder Forwarder + lifecycle Lifecycle + sessionRegistry SessionRegistry + program *tea.Program + ctx context.Context + cancel context.CancelFunc } func (i *Interaction) SetWH(w, h int) { @@ -102,17 +107,21 @@ type tickMsg time.Time func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction { ctx, cancel := context.WithCancel(context.Background()) return &Interaction{ - channel: nil, - slugManager: slugManager, - forwarder: forwarder, - lifecycle: nil, - updateClientSlug: nil, - program: nil, - ctx: ctx, - cancel: cancel, + channel: nil, + slugManager: slugManager, + forwarder: forwarder, + lifecycle: nil, + sessionRegistry: nil, + program: nil, + ctx: ctx, + cancel: cancel, } } +func (i *Interaction) SetSessionRegistry(registry SessionRegistry) { + i.sessionRegistry = registry +} + func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { i.lifecycle = lifecycle } @@ -121,10 +130,6 @@ func (i *Interaction) SetChannel(channel ssh.Channel) { i.channel = channel } -func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) error) { - i.updateClientSlug = modificator -} - func (i *Interaction) Stop() { if i.cancel != nil { i.cancel() @@ -218,7 +223,13 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(tea.ClearScreen, textinput.Blink) case "enter": inputValue := m.slugInput.Value() - if err := m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue); err != nil { + if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.GetUser(), types.SessionKey{ + Id: m.interaction.slugManager.Get(), + Type: types.HTTP, + }, types.SessionKey{ + Id: inputValue, + Type: types.HTTP, + }); err != nil { m.slugError = err.Error() return m, nil } diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index ea96081..ccc01f0 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -19,30 +19,36 @@ type Forwarder interface { GetForwardedPort() uint16 } -type Lifecycle struct { - status types.Status - conn ssh.Conn - channel ssh.Channel - forwarder Forwarder - slugManager slug.Manager - unregisterClient func(key types.SessionKey) - startedAt time.Time +type SessionRegistry interface { + Remove(key types.SessionKey) } -func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle { +type Lifecycle struct { + status types.Status + conn ssh.Conn + channel ssh.Channel + forwarder Forwarder + sessionRegistry SessionRegistry + slugManager slug.Manager + startedAt time.Time + user string +} + +func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager, user string) *Lifecycle { return &Lifecycle{ - status: types.INITIALIZING, - conn: conn, - channel: nil, - forwarder: forwarder, - slugManager: slugManager, - unregisterClient: nil, - startedAt: time.Now(), + status: types.INITIALIZING, + conn: conn, + channel: nil, + forwarder: forwarder, + slugManager: slugManager, + sessionRegistry: nil, + startedAt: time.Now(), + user: user, } } -func (l *Lifecycle) SetUnregisterClient(unregisterClient func(key types.SessionKey)) { - l.unregisterClient = unregisterClient +func (l *Lifecycle) SetSessionRegistry(registry SessionRegistry) { + l.sessionRegistry = registry } type SessionLifecycle interface { @@ -50,12 +56,17 @@ type SessionLifecycle interface { SetStatus(status types.Status) GetConnection() ssh.Conn GetChannel() ssh.Channel + GetUser() string SetChannel(channel ssh.Channel) - SetUnregisterClient(unregisterClient func(key types.SessionKey)) + SetSessionRegistry(registry SessionRegistry) IsActive() bool StartedAt() time.Time } +func (l *Lifecycle) GetUser() string { + return l.user +} + func (l *Lifecycle) GetChannel() ssh.Channel { return l.channel } @@ -94,13 +105,13 @@ func (l *Lifecycle) Close() error { } clientSlug := l.slugManager.Get() - if clientSlug != "" && l.unregisterClient != nil { + if clientSlug != "" && l.sessionRegistry.Remove != nil { key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()} - l.unregisterClient(key) + l.sessionRegistry.Remove(key) } if l.forwarder.GetTunnelType() == types.TCP { - err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false) + err = portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false) if err != nil { return err } diff --git a/session/registry.go b/session/registry.go index 6a607f9..60b86d3 100644 --- a/session/registry.go +++ b/session/registry.go @@ -10,7 +10,7 @@ type Key = types.SessionKey type Registry interface { Get(key Key) (session *SSHSession, err error) - Update(oldKey, newKey Key) error + Update(user string, oldKey, newKey Key) error Register(key Key, session *SSHSession) (success bool) Remove(key Key) GetAllSessionFromUser(user string) []*SSHSession @@ -44,7 +44,7 @@ func (r *registry) Get(key Key) (session *SSHSession, err error) { return client, nil } -func (r *registry) Update(oldKey, newKey Key) error { +func (r *registry) Update(user string, oldKey, newKey Key) error { if oldKey.Type != newKey.Type { return fmt.Errorf("tunnel type cannot change") } @@ -64,30 +64,24 @@ func (r *registry) Update(oldKey, newKey Key) error { r.mu.Lock() defer r.mu.Unlock() - userID, ok := r.slugIndex[oldKey] - if !ok { - return fmt.Errorf("session not found") - } - if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { return fmt.Errorf("someone already uses this subdomain") } - - client, ok := r.byUser[userID][oldKey] + client, ok := r.byUser[user][oldKey] if !ok { return fmt.Errorf("session not found") } - delete(r.byUser[userID], oldKey) + delete(r.byUser[user], oldKey) delete(r.slugIndex, oldKey) client.slugManager.Set(newKey.Id) - r.slugIndex[newKey] = userID + r.slugIndex[newKey] = user - if r.byUser[userID] == nil { - r.byUser[userID] = make(map[Key]*SSHSession) + if r.byUser[user] == nil { + r.byUser[user] = make(map[Key]*SSHSession) } - r.byUser[userID][newKey] = client + r.byUser[user][newKey] = client return nil } @@ -99,7 +93,7 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) { return false } - userID := session.userID + userID := session.lifecycle.GetUser() if r.byUser[userID] == nil { r.byUser[userID] = make(map[Key]*SSHSession) } diff --git a/session/session.go b/session/session.go index d406160..b1a9ac5 100644 --- a/session/session.go +++ b/session/session.go @@ -9,7 +9,6 @@ import ( "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" "tunnel_pls/session/slug" - "tunnel_pls/types" "golang.org/x/crypto/ssh" ) @@ -29,7 +28,6 @@ type SSHSession struct { forwarder forwarder.ForwardingController slugManager slug.Manager registry Registry - userID string } func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle { @@ -48,22 +46,16 @@ func (s *SSHSession) GetSlugManager() slug.Manager { return s.slugManager } -func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession { +func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) *SSHSession { slugManager := slug.NewManager() forwarderManager := forwarder.NewForwarder(slugManager) interactionManager := interaction.NewInteraction(slugManager, forwarderManager) - lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager) + lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager, user) interactionManager.SetLifecycle(lifecycleManager) - interactionManager.SetSlugModificator(func(oldSlug, newSlug string) error { - oldKey := types.SessionKey{Id: oldSlug, Type: forwarderManager.GetTunnelType()} - newKey := types.SessionKey{Id: newSlug, Type: forwarderManager.GetTunnelType()} - return sessionRegistry.Update(oldKey, newKey) - }) forwarderManager.SetLifecycle(lifecycleManager) - lifecycleManager.SetUnregisterClient(func(key types.SessionKey) { - sessionRegistry.Remove(key) - }) + interactionManager.SetSessionRegistry(sessionRegistry) + lifecycleManager.SetSessionRegistry(sessionRegistry) return &SSHSession{ initialReq: forwardingReq, @@ -73,7 +65,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan forwarder: forwarderManager, slugManager: slugManager, registry: sessionRegistry, - userID: userID, } } @@ -89,7 +80,7 @@ func (s *SSHSession) Detail() Detail { return Detail{ ForwardingType: string(s.forwarder.GetTunnelType()), Slug: s.slugManager.Get(), - UserID: s.userID, + UserID: s.lifecycle.GetUser(), Active: s.lifecycle.IsActive(), StartedAt: s.lifecycle.StartedAt(), }