feat: implement get sessions by user
This commit is contained in:
@@ -4,6 +4,8 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
@@ -24,16 +26,18 @@ type Lifecycle struct {
|
||||
forwarder Forwarder
|
||||
slugManager slug.Manager
|
||||
unregisterClient func(slug string)
|
||||
startedAt time.Time
|
||||
}
|
||||
|
||||
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
|
||||
return &Lifecycle{
|
||||
status: "",
|
||||
status: types.INITIALIZING,
|
||||
conn: conn,
|
||||
channel: nil,
|
||||
forwarder: forwarder,
|
||||
slugManager: slugManager,
|
||||
unregisterClient: nil,
|
||||
startedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,6 +52,8 @@ type SessionLifecycle interface {
|
||||
GetChannel() ssh.Channel
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetUnregisterClient(unregisterClient func(slug string))
|
||||
IsActive() bool
|
||||
StartedAt() time.Time
|
||||
}
|
||||
|
||||
func (l *Lifecycle) GetChannel() ssh.Channel {
|
||||
@@ -62,6 +68,9 @@ func (l *Lifecycle) GetConnection() ssh.Conn {
|
||||
}
|
||||
func (l *Lifecycle) SetStatus(status types.Status) {
|
||||
l.status = status
|
||||
if status == types.RUNNING && l.startedAt.IsZero() {
|
||||
l.startedAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lifecycle) Close() error {
|
||||
@@ -98,3 +107,11 @@ func (l *Lifecycle) Close() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *Lifecycle) IsActive() bool {
|
||||
return l.status == types.RUNNING
|
||||
}
|
||||
|
||||
func (l *Lifecycle) StartedAt() time.Time {
|
||||
return l.startedAt
|
||||
}
|
||||
|
||||
@@ -10,15 +10,18 @@ type Registry interface {
|
||||
Update(oldSlug, newSlug string) error
|
||||
Register(slug string, session *SSHSession) (success bool)
|
||||
Remove(slug string)
|
||||
GetAllSessionFromUser(user string) []*SSHSession
|
||||
}
|
||||
type registry struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*SSHSession
|
||||
mu sync.RWMutex
|
||||
byUser map[string]map[string]*SSHSession
|
||||
slugIndex map[string]string
|
||||
}
|
||||
|
||||
func NewRegistry() Registry {
|
||||
return ®istry{
|
||||
clients: make(map[string]*SSHSession),
|
||||
byUser: make(map[string]map[string]*SSHSession),
|
||||
slugIndex: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +29,12 @@ func (r *registry) Get(slug string) (session *SSHSession, err error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
client, ok := r.clients[slug]
|
||||
userID, ok := r.slugIndex[slug]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
client, ok := r.byUser[userID][slug]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
@@ -43,18 +51,30 @@ func (r *registry) Update(oldSlug, newSlug string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.clients[newSlug]; exists && newSlug != oldSlug {
|
||||
return fmt.Errorf("someone already uses this subdomain")
|
||||
}
|
||||
|
||||
client, ok := r.clients[oldSlug]
|
||||
userID, ok := r.slugIndex[oldSlug]
|
||||
if !ok {
|
||||
return fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
delete(r.clients, oldSlug)
|
||||
if _, exists := r.slugIndex[newSlug]; exists && newSlug != oldSlug {
|
||||
return fmt.Errorf("someone already uses this subdomain")
|
||||
}
|
||||
|
||||
client, ok := r.byUser[userID][oldSlug]
|
||||
if !ok {
|
||||
return fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
delete(r.byUser[userID], oldSlug)
|
||||
delete(r.slugIndex, oldSlug)
|
||||
|
||||
client.slugManager.Set(newSlug)
|
||||
r.clients[newSlug] = client
|
||||
r.slugIndex[newSlug] = userID
|
||||
|
||||
if r.byUser[userID] == nil {
|
||||
r.byUser[userID] = make(map[string]*SSHSession)
|
||||
}
|
||||
r.byUser[userID][newSlug] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -62,19 +82,50 @@ func (r *registry) Register(slug string, session *SSHSession) (success bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.clients[slug]; exists {
|
||||
if _, exists := r.slugIndex[slug]; exists {
|
||||
return false
|
||||
}
|
||||
|
||||
r.clients[slug] = session
|
||||
userID := session.userID
|
||||
if r.byUser[userID] == nil {
|
||||
r.byUser[userID] = make(map[string]*SSHSession)
|
||||
}
|
||||
|
||||
r.byUser[userID][slug] = session
|
||||
r.slugIndex[slug] = userID
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *registry) GetAllSessionFromUser(user string) []*SSHSession {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
m := r.byUser[user]
|
||||
if len(m) == 0 {
|
||||
return []*SSHSession{}
|
||||
}
|
||||
|
||||
sessions := make([]*SSHSession, 0, len(m))
|
||||
for _, s := range m {
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (r *registry) Remove(slug string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
delete(r.clients, slug)
|
||||
userID, ok := r.slugIndex[slug]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(r.byUser[userID], slug)
|
||||
if len(r.byUser[userID]) == 0 {
|
||||
delete(r.byUser, userID)
|
||||
}
|
||||
delete(r.slugIndex, slug)
|
||||
}
|
||||
|
||||
func isValidSlug(slug string) bool {
|
||||
|
||||
@@ -28,6 +28,7 @@ type SSHSession struct {
|
||||
forwarder forwarder.ForwardingController
|
||||
slugManager slug.Manager
|
||||
registry Registry
|
||||
userID string
|
||||
}
|
||||
|
||||
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
|
||||
@@ -46,7 +47,7 @@ func (s *SSHSession) GetSlugManager() slug.Manager {
|
||||
return s.slugManager
|
||||
}
|
||||
|
||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry) *SSHSession {
|
||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession {
|
||||
slugManager := slug.NewManager()
|
||||
forwarderManager := forwarder.NewForwarder(slugManager)
|
||||
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
|
||||
@@ -65,6 +66,25 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
|
||||
forwarder: forwarderManager,
|
||||
slugManager: slugManager,
|
||||
registry: sessionRegistry,
|
||||
userID: userID,
|
||||
}
|
||||
}
|
||||
|
||||
type Detail struct {
|
||||
ForwardingType string `json:"forwarding_type,omitempty"`
|
||||
Slug string `json:"slug,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Active bool `json:"active,omitempty"`
|
||||
StartedAt time.Time `json:"started_at,omitempty"`
|
||||
}
|
||||
|
||||
func (s *SSHSession) Detail() Detail {
|
||||
return Detail{
|
||||
ForwardingType: string(s.forwarder.GetTunnelType()),
|
||||
Slug: s.slugManager.Get(),
|
||||
UserID: s.userID,
|
||||
Active: s.lifecycle.IsActive(),
|
||||
StartedAt: s.lifecycle.StartedAt(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,7 +106,7 @@ func (s *SSHSession) Start() error {
|
||||
if err := s.lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return fmt.Errorf("No forwarding Request")
|
||||
return fmt.Errorf("no forwarding Request")
|
||||
}
|
||||
|
||||
s.lifecycle.SetChannel(ch)
|
||||
|
||||
Reference in New Issue
Block a user