feat: improve auth
This commit is contained in:
69
main.go
69
main.go
@@ -2,8 +2,13 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"git.fossy.my.id/bagas/tunnel-please-controller/db/sqlc/repository"
|
"git.fossy.my.id/bagas/tunnel-please-controller/db/sqlc/repository"
|
||||||
"git.fossy.my.id/bagas/tunnel-please-controller/server"
|
"git.fossy.my.id/bagas/tunnel-please-controller/server"
|
||||||
@@ -18,9 +23,26 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
connect, err := pgx.Connect(ctx, os.Getenv("DATABASE_URL"))
|
log.SetOutput(os.Stdout)
|
||||||
|
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||||
|
|
||||||
|
dbURL := os.Getenv("DATABASE_URL")
|
||||||
|
if dbURL == "" {
|
||||||
|
log.Fatal("DATABASE_URL is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
controllerAddr := getenv("CONTROLLER_ADDR", ":8080")
|
||||||
|
apiAddr := getenv("API_ADDR", ":8081")
|
||||||
|
authToken := getenv("AUTH_TOKEN", "")
|
||||||
|
if authToken == "" {
|
||||||
|
authToken = generateAuthToken()
|
||||||
|
log.Printf("No AUTH_TOKEN provided. Generated token: %s", authToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
connect, err := pgx.Connect(ctx, dbURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
return
|
return
|
||||||
@@ -33,15 +55,44 @@ func main() {
|
|||||||
}(connect, ctx)
|
}(connect, ctx)
|
||||||
|
|
||||||
repo := repository.New(connect)
|
repo := repository.New(connect)
|
||||||
s := server.New(repo, "test_auth_key")
|
s := server.New(repo, authToken)
|
||||||
|
|
||||||
log.SetOutput(os.Stdout)
|
log.Printf("Listening controller on %s", controllerAddr)
|
||||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
log.Printf("Listening api on %s", apiAddr)
|
||||||
|
|
||||||
log.Printf("Listening on :8080\n")
|
errCh := make(chan error, 2)
|
||||||
err = s.ListenAndServe(":8080")
|
|
||||||
if err != nil {
|
go func() {
|
||||||
|
if err := s.StartAPI(ctx, apiAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := s.StartController(ctx, controllerAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Printf("shutting down: %v", ctx.Err())
|
||||||
|
case err := <-errCh:
|
||||||
|
log.Fatalf("server error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getenv(key, def string) string {
|
||||||
|
if val := os.Getenv(key); val != "" {
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateAuthToken() string {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(buf)
|
||||||
}
|
}
|
||||||
|
|||||||
365
server/server.go
365
server/server.go
@@ -2,9 +2,13 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,106 +23,52 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultSubscriberResponseWait = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
type Subscriber struct {
|
type Subscriber struct {
|
||||||
client chan *proto.Client
|
node chan *proto.Node
|
||||||
controller chan *proto.Controller
|
events chan *proto.Events
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
type Server struct {
|
type Server struct {
|
||||||
Database *repository.Queries
|
Database *repository.Queries
|
||||||
Subscribers map[string]*Subscriber
|
Subscribers map[string]*Subscriber
|
||||||
mu *sync.RWMutex
|
mu *sync.RWMutex
|
||||||
authToken string
|
authToken string
|
||||||
broadcastChan chan *proto.Controller
|
|
||||||
broadcastResultChan chan []SubscriberResult
|
|
||||||
notifyAllCancel context.CancelFunc
|
|
||||||
proto.UnimplementedEventServiceServer
|
proto.UnimplementedEventServiceServer
|
||||||
proto.UnimplementedSlugChangeServer
|
proto.UnimplementedSlugChangeServer
|
||||||
proto.UnimplementedUserServiceServer
|
proto.UnimplementedUserServiceServer
|
||||||
proto.UnimplementedUserSessionsServer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(database *repository.Queries, authToken string) *Server {
|
func New(database *repository.Queries, authToken string) *Server {
|
||||||
broadcastChan := make(chan *proto.Controller, 10)
|
return &Server{
|
||||||
broadcastResultChan := make(chan []SubscriberResult, 10)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
|
|
||||||
srv := &Server{
|
|
||||||
Database: database,
|
Database: database,
|
||||||
Subscribers: make(map[string]*Subscriber),
|
Subscribers: make(map[string]*Subscriber),
|
||||||
mu: new(sync.RWMutex),
|
mu: new(sync.RWMutex),
|
||||||
authToken: authToken,
|
authToken: authToken,
|
||||||
broadcastChan: broadcastChan,
|
|
||||||
broadcastResultChan: broadcastResultChan,
|
|
||||||
notifyAllCancel: cancel,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go srv.notifyAllSubscriber(ctx, broadcastChan, broadcastResultChan)
|
|
||||||
|
|
||||||
return srv
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) GetSession(ctx context.Context, req *proto.GetSessionRequest) (*proto.GetSessionsResponse, error) {
|
|
||||||
if req == nil {
|
|
||||||
return nil, status.Error(codes.InvalidArgument, "request is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
controllerReq := &proto.Controller{
|
|
||||||
Type: proto.EventType_GET_SESSIONS,
|
|
||||||
Payload: &proto.Controller_GetSessionsEvent{
|
|
||||||
GetSessionsEvent: &proto.GetSessionsEvent{
|
|
||||||
Identity: req.GetIdentity(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case s.broadcastChan <- controllerReq:
|
|
||||||
}
|
|
||||||
|
|
||||||
var results []SubscriberResult
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case results = <-s.broadcastResultChan:
|
|
||||||
}
|
|
||||||
responses := make([]*proto.Detail, 0, len(results))
|
|
||||||
for _, result := range results {
|
|
||||||
for _, detail := range result.Response.Payload.(*proto.Client_GetSessionsEvent).GetSessionsEvent.Details {
|
|
||||||
responses = append(responses, detail)
|
|
||||||
}
|
|
||||||
responses = append(responses)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(responses) == 0 {
|
|
||||||
return nil, status.Error(codes.NotFound, "no subscriber responded")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &proto.GetSessionsResponse{Details: responses}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Check(ctx context.Context, request *proto.CheckRequest) (*proto.CheckResponse, error) {
|
func (s *Server) Check(ctx context.Context, request *proto.CheckRequest) (*proto.CheckResponse, error) {
|
||||||
exist, err := s.Database.UserExistsByIdentifier(ctx, request.GetAuthToken())
|
user, err := s.Database.GetVerifiedEmailBySSHIdentifier(ctx, request.GetAuthToken())
|
||||||
if err != nil {
|
if err != nil || user == "" {
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if exist {
|
|
||||||
return &proto.CheckResponse{
|
|
||||||
Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &proto.CheckResponse{
|
return &proto.CheckResponse{
|
||||||
Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED,
|
Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED,
|
||||||
|
User: "",
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.CheckResponse{
|
||||||
|
Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED,
|
||||||
|
User: user,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error {
|
func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Node, proto.Events]) error {
|
||||||
ctx := event.Context()
|
ctx := event.Context()
|
||||||
recv, err := event.Recv()
|
recv, err := event.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -130,7 +80,7 @@ func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Co
|
|||||||
if recv.GetType() != proto.EventType_AUTHENTICATION {
|
if recv.GetType() != proto.EventType_AUTHENTICATION {
|
||||||
return status.Errorf(codes.InvalidArgument, "invalid event type: %s", recv.GetType())
|
return status.Errorf(codes.InvalidArgument, "invalid event type: %s", recv.GetType())
|
||||||
}
|
}
|
||||||
payload, ok := recv.GetPayload().(*proto.Client_AuthEvent)
|
payload, ok := recv.GetPayload().(*proto.Node_AuthEvent)
|
||||||
if !ok || payload == nil || payload.AuthEvent == nil {
|
if !ok || payload == nil || payload.AuthEvent == nil {
|
||||||
return status.Error(codes.InvalidArgument, "missing auth payload")
|
return status.Error(codes.InvalidArgument, "missing auth payload")
|
||||||
}
|
}
|
||||||
@@ -146,8 +96,8 @@ func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Co
|
|||||||
log.Printf("Client %s authenticated successfully", identity)
|
log.Printf("Client %s authenticated successfully", identity)
|
||||||
|
|
||||||
requestChan := &Subscriber{
|
requestChan := &Subscriber{
|
||||||
client: make(chan *proto.Client, 10),
|
node: make(chan *proto.Node, 10),
|
||||||
controller: make(chan *proto.Controller, 10),
|
events: make(chan *proto.Events, 10),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,14 +112,14 @@ func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Co
|
|||||||
return processEventStream(ctx, requestChan, event)
|
return processEventStream(ctx, requestChan, event)
|
||||||
}
|
}
|
||||||
|
|
||||||
func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error {
|
func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc.BidiStreamingServer[proto.Node, proto.Events]) error {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case <-requestChan.done:
|
case <-requestChan.done:
|
||||||
return nil
|
return nil
|
||||||
case request, ok := <-requestChan.controller:
|
case request, ok := <-requestChan.events:
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -179,7 +129,7 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc
|
|||||||
log.Printf("Received event request: %v", request)
|
log.Printf("Received event request: %v", request)
|
||||||
switch request.GetType() {
|
switch request.GetType() {
|
||||||
case proto.EventType_SLUG_CHANGE:
|
case proto.EventType_SLUG_CHANGE:
|
||||||
payload, ok := request.GetPayload().(*proto.Controller_SlugEvent)
|
payload, ok := request.GetPayload().(*proto.Events_SlugEvent)
|
||||||
if !ok || payload == nil || payload.SlugEvent == nil {
|
if !ok || payload == nil || payload.SlugEvent == nil {
|
||||||
log.Printf("invalid slug change payload")
|
log.Printf("invalid slug change payload")
|
||||||
continue
|
continue
|
||||||
@@ -189,14 +139,14 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc
|
|||||||
if err := event.Send(request); err != nil {
|
if err := event.Send(request); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
recv, err := event.Recv()
|
recv, err := recvClientWithTimeout(ctx, requestChan.done, event, defaultSubscriberResponseWait)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case requestChan.client <- recv:
|
case requestChan.node <- recv:
|
||||||
}
|
}
|
||||||
log.Printf("Received slug change event: %v", recv)
|
log.Printf("Received slug change event: %v", recv)
|
||||||
case proto.EventType_GET_SESSIONS:
|
case proto.EventType_GET_SESSIONS:
|
||||||
@@ -204,14 +154,14 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc
|
|||||||
if err := event.Send(request); err != nil {
|
if err := event.Send(request); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
recv, err := event.Recv()
|
recv, err := recvClientWithTimeout(ctx, requestChan.done, event, defaultSubscriberResponseWait)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
case requestChan.client <- recv:
|
case requestChan.node <- recv:
|
||||||
}
|
}
|
||||||
log.Printf("Received SESSIONS event: %v", recv)
|
log.Printf("Received SESSIONS event: %v", recv)
|
||||||
default:
|
default:
|
||||||
@@ -221,6 +171,36 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func recvClientWithTimeout(ctx context.Context, done <-chan struct{}, event grpc.BidiStreamingServer[proto.Node, proto.Events], timeout time.Duration) (*proto.Node, error) {
|
||||||
|
respCh := make(chan *proto.Node, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
recv, err := event.Recv()
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
respCh <- recv
|
||||||
|
}()
|
||||||
|
|
||||||
|
timer := time.NewTimer(timeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-done:
|
||||||
|
return nil, status.Error(codes.Canceled, "subscriber removed")
|
||||||
|
case err := <-errCh:
|
||||||
|
return nil, err
|
||||||
|
case <-timer.C:
|
||||||
|
return nil, status.Error(codes.DeadlineExceeded, "client response timeout")
|
||||||
|
case recv := <-respCh:
|
||||||
|
return recv, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) AddEventSubscriber(identity string, req *Subscriber) error {
|
func (s *Server) AddEventSubscriber(identity string, req *Subscriber) error {
|
||||||
if identity == "" || req == nil {
|
if identity == "" || req == nil {
|
||||||
return fmt.Errorf("invalid subscriber")
|
return fmt.Errorf("invalid subscriber")
|
||||||
@@ -274,9 +254,9 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
controllerMsg := &proto.Controller{
|
controllerMsg := &proto.Events{
|
||||||
Type: proto.EventType_SLUG_CHANGE,
|
Type: proto.EventType_SLUG_CHANGE,
|
||||||
Payload: &proto.Controller_SlugEvent{
|
Payload: &proto.Events_SlugEvent{
|
||||||
SlugEvent: &proto.SlugChangeEvent{
|
SlugEvent: &proto.SlugChangeEvent{
|
||||||
Old: request.Old,
|
Old: request.Old,
|
||||||
New: request.New,
|
New: request.New,
|
||||||
@@ -284,26 +264,14 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
resp, err := s.sendAndReceive(ctx, subscriber, controllerMsg, defaultSubscriberResponseWait)
|
||||||
case <-ctx.Done():
|
if err != nil {
|
||||||
return nil, ctx.Err()
|
return nil, err
|
||||||
case <-subscriber.done:
|
|
||||||
return nil, status.Error(codes.Canceled, "subscriber removed")
|
|
||||||
case subscriber.controller <- controllerMsg:
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp *proto.Client
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-subscriber.done:
|
|
||||||
return nil, status.Error(codes.Canceled, "subscriber removed")
|
|
||||||
case resp = <-subscriber.client:
|
|
||||||
}
|
}
|
||||||
if resp == nil {
|
if resp == nil {
|
||||||
return nil, status.Error(codes.FailedPrecondition, "empty response from client")
|
return nil, status.Error(codes.FailedPrecondition, "empty response from client")
|
||||||
}
|
}
|
||||||
response, ok := resp.Payload.(*proto.Client_SlugEventResponse)
|
response, ok := resp.Payload.(*proto.Node_SlugEventResponse)
|
||||||
if !ok || response == nil || response.SlugEventResponse == nil {
|
if !ok || response == nil || response.SlugEventResponse == nil {
|
||||||
return nil, status.Error(codes.FailedPrecondition, "invalid slug response payload")
|
return nil, status.Error(codes.FailedPrecondition, "invalid slug response payload")
|
||||||
}
|
}
|
||||||
@@ -312,11 +280,11 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu
|
|||||||
|
|
||||||
type SubscriberResult struct {
|
type SubscriberResult struct {
|
||||||
Identity string
|
Identity string
|
||||||
Response *proto.Client
|
Response *proto.Node
|
||||||
Err error
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto.Controller, resultChan chan<- []SubscriberResult) {
|
func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto.Events, resultChan chan<- []SubscriberResult) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -343,7 +311,7 @@ func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto
|
|||||||
case <-sub.done:
|
case <-sub.done:
|
||||||
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")})
|
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")})
|
||||||
continue
|
continue
|
||||||
case sub.controller <- controllerReq:
|
case sub.events <- controllerReq:
|
||||||
default:
|
default:
|
||||||
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Unavailable, "controller channel blocked")})
|
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Unavailable, "controller channel blocked")})
|
||||||
continue
|
continue
|
||||||
@@ -353,7 +321,7 @@ func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto
|
|||||||
return
|
return
|
||||||
case <-sub.done:
|
case <-sub.done:
|
||||||
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")})
|
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")})
|
||||||
case resp, ok := <-sub.client:
|
case resp, ok := <-sub.node:
|
||||||
if !ok {
|
if !ok {
|
||||||
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.FailedPrecondition, "client channel closed")})
|
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.FailedPrecondition, "client channel closed")})
|
||||||
continue
|
continue
|
||||||
@@ -370,7 +338,104 @@ func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListenAndServe(Addr string) error {
|
func (s *Server) StartAPI(ctx context.Context, Addr string) error {
|
||||||
|
handler := http.NewServeMux()
|
||||||
|
httpServer := http.Server{
|
||||||
|
Addr: Addr,
|
||||||
|
Handler: handler,
|
||||||
|
ReadTimeout: 15 * time.Second,
|
||||||
|
WriteTimeout: 15 * time.Second,
|
||||||
|
IdleTimeout: 60 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler.HandleFunc("/api/sessions", func(writer http.ResponseWriter, request *http.Request) {
|
||||||
|
identity := request.URL.Query().Get("identity")
|
||||||
|
results := s.broadcastAndCollect(request.Context(), func(ctx context.Context, subscriber *Subscriber) (interface{}, bool) {
|
||||||
|
receive, err := s.sendAndReceive(ctx, subscriber, &proto.Events{
|
||||||
|
Type: proto.EventType_GET_SESSIONS,
|
||||||
|
Payload: &proto.Events_GetSessionsEvent{
|
||||||
|
GetSessionsEvent: &proto.GetSessionsEvent{
|
||||||
|
Identity: identity,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, defaultSubscriberResponseWait)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("get sessions request failed: %v", err)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if receive == nil {
|
||||||
|
log.Printf("get sessions request returned nil response")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
payload, ok := receive.Payload.(*proto.Node_GetSessionsEvent)
|
||||||
|
if !ok || payload == nil || payload.GetSessionsEvent == nil {
|
||||||
|
log.Printf("unexpected get sessions payload type: %T", receive.Payload)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return payload.GetSessionsEvent.Details, true
|
||||||
|
})
|
||||||
|
|
||||||
|
flatten := flattenInterfaces(results)
|
||||||
|
|
||||||
|
writer.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if len(flatten) == 0 {
|
||||||
|
_, err := writer.Write([]byte("[]"))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
marshal, err := json.Marshal(flatten)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, "failed to marshal sessions", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = writer.Write(marshal)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errCh <- httpServer.ListenAndServe()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = httpServer.Shutdown(shutdownCtx)
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-errCh:
|
||||||
|
if errors.Is(err, http.ErrServerClosed) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func flattenInterfaces(values []interface{}) []interface{} {
|
||||||
|
var flat []interface{}
|
||||||
|
for _, v := range values {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
if rv.Kind() == reflect.Slice {
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
flat = append(flat, rv.Index(i).Interface())
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
flat = append(flat, v)
|
||||||
|
}
|
||||||
|
return flat
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) StartController(ctx context.Context, Addr string) error {
|
||||||
listener, err := net.Listen("tcp", Addr)
|
listener, err := net.Listen("tcp", Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -397,11 +462,97 @@ func (s *Server) ListenAndServe(Addr string) error {
|
|||||||
proto.RegisterSlugChangeServer(grpcServer, s)
|
proto.RegisterSlugChangeServer(grpcServer, s)
|
||||||
proto.RegisterEventServiceServer(grpcServer, s)
|
proto.RegisterEventServiceServer(grpcServer, s)
|
||||||
proto.RegisterUserServiceServer(grpcServer, s)
|
proto.RegisterUserServiceServer(grpcServer, s)
|
||||||
proto.RegisterUserSessionsServer(grpcServer, s)
|
|
||||||
|
|
||||||
healthServer := health.NewServer()
|
healthServer := health.NewServer()
|
||||||
grpc_health_v1.RegisterHealthServer(grpcServer, healthServer)
|
grpc_health_v1.RegisterHealthServer(grpcServer, healthServer)
|
||||||
healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING)
|
healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING)
|
||||||
|
|
||||||
return grpcServer.Serve(listener)
|
serveErr := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
serveErr <- grpcServer.Serve(listener)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
grpcServer.GracefulStop()
|
||||||
|
err := <-serveErr
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-serveErr:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) sendAndReceive(ctx context.Context, sub *Subscriber, req *proto.Events, wait time.Duration) (*proto.Node, error) {
|
||||||
|
if sub == nil || req == nil {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "missing subscriber or request")
|
||||||
|
}
|
||||||
|
|
||||||
|
sub.mu.Lock()
|
||||||
|
defer sub.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-sub.done:
|
||||||
|
return nil, status.Error(codes.Canceled, "subscriber removed")
|
||||||
|
case sub.events <- req:
|
||||||
|
}
|
||||||
|
|
||||||
|
timer := time.NewTimer(wait)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-sub.done:
|
||||||
|
return nil, status.Error(codes.Canceled, "subscriber removed")
|
||||||
|
case <-timer.C:
|
||||||
|
return nil, status.Error(codes.DeadlineExceeded, "subscriber response timeout")
|
||||||
|
case resp, ok := <-sub.node:
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Error(codes.FailedPrecondition, "client channel closed")
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) snapshotSubscribers() []*Subscriber {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
subs := make([]*Subscriber, 0, len(s.Subscribers))
|
||||||
|
for _, sub := range s.Subscribers {
|
||||||
|
subs = append(subs, sub)
|
||||||
|
}
|
||||||
|
return subs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) broadcastAndCollect(ctx context.Context, worker func(context.Context, *Subscriber) (interface{}, bool)) []interface{} {
|
||||||
|
subs := s.snapshotSubscribers()
|
||||||
|
if len(subs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resultsCh := make(chan interface{}, len(subs))
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, sub := range subs {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(subscriber *Subscriber) {
|
||||||
|
defer wg.Done()
|
||||||
|
if val, ok := worker(ctx, subscriber); ok {
|
||||||
|
resultsCh <- val
|
||||||
|
}
|
||||||
|
}(sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
close(resultsCh)
|
||||||
|
|
||||||
|
var results []interface{}
|
||||||
|
for val := range resultsCh {
|
||||||
|
results = append(results, val)
|
||||||
|
}
|
||||||
|
return results
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user