From 4410c9b9938c43266a4b55a5f589c0c91d0d0fc2 Mon Sep 17 00:00:00 2001 From: bagas Date: Mon, 26 Jan 2026 11:53:00 +0700 Subject: [PATCH] chore(tests): migrate to Testify for mocking and assertions --- internal/bootstrap/bootstrap_test.go | 4 +- internal/grpc/client/client_test.go | 647 +++++++++--------- internal/middleware/forwardedfor_test.go | 43 +- internal/middleware/tunnelfingerprint_test.go | 61 +- internal/port/port_test.go | 33 +- internal/random/random_test.go | 50 +- internal/registry/registry_test.go | 205 +++--- internal/transport/httphandler_test.go | 13 - session/forwarder/forwarder.go | 11 +- session/lifecycle/lifecycle_test.go | 10 +- 10 files changed, 530 insertions(+), 547 deletions(-) diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go index 6586ea3..f5b5c0c 100644 --- a/internal/bootstrap/bootstrap_test.go +++ b/internal/bootstrap/bootstrap_test.go @@ -198,8 +198,8 @@ func (m *MockGRPCClient) ClientConn() *grpc.ClientConn { } func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { - m.Called() - return + args := m.Called(ctx, token) + return args.Bool(0), args.String(1), args.Error(2) } func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error { diff --git a/internal/grpc/client/client_test.go b/internal/grpc/client/client_test.go index f19a0b9..fb2147e 100644 --- a/internal/grpc/client/client_test.go +++ b/internal/grpc/client/client_test.go @@ -8,14 +8,18 @@ import ( "testing" "time" + "tunnel_pls/internal/port" "tunnel_pls/internal/registry" + "tunnel_pls/session/forwarder" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" "tunnel_pls/session/slug" "tunnel_pls/types" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "golang.org/x/crypto/ssh" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health/grpc_health_v1" @@ -78,23 +82,15 @@ func TestAuthorizeConn(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockUserSvc.checkFunc = func(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) { - if in.AuthToken != tt.token { - t.Errorf("expected token %s, got %s", tt.token, in.AuthToken) - } - return tt.mockResp, tt.mockErr - } + mockUserSvc.On("Check", mock.Anything, &proto.CheckRequest{AuthToken: tt.token}, mock.Anything).Return(tt.mockResp, tt.mockErr).Once() auth, user, err := c.AuthorizeConn(context.Background(), tt.token) if (err != nil) != tt.wantErr { t.Errorf("AuthorizeConn() error = %v, wantErr %v", err, tt.wantErr) } - if auth != tt.wantAuth { - t.Errorf("AuthorizeConn() auth = %v, wantAuth %v", auth, tt.wantAuth) - } - if user != tt.wantUser { - t.Errorf("AuthorizeConn() user = %s, wantUser %s", user, tt.wantUser) - } + assert.Equal(t, tt.wantAuth, auth) + assert.Equal(t, tt.wantUser, user) + mockUserSvc.AssertExpectations(t) }) } } @@ -330,25 +326,15 @@ func TestHandleAuthError_WaitFail(t *testing.T) { } func TestProcessEventStream(t *testing.T) { - mockStream := &mockSubscribeClient{} c := &client{} t.Run("UnknownEventType", func(t *testing.T) { - mockStream.recvFunc = func() (*proto.Events, error) { - return &proto.Events{Type: proto.EventType(999)}, nil - } - first := true - mockStream.recvFunc = func() (*proto.Events, error) { - if first { - first = false - return &proto.Events{Type: proto.EventType(999)}, nil - } - return nil, io.EOF - } + mockStream := &mockSubscribeClient{} + mockStream.On("Recv").Return(&proto.Events{Type: proto.EventType(999)}, nil).Once() + mockStream.On("Recv").Return(nil, io.EOF).Once() + err := c.processEventStream(mockStream) - if !errors.Is(err, io.EOF) { - t.Errorf("expected EOF, got %v", err) - } + assert.ErrorIs(t, err, io.EOF) }) t.Run("DispatchSuccess", func(t *testing.T) { @@ -360,87 +346,87 @@ func TestProcessEventStream(t *testing.T) { for _, et := range events { t.Run(et.String(), func(t *testing.T) { - first := true - mockStream.recvFunc = func() (*proto.Events, error) { - if first { - first = false - payload := &proto.Events{Type: et} - switch et { - case proto.EventType_SLUG_CHANGE: - payload.Payload = &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}} - case proto.EventType_GET_SESSIONS: - payload.Payload = &proto.Events_GetSessionsEvent{GetSessionsEvent: &proto.GetSessionsEvent{}} - case proto.EventType_TERMINATE_SESSION: - payload.Payload = &proto.Events_TerminateSessionEvent{TerminateSessionEvent: &proto.TerminateSessionEvent{}} - } - return payload, nil - } - return nil, io.EOF + mockStream := &mockSubscribeClient{} + payload := &proto.Events{Type: et} + switch et { + case proto.EventType_SLUG_CHANGE: + payload.Payload = &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}} + case proto.EventType_GET_SESSIONS: + payload.Payload = &proto.Events_GetSessionsEvent{GetSessionsEvent: &proto.GetSessionsEvent{}} + case proto.EventType_TERMINATE_SESSION: + payload.Payload = &proto.Events_TerminateSessionEvent{TerminateSessionEvent: &proto.TerminateSessionEvent{}} } + + mockStream.On("Recv").Return(payload, nil).Once() + mockStream.On("Recv").Return(nil, io.EOF).Once() + mockReg := &mockRegistry{} - mockReg.getAllSessionFromUserFunc = func(user string) []registry.Session { return nil } - mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } - mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } c.sessionRegistry = mockReg - c.config = &MockConfig{} - c.config.(*MockConfig).On("Domain").Return("test.com") - mockStream.sendFunc = func(n *proto.Node) error { return nil } + mCfg := &MockConfig{} + c.config = mCfg + mCfg.On("Domain").Return("test.com").Maybe() + + switch et { + case proto.EventType_SLUG_CHANGE: + mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once() + case proto.EventType_GET_SESSIONS: + mockReg.On("GetAllSessionFromUser", mock.Anything).Return(nil).Once() + case proto.EventType_TERMINATE_SESSION: + mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("fail")).Once() + } + mockStream.On("Send", mock.Anything).Return(nil).Once() err := c.processEventStream(mockStream) - if !errors.Is(err, io.EOF) { - t.Errorf("expected EOF, got %v", err) - } + assert.ErrorIs(t, err, io.EOF) }) } }) t.Run("HandlerError", func(t *testing.T) { - first := true - mockStream.recvFunc = func() (*proto.Events, error) { - if first { - first = false - return &proto.Events{Type: proto.EventType_SLUG_CHANGE, Payload: &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}}}, nil - } - return nil, io.EOF - } + mockStream := &mockSubscribeClient{} + mockStream.On("Recv").Return(&proto.Events{ + Type: proto.EventType_SLUG_CHANGE, + Payload: &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}}, + }, nil).Once() + mockReg := &mockRegistry{} - mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } + mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once() c.sessionRegistry = mockReg - mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Unavailable, "send fail") } + + expectedErr := status.Error(codes.Unavailable, "send fail") + mockStream.On("Send", mock.Anything).Return(expectedErr).Once() err := c.processEventStream(mockStream) - if !errors.Is(err, status.Error(codes.Unavailable, "send fail")) { - t.Errorf("expected send fail error, got %v", err) - } + assert.Equal(t, expectedErr, err) }) } func TestSendNode(t *testing.T) { c := &client{} - mockStream := &mockSubscribeClient{} t.Run("Success", func(t *testing.T) { - mockStream.sendFunc = func(n *proto.Node) error { return nil } + mockStream := &mockSubscribeClient{} + mockStream.On("Send", mock.Anything).Return(nil).Once() err := c.sendNode(mockStream, &proto.Node{}, "context") - if err != nil { - t.Errorf("sendNode error = %v", err) - } + assert.NoError(t, err) + mockStream.AssertExpectations(t) }) t.Run("ConnectionError", func(t *testing.T) { - mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Unavailable, "fail") } + mockStream := &mockSubscribeClient{} + expectedErr := status.Error(codes.Unavailable, "fail") + mockStream.On("Send", mock.Anything).Return(expectedErr).Once() err := c.sendNode(mockStream, &proto.Node{}, "context") - if err == nil { - t.Errorf("expected error") - } + assert.ErrorIs(t, err, expectedErr) + mockStream.AssertExpectations(t) }) t.Run("OtherError", func(t *testing.T) { - mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Internal, "fail") } + mockStream := &mockSubscribeClient{} + mockStream.On("Send", mock.Anything).Return(status.Error(codes.Internal, "fail")).Once() err := c.sendNode(mockStream, &proto.Node{}, "context") - if err != nil { - t.Errorf("expected nil error for non-connection error (logged only)") - } + assert.NoError(t, err) + mockStream.AssertExpectations(t) }) } @@ -462,80 +448,47 @@ func TestHandleSlugChange(t *testing.T) { t.Run("Success", func(t *testing.T) { mockSess := &mockSession{} mockInter := &mockInteraction{} - mockSess.interactionFunc = func() interaction.Interaction { return mockInter } + mockSess.On("Interaction").Return(mockInter).Once() + mockInter.On("Redraw").Return().Once() - mockReg.getFunc = func(key registry.Key) (registry.Session, error) { - if key.Id != "old-slug" { - t.Errorf("expected old-slug, got %s", key.Id) - } - return mockSess, nil - } - mockReg.updateFunc = func(user string, oldKey, newKey registry.Key) error { - if user != "mas-fuad" || oldKey.Id != "old-slug" || newKey.Id != "new-slug" { - t.Errorf("unexpected update args") - } - return nil - } + mockReg.On("Get", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once() + mockReg.On("Update", "mas-fuad", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}, types.SessionKey{Id: "new-slug", Type: types.TunnelTypeHTTP}).Return(nil).Once() - sent := false - mockStream.sendFunc = func(n *proto.Node) error { - sent = true - if n.Type != proto.EventType_SLUG_CHANGE_RESPONSE { - t.Errorf("expected slug change response") - } - resp := n.GetSlugEventResponse() - if !resp.Success { - t.Errorf("expected success") - } - return nil - } + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + return n.Type == proto.EventType_SLUG_CHANGE_RESPONSE && n.GetSlugEventResponse().Success + })).Return(nil).Once() err := c.handleSlugChange(mockStream, evt) - if err != nil { - t.Errorf("handleSlugChange error = %v", err) - } - if !mockInter.redrawCalled { - t.Errorf("redraw was not called") - } - if !sent { - t.Errorf("response not sent") - } + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + mockInter.AssertExpectations(t) }) t.Run("SessionNotFound", func(t *testing.T) { - mockReg.getFunc = func(key registry.Key) (registry.Session, error) { - return nil, errors.New("not found") - } - mockStream.sendFunc = func(n *proto.Node) error { - resp := n.GetSlugEventResponse() - if resp.Success || resp.Message != "not found" { - t.Errorf("unexpected failure response: %v", resp) - } - return nil - } + mockReg.On("Get", mock.Anything).Return(nil, errors.New("not found")).Once() + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "not found" + })).Return(nil).Once() + err := c.handleSlugChange(mockStream, evt) - if err != nil { - t.Errorf("handleSlugChange should return nil if error is handled via response, but it currently returns whatever sendNode returns") - } + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) }) t.Run("UpdateError", func(t *testing.T) { mockSess := &mockSession{} - mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return mockSess, nil } - mockReg.updateFunc = func(user string, oldKey, newKey registry.Key) error { - return errors.New("update fail") - } - mockStream.sendFunc = func(n *proto.Node) error { - resp := n.GetSlugEventResponse() - if resp.Success || resp.Message != "update fail" { - t.Errorf("unexpected failure response: %v", resp) - } - return nil - } + mockReg.On("Get", mock.Anything).Return(mockSess, nil).Once() + mockReg.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("update fail")).Once() + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "update fail" + })).Return(nil).Once() + err := c.handleSlugChange(mockStream, evt) - if err != nil { - t.Errorf("handleSlugChange error = %v", err) - } + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) }) } @@ -543,7 +496,6 @@ func TestHandleGetSessions(t *testing.T) { mockReg := &mockRegistry{} mockStream := &mockSubscribeClient{} mockCfg := &MockConfig{} - mockCfg.On("Domain").Return("test.com") c := &client{sessionRegistry: mockReg, config: mockCfg} evt := &proto.Events{ @@ -557,43 +509,30 @@ func TestHandleGetSessions(t *testing.T) { t.Run("Success", func(t *testing.T) { now := time.Now() mockSess := &mockSession{} - mockSess.detailFunc = func() *types.Detail { - return &types.Detail{ - ForwardingType: "http", - Slug: "myslug", - UserID: "mas-fuad", - Active: true, - StartedAt: now, - } - } + mockSess.On("Detail").Return(&types.Detail{ + ForwardingType: "http", + Slug: "myslug", + UserID: "mas-fuad", + Active: true, + StartedAt: now, + }).Once() - mockReg.getAllSessionFromUserFunc = func(user string) []registry.Session { - if user != "mas-fuad" { - t.Errorf("expected mas-fuad, got %s", user) - } - return []registry.Session{mockSess} - } + mockReg.On("GetAllSessionFromUser", "mas-fuad").Return([]registry.Session{mockSess}).Once() + mockCfg.On("Domain").Return("test.com").Once() - sent := false - mockStream.sendFunc = func(n *proto.Node) error { - sent = true + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { if n.Type != proto.EventType_GET_SESSIONS { - t.Errorf("expected get sessions response type") + return false } - resp := n.GetGetSessionsEvent() - if len(resp.Details) != 1 || resp.Details[0].Slug != "myslug" { - t.Errorf("unexpected details: %v", resp.Details) - } - return nil - } + details := n.GetGetSessionsEvent().Details + return len(details) == 1 && details[0].Slug == "myslug" + })).Return(nil).Once() err := c.handleGetSessions(mockStream, evt) - if err != nil { - t.Errorf("handleGetSessions error = %v", err) - } - if !sent { - t.Errorf("response not sent") - } + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + mockCfg.AssertExpectations(t) }) } @@ -615,41 +554,20 @@ func TestHandleTerminateSession(t *testing.T) { t.Run("Success", func(t *testing.T) { mockSess := &mockSession{} mockLife := &mockLifecycle{} - mockSess.lifecycleFunc = func() lifecycle.Lifecycle { return mockLife } + mockSess.On("Lifecycle").Return(mockLife).Once() + mockLife.On("Close").Return(nil).Once() - closed := false - mockLife.closeFunc = func() error { - closed = true - return nil - } + mockReg.On("GetWithUser", "mas-fuad", types.SessionKey{Id: "myslug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once() - mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { - if user != "mas-fuad" || key.Id != "myslug" || key.Type != types.TunnelTypeHTTP { - t.Errorf("unexpected get args") - } - return mockSess, nil - } - - sent := false - mockStream.sendFunc = func(n *proto.Node) error { - sent = true - resp := n.GetTerminateSessionEventResponse() - if !resp.Success { - t.Errorf("expected success") - } - return nil - } + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + return n.GetTerminateSessionEventResponse().Success + })).Return(nil).Once() err := c.handleTerminateSession(mockStream, evt) - if err != nil { - t.Errorf("handleTerminateSession error = %v", err) - } - if !closed { - t.Errorf("close was not called") - } - if !sent { - t.Errorf("response not sent") - } + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + mockLife.AssertExpectations(t) }) t.Run("TunnelTypeUnknown", func(t *testing.T) { @@ -660,54 +578,46 @@ func TestHandleTerminateSession(t *testing.T) { }, }, } - mockStream.sendFunc = func(n *proto.Node) error { + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { resp := n.GetTerminateSessionEventResponse() - if resp.Success || resp.Message == "" { - t.Errorf("expected failure response") - } - return nil - } + return !resp.Success && resp.Message != "" + })).Return(nil).Once() + err := c.handleTerminateSession(mockStream, badEvt) - if err != nil { - t.Errorf("handleTerminateSession error = %v", err) - } + assert.NoError(t, err) + mockStream.AssertExpectations(t) }) t.Run("SessionNotFound", func(t *testing.T) { - mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { - return nil, errors.New("not found") - } - mockStream.sendFunc = func(n *proto.Node) error { + mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("not found")).Once() + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { resp := n.GetTerminateSessionEventResponse() - if resp.Success || resp.Message != "not found" { - t.Errorf("unexpected failure response: %v", resp) - } - return nil - } + return !resp.Success && resp.Message == "not found" + })).Return(nil).Once() + err := c.handleTerminateSession(mockStream, evt) - if err != nil { - t.Errorf("handleTerminateSession error = %v", err) - } + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) }) t.Run("CloseError", func(t *testing.T) { mockSess := &mockSession{} mockLife := &mockLifecycle{} - mockSess.lifecycleFunc = func() lifecycle.Lifecycle { return mockLife } - mockLife.closeFunc = func() error { return errors.New("close fail") } - mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return mockSess, nil } + mockSess.On("Lifecycle").Return(mockLife).Once() + mockLife.On("Close").Return(errors.New("close fail")).Once() + mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(mockSess, nil).Once() - mockStream.sendFunc = func(n *proto.Node) error { + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { resp := n.GetTerminateSessionEventResponse() - if resp.Success || resp.Message != "close fail" { - t.Errorf("expected failure response: %v", resp) - } - return nil - } + return !resp.Success && resp.Message == "close fail" + })).Return(nil).Once() + err := c.handleTerminateSession(mockStream, evt) - if err != nil { - t.Errorf("handleTerminateSession error = %v", err) - } + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + mockLife.AssertExpectations(t) }) } @@ -718,42 +628,29 @@ func TestSubscribeAndProcess(t *testing.T) { backoff := time.Second t.Run("SubscribeError", func(t *testing.T) { - mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { - return nil, status.Error(codes.Unauthenticated, "unauth") - } + expectedErr := status.Error(codes.Unauthenticated, "unauth") + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once() err := c.subscribeAndProcess(ctx, "id", "token", &backoff) - if !errors.Is(err, status.Error(codes.Unauthenticated, "unauth")) { - t.Errorf("expected unauth error, got %v", err) - } + assert.ErrorIs(t, err, expectedErr) }) t.Run("AuthSendError", func(t *testing.T) { mockStream := &mockSubscribeClient{} - mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { - return mockStream, nil - } - mockStream.sendFunc = func(n *proto.Node) error { - return status.Error(codes.Internal, "send fail") - } + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once() + expectedErr := status.Error(codes.Internal, "send fail") + mockStream.On("Send", mock.Anything).Return(expectedErr).Once() err := c.subscribeAndProcess(ctx, "id", "token", &backoff) - if !errors.Is(err, status.Error(codes.Internal, "send fail")) { - t.Errorf("expected send fail, got %v", err) - } + assert.ErrorIs(t, err, expectedErr) }) t.Run("StreamError", func(t *testing.T) { mockStream := &mockSubscribeClient{} - mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { - return mockStream, nil - } - mockStream.sendFunc = func(n *proto.Node) error { return nil } - mockStream.recvFunc = func() (*proto.Events, error) { - return nil, status.Error(codes.Internal, "stream fail") - } + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once() + mockStream.On("Send", mock.Anything).Return(nil).Once() + expectedErr := status.Error(codes.Internal, "stream fail") + mockStream.On("Recv").Return(nil, expectedErr).Once() err := c.subscribeAndProcess(ctx, "id", "token", &backoff) - if !errors.Is(err, status.Error(codes.Internal, "stream fail")) { - t.Errorf("expected stream fail, got %v", err) - } + assert.ErrorIs(t, err, expectedErr) }) } @@ -762,13 +659,10 @@ func TestSubscribeEvents(t *testing.T) { c := &client{eventService: mockEventSvc} t.Run("ReturnsOnError", func(t *testing.T) { - mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { - return nil, errors.New("fatal error") - } + expectedErr := errors.New("fatal error") + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once() err := c.SubscribeEvents(context.Background(), "id", "token") - if err == nil || err.Error() != "fatal error" { - t.Errorf("expected fatal error, got %v", err) - } + assert.ErrorIs(t, err, expectedErr) }) t.Run("RetryLoop", func(t *testing.T) { @@ -779,19 +673,11 @@ func TestSubscribeEvents(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() - callCount := 0 - mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { - callCount++ - return nil, status.Error(codes.Unavailable, "unavailable") - } + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, status.Error(codes.Unavailable, "unavailable")) err := c.SubscribeEvents(ctx, "id", "token") - if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) { - t.Errorf("expected timeout/canceled error, got %v", err) - } - if callCount <= 1 { - t.Errorf("expected multiple calls due to retry, got %d", callCount) - } + assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)) + mockEventSvc.AssertExpectations(t) }) } @@ -806,33 +692,24 @@ func TestCheckServerHealth(t *testing.T) { c := &client{} t.Run("Success", func(t *testing.T) { - mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) { - return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil - } + mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil).Once() err := c.CheckServerHealth(context.Background()) - if err != nil { - t.Errorf("expected nil error, got %v", err) - } + assert.NoError(t, err) + mockHealth.AssertExpectations(t) }) t.Run("Error", func(t *testing.T) { - mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) { - return nil, errors.New("health fail") - } + mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("health fail")).Once() err := c.CheckServerHealth(context.Background()) - if err == nil || err.Error() != "health check failed: health fail" { - t.Errorf("expected health fail error, got %v", err) - } + assert.ErrorContains(t, err, "health check failed: health fail") + mockHealth.AssertExpectations(t) }) t.Run("NotServing", func(t *testing.T) { - mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) { - return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil - } + mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil).Once() err := c.CheckServerHealth(context.Background()) - if err == nil || err.Error() != "server not serving: NOT_SERVING" { - t.Errorf("expected not serving error, got %v", err) - } + assert.ErrorContains(t, err, "server not serving: NOT_SERVING") + mockHealth.AssertExpectations(t) }) } @@ -897,88 +774,192 @@ func (m *MockConfig) NodeToken() string { return m.Called().String(0) } func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } type mockRegistry struct { - registry.Registry - getFunc func(key registry.Key) (registry.Session, error) - getWithUserFunc func(user string, key registry.Key) (registry.Session, error) - updateFunc func(user string, oldKey, newKey registry.Key) error - getAllSessionFromUserFunc func(user string) []registry.Session + mock.Mock } func (m *mockRegistry) Get(key registry.Key) (registry.Session, error) { - return m.getFunc(key) + args := m.Called(key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) } func (m *mockRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { - return m.getWithUserFunc(user, key) + args := m.Called(user, key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) } func (m *mockRegistry) Update(user string, oldKey, newKey registry.Key) error { - return m.updateFunc(user, oldKey, newKey) + return m.Called(user, oldKey, newKey).Error(0) } func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session { - return m.getAllSessionFromUserFunc(user) + args := m.Called(user) + if args.Get(0) == nil { + return nil + } + return args.Get(0).([]registry.Session) +} +func (m *mockRegistry) Register(key registry.Key, session registry.Session) bool { + return m.Called(key, session).Bool(0) +} +func (m *mockRegistry) Remove(key registry.Key) { + m.Called(key) } type mockSession struct { - registry.Session - lifecycleFunc func() lifecycle.Lifecycle - interactionFunc func() interaction.Interaction - detailFunc func() *types.Detail - slugFunc func() slug.Slug + mock.Mock } -func (m *mockSession) Lifecycle() lifecycle.Lifecycle { return m.lifecycleFunc() } -func (m *mockSession) Interaction() interaction.Interaction { return m.interactionFunc() } -func (m *mockSession) Detail() *types.Detail { return m.detailFunc() } -func (m *mockSession) Slug() slug.Slug { return m.slugFunc() } +func (m *mockSession) Lifecycle() lifecycle.Lifecycle { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(lifecycle.Lifecycle) +} +func (m *mockSession) Interaction() interaction.Interaction { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(interaction.Interaction) +} +func (m *mockSession) Detail() *types.Detail { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(*types.Detail) +} +func (m *mockSession) Slug() slug.Slug { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(slug.Slug) +} +func (m *mockSession) Forwarder() forwarder.Forwarder { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(forwarder.Forwarder) +} type mockInteraction struct { - interaction.Interaction - redrawCalled bool + mock.Mock } -func (m *mockInteraction) Redraw() { m.redrawCalled = true } +func (m *mockInteraction) Start() { m.Called() } +func (m *mockInteraction) Stop() { m.Called() } +func (m *mockInteraction) Redraw() { m.Called() } +func (m *mockInteraction) SetWH(w, h int) { m.Called(w, h) } +func (m *mockInteraction) SetChannel(channel ssh.Channel) { m.Called(channel) } +func (m *mockInteraction) SetMode(mode types.InteractiveMode) { m.Called(mode) } +func (m *mockInteraction) Mode() types.InteractiveMode { + return m.Called().Get(0).(types.InteractiveMode) +} +func (m *mockInteraction) Send(message string) error { return m.Called(message).Error(0) } type mockLifecycle struct { - lifecycle.Lifecycle - closeFunc func() error + mock.Mock } -func (m *mockLifecycle) Close() error { return m.closeFunc() } +func (m *mockLifecycle) Close() error { return m.Called().Error(0) } +func (m *mockLifecycle) Channel() ssh.Channel { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(ssh.Channel) +} +func (m *mockLifecycle) Connection() ssh.Conn { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(ssh.Conn) +} +func (m *mockLifecycle) User() string { return m.Called().String(0) } +func (m *mockLifecycle) SetChannel(channel ssh.Channel) { m.Called(channel) } +func (m *mockLifecycle) SetStatus(status types.SessionStatus) { m.Called(status) } +func (m *mockLifecycle) IsActive() bool { return m.Called().Bool(0) } +func (m *mockLifecycle) StartedAt() time.Time { return m.Called().Get(0).(time.Time) } +func (m *mockLifecycle) PortRegistry() port.Port { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(port.Port) +} type mockEventServiceClient struct { - proto.EventServiceClient - subscribeFunc func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) + mock.Mock } func (m *mockEventServiceClient) Subscribe(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { - return m.subscribeFunc(ctx, opts...) + args := m.Called(ctx, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(proto.EventService_SubscribeClient), args.Error(1) } type mockSubscribeClient struct { + mock.Mock grpc.ClientStream - sendFunc func(*proto.Node) error - recvFunc func() (*proto.Events, error) } -func (m *mockSubscribeClient) Send(n *proto.Node) error { return m.sendFunc(n) } -func (m *mockSubscribeClient) Recv() (*proto.Events, error) { return m.recvFunc() } -func (m *mockSubscribeClient) Context() context.Context { return context.Background() } +func (m *mockSubscribeClient) Send(n *proto.Node) error { return m.Called(n).Error(0) } +func (m *mockSubscribeClient) Recv() (*proto.Events, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*proto.Events), args.Error(1) +} +func (m *mockSubscribeClient) Context() context.Context { return m.Called().Get(0).(context.Context) } type mockUserServiceClient struct { - proto.UserServiceClient - checkFunc func(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) + mock.Mock } func (m *mockUserServiceClient) Check(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) { - return m.checkFunc(ctx, in, opts...) + args := m.Called(ctx, in, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*proto.CheckResponse), args.Error(1) } type mockHealthClient struct { - grpc_health_v1.HealthClient - checkFunc func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) + mock.Mock } func (m *mockHealthClient) Check(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) { - return m.checkFunc(ctx, in, opts...) + args := m.Called(ctx, in, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*grpc_health_v1.HealthCheckResponse), args.Error(1) +} + +func (m *mockHealthClient) Watch(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (grpc_health_v1.Health_WatchClient, error) { + args := m.Called(ctx, in, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(grpc_health_v1.Health_WatchClient), args.Error(1) +} + +func (m *mockHealthClient) List(ctx context.Context, in *grpc_health_v1.HealthListRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthListResponse, error) { + args := m.Called(ctx, in, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*grpc_health_v1.HealthListResponse), args.Error(1) } func TestProtoToTunnelType(t *testing.T) { diff --git a/internal/middleware/forwardedfor_test.go b/internal/middleware/forwardedfor_test.go index ef6a536..5a45dc0 100644 --- a/internal/middleware/forwardedfor_test.go +++ b/internal/middleware/forwardedfor_test.go @@ -1,40 +1,42 @@ package middleware import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "net" "testing" ) type mockRequestHeader struct { - headers map[string]string + mock.Mock } func (m *mockRequestHeader) Value(key string) string { - return m.headers[key] + return m.Called(key).String(0) } func (m *mockRequestHeader) Set(key string, value string) { - m.headers[key] = value + m.Called(key, value) } func (m *mockRequestHeader) Remove(key string) { - delete(m.headers, key) + m.Called(key) } func (m *mockRequestHeader) Finalize() []byte { - return []byte{} + return m.Called().Get(0).([]byte) } func (m *mockRequestHeader) Method() string { - return "" + return m.Called().String(0) } func (m *mockRequestHeader) Path() string { - return "" + return m.Called().String(0) } func (m *mockRequestHeader) Version() string { - return "" + return m.Called().String(0) } func TestForwardedFor_HandleRequest(t *testing.T) { @@ -73,23 +75,19 @@ func TestForwardedFor_HandleRequest(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { ff := NewForwardedFor(tc.addr) - reqHeader := &mockRequestHeader{headers: make(map[string]string)} + reqHeader := new(mockRequestHeader) + + if !tc.expectError { + reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return() + } err := ff.HandleRequest(reqHeader) if tc.expectError { - if err == nil { - t.Fatalf("expected error but got none") - } + assert.Error(t, err) } else { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - host := reqHeader.Value("X-Forwarded-For") - if host != tc.expectedHost { - t.Errorf("expected X-Forwarded-For header to be '%s', got '%s'", tc.expectedHost, host) - } + assert.NoError(t, err) + reqHeader.AssertExpectations(t) } }) } @@ -121,10 +119,7 @@ func TestNewForwardedFor(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { ff := NewForwardedFor(tc.addr) - - if ff.addr.String() != tc.expectAddr.String() { - t.Errorf("expected addr to be '%v', got '%v'", tc.expectAddr, ff.addr) - } + assert.Equal(t, tc.expectAddr.String(), ff.addr.String()) }) } } diff --git a/internal/middleware/tunnelfingerprint_test.go b/internal/middleware/tunnelfingerprint_test.go index 4753ac0..21e8b15 100644 --- a/internal/middleware/tunnelfingerprint_test.go +++ b/internal/middleware/tunnelfingerprint_test.go @@ -1,76 +1,69 @@ package middleware import ( - "errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "testing" ) type mockResponseHeader struct { - headers map[string]string + mock.Mock } func (m *mockResponseHeader) Value(key string) string { - return m.headers[key] + return m.Called(key).String(0) } func (m *mockResponseHeader) Set(key string, value string) { - m.headers[key] = value + m.Called(key, value) } func (m *mockResponseHeader) Remove(key string) { - delete(m.headers, key) + m.Called(key) } func (m *mockResponseHeader) Finalize() []byte { - return nil + return m.Called().Get(0).([]byte) } func TestTunnelFingerprintHandleResponse(t *testing.T) { tests := []struct { - name string - initialState map[string]string - expected map[string]string - body []byte - wantErr error + name string + expected map[string]string + body []byte + wantErr error }{ { - name: "Sets Server Header", - initialState: map[string]string{}, - expected: map[string]string{"Server": "Tunnel Please"}, - body: []byte("Sample body"), - wantErr: nil, + name: "Sets Server Header", + expected: map[string]string{"Server": "Tunnel Please"}, + body: []byte("Sample body"), + wantErr: nil, }, { - name: "Overwrites Server Header", - initialState: map[string]string{"Server": "Old Value"}, - expected: map[string]string{"Server": "Tunnel Please"}, - body: nil, - wantErr: nil, + name: "Overwrites Server Header", + expected: map[string]string{"Server": "Tunnel Please"}, + body: nil, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockHeader := &mockResponseHeader{headers: tt.initialState} + mockHeader := new(mockResponseHeader) + for k, v := range tt.expected { + mockHeader.On("Set", k, v).Return() + } + tunnelFingerprint := NewTunnelFingerprint() err := tunnelFingerprint.HandleResponse(mockHeader, tt.body) - if !errors.Is(err, tt.wantErr) { - t.Fatalf("unexpected error, got: %v, want: %v", err, tt.wantErr) - } - - for key, expectedValue := range tt.expected { - if val := mockHeader.Value(key); val != expectedValue { - t.Errorf("header[%q] = %q; want %q", key, val, expectedValue) - } - } + assert.ErrorIs(t, err, tt.wantErr) + mockHeader.AssertExpectations(t) }) } } func TestNewTunnelFingerprint(t *testing.T) { instance := NewTunnelFingerprint() - if instance == nil { - t.Errorf("NewTunnelFingerprint() = nil; want non-nil instance") - } + assert.NotNil(t, instance) } diff --git a/internal/port/port_test.go b/internal/port/port_test.go index b787f30..56526b3 100644 --- a/internal/port/port_test.go +++ b/internal/port/port_test.go @@ -1,6 +1,7 @@ package port import ( + "github.com/stretchr/testify/assert" "testing" ) @@ -20,8 +21,10 @@ func TestAddRange(t *testing.T) { t.Run(tt.name, func(t *testing.T) { pm := New() err := pm.AddRange(tt.startPort, tt.endPort) - if (err != nil) != tt.wantErr { - t.Errorf("AddRange() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) } }) } @@ -48,9 +51,8 @@ func TestUnassigned(t *testing.T) { _ = pm.SetStatus(k, v) } got, gotOk := pm.Unassigned() - if got != tt.want || gotOk != tt.wantOk { - t.Errorf("Unassigned() got = %v, want %v, gotOk = %v, wantOk %v", got, tt.want, gotOk, tt.wantOk) - } + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.wantOk, gotOk) }) } } @@ -70,12 +72,12 @@ func TestSetStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := pm.SetStatus(tt.port, tt.assigned); err != nil { - t.Errorf("SetStatus() error = %v", err) - } - if status, _ := pm.(*port).ports[tt.port]; status != tt.assigned { - t.Errorf("SetStatus() failed, port %v has status %v, want %v", tt.port, status, tt.assigned) - } + err := pm.SetStatus(tt.port, tt.assigned) + assert.NoError(t, err) + + status, ok := pm.(*port).ports[tt.port] + assert.True(t, ok) + assert.Equal(t, tt.assigned, status) }) } } @@ -102,13 +104,10 @@ func TestClaim(t *testing.T) { } got := pm.Claim(tt.port) - if got != tt.want { - t.Errorf("Claim() got = %v, want %v", got, tt.want) - } + assert.Equal(t, tt.want, got) - if finalState := pm.(*port).ports[tt.port]; finalState != true { - t.Errorf("Claim() did not update port %v status to 'assigned'", tt.port) - } + finalState := pm.(*port).ports[tt.port] + assert.True(t, finalState) }) } } diff --git a/internal/random/random_test.go b/internal/random/random_test.go index 057487b..e0cd512 100644 --- a/internal/random/random_test.go +++ b/internal/random/random_test.go @@ -1,20 +1,12 @@ package random import ( - "errors" - "fmt" "io" "testing" + + "github.com/stretchr/testify/assert" ) -type brainrotReader struct { - err error -} - -func (f *brainrotReader) Read(p []byte) (int, error) { - return 0, f.err -} - func TestRandom_String(t *testing.T) { tests := []struct { name string @@ -32,20 +24,18 @@ func TestRandom_String(t *testing.T) { randomizer := New() result, err := randomizer.String(tt.length) - if (err != nil) != tt.wantErr { - t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if !tt.wantErr && len(result) != tt.length { - t.Errorf("String() length = %v, want %v", len(result), tt.length) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, result, tt.length) } }) } } func TestRandomWithFailingReader_String(t *testing.T) { - errBrainrot := fmt.Errorf("you are not sigma enough") + errBrainrot := assert.AnError tests := []struct { name string @@ -53,8 +43,10 @@ func TestRandomWithFailingReader_String(t *testing.T) { expectErr error }{ { - name: "failing reader", - reader: &brainrotReader{err: errBrainrot}, + name: "failing reader", + reader: func() io.Reader { + return &failingReader{err: errBrainrot} + }(), expectErr: errBrainrot, }, } @@ -63,14 +55,16 @@ func TestRandomWithFailingReader_String(t *testing.T) { t.Run(tt.name, func(t *testing.T) { randomizer := &random{reader: tt.reader} result, err := randomizer.String(20) - if !errors.Is(err, tt.expectErr) { - t.Errorf("String() error = %v, wantErr %v", err, tt.expectErr) - return - } - - if result != "" { - t.Errorf("String() result = %v, want an empty string due to error", result) - } + assert.ErrorIs(t, err, tt.expectErr) + assert.Empty(t, result) }) } } + +type failingReader struct { + err error +} + +func (f *failingReader) Read(p []byte) (int, error) { + return 0, f.err +} diff --git a/internal/registry/registry_test.go b/internal/registry/registry_test.go index ce3c14b..6e93ceb 100644 --- a/internal/registry/registry_test.go +++ b/internal/registry/registry_test.go @@ -1,7 +1,9 @@ package registry import ( - "errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "sync" "testing" "time" @@ -15,47 +17,109 @@ import ( "golang.org/x/crypto/ssh" ) -type mockSession struct{ user string } +type mockSession struct { + mock.Mock +} -func (m *mockSession) Lifecycle() lifecycle.Lifecycle { return &mockLifecycle{user: m.user} } +func (m *mockSession) Lifecycle() lifecycle.Lifecycle { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(lifecycle.Lifecycle) +} func (m *mockSession) Interaction() interaction.Interaction { - return nil + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(interaction.Interaction) } func (m *mockSession) Forwarder() forwarder.Forwarder { - return nil + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(forwarder.Forwarder) } func (m *mockSession) Slug() slug.Slug { - return &mockSlug{} + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(slug.Slug) } func (m *mockSession) Detail() *types.Detail { - return nil + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(*types.Detail) } -type mockLifecycle struct{ user string } +type mockLifecycle struct { + mock.Mock +} func (ml *mockLifecycle) Channel() ssh.Channel { - return nil + args := ml.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(ssh.Channel) } -func (ml *mockLifecycle) Connection() ssh.Conn { return nil } -func (ml *mockLifecycle) PortRegistry() port.Port { return nil } -func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { _ = channel } -func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { _ = status } -func (ml *mockLifecycle) IsActive() bool { return false } -func (ml *mockLifecycle) StartedAt() time.Time { return time.Time{} } -func (ml *mockLifecycle) Close() error { return nil } -func (ml *mockLifecycle) User() string { return ml.user } +func (ml *mockLifecycle) Connection() ssh.Conn { + args := ml.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(ssh.Conn) +} -type mockSlug struct{} +func (ml *mockLifecycle) PortRegistry() port.Port { + args := ml.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(port.Port) +} -func (ms *mockSlug) Set(slug string) { _ = slug } -func (ms *mockSlug) String() string { return "" } +func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { ml.Called(channel) } +func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { ml.Called(status) } +func (ml *mockLifecycle) IsActive() bool { return ml.Called().Bool(0) } +func (ml *mockLifecycle) StartedAt() time.Time { return ml.Called().Get(0).(time.Time) } +func (ml *mockLifecycle) Close() error { return ml.Called().Error(0) } +func (ml *mockLifecycle) User() string { return ml.Called().String(0) } + +type mockSlug struct { + mock.Mock +} + +func (ms *mockSlug) Set(slug string) { ms.Called(slug) } +func (ms *mockSlug) String() string { return ms.Called().String(0) } + +func createMockSession(user ...string) *mockSession { + u := "user1" + if len(user) > 0 { + u = user[0] + } + m := new(mockSession) + ml := new(mockLifecycle) + ml.On("User").Return(u).Maybe() + m.On("Lifecycle").Return(ml).Maybe() + ms := new(mockSlug) + ms.On("Set", mock.Anything).Maybe() + m.On("Slug").Return(ms).Maybe() + m.On("Interaction").Return(nil).Maybe() + m.On("Forwarder").Return(nil).Maybe() + m.On("Detail").Return(nil).Maybe() + return m +} func TestNewRegistry(t *testing.T) { r := NewRegistry() - if r == nil { - t.Fatal("NewRegistry returned nil") - } + require.NotNil(t, r) } func TestRegistry_Get(t *testing.T) { @@ -71,7 +135,7 @@ func TestRegistry_Get(t *testing.T) { setupFunc: func(r *registry) { user := "user1" key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} - session := &mockSession{user: user} + session := createMockSession(user) r.mu.Lock() defer r.mu.Unlock() @@ -113,13 +177,8 @@ func TestRegistry_Get(t *testing.T) { session, err := r.Get(tt.key) - if !errors.Is(err, tt.wantErr) { - t.Fatalf("expected error %v, got %v", tt.wantErr, err) - } - - if (session != nil) != tt.wantResult { - t.Fatalf("expected session existence to be %v, got %v", tt.wantResult, session != nil) - } + assert.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantResult, session != nil) }) } } @@ -138,7 +197,7 @@ func TestRegistry_GetWithUser(t *testing.T) { setupFunc: func(r *registry) { user := "user1" key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} - session := &mockSession{user: user} + session := createMockSession() r.mu.Lock() defer r.mu.Unlock() @@ -183,13 +242,8 @@ func TestRegistry_GetWithUser(t *testing.T) { session, err := r.GetWithUser(tt.user, tt.key) - if !errors.Is(err, tt.wantErr) { - t.Fatalf("expected error %v, got %v", tt.wantErr, err) - } - - if (session != nil) != tt.wantResult { - t.Fatalf("expected session existence to be %v, got %v", tt.wantResult, session != nil) - } + assert.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantResult, session != nil) }) } } @@ -207,7 +261,7 @@ func TestRegistry_Update(t *testing.T) { setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP} - session := &mockSession{user: "user1"} + session := createMockSession("user1") r.mu.Lock() defer r.mu.Unlock() @@ -226,7 +280,7 @@ func TestRegistry_Update(t *testing.T) { setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP} - session := &mockSession{user: "user1"} + session := createMockSession() r.mu.Lock() defer r.mu.Unlock() @@ -247,7 +301,7 @@ func TestRegistry_Update(t *testing.T) { setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP} - session := &mockSession{user: "user1"} + session := createMockSession() r.mu.Lock() defer r.mu.Unlock() @@ -266,7 +320,7 @@ func TestRegistry_Update(t *testing.T) { setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP} - session := &mockSession{user: "user1"} + session := createMockSession() r.mu.Lock() defer r.mu.Unlock() @@ -285,7 +339,7 @@ func TestRegistry_Update(t *testing.T) { setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP} newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP} - session := &mockSession{user: "user1"} + session := createMockSession() r.mu.Lock() defer r.mu.Unlock() @@ -304,7 +358,7 @@ func TestRegistry_Update(t *testing.T) { setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP} newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP} - session := &mockSession{user: "user1"} + session := createMockSession() r.mu.Lock() defer r.mu.Unlock() @@ -323,7 +377,7 @@ func TestRegistry_Update(t *testing.T) { setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP} - session := &mockSession{user: "user1"} + session := createMockSession() r.mu.Lock() defer r.mu.Unlock() @@ -342,7 +396,7 @@ func TestRegistry_Update(t *testing.T) { setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP} newKey := oldKey - session := &mockSession{user: "user1"} + session := createMockSession() r.mu.Lock() defer r.mu.Unlock() @@ -371,19 +425,15 @@ func TestRegistry_Update(t *testing.T) { oldKey, newKey := tt.setupFunc(r) err := r.Update(tt.user, oldKey, newKey) - if !errors.Is(err, tt.wantErr) { - t.Fatalf("expected error %v, got %v", tt.wantErr, err) - } + assert.ErrorIs(t, err, tt.wantErr) if err == nil { r.mu.RLock() defer r.mu.RUnlock() - if _, ok := r.byUser[tt.user][newKey]; !ok { - t.Errorf("newKey not found in registry") - } - if _, ok := r.byUser[tt.user][oldKey]; ok { - t.Errorf("oldKey still exists in registry") - } + _, ok := r.byUser[tt.user][newKey] + assert.True(t, ok, "newKey not found in registry") + _, ok = r.byUser[tt.user][oldKey] + assert.False(t, ok, "oldKey still exists in registry") } }) } @@ -410,7 +460,7 @@ func TestRegistry_Register(t *testing.T) { user: "user1", setupFunc: func(r *registry) Key { key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} - session := &mockSession{user: "user1"} + session := createMockSession() r.mu.Lock() r.byUser["user1"] = map[Key]Session{key: session} @@ -426,7 +476,7 @@ func TestRegistry_Register(t *testing.T) { user: "user1", setupFunc: func(r *registry) Key { firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP} - session := &mockSession{user: "user1"} + session := createMockSession() r.mu.Lock() r.byUser["user1"] = map[Key]Session{firstKey: session} r.slugIndex[firstKey] = "user1" @@ -450,22 +500,16 @@ func TestRegistry_Register(t *testing.T) { } key := tt.setupFunc(r) - session := &mockSession{user: tt.user} + session := createMockSession() ok := r.Register(key, session) - if ok != tt.wantOK { - t.Fatalf("expected success %v, got %v", tt.wantOK, ok) - } + assert.Equal(t, tt.wantOK, ok) if ok { r.mu.RLock() defer r.mu.RUnlock() - if r.byUser[tt.user][key] != session { - t.Errorf("session not stored in byUser") - } - if r.slugIndex[key] != tt.user { - t.Errorf("slugIndex not updated") - } + assert.Equal(t, session, r.byUser[tt.user][key], "session not stored in byUser") + assert.Equal(t, tt.user, r.slugIndex[key], "slugIndex not updated") } }) } @@ -492,8 +536,8 @@ func TestRegistry_GetAllSessionFromUser(t *testing.T) { key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP} r.mu.Lock() r.byUser[user] = map[Key]Session{ - key1: &mockSession{user: user}, - key2: &mockSession{user: user}, + key1: createMockSession(), + key2: createMockSession(), } r.mu.Unlock() return user @@ -511,9 +555,7 @@ func TestRegistry_GetAllSessionFromUser(t *testing.T) { } user := tt.setupFunc(r) sessions := r.GetAllSessionFromUser(user) - if len(sessions) != tt.expectN { - t.Errorf("expected %d sessions, got %d", tt.expectN, len(sessions)) - } + assert.Len(t, sessions, tt.expectN) }) } } @@ -530,7 +572,7 @@ func TestRegistry_Remove(t *testing.T) { setupFunc: func(r *registry) (string, types.SessionKey) { user := "user1" key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP} - session := &mockSession{user: user} + session := createMockSession() r.mu.Lock() r.byUser[user] = map[Key]Session{key: session} r.slugIndex[key] = user @@ -538,15 +580,12 @@ func TestRegistry_Remove(t *testing.T) { return user, key }, verify: func(t *testing.T, r *registry, user string, key types.SessionKey) { - if _, ok := r.byUser[user][key]; ok { - t.Errorf("expected key to be removed from byUser") - } - if _, ok := r.slugIndex[key]; ok { - t.Errorf("expected key to be removed from slugIndex") - } - if _, ok := r.byUser[user]; ok { - t.Errorf("expected user to be removed from byUser map") - } + _, ok := r.byUser[user][key] + assert.False(t, ok, "expected key to be removed from byUser") + _, ok = r.slugIndex[key] + assert.False(t, ok, "expected key to be removed from slugIndex") + _, ok = r.byUser[user] + assert.False(t, ok, "expected user to be removed from byUser map") }, }, { diff --git a/internal/transport/httphandler_test.go b/internal/transport/httphandler_test.go index de1dd24..b30d9d5 100644 --- a/internal/transport/httphandler_test.go +++ b/internal/transport/httphandler_test.go @@ -144,19 +144,6 @@ func (m *MockLifecycle) Close() error { return args.Error(0) } -type MockSSHConn struct { - ssh.Conn - mock.Mock -} - -func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { - args := m.Called(name, data) - if args.Get(0) == nil { - return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) - } - return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) -} - type MockSSHChannel struct { ssh.Channel mock.Mock diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 629fffd..b894536 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -73,11 +73,7 @@ func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) ( case resultChan <- channelResult{channel, reqs, err}: case <-ctx.Done(): if channel != nil { - err = channel.Close() - if err != nil { - log.Printf("Failed to close unused channel: %v", err) - return - } + _ = channel.Close() go ssh.DiscardRequests(reqs) } } @@ -116,10 +112,7 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { defer func() { - _, err := io.Copy(io.Discard, src) - if err != nil { - log.Printf("Failed to discard connection: %v", err) - } + _, _ = io.Copy(io.Discard, src) }() var wg sync.WaitGroup diff --git a/session/lifecycle/lifecycle_test.go b/session/lifecycle/lifecycle_test.go index 73333e5..4f4335b 100644 --- a/session/lifecycle/lifecycle_test.go +++ b/session/lifecycle/lifecycle_test.go @@ -1,6 +1,7 @@ package lifecycle import ( + "context" "errors" "io" "net" @@ -65,10 +66,10 @@ func (m *MockForwarder) Listener() net.Listener { return args.Get(0).(net.Listener) } -func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { - args := m.Called(payload) +func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(ctx, origin) if args.Get(0) == nil { - return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) + return nil, nil, args.Error(2) } return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) } @@ -208,7 +209,8 @@ func TestLifecycle_SetStatus(t *testing.T) { mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") - assert.NotNil(t, mockLifecycle.StartedAt()) + mockLifecycle.SetStatus(types.SessionStatusRUNNING) + assert.True(t, mockLifecycle.IsActive()) } func TestLifecycle_IsActive(t *testing.T) {