feat: improve auth

This commit is contained in:
2026-01-03 20:03:54 +07:00
parent c9aa7261e6
commit 69e250b439
2 changed files with 328 additions and 126 deletions

71
main.go
View File

@@ -2,8 +2,13 @@ package main
import ( import (
"context" "context"
"crypto/rand"
"encoding/base64"
"errors"
"log" "log"
"os" "os"
"os/signal"
"syscall"
"git.fossy.my.id/bagas/tunnel-please-controller/db/sqlc/repository" "git.fossy.my.id/bagas/tunnel-please-controller/db/sqlc/repository"
"git.fossy.my.id/bagas/tunnel-please-controller/server" "git.fossy.my.id/bagas/tunnel-please-controller/server"
@@ -18,9 +23,26 @@ func main() {
} }
} }
ctx := context.Background() ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
connect, err := pgx.Connect(ctx, os.Getenv("DATABASE_URL")) log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile)
dbURL := os.Getenv("DATABASE_URL")
if dbURL == "" {
log.Fatal("DATABASE_URL is required")
}
controllerAddr := getenv("CONTROLLER_ADDR", ":8080")
apiAddr := getenv("API_ADDR", ":8081")
authToken := getenv("AUTH_TOKEN", "")
if authToken == "" {
authToken = generateAuthToken()
log.Printf("No AUTH_TOKEN provided. Generated token: %s", authToken)
}
connect, err := pgx.Connect(ctx, dbURL)
if err != nil { if err != nil {
panic(err) panic(err)
return return
@@ -33,15 +55,44 @@ func main() {
}(connect, ctx) }(connect, ctx)
repo := repository.New(connect) repo := repository.New(connect)
s := server.New(repo, "test_auth_key") s := server.New(repo, authToken)
log.SetOutput(os.Stdout) log.Printf("Listening controller on %s", controllerAddr)
log.SetFlags(log.LstdFlags | log.Lshortfile) log.Printf("Listening api on %s", apiAddr)
log.Printf("Listening on :8080\n") errCh := make(chan error, 2)
err = s.ListenAndServe(":8080")
if err != nil { go func() {
panic(err) if err := s.StartAPI(ctx, apiAddr); err != nil && !errors.Is(err, context.Canceled) {
return errCh <- err
}
}()
go func() {
if err := s.StartController(ctx, controllerAddr); err != nil && !errors.Is(err, context.Canceled) {
errCh <- err
}
}()
select {
case <-ctx.Done():
log.Printf("shutting down: %v", ctx.Err())
case err := <-errCh:
log.Fatalf("server error: %v", err)
} }
} }
func getenv(key, def string) string {
if val := os.Getenv(key); val != "" {
return val
}
return def
}
func generateAuthToken() string {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
panic(err)
}
return base64.StdEncoding.EncodeToString(buf)
}

View File

@@ -2,9 +2,13 @@ package server
import ( import (
"context" "context"
"encoding/json"
"errors"
"fmt" "fmt"
"log" "log"
"net" "net"
"net/http"
"reflect"
"sync" "sync"
"time" "time"
@@ -19,106 +23,52 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
const (
defaultSubscriberResponseWait = 5 * time.Second
)
type Subscriber struct { type Subscriber struct {
client chan *proto.Client node chan *proto.Node
controller chan *proto.Controller events chan *proto.Events
done chan struct{} done chan struct{}
closeOnce sync.Once closeOnce sync.Once
mu sync.Mutex
} }
type Server struct { type Server struct {
Database *repository.Queries Database *repository.Queries
Subscribers map[string]*Subscriber Subscribers map[string]*Subscriber
mu *sync.RWMutex mu *sync.RWMutex
authToken string authToken string
broadcastChan chan *proto.Controller
broadcastResultChan chan []SubscriberResult
notifyAllCancel context.CancelFunc
proto.UnimplementedEventServiceServer proto.UnimplementedEventServiceServer
proto.UnimplementedSlugChangeServer proto.UnimplementedSlugChangeServer
proto.UnimplementedUserServiceServer proto.UnimplementedUserServiceServer
proto.UnimplementedUserSessionsServer
} }
func New(database *repository.Queries, authToken string) *Server { func New(database *repository.Queries, authToken string) *Server {
broadcastChan := make(chan *proto.Controller, 10) return &Server{
broadcastResultChan := make(chan []SubscriberResult, 10)
ctx, cancel := context.WithCancel(context.Background())
srv := &Server{
Database: database, Database: database,
Subscribers: make(map[string]*Subscriber), Subscribers: make(map[string]*Subscriber),
mu: new(sync.RWMutex), mu: new(sync.RWMutex),
authToken: authToken, authToken: authToken,
broadcastChan: broadcastChan,
broadcastResultChan: broadcastResultChan,
notifyAllCancel: cancel,
} }
go srv.notifyAllSubscriber(ctx, broadcastChan, broadcastResultChan)
return srv
}
func (s *Server) GetSession(ctx context.Context, req *proto.GetSessionRequest) (*proto.GetSessionsResponse, error) {
if req == nil {
return nil, status.Error(codes.InvalidArgument, "request is nil")
}
controllerReq := &proto.Controller{
Type: proto.EventType_GET_SESSIONS,
Payload: &proto.Controller_GetSessionsEvent{
GetSessionsEvent: &proto.GetSessionsEvent{
Identity: req.GetIdentity(),
},
},
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case s.broadcastChan <- controllerReq:
}
var results []SubscriberResult
select {
case <-ctx.Done():
return nil, ctx.Err()
case results = <-s.broadcastResultChan:
}
responses := make([]*proto.Detail, 0, len(results))
for _, result := range results {
for _, detail := range result.Response.Payload.(*proto.Client_GetSessionsEvent).GetSessionsEvent.Details {
responses = append(responses, detail)
}
responses = append(responses)
}
if len(responses) == 0 {
return nil, status.Error(codes.NotFound, "no subscriber responded")
}
return &proto.GetSessionsResponse{Details: responses}, nil
} }
func (s *Server) Check(ctx context.Context, request *proto.CheckRequest) (*proto.CheckResponse, error) { func (s *Server) Check(ctx context.Context, request *proto.CheckRequest) (*proto.CheckResponse, error) {
exist, err := s.Database.UserExistsByIdentifier(ctx, request.GetAuthToken()) user, err := s.Database.GetVerifiedEmailBySSHIdentifier(ctx, request.GetAuthToken())
if err != nil { if err != nil || user == "" {
return nil, err
}
if exist {
return &proto.CheckResponse{
Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED,
}, nil
}
return &proto.CheckResponse{ return &proto.CheckResponse{
Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED, Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED,
User: "",
}, err
}
return &proto.CheckResponse{
Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED,
User: user,
}, nil }, nil
} }
func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error { func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Node, proto.Events]) error {
ctx := event.Context() ctx := event.Context()
recv, err := event.Recv() recv, err := event.Recv()
if err != nil { if err != nil {
@@ -130,7 +80,7 @@ func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Co
if recv.GetType() != proto.EventType_AUTHENTICATION { if recv.GetType() != proto.EventType_AUTHENTICATION {
return status.Errorf(codes.InvalidArgument, "invalid event type: %s", recv.GetType()) return status.Errorf(codes.InvalidArgument, "invalid event type: %s", recv.GetType())
} }
payload, ok := recv.GetPayload().(*proto.Client_AuthEvent) payload, ok := recv.GetPayload().(*proto.Node_AuthEvent)
if !ok || payload == nil || payload.AuthEvent == nil { if !ok || payload == nil || payload.AuthEvent == nil {
return status.Error(codes.InvalidArgument, "missing auth payload") return status.Error(codes.InvalidArgument, "missing auth payload")
} }
@@ -146,8 +96,8 @@ func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Co
log.Printf("Client %s authenticated successfully", identity) log.Printf("Client %s authenticated successfully", identity)
requestChan := &Subscriber{ requestChan := &Subscriber{
client: make(chan *proto.Client, 10), node: make(chan *proto.Node, 10),
controller: make(chan *proto.Controller, 10), events: make(chan *proto.Events, 10),
done: make(chan struct{}), done: make(chan struct{}),
} }
@@ -162,14 +112,14 @@ func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Co
return processEventStream(ctx, requestChan, event) return processEventStream(ctx, requestChan, event)
} }
func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error { func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc.BidiStreamingServer[proto.Node, proto.Events]) error {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case <-requestChan.done: case <-requestChan.done:
return nil return nil
case request, ok := <-requestChan.controller: case request, ok := <-requestChan.events:
if !ok { if !ok {
return nil return nil
} }
@@ -179,7 +129,7 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc
log.Printf("Received event request: %v", request) log.Printf("Received event request: %v", request)
switch request.GetType() { switch request.GetType() {
case proto.EventType_SLUG_CHANGE: case proto.EventType_SLUG_CHANGE:
payload, ok := request.GetPayload().(*proto.Controller_SlugEvent) payload, ok := request.GetPayload().(*proto.Events_SlugEvent)
if !ok || payload == nil || payload.SlugEvent == nil { if !ok || payload == nil || payload.SlugEvent == nil {
log.Printf("invalid slug change payload") log.Printf("invalid slug change payload")
continue continue
@@ -189,14 +139,14 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc
if err := event.Send(request); err != nil { if err := event.Send(request); err != nil {
return err return err
} }
recv, err := event.Recv() recv, err := recvClientWithTimeout(ctx, requestChan.done, event, defaultSubscriberResponseWait)
if err != nil { if err != nil {
return err return err
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case requestChan.client <- recv: case requestChan.node <- recv:
} }
log.Printf("Received slug change event: %v", recv) log.Printf("Received slug change event: %v", recv)
case proto.EventType_GET_SESSIONS: case proto.EventType_GET_SESSIONS:
@@ -204,14 +154,14 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc
if err := event.Send(request); err != nil { if err := event.Send(request); err != nil {
return err return err
} }
recv, err := event.Recv() recv, err := recvClientWithTimeout(ctx, requestChan.done, event, defaultSubscriberResponseWait)
if err != nil { if err != nil {
return err return err
} }
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
case requestChan.client <- recv: case requestChan.node <- recv:
} }
log.Printf("Received SESSIONS event: %v", recv) log.Printf("Received SESSIONS event: %v", recv)
default: default:
@@ -221,6 +171,36 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc
} }
} }
func recvClientWithTimeout(ctx context.Context, done <-chan struct{}, event grpc.BidiStreamingServer[proto.Node, proto.Events], timeout time.Duration) (*proto.Node, error) {
respCh := make(chan *proto.Node, 1)
errCh := make(chan error, 1)
go func() {
recv, err := event.Recv()
if err != nil {
errCh <- err
return
}
respCh <- recv
}()
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return nil, status.Error(codes.Canceled, "subscriber removed")
case err := <-errCh:
return nil, err
case <-timer.C:
return nil, status.Error(codes.DeadlineExceeded, "client response timeout")
case recv := <-respCh:
return recv, nil
}
}
func (s *Server) AddEventSubscriber(identity string, req *Subscriber) error { func (s *Server) AddEventSubscriber(identity string, req *Subscriber) error {
if identity == "" || req == nil { if identity == "" || req == nil {
return fmt.Errorf("invalid subscriber") return fmt.Errorf("invalid subscriber")
@@ -274,9 +254,9 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu
return nil, err return nil, err
} }
controllerMsg := &proto.Controller{ controllerMsg := &proto.Events{
Type: proto.EventType_SLUG_CHANGE, Type: proto.EventType_SLUG_CHANGE,
Payload: &proto.Controller_SlugEvent{ Payload: &proto.Events_SlugEvent{
SlugEvent: &proto.SlugChangeEvent{ SlugEvent: &proto.SlugChangeEvent{
Old: request.Old, Old: request.Old,
New: request.New, New: request.New,
@@ -284,26 +264,14 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu
}, },
} }
select { resp, err := s.sendAndReceive(ctx, subscriber, controllerMsg, defaultSubscriberResponseWait)
case <-ctx.Done(): if err != nil {
return nil, ctx.Err() return nil, err
case <-subscriber.done:
return nil, status.Error(codes.Canceled, "subscriber removed")
case subscriber.controller <- controllerMsg:
}
var resp *proto.Client
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-subscriber.done:
return nil, status.Error(codes.Canceled, "subscriber removed")
case resp = <-subscriber.client:
} }
if resp == nil { if resp == nil {
return nil, status.Error(codes.FailedPrecondition, "empty response from client") return nil, status.Error(codes.FailedPrecondition, "empty response from client")
} }
response, ok := resp.Payload.(*proto.Client_SlugEventResponse) response, ok := resp.Payload.(*proto.Node_SlugEventResponse)
if !ok || response == nil || response.SlugEventResponse == nil { if !ok || response == nil || response.SlugEventResponse == nil {
return nil, status.Error(codes.FailedPrecondition, "invalid slug response payload") return nil, status.Error(codes.FailedPrecondition, "invalid slug response payload")
} }
@@ -312,11 +280,11 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu
type SubscriberResult struct { type SubscriberResult struct {
Identity string Identity string
Response *proto.Client Response *proto.Node
Err error Err error
} }
func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto.Controller, resultChan chan<- []SubscriberResult) { func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto.Events, resultChan chan<- []SubscriberResult) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -343,7 +311,7 @@ func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto
case <-sub.done: case <-sub.done:
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")}) results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")})
continue continue
case sub.controller <- controllerReq: case sub.events <- controllerReq:
default: default:
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Unavailable, "controller channel blocked")}) results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Unavailable, "controller channel blocked")})
continue continue
@@ -353,7 +321,7 @@ func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto
return return
case <-sub.done: case <-sub.done:
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")}) results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")})
case resp, ok := <-sub.client: case resp, ok := <-sub.node:
if !ok { if !ok {
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.FailedPrecondition, "client channel closed")}) results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.FailedPrecondition, "client channel closed")})
continue continue
@@ -370,7 +338,104 @@ func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto
} }
} }
func (s *Server) ListenAndServe(Addr string) error { func (s *Server) StartAPI(ctx context.Context, Addr string) error {
handler := http.NewServeMux()
httpServer := http.Server{
Addr: Addr,
Handler: handler,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 60 * time.Second,
}
handler.HandleFunc("/api/sessions", func(writer http.ResponseWriter, request *http.Request) {
identity := request.URL.Query().Get("identity")
results := s.broadcastAndCollect(request.Context(), func(ctx context.Context, subscriber *Subscriber) (interface{}, bool) {
receive, err := s.sendAndReceive(ctx, subscriber, &proto.Events{
Type: proto.EventType_GET_SESSIONS,
Payload: &proto.Events_GetSessionsEvent{
GetSessionsEvent: &proto.GetSessionsEvent{
Identity: identity,
},
},
}, defaultSubscriberResponseWait)
if err != nil {
log.Printf("get sessions request failed: %v", err)
return nil, false
}
if receive == nil {
log.Printf("get sessions request returned nil response")
return nil, false
}
payload, ok := receive.Payload.(*proto.Node_GetSessionsEvent)
if !ok || payload == nil || payload.GetSessionsEvent == nil {
log.Printf("unexpected get sessions payload type: %T", receive.Payload)
return nil, false
}
return payload.GetSessionsEvent.Details, true
})
flatten := flattenInterfaces(results)
writer.Header().Set("Content-Type", "application/json")
if len(flatten) == 0 {
_, err := writer.Write([]byte("[]"))
if err != nil {
return
}
return
}
marshal, err := json.Marshal(flatten)
if err != nil {
http.Error(writer, "failed to marshal sessions", http.StatusInternalServerError)
return
}
_, err = writer.Write(marshal)
if err != nil {
return
}
})
errCh := make(chan error, 1)
go func() {
errCh <- httpServer.ListenAndServe()
}()
select {
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
_ = httpServer.Shutdown(shutdownCtx)
return ctx.Err()
case err := <-errCh:
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}
}
func flattenInterfaces(values []interface{}) []interface{} {
var flat []interface{}
for _, v := range values {
if v == nil {
continue
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Slice {
for i := 0; i < rv.Len(); i++ {
flat = append(flat, rv.Index(i).Interface())
}
continue
}
flat = append(flat, v)
}
return flat
}
func (s *Server) StartController(ctx context.Context, Addr string) error {
listener, err := net.Listen("tcp", Addr) listener, err := net.Listen("tcp", Addr)
if err != nil { if err != nil {
return err return err
@@ -397,11 +462,97 @@ func (s *Server) ListenAndServe(Addr string) error {
proto.RegisterSlugChangeServer(grpcServer, s) proto.RegisterSlugChangeServer(grpcServer, s)
proto.RegisterEventServiceServer(grpcServer, s) proto.RegisterEventServiceServer(grpcServer, s)
proto.RegisterUserServiceServer(grpcServer, s) proto.RegisterUserServiceServer(grpcServer, s)
proto.RegisterUserSessionsServer(grpcServer, s)
healthServer := health.NewServer() healthServer := health.NewServer()
grpc_health_v1.RegisterHealthServer(grpcServer, healthServer) grpc_health_v1.RegisterHealthServer(grpcServer, healthServer)
healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING) healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING)
return grpcServer.Serve(listener) serveErr := make(chan error, 1)
go func() {
serveErr <- grpcServer.Serve(listener)
}()
select {
case <-ctx.Done():
grpcServer.GracefulStop()
err := <-serveErr
if err != nil {
return err
}
return ctx.Err()
case err := <-serveErr:
return err
}
}
func (s *Server) sendAndReceive(ctx context.Context, sub *Subscriber, req *proto.Events, wait time.Duration) (*proto.Node, error) {
if sub == nil || req == nil {
return nil, status.Error(codes.InvalidArgument, "missing subscriber or request")
}
sub.mu.Lock()
defer sub.mu.Unlock()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-sub.done:
return nil, status.Error(codes.Canceled, "subscriber removed")
case sub.events <- req:
}
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-sub.done:
return nil, status.Error(codes.Canceled, "subscriber removed")
case <-timer.C:
return nil, status.Error(codes.DeadlineExceeded, "subscriber response timeout")
case resp, ok := <-sub.node:
if !ok {
return nil, status.Error(codes.FailedPrecondition, "client channel closed")
}
return resp, nil
}
}
func (s *Server) snapshotSubscribers() []*Subscriber {
s.mu.RLock()
defer s.mu.RUnlock()
subs := make([]*Subscriber, 0, len(s.Subscribers))
for _, sub := range s.Subscribers {
subs = append(subs, sub)
}
return subs
}
func (s *Server) broadcastAndCollect(ctx context.Context, worker func(context.Context, *Subscriber) (interface{}, bool)) []interface{} {
subs := s.snapshotSubscribers()
if len(subs) == 0 {
return nil
}
resultsCh := make(chan interface{}, len(subs))
var wg sync.WaitGroup
for _, sub := range subs {
wg.Add(1)
go func(subscriber *Subscriber) {
defer wg.Done()
if val, ok := worker(ctx, subscriber); ok {
resultsCh <- val
}
}(sub)
}
wg.Wait()
close(resultsCh)
var results []interface{}
for val := range resultsCh {
results = append(results, val)
}
return results
} }