408 lines
11 KiB
Go
408 lines
11 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.fossy.my.id/bagas/tunnel-please-controller/db/sqlc/repository"
|
|
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/health"
|
|
"google.golang.org/grpc/health/grpc_health_v1"
|
|
"google.golang.org/grpc/keepalive"
|
|
"google.golang.org/grpc/reflection"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
type Subscriber struct {
|
|
client chan *proto.Client
|
|
controller chan *proto.Controller
|
|
done chan struct{}
|
|
closeOnce sync.Once
|
|
}
|
|
type Server struct {
|
|
Database *repository.Queries
|
|
Subscribers map[string]*Subscriber
|
|
mu *sync.RWMutex
|
|
authToken string
|
|
broadcastChan chan *proto.Controller
|
|
broadcastResultChan chan []SubscriberResult
|
|
notifyAllCancel context.CancelFunc
|
|
proto.UnimplementedEventServiceServer
|
|
proto.UnimplementedSlugChangeServer
|
|
proto.UnimplementedUserServiceServer
|
|
proto.UnimplementedUserSessionsServer
|
|
}
|
|
|
|
func New(database *repository.Queries, authToken string) *Server {
|
|
broadcastChan := make(chan *proto.Controller, 10)
|
|
broadcastResultChan := make(chan []SubscriberResult, 10)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
srv := &Server{
|
|
Database: database,
|
|
Subscribers: make(map[string]*Subscriber),
|
|
mu: new(sync.RWMutex),
|
|
authToken: authToken,
|
|
broadcastChan: broadcastChan,
|
|
broadcastResultChan: broadcastResultChan,
|
|
notifyAllCancel: cancel,
|
|
}
|
|
|
|
go srv.notifyAllSubscriber(ctx, broadcastChan, broadcastResultChan)
|
|
|
|
return srv
|
|
}
|
|
|
|
func (s *Server) GetSession(ctx context.Context, req *proto.GetSessionRequest) (*proto.GetSessionsResponse, error) {
|
|
if req == nil {
|
|
return nil, status.Error(codes.InvalidArgument, "request is nil")
|
|
}
|
|
|
|
controllerReq := &proto.Controller{
|
|
Type: proto.EventType_GET_SESSIONS,
|
|
Payload: &proto.Controller_GetSessionsEvent{
|
|
GetSessionsEvent: &proto.GetSessionsEvent{
|
|
Identity: req.GetIdentity(),
|
|
},
|
|
},
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case s.broadcastChan <- controllerReq:
|
|
}
|
|
|
|
var results []SubscriberResult
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case results = <-s.broadcastResultChan:
|
|
}
|
|
responses := make([]*proto.Detail, 0, len(results))
|
|
for _, result := range results {
|
|
for _, detail := range result.Response.Payload.(*proto.Client_GetSessionsEvent).GetSessionsEvent.Details {
|
|
responses = append(responses, detail)
|
|
}
|
|
responses = append(responses)
|
|
}
|
|
|
|
if len(responses) == 0 {
|
|
return nil, status.Error(codes.NotFound, "no subscriber responded")
|
|
}
|
|
|
|
return &proto.GetSessionsResponse{Details: responses}, nil
|
|
}
|
|
|
|
func (s *Server) Check(ctx context.Context, request *proto.CheckRequest) (*proto.CheckResponse, error) {
|
|
exist, err := s.Database.UserExistsByIdentifier(ctx, request.GetAuthToken())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if exist {
|
|
return &proto.CheckResponse{
|
|
Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED,
|
|
}, nil
|
|
}
|
|
|
|
return &proto.CheckResponse{
|
|
Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error {
|
|
ctx := event.Context()
|
|
recv, err := event.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if recv == nil {
|
|
return status.Error(codes.InvalidArgument, "missing authentication event")
|
|
}
|
|
if recv.GetType() != proto.EventType_AUTHENTICATION {
|
|
return status.Errorf(codes.InvalidArgument, "invalid event type: %s", recv.GetType())
|
|
}
|
|
payload, ok := recv.GetPayload().(*proto.Client_AuthEvent)
|
|
if !ok || payload == nil || payload.AuthEvent == nil {
|
|
return status.Error(codes.InvalidArgument, "missing auth payload")
|
|
}
|
|
identity := payload.AuthEvent.Identity
|
|
if identity == "" {
|
|
return status.Error(codes.InvalidArgument, "missing identity")
|
|
}
|
|
token := payload.AuthEvent.AuthToken
|
|
if token != s.authToken {
|
|
return status.Error(codes.Unauthenticated, "invalid auth token")
|
|
}
|
|
|
|
log.Printf("Client %s authenticated successfully", identity)
|
|
|
|
requestChan := &Subscriber{
|
|
client: make(chan *proto.Client, 10),
|
|
controller: make(chan *proto.Controller, 10),
|
|
done: make(chan struct{}),
|
|
}
|
|
|
|
if err = s.AddEventSubscriber(identity, requestChan); err != nil {
|
|
return status.Error(codes.AlreadyExists, err.Error())
|
|
}
|
|
defer func() {
|
|
s.RemoveEventSubscriber(identity)
|
|
log.Printf("Client %s disconnected and unsubscribed", identity)
|
|
}()
|
|
|
|
return processEventStream(ctx, requestChan, event)
|
|
}
|
|
|
|
func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-requestChan.done:
|
|
return nil
|
|
case request, ok := <-requestChan.controller:
|
|
if !ok {
|
|
return nil
|
|
}
|
|
if request == nil {
|
|
continue
|
|
}
|
|
log.Printf("Received event request: %v", request)
|
|
switch request.GetType() {
|
|
case proto.EventType_SLUG_CHANGE:
|
|
payload, ok := request.GetPayload().(*proto.Controller_SlugEvent)
|
|
if !ok || payload == nil || payload.SlugEvent == nil {
|
|
log.Printf("invalid slug change payload")
|
|
continue
|
|
}
|
|
slugEvent := payload.SlugEvent
|
|
log.Printf("Processing slug change event: old=%s, new=%s", slugEvent.Old, slugEvent.New)
|
|
if err := event.Send(request); err != nil {
|
|
return err
|
|
}
|
|
recv, err := event.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case requestChan.client <- recv:
|
|
}
|
|
log.Printf("Received slug change event: %v", recv)
|
|
case proto.EventType_GET_SESSIONS:
|
|
log.Printf("Processing session event")
|
|
if err := event.Send(request); err != nil {
|
|
return err
|
|
}
|
|
recv, err := event.Recv()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case requestChan.client <- recv:
|
|
}
|
|
log.Printf("Received SESSIONS event: %v", recv)
|
|
default:
|
|
log.Printf("Unknown event type: %v", request.GetType())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) AddEventSubscriber(identity string, req *Subscriber) error {
|
|
if identity == "" || req == nil {
|
|
return fmt.Errorf("invalid subscriber")
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if _, exist := s.Subscribers[identity]; exist {
|
|
return fmt.Errorf("identity %s already subscribed", identity)
|
|
}
|
|
s.Subscribers[identity] = req
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) RemoveEventSubscriber(identity string) {
|
|
s.mu.Lock()
|
|
sub := s.Subscribers[identity]
|
|
delete(s.Subscribers, identity)
|
|
s.mu.Unlock()
|
|
|
|
if sub != nil {
|
|
sub.closeOnce.Do(func() { close(sub.done) })
|
|
}
|
|
}
|
|
|
|
func (s *Server) GetEventSubscriber(identity string) (*Subscriber, error) {
|
|
if identity == "" {
|
|
return nil, status.Error(codes.InvalidArgument, "missing identity")
|
|
}
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
req, ok := s.Subscribers[identity]
|
|
if !ok {
|
|
return nil, status.Errorf(codes.NotFound, "identity %s not subscribed", identity)
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlugRequest) (*proto.ChangeSlugResponse, error) {
|
|
if request == nil {
|
|
return nil, status.Error(codes.InvalidArgument, "request is nil")
|
|
}
|
|
if request.GetNode() == "" {
|
|
return nil, status.Error(codes.InvalidArgument, "node is required")
|
|
}
|
|
if request.Old == "" || request.New == "" {
|
|
return nil, status.Error(codes.InvalidArgument, "old and new slugs are required")
|
|
}
|
|
|
|
subscriber, err := s.GetEventSubscriber(request.GetNode())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
controllerMsg := &proto.Controller{
|
|
Type: proto.EventType_SLUG_CHANGE,
|
|
Payload: &proto.Controller_SlugEvent{
|
|
SlugEvent: &proto.SlugChangeEvent{
|
|
Old: request.Old,
|
|
New: request.New,
|
|
},
|
|
},
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-subscriber.done:
|
|
return nil, status.Error(codes.Canceled, "subscriber removed")
|
|
case subscriber.controller <- controllerMsg:
|
|
}
|
|
|
|
var resp *proto.Client
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-subscriber.done:
|
|
return nil, status.Error(codes.Canceled, "subscriber removed")
|
|
case resp = <-subscriber.client:
|
|
}
|
|
if resp == nil {
|
|
return nil, status.Error(codes.FailedPrecondition, "empty response from client")
|
|
}
|
|
response, ok := resp.Payload.(*proto.Client_SlugEventResponse)
|
|
if !ok || response == nil || response.SlugEventResponse == nil {
|
|
return nil, status.Error(codes.FailedPrecondition, "invalid slug response payload")
|
|
}
|
|
return (*proto.ChangeSlugResponse)(response.SlugEventResponse), nil
|
|
}
|
|
|
|
type SubscriberResult struct {
|
|
Identity string
|
|
Response *proto.Client
|
|
Err error
|
|
}
|
|
|
|
func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto.Controller, resultChan chan<- []SubscriberResult) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case controllerReq, ok := <-recvChan:
|
|
if !ok {
|
|
return
|
|
}
|
|
if controllerReq == nil {
|
|
continue
|
|
}
|
|
|
|
s.mu.RLock()
|
|
subs := make(map[string]*Subscriber, len(s.Subscribers))
|
|
for id, sub := range s.Subscribers {
|
|
subs[id] = sub
|
|
}
|
|
s.mu.RUnlock()
|
|
results := make([]SubscriberResult, 0, len(subs))
|
|
for id, sub := range subs {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-sub.done:
|
|
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")})
|
|
continue
|
|
case sub.controller <- controllerReq:
|
|
default:
|
|
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Unavailable, "controller channel blocked")})
|
|
continue
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-sub.done:
|
|
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")})
|
|
case resp, ok := <-sub.client:
|
|
if !ok {
|
|
results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.FailedPrecondition, "client channel closed")})
|
|
continue
|
|
}
|
|
results = append(results, SubscriberResult{Identity: id, Response: resp})
|
|
}
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case resultChan <- results:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) ListenAndServe(Addr string) error {
|
|
listener, err := net.Listen("tcp", Addr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
kaParams := keepalive.ServerParameters{
|
|
MaxConnectionIdle: 0,
|
|
MaxConnectionAge: 0,
|
|
MaxConnectionAgeGrace: 0,
|
|
Time: 30 * time.Second,
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
kaPolicy := keepalive.EnforcementPolicy{
|
|
MinTime: 5 * time.Second,
|
|
PermitWithoutStream: true,
|
|
}
|
|
|
|
grpcServer := grpc.NewServer(
|
|
grpc.KeepaliveParams(kaParams),
|
|
grpc.KeepaliveEnforcementPolicy(kaPolicy),
|
|
)
|
|
reflection.Register(grpcServer)
|
|
|
|
proto.RegisterSlugChangeServer(grpcServer, s)
|
|
proto.RegisterEventServiceServer(grpcServer, s)
|
|
proto.RegisterUserServiceServer(grpcServer, s)
|
|
proto.RegisterUserSessionsServer(grpcServer, s)
|
|
|
|
healthServer := health.NewServer()
|
|
grpc_health_v1.RegisterHealthServer(grpcServer, healthServer)
|
|
healthServer.SetServingStatus("", grpc_health_v1.HealthCheckResponse_SERVING)
|
|
|
|
return grpcServer.Serve(listener)
|
|
}
|