diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 38835c1..ea86b42 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" ) type GrpcConfig struct { @@ -33,11 +34,12 @@ type GrpcConfig struct { } type Client struct { - conn *grpc.ClientConn - config *GrpcConfig - sessionRegistry session.Registry - slugService proto.SlugChangeClient - eventService proto.EventServiceClient + conn *grpc.ClientConn + config *GrpcConfig + sessionRegistry session.Registry + slugService proto.SlugChangeClient + eventService proto.EventServiceClient + authorizeConnectionService proto.UserServiceClient } func DefaultConfig() *GrpcConfig { @@ -111,13 +113,15 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error) slugService := proto.NewSlugChangeClient(conn) eventService := proto.NewEventServiceClient(conn) + authorizeConnectionService := proto.NewUserServiceClient(conn) return &Client{ - conn: conn, - config: config, - slugService: slugService, - sessionRegistry: sessionRegistry, - eventService: eventService, + conn: conn, + config: config, + slugService: slugService, + sessionRegistry: sessionRegistry, + eventService: eventService, + authorizeConnectionService: authorizeConnectionService, }, nil } @@ -221,6 +225,35 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Cli log.Printf("non-connection send error for slug change success: %v", err) continue } + case proto.EventType_GET_SESSIONS: + sessions := c.sessionRegistry.GetAllSessionFromUser(recv.GetGetSessionsEvent().GetIdentity()) + var details []*proto.Detail + for _, ses := range sessions { + detail := ses.Detail() + details = append(details, &proto.Detail{ + ForwardingType: detail.ForwardingType, + Slug: detail.Slug, + UserId: detail.UserID, + Active: detail.Active, + StartedAt: timestamppb.New(detail.StartedAt), + }) + } + err = subscribe.Send(&proto.Client{ + Type: proto.EventType_GET_SESSIONS, + Payload: &proto.Client_GetSessionsEvent{ + GetSessionsEvent: &proto.GetSessionsResponse{ + Details: details, + }, + }, + }) + if err != nil { + if isConnectionError(err) { + log.Printf("connection error sending sessions success: %v", err) + return err + } + log.Printf("non-connection send error for sessions success: %v", err) + continue + } default: log.Printf("Unknown event type received: %v", recv.GetType()) } @@ -231,6 +264,18 @@ func (c *Client) GetConnection() *grpc.ClientConn { return c.conn } +func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bool, err error) { + check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token}) + if err != nil { + return false, err + + } + if check.GetResponse() == proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED { + return false, nil + } + return true, nil +} + func (c *Client) Close() error { if c.conn != nil { log.Printf("Closing gRPC connection to %s", c.config.Address) diff --git a/server/server.go b/server/server.go index 53e8e3f..f377a4b 100644 --- a/server/server.go +++ b/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "log" "net" @@ -63,6 +64,13 @@ func (s *Server) Start() { func (s *Server) handleConnection(conn net.Conn) { sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) + defer func(sshConn *ssh.ServerConn) { + err = sshConn.Close() + if err != nil { + log.Printf("failed to close SSH server: %v", err) + } + }(sshConn) + if err != nil { log.Printf("failed to establish SSH connection: %v", err) err := conn.Close() @@ -72,14 +80,20 @@ func (s *Server) handleConnection(conn net.Conn) { } return } - //ctx := context.Background() - //log.Println("SSH connection established:", sshConn.User()) - //get, err := s.grpcClient.IdentityService.Get(ctx, &gen.IdentifierRequest{Id: sshConn.User()}) - //if err != nil { - // return - //} - //fmt.Println(get) - sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry) + ctx := context.Background() + log.Println("SSH connection established:", sshConn.User()) + + //Fallback: kalau auth gagal userID di set UNAUTHORIZED + authorized, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User()) + + var userID string + if authorized { + userID = sshConn.User() + } else { + userID = "UNAUTHORIZED" + } + + sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, userID) err = sshSession.Start() if err != nil { log.Printf("SSH session ended with error: %v", err) diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 5917ba0..f2dea50 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -4,6 +4,8 @@ import ( "errors" "io" "net" + "time" + portUtil "tunnel_pls/internal/port" "tunnel_pls/session/slug" "tunnel_pls/types" @@ -24,16 +26,18 @@ type Lifecycle struct { forwarder Forwarder slugManager slug.Manager unregisterClient func(slug string) + startedAt time.Time } func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle { return &Lifecycle{ - status: "", + status: types.INITIALIZING, conn: conn, channel: nil, forwarder: forwarder, slugManager: slugManager, unregisterClient: nil, + startedAt: time.Now(), } } @@ -48,6 +52,8 @@ type SessionLifecycle interface { GetChannel() ssh.Channel SetChannel(channel ssh.Channel) SetUnregisterClient(unregisterClient func(slug string)) + IsActive() bool + StartedAt() time.Time } func (l *Lifecycle) GetChannel() ssh.Channel { @@ -62,6 +68,9 @@ func (l *Lifecycle) GetConnection() ssh.Conn { } func (l *Lifecycle) SetStatus(status types.Status) { l.status = status + if status == types.RUNNING && l.startedAt.IsZero() { + l.startedAt = time.Now() + } } func (l *Lifecycle) Close() error { @@ -98,3 +107,11 @@ func (l *Lifecycle) Close() error { return nil } + +func (l *Lifecycle) IsActive() bool { + return l.status == types.RUNNING +} + +func (l *Lifecycle) StartedAt() time.Time { + return l.startedAt +} diff --git a/session/registry.go b/session/registry.go index 8ff70fa..b6f655d 100644 --- a/session/registry.go +++ b/session/registry.go @@ -10,15 +10,18 @@ type Registry interface { Update(oldSlug, newSlug string) error Register(slug string, session *SSHSession) (success bool) Remove(slug string) + GetAllSessionFromUser(user string) []*SSHSession } type registry struct { - mu sync.RWMutex - clients map[string]*SSHSession + mu sync.RWMutex + byUser map[string]map[string]*SSHSession + slugIndex map[string]string } func NewRegistry() Registry { return ®istry{ - clients: make(map[string]*SSHSession), + byUser: make(map[string]map[string]*SSHSession), + slugIndex: make(map[string]string), } } @@ -26,7 +29,12 @@ func (r *registry) Get(slug string) (session *SSHSession, err error) { r.mu.RLock() defer r.mu.RUnlock() - client, ok := r.clients[slug] + userID, ok := r.slugIndex[slug] + if !ok { + return nil, fmt.Errorf("session not found") + } + + client, ok := r.byUser[userID][slug] if !ok { return nil, fmt.Errorf("session not found") } @@ -43,18 +51,30 @@ func (r *registry) Update(oldSlug, newSlug string) error { r.mu.Lock() defer r.mu.Unlock() - if _, exists := r.clients[newSlug]; exists && newSlug != oldSlug { - return fmt.Errorf("someone already uses this subdomain") - } - - client, ok := r.clients[oldSlug] + userID, ok := r.slugIndex[oldSlug] if !ok { return fmt.Errorf("session not found") } - delete(r.clients, oldSlug) + if _, exists := r.slugIndex[newSlug]; exists && newSlug != oldSlug { + return fmt.Errorf("someone already uses this subdomain") + } + + client, ok := r.byUser[userID][oldSlug] + if !ok { + return fmt.Errorf("session not found") + } + + delete(r.byUser[userID], oldSlug) + delete(r.slugIndex, oldSlug) + client.slugManager.Set(newSlug) - r.clients[newSlug] = client + r.slugIndex[newSlug] = userID + + if r.byUser[userID] == nil { + r.byUser[userID] = make(map[string]*SSHSession) + } + r.byUser[userID][newSlug] = client return nil } @@ -62,19 +82,50 @@ func (r *registry) Register(slug string, session *SSHSession) (success bool) { r.mu.Lock() defer r.mu.Unlock() - if _, exists := r.clients[slug]; exists { + if _, exists := r.slugIndex[slug]; exists { return false } - r.clients[slug] = session + userID := session.userID + if r.byUser[userID] == nil { + r.byUser[userID] = make(map[string]*SSHSession) + } + + r.byUser[userID][slug] = session + r.slugIndex[slug] = userID return true } +func (r *registry) GetAllSessionFromUser(user string) []*SSHSession { + r.mu.RLock() + defer r.mu.RUnlock() + + m := r.byUser[user] + if len(m) == 0 { + return []*SSHSession{} + } + + sessions := make([]*SSHSession, 0, len(m)) + for _, s := range m { + sessions = append(sessions, s) + } + return sessions +} + func (r *registry) Remove(slug string) { r.mu.Lock() defer r.mu.Unlock() - delete(r.clients, slug) + userID, ok := r.slugIndex[slug] + if !ok { + return + } + + delete(r.byUser[userID], slug) + if len(r.byUser[userID]) == 0 { + delete(r.byUser, userID) + } + delete(r.slugIndex, slug) } func isValidSlug(slug string) bool { diff --git a/session/session.go b/session/session.go index 9a35770..45d497a 100644 --- a/session/session.go +++ b/session/session.go @@ -28,6 +28,7 @@ type SSHSession struct { forwarder forwarder.ForwardingController slugManager slug.Manager registry Registry + userID string } func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle { @@ -46,7 +47,7 @@ func (s *SSHSession) GetSlugManager() slug.Manager { return s.slugManager } -func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry) *SSHSession { +func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession { slugManager := slug.NewManager() forwarderManager := forwarder.NewForwarder(slugManager) interactionManager := interaction.NewInteraction(slugManager, forwarderManager) @@ -65,6 +66,25 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan forwarder: forwarderManager, slugManager: slugManager, registry: sessionRegistry, + userID: userID, + } +} + +type Detail struct { + ForwardingType string `json:"forwarding_type,omitempty"` + Slug string `json:"slug,omitempty"` + UserID string `json:"user_id,omitempty"` + Active bool `json:"active,omitempty"` + StartedAt time.Time `json:"started_at,omitempty"` +} + +func (s *SSHSession) Detail() Detail { + return Detail{ + ForwardingType: string(s.forwarder.GetTunnelType()), + Slug: s.slugManager.Get(), + UserID: s.userID, + Active: s.lifecycle.IsActive(), + StartedAt: s.lifecycle.StartedAt(), } } @@ -86,7 +106,7 @@ func (s *SSHSession) Start() error { if err := s.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) } - return fmt.Errorf("No forwarding Request") + return fmt.Errorf("no forwarding Request") } s.lifecycle.SetChannel(ch)