chore(tests): migrate to Testify for mocking and assertions
SonarQube Scan / SonarQube Trigger (push) Successful in 2m36s
SonarQube Scan / SonarQube Trigger (push) Successful in 2m36s
This commit is contained in:
+314
-333
@@ -8,14 +8,18 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
@@ -78,23 +82,15 @@ func TestAuthorizeConn(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockUserSvc.checkFunc = func(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) {
|
||||
if in.AuthToken != tt.token {
|
||||
t.Errorf("expected token %s, got %s", tt.token, in.AuthToken)
|
||||
}
|
||||
return tt.mockResp, tt.mockErr
|
||||
}
|
||||
mockUserSvc.On("Check", mock.Anything, &proto.CheckRequest{AuthToken: tt.token}, mock.Anything).Return(tt.mockResp, tt.mockErr).Once()
|
||||
|
||||
auth, user, err := c.AuthorizeConn(context.Background(), tt.token)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AuthorizeConn() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if auth != tt.wantAuth {
|
||||
t.Errorf("AuthorizeConn() auth = %v, wantAuth %v", auth, tt.wantAuth)
|
||||
}
|
||||
if user != tt.wantUser {
|
||||
t.Errorf("AuthorizeConn() user = %s, wantUser %s", user, tt.wantUser)
|
||||
}
|
||||
assert.Equal(t, tt.wantAuth, auth)
|
||||
assert.Equal(t, tt.wantUser, user)
|
||||
mockUserSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -330,25 +326,15 @@ func TestHandleAuthError_WaitFail(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProcessEventStream(t *testing.T) {
|
||||
mockStream := &mockSubscribeClient{}
|
||||
c := &client{}
|
||||
|
||||
t.Run("UnknownEventType", func(t *testing.T) {
|
||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
||||
return &proto.Events{Type: proto.EventType(999)}, nil
|
||||
}
|
||||
first := true
|
||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
||||
if first {
|
||||
first = false
|
||||
return &proto.Events{Type: proto.EventType(999)}, nil
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
mockStream := &mockSubscribeClient{}
|
||||
mockStream.On("Recv").Return(&proto.Events{Type: proto.EventType(999)}, nil).Once()
|
||||
mockStream.On("Recv").Return(nil, io.EOF).Once()
|
||||
|
||||
err := c.processEventStream(mockStream)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Errorf("expected EOF, got %v", err)
|
||||
}
|
||||
assert.ErrorIs(t, err, io.EOF)
|
||||
})
|
||||
|
||||
t.Run("DispatchSuccess", func(t *testing.T) {
|
||||
@@ -360,87 +346,87 @@ func TestProcessEventStream(t *testing.T) {
|
||||
|
||||
for _, et := range events {
|
||||
t.Run(et.String(), func(t *testing.T) {
|
||||
first := true
|
||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
||||
if first {
|
||||
first = false
|
||||
payload := &proto.Events{Type: et}
|
||||
switch et {
|
||||
case proto.EventType_SLUG_CHANGE:
|
||||
payload.Payload = &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}}
|
||||
case proto.EventType_GET_SESSIONS:
|
||||
payload.Payload = &proto.Events_GetSessionsEvent{GetSessionsEvent: &proto.GetSessionsEvent{}}
|
||||
case proto.EventType_TERMINATE_SESSION:
|
||||
payload.Payload = &proto.Events_TerminateSessionEvent{TerminateSessionEvent: &proto.TerminateSessionEvent{}}
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
return nil, io.EOF
|
||||
mockStream := &mockSubscribeClient{}
|
||||
payload := &proto.Events{Type: et}
|
||||
switch et {
|
||||
case proto.EventType_SLUG_CHANGE:
|
||||
payload.Payload = &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}}
|
||||
case proto.EventType_GET_SESSIONS:
|
||||
payload.Payload = &proto.Events_GetSessionsEvent{GetSessionsEvent: &proto.GetSessionsEvent{}}
|
||||
case proto.EventType_TERMINATE_SESSION:
|
||||
payload.Payload = &proto.Events_TerminateSessionEvent{TerminateSessionEvent: &proto.TerminateSessionEvent{}}
|
||||
}
|
||||
|
||||
mockStream.On("Recv").Return(payload, nil).Once()
|
||||
mockStream.On("Recv").Return(nil, io.EOF).Once()
|
||||
|
||||
mockReg := &mockRegistry{}
|
||||
mockReg.getAllSessionFromUserFunc = func(user string) []registry.Session { return nil }
|
||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
|
||||
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
|
||||
c.sessionRegistry = mockReg
|
||||
c.config = &MockConfig{}
|
||||
c.config.(*MockConfig).On("Domain").Return("test.com")
|
||||
mockStream.sendFunc = func(n *proto.Node) error { return nil }
|
||||
mCfg := &MockConfig{}
|
||||
c.config = mCfg
|
||||
mCfg.On("Domain").Return("test.com").Maybe()
|
||||
|
||||
switch et {
|
||||
case proto.EventType_SLUG_CHANGE:
|
||||
mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once()
|
||||
case proto.EventType_GET_SESSIONS:
|
||||
mockReg.On("GetAllSessionFromUser", mock.Anything).Return(nil).Once()
|
||||
case proto.EventType_TERMINATE_SESSION:
|
||||
mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("fail")).Once()
|
||||
}
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
|
||||
err := c.processEventStream(mockStream)
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Errorf("expected EOF, got %v", err)
|
||||
}
|
||||
assert.ErrorIs(t, err, io.EOF)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandlerError", func(t *testing.T) {
|
||||
first := true
|
||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
||||
if first {
|
||||
first = false
|
||||
return &proto.Events{Type: proto.EventType_SLUG_CHANGE, Payload: &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}}}, nil
|
||||
}
|
||||
return nil, io.EOF
|
||||
}
|
||||
mockStream := &mockSubscribeClient{}
|
||||
mockStream.On("Recv").Return(&proto.Events{
|
||||
Type: proto.EventType_SLUG_CHANGE,
|
||||
Payload: &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}},
|
||||
}, nil).Once()
|
||||
|
||||
mockReg := &mockRegistry{}
|
||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
|
||||
mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once()
|
||||
c.sessionRegistry = mockReg
|
||||
mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Unavailable, "send fail") }
|
||||
|
||||
expectedErr := status.Error(codes.Unavailable, "send fail")
|
||||
mockStream.On("Send", mock.Anything).Return(expectedErr).Once()
|
||||
|
||||
err := c.processEventStream(mockStream)
|
||||
if !errors.Is(err, status.Error(codes.Unavailable, "send fail")) {
|
||||
t.Errorf("expected send fail error, got %v", err)
|
||||
}
|
||||
assert.Equal(t, expectedErr, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendNode(t *testing.T) {
|
||||
c := &client{}
|
||||
mockStream := &mockSubscribeClient{}
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
mockStream.sendFunc = func(n *proto.Node) error { return nil }
|
||||
mockStream := &mockSubscribeClient{}
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
err := c.sendNode(mockStream, &proto.Node{}, "context")
|
||||
if err != nil {
|
||||
t.Errorf("sendNode error = %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockStream.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("ConnectionError", func(t *testing.T) {
|
||||
mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Unavailable, "fail") }
|
||||
mockStream := &mockSubscribeClient{}
|
||||
expectedErr := status.Error(codes.Unavailable, "fail")
|
||||
mockStream.On("Send", mock.Anything).Return(expectedErr).Once()
|
||||
err := c.sendNode(mockStream, &proto.Node{}, "context")
|
||||
if err == nil {
|
||||
t.Errorf("expected error")
|
||||
}
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
mockStream.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("OtherError", func(t *testing.T) {
|
||||
mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Internal, "fail") }
|
||||
mockStream := &mockSubscribeClient{}
|
||||
mockStream.On("Send", mock.Anything).Return(status.Error(codes.Internal, "fail")).Once()
|
||||
err := c.sendNode(mockStream, &proto.Node{}, "context")
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error for non-connection error (logged only)")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockStream.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -462,80 +448,47 @@ func TestHandleSlugChange(t *testing.T) {
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
mockSess := &mockSession{}
|
||||
mockInter := &mockInteraction{}
|
||||
mockSess.interactionFunc = func() interaction.Interaction { return mockInter }
|
||||
mockSess.On("Interaction").Return(mockInter).Once()
|
||||
mockInter.On("Redraw").Return().Once()
|
||||
|
||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) {
|
||||
if key.Id != "old-slug" {
|
||||
t.Errorf("expected old-slug, got %s", key.Id)
|
||||
}
|
||||
return mockSess, nil
|
||||
}
|
||||
mockReg.updateFunc = func(user string, oldKey, newKey registry.Key) error {
|
||||
if user != "mas-fuad" || oldKey.Id != "old-slug" || newKey.Id != "new-slug" {
|
||||
t.Errorf("unexpected update args")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
mockReg.On("Get", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once()
|
||||
mockReg.On("Update", "mas-fuad", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}, types.SessionKey{Id: "new-slug", Type: types.TunnelTypeHTTP}).Return(nil).Once()
|
||||
|
||||
sent := false
|
||||
mockStream.sendFunc = func(n *proto.Node) error {
|
||||
sent = true
|
||||
if n.Type != proto.EventType_SLUG_CHANGE_RESPONSE {
|
||||
t.Errorf("expected slug change response")
|
||||
}
|
||||
resp := n.GetSlugEventResponse()
|
||||
if !resp.Success {
|
||||
t.Errorf("expected success")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||
return n.Type == proto.EventType_SLUG_CHANGE_RESPONSE && n.GetSlugEventResponse().Success
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := c.handleSlugChange(mockStream, evt)
|
||||
if err != nil {
|
||||
t.Errorf("handleSlugChange error = %v", err)
|
||||
}
|
||||
if !mockInter.redrawCalled {
|
||||
t.Errorf("redraw was not called")
|
||||
}
|
||||
if !sent {
|
||||
t.Errorf("response not sent")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockReg.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
mockInter.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("SessionNotFound", func(t *testing.T) {
|
||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
mockStream.sendFunc = func(n *proto.Node) error {
|
||||
resp := n.GetSlugEventResponse()
|
||||
if resp.Success || resp.Message != "not found" {
|
||||
t.Errorf("unexpected failure response: %v", resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
mockReg.On("Get", mock.Anything).Return(nil, errors.New("not found")).Once()
|
||||
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||
return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "not found"
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := c.handleSlugChange(mockStream, evt)
|
||||
if err != nil {
|
||||
t.Errorf("handleSlugChange should return nil if error is handled via response, but it currently returns whatever sendNode returns")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockReg.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("UpdateError", func(t *testing.T) {
|
||||
mockSess := &mockSession{}
|
||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return mockSess, nil }
|
||||
mockReg.updateFunc = func(user string, oldKey, newKey registry.Key) error {
|
||||
return errors.New("update fail")
|
||||
}
|
||||
mockStream.sendFunc = func(n *proto.Node) error {
|
||||
resp := n.GetSlugEventResponse()
|
||||
if resp.Success || resp.Message != "update fail" {
|
||||
t.Errorf("unexpected failure response: %v", resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
mockReg.On("Get", mock.Anything).Return(mockSess, nil).Once()
|
||||
mockReg.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("update fail")).Once()
|
||||
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||
return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "update fail"
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := c.handleSlugChange(mockStream, evt)
|
||||
if err != nil {
|
||||
t.Errorf("handleSlugChange error = %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockReg.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -543,7 +496,6 @@ func TestHandleGetSessions(t *testing.T) {
|
||||
mockReg := &mockRegistry{}
|
||||
mockStream := &mockSubscribeClient{}
|
||||
mockCfg := &MockConfig{}
|
||||
mockCfg.On("Domain").Return("test.com")
|
||||
c := &client{sessionRegistry: mockReg, config: mockCfg}
|
||||
|
||||
evt := &proto.Events{
|
||||
@@ -557,43 +509,30 @@ func TestHandleGetSessions(t *testing.T) {
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
mockSess := &mockSession{}
|
||||
mockSess.detailFunc = func() *types.Detail {
|
||||
return &types.Detail{
|
||||
ForwardingType: "http",
|
||||
Slug: "myslug",
|
||||
UserID: "mas-fuad",
|
||||
Active: true,
|
||||
StartedAt: now,
|
||||
}
|
||||
}
|
||||
mockSess.On("Detail").Return(&types.Detail{
|
||||
ForwardingType: "http",
|
||||
Slug: "myslug",
|
||||
UserID: "mas-fuad",
|
||||
Active: true,
|
||||
StartedAt: now,
|
||||
}).Once()
|
||||
|
||||
mockReg.getAllSessionFromUserFunc = func(user string) []registry.Session {
|
||||
if user != "mas-fuad" {
|
||||
t.Errorf("expected mas-fuad, got %s", user)
|
||||
}
|
||||
return []registry.Session{mockSess}
|
||||
}
|
||||
mockReg.On("GetAllSessionFromUser", "mas-fuad").Return([]registry.Session{mockSess}).Once()
|
||||
mockCfg.On("Domain").Return("test.com").Once()
|
||||
|
||||
sent := false
|
||||
mockStream.sendFunc = func(n *proto.Node) error {
|
||||
sent = true
|
||||
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||
if n.Type != proto.EventType_GET_SESSIONS {
|
||||
t.Errorf("expected get sessions response type")
|
||||
return false
|
||||
}
|
||||
resp := n.GetGetSessionsEvent()
|
||||
if len(resp.Details) != 1 || resp.Details[0].Slug != "myslug" {
|
||||
t.Errorf("unexpected details: %v", resp.Details)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
details := n.GetGetSessionsEvent().Details
|
||||
return len(details) == 1 && details[0].Slug == "myslug"
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := c.handleGetSessions(mockStream, evt)
|
||||
if err != nil {
|
||||
t.Errorf("handleGetSessions error = %v", err)
|
||||
}
|
||||
if !sent {
|
||||
t.Errorf("response not sent")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockReg.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
mockCfg.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -615,41 +554,20 @@ func TestHandleTerminateSession(t *testing.T) {
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
mockSess := &mockSession{}
|
||||
mockLife := &mockLifecycle{}
|
||||
mockSess.lifecycleFunc = func() lifecycle.Lifecycle { return mockLife }
|
||||
mockSess.On("Lifecycle").Return(mockLife).Once()
|
||||
mockLife.On("Close").Return(nil).Once()
|
||||
|
||||
closed := false
|
||||
mockLife.closeFunc = func() error {
|
||||
closed = true
|
||||
return nil
|
||||
}
|
||||
mockReg.On("GetWithUser", "mas-fuad", types.SessionKey{Id: "myslug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once()
|
||||
|
||||
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) {
|
||||
if user != "mas-fuad" || key.Id != "myslug" || key.Type != types.TunnelTypeHTTP {
|
||||
t.Errorf("unexpected get args")
|
||||
}
|
||||
return mockSess, nil
|
||||
}
|
||||
|
||||
sent := false
|
||||
mockStream.sendFunc = func(n *proto.Node) error {
|
||||
sent = true
|
||||
resp := n.GetTerminateSessionEventResponse()
|
||||
if !resp.Success {
|
||||
t.Errorf("expected success")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||
return n.GetTerminateSessionEventResponse().Success
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := c.handleTerminateSession(mockStream, evt)
|
||||
if err != nil {
|
||||
t.Errorf("handleTerminateSession error = %v", err)
|
||||
}
|
||||
if !closed {
|
||||
t.Errorf("close was not called")
|
||||
}
|
||||
if !sent {
|
||||
t.Errorf("response not sent")
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockReg.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
mockLife.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("TunnelTypeUnknown", func(t *testing.T) {
|
||||
@@ -660,54 +578,46 @@ func TestHandleTerminateSession(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
mockStream.sendFunc = func(n *proto.Node) error {
|
||||
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||
resp := n.GetTerminateSessionEventResponse()
|
||||
if resp.Success || resp.Message == "" {
|
||||
t.Errorf("expected failure response")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return !resp.Success && resp.Message != ""
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := c.handleTerminateSession(mockStream, badEvt)
|
||||
if err != nil {
|
||||
t.Errorf("handleTerminateSession error = %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockStream.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("SessionNotFound", func(t *testing.T) {
|
||||
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
mockStream.sendFunc = func(n *proto.Node) error {
|
||||
mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("not found")).Once()
|
||||
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||
resp := n.GetTerminateSessionEventResponse()
|
||||
if resp.Success || resp.Message != "not found" {
|
||||
t.Errorf("unexpected failure response: %v", resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return !resp.Success && resp.Message == "not found"
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := c.handleTerminateSession(mockStream, evt)
|
||||
if err != nil {
|
||||
t.Errorf("handleTerminateSession error = %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockReg.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("CloseError", func(t *testing.T) {
|
||||
mockSess := &mockSession{}
|
||||
mockLife := &mockLifecycle{}
|
||||
mockSess.lifecycleFunc = func() lifecycle.Lifecycle { return mockLife }
|
||||
mockLife.closeFunc = func() error { return errors.New("close fail") }
|
||||
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return mockSess, nil }
|
||||
mockSess.On("Lifecycle").Return(mockLife).Once()
|
||||
mockLife.On("Close").Return(errors.New("close fail")).Once()
|
||||
mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(mockSess, nil).Once()
|
||||
|
||||
mockStream.sendFunc = func(n *proto.Node) error {
|
||||
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||
resp := n.GetTerminateSessionEventResponse()
|
||||
if resp.Success || resp.Message != "close fail" {
|
||||
t.Errorf("expected failure response: %v", resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return !resp.Success && resp.Message == "close fail"
|
||||
})).Return(nil).Once()
|
||||
|
||||
err := c.handleTerminateSession(mockStream, evt)
|
||||
if err != nil {
|
||||
t.Errorf("handleTerminateSession error = %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockReg.AssertExpectations(t)
|
||||
mockStream.AssertExpectations(t)
|
||||
mockLife.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -718,42 +628,29 @@ func TestSubscribeAndProcess(t *testing.T) {
|
||||
backoff := time.Second
|
||||
|
||||
t.Run("SubscribeError", func(t *testing.T) {
|
||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
||||
return nil, status.Error(codes.Unauthenticated, "unauth")
|
||||
}
|
||||
expectedErr := status.Error(codes.Unauthenticated, "unauth")
|
||||
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once()
|
||||
err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
|
||||
if !errors.Is(err, status.Error(codes.Unauthenticated, "unauth")) {
|
||||
t.Errorf("expected unauth error, got %v", err)
|
||||
}
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
})
|
||||
|
||||
t.Run("AuthSendError", func(t *testing.T) {
|
||||
mockStream := &mockSubscribeClient{}
|
||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
||||
return mockStream, nil
|
||||
}
|
||||
mockStream.sendFunc = func(n *proto.Node) error {
|
||||
return status.Error(codes.Internal, "send fail")
|
||||
}
|
||||
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once()
|
||||
expectedErr := status.Error(codes.Internal, "send fail")
|
||||
mockStream.On("Send", mock.Anything).Return(expectedErr).Once()
|
||||
err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
|
||||
if !errors.Is(err, status.Error(codes.Internal, "send fail")) {
|
||||
t.Errorf("expected send fail, got %v", err)
|
||||
}
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
})
|
||||
|
||||
t.Run("StreamError", func(t *testing.T) {
|
||||
mockStream := &mockSubscribeClient{}
|
||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
||||
return mockStream, nil
|
||||
}
|
||||
mockStream.sendFunc = func(n *proto.Node) error { return nil }
|
||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
||||
return nil, status.Error(codes.Internal, "stream fail")
|
||||
}
|
||||
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once()
|
||||
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||
expectedErr := status.Error(codes.Internal, "stream fail")
|
||||
mockStream.On("Recv").Return(nil, expectedErr).Once()
|
||||
err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
|
||||
if !errors.Is(err, status.Error(codes.Internal, "stream fail")) {
|
||||
t.Errorf("expected stream fail, got %v", err)
|
||||
}
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -762,13 +659,10 @@ func TestSubscribeEvents(t *testing.T) {
|
||||
c := &client{eventService: mockEventSvc}
|
||||
|
||||
t.Run("ReturnsOnError", func(t *testing.T) {
|
||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
||||
return nil, errors.New("fatal error")
|
||||
}
|
||||
expectedErr := errors.New("fatal error")
|
||||
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once()
|
||||
err := c.SubscribeEvents(context.Background(), "id", "token")
|
||||
if err == nil || err.Error() != "fatal error" {
|
||||
t.Errorf("expected fatal error, got %v", err)
|
||||
}
|
||||
assert.ErrorIs(t, err, expectedErr)
|
||||
})
|
||||
|
||||
t.Run("RetryLoop", func(t *testing.T) {
|
||||
@@ -779,19 +673,11 @@ func TestSubscribeEvents(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
callCount := 0
|
||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
||||
callCount++
|
||||
return nil, status.Error(codes.Unavailable, "unavailable")
|
||||
}
|
||||
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, status.Error(codes.Unavailable, "unavailable"))
|
||||
|
||||
err := c.SubscribeEvents(ctx, "id", "token")
|
||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("expected timeout/canceled error, got %v", err)
|
||||
}
|
||||
if callCount <= 1 {
|
||||
t.Errorf("expected multiple calls due to retry, got %d", callCount)
|
||||
}
|
||||
assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled))
|
||||
mockEventSvc.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -806,33 +692,24 @@ func TestCheckServerHealth(t *testing.T) {
|
||||
c := &client{}
|
||||
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
|
||||
return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil
|
||||
}
|
||||
mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil).Once()
|
||||
err := c.CheckServerHealth(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("expected nil error, got %v", err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
mockHealth.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("Error", func(t *testing.T) {
|
||||
mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
|
||||
return nil, errors.New("health fail")
|
||||
}
|
||||
mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("health fail")).Once()
|
||||
err := c.CheckServerHealth(context.Background())
|
||||
if err == nil || err.Error() != "health check failed: health fail" {
|
||||
t.Errorf("expected health fail error, got %v", err)
|
||||
}
|
||||
assert.ErrorContains(t, err, "health check failed: health fail")
|
||||
mockHealth.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("NotServing", func(t *testing.T) {
|
||||
mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
|
||||
return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil
|
||||
}
|
||||
mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil).Once()
|
||||
err := c.CheckServerHealth(context.Background())
|
||||
if err == nil || err.Error() != "server not serving: NOT_SERVING" {
|
||||
t.Errorf("expected not serving error, got %v", err)
|
||||
}
|
||||
assert.ErrorContains(t, err, "server not serving: NOT_SERVING")
|
||||
mockHealth.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -897,88 +774,192 @@ func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
||||
|
||||
type mockRegistry struct {
|
||||
registry.Registry
|
||||
getFunc func(key registry.Key) (registry.Session, error)
|
||||
getWithUserFunc func(user string, key registry.Key) (registry.Session, error)
|
||||
updateFunc func(user string, oldKey, newKey registry.Key) error
|
||||
getAllSessionFromUserFunc func(user string) []registry.Session
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
return m.getFunc(key)
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
func (m *mockRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
return m.getWithUserFunc(user, key)
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
func (m *mockRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
return m.updateFunc(user, oldKey, newKey)
|
||||
return m.Called(user, oldKey, newKey).Error(0)
|
||||
}
|
||||
func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
return m.getAllSessionFromUserFunc(user)
|
||||
args := m.Called(user)
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
func (m *mockRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
return m.Called(key, session).Bool(0)
|
||||
}
|
||||
func (m *mockRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
type mockSession struct {
|
||||
registry.Session
|
||||
lifecycleFunc func() lifecycle.Lifecycle
|
||||
interactionFunc func() interaction.Interaction
|
||||
detailFunc func() *types.Detail
|
||||
slugFunc func() slug.Slug
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockSession) Lifecycle() lifecycle.Lifecycle { return m.lifecycleFunc() }
|
||||
func (m *mockSession) Interaction() interaction.Interaction { return m.interactionFunc() }
|
||||
func (m *mockSession) Detail() *types.Detail { return m.detailFunc() }
|
||||
func (m *mockSession) Slug() slug.Slug { return m.slugFunc() }
|
||||
func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(lifecycle.Lifecycle)
|
||||
}
|
||||
func (m *mockSession) Interaction() interaction.Interaction {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(interaction.Interaction)
|
||||
}
|
||||
func (m *mockSession) Detail() *types.Detail {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*types.Detail)
|
||||
}
|
||||
func (m *mockSession) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
func (m *mockSession) Forwarder() forwarder.Forwarder {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(forwarder.Forwarder)
|
||||
}
|
||||
|
||||
type mockInteraction struct {
|
||||
interaction.Interaction
|
||||
redrawCalled bool
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockInteraction) Redraw() { m.redrawCalled = true }
|
||||
func (m *mockInteraction) Start() { m.Called() }
|
||||
func (m *mockInteraction) Stop() { m.Called() }
|
||||
func (m *mockInteraction) Redraw() { m.Called() }
|
||||
func (m *mockInteraction) SetWH(w, h int) { m.Called(w, h) }
|
||||
func (m *mockInteraction) SetChannel(channel ssh.Channel) { m.Called(channel) }
|
||||
func (m *mockInteraction) SetMode(mode types.InteractiveMode) { m.Called(mode) }
|
||||
func (m *mockInteraction) Mode() types.InteractiveMode {
|
||||
return m.Called().Get(0).(types.InteractiveMode)
|
||||
}
|
||||
func (m *mockInteraction) Send(message string) error { return m.Called(message).Error(0) }
|
||||
|
||||
type mockLifecycle struct {
|
||||
lifecycle.Lifecycle
|
||||
closeFunc func() error
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockLifecycle) Close() error { return m.closeFunc() }
|
||||
func (m *mockLifecycle) Close() error { return m.Called().Error(0) }
|
||||
func (m *mockLifecycle) Channel() ssh.Channel {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(ssh.Channel)
|
||||
}
|
||||
func (m *mockLifecycle) Connection() ssh.Conn {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(ssh.Conn)
|
||||
}
|
||||
func (m *mockLifecycle) User() string { return m.Called().String(0) }
|
||||
func (m *mockLifecycle) SetChannel(channel ssh.Channel) { m.Called(channel) }
|
||||
func (m *mockLifecycle) SetStatus(status types.SessionStatus) { m.Called(status) }
|
||||
func (m *mockLifecycle) IsActive() bool { return m.Called().Bool(0) }
|
||||
func (m *mockLifecycle) StartedAt() time.Time { return m.Called().Get(0).(time.Time) }
|
||||
func (m *mockLifecycle) PortRegistry() port.Port {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(port.Port)
|
||||
}
|
||||
|
||||
type mockEventServiceClient struct {
|
||||
proto.EventServiceClient
|
||||
subscribeFunc func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error)
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockEventServiceClient) Subscribe(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
||||
return m.subscribeFunc(ctx, opts...)
|
||||
args := m.Called(ctx, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(proto.EventService_SubscribeClient), args.Error(1)
|
||||
}
|
||||
|
||||
type mockSubscribeClient struct {
|
||||
mock.Mock
|
||||
grpc.ClientStream
|
||||
sendFunc func(*proto.Node) error
|
||||
recvFunc func() (*proto.Events, error)
|
||||
}
|
||||
|
||||
func (m *mockSubscribeClient) Send(n *proto.Node) error { return m.sendFunc(n) }
|
||||
func (m *mockSubscribeClient) Recv() (*proto.Events, error) { return m.recvFunc() }
|
||||
func (m *mockSubscribeClient) Context() context.Context { return context.Background() }
|
||||
func (m *mockSubscribeClient) Send(n *proto.Node) error { return m.Called(n).Error(0) }
|
||||
func (m *mockSubscribeClient) Recv() (*proto.Events, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*proto.Events), args.Error(1)
|
||||
}
|
||||
func (m *mockSubscribeClient) Context() context.Context { return m.Called().Get(0).(context.Context) }
|
||||
|
||||
type mockUserServiceClient struct {
|
||||
proto.UserServiceClient
|
||||
checkFunc func(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error)
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockUserServiceClient) Check(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) {
|
||||
return m.checkFunc(ctx, in, opts...)
|
||||
args := m.Called(ctx, in, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*proto.CheckResponse), args.Error(1)
|
||||
}
|
||||
|
||||
type mockHealthClient struct {
|
||||
grpc_health_v1.HealthClient
|
||||
checkFunc func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error)
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockHealthClient) Check(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
|
||||
return m.checkFunc(ctx, in, opts...)
|
||||
args := m.Called(ctx, in, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*grpc_health_v1.HealthCheckResponse), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockHealthClient) Watch(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (grpc_health_v1.Health_WatchClient, error) {
|
||||
args := m.Called(ctx, in, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(grpc_health_v1.Health_WatchClient), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockHealthClient) List(ctx context.Context, in *grpc_health_v1.HealthListRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthListResponse, error) {
|
||||
args := m.Called(ctx, in, opts)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(*grpc_health_v1.HealthListResponse), args.Error(1)
|
||||
}
|
||||
|
||||
func TestProtoToTunnelType(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user