feat: implement get sessions by user

This commit is contained in:
2026-01-02 22:58:54 +07:00
parent fd6ffc2500
commit 30e84ac3b7
5 changed files with 182 additions and 35 deletions

View File

@@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
) )
type GrpcConfig struct { type GrpcConfig struct {
@@ -38,6 +39,7 @@ type Client struct {
sessionRegistry session.Registry sessionRegistry session.Registry
slugService proto.SlugChangeClient slugService proto.SlugChangeClient
eventService proto.EventServiceClient eventService proto.EventServiceClient
authorizeConnectionService proto.UserServiceClient
} }
func DefaultConfig() *GrpcConfig { func DefaultConfig() *GrpcConfig {
@@ -111,6 +113,7 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error)
slugService := proto.NewSlugChangeClient(conn) slugService := proto.NewSlugChangeClient(conn)
eventService := proto.NewEventServiceClient(conn) eventService := proto.NewEventServiceClient(conn)
authorizeConnectionService := proto.NewUserServiceClient(conn)
return &Client{ return &Client{
conn: conn, conn: conn,
@@ -118,6 +121,7 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error)
slugService: slugService, slugService: slugService,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
eventService: eventService, eventService: eventService,
authorizeConnectionService: authorizeConnectionService,
}, nil }, nil
} }
@@ -221,6 +225,35 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Cli
log.Printf("non-connection send error for slug change success: %v", err) log.Printf("non-connection send error for slug change success: %v", err)
continue continue
} }
case proto.EventType_GET_SESSIONS:
sessions := c.sessionRegistry.GetAllSessionFromUser(recv.GetGetSessionsEvent().GetIdentity())
var details []*proto.Detail
for _, ses := range sessions {
detail := ses.Detail()
details = append(details, &proto.Detail{
ForwardingType: detail.ForwardingType,
Slug: detail.Slug,
UserId: detail.UserID,
Active: detail.Active,
StartedAt: timestamppb.New(detail.StartedAt),
})
}
err = subscribe.Send(&proto.Client{
Type: proto.EventType_GET_SESSIONS,
Payload: &proto.Client_GetSessionsEvent{
GetSessionsEvent: &proto.GetSessionsResponse{
Details: details,
},
},
})
if err != nil {
if isConnectionError(err) {
log.Printf("connection error sending sessions success: %v", err)
return err
}
log.Printf("non-connection send error for sessions success: %v", err)
continue
}
default: default:
log.Printf("Unknown event type received: %v", recv.GetType()) log.Printf("Unknown event type received: %v", recv.GetType())
} }
@@ -231,6 +264,18 @@ func (c *Client) GetConnection() *grpc.ClientConn {
return c.conn return c.conn
} }
func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bool, err error) {
check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token})
if err != nil {
return false, err
}
if check.GetResponse() == proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED {
return false, nil
}
return true, nil
}
func (c *Client) Close() error { func (c *Client) Close() error {
if c.conn != nil { if c.conn != nil {
log.Printf("Closing gRPC connection to %s", c.config.Address) log.Printf("Closing gRPC connection to %s", c.config.Address)

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"net" "net"
@@ -63,6 +64,13 @@ func (s *Server) Start() {
func (s *Server) handleConnection(conn net.Conn) { func (s *Server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
defer func(sshConn *ssh.ServerConn) {
err = sshConn.Close()
if err != nil {
log.Printf("failed to close SSH server: %v", err)
}
}(sshConn)
if err != nil { if err != nil {
log.Printf("failed to establish SSH connection: %v", err) log.Printf("failed to establish SSH connection: %v", err)
err := conn.Close() err := conn.Close()
@@ -72,14 +80,20 @@ func (s *Server) handleConnection(conn net.Conn) {
} }
return return
} }
//ctx := context.Background() ctx := context.Background()
//log.Println("SSH connection established:", sshConn.User()) log.Println("SSH connection established:", sshConn.User())
//get, err := s.grpcClient.IdentityService.Get(ctx, &gen.IdentifierRequest{Id: sshConn.User()})
//if err != nil { //Fallback: kalau auth gagal userID di set UNAUTHORIZED
// return authorized, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User())
//}
//fmt.Println(get) var userID string
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry) if authorized {
userID = sshConn.User()
} else {
userID = "UNAUTHORIZED"
}
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, userID)
err = sshSession.Start() err = sshSession.Start()
if err != nil { if err != nil {
log.Printf("SSH session ended with error: %v", err) log.Printf("SSH session ended with error: %v", err)

View File

@@ -4,6 +4,8 @@ import (
"errors" "errors"
"io" "io"
"net" "net"
"time"
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
@@ -24,16 +26,18 @@ type Lifecycle struct {
forwarder Forwarder forwarder Forwarder
slugManager slug.Manager slugManager slug.Manager
unregisterClient func(slug string) unregisterClient func(slug string)
startedAt time.Time
} }
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle { func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
return &Lifecycle{ return &Lifecycle{
status: "", status: types.INITIALIZING,
conn: conn, conn: conn,
channel: nil, channel: nil,
forwarder: forwarder, forwarder: forwarder,
slugManager: slugManager, slugManager: slugManager,
unregisterClient: nil, unregisterClient: nil,
startedAt: time.Now(),
} }
} }
@@ -48,6 +52,8 @@ type SessionLifecycle interface {
GetChannel() ssh.Channel GetChannel() ssh.Channel
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetUnregisterClient(unregisterClient func(slug string)) SetUnregisterClient(unregisterClient func(slug string))
IsActive() bool
StartedAt() time.Time
} }
func (l *Lifecycle) GetChannel() ssh.Channel { func (l *Lifecycle) GetChannel() ssh.Channel {
@@ -62,6 +68,9 @@ func (l *Lifecycle) GetConnection() ssh.Conn {
} }
func (l *Lifecycle) SetStatus(status types.Status) { func (l *Lifecycle) SetStatus(status types.Status) {
l.status = status l.status = status
if status == types.RUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now()
}
} }
func (l *Lifecycle) Close() error { func (l *Lifecycle) Close() error {
@@ -98,3 +107,11 @@ func (l *Lifecycle) Close() error {
return nil return nil
} }
func (l *Lifecycle) IsActive() bool {
return l.status == types.RUNNING
}
func (l *Lifecycle) StartedAt() time.Time {
return l.startedAt
}

View File

@@ -10,15 +10,18 @@ type Registry interface {
Update(oldSlug, newSlug string) error Update(oldSlug, newSlug string) error
Register(slug string, session *SSHSession) (success bool) Register(slug string, session *SSHSession) (success bool)
Remove(slug string) Remove(slug string)
GetAllSessionFromUser(user string) []*SSHSession
} }
type registry struct { type registry struct {
mu sync.RWMutex mu sync.RWMutex
clients map[string]*SSHSession byUser map[string]map[string]*SSHSession
slugIndex map[string]string
} }
func NewRegistry() Registry { func NewRegistry() Registry {
return &registry{ return &registry{
clients: make(map[string]*SSHSession), byUser: make(map[string]map[string]*SSHSession),
slugIndex: make(map[string]string),
} }
} }
@@ -26,7 +29,12 @@ func (r *registry) Get(slug string) (session *SSHSession, err error) {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
client, ok := r.clients[slug] userID, ok := r.slugIndex[slug]
if !ok {
return nil, fmt.Errorf("session not found")
}
client, ok := r.byUser[userID][slug]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, fmt.Errorf("session not found")
} }
@@ -43,18 +51,30 @@ func (r *registry) Update(oldSlug, newSlug string) error {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, exists := r.clients[newSlug]; exists && newSlug != oldSlug { userID, ok := r.slugIndex[oldSlug]
return fmt.Errorf("someone already uses this subdomain")
}
client, ok := r.clients[oldSlug]
if !ok { if !ok {
return fmt.Errorf("session not found") return fmt.Errorf("session not found")
} }
delete(r.clients, oldSlug) if _, exists := r.slugIndex[newSlug]; exists && newSlug != oldSlug {
return fmt.Errorf("someone already uses this subdomain")
}
client, ok := r.byUser[userID][oldSlug]
if !ok {
return fmt.Errorf("session not found")
}
delete(r.byUser[userID], oldSlug)
delete(r.slugIndex, oldSlug)
client.slugManager.Set(newSlug) client.slugManager.Set(newSlug)
r.clients[newSlug] = client r.slugIndex[newSlug] = userID
if r.byUser[userID] == nil {
r.byUser[userID] = make(map[string]*SSHSession)
}
r.byUser[userID][newSlug] = client
return nil return nil
} }
@@ -62,19 +82,50 @@ func (r *registry) Register(slug string, session *SSHSession) (success bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, exists := r.clients[slug]; exists { if _, exists := r.slugIndex[slug]; exists {
return false return false
} }
r.clients[slug] = session userID := session.userID
if r.byUser[userID] == nil {
r.byUser[userID] = make(map[string]*SSHSession)
}
r.byUser[userID][slug] = session
r.slugIndex[slug] = userID
return true return true
} }
func (r *registry) GetAllSessionFromUser(user string) []*SSHSession {
r.mu.RLock()
defer r.mu.RUnlock()
m := r.byUser[user]
if len(m) == 0 {
return []*SSHSession{}
}
sessions := make([]*SSHSession, 0, len(m))
for _, s := range m {
sessions = append(sessions, s)
}
return sessions
}
func (r *registry) Remove(slug string) { func (r *registry) Remove(slug string) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
delete(r.clients, slug) userID, ok := r.slugIndex[slug]
if !ok {
return
}
delete(r.byUser[userID], slug)
if len(r.byUser[userID]) == 0 {
delete(r.byUser, userID)
}
delete(r.slugIndex, slug)
} }
func isValidSlug(slug string) bool { func isValidSlug(slug string) bool {

View File

@@ -28,6 +28,7 @@ type SSHSession struct {
forwarder forwarder.ForwardingController forwarder forwarder.ForwardingController
slugManager slug.Manager slugManager slug.Manager
registry Registry registry Registry
userID string
} }
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle { func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
@@ -46,7 +47,7 @@ func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager return s.slugManager
} }
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry) *SSHSession { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession {
slugManager := slug.NewManager() slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager) forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager) interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
@@ -65,6 +66,25 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
forwarder: forwarderManager, forwarder: forwarderManager,
slugManager: slugManager, slugManager: slugManager,
registry: sessionRegistry, registry: sessionRegistry,
userID: userID,
}
}
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
}
func (s *SSHSession) Detail() Detail {
return Detail{
ForwardingType: string(s.forwarder.GetTunnelType()),
Slug: s.slugManager.Get(),
UserID: s.userID,
Active: s.lifecycle.IsActive(),
StartedAt: s.lifecycle.StartedAt(),
} }
} }
@@ -86,7 +106,7 @@ func (s *SSHSession) Start() error {
if err := s.lifecycle.Close(); err != nil { if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return fmt.Errorf("No forwarding Request") return fmt.Errorf("no forwarding Request")
} }
s.lifecycle.SetChannel(ch) s.lifecycle.SetChannel(ch)