refactor: inject SessionRegistry interface instead of individual functions
This commit is contained in:
4
go.mod
4
go.mod
@@ -53,3 +53,7 @@ require (
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
|
||||
)
|
||||
|
||||
replace (
|
||||
git.fossy.my.id/bagas/tunnel-please-grpc => ../tunnel-please-grpc
|
||||
)
|
||||
|
||||
@@ -217,6 +217,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
|
||||
}
|
||||
switch recv.GetType() {
|
||||
case proto.EventType_SLUG_CHANGE:
|
||||
user := recv.GetSlugEvent().GetUser()
|
||||
oldSlug := recv.GetSlugEvent().GetOld()
|
||||
newSlug := recv.GetSlugEvent().GetNew()
|
||||
var userSession *session.SSHSession
|
||||
@@ -242,7 +243,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
|
||||
}
|
||||
continue
|
||||
}
|
||||
err = c.sessionRegistry.Update(types.SessionKey{
|
||||
err = c.sessionRegistry.Update(user, types.SessionKey{
|
||||
Id: oldSlug,
|
||||
Type: types.HTTP,
|
||||
}, types.SessionKey{
|
||||
|
||||
@@ -23,15 +23,20 @@ import (
|
||||
|
||||
type Lifecycle interface {
|
||||
Close() error
|
||||
GetUser() string
|
||||
}
|
||||
|
||||
type SessionRegistry interface {
|
||||
Update(user string, oldKey, newKey types.SessionKey) error
|
||||
}
|
||||
|
||||
type Controller interface {
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
SetSlugModificator(func(oldSlug, newSlug string) error)
|
||||
Start()
|
||||
SetWH(w, h int)
|
||||
Redraw()
|
||||
SetSessionRegistry(registry SessionRegistry)
|
||||
}
|
||||
|
||||
type Forwarder interface {
|
||||
@@ -45,7 +50,7 @@ type Interaction struct {
|
||||
slugManager slug.Manager
|
||||
forwarder Forwarder
|
||||
lifecycle Lifecycle
|
||||
updateClientSlug func(oldSlug, newSlug string) error
|
||||
sessionRegistry SessionRegistry
|
||||
program *tea.Program
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -106,13 +111,17 @@ func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction
|
||||
slugManager: slugManager,
|
||||
forwarder: forwarder,
|
||||
lifecycle: nil,
|
||||
updateClientSlug: nil,
|
||||
sessionRegistry: nil,
|
||||
program: nil,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Interaction) SetSessionRegistry(registry SessionRegistry) {
|
||||
i.sessionRegistry = registry
|
||||
}
|
||||
|
||||
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
|
||||
i.lifecycle = lifecycle
|
||||
}
|
||||
@@ -121,10 +130,6 @@ func (i *Interaction) SetChannel(channel ssh.Channel) {
|
||||
i.channel = channel
|
||||
}
|
||||
|
||||
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) error) {
|
||||
i.updateClientSlug = modificator
|
||||
}
|
||||
|
||||
func (i *Interaction) Stop() {
|
||||
if i.cancel != nil {
|
||||
i.cancel()
|
||||
@@ -218,7 +223,13 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case "enter":
|
||||
inputValue := m.slugInput.Value()
|
||||
if err := m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue); err != nil {
|
||||
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.GetUser(), types.SessionKey{
|
||||
Id: m.interaction.slugManager.Get(),
|
||||
Type: types.HTTP,
|
||||
}, types.SessionKey{
|
||||
Id: inputValue,
|
||||
Type: types.HTTP,
|
||||
}); err != nil {
|
||||
m.slugError = err.Error()
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -19,30 +19,36 @@ type Forwarder interface {
|
||||
GetForwardedPort() uint16
|
||||
}
|
||||
|
||||
type SessionRegistry interface {
|
||||
Remove(key types.SessionKey)
|
||||
}
|
||||
|
||||
type Lifecycle struct {
|
||||
status types.Status
|
||||
conn ssh.Conn
|
||||
channel ssh.Channel
|
||||
forwarder Forwarder
|
||||
sessionRegistry SessionRegistry
|
||||
slugManager slug.Manager
|
||||
unregisterClient func(key types.SessionKey)
|
||||
startedAt time.Time
|
||||
user string
|
||||
}
|
||||
|
||||
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
|
||||
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager, user string) *Lifecycle {
|
||||
return &Lifecycle{
|
||||
status: types.INITIALIZING,
|
||||
conn: conn,
|
||||
channel: nil,
|
||||
forwarder: forwarder,
|
||||
slugManager: slugManager,
|
||||
unregisterClient: nil,
|
||||
sessionRegistry: nil,
|
||||
startedAt: time.Now(),
|
||||
user: user,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(key types.SessionKey)) {
|
||||
l.unregisterClient = unregisterClient
|
||||
func (l *Lifecycle) SetSessionRegistry(registry SessionRegistry) {
|
||||
l.sessionRegistry = registry
|
||||
}
|
||||
|
||||
type SessionLifecycle interface {
|
||||
@@ -50,12 +56,17 @@ type SessionLifecycle interface {
|
||||
SetStatus(status types.Status)
|
||||
GetConnection() ssh.Conn
|
||||
GetChannel() ssh.Channel
|
||||
GetUser() string
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetUnregisterClient(unregisterClient func(key types.SessionKey))
|
||||
SetSessionRegistry(registry SessionRegistry)
|
||||
IsActive() bool
|
||||
StartedAt() time.Time
|
||||
}
|
||||
|
||||
func (l *Lifecycle) GetUser() string {
|
||||
return l.user
|
||||
}
|
||||
|
||||
func (l *Lifecycle) GetChannel() ssh.Channel {
|
||||
return l.channel
|
||||
}
|
||||
@@ -94,13 +105,13 @@ func (l *Lifecycle) Close() error {
|
||||
}
|
||||
|
||||
clientSlug := l.slugManager.Get()
|
||||
if clientSlug != "" && l.unregisterClient != nil {
|
||||
if clientSlug != "" && l.sessionRegistry.Remove != nil {
|
||||
key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()}
|
||||
l.unregisterClient(key)
|
||||
l.sessionRegistry.Remove(key)
|
||||
}
|
||||
|
||||
if l.forwarder.GetTunnelType() == types.TCP {
|
||||
err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
|
||||
err = portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ type Key = types.SessionKey
|
||||
|
||||
type Registry interface {
|
||||
Get(key Key) (session *SSHSession, err error)
|
||||
Update(oldKey, newKey Key) error
|
||||
Update(user string, oldKey, newKey Key) error
|
||||
Register(key Key, session *SSHSession) (success bool)
|
||||
Remove(key Key)
|
||||
GetAllSessionFromUser(user string) []*SSHSession
|
||||
@@ -44,7 +44,7 @@ func (r *registry) Get(key Key) (session *SSHSession, err error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (r *registry) Update(oldKey, newKey Key) error {
|
||||
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
||||
if oldKey.Type != newKey.Type {
|
||||
return fmt.Errorf("tunnel type cannot change")
|
||||
}
|
||||
@@ -64,30 +64,24 @@ func (r *registry) Update(oldKey, newKey Key) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
userID, ok := r.slugIndex[oldKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
||||
return fmt.Errorf("someone already uses this subdomain")
|
||||
}
|
||||
|
||||
client, ok := r.byUser[userID][oldKey]
|
||||
client, ok := r.byUser[user][oldKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
delete(r.byUser[userID], oldKey)
|
||||
delete(r.byUser[user], oldKey)
|
||||
delete(r.slugIndex, oldKey)
|
||||
|
||||
client.slugManager.Set(newKey.Id)
|
||||
r.slugIndex[newKey] = userID
|
||||
r.slugIndex[newKey] = user
|
||||
|
||||
if r.byUser[userID] == nil {
|
||||
r.byUser[userID] = make(map[Key]*SSHSession)
|
||||
if r.byUser[user] == nil {
|
||||
r.byUser[user] = make(map[Key]*SSHSession)
|
||||
}
|
||||
r.byUser[userID][newKey] = client
|
||||
r.byUser[user][newKey] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -99,7 +93,7 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
userID := session.userID
|
||||
userID := session.lifecycle.GetUser()
|
||||
if r.byUser[userID] == nil {
|
||||
r.byUser[userID] = make(map[Key]*SSHSession)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
@@ -29,7 +28,6 @@ type SSHSession struct {
|
||||
forwarder forwarder.ForwardingController
|
||||
slugManager slug.Manager
|
||||
registry Registry
|
||||
userID string
|
||||
}
|
||||
|
||||
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
|
||||
@@ -48,22 +46,16 @@ func (s *SSHSession) GetSlugManager() slug.Manager {
|
||||
return s.slugManager
|
||||
}
|
||||
|
||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession {
|
||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) *SSHSession {
|
||||
slugManager := slug.NewManager()
|
||||
forwarderManager := forwarder.NewForwarder(slugManager)
|
||||
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
|
||||
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager)
|
||||
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager, user)
|
||||
|
||||
interactionManager.SetLifecycle(lifecycleManager)
|
||||
interactionManager.SetSlugModificator(func(oldSlug, newSlug string) error {
|
||||
oldKey := types.SessionKey{Id: oldSlug, Type: forwarderManager.GetTunnelType()}
|
||||
newKey := types.SessionKey{Id: newSlug, Type: forwarderManager.GetTunnelType()}
|
||||
return sessionRegistry.Update(oldKey, newKey)
|
||||
})
|
||||
forwarderManager.SetLifecycle(lifecycleManager)
|
||||
lifecycleManager.SetUnregisterClient(func(key types.SessionKey) {
|
||||
sessionRegistry.Remove(key)
|
||||
})
|
||||
interactionManager.SetSessionRegistry(sessionRegistry)
|
||||
lifecycleManager.SetSessionRegistry(sessionRegistry)
|
||||
|
||||
return &SSHSession{
|
||||
initialReq: forwardingReq,
|
||||
@@ -73,7 +65,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
||||
forwarder: forwarderManager,
|
||||
slugManager: slugManager,
|
||||
registry: sessionRegistry,
|
||||
userID: userID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +80,7 @@ func (s *SSHSession) Detail() Detail {
|
||||
return Detail{
|
||||
ForwardingType: string(s.forwarder.GetTunnelType()),
|
||||
Slug: s.slugManager.Get(),
|
||||
UserID: s.userID,
|
||||
UserID: s.lifecycle.GetUser(),
|
||||
Active: s.lifecycle.IsActive(),
|
||||
StartedAt: s.lifecycle.StartedAt(),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user