package server import ( "context" "fmt" "log" "net" "sync" "time" "git.fossy.my.id/bagas/tunnel-please-controller/db/sqlc/repository" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" ) type Subscriber struct { client chan *proto.Client controller chan *proto.Controller done chan struct{} closeOnce sync.Once } type Server struct { Database *repository.Queries Subscribers map[string]*Subscriber mu *sync.RWMutex authToken string broadcastChan chan *proto.Controller broadcastResultChan chan []SubscriberResult notifyAllCancel context.CancelFunc proto.UnimplementedEventServiceServer proto.UnimplementedSlugChangeServer proto.UnimplementedUserServiceServer proto.UnimplementedUserSessionsServer } func New(database *repository.Queries, authToken string) *Server { broadcastChan := make(chan *proto.Controller, 10) broadcastResultChan := make(chan []SubscriberResult, 10) ctx, cancel := context.WithCancel(context.Background()) srv := &Server{ Database: database, Subscribers: make(map[string]*Subscriber), mu: new(sync.RWMutex), authToken: authToken, broadcastChan: broadcastChan, broadcastResultChan: broadcastResultChan, notifyAllCancel: cancel, } go srv.notifyAllSubscriber(ctx, broadcastChan, broadcastResultChan) return srv } func (s *Server) GetSession(ctx context.Context, req *proto.GetSessionRequest) (*proto.GetSessionsResponse, error) { if req == nil { return nil, status.Error(codes.InvalidArgument, "request is nil") } controllerReq := &proto.Controller{ Type: proto.EventType_GET_SESSIONS, Payload: &proto.Controller_GetSessionsEvent{ GetSessionsEvent: &proto.GetSessionsEvent{ Identity: req.GetIdentity(), }, }, } select { case <-ctx.Done(): return nil, ctx.Err() case s.broadcastChan <- controllerReq: } var results []SubscriberResult select { case <-ctx.Done(): return nil, ctx.Err() case results = <-s.broadcastResultChan: } responses := make([]*proto.Detail, 0, len(results)) for _, result := range results { for _, detail := range result.Response.Payload.(*proto.Client_GetSessionsEvent).GetSessionsEvent.Details { responses = append(responses, detail) } responses = append(responses) } if len(responses) == 0 { return nil, status.Error(codes.NotFound, "no subscriber responded") } return &proto.GetSessionsResponse{Details: responses}, nil } func (s *Server) Check(ctx context.Context, request *proto.CheckRequest) (*proto.CheckResponse, error) { exist, err := s.Database.UserExistsByIdentifier(ctx, request.GetAuthToken()) if err != nil { return nil, err } if exist { return &proto.CheckResponse{ Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED, }, nil } return &proto.CheckResponse{ Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED, }, nil } func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error { ctx := event.Context() recv, err := event.Recv() if err != nil { return err } if recv == nil { return status.Error(codes.InvalidArgument, "missing authentication event") } if recv.GetType() != proto.EventType_AUTHENTICATION { return status.Errorf(codes.InvalidArgument, "invalid event type: %s", recv.GetType()) } payload, ok := recv.GetPayload().(*proto.Client_AuthEvent) if !ok || payload == nil || payload.AuthEvent == nil { return status.Error(codes.InvalidArgument, "missing auth payload") } identity := payload.AuthEvent.Identity if identity == "" { return status.Error(codes.InvalidArgument, "missing identity") } token := payload.AuthEvent.AuthToken if token != s.authToken { return status.Error(codes.Unauthenticated, "invalid auth token") } log.Printf("Client %s authenticated successfully", identity) requestChan := &Subscriber{ client: make(chan *proto.Client, 10), controller: make(chan *proto.Controller, 10), done: make(chan struct{}), } if err = s.AddEventSubscriber(identity, requestChan); err != nil { return status.Error(codes.AlreadyExists, err.Error()) } defer func() { s.RemoveEventSubscriber(identity) log.Printf("Client %s disconnected and unsubscribed", identity) }() return processEventStream(ctx, requestChan, event) } func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error { for { select { case <-ctx.Done(): return ctx.Err() case <-requestChan.done: return nil case request, ok := <-requestChan.controller: if !ok { return nil } if request == nil { continue } log.Printf("Received event request: %v", request) switch request.GetType() { case proto.EventType_SLUG_CHANGE: payload, ok := request.GetPayload().(*proto.Controller_SlugEvent) if !ok || payload == nil || payload.SlugEvent == nil { log.Printf("invalid slug change payload") continue } slugEvent := payload.SlugEvent log.Printf("Processing slug change event: old=%s, new=%s", slugEvent.Old, slugEvent.New) if err := event.Send(request); err != nil { return err } recv, err := event.Recv() if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() case requestChan.client <- recv: } log.Printf("Received slug change event: %v", recv) case proto.EventType_GET_SESSIONS: log.Printf("Processing session event") if err := event.Send(request); err != nil { return err } recv, err := event.Recv() if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() case requestChan.client <- recv: } log.Printf("Received SESSIONS event: %v", recv) default: log.Printf("Unknown event type: %v", request.GetType()) } } } } func (s *Server) AddEventSubscriber(identity string, req *Subscriber) error { if identity == "" || req == nil { return fmt.Errorf("invalid subscriber") } s.mu.Lock() defer s.mu.Unlock() if _, exist := s.Subscribers[identity]; exist { return fmt.Errorf("identity %s already subscribed", identity) } s.Subscribers[identity] = req return nil } func (s *Server) RemoveEventSubscriber(identity string) { s.mu.Lock() sub := s.Subscribers[identity] delete(s.Subscribers, identity) s.mu.Unlock() if sub != nil { sub.closeOnce.Do(func() { close(sub.done) }) } } func (s *Server) GetEventSubscriber(identity string) (*Subscriber, error) { if identity == "" { return nil, status.Error(codes.InvalidArgument, "missing identity") } s.mu.RLock() defer s.mu.RUnlock() req, ok := s.Subscribers[identity] if !ok { return nil, status.Errorf(codes.NotFound, "identity %s not subscribed", identity) } return req, nil } func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlugRequest) (*proto.ChangeSlugResponse, error) { if request == nil { return nil, status.Error(codes.InvalidArgument, "request is nil") } if request.GetNode() == "" { return nil, status.Error(codes.InvalidArgument, "node is required") } if request.Old == "" || request.New == "" { return nil, status.Error(codes.InvalidArgument, "old and new slugs are required") } subscriber, err := s.GetEventSubscriber(request.GetNode()) if err != nil { return nil, err } controllerMsg := &proto.Controller{ Type: proto.EventType_SLUG_CHANGE, Payload: &proto.Controller_SlugEvent{ SlugEvent: &proto.SlugChangeEvent{ Old: request.Old, New: request.New, }, }, } select { case <-ctx.Done(): return nil, ctx.Err() case <-subscriber.done: return nil, status.Error(codes.Canceled, "subscriber removed") case subscriber.controller <- controllerMsg: } var resp *proto.Client select { case <-ctx.Done(): return nil, ctx.Err() case <-subscriber.done: return nil, status.Error(codes.Canceled, "subscriber removed") case resp = <-subscriber.client: } if resp == nil { return nil, status.Error(codes.FailedPrecondition, "empty response from client") } response, ok := resp.Payload.(*proto.Client_SlugEventResponse) if !ok || response == nil || response.SlugEventResponse == nil { return nil, status.Error(codes.FailedPrecondition, "invalid slug response payload") } return (*proto.ChangeSlugResponse)(response.SlugEventResponse), nil } type SubscriberResult struct { Identity string Response *proto.Client Err error } func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto.Controller, resultChan chan<- []SubscriberResult) { for { select { case <-ctx.Done(): return case controllerReq, ok := <-recvChan: if !ok { return } if controllerReq == nil { continue } s.mu.RLock() subs := make(map[string]*Subscriber, len(s.Subscribers)) for id, sub := range s.Subscribers { subs[id] = sub } s.mu.RUnlock() results := make([]SubscriberResult, 0, len(subs)) for id, sub := range subs { select { case <-ctx.Done(): return case <-sub.done: results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")}) continue case sub.controller <- controllerReq: default: results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Unavailable, "controller channel blocked")}) continue } select { case <-ctx.Done(): return case <-sub.done: results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")}) case resp, ok := <-sub.client: if !ok { results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.FailedPrecondition, "client channel closed")}) continue } results = append(results, SubscriberResult{Identity: id, Response: resp}) } } select { case <-ctx.Done(): return case resultChan <- results: } } } } func (s *Server) ListenAndServe(Addr string) error { listener, err := net.Listen("tcp", Addr) if err != nil { return err } kaParams := keepalive.ServerParameters{ MaxConnectionIdle: 0, MaxConnectionAge: 0, MaxConnectionAgeGrace: 0, Time: 30 * time.Second, Timeout: 10 * time.Second, } kaPolicy := keepalive.EnforcementPolicy{ MinTime: 5 * time.Second, PermitWithoutStream: true, } grpcServer := grpc.NewServer( grpc.KeepaliveParams(kaParams), grpc.KeepaliveEnforcementPolicy(kaPolicy), ) reflection.Register(grpcServer) proto.RegisterSlugChangeServer(grpcServer, s) proto.RegisterEventServiceServer(grpcServer, s) proto.RegisterUserServiceServer(grpcServer, s) proto.RegisterUserSessionsServer(grpcServer, s) healthServer := health.NewServer() grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING) return grpcServer.Serve(listener) }