diff --git a/db/sqlc/migrations/000001_init_schema.down.sql b/db/sqlc/migrations/000001_init_schema.down.sql new file mode 100644 index 0000000..2629106 --- /dev/null +++ b/db/sqlc/migrations/000001_init_schema.down.sql @@ -0,0 +1,4 @@ +DROP TABLE IF EXISTS public.verification; +DROP TABLE IF EXISTS public.session; +DROP TABLE IF EXISTS public.account; +DROP TABLE IF EXISTS public."user"; \ No newline at end of file diff --git a/db/sqlc/migrations/000001_init_schema.up.sql b/db/sqlc/migrations/000001_init_schema.up.sql new file mode 100644 index 0000000..6a793a3 --- /dev/null +++ b/db/sqlc/migrations/000001_init_schema.up.sql @@ -0,0 +1,52 @@ +CREATE EXTENSION IF NOT EXISTS pgcrypto; + +CREATE TABLE IF NOT EXISTS public."user" ( + id text NOT NULL, + identifier text UNIQUE NOT NULL DEFAULT substr(encode(gen_random_bytes(16), 'hex'), 1, 32), + name text NOT NULL, + email text NOT NULL, + email_verified boolean DEFAULT false NOT NULL, + image text, + created_at timestamp without time zone DEFAULT now() NOT NULL, + updated_at timestamp without time zone DEFAULT now() NOT NULL, + CONSTRAINT user_pkey PRIMARY KEY (id) +); + +CREATE TABLE IF NOT EXISTS public.account ( + id text NOT NULL, + account_id text NOT NULL, + provider_id text NOT NULL, + user_id text NOT NULL, + access_token text, + refresh_token text, + id_token text, + access_token_expires_at timestamp without time zone, + refresh_token_expires_at timestamp without time zone, + scope text, + password text, + created_at timestamp without time zone DEFAULT now() NOT NULL, + updated_at timestamp without time zone NOT NULL, + CONSTRAINT account_pkey PRIMARY KEY (id) +); + +CREATE TABLE IF NOT EXISTS public.session ( + id text NOT NULL, + expires_at timestamp without time zone NOT NULL, + token text NOT NULL, + created_at timestamp without time zone DEFAULT now() NOT NULL, + updated_at timestamp without time zone NOT NULL, + ip_address text, + user_agent text, + user_id text NOT NULL, + CONSTRAINT session_pkey PRIMARY KEY (id) +); + +CREATE TABLE IF NOT EXISTS public.verification ( + id text NOT NULL, + identifier text NOT NULL, + value text NOT NULL, + expires_at timestamp without time zone NOT NULL, + created_at timestamp without time zone DEFAULT now() NOT NULL, + updated_at timestamp without time zone DEFAULT now() NOT NULL, + CONSTRAINT verification_pkey PRIMARY KEY (id) +); \ No newline at end of file diff --git a/db/sqlc/migrations/001_create_identifiers.down.sql b/db/sqlc/migrations/001_create_identifiers.down.sql deleted file mode 100644 index f83303a..0000000 --- a/db/sqlc/migrations/001_create_identifiers.down.sql +++ /dev/null @@ -1,5 +0,0 @@ -DROP INDEX IF EXISTS idx_identifiers_created_at; -DROP INDEX IF EXISTS idx_identifiers_slug; - -DROP TABLE IF EXISTS identifiers; - diff --git a/db/sqlc/migrations/001_create_identifiers.up.sql b/db/sqlc/migrations/001_create_identifiers.up.sql deleted file mode 100644 index 5692521..0000000 --- a/db/sqlc/migrations/001_create_identifiers.up.sql +++ /dev/null @@ -1,12 +0,0 @@ -CREATE TABLE identifiers ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - slug VARCHAR(255) UNIQUE NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT NOW(), - updated_at TIMESTAMP NOT NULL DEFAULT NOW() -); - -CREATE INDEX idx_identifiers_created_at - ON identifiers(created_at); - -CREATE INDEX idx_identifiers_slug - ON identifiers(slug); diff --git a/db/sqlc/queries/query.sql b/db/sqlc/queries/query.sql index 80246be..c609363 100644 --- a/db/sqlc/queries/query.sql +++ b/db/sqlc/queries/query.sql @@ -1,34 +1,6 @@ --- name: CreateIdentifier :one -INSERT INTO identifiers (slug) -VALUES ($1) -RETURNING id, slug, created_at, updated_at; - --- name: GetIdentifierById :one -SELECT id, slug, created_at, updated_at -FROM identifiers -WHERE id = $1; - --- name: GetIdentifierBySlug :one -SELECT id, slug, created_at, updated_at -FROM identifiers -WHERE slug = $1; - --- name: ListIdentifiers :many -SELECT id, slug, created_at, updated_at -FROM identifiers -ORDER BY created_at DESC -LIMIT $1 OFFSET $2; - --- name: DeleteIdentifier :exec -DELETE FROM identifiers -WHERE id = $1; - --- name: UpdateIdentifierSlug :one -UPDATE identifiers -SET slug = $2, updated_at = NOW() -WHERE id = $1 -RETURNING id, slug, created_at, updated_at; - --- name: CountIdentifiers :one -SELECT COUNT(*) FROM identifiers; - +-- name: UserExistsByIdentifier :one +SELECT EXISTS ( + SELECT 1 + FROM public."user" + WHERE identifier = $1 +) AS exists; \ No newline at end of file diff --git a/db/sqlc/repository/models.go b/db/sqlc/repository/models.go index 6a124ed..345378a 100644 --- a/db/sqlc/repository/models.go +++ b/db/sqlc/repository/models.go @@ -5,13 +5,52 @@ package repository import ( - "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" ) -type Identifier struct { - ID uuid.UUID - Slug string +type Account struct { + ID string + AccountID string + ProviderID string + UserID string + AccessToken pgtype.Text + RefreshToken pgtype.Text + IDToken pgtype.Text + AccessTokenExpiresAt pgtype.Timestamp + RefreshTokenExpiresAt pgtype.Timestamp + Scope pgtype.Text + Password pgtype.Text + CreatedAt pgtype.Timestamp + UpdatedAt pgtype.Timestamp +} + +type Session struct { + ID string + ExpiresAt pgtype.Timestamp + Token string CreatedAt pgtype.Timestamp UpdatedAt pgtype.Timestamp + IpAddress pgtype.Text + UserAgent pgtype.Text + UserID string +} + +type User struct { + ID string + Identifier string + Name string + Email string + EmailVerified bool + Image pgtype.Text + CreatedAt pgtype.Timestamp + UpdatedAt pgtype.Timestamp +} + +type Verification struct { + ID string + Identifier string + Value string + ExpiresAt pgtype.Timestamp + CreatedAt pgtype.Timestamp + UpdatedAt pgtype.Timestamp } diff --git a/db/sqlc/repository/query.sql.go b/db/sqlc/repository/query.sql.go index e9baf2e..acd965d 100644 --- a/db/sqlc/repository/query.sql.go +++ b/db/sqlc/repository/query.sql.go @@ -7,142 +7,19 @@ package repository import ( "context" - - "github.com/google/uuid" ) -const countIdentifiers = `-- name: CountIdentifiers :one -SELECT COUNT(*) FROM identifiers +const userExistsByIdentifier = `-- name: UserExistsByIdentifier :one +SELECT EXISTS ( + SELECT 1 + FROM public."user" + WHERE identifier = $1 +) AS exists ` -func (q *Queries) CountIdentifiers(ctx context.Context) (int64, error) { - row := q.db.QueryRow(ctx, countIdentifiers) - var count int64 - err := row.Scan(&count) - return count, err -} - -const createIdentifier = `-- name: CreateIdentifier :one -INSERT INTO identifiers (slug) -VALUES ($1) -RETURNING id, slug, created_at, updated_at -` - -func (q *Queries) CreateIdentifier(ctx context.Context, slug string) (Identifier, error) { - row := q.db.QueryRow(ctx, createIdentifier, slug) - var i Identifier - err := row.Scan( - &i.ID, - &i.Slug, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const deleteIdentifier = `-- name: DeleteIdentifier :exec -DELETE FROM identifiers -WHERE id = $1 -` - -func (q *Queries) DeleteIdentifier(ctx context.Context, id uuid.UUID) error { - _, err := q.db.Exec(ctx, deleteIdentifier, id) - return err -} - -const getIdentifierById = `-- name: GetIdentifierById :one -SELECT id, slug, created_at, updated_at -FROM identifiers -WHERE id = $1 -` - -func (q *Queries) GetIdentifierById(ctx context.Context, id uuid.UUID) (Identifier, error) { - row := q.db.QueryRow(ctx, getIdentifierById, id) - var i Identifier - err := row.Scan( - &i.ID, - &i.Slug, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const getIdentifierBySlug = `-- name: GetIdentifierBySlug :one -SELECT id, slug, created_at, updated_at -FROM identifiers -WHERE slug = $1 -` - -func (q *Queries) GetIdentifierBySlug(ctx context.Context, slug string) (Identifier, error) { - row := q.db.QueryRow(ctx, getIdentifierBySlug, slug) - var i Identifier - err := row.Scan( - &i.ID, - &i.Slug, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const listIdentifiers = `-- name: ListIdentifiers :many -SELECT id, slug, created_at, updated_at -FROM identifiers -ORDER BY created_at DESC -LIMIT $1 OFFSET $2 -` - -type ListIdentifiersParams struct { - Limit int32 - Offset int32 -} - -func (q *Queries) ListIdentifiers(ctx context.Context, arg ListIdentifiersParams) ([]Identifier, error) { - rows, err := q.db.Query(ctx, listIdentifiers, arg.Limit, arg.Offset) - if err != nil { - return nil, err - } - defer rows.Close() - var items []Identifier - for rows.Next() { - var i Identifier - if err := rows.Scan( - &i.ID, - &i.Slug, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const updateIdentifierSlug = `-- name: UpdateIdentifierSlug :one -UPDATE identifiers -SET slug = $2, updated_at = NOW() -WHERE id = $1 -RETURNING id, slug, created_at, updated_at -` - -type UpdateIdentifierSlugParams struct { - ID uuid.UUID - Slug string -} - -func (q *Queries) UpdateIdentifierSlug(ctx context.Context, arg UpdateIdentifierSlugParams) (Identifier, error) { - row := q.db.QueryRow(ctx, updateIdentifierSlug, arg.ID, arg.Slug) - var i Identifier - err := row.Scan( - &i.ID, - &i.Slug, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err +func (q *Queries) UserExistsByIdentifier(ctx context.Context, identifier string) (bool, error) { + row := q.db.QueryRow(ctx, userExistsByIdentifier, identifier) + var exists bool + err := row.Scan(&exists) + return exists, err } diff --git a/server/server.go b/server/server.go index 0d2f20d..d36a8ac 100644 --- a/server/server.go +++ b/server/server.go @@ -26,21 +26,96 @@ type Subscriber struct { closeOnce sync.Once } type Server struct { - Database *repository.Queries - Subscribers map[string]*Subscriber - mu *sync.RWMutex - authToken string + Database *repository.Queries + Subscribers map[string]*Subscriber + mu *sync.RWMutex + authToken string + broadcastChan chan *proto.Controller + broadcastResultChan chan []SubscriberResult + notifyAllCancel context.CancelFunc proto.UnimplementedEventServiceServer proto.UnimplementedSlugChangeServer + proto.UnimplementedUserServiceServer + proto.UnimplementedUserSessionsServer } func New(database *repository.Queries, authToken string) *Server { - return &Server{ - Database: database, - Subscribers: make(map[string]*Subscriber), - mu: new(sync.RWMutex), - authToken: authToken, + broadcastChan := make(chan *proto.Controller, 10) + broadcastResultChan := make(chan []SubscriberResult, 10) + + ctx, cancel := context.WithCancel(context.Background()) + + srv := &Server{ + Database: database, + Subscribers: make(map[string]*Subscriber), + mu: new(sync.RWMutex), + authToken: authToken, + broadcastChan: broadcastChan, + broadcastResultChan: broadcastResultChan, + notifyAllCancel: cancel, } + + go srv.notifyAllSubscriber(ctx, broadcastChan, broadcastResultChan) + + return srv +} + +func (s *Server) GetSession(ctx context.Context, req *proto.GetSessionRequest) (*proto.GetSessionsResponse, error) { + if req == nil { + return nil, status.Error(codes.InvalidArgument, "request is nil") + } + + controllerReq := &proto.Controller{ + Type: proto.EventType_GET_SESSIONS, + Payload: &proto.Controller_GetSessionsEvent{ + GetSessionsEvent: &proto.GetSessionsEvent{ + Identity: req.GetIdentity(), + }, + }, + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case s.broadcastChan <- controllerReq: + } + + var results []SubscriberResult + select { + case <-ctx.Done(): + return nil, ctx.Err() + case results = <-s.broadcastResultChan: + } + responses := make([]*proto.Detail, 0, len(results)) + for _, result := range results { + for _, detail := range result.Response.Payload.(*proto.Client_GetSessionsEvent).GetSessionsEvent.Details { + responses = append(responses, detail) + } + responses = append(responses) + } + + if len(responses) == 0 { + return nil, status.Error(codes.NotFound, "no subscriber responded") + } + + return &proto.GetSessionsResponse{Details: responses}, nil +} + +func (s *Server) Check(ctx context.Context, request *proto.CheckRequest) (*proto.CheckResponse, error) { + exist, err := s.Database.UserExistsByIdentifier(ctx, request.GetAuthToken()) + if err != nil { + return nil, err + } + + if exist { + return &proto.CheckResponse{ + Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED, + }, nil + } + + return &proto.CheckResponse{ + Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED, + }, nil } func (s *Server) Subscribe(event grpc.BidiStreamingServer[proto.Client, proto.Controller]) error { @@ -124,6 +199,21 @@ func processEventStream(ctx context.Context, requestChan *Subscriber, event grpc case requestChan.client <- recv: } log.Printf("Received slug change event: %v", recv) + case proto.EventType_GET_SESSIONS: + log.Printf("Processing session event") + if err := event.Send(request); err != nil { + return err + } + recv, err := event.Recv() + if err != nil { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case requestChan.client <- recv: + } + log.Printf("Received SESSIONS event: %v", recv) default: log.Printf("Unknown event type: %v", request.GetType()) } @@ -220,6 +310,66 @@ func (s *Server) RequestChangeSlug(ctx context.Context, request *proto.ChangeSlu return (*proto.ChangeSlugResponse)(response.SlugEventResponse), nil } +type SubscriberResult struct { + Identity string + Response *proto.Client + Err error +} + +func (s *Server) notifyAllSubscriber(ctx context.Context, recvChan <-chan *proto.Controller, resultChan chan<- []SubscriberResult) { + for { + select { + case <-ctx.Done(): + return + case controllerReq, ok := <-recvChan: + if !ok { + return + } + if controllerReq == nil { + continue + } + + s.mu.RLock() + subs := make(map[string]*Subscriber, len(s.Subscribers)) + for id, sub := range s.Subscribers { + subs[id] = sub + } + s.mu.RUnlock() + results := make([]SubscriberResult, 0, len(subs)) + for id, sub := range subs { + select { + case <-ctx.Done(): + return + case <-sub.done: + results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")}) + continue + case sub.controller <- controllerReq: + default: + results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Unavailable, "controller channel blocked")}) + continue + } + select { + case <-ctx.Done(): + return + case <-sub.done: + results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.Canceled, "subscriber removed")}) + case resp, ok := <-sub.client: + if !ok { + results = append(results, SubscriberResult{Identity: id, Err: status.Error(codes.FailedPrecondition, "client channel closed")}) + continue + } + results = append(results, SubscriberResult{Identity: id, Response: resp}) + } + } + select { + case <-ctx.Done(): + return + case resultChan <- results: + } + } + } +} + func (s *Server) ListenAndServe(Addr string) error { listener, err := net.Listen("tcp", Addr) if err != nil { @@ -246,6 +396,8 @@ func (s *Server) ListenAndServe(Addr string) error { proto.RegisterSlugChangeServer(grpcServer, s) proto.RegisterEventServiceServer(grpcServer, s) + proto.RegisterUserServiceServer(grpcServer, s) + proto.RegisterUserSessionsServer(grpcServer, s) healthServer := health.NewServer() grpc_health_v1.RegisterHealthServer(grpcServer, healthServer)