diff --git a/go.mod b/go.mod index fe62062..ba15d04 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module tunnel_pls go 1.25.5 require ( - git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0 + git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 github.com/caddyserver/certmagic v0.25.0 github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbletea v1.3.10 diff --git a/go.sum b/go.sum index d477230..96c2ea5 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0 h1:RhcBKUG41/om4jgN+iF/vlY/RojTe git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY= git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0 h1:tpJSKjaSmV+vxxbVx6qnStjxFVXjj2M0rygWXxLb99o= git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY= +git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA= +git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 00eac89..8aaf949 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -316,6 +316,96 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod log.Printf("non-connection send error for sessions success: %v", err) continue } + case proto.EventType_TERMINATE_SESSION: + user := recv.GetTerminateSessionEvent().GetUser() + tunnelTypeRaw := recv.GetTerminateSessionEvent().GetTunnelType() + slug := recv.GetTerminateSessionEvent().GetSlug() + + var userSession *session.SSHSession + var tunnelType types.TunnelType + if tunnelTypeRaw == proto.TunnelType_HTTP { + tunnelType = types.HTTP + } else if tunnelTypeRaw == proto.TunnelType_TCP { + tunnelType = types.TCP + } else { + err = subscribe.Send(&proto.Node{ + Type: proto.EventType_TERMINATE_SESSION, + Payload: &proto.Node_TerminateSessionEventResponse{ + TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{ + Success: false, + Message: "unknown tunnel type recived", + }, + }, + }) + if err != nil { + if c.isConnectionError(err) { + log.Printf("connection error sending sessions success: %v", err) + return err + } + log.Printf("non-connection send error for sessions success: %v", err) + } + continue + } + userSession, err = c.sessionRegistry.GetWithUser(user, types.SessionKey{ + Id: slug, + Type: tunnelType, + }) + if err != nil { + err = subscribe.Send(&proto.Node{ + Type: proto.EventType_TERMINATE_SESSION, + Payload: &proto.Node_TerminateSessionEventResponse{ + TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{ + Success: false, + Message: err.Error(), + }, + }, + }) + if err != nil { + if c.isConnectionError(err) { + log.Printf("connection error sending sessions success: %v", err) + return err + } + log.Printf("non-connection send error for sessions success: %v", err) + } + continue + } + err = userSession.GetLifecycle().Close() + if err != nil { + err = subscribe.Send(&proto.Node{ + Type: proto.EventType_TERMINATE_SESSION, + Payload: &proto.Node_TerminateSessionEventResponse{ + TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{ + Success: false, + Message: err.Error(), + }, + }, + }) + if err != nil { + if c.isConnectionError(err) { + log.Printf("connection error sending sessions success: %v", err) + return err + } + log.Printf("non-connection send error for sessions success: %v", err) + } + continue + } + err = subscribe.Send(&proto.Node{ + Type: proto.EventType_TERMINATE_SESSION, + Payload: &proto.Node_TerminateSessionEventResponse{ + TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{ + Success: true, + Message: "", + }, + }, + }) + if err != nil { + if c.isConnectionError(err) { + log.Printf("connection error sending sessions success: %v", err) + return err + } + log.Printf("non-connection send error for sessions success: %v", err) + continue + } default: log.Printf("Unknown event type received: %v", recv.GetType()) } diff --git a/session/registry.go b/session/registry.go index 60b86d3..3113dd6 100644 --- a/session/registry.go +++ b/session/registry.go @@ -10,6 +10,7 @@ type Key = types.SessionKey type Registry interface { Get(key Key) (session *SSHSession, err error) + GetWithUser(user string, key Key) (session *SSHSession, err error) Update(user string, oldKey, newKey Key) error Register(key Key, session *SSHSession) (success bool) Remove(key Key) @@ -44,6 +45,17 @@ func (r *registry) Get(key Key) (session *SSHSession, err error) { return client, nil } +func (r *registry) GetWithUser(user string, key Key) (session *SSHSession, err error) { + r.mu.RLock() + defer r.mu.RUnlock() + + client, ok := r.byUser[user][key] + if !ok { + return nil, fmt.Errorf("session not found") + } + return client, nil +} + func (r *registry) Update(user string, oldKey, newKey Key) error { if oldKey.Type != newKey.Type { return fmt.Errorf("tunnel type cannot change")