package client import ( "context" "errors" "fmt" "io" "testing" "time" "tunnel_pls/internal/registry" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" "tunnel_pls/session/slug" "tunnel_pls/types" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" "github.com/stretchr/testify/mock" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" ) func TestClient_ClientConn(t *testing.T) { conn := &grpc.ClientConn{} c := &client{conn: conn} if c.ClientConn() != conn { t.Errorf("ClientConn() did not return expected connection") } } func TestClient_Close(t *testing.T) { c := &client{} if err := c.Close(); err != nil { t.Errorf("Close() on nil connection returned error: %v", err) } } func TestAuthorizeConn(t *testing.T) { mockUserSvc := &mockUserServiceClient{} c := &client{authorizeConnectionService: mockUserSvc} tests := []struct { name string token string mockResp *proto.CheckResponse mockErr error wantAuth bool wantUser string wantErr bool }{ { name: "Success", token: "valid", mockResp: &proto.CheckResponse{Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED, User: "mas-fuad"}, wantAuth: true, wantUser: "mas-fuad", wantErr: false, }, { name: "Unauthorized", token: "invalid", mockResp: &proto.CheckResponse{Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED}, wantAuth: false, wantUser: "UNAUTHORIZED", wantErr: false, }, { name: "Error", token: "error", mockErr: errors.New("grpc error"), wantAuth: false, wantUser: "UNAUTHORIZED", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockUserSvc.checkFunc = func(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) { if in.AuthToken != tt.token { t.Errorf("expected token %s, got %s", tt.token, in.AuthToken) } return tt.mockResp, tt.mockErr } auth, user, err := c.AuthorizeConn(context.Background(), tt.token) if (err != nil) != tt.wantErr { t.Errorf("AuthorizeConn() error = %v, wantErr %v", err, tt.wantErr) } if auth != tt.wantAuth { t.Errorf("AuthorizeConn() auth = %v, wantAuth %v", auth, tt.wantAuth) } if user != tt.wantUser { t.Errorf("AuthorizeConn() user = %s, wantUser %s", user, tt.wantUser) } }) } } func TestHandleSubscribeError(t *testing.T) { c := &client{} ctx := context.Background() canceledCtx, cancel := context.WithCancel(ctx) cancel() tests := []struct { name string ctx context.Context err error backoff time.Duration wantErr bool wantB time.Duration }{ { name: "ContextCanceled", ctx: canceledCtx, err: context.Canceled, backoff: time.Second, wantErr: true, }, { name: "GrpcCanceled", ctx: ctx, err: status.Error(codes.Canceled, "canceled"), backoff: time.Second, wantErr: true, }, { name: "CtxErrSet", ctx: canceledCtx, err: errors.New("other error"), backoff: time.Second, wantErr: true, }, { name: "Unauthenticated", ctx: ctx, err: status.Error(codes.Unauthenticated, "unauth"), backoff: time.Second, wantErr: true, }, { name: "ConnectionError", ctx: ctx, err: status.Error(codes.Unavailable, "unavailable"), backoff: time.Second, wantErr: false, wantB: 2 * time.Second, }, { name: "NonConnectionError", ctx: ctx, err: status.Error(codes.Internal, "internal"), backoff: time.Second, wantErr: true, }, { name: "WaitCanceled", ctx: canceledCtx, err: status.Error(codes.Unavailable, "unavailable"), backoff: time.Second, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { backoff := tt.backoff err := c.handleSubscribeError(tt.ctx, tt.err, &backoff) if (err != nil) != tt.wantErr { t.Errorf("handleSubscribeError() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && backoff != tt.wantB { t.Errorf("handleSubscribeError() backoff = %v, want %v", backoff, tt.wantB) } }) } t.Run("WaitCanceledReal", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) backoff := 50 * time.Millisecond go func() { time.Sleep(10 * time.Millisecond) cancel() }() err := c.handleSubscribeError(ctx, status.Error(codes.Unavailable, "unavailable"), &backoff) if err == nil { t.Errorf("expected error from wait") } }) } func TestHandleStreamError(t *testing.T) { c := &client{} ctx := context.Background() canceledCtx, cancel := context.WithCancel(ctx) cancel() tests := []struct { name string ctx context.Context err error backoff time.Duration wantErr bool wantB time.Duration }{ { name: "ContextCanceled", ctx: canceledCtx, err: context.Canceled, backoff: time.Second, wantErr: true, }, { name: "GrpcCanceled", ctx: ctx, err: status.Error(codes.Canceled, "canceled"), backoff: time.Second, wantErr: true, }, { name: "CtxErrSet", ctx: canceledCtx, err: errors.New("other error"), backoff: time.Second, wantErr: true, }, { name: "ConnectionError", ctx: ctx, err: status.Error(codes.Unavailable, "unavailable"), backoff: time.Second, wantErr: false, wantB: 2 * time.Second, }, { name: "NonConnectionError", ctx: ctx, err: status.Error(codes.Internal, "internal"), backoff: time.Second, wantErr: true, }, { name: "WaitCanceled", ctx: canceledCtx, err: status.Error(codes.Unavailable, "unavailable"), backoff: time.Second, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { backoff := tt.backoff err := c.handleStreamError(tt.ctx, tt.err, &backoff) if (err != nil) != tt.wantErr { t.Errorf("handleStreamError() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && backoff != tt.wantB { t.Errorf("handleStreamError() backoff = %v, want %v", backoff, tt.wantB) } }) } t.Run("WaitCanceledReal", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) backoff := 50 * time.Millisecond go func() { time.Sleep(10 * time.Millisecond) cancel() }() err := c.handleStreamError(ctx, status.Error(codes.Unavailable, "unavailable"), &backoff) if err == nil { t.Errorf("expected error from wait") } }) } func TestHandleAuthError(t *testing.T) { c := &client{} ctx := context.Background() tests := []struct { name string err error backoff time.Duration wantErr bool wantB time.Duration }{ { name: "ConnectionError", err: status.Error(codes.Unavailable, "unavailable"), backoff: time.Second, wantErr: false, wantB: 2 * time.Second, }, { name: "NonConnectionError", err: status.Error(codes.Internal, "internal"), backoff: time.Second, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { backoff := tt.backoff err := c.handleAuthError(ctx, tt.err, &backoff) if (err != nil) != tt.wantErr { t.Errorf("handleAuthError() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && backoff != tt.wantB { t.Errorf("handleAuthError() backoff = %v, want %v", backoff, tt.wantB) } }) } } func TestHandleAuthError_WaitFail(t *testing.T) { c := &client{} ctx, cancel := context.WithCancel(context.Background()) cancel() backoff := time.Second err := c.handleAuthError(ctx, status.Error(codes.Unavailable, "unavailable"), &backoff) if err == nil { t.Errorf("expected error when wait fails") } } func TestProcessEventStream(t *testing.T) { mockStream := &mockSubscribeClient{} c := &client{} t.Run("UnknownEventType", func(t *testing.T) { mockStream.recvFunc = func() (*proto.Events, error) { return &proto.Events{Type: proto.EventType(999)}, nil } first := true mockStream.recvFunc = func() (*proto.Events, error) { if first { first = false return &proto.Events{Type: proto.EventType(999)}, nil } return nil, io.EOF } err := c.processEventStream(mockStream) if !errors.Is(err, io.EOF) { t.Errorf("expected EOF, got %v", err) } }) t.Run("DispatchSuccess", func(t *testing.T) { events := []proto.EventType{ proto.EventType_SLUG_CHANGE, proto.EventType_GET_SESSIONS, proto.EventType_TERMINATE_SESSION, } for _, et := range events { t.Run(et.String(), func(t *testing.T) { first := true mockStream.recvFunc = func() (*proto.Events, error) { if first { first = false payload := &proto.Events{Type: et} switch et { case proto.EventType_SLUG_CHANGE: payload.Payload = &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}} case proto.EventType_GET_SESSIONS: payload.Payload = &proto.Events_GetSessionsEvent{GetSessionsEvent: &proto.GetSessionsEvent{}} case proto.EventType_TERMINATE_SESSION: payload.Payload = &proto.Events_TerminateSessionEvent{TerminateSessionEvent: &proto.TerminateSessionEvent{}} } return payload, nil } return nil, io.EOF } mockReg := &mockRegistry{} mockReg.getAllSessionFromUserFunc = func(user string) []registry.Session { return nil } mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } c.sessionRegistry = mockReg c.config = &MockConfig{} c.config.(*MockConfig).On("Domain").Return("test.com") mockStream.sendFunc = func(n *proto.Node) error { return nil } err := c.processEventStream(mockStream) if !errors.Is(err, io.EOF) { t.Errorf("expected EOF, got %v", err) } }) } }) t.Run("HandlerError", func(t *testing.T) { first := true mockStream.recvFunc = func() (*proto.Events, error) { if first { first = false return &proto.Events{Type: proto.EventType_SLUG_CHANGE, Payload: &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}}}, nil } return nil, io.EOF } mockReg := &mockRegistry{} mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } c.sessionRegistry = mockReg mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Unavailable, "send fail") } err := c.processEventStream(mockStream) if !errors.Is(err, status.Error(codes.Unavailable, "send fail")) { t.Errorf("expected send fail error, got %v", err) } }) } func TestSendNode(t *testing.T) { c := &client{} mockStream := &mockSubscribeClient{} t.Run("Success", func(t *testing.T) { mockStream.sendFunc = func(n *proto.Node) error { return nil } err := c.sendNode(mockStream, &proto.Node{}, "context") if err != nil { t.Errorf("sendNode error = %v", err) } }) t.Run("ConnectionError", func(t *testing.T) { mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Unavailable, "fail") } err := c.sendNode(mockStream, &proto.Node{}, "context") if err == nil { t.Errorf("expected error") } }) t.Run("OtherError", func(t *testing.T) { mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Internal, "fail") } err := c.sendNode(mockStream, &proto.Node{}, "context") if err != nil { t.Errorf("expected nil error for non-connection error (logged only)") } }) } func TestHandleSlugChange(t *testing.T) { mockReg := &mockRegistry{} mockStream := &mockSubscribeClient{} c := &client{sessionRegistry: mockReg} evt := &proto.Events{ Payload: &proto.Events_SlugEvent{ SlugEvent: &proto.SlugChangeEvent{ User: "mas-fuad", Old: "old-slug", New: "new-slug", }, }, } t.Run("Success", func(t *testing.T) { mockSess := &mockSession{} mockInter := &mockInteraction{} mockSess.interactionFunc = func() interaction.Interaction { return mockInter } mockReg.getFunc = func(key registry.Key) (registry.Session, error) { if key.Id != "old-slug" { t.Errorf("expected old-slug, got %s", key.Id) } return mockSess, nil } mockReg.updateFunc = func(user string, oldKey, newKey registry.Key) error { if user != "mas-fuad" || oldKey.Id != "old-slug" || newKey.Id != "new-slug" { t.Errorf("unexpected update args") } return nil } sent := false mockStream.sendFunc = func(n *proto.Node) error { sent = true if n.Type != proto.EventType_SLUG_CHANGE_RESPONSE { t.Errorf("expected slug change response") } resp := n.GetSlugEventResponse() if !resp.Success { t.Errorf("expected success") } return nil } err := c.handleSlugChange(mockStream, evt) if err != nil { t.Errorf("handleSlugChange error = %v", err) } if !mockInter.redrawCalled { t.Errorf("redraw was not called") } if !sent { t.Errorf("response not sent") } }) t.Run("SessionNotFound", func(t *testing.T) { mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("not found") } mockStream.sendFunc = func(n *proto.Node) error { resp := n.GetSlugEventResponse() if resp.Success || resp.Message != "not found" { t.Errorf("unexpected failure response: %v", resp) } return nil } err := c.handleSlugChange(mockStream, evt) if err != nil { t.Errorf("handleSlugChange should return nil if error is handled via response, but it currently returns whatever sendNode returns") } }) t.Run("UpdateError", func(t *testing.T) { mockSess := &mockSession{} mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return mockSess, nil } mockReg.updateFunc = func(user string, oldKey, newKey registry.Key) error { return errors.New("update fail") } mockStream.sendFunc = func(n *proto.Node) error { resp := n.GetSlugEventResponse() if resp.Success || resp.Message != "update fail" { t.Errorf("unexpected failure response: %v", resp) } return nil } err := c.handleSlugChange(mockStream, evt) if err != nil { t.Errorf("handleSlugChange error = %v", err) } }) } func TestHandleGetSessions(t *testing.T) { mockReg := &mockRegistry{} mockStream := &mockSubscribeClient{} mockCfg := &MockConfig{} mockCfg.On("Domain").Return("test.com") c := &client{sessionRegistry: mockReg, config: mockCfg} evt := &proto.Events{ Payload: &proto.Events_GetSessionsEvent{ GetSessionsEvent: &proto.GetSessionsEvent{ Identity: "mas-fuad", }, }, } t.Run("Success", func(t *testing.T) { now := time.Now() mockSess := &mockSession{} mockSess.detailFunc = func() *types.Detail { return &types.Detail{ ForwardingType: "http", Slug: "myslug", UserID: "mas-fuad", Active: true, StartedAt: now, } } mockReg.getAllSessionFromUserFunc = func(user string) []registry.Session { if user != "mas-fuad" { t.Errorf("expected mas-fuad, got %s", user) } return []registry.Session{mockSess} } sent := false mockStream.sendFunc = func(n *proto.Node) error { sent = true if n.Type != proto.EventType_GET_SESSIONS { t.Errorf("expected get sessions response type") } resp := n.GetGetSessionsEvent() if len(resp.Details) != 1 || resp.Details[0].Slug != "myslug" { t.Errorf("unexpected details: %v", resp.Details) } return nil } err := c.handleGetSessions(mockStream, evt) if err != nil { t.Errorf("handleGetSessions error = %v", err) } if !sent { t.Errorf("response not sent") } }) } func TestHandleTerminateSession(t *testing.T) { mockReg := &mockRegistry{} mockStream := &mockSubscribeClient{} c := &client{sessionRegistry: mockReg} evt := &proto.Events{ Payload: &proto.Events_TerminateSessionEvent{ TerminateSessionEvent: &proto.TerminateSessionEvent{ User: "mas-fuad", Slug: "myslug", TunnelType: proto.TunnelType_HTTP, }, }, } t.Run("Success", func(t *testing.T) { mockSess := &mockSession{} mockLife := &mockLifecycle{} mockSess.lifecycleFunc = func() lifecycle.Lifecycle { return mockLife } closed := false mockLife.closeFunc = func() error { closed = true return nil } mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { if user != "mas-fuad" || key.Id != "myslug" || key.Type != types.TunnelTypeHTTP { t.Errorf("unexpected get args") } return mockSess, nil } sent := false mockStream.sendFunc = func(n *proto.Node) error { sent = true resp := n.GetTerminateSessionEventResponse() if !resp.Success { t.Errorf("expected success") } return nil } err := c.handleTerminateSession(mockStream, evt) if err != nil { t.Errorf("handleTerminateSession error = %v", err) } if !closed { t.Errorf("close was not called") } if !sent { t.Errorf("response not sent") } }) t.Run("TunnelTypeUnknown", func(t *testing.T) { badEvt := &proto.Events{ Payload: &proto.Events_TerminateSessionEvent{ TerminateSessionEvent: &proto.TerminateSessionEvent{ TunnelType: proto.TunnelType(999), }, }, } mockStream.sendFunc = func(n *proto.Node) error { resp := n.GetTerminateSessionEventResponse() if resp.Success || resp.Message == "" { t.Errorf("expected failure response") } return nil } err := c.handleTerminateSession(mockStream, badEvt) if err != nil { t.Errorf("handleTerminateSession error = %v", err) } }) t.Run("SessionNotFound", func(t *testing.T) { mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return nil, errors.New("not found") } mockStream.sendFunc = func(n *proto.Node) error { resp := n.GetTerminateSessionEventResponse() if resp.Success || resp.Message != "not found" { t.Errorf("unexpected failure response: %v", resp) } return nil } err := c.handleTerminateSession(mockStream, evt) if err != nil { t.Errorf("handleTerminateSession error = %v", err) } }) t.Run("CloseError", func(t *testing.T) { mockSess := &mockSession{} mockLife := &mockLifecycle{} mockSess.lifecycleFunc = func() lifecycle.Lifecycle { return mockLife } mockLife.closeFunc = func() error { return errors.New("close fail") } mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return mockSess, nil } mockStream.sendFunc = func(n *proto.Node) error { resp := n.GetTerminateSessionEventResponse() if resp.Success || resp.Message != "close fail" { t.Errorf("expected failure response: %v", resp) } return nil } err := c.handleTerminateSession(mockStream, evt) if err != nil { t.Errorf("handleTerminateSession error = %v", err) } }) } func TestSubscribeAndProcess(t *testing.T) { mockEventSvc := &mockEventServiceClient{} c := &client{eventService: mockEventSvc} ctx := context.Background() backoff := time.Second t.Run("SubscribeError", func(t *testing.T) { mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { return nil, status.Error(codes.Unauthenticated, "unauth") } err := c.subscribeAndProcess(ctx, "id", "token", &backoff) if !errors.Is(err, status.Error(codes.Unauthenticated, "unauth")) { t.Errorf("expected unauth error, got %v", err) } }) t.Run("AuthSendError", func(t *testing.T) { mockStream := &mockSubscribeClient{} mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { return mockStream, nil } mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Internal, "send fail") } err := c.subscribeAndProcess(ctx, "id", "token", &backoff) if !errors.Is(err, status.Error(codes.Internal, "send fail")) { t.Errorf("expected send fail, got %v", err) } }) t.Run("StreamError", func(t *testing.T) { mockStream := &mockSubscribeClient{} mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { return mockStream, nil } mockStream.sendFunc = func(n *proto.Node) error { return nil } mockStream.recvFunc = func() (*proto.Events, error) { return nil, status.Error(codes.Internal, "stream fail") } err := c.subscribeAndProcess(ctx, "id", "token", &backoff) if !errors.Is(err, status.Error(codes.Internal, "stream fail")) { t.Errorf("expected stream fail, got %v", err) } }) } func TestSubscribeEvents(t *testing.T) { mockEventSvc := &mockEventServiceClient{} c := &client{eventService: mockEventSvc} t.Run("ReturnsOnError", func(t *testing.T) { mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { return nil, errors.New("fatal error") } err := c.SubscribeEvents(context.Background(), "id", "token") if err == nil || err.Error() != "fatal error" { t.Errorf("expected fatal error, got %v", err) } }) t.Run("RetryLoop", func(t *testing.T) { oldB := initialBackoff initialBackoff = 5 * time.Millisecond defer func() { initialBackoff = oldB }() ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() callCount := 0 mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { callCount++ return nil, status.Error(codes.Unavailable, "unavailable") } err := c.SubscribeEvents(ctx, "id", "token") if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { t.Errorf("expected timeout/canceled error, got %v", err) } if callCount <= 1 { t.Errorf("expected multiple calls due to retry, got %d", callCount) } }) } func TestCheckServerHealth(t *testing.T) { mockHealth := &mockHealthClient{} old := healthNewHealthClient healthNewHealthClient = func(cc grpc.ClientConnInterface) grpc_health_v1.HealthClient { return mockHealth } defer func() { healthNewHealthClient = old }() c := &client{} t.Run("Success", func(t *testing.T) { mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) { return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil } err := c.CheckServerHealth(context.Background()) if err != nil { t.Errorf("expected nil error, got %v", err) } }) t.Run("Error", func(t *testing.T) { mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) { return nil, errors.New("health fail") } err := c.CheckServerHealth(context.Background()) if err == nil || err.Error() != "health check failed: health fail" { t.Errorf("expected health fail error, got %v", err) } }) t.Run("NotServing", func(t *testing.T) { mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) { return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil } err := c.CheckServerHealth(context.Background()) if err == nil || err.Error() != "server not serving: NOT_SERVING" { t.Errorf("expected not serving error, got %v", err) } }) } func TestNew_Error(t *testing.T) { old := grpcNewClient grpcNewClient = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { return nil, errors.New("dial fail") } defer func() { grpcNewClient = old }() mockConfig := &MockConfig{} mockConfig.On("GRPCAddress").Return("localhost") mockConfig.On("GRPCPort").Return("1234") cli, err := New(mockConfig, &mockRegistry{}) if err == nil || err.Error() != "failed to connect to gRPC server at localhost:1234: dial fail" { t.Errorf("expected dial fail error, got %v", err) } if cli != nil { t.Errorf("expected nil client") } } func TestNew(t *testing.T) { mockConfig := &MockConfig{} mockReg := &mockRegistry{} mockConfig.On("GRPCAddress").Return("localhost") mockConfig.On("GRPCPort").Return("1234") cli, err := New(mockConfig, mockReg) if err != nil { t.Errorf("New() error = %v", err) } if cli == nil { t.Fatal("New() returned nil client") } defer cli.Close() } type MockConfig struct { mock.Mock } func (m *MockConfig) Domain() string { return m.Called().String(0) } func (m *MockConfig) SSHPort() string { return m.Called().String(0) } func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } func (m *MockConfig) NodeToken() string { return m.Called().String(0) } func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } type mockRegistry struct { registry.Registry getFunc func(key registry.Key) (registry.Session, error) getWithUserFunc func(user string, key registry.Key) (registry.Session, error) updateFunc func(user string, oldKey, newKey registry.Key) error getAllSessionFromUserFunc func(user string) []registry.Session } func (m *mockRegistry) Get(key registry.Key) (registry.Session, error) { return m.getFunc(key) } func (m *mockRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { return m.getWithUserFunc(user, key) } func (m *mockRegistry) Update(user string, oldKey, newKey registry.Key) error { return m.updateFunc(user, oldKey, newKey) } func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session { return m.getAllSessionFromUserFunc(user) } type mockSession struct { registry.Session lifecycleFunc func() lifecycle.Lifecycle interactionFunc func() interaction.Interaction detailFunc func() *types.Detail slugFunc func() slug.Slug } func (m *mockSession) Lifecycle() lifecycle.Lifecycle { return m.lifecycleFunc() } func (m *mockSession) Interaction() interaction.Interaction { return m.interactionFunc() } func (m *mockSession) Detail() *types.Detail { return m.detailFunc() } func (m *mockSession) Slug() slug.Slug { return m.slugFunc() } type mockInteraction struct { interaction.Interaction redrawCalled bool } func (m *mockInteraction) Redraw() { m.redrawCalled = true } type mockLifecycle struct { lifecycle.Lifecycle closeFunc func() error } func (m *mockLifecycle) Close() error { return m.closeFunc() } type mockEventServiceClient struct { proto.EventServiceClient subscribeFunc func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) } func (m *mockEventServiceClient) Subscribe(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { return m.subscribeFunc(ctx, opts...) } type mockSubscribeClient struct { grpc.ClientStream sendFunc func(*proto.Node) error recvFunc func() (*proto.Events, error) } func (m *mockSubscribeClient) Send(n *proto.Node) error { return m.sendFunc(n) } func (m *mockSubscribeClient) Recv() (*proto.Events, error) { return m.recvFunc() } func (m *mockSubscribeClient) Context() context.Context { return context.Background() } type mockUserServiceClient struct { proto.UserServiceClient checkFunc func(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) } func (m *mockUserServiceClient) Check(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) { return m.checkFunc(ctx, in, opts...) } type mockHealthClient struct { grpc_health_v1.HealthClient checkFunc func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) } func (m *mockHealthClient) Check(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) { return m.checkFunc(ctx, in, opts...) } func TestProtoToTunnelType(t *testing.T) { c := &client{} tests := []struct { name string input proto.TunnelType want types.TunnelType wantErr bool }{ {"HTTP", proto.TunnelType_HTTP, types.TunnelTypeHTTP, false}, {"TCP", proto.TunnelType_TCP, types.TunnelTypeTCP, false}, {"Unknown", proto.TunnelType(999), types.TunnelTypeUNKNOWN, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := c.protoToTunnelType(tt.input) if (err != nil) != tt.wantErr { t.Errorf("protoToTunnelType() error = %v, wantErr %v", err, tt.wantErr) return } if got != tt.want { t.Errorf("protoToTunnelType() got = %v, want %v", got, tt.want) } }) } } func TestIsConnectionError(t *testing.T) { c := &client{} tests := []struct { name string closing bool err error want bool }{ {"NilError", false, nil, false}, {"Closing", true, io.EOF, false}, {"EOF", false, io.EOF, true}, {"Unavailable", false, status.Error(codes.Unavailable, "unavailable"), true}, {"Canceled", false, status.Error(codes.Canceled, "canceled"), true}, {"DeadlineExceeded", false, status.Error(codes.DeadlineExceeded, "deadline"), true}, {"Internal", false, status.Error(codes.Internal, "internal"), false}, {"WrappedEOF", false, errors.New("wrapped: " + io.EOF.Error()), false}, } tests[7].err = fmt.Errorf("wrapped: %w", io.EOF) tests[7].want = true for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c.closing = tt.closing if got := c.isConnectionError(tt.err); got != tt.want { t.Errorf("isConnectionError() = %v, want %v for error %v", got, tt.want, tt.err) } }) } } func TestGrowBackoff(t *testing.T) { c := &client{} tests := []struct { name string initial time.Duration want time.Duration }{ {"NormalGrow", time.Second, 2 * time.Second}, {"MaxLimit", 20 * time.Second, 30 * time.Second}, {"AlreadyAtMax", 30 * time.Second, 30 * time.Second}, {"OverMax", 40 * time.Second, 30 * time.Second}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { backoff := tt.initial c.growBackoff(&backoff) if backoff != tt.want { t.Errorf("growBackoff() = %v, want %v", backoff, tt.want) } }) } } func TestWait(t *testing.T) { c := &client{} t.Run("ZeroDuration", func(t *testing.T) { err := c.wait(context.Background(), 0) if err != nil { t.Errorf("wait() zero duration error = %v", err) } }) t.Run("ContextCanceled", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() err := c.wait(ctx, time.Minute) if !errors.Is(err, context.Canceled) { t.Errorf("wait() context canceled error = %v", err) } }) t.Run("Timeout", func(t *testing.T) { start := time.Now() err := c.wait(context.Background(), 10*time.Millisecond) if err != nil { t.Errorf("wait() timeout error = %v", err) } if time.Since(start) < 10*time.Millisecond { t.Errorf("wait() returned too early") } }) }