From 69e250b4395024231fedcc9b0f088425a38017f9 Mon Sep 17 00:00:00 2001 From: bagas Date: Sat, 3 Jan 2026 20:03:54 +0700 Subject: [PATCH] feat: improve auth --- main.go | 71 +++++++-- server/server.go | 383 +++++++++++++++++++++++++++++++++-------------- 2 files changed, 328 insertions(+), 126 deletions(-) diff --git a/main.go b/main.go index 6ee73c5..b3a2cf4 100644 --- a/main.go +++ b/main.go @@ -2,8 +2,13 @@ package main import ( "context" + "crypto/rand" + "encoding/base64" + "errors" "log" "os" + "os/signal" + "syscall" "git.fossy.my.id/bagas/tunnel-please-controller/db/sqlc/repository" "git.fossy.my.id/bagas/tunnel-please-controller/server" @@ -18,9 +23,26 @@ func main() { } } - ctx := context.Background() + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() - connect, err := pgx.Connect(ctx, os.Getenv("DATABASE_URL")) + log.SetOutput(os.Stdout) + log.SetFlags(log.LstdFlags | log.Lshortfile) + + dbURL := os.Getenv("DATABASE_URL") + if dbURL == "" { + log.Fatal("DATABASE_URL is required") + } + + controllerAddr := getenv("CONTROLLER_ADDR", ":8080") + apiAddr := getenv("API_ADDR", ":8081") + authToken := getenv("AUTH_TOKEN", "") + if authToken == "" { + authToken = generateAuthToken() + log.Printf("No AUTH_TOKEN provided. Generated token: %s", authToken) + } + + connect, err := pgx.Connect(ctx, dbURL) if err != nil { panic(err) return @@ -33,15 +55,44 @@ func main() { }(connect, ctx) repo := repository.New(connect) - s := server.New(repo, "test_auth_key") + s := server.New(repo, authToken) - log.SetOutput(os.Stdout) - log.SetFlags(log.LstdFlags | log.Lshortfile) + log.Printf("Listening controller on %s", controllerAddr) + log.Printf("Listening api on %s", apiAddr) - log.Printf("Listening on :8080\n") - err = s.ListenAndServe(":8080") - if err != nil { - panic(err) - return + errCh := make(chan error, 2) + + go func() { + if err := s.StartAPI(ctx, apiAddr); err != nil && !errors.Is(err, context.Canceled) { + errCh <- err + } + }() + + go func() { + if err := s.StartController(ctx, controllerAddr); err != nil && !errors.Is(err, context.Canceled) { + errCh <- err + } + }() + + select { + case <-ctx.Done(): + log.Printf("shutting down: %v", ctx.Err()) + case err := <-errCh: + log.Fatalf("server error: %v", err) } } + +func getenv(key, def string) string { + if val := os.Getenv(key); val != "" { + return val + } + return def +} + +func generateAuthToken() string { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + panic(err) + } + return base64.StdEncoding.EncodeToString(buf) +} diff --git a/server/server.go b/server/server.go index d36a8ac..4a46653 100644 --- a/server/server.go +++ b/server/server.go @@ -2,9 +2,13 @@ package server import ( "context" + "encoding/json" + "errors" "fmt" "log" "net" + "net/http" + "reflect" "sync" "time" @@ -19,106 +23,52 @@ import ( "google.golang.org/grpc/status" ) +const ( + defaultSubscriberResponseWait = 5 * time.Second +) + type Subscriber struct { - client chan *proto.Client - controller chan *proto.Controller - done chan struct{} - closeOnce sync.Once + node chan *proto.Node + events chan *proto.Events + done chan struct{} + closeOnce sync.Once + mu sync.Mutex } type Server struct { - Database *repository.Queries - Subscribers map[string]*Subscriber - mu *sync.RWMutex - authToken string - broadcastChan chan *proto.Controller - broadcastResultChan chan []SubscriberResult - notifyAllCancel context.CancelFunc + Database *repository.Queries + Subscribers map[string]*Subscriber + mu *sync.RWMutex + authToken string proto.UnimplementedEventServiceServer proto.UnimplementedSlugChangeServer proto.UnimplementedUserServiceServer - proto.UnimplementedUserSessionsServer } func New(database *repository.Queries, authToken string) *Server { - broadcastChan := make(chan *proto.Controller, 10) - broadcastResultChan := make(chan []SubscriberResult, 10) - - ctx, cancel := context.WithCancel(context.Background()) - - srv := &Server{ - Database: database, - Subscribers: make(map[string]*Subscriber), - mu: new(sync.RWMutex), - authToken: authToken, - broadcastChan: broadcastChan, - broadcastResultChan: broadcastResultChan, - notifyAllCancel: cancel, + return &Server{ + Database: database, + Subscribers: make(map[string]*Subscriber), + mu: new(sync.RWMutex), + authToken: authToken, } - - go srv.notifyAllSubscriber(ctx, broadcastChan, broadcastResultChan) - - return srv -} - -func (s *Server) GetSession(ctx context.Context, req *proto.GetSessionRequest) (*proto.GetSessionsResponse, error) { - if req == nil { - return nil, status.Error(codes.InvalidArgument, "request is nil") - } - - controllerReq := &proto.Controller{ - Type: proto.EventType_GET_SESSIONS, - Payload: &proto.Controller_GetSessionsEvent{ - GetSessionsEvent: &proto.GetSessionsEvent{ - Identity: req.GetIdentity(), - }, - }, - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case s.broadcastChan <- controllerReq: - } - - var results []SubscriberResult - select { - case <-ctx.Done(): - return nil, ctx.Err() - case results = <-s.broadcastResultChan: - } - responses := make([]*proto.Detail, 0, len(results)) - for _, result := range results { - for _, detail := range result.Response.Payload.(*proto.Client_GetSessionsEvent).GetSessionsEvent.Details { - responses = append(responses, detail) - } - responses = append(responses) - } - - if len(responses) == 0 { - return nil, status.Error(codes.NotFound, "no subscriber responded") - } - - return &proto.GetSessionsResponse{Details: responses}, nil } func (s *Server) Check(ctx context.Context, request *proto.CheckRequest) (*proto.CheckResponse, error) { - exist, err := s.Database.UserExistsByIdentifier(ctx, request.GetAuthToken()) - if err != nil { - return nil, err - } - - if exist { + user, err := s.Database.GetVerifiedEmailBySSHIdentifier(ctx, request.GetAuthToken()) + if err != nil || user == "" { return &proto.CheckResponse{ - Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED, - }, nil + Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED, + User: "", + }, err } return &proto.CheckResponse{ - Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED, + Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED, + User: user, }, nil } -func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error { +func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Node, proto.Events]) error { ctx := event.Context() recv, err := event.Recv() if err != nil { @@ -130,7 +80,7 @@ func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Co if recv.GetType() != proto.EventType_AUTHENTICATION { return status.Errorf(codes.InvalidArgument, "invalid event type: %s", recv.GetType()) } - payload, ok := recv.GetPayload().(*proto.Client_AuthEvent) + payload, ok := recv.GetPayload().(*proto.Node_AuthEvent) if !ok || payload == nil || payload.AuthEvent == nil { return status.Error(codes.InvalidArgument, "missing auth payload") } @@ -146,9 +96,9 @@ func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Co log.Printf("Client %s authenticated successfully", identity) requestChan := &Subscriber{ - client: make(chan *proto.Client, 10), - controller: make(chan *proto.Controller, 10), - done: make(chan struct{}), + node: make(chan *proto.Node, 10), + events: make(chan *proto.Events, 10), + done: make(chan struct{}), } if err = s.AddEventSubscriber(identity, requestChan); err != nil { @@ -162,14 +112,14 @@ func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Co return processEventStream(ctx, requestChan, event) } -func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error { +func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc.BidiStreamingServer[proto.Node, proto.Events]) error { for { select { case <-ctx.Done(): return ctx.Err() case <-requestChan.done: return nil - case request, ok := <-requestChan.controller: + case request, ok := <-requestChan.events: if !ok { return nil } @@ -179,7 +129,7 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc log.Printf("Received event request: %v", request) switch request.GetType() { case proto.EventType_SLUG_CHANGE: - payload, ok := request.GetPayload().(*proto.Controller_SlugEvent) + payload, ok := request.GetPayload().(*proto.Events_SlugEvent) if !ok || payload == nil || payload.SlugEvent == nil { log.Printf("invalid slug change payload") continue @@ -189,14 +139,14 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc if err := event.Send(request); err != nil { return err } - recv, err := event.Recv() + recv, err := recvClientWithTimeout(ctx, requestChan.done, event, defaultSubscriberResponseWait) if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() - case requestChan.client <- recv: + case requestChan.node <- recv: } log.Printf("Received slug change event: %v", recv) case proto.EventType_GET_SESSIONS: @@ -204,14 +154,14 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc if err := event.Send(request); err != nil { return err } - recv, err := event.Recv() + recv, err := recvClientWithTimeout(ctx, requestChan.done, event, defaultSubscriberResponseWait) if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() - case requestChan.client <- recv: + case requestChan.node <- recv: } log.Printf("Received SESSIONS event: %v", recv) default: @@ -221,6 +171,36 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc } } +func recvClientWithTimeout(ctx context.Context, done <-chan struct{}, event grpc.BidiStreamingServer[proto.Node, proto.Events], timeout time.Duration) (*proto.Node, error) { + respCh := make(chan *proto.Node, 1) + errCh := make(chan error, 1) + + go func() { + recv, err := event.Recv() + if err != nil { + errCh <- err + return + } + respCh <- recv + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-done: + return nil, status.Error(codes.Canceled, "subscriber removed") + case err := <-errCh: + return nil, err + case <-timer.C: + return nil, status.Error(codes.DeadlineExceeded, "client response timeout") + case recv := <-respCh: + return recv, nil + } +} + func (s *Server) AddEventSubscriber(identity string, req *Subscriber) error { if identity == "" || req == nil { return fmt.Errorf("invalid subscriber") @@ -274,9 +254,9 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu return nil, err } - controllerMsg := &proto.Controller{ + controllerMsg := &proto.Events{ Type: proto.EventType_SLUG_CHANGE, - Payload: &proto.Controller_SlugEvent{ + Payload: &proto.Events_SlugEvent{ SlugEvent: &proto.SlugChangeEvent{ Old: request.Old, New: request.New, @@ -284,26 +264,14 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu }, } - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-subscriber.done: - return nil, status.Error(codes.Canceled, "subscriber removed") - case subscriber.controller <- controllerMsg: - } - - var resp *proto.Client - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-subscriber.done: - return nil, status.Error(codes.Canceled, "subscriber removed") - case resp = <-subscriber.client: + resp, err := s.sendAndReceive(ctx, subscriber, controllerMsg, defaultSubscriberResponseWait) + if err != nil { + return nil, err } if resp == nil { return nil, status.Error(codes.FailedPrecondition, "empty response from client") } - response, ok := resp.Payload.(*proto.Client_SlugEventResponse) + response, ok := resp.Payload.(*proto.Node_SlugEventResponse) if !ok || response == nil || response.SlugEventResponse == nil { return nil, status.Error(codes.FailedPrecondition, "invalid slug response payload") } @@ -312,11 +280,11 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu type SubscriberResult struct { Identity string - Response *proto.Client + Response *proto.Node Err error } -func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto.Controller, resultChan chan<- []SubscriberResult) { +func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto.Events, resultChan chan<- []SubscriberResult) { for { select { case <-ctx.Done(): @@ -343,7 +311,7 @@ func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto case <-sub.done: results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")}) continue - case sub.controller <- controllerReq: + case sub.events <- controllerReq: default: results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Unavailable, "controller channel blocked")}) continue @@ -353,7 +321,7 @@ func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto return case <-sub.done: results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")}) - case resp, ok := <-sub.client: + case resp, ok := <-sub.node: if !ok { results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.FailedPrecondition, "client channel closed")}) continue @@ -370,7 +338,104 @@ func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto } } -func (s *Server) ListenAndServe(Addr string) error { +func (s *Server) StartAPI(ctx context.Context, Addr string) error { + handler := http.NewServeMux() + httpServer := http.Server{ + Addr: Addr, + Handler: handler, + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, + } + + handler.HandleFunc("/api/sessions", func(writer http.ResponseWriter, request *http.Request) { + identity := request.URL.Query().Get("identity") + results := s.broadcastAndCollect(request.Context(), func(ctx context.Context, subscriber *Subscriber) (interface{}, bool) { + receive, err := s.sendAndReceive(ctx, subscriber, &proto.Events{ + Type: proto.EventType_GET_SESSIONS, + Payload: &proto.Events_GetSessionsEvent{ + GetSessionsEvent: &proto.GetSessionsEvent{ + Identity: identity, + }, + }, + }, defaultSubscriberResponseWait) + if err != nil { + log.Printf("get sessions request failed: %v", err) + return nil, false + } + if receive == nil { + log.Printf("get sessions request returned nil response") + return nil, false + } + payload, ok := receive.Payload.(*proto.Node_GetSessionsEvent) + if !ok || payload == nil || payload.GetSessionsEvent == nil { + log.Printf("unexpected get sessions payload type: %T", receive.Payload) + return nil, false + } + return payload.GetSessionsEvent.Details, true + }) + + flatten := flattenInterfaces(results) + + writer.Header().Set("Content-Type", "application/json") + + if len(flatten) == 0 { + _, err := writer.Write([]byte("[]")) + if err != nil { + return + } + return + } + + marshal, err := json.Marshal(flatten) + if err != nil { + http.Error(writer, "failed to marshal sessions", http.StatusInternalServerError) + return + } + _, err = writer.Write(marshal) + if err != nil { + return + } + }) + + errCh := make(chan error, 1) + go func() { + errCh <- httpServer.ListenAndServe() + }() + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = httpServer.Shutdown(shutdownCtx) + return ctx.Err() + case err := <-errCh: + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err + } +} + +func flattenInterfaces(values []interface{}) []interface{} { + var flat []interface{} + for _, v := range values { + if v == nil { + continue + } + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Slice { + for i := 0; i < rv.Len(); i++ { + flat = append(flat, rv.Index(i).Interface()) + } + continue + } + flat = append(flat, v) + } + return flat +} + +func (s *Server) StartController(ctx context.Context, Addr string) error { listener, err := net.Listen("tcp", Addr) if err != nil { return err @@ -397,11 +462,97 @@ func (s *Server) ListenAndServe(Addr string) error { proto.RegisterSlugChangeServer(grpcServer, s) proto.RegisterEventServiceServer(grpcServer, s) proto.RegisterUserServiceServer(grpcServer, s) - proto.RegisterUserSessionsServer(grpcServer, s) healthServer := health.NewServer() grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING) - return grpcServer.Serve(listener) + serveErr := make(chan error, 1) + go func() { + serveErr <- grpcServer.Serve(listener) + }() + + select { + case <-ctx.Done(): + grpcServer.GracefulStop() + err := <-serveErr + if err != nil { + return err + } + return ctx.Err() + case err := <-serveErr: + return err + } +} + +func (s *Server) sendAndReceive(ctx context.Context, sub *Subscriber, req *proto.Events, wait time.Duration) (*proto.Node, error) { + if sub == nil || req == nil { + return nil, status.Error(codes.InvalidArgument, "missing subscriber or request") + } + + sub.mu.Lock() + defer sub.mu.Unlock() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-sub.done: + return nil, status.Error(codes.Canceled, "subscriber removed") + case sub.events <- req: + } + + timer := time.NewTimer(wait) + defer timer.Stop() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-sub.done: + return nil, status.Error(codes.Canceled, "subscriber removed") + case <-timer.C: + return nil, status.Error(codes.DeadlineExceeded, "subscriber response timeout") + case resp, ok := <-sub.node: + if !ok { + return nil, status.Error(codes.FailedPrecondition, "client channel closed") + } + return resp, nil + } +} + +func (s *Server) snapshotSubscribers() []*Subscriber { + s.mu.RLock() + defer s.mu.RUnlock() + subs := make([]*Subscriber, 0, len(s.Subscribers)) + for _, sub := range s.Subscribers { + subs = append(subs, sub) + } + return subs +} + +func (s *Server) broadcastAndCollect(ctx context.Context, worker func(context.Context, *Subscriber) (interface{}, bool)) []interface{} { + subs := s.snapshotSubscribers() + if len(subs) == 0 { + return nil + } + + resultsCh := make(chan interface{}, len(subs)) + var wg sync.WaitGroup + for _, sub := range subs { + wg.Add(1) + go func(subscriber *Subscriber) { + defer wg.Done() + if val, ok := worker(ctx, subscriber); ok { + resultsCh <- val + } + }(sub) + } + + wg.Wait() + close(resultsCh) + + var results []interface{} + for val := range resultsCh { + results = append(results, val) + } + return results }