diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 8aaf949..3d3d1c2 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -210,205 +210,152 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string } func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error { + handlers := c.eventHandlers(subscribe) + for { recv, err := subscribe.Recv() if err != nil { return err } - switch recv.GetType() { - case proto.EventType_SLUG_CHANGE: - user := recv.GetSlugEvent().GetUser() - oldSlug := recv.GetSlugEvent().GetOld() - newSlug := recv.GetSlugEvent().GetNew() - var userSession *session.SSHSession - userSession, err = c.sessionRegistry.Get(types.SessionKey{ - Id: oldSlug, - Type: types.HTTP, - }) - if err != nil { - errSend := subscribe.Send(&proto.Node{ - Type: proto.EventType_SLUG_CHANGE_RESPONSE, - Payload: &proto.Node_SlugEventResponse{ - SlugEventResponse: &proto.SlugChangeEventResponse{ - Success: false, - Message: err.Error(), - }, - }, - }) - if errSend != nil { - if c.isConnectionError(errSend) { - return errSend - } - log.Printf("non-connection send error for slug change failure: %v", errSend) - } - continue - } - err = c.sessionRegistry.Update(user, types.SessionKey{ - Id: oldSlug, - Type: types.HTTP, - }, types.SessionKey{ - Id: newSlug, - Type: types.HTTP, - }) - if err != nil { - errSend := subscribe.Send(&proto.Node{ - Type: proto.EventType_SLUG_CHANGE_RESPONSE, - Payload: &proto.Node_SlugEventResponse{ - SlugEventResponse: &proto.SlugChangeEventResponse{ - Success: false, - Message: err.Error(), - }, - }, - }) - if errSend != nil { - if c.isConnectionError(errSend) { - return errSend - } - log.Printf("non-connection send error for slug change failure: %v", errSend) - } - continue - } - userSession.GetInteraction().Redraw() - err = subscribe.Send(&proto.Node{ - Type: proto.EventType_SLUG_CHANGE_RESPONSE, - Payload: &proto.Node_SlugEventResponse{ - SlugEventResponse: &proto.SlugChangeEventResponse{ - Success: true, - Message: "", - }, - }, - }) - if err != nil { - if c.isConnectionError(err) { - log.Printf("connection error sending slug change success: %v", err) - return err - } - log.Printf("non-connection send error for slug change success: %v", err) - continue - } - case proto.EventType_GET_SESSIONS: - sessions := c.sessionRegistry.GetAllSessionFromUser(recv.GetGetSessionsEvent().GetIdentity()) - var details []*proto.Detail - for _, ses := range sessions { - detail := ses.Detail() - details = append(details, &proto.Detail{ - Node: config.Getenv("DOMAIN", "localhost"), - ForwardingType: detail.ForwardingType, - Slug: detail.Slug, - UserId: detail.UserID, - Active: detail.Active, - StartedAt: timestamppb.New(detail.StartedAt), - }) - } - err = subscribe.Send(&proto.Node{ - Type: proto.EventType_GET_SESSIONS, - Payload: &proto.Node_GetSessionsEvent{ - GetSessionsEvent: &proto.GetSessionsResponse{ - Details: details, - }, - }, - }) - if err != nil { - if c.isConnectionError(err) { - log.Printf("connection error sending sessions success: %v", err) - return err - } - log.Printf("non-connection send error for sessions success: %v", err) - continue - } - case proto.EventType_TERMINATE_SESSION: - user := recv.GetTerminateSessionEvent().GetUser() - tunnelTypeRaw := recv.GetTerminateSessionEvent().GetTunnelType() - slug := recv.GetTerminateSessionEvent().GetSlug() - var userSession *session.SSHSession - var tunnelType types.TunnelType - if tunnelTypeRaw == proto.TunnelType_HTTP { - tunnelType = types.HTTP - } else if tunnelTypeRaw == proto.TunnelType_TCP { - tunnelType = types.TCP - } else { - err = subscribe.Send(&proto.Node{ - Type: proto.EventType_TERMINATE_SESSION, - Payload: &proto.Node_TerminateSessionEventResponse{ - TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{ - Success: false, - Message: "unknown tunnel type recived", - }, - }, - }) - if err != nil { - if c.isConnectionError(err) { - log.Printf("connection error sending sessions success: %v", err) - return err - } - log.Printf("non-connection send error for sessions success: %v", err) - } - continue - } - userSession, err = c.sessionRegistry.GetWithUser(user, types.SessionKey{ - Id: slug, - Type: tunnelType, - }) - if err != nil { - err = subscribe.Send(&proto.Node{ - Type: proto.EventType_TERMINATE_SESSION, - Payload: &proto.Node_TerminateSessionEventResponse{ - TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{ - Success: false, - Message: err.Error(), - }, - }, - }) - if err != nil { - if c.isConnectionError(err) { - log.Printf("connection error sending sessions success: %v", err) - return err - } - log.Printf("non-connection send error for sessions success: %v", err) - } - continue - } - err = userSession.GetLifecycle().Close() - if err != nil { - err = subscribe.Send(&proto.Node{ - Type: proto.EventType_TERMINATE_SESSION, - Payload: &proto.Node_TerminateSessionEventResponse{ - TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{ - Success: false, - Message: err.Error(), - }, - }, - }) - if err != nil { - if c.isConnectionError(err) { - log.Printf("connection error sending sessions success: %v", err) - return err - } - log.Printf("non-connection send error for sessions success: %v", err) - } - continue - } - err = subscribe.Send(&proto.Node{ - Type: proto.EventType_TERMINATE_SESSION, - Payload: &proto.Node_TerminateSessionEventResponse{ - TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{ - Success: true, - Message: "", - }, - }, - }) - if err != nil { - if c.isConnectionError(err) { - log.Printf("connection error sending sessions success: %v", err) - return err - } - log.Printf("non-connection send error for sessions success: %v", err) - continue - } - default: + handler, ok := handlers[recv.GetType()] + if !ok { log.Printf("Unknown event type received: %v", recv.GetType()) + continue } + + if err = handler(recv); err != nil { + return err + } + } +} + +func (c *Client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error { + return map[proto.EventType]func(*proto.Events) error{ + proto.EventType_SLUG_CHANGE: func(evt *proto.Events) error { return c.handleSlugChange(subscribe, evt) }, + proto.EventType_GET_SESSIONS: func(evt *proto.Events) error { return c.handleGetSessions(subscribe, evt) }, + proto.EventType_TERMINATE_SESSION: func(evt *proto.Events) error { return c.handleTerminateSession(subscribe, evt) }, + } +} + +func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { + slugEvent := evt.GetSlugEvent() + user := slugEvent.GetUser() + oldSlug := slugEvent.GetOld() + newSlug := slugEvent.GetNew() + + userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP}) + if err != nil { + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_SLUG_CHANGE_RESPONSE, + Payload: &proto.Node_SlugEventResponse{ + SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()}, + }, + }, "slug change failure response") + } + + if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil { + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_SLUG_CHANGE_RESPONSE, + Payload: &proto.Node_SlugEventResponse{ + SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()}, + }, + }, "slug change failure response") + } + + userSession.GetInteraction().Redraw() + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_SLUG_CHANGE_RESPONSE, + Payload: &proto.Node_SlugEventResponse{ + SlugEventResponse: &proto.SlugChangeEventResponse{Success: true, Message: ""}, + }, + }, "slug change success response") +} + +func (c *Client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { + sessions := c.sessionRegistry.GetAllSessionFromUser(evt.GetGetSessionsEvent().GetIdentity()) + + var details []*proto.Detail + for _, ses := range sessions { + detail := ses.Detail() + details = append(details, &proto.Detail{ + Node: config.Getenv("DOMAIN", "localhost"), + ForwardingType: detail.ForwardingType, + Slug: detail.Slug, + UserId: detail.UserID, + Active: detail.Active, + StartedAt: timestamppb.New(detail.StartedAt), + }) + } + + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_GET_SESSIONS, + Payload: &proto.Node_GetSessionsEvent{ + GetSessionsEvent: &proto.GetSessionsResponse{Details: details}, + }, + }, "send get sessions response") +} + +func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { + terminate := evt.GetTerminateSessionEvent() + user := terminate.GetUser() + slug := terminate.GetSlug() + + tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType()) + if err != nil { + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_TERMINATE_SESSION, + Payload: &proto.Node_TerminateSessionEventResponse{ + TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()}, + }, + }, "terminate session invalid tunnel type") + } + + userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType}) + if err != nil { + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_TERMINATE_SESSION, + Payload: &proto.Node_TerminateSessionEventResponse{ + TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()}, + }, + }, "terminate session fetch failed") + } + + if err = userSession.GetLifecycle().Close(); err != nil { + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_TERMINATE_SESSION, + Payload: &proto.Node_TerminateSessionEventResponse{ + TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()}, + }, + }, "terminate session close failed") + } + + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_TERMINATE_SESSION, + Payload: &proto.Node_TerminateSessionEventResponse{ + TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: true, Message: ""}, + }, + }, "terminate session success response") +} + +func (c *Client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error { + if err := subscribe.Send(node); err != nil { + if c.isConnectionError(err) { + return err + } + log.Printf("%s: %v", context, err) + } + return nil +} + +func (c *Client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) { + switch t { + case proto.TunnelType_HTTP: + return types.HTTP, nil + case proto.TunnelType_TCP: + return types.TCP, nil + default: + return types.UNKNOWN, fmt.Errorf("unknown tunnel type received") } } diff --git a/types/types.go b/types/types.go index 5c4eece..7d3f4de 100644 --- a/types/types.go +++ b/types/types.go @@ -11,8 +11,9 @@ const ( type TunnelType string const ( - HTTP TunnelType = "HTTP" - TCP TunnelType = "TCP" + UNKNOWN TunnelType = "UNKNOWN" + HTTP TunnelType = "HTTP" + TCP TunnelType = "TCP" ) type SessionKey struct {