From 8fd9f8b567f6784fa3cfe6521e508a5f75bbc843 Mon Sep 17 00:00:00 2001 From: bagas Date: Sat, 3 Jan 2026 20:06:14 +0700 Subject: [PATCH] feat: implement sessions request from grpc server --- README.md | 4 + go.mod | 3 +- internal/grpc/client/client.go | 138 +++++++++++++++++++++++---------- main.go | 77 +++++++++++------- server/server.go | 14 ++-- 5 files changed, 155 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index dbcf475..474f430 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,10 @@ The following environment variables can be configured in the `.env` file: | `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No | | `PPROF_ENABLED` | Enable pprof profiling server | `false` | No | | `PPROF_PORT` | Port for pprof server | `6060` | No | +| `MODE` | Runtime mode: `standalone` (default, no gRPC/auth) or `node` (enable gRPC + auth) | `standalone` | No | +| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No | +| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No | +| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | - (required in `node`) | Yes (node mode) | **Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality. diff --git a/go.mod b/go.mod index 0806c93..87897ae 100644 --- a/go.mod +++ b/go.mod @@ -8,12 +8,12 @@ require ( github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 - github.com/golang/protobuf v1.5.4 github.com/joho/godotenv v1.5.1 github.com/libdns/cloudflare v0.2.2 github.com/muesli/termenv v0.16.0 golang.org/x/crypto v0.46.0 google.golang.org/grpc v1.78.0 + google.golang.org/protobuf v1.36.11 ) require ( @@ -49,7 +49,6 @@ require ( golang.org/x/text v0.32.0 // indirect golang.org/x/tools v0.39.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect - google.golang.org/protobuf v1.36.11 // indirect ) replace git.fossy.my.id/bagas/tunnel-please-grpc => ../tunnel-please-grpc diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index ea86b42..0883a86 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -8,6 +8,8 @@ import ( "io" "log" "time" + "tunnel_pls/internal/config" + "tunnel_pls/session" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" @@ -125,34 +127,86 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error) }, nil } -func (c *Client) SubscribeEvents(ctx context.Context, identity string) error { - subscribe, err := c.eventService.Subscribe(ctx) - if err != nil { - return err - } - err = subscribe.Send(&proto.Client{ - Type: proto.EventType_AUTHENTICATION, - Payload: &proto.Client_AuthEvent{ - AuthEvent: &proto.Authentication{ - Identity: identity, - AuthToken: "test_auth_key", - }, - }, - }) +func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string) error { + const ( + baseBackoff = time.Second + maxBackoff = 30 * time.Second + ) - if err != nil { - log.Println("Authentication failed to send to gRPC server:", err) - return err + backoff := baseBackoff + wait := func() error { + if backoff <= 0 { + return nil + } + select { + case <-time.After(backoff): + return nil + case <-ctx.Done(): + return ctx.Err() + } } - log.Println("Authentication Successfully sent to gRPC server") - err = c.processEventStream(subscribe) - if err != nil { - return err + growBackoff := func() { + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } + + for { + subscribe, err := c.eventService.Subscribe(ctx) + if err != nil { + if !isConnectionError(err) { + return err + } + if status.Code(err) == codes.Unauthenticated { + return err + } + if err := wait(); err != nil { + return err + } + growBackoff() + continue + } + + err = subscribe.Send(&proto.Node{ + Type: proto.EventType_AUTHENTICATION, + Payload: &proto.Node_AuthEvent{ + AuthEvent: &proto.Authentication{ + Identity: identity, + AuthToken: authToken, + }, + }, + }) + + if err != nil { + log.Println("Authentication failed to send to gRPC server:", err) + if isConnectionError(err) { + if err := wait(); err != nil { + return err + } + growBackoff() + continue + } + return err + } + log.Println("Authentication Successfully sent to gRPC server") + backoff = baseBackoff + + if err = c.processEventStream(subscribe); err != nil { + if isConnectionError(err) { + log.Printf("Reconnect to controller within %v sec", backoff.Seconds()) + if err := wait(); err != nil { + return err + } + growBackoff() + continue + } + return err + } } - return nil } -func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Client, proto.Controller]) error { +func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error { for { recv, err := subscribe.Recv() if err != nil { @@ -160,6 +214,10 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Cli log.Printf("connection error receiving from gRPC server: %v", err) return err } + if status.Code(err) == codes.Unauthenticated { + log.Printf("Authentication failed: %v", err) + return err + } log.Printf("non-connection receive error from gRPC server: %v", err) continue } @@ -167,11 +225,11 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Cli case proto.EventType_SLUG_CHANGE: oldSlug := recv.GetSlugEvent().GetOld() newSlug := recv.GetSlugEvent().GetNew() - session, err := c.sessionRegistry.Get(oldSlug) + sess, err := c.sessionRegistry.Get(oldSlug) if err != nil { - errSend := subscribe.Send(&proto.Client{ + errSend := subscribe.Send(&proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, - Payload: &proto.Client_SlugEventResponse{ + Payload: &proto.Node_SlugEventResponse{ SlugEventResponse: &proto.SlugChangeEventResponse{ Success: false, Message: err.Error(), @@ -189,9 +247,9 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Cli } err = c.sessionRegistry.Update(oldSlug, newSlug) if err != nil { - errSend := subscribe.Send(&proto.Client{ + errSend := subscribe.Send(&proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, - Payload: &proto.Client_SlugEventResponse{ + Payload: &proto.Node_SlugEventResponse{ SlugEventResponse: &proto.SlugChangeEventResponse{ Success: false, Message: err.Error(), @@ -207,10 +265,10 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Cli } continue } - session.GetInteraction().Redraw() - err = subscribe.Send(&proto.Client{ + sess.GetInteraction().Redraw() + err = subscribe.Send(&proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, - Payload: &proto.Client_SlugEventResponse{ + Payload: &proto.Node_SlugEventResponse{ SlugEventResponse: &proto.SlugChangeEventResponse{ Success: true, Message: "", @@ -231,6 +289,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Cli for _, ses := range sessions { detail := ses.Detail() details = append(details, &proto.Detail{ + Node: config.Getenv("domain", "localhost"), ForwardingType: detail.ForwardingType, Slug: detail.Slug, UserId: detail.UserID, @@ -238,9 +297,9 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Cli StartedAt: timestamppb.New(detail.StartedAt), }) } - err = subscribe.Send(&proto.Client{ + err = subscribe.Send(&proto.Node{ Type: proto.EventType_GET_SESSIONS, - Payload: &proto.Client_GetSessionsEvent{ + Payload: &proto.Node_GetSessionsEvent{ GetSessionsEvent: &proto.GetSessionsResponse{ Details: details, }, @@ -264,16 +323,16 @@ func (c *Client) GetConnection() *grpc.ClientConn { return c.conn } -func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bool, err error) { +func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token}) if err != nil { - return false, err + return false, "UNAUTHORIZED", err + } - } if check.GetResponse() == proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED { - return false, nil + return false, "UNAUTHORIZED", nil } - return true, nil + return true, check.GetUser(), nil } func (c *Client) Close() error { @@ -289,15 +348,12 @@ func (c *Client) CheckServerHealth(ctx context.Context) error { resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{ Service: "", }) - if err != nil { return fmt.Errorf("health check failed: %w", err) } - if resp.Status != grpc_health_v1.HealthCheckResponse_SERVING { return fmt.Errorf("server not serving: %v", resp.Status) } - return nil } diff --git a/main.go b/main.go index 1daa6e0..36ec8ca 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "net/http" _ "net/http/pprof" "os" + "strings" "time" "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" @@ -29,6 +30,9 @@ func main() { log.Printf("Starting %s", version.GetVersion()) + mode := strings.ToLower(config.Getenv("MODE", "standalone")) + isNodeMode := mode == "node" + pprofEnabled := config.Getenv("PPROF_ENABLED", "false") if pprofEnabled == "true" { pprofPort := config.Getenv("PPROF_PORT", "6060") @@ -64,40 +68,55 @@ func main() { sshConfig.AddHostKey(private) sessionRegistry := session.NewRegistry() - grpcClient, err := client.New(&client.GrpcConfig{ - Address: "localhost:8080", - UseTLS: false, - InsecureSkipVerify: false, - Timeout: 10 * time.Second, - KeepAlive: true, - MaxRetries: 3, - }, sessionRegistry) - if err != nil { - return - } - defer func(grpcClient *client.Client) { - err := grpcClient.Close() - if err != nil { + var grpcClient *client.Client + var cancel context.CancelFunc = func() {} + var ctx context.Context = context.Background() + if isNodeMode { + grpcHost := config.Getenv("GRPC_ADDRESS", "localhost") + grpcPort := config.Getenv("GRPC_PORT", "8080") + grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort) + nodeToken := config.Getenv("NODE_TOKEN", "") + if nodeToken == "" { + log.Fatalf("NODE_TOKEN is required in node mode") + return } - }(grpcClient) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - err = grpcClient.CheckServerHealth(ctx) - if err != nil { - log.Fatalf("gRPC health check failed: %s", err) - return - } - cancel() - - ctx, cancel = context.WithCancel(context.Background()) - go func() { - identity := config.Getenv("DOMAIN", "localhost") - err = grpcClient.SubscribeEvents(ctx, identity) + + grpcClient, err = client.New(&client.GrpcConfig{ + Address: grpcAddr, + UseTLS: false, + InsecureSkipVerify: false, + Timeout: 10 * time.Second, + KeepAlive: true, + MaxRetries: 3, + }, sessionRegistry) if err != nil { return } - }() + defer func(grpcClient *client.Client) { + err := grpcClient.Close() + if err != nil { + + } + }(grpcClient) + + ctx, cancel = context.WithTimeout(context.Background(), time.Second*5) + err = grpcClient.CheckServerHealth(ctx) + if err != nil { + log.Fatalf("gRPC health check failed: %s", err) + return + } + cancel() + + ctx, cancel = context.WithCancel(context.Background()) + go func() { + identity := config.Getenv("DOMAIN", "localhost") + err = grpcClient.SubscribeEvents(ctx, identity, nodeToken) + if err != nil { + return + } + }() + } app, err := server.NewServer(sshConfig, sessionRegistry, grpcClient) if err != nil { diff --git a/server/server.go b/server/server.go index f377a4b..0ee111c 100644 --- a/server/server.go +++ b/server/server.go @@ -83,17 +83,13 @@ func (s *Server) handleConnection(conn net.Conn) { ctx := context.Background() log.Println("SSH connection established:", sshConn.User()) - //Fallback: kalau auth gagal userID di set UNAUTHORIZED - authorized, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User()) - - var userID string - if authorized { - userID = sshConn.User() - } else { - userID = "UNAUTHORIZED" + user := "UNAUTHORIZED" + if s.grpcClient != nil { + _, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User()) + user = u } - sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, userID) + sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user) err = sshSession.Start() if err != nil { log.Printf("SSH session ended with error: %v", err)