refactor: inject SessionRegistry interface instead of individual functions
This commit is contained in:
2
go.mod
2
go.mod
@@ -52,4 +52,4 @@ require (
|
|||||||
golang.org/x/text v0.32.0 // indirect
|
golang.org/x/text v0.32.0 // indirect
|
||||||
golang.org/x/tools v0.40.0 // indirect
|
golang.org/x/tools v0.40.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
|
||||||
)
|
)
|
||||||
@@ -217,6 +217,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
|
|||||||
}
|
}
|
||||||
switch recv.GetType() {
|
switch recv.GetType() {
|
||||||
case proto.EventType_SLUG_CHANGE:
|
case proto.EventType_SLUG_CHANGE:
|
||||||
|
user := recv.GetSlugEvent().GetUser()
|
||||||
oldSlug := recv.GetSlugEvent().GetOld()
|
oldSlug := recv.GetSlugEvent().GetOld()
|
||||||
newSlug := recv.GetSlugEvent().GetNew()
|
newSlug := recv.GetSlugEvent().GetNew()
|
||||||
var userSession *session.SSHSession
|
var userSession *session.SSHSession
|
||||||
@@ -242,7 +243,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = c.sessionRegistry.Update(types.SessionKey{
|
err = c.sessionRegistry.Update(user, types.SessionKey{
|
||||||
Id: oldSlug,
|
Id: oldSlug,
|
||||||
Type: types.HTTP,
|
Type: types.HTTP,
|
||||||
}, types.SessionKey{
|
}, types.SessionKey{
|
||||||
|
|||||||
@@ -23,15 +23,20 @@ import (
|
|||||||
|
|
||||||
type Lifecycle interface {
|
type Lifecycle interface {
|
||||||
Close() error
|
Close() error
|
||||||
|
GetUser() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionRegistry interface {
|
||||||
|
Update(user string, oldKey, newKey types.SessionKey) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Controller interface {
|
type Controller interface {
|
||||||
SetChannel(channel ssh.Channel)
|
SetChannel(channel ssh.Channel)
|
||||||
SetLifecycle(lifecycle Lifecycle)
|
SetLifecycle(lifecycle Lifecycle)
|
||||||
SetSlugModificator(func(oldSlug, newSlug string) error)
|
|
||||||
Start()
|
Start()
|
||||||
SetWH(w, h int)
|
SetWH(w, h int)
|
||||||
Redraw()
|
Redraw()
|
||||||
|
SetSessionRegistry(registry SessionRegistry)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Forwarder interface {
|
type Forwarder interface {
|
||||||
@@ -41,14 +46,14 @@ type Forwarder interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Interaction struct {
|
type Interaction struct {
|
||||||
channel ssh.Channel
|
channel ssh.Channel
|
||||||
slugManager slug.Manager
|
slugManager slug.Manager
|
||||||
forwarder Forwarder
|
forwarder Forwarder
|
||||||
lifecycle Lifecycle
|
lifecycle Lifecycle
|
||||||
updateClientSlug func(oldSlug, newSlug string) error
|
sessionRegistry SessionRegistry
|
||||||
program *tea.Program
|
program *tea.Program
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interaction) SetWH(w, h int) {
|
func (i *Interaction) SetWH(w, h int) {
|
||||||
@@ -102,17 +107,21 @@ type tickMsg time.Time
|
|||||||
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
|
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Interaction{
|
return &Interaction{
|
||||||
channel: nil,
|
channel: nil,
|
||||||
slugManager: slugManager,
|
slugManager: slugManager,
|
||||||
forwarder: forwarder,
|
forwarder: forwarder,
|
||||||
lifecycle: nil,
|
lifecycle: nil,
|
||||||
updateClientSlug: nil,
|
sessionRegistry: nil,
|
||||||
program: nil,
|
program: nil,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *Interaction) SetSessionRegistry(registry SessionRegistry) {
|
||||||
|
i.sessionRegistry = registry
|
||||||
|
}
|
||||||
|
|
||||||
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
|
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
|
||||||
i.lifecycle = lifecycle
|
i.lifecycle = lifecycle
|
||||||
}
|
}
|
||||||
@@ -121,10 +130,6 @@ func (i *Interaction) SetChannel(channel ssh.Channel) {
|
|||||||
i.channel = channel
|
i.channel = channel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) error) {
|
|
||||||
i.updateClientSlug = modificator
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *Interaction) Stop() {
|
func (i *Interaction) Stop() {
|
||||||
if i.cancel != nil {
|
if i.cancel != nil {
|
||||||
i.cancel()
|
i.cancel()
|
||||||
@@ -218,7 +223,13 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
case "enter":
|
case "enter":
|
||||||
inputValue := m.slugInput.Value()
|
inputValue := m.slugInput.Value()
|
||||||
if err := m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue); err != nil {
|
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.GetUser(), types.SessionKey{
|
||||||
|
Id: m.interaction.slugManager.Get(),
|
||||||
|
Type: types.HTTP,
|
||||||
|
}, types.SessionKey{
|
||||||
|
Id: inputValue,
|
||||||
|
Type: types.HTTP,
|
||||||
|
}); err != nil {
|
||||||
m.slugError = err.Error()
|
m.slugError = err.Error()
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,30 +19,36 @@ type Forwarder interface {
|
|||||||
GetForwardedPort() uint16
|
GetForwardedPort() uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
type Lifecycle struct {
|
type SessionRegistry interface {
|
||||||
status types.Status
|
Remove(key types.SessionKey)
|
||||||
conn ssh.Conn
|
|
||||||
channel ssh.Channel
|
|
||||||
forwarder Forwarder
|
|
||||||
slugManager slug.Manager
|
|
||||||
unregisterClient func(key types.SessionKey)
|
|
||||||
startedAt time.Time
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
|
type Lifecycle struct {
|
||||||
|
status types.Status
|
||||||
|
conn ssh.Conn
|
||||||
|
channel ssh.Channel
|
||||||
|
forwarder Forwarder
|
||||||
|
sessionRegistry SessionRegistry
|
||||||
|
slugManager slug.Manager
|
||||||
|
startedAt time.Time
|
||||||
|
user string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager, user string) *Lifecycle {
|
||||||
return &Lifecycle{
|
return &Lifecycle{
|
||||||
status: types.INITIALIZING,
|
status: types.INITIALIZING,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
channel: nil,
|
channel: nil,
|
||||||
forwarder: forwarder,
|
forwarder: forwarder,
|
||||||
slugManager: slugManager,
|
slugManager: slugManager,
|
||||||
unregisterClient: nil,
|
sessionRegistry: nil,
|
||||||
startedAt: time.Now(),
|
startedAt: time.Now(),
|
||||||
|
user: user,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(key types.SessionKey)) {
|
func (l *Lifecycle) SetSessionRegistry(registry SessionRegistry) {
|
||||||
l.unregisterClient = unregisterClient
|
l.sessionRegistry = registry
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionLifecycle interface {
|
type SessionLifecycle interface {
|
||||||
@@ -50,12 +56,17 @@ type SessionLifecycle interface {
|
|||||||
SetStatus(status types.Status)
|
SetStatus(status types.Status)
|
||||||
GetConnection() ssh.Conn
|
GetConnection() ssh.Conn
|
||||||
GetChannel() ssh.Channel
|
GetChannel() ssh.Channel
|
||||||
|
GetUser() string
|
||||||
SetChannel(channel ssh.Channel)
|
SetChannel(channel ssh.Channel)
|
||||||
SetUnregisterClient(unregisterClient func(key types.SessionKey))
|
SetSessionRegistry(registry SessionRegistry)
|
||||||
IsActive() bool
|
IsActive() bool
|
||||||
StartedAt() time.Time
|
StartedAt() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *Lifecycle) GetUser() string {
|
||||||
|
return l.user
|
||||||
|
}
|
||||||
|
|
||||||
func (l *Lifecycle) GetChannel() ssh.Channel {
|
func (l *Lifecycle) GetChannel() ssh.Channel {
|
||||||
return l.channel
|
return l.channel
|
||||||
}
|
}
|
||||||
@@ -94,13 +105,13 @@ func (l *Lifecycle) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
clientSlug := l.slugManager.Get()
|
clientSlug := l.slugManager.Get()
|
||||||
if clientSlug != "" && l.unregisterClient != nil {
|
if clientSlug != "" && l.sessionRegistry.Remove != nil {
|
||||||
key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()}
|
key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()}
|
||||||
l.unregisterClient(key)
|
l.sessionRegistry.Remove(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
if l.forwarder.GetTunnelType() == types.TCP {
|
if l.forwarder.GetTunnelType() == types.TCP {
|
||||||
err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
|
err = portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ type Key = types.SessionKey
|
|||||||
|
|
||||||
type Registry interface {
|
type Registry interface {
|
||||||
Get(key Key) (session *SSHSession, err error)
|
Get(key Key) (session *SSHSession, err error)
|
||||||
Update(oldKey, newKey Key) error
|
Update(user string, oldKey, newKey Key) error
|
||||||
Register(key Key, session *SSHSession) (success bool)
|
Register(key Key, session *SSHSession) (success bool)
|
||||||
Remove(key Key)
|
Remove(key Key)
|
||||||
GetAllSessionFromUser(user string) []*SSHSession
|
GetAllSessionFromUser(user string) []*SSHSession
|
||||||
@@ -44,7 +44,7 @@ func (r *registry) Get(key Key) (session *SSHSession, err error) {
|
|||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registry) Update(oldKey, newKey Key) error {
|
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
||||||
if oldKey.Type != newKey.Type {
|
if oldKey.Type != newKey.Type {
|
||||||
return fmt.Errorf("tunnel type cannot change")
|
return fmt.Errorf("tunnel type cannot change")
|
||||||
}
|
}
|
||||||
@@ -64,30 +64,24 @@ func (r *registry) Update(oldKey, newKey Key) error {
|
|||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
userID, ok := r.slugIndex[oldKey]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("session not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
||||||
return fmt.Errorf("someone already uses this subdomain")
|
return fmt.Errorf("someone already uses this subdomain")
|
||||||
}
|
}
|
||||||
|
client, ok := r.byUser[user][oldKey]
|
||||||
client, ok := r.byUser[userID][oldKey]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("session not found")
|
return fmt.Errorf("session not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(r.byUser[userID], oldKey)
|
delete(r.byUser[user], oldKey)
|
||||||
delete(r.slugIndex, oldKey)
|
delete(r.slugIndex, oldKey)
|
||||||
|
|
||||||
client.slugManager.Set(newKey.Id)
|
client.slugManager.Set(newKey.Id)
|
||||||
r.slugIndex[newKey] = userID
|
r.slugIndex[newKey] = user
|
||||||
|
|
||||||
if r.byUser[userID] == nil {
|
if r.byUser[user] == nil {
|
||||||
r.byUser[userID] = make(map[Key]*SSHSession)
|
r.byUser[user] = make(map[Key]*SSHSession)
|
||||||
}
|
}
|
||||||
r.byUser[userID][newKey] = client
|
r.byUser[user][newKey] = client
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,7 +93,7 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := session.userID
|
userID := session.lifecycle.GetUser()
|
||||||
if r.byUser[userID] == nil {
|
if r.byUser[userID] == nil {
|
||||||
r.byUser[userID] = make(map[Key]*SSHSession)
|
r.byUser[userID] = make(map[Key]*SSHSession)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"tunnel_pls/session/interaction"
|
"tunnel_pls/session/interaction"
|
||||||
"tunnel_pls/session/lifecycle"
|
"tunnel_pls/session/lifecycle"
|
||||||
"tunnel_pls/session/slug"
|
"tunnel_pls/session/slug"
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
@@ -29,7 +28,6 @@ type SSHSession struct {
|
|||||||
forwarder forwarder.ForwardingController
|
forwarder forwarder.ForwardingController
|
||||||
slugManager slug.Manager
|
slugManager slug.Manager
|
||||||
registry Registry
|
registry Registry
|
||||||
userID string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
|
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
|
||||||
@@ -48,22 +46,16 @@ func (s *SSHSession) GetSlugManager() slug.Manager {
|
|||||||
return s.slugManager
|
return s.slugManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession {
|
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) *SSHSession {
|
||||||
slugManager := slug.NewManager()
|
slugManager := slug.NewManager()
|
||||||
forwarderManager := forwarder.NewForwarder(slugManager)
|
forwarderManager := forwarder.NewForwarder(slugManager)
|
||||||
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
|
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
|
||||||
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager)
|
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager, user)
|
||||||
|
|
||||||
interactionManager.SetLifecycle(lifecycleManager)
|
interactionManager.SetLifecycle(lifecycleManager)
|
||||||
interactionManager.SetSlugModificator(func(oldSlug, newSlug string) error {
|
|
||||||
oldKey := types.SessionKey{Id: oldSlug, Type: forwarderManager.GetTunnelType()}
|
|
||||||
newKey := types.SessionKey{Id: newSlug, Type: forwarderManager.GetTunnelType()}
|
|
||||||
return sessionRegistry.Update(oldKey, newKey)
|
|
||||||
})
|
|
||||||
forwarderManager.SetLifecycle(lifecycleManager)
|
forwarderManager.SetLifecycle(lifecycleManager)
|
||||||
lifecycleManager.SetUnregisterClient(func(key types.SessionKey) {
|
interactionManager.SetSessionRegistry(sessionRegistry)
|
||||||
sessionRegistry.Remove(key)
|
lifecycleManager.SetSessionRegistry(sessionRegistry)
|
||||||
})
|
|
||||||
|
|
||||||
return &SSHSession{
|
return &SSHSession{
|
||||||
initialReq: forwardingReq,
|
initialReq: forwardingReq,
|
||||||
@@ -73,7 +65,6 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
|||||||
forwarder: forwarderManager,
|
forwarder: forwarderManager,
|
||||||
slugManager: slugManager,
|
slugManager: slugManager,
|
||||||
registry: sessionRegistry,
|
registry: sessionRegistry,
|
||||||
userID: userID,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,7 +80,7 @@ func (s *SSHSession) Detail() Detail {
|
|||||||
return Detail{
|
return Detail{
|
||||||
ForwardingType: string(s.forwarder.GetTunnelType()),
|
ForwardingType: string(s.forwarder.GetTunnelType()),
|
||||||
Slug: s.slugManager.Get(),
|
Slug: s.slugManager.Get(),
|
||||||
UserID: s.userID,
|
UserID: s.lifecycle.GetUser(),
|
||||||
Active: s.lifecycle.IsActive(),
|
Active: s.lifecycle.IsActive(),
|
||||||
StartedAt: s.lifecycle.StartedAt(),
|
StartedAt: s.lifecycle.StartedAt(),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user