diff --git a/main.go b/main.go index 6b7af01..6ee73c5 100644 --- a/main.go +++ b/main.go @@ -33,7 +33,10 @@ func main() { }(connect, ctx) repo := repository.New(connect) - s := server.New(repo) + s := server.New(repo, "test_auth_key") + + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) log.Printf("Listening on :8080\n") err = s.ListenAndServe(":8080") diff --git a/server/server.go b/server/server.go index d83c9df..0d2f20d 100644 --- a/server/server.go +++ b/server/server.go @@ -2,126 +2,222 @@ package server import ( "context" + "fmt" "log" - mathrand "math/rand" "net" - "strings" "sync" "time" "git.fossy.my.id/bagas/tunnel-please-controller/db/sqlc/repository" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" - "github.com/google/uuid" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/keepalive" "google.golang.org/grpc/reflection" - "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/grpc/status" ) +type Subscriber struct { + client chan *proto.Client + controller chan *proto.Controller + done chan struct{} + closeOnce sync.Once +} type Server struct { - Database *repository.Queries - Subscriber []chan *proto.Event - mu sync.RWMutex - proto.UnimplementedIdentityServer + Database *repository.Queries + Subscribers map[string]*Subscriber + mu *sync.RWMutex + authToken string proto.UnimplementedEventServiceServer - proto.UnimplementedSlugServer + proto.UnimplementedSlugChangeServer } -func (s *Server) ChangeSlug(ctx context.Context, request *proto.ChangeSlugRequest) (*proto.ChangeSlugResponse, error) { - s.NotifyAllSubscriber(&proto.Event{ - Type: proto.EventType_SLUG_CHANGE, - TimestampUnixMs: time.Now().Unix(), - Data: &proto.Event_DataEvent{DataEvent: &proto.SlugChangeEvent{ - Old: request.GetOld(), - New: request.GetNew(), - }}, - }) - return &proto.ChangeSlugResponse{}, nil +func New(database *repository.Queries, authToken string) *Server { + return &Server{ + Database: database, + Subscribers: make(map[string]*Subscriber), + mu: new(sync.RWMutex), + authToken: authToken, + } } -func (s *Server) Subscribe(request *emptypb.Empty, g grpc.ServerStreamingServer[proto.Event]) error { - sr := make(chan *proto.Event) - s.AddSubscriberChan(sr) - defer s.RemoveSubscriberChan(sr) +func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error { + ctx := event.Context() + recv, err := event.Recv() + if err != nil { + return err + } + if recv == nil { + return status.Error(codes.InvalidArgument, "missing authentication event") + } + if recv.GetType() != proto.EventType_AUTHENTICATION { + return status.Errorf(codes.InvalidArgument, "invalid event type: %s", recv.GetType()) + } + payload, ok := recv.GetPayload().(*proto.Client_AuthEvent) + if !ok || payload == nil || payload.AuthEvent == nil { + return status.Error(codes.InvalidArgument, "missing auth payload") + } + identity := payload.AuthEvent.Identity + if identity == "" { + return status.Error(codes.InvalidArgument, "missing identity") + } + token := payload.AuthEvent.AuthToken + if token != s.authToken { + return status.Error(codes.Unauthenticated, "invalid auth token") + } - for ev := range sr { - if err := g.Send(ev); err != nil { - return err + log.Printf("Client %s authenticated successfully", identity) + + requestChan := &Subscriber{ + client: make(chan *proto.Client, 10), + controller: make(chan *proto.Controller, 10), + done: make(chan struct{}), + } + + if err = s.AddEventSubscriber(identity, requestChan); err != nil { + return status.Error(codes.AlreadyExists, err.Error()) + } + defer func() { + s.RemoveEventSubscriber(identity) + log.Printf("Client %s disconnected and unsubscribed", identity) + }() + + return processEventStream(ctx, requestChan, event) +} + +func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-requestChan.done: + return nil + case request, ok := <-requestChan.controller: + if !ok { + return nil + } + if request == nil { + continue + } + log.Printf("Received event request: %v", request) + switch request.GetType() { + case proto.EventType_SLUG_CHANGE: + payload, ok := request.GetPayload().(*proto.Controller_SlugEvent) + if !ok || payload == nil || payload.SlugEvent == nil { + log.Printf("invalid slug change payload") + continue + } + slugEvent := payload.SlugEvent + log.Printf("Processing slug change event: old=%s, new=%s", slugEvent.Old, slugEvent.New) + if err := event.Send(request); err != nil { + return err + } + recv, err := event.Recv() + if err != nil { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case requestChan.client <- recv: + } + log.Printf("Received slug change event: %v", recv) + default: + log.Printf("Unknown event type: %v", request.GetType()) + } } } +} + +func (s *Server) AddEventSubscriber(identity string, req *Subscriber) error { + if identity == "" || req == nil { + return fmt.Errorf("invalid subscriber") + } + s.mu.Lock() + defer s.mu.Unlock() + if _, exist := s.Subscribers[identity]; exist { + return fmt.Errorf("identity %s already subscribed", identity) + } + s.Subscribers[identity] = req return nil } -func (s *Server) NotifyAllSubscriber(event *proto.Event) { - for _, subs := range s.Subscriber { - subs <- event - } -} - -func (s *Server) AddSubscriberChan(event chan *proto.Event) { +func (s *Server) RemoveEventSubscriber(identity string) { s.mu.Lock() - defer s.mu.Unlock() - s.Subscriber = append(s.Subscriber, event) + sub := s.Subscribers[identity] + delete(s.Subscribers, identity) + s.mu.Unlock() + + if sub != nil { + sub.closeOnce.Do(func() { close(sub.done) }) + } } -func (s *Server) RemoveSubscriberChan(ch chan *proto.Event) { - s.mu.Lock() - defer s.mu.Unlock() - if len(s.Subscriber) == 0 || ch == nil { - return +func (s *Server) GetEventSubscriber(identity string) (*Subscriber, error) { + if identity == "" { + return nil, status.Error(codes.InvalidArgument, "missing identity") } - newSubs := s.Subscriber[:0] - for _, c := range s.Subscriber { - if c == ch { - continue - } - newSubs = append(newSubs, c) + s.mu.RLock() + defer s.mu.RUnlock() + req, ok := s.Subscribers[identity] + if !ok { + return nil, status.Errorf(codes.NotFound, "identity %s not subscribed", identity) } - s.Subscriber = newSubs - - close(ch) + return req, nil } -func (s *Server) Get(ctx context.Context, request *proto.IdentifierRequest) (*proto.IdentifierResponse, error) { - parse, err := uuid.Parse(request.GetId()) +func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlugRequest) (*proto.ChangeSlugResponse, error) { + if request == nil { + return nil, status.Error(codes.InvalidArgument, "request is nil") + } + if request.GetNode() == "" { + return nil, status.Error(codes.InvalidArgument, "node is required") + } + if request.Old == "" || request.New == "" { + return nil, status.Error(codes.InvalidArgument, "old and new slugs are required") + } + + subscriber, err := s.GetEventSubscriber(request.GetNode()) if err != nil { return nil, err } - data, err := s.Database.GetIdentifierById(ctx, parse) - if err != nil { - return nil, err - } - return &proto.IdentifierResponse{ - Id: data.ID.String(), - Slug: data.Slug, - }, nil -} -func (s *Server) Create(ctx context.Context, request *emptypb.Empty) (*proto.IdentifierResponse, error) { - createIdentifier, err := s.Database.CreateIdentifier(ctx, GenerateRandomString(32)) - if err != nil { - return nil, err + controllerMsg := &proto.Controller{ + Type: proto.EventType_SLUG_CHANGE, + Payload: &proto.Controller_SlugEvent{ + SlugEvent: &proto.SlugChangeEvent{ + Old: request.Old, + New: request.New, + }, + }, } - return &proto.IdentifierResponse{ - Id: createIdentifier.ID.String(), - Slug: createIdentifier.Slug, - }, nil -} -func GenerateRandomString(length int) string { - const charset = "abcdefghijklmnopqrstuvwxyz" - seededRand := mathrand.New(mathrand.NewSource(time.Now().UnixNano() + int64(mathrand.Intn(9999)))) - var result strings.Builder - for i := 0; i < length; i++ { - randomIndex := seededRand.Intn(len(charset)) - result.WriteString(string(charset[randomIndex])) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-subscriber.done: + return nil, status.Error(codes.Canceled, "subscriber removed") + case subscriber.controller <- controllerMsg: } - return result.String() -} -func New(database *repository.Queries) *Server { - return &Server{Database: database, Subscriber: make([]chan *proto.Event, 0)} + var resp *proto.Client + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-subscriber.done: + return nil, status.Error(codes.Canceled, "subscriber removed") + case resp = <-subscriber.client: + } + if resp == nil { + return nil, status.Error(codes.FailedPrecondition, "empty response from client") + } + response, ok := resp.Payload.(*proto.Client_SlugEventResponse) + if !ok || response == nil || response.SlugEventResponse == nil { + return nil, status.Error(codes.FailedPrecondition, "invalid slug response payload") + } + return (*proto.ChangeSlugResponse)(response.SlugEventResponse), nil } func (s *Server) ListenAndServe(Addr string) error { @@ -130,19 +226,30 @@ func (s *Server) ListenAndServe(Addr string) error { return err } - grpcServer := grpc.NewServer() + kaParams := keepalive.ServerParameters{ + MaxConnectionIdle: 0, + MaxConnectionAge: 0, + MaxConnectionAgeGrace: 0, + Time: 30 * time.Second, + Timeout: 10 * time.Second, + } + kaPolicy := keepalive.EnforcementPolicy{ + MinTime: 5 * time.Second, + PermitWithoutStream: true, + } + + grpcServer := grpc.NewServer( + grpc.KeepaliveParams(kaParams), + grpc.KeepaliveEnforcementPolicy(kaPolicy), + ) reflection.Register(grpcServer) - proto.RegisterIdentityServer(grpcServer, s) + proto.RegisterSlugChangeServer(grpcServer, s) proto.RegisterEventServiceServer(grpcServer, s) - proto.RegisterSlugServer(grpcServer, s) healthServer := health.NewServer() grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING) - if err := grpcServer.Serve(listener); err != nil { - log.Fatalf("failed to serve: %v", err) - } - return nil + return grpcServer.Serve(listener) }