diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index da76a55..38835c1 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -3,13 +3,14 @@ package client import ( "context" "crypto/tls" + "errors" "fmt" + "io" "log" "time" "tunnel_pls/session" - "git.fossy.my.id/bagas/tunnel-please-grpc/gen" - "github.com/golang/protobuf/ptypes/empty" + proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -20,36 +21,59 @@ import ( ) type GrpcConfig struct { - Address string - UseTLS bool - InsecureSkipVerify bool - Timeout time.Duration - KeepAlive bool - MaxRetries int + Address string + UseTLS bool + InsecureSkipVerify bool + Timeout time.Duration + KeepAlive bool + MaxRetries int + KeepAliveTime time.Duration + KeepAliveTimeout time.Duration + PermitWithoutStream bool } type Client struct { conn *grpc.ClientConn config *GrpcConfig sessionRegistry session.Registry - IdentityService gen.IdentityClient - eventService gen.EventServiceClient + slugService proto.SlugChangeClient + eventService proto.EventServiceClient } func DefaultConfig() *GrpcConfig { return &GrpcConfig{ - Address: "localhost:50051", - UseTLS: false, - InsecureSkipVerify: false, - Timeout: 10 * time.Second, - KeepAlive: true, - MaxRetries: 3, + Address: "localhost:50051", + UseTLS: false, + InsecureSkipVerify: false, + Timeout: 10 * time.Second, + KeepAlive: true, + MaxRetries: 3, + KeepAliveTime: 2 * time.Minute, + KeepAliveTimeout: 10 * time.Second, + PermitWithoutStream: false, } } func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error) { if config == nil { config = DefaultConfig() + } else { + defaults := DefaultConfig() + if config.Address == "" { + config.Address = defaults.Address + } + if config.Timeout == 0 { + config.Timeout = defaults.Timeout + } + if config.MaxRetries == 0 { + config.MaxRetries = defaults.MaxRetries + } + if config.KeepAliveTime == 0 { + config.KeepAliveTime = defaults.KeepAliveTime + } + if config.KeepAliveTimeout == 0 { + config.KeepAliveTimeout = defaults.KeepAliveTimeout + } } var opts []grpc.DialOption @@ -66,9 +90,9 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error) if config.KeepAlive { kaParams := keepalive.ClientParameters{ - Time: 10 * time.Second, - Timeout: 3 * time.Second, - PermitWithoutStream: false, + Time: config.KeepAliveTime, + Timeout: config.KeepAliveTimeout, + PermitWithoutStream: config.PermitWithoutStream, } opts = append(opts, grpc.WithKeepaliveParams(kaParams)) } @@ -85,94 +109,120 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error) return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", config.Address, err) } - identityService := gen.NewIdentityClient(conn) - eventService := gen.NewEventServiceClient(conn) + slugService := proto.NewSlugChangeClient(conn) + eventService := proto.NewEventServiceClient(conn) return &Client{ conn: conn, config: config, - IdentityService: identityService, - eventService: eventService, + slugService: slugService, sessionRegistry: sessionRegistry, + eventService: eventService, }, nil } -func (c *Client) SubscribeEvents(ctx context.Context) error { - for { - if ctx.Err() != nil { - log.Println("Context cancelled, stopping event subscription") - return ctx.Err() - } - - log.Println("Subscribing to events...") - stream, err := c.eventService.Subscribe(ctx, &empty.Empty{}) - if err != nil { - log.Printf("Failed to subscribe: %v. Retrying in 10 seconds...", err) - select { - case <-time.After(10 * time.Second): - case <-ctx.Done(): - return ctx.Err() - } - continue - } - - if err := c.processEventStream(ctx, stream); err != nil { - if ctx.Err() != nil { - return ctx.Err() - } - log.Printf("Stream error: %v. Reconnecting in 10 seconds...", err) - select { - case <-time.After(10 * time.Second): - case <-ctx.Done(): - return ctx.Err() - } - } +func (c *Client) SubscribeEvents(ctx context.Context, identity string) error { + subscribe, err := c.eventService.Subscribe(ctx) + if err != nil { + return err } + err = subscribe.Send(&proto.Client{ + Type: proto.EventType_AUTHENTICATION, + Payload: &proto.Client_AuthEvent{ + AuthEvent: &proto.Authentication{ + Identity: identity, + AuthToken: "test_auth_key", + }, + }, + }) + + if err != nil { + log.Println("Authentication failed to send to gRPC server:", err) + return err + } + log.Println("Authentication Successfully sent to gRPC server") + err = c.processEventStream(subscribe) + if err != nil { + return err + } + return nil } -func (c *Client) processEventStream(ctx context.Context, stream gen.EventService_SubscribeClient) error { +func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Client, proto.Controller]) error { for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - event, err := stream.Recv() + recv, err := subscribe.Recv() if err != nil { - st, ok := status.FromError(err) - if !ok { - return fmt.Errorf("non-gRPC error: %w", err) - } - - switch st.Code() { - case codes.Unavailable, codes.Canceled, codes.DeadlineExceeded: - return fmt.Errorf("stream closed [%s]: %s", st.Code(), st.Message()) - default: - return fmt.Errorf("gRPC error [%s]: %s", st.Code(), st.Message()) + if isConnectionError(err) { + log.Printf("connection error receiving from gRPC server: %v", err) + return err } + log.Printf("non-connection receive error from gRPC server: %v", err) + continue } - - if event != nil { - dataEvent := event.GetDataEvent() - if dataEvent != nil { - oldSlug := dataEvent.GetOld() - newSlug := dataEvent.GetNew() - - userSession, exist := c.sessionRegistry.Get(oldSlug) - if !exist { - log.Printf("Session with slug '%s' not found, ignoring event", oldSlug) - continue - } - success := c.sessionRegistry.Update(oldSlug, newSlug) - - if success { - log.Printf("Successfully updated session slug from '%s' to '%s'", oldSlug, newSlug) - userSession.GetInteraction().Redraw() - } else { - log.Printf("Failed to update session slug from '%s' to '%s'", oldSlug, newSlug) + switch recv.GetType() { + case proto.EventType_SLUG_CHANGE: + oldSlug := recv.GetSlugEvent().GetOld() + newSlug := recv.GetSlugEvent().GetNew() + session, err := c.sessionRegistry.Get(oldSlug) + if err != nil { + errSend := subscribe.Send(&proto.Client{ + Type: proto.EventType_SLUG_CHANGE_RESPONSE, + Payload: &proto.Client_SlugEventResponse{ + SlugEventResponse: &proto.SlugChangeEventResponse{ + Success: false, + Message: err.Error(), + }, + }, + }) + if errSend != nil { + if isConnectionError(errSend) { + log.Printf("connection error sending slug change failure: %v", errSend) + return errSend + } + log.Printf("non-connection send error for slug change failure: %v", errSend) } + continue } + err = c.sessionRegistry.Update(oldSlug, newSlug) + if err != nil { + errSend := subscribe.Send(&proto.Client{ + Type: proto.EventType_SLUG_CHANGE_RESPONSE, + Payload: &proto.Client_SlugEventResponse{ + SlugEventResponse: &proto.SlugChangeEventResponse{ + Success: false, + Message: err.Error(), + }, + }, + }) + if errSend != nil { + if isConnectionError(errSend) { + log.Printf("connection error sending slug change failure: %v", errSend) + return errSend + } + log.Printf("non-connection send error for slug change failure: %v", errSend) + } + continue + } + session.GetInteraction().Redraw() + err = subscribe.Send(&proto.Client{ + Type: proto.EventType_SLUG_CHANGE_RESPONSE, + Payload: &proto.Client_SlugEventResponse{ + SlugEventResponse: &proto.SlugChangeEventResponse{ + Success: true, + Message: "", + }, + }, + }) + if err != nil { + if isConnectionError(err) { + log.Printf("connection error sending slug change success: %v", err) + return err + } + log.Printf("non-connection send error for slug change success: %v", err) + continue + } + default: + log.Printf("Unknown event type received: %v", recv.GetType()) } } } @@ -209,3 +259,18 @@ func (c *Client) CheckServerHealth(ctx context.Context) error { func (c *Client) GetConfig() *GrpcConfig { return c.config } + +func isConnectionError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) { + return true + } + switch status.Code(err) { + case codes.Unavailable, codes.Canceled, codes.DeadlineExceeded: + return true + default: + return false + } +} diff --git a/main.go b/main.go index 5f7a82b..1daa6e0 100644 --- a/main.go +++ b/main.go @@ -91,13 +91,9 @@ func main() { cancel() ctx, cancel = context.WithCancel(context.Background()) - //go func(err error) { - // if !errors.Is(err, ctx.Err()) { - // log.Fatalf("Event subscription error: %s", err) - // } - //}(grpcClient.SubscribeEvents(ctx)) go func() { - err := grpcClient.SubscribeEvents(ctx) + identity := config.Getenv("DOMAIN", "localhost") + err = grpcClient.SubscribeEvents(ctx, identity) if err != nil { return } diff --git a/server/http.go b/server/http.go index 433b9a0..8add118 100644 --- a/server/http.go +++ b/server/http.go @@ -313,8 +313,8 @@ func (hs *httpServer) handler(conn net.Conn) { return } - sshSession, exist := hs.sessionRegistry.Get(slug) - if !exist { + sshSession, err := hs.sessionRegistry.Get(slug) + if err != nil { _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + "Content-Length: 0\r\n" + diff --git a/server/https.go b/server/https.go index 90ffd49..3c502a5 100644 --- a/server/https.go +++ b/server/https.go @@ -89,8 +89,8 @@ func (hs *httpServer) handlerTLS(conn net.Conn) { return } - sshSession, exist := hs.sessionRegistry.Get(slug) - if !exist { + sshSession, err := hs.sessionRegistry.Get(slug) + if err != nil { _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) + "Content-Length: 0\r\n" + diff --git a/session/interaction/constants.go b/session/interaction/constants.go deleted file mode 100644 index 91024d3..0000000 --- a/session/interaction/constants.go +++ /dev/null @@ -1,152 +0,0 @@ -package interaction - -const ( - backspaceChar = 8 - deleteChar = 127 - enterChar = 13 - escapeChar = 27 - ctrlC = 3 - forwardSlash = '/' - minPrintableChar = 32 - maxPrintableChar = 126 - - minSlugLength = 3 - maxSlugLength = 20 - - clearScreen = "\033[H\033[2J" - clearLine = "\033[K" - clearToLineEnd = "\r\033[K" - backspaceSeq = "\b \b" - - minBoxWidth = 50 - paddingRight = 4 -) - -var forbiddenSlugs = map[string]struct{}{ - "ping": {}, - "staging": {}, - "admin": {}, - "root": {}, - "api": {}, - "www": {}, - "support": {}, - "help": {}, - "status": {}, - "health": {}, - "login": {}, - "logout": {}, - "signup": {}, - "register": {}, - "settings": {}, - "config": {}, - "null": {}, - "undefined": {}, - "example": {}, - "test": {}, - "dev": {}, - "system": {}, - "administrator": {}, - "dashboard": {}, - "account": {}, - "profile": {}, - "user": {}, - "users": {}, - "auth": {}, - "oauth": {}, - "callback": {}, - "webhook": {}, - "webhooks": {}, - "static": {}, - "assets": {}, - "cdn": {}, - "mail": {}, - "email": {}, - "ftp": {}, - "ssh": {}, - "git": {}, - "svn": {}, - "blog": {}, - "news": {}, - "about": {}, - "contact": {}, - "terms": {}, - "privacy": {}, - "legal": {}, - "billing": {}, - "payment": {}, - "checkout": {}, - "cart": {}, - "shop": {}, - "store": {}, - "download": {}, - "uploads": {}, - "images": {}, - "img": {}, - "css": {}, - "js": {}, - "fonts": {}, - "public": {}, - "private": {}, - "internal": {}, - "external": {}, - "proxy": {}, - "cache": {}, - "debug": {}, - "metrics": {}, - "monitoring": {}, - "graphql": {}, - "rest": {}, - "rpc": {}, - "socket": {}, - "ws": {}, - "wss": {}, - "app": {}, - "apps": {}, - "mobile": {}, - "desktop": {}, - "embed": {}, - "widget": {}, - "docs": {}, - "documentation": {}, - "wiki": {}, - "forum": {}, - "community": {}, - "feedback": {}, - "report": {}, - "abuse": {}, - "spam": {}, - "security": {}, - "verify": {}, - "confirm": {}, - "reset": {}, - "password": {}, - "recovery": {}, - "unsubscribe": {}, - "subscribe": {}, - "notifications": {}, - "alerts": {}, - "messages": {}, - "inbox": {}, - "outbox": {}, - "sent": {}, - "draft": {}, - "trash": {}, - "archive": {}, - "search": {}, - "explore": {}, - "discover": {}, - "trending": {}, - "popular": {}, - "featured": {}, - "new": {}, - "latest": {}, - "top": {}, - "best": {}, - "hot": {}, - "random": {}, - "all": {}, - "any": {}, - "none": {}, - "true": {}, - "false": {}, -} diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 9356a3a..3a36f4c 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -28,7 +28,7 @@ type Lifecycle interface { type Controller interface { SetChannel(channel ssh.Channel) SetLifecycle(lifecycle Lifecycle) - SetSlugModificator(func(oldSlug, newSlug string) bool) + SetSlugModificator(func(oldSlug, newSlug string) error) Start() SetWH(w, h int) Redraw() @@ -45,7 +45,7 @@ type Interaction struct { slugManager slug.Manager forwarder Forwarder lifecycle Lifecycle - updateClientSlug func(oldSlug, newSlug string) bool + updateClientSlug func(oldSlug, newSlug string) error program *tea.Program ctx context.Context cancel context.CancelFunc @@ -121,7 +121,7 @@ func (i *Interaction) SetChannel(channel ssh.Channel) { i.channel = channel } -func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) (success bool)) { +func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) error) { i.updateClientSlug = modificator } @@ -218,20 +218,10 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, tea.Batch(tea.ClearScreen, textinput.Blink) case "enter": inputValue := m.slugInput.Value() - - if isForbiddenSlug(inputValue) { - m.slugError = "This subdomain is reserved. Please choose a different one." - return m, nil - } else if !isValidSlug(inputValue) { - m.slugError = "Invalid subdomain. Follow the rules." + if err := m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue); err != nil { + m.slugError = err.Error() return m, nil } - - if !m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue) { - m.slugError = "Someone already uses this subdomain." - return m, nil - } - m.editingSlug = false m.slugError = "" return m, tea.Batch(tea.ClearScreen, textinput.Blink) @@ -823,30 +813,3 @@ func buildURL(protocol, subdomain, domain string) string { func generateRandomSubdomain() string { return random.GenerateRandomString(20) } - -func isValidSlug(slug string) bool { - if len(slug) < minSlugLength || len(slug) > maxSlugLength { - return false - } - - if slug[0] == '-' || slug[len(slug)-1] == '-' { - return false - } - - for _, c := range slug { - if !isValidSlugChar(byte(c)) { - return false - } - } - - return true -} - -func isValidSlugChar(c byte) bool { - return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' -} - -func isForbiddenSlug(slug string) bool { - _, ok := forbiddenSlugs[slug] - return ok -} diff --git a/session/registry.go b/session/registry.go index cc1e955..8ff70fa 100644 --- a/session/registry.go +++ b/session/registry.go @@ -1,10 +1,13 @@ package session -import "sync" +import ( + "fmt" + "sync" +) type Registry interface { - Get(slug string) (session *SSHSession, exist bool) - Update(oldSlug, newSlug string) (success bool) + Get(slug string) (session *SSHSession, err error) + Update(oldSlug, newSlug string) error Register(slug string, session *SSHSession) (success bool) Remove(slug string) } @@ -19,31 +22,40 @@ func NewRegistry() Registry { } } -func (r *registry) Get(slug string) (session *SSHSession, exist bool) { +func (r *registry) Get(slug string) (session *SSHSession, err error) { r.mu.RLock() defer r.mu.RUnlock() - session, exist = r.clients[slug] - return + client, ok := r.clients[slug] + if !ok { + return nil, fmt.Errorf("session not found") + } + return client, nil } -func (r *registry) Update(oldSlug, newSlug string) (success bool) { +func (r *registry) Update(oldSlug, newSlug string) error { + if isForbiddenSlug(newSlug) { + return fmt.Errorf("this subdomain is reserved. Please choose a different one") + } else if !isValidSlug(newSlug) { + return fmt.Errorf("invalid subdomain. Follow the rules") + } + r.mu.Lock() defer r.mu.Unlock() if _, exists := r.clients[newSlug]; exists && newSlug != oldSlug { - return false + return fmt.Errorf("someone already uses this subdomain") } client, ok := r.clients[oldSlug] if !ok { - return false + return fmt.Errorf("session not found") } delete(r.clients, oldSlug) client.slugManager.Set(newSlug) r.clients[newSlug] = client - return true + return nil } func (r *registry) Register(slug string, session *SSHSession) (success bool) { @@ -64,3 +76,164 @@ func (r *registry) Remove(slug string) { delete(r.clients, slug) } + +func isValidSlug(slug string) bool { + if len(slug) < minSlugLength || len(slug) > maxSlugLength { + return false + } + + if slug[0] == '-' || slug[len(slug)-1] == '-' { + return false + } + + for _, c := range slug { + if !isValidSlugChar(byte(c)) { + return false + } + } + + return true +} + +func isValidSlugChar(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' +} + +func isForbiddenSlug(slug string) bool { + _, ok := forbiddenSlugs[slug] + return ok +} + +var forbiddenSlugs = map[string]struct{}{ + "ping": {}, + "staging": {}, + "admin": {}, + "root": {}, + "api": {}, + "www": {}, + "support": {}, + "help": {}, + "status": {}, + "health": {}, + "login": {}, + "logout": {}, + "signup": {}, + "register": {}, + "settings": {}, + "config": {}, + "null": {}, + "undefined": {}, + "example": {}, + "test": {}, + "dev": {}, + "system": {}, + "administrator": {}, + "dashboard": {}, + "account": {}, + "profile": {}, + "user": {}, + "users": {}, + "auth": {}, + "oauth": {}, + "callback": {}, + "webhook": {}, + "webhooks": {}, + "static": {}, + "assets": {}, + "cdn": {}, + "mail": {}, + "email": {}, + "ftp": {}, + "ssh": {}, + "git": {}, + "svn": {}, + "blog": {}, + "news": {}, + "about": {}, + "contact": {}, + "terms": {}, + "privacy": {}, + "legal": {}, + "billing": {}, + "payment": {}, + "checkout": {}, + "cart": {}, + "shop": {}, + "store": {}, + "download": {}, + "uploads": {}, + "images": {}, + "img": {}, + "css": {}, + "js": {}, + "fonts": {}, + "public": {}, + "private": {}, + "internal": {}, + "external": {}, + "proxy": {}, + "cache": {}, + "debug": {}, + "metrics": {}, + "monitoring": {}, + "graphql": {}, + "rest": {}, + "rpc": {}, + "socket": {}, + "ws": {}, + "wss": {}, + "app": {}, + "apps": {}, + "mobile": {}, + "desktop": {}, + "embed": {}, + "widget": {}, + "docs": {}, + "documentation": {}, + "wiki": {}, + "forum": {}, + "community": {}, + "feedback": {}, + "report": {}, + "abuse": {}, + "spam": {}, + "security": {}, + "verify": {}, + "confirm": {}, + "reset": {}, + "password": {}, + "recovery": {}, + "unsubscribe": {}, + "subscribe": {}, + "notifications": {}, + "alerts": {}, + "messages": {}, + "inbox": {}, + "outbox": {}, + "sent": {}, + "draft": {}, + "trash": {}, + "archive": {}, + "search": {}, + "explore": {}, + "discover": {}, + "trending": {}, + "popular": {}, + "featured": {}, + "new": {}, + "latest": {}, + "top": {}, + "best": {}, + "hot": {}, + "random": {}, + "all": {}, + "any": {}, + "none": {}, + "true": {}, + "false": {}, +} + +var ( + minSlugLength = 3 + maxSlugLength = 20 +)