fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 - autoclosed #63
@@ -198,8 +198,8 @@ func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
||||||
m.Called()
|
args := m.Called(ctx, token)
|
||||||
return
|
return args.Bool(0), args.String(1), args.Error(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
||||||
|
|||||||
+301
-320
@@ -8,14 +8,18 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"tunnel_pls/internal/port"
|
||||||
"tunnel_pls/internal/registry"
|
"tunnel_pls/internal/registry"
|
||||||
|
"tunnel_pls/session/forwarder"
|
||||||
"tunnel_pls/session/interaction"
|
"tunnel_pls/session/interaction"
|
||||||
"tunnel_pls/session/lifecycle"
|
"tunnel_pls/session/lifecycle"
|
||||||
"tunnel_pls/session/slug"
|
"tunnel_pls/session/slug"
|
||||||
"tunnel_pls/types"
|
"tunnel_pls/types"
|
||||||
|
|
||||||
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
|
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/health/grpc_health_v1"
|
"google.golang.org/grpc/health/grpc_health_v1"
|
||||||
@@ -78,23 +82,15 @@ func TestAuthorizeConn(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
mockUserSvc.checkFunc = func(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) {
|
mockUserSvc.On("Check", mock.Anything, &proto.CheckRequest{AuthToken: tt.token}, mock.Anything).Return(tt.mockResp, tt.mockErr).Once()
|
||||||
if in.AuthToken != tt.token {
|
|
||||||
t.Errorf("expected token %s, got %s", tt.token, in.AuthToken)
|
|
||||||
}
|
|
||||||
return tt.mockResp, tt.mockErr
|
|
||||||
}
|
|
||||||
|
|
||||||
auth, user, err := c.AuthorizeConn(context.Background(), tt.token)
|
auth, user, err := c.AuthorizeConn(context.Background(), tt.token)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("AuthorizeConn() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("AuthorizeConn() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
if auth != tt.wantAuth {
|
assert.Equal(t, tt.wantAuth, auth)
|
||||||
t.Errorf("AuthorizeConn() auth = %v, wantAuth %v", auth, tt.wantAuth)
|
assert.Equal(t, tt.wantUser, user)
|
||||||
}
|
mockUserSvc.AssertExpectations(t)
|
||||||
if user != tt.wantUser {
|
|
||||||
t.Errorf("AuthorizeConn() user = %s, wantUser %s", user, tt.wantUser)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -330,25 +326,15 @@ func TestHandleAuthError_WaitFail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessEventStream(t *testing.T) {
|
func TestProcessEventStream(t *testing.T) {
|
||||||
mockStream := &mockSubscribeClient{}
|
|
||||||
c := &client{}
|
c := &client{}
|
||||||
|
|
||||||
t.Run("UnknownEventType", func(t *testing.T) {
|
t.Run("UnknownEventType", func(t *testing.T) {
|
||||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
mockStream := &mockSubscribeClient{}
|
||||||
return &proto.Events{Type: proto.EventType(999)}, nil
|
mockStream.On("Recv").Return(&proto.Events{Type: proto.EventType(999)}, nil).Once()
|
||||||
}
|
mockStream.On("Recv").Return(nil, io.EOF).Once()
|
||||||
first := true
|
|
||||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
|
||||||
if first {
|
|
||||||
first = false
|
|
||||||
return &proto.Events{Type: proto.EventType(999)}, nil
|
|
||||||
}
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
err := c.processEventStream(mockStream)
|
err := c.processEventStream(mockStream)
|
||||||
if !errors.Is(err, io.EOF) {
|
assert.ErrorIs(t, err, io.EOF)
|
||||||
t.Errorf("expected EOF, got %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("DispatchSuccess", func(t *testing.T) {
|
t.Run("DispatchSuccess", func(t *testing.T) {
|
||||||
@@ -360,10 +346,7 @@ func TestProcessEventStream(t *testing.T) {
|
|||||||
|
|
||||||
for _, et := range events {
|
for _, et := range events {
|
||||||
t.Run(et.String(), func(t *testing.T) {
|
t.Run(et.String(), func(t *testing.T) {
|
||||||
first := true
|
mockStream := &mockSubscribeClient{}
|
||||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
|
||||||
if first {
|
|
||||||
first = false
|
|
||||||
payload := &proto.Events{Type: et}
|
payload := &proto.Events{Type: et}
|
||||||
switch et {
|
switch et {
|
||||||
case proto.EventType_SLUG_CHANGE:
|
case proto.EventType_SLUG_CHANGE:
|
||||||
@@ -373,74 +356,77 @@ func TestProcessEventStream(t *testing.T) {
|
|||||||
case proto.EventType_TERMINATE_SESSION:
|
case proto.EventType_TERMINATE_SESSION:
|
||||||
payload.Payload = &proto.Events_TerminateSessionEvent{TerminateSessionEvent: &proto.TerminateSessionEvent{}}
|
payload.Payload = &proto.Events_TerminateSessionEvent{TerminateSessionEvent: &proto.TerminateSessionEvent{}}
|
||||||
}
|
}
|
||||||
return payload, nil
|
|
||||||
}
|
mockStream.On("Recv").Return(payload, nil).Once()
|
||||||
return nil, io.EOF
|
mockStream.On("Recv").Return(nil, io.EOF).Once()
|
||||||
}
|
|
||||||
mockReg := &mockRegistry{}
|
mockReg := &mockRegistry{}
|
||||||
mockReg.getAllSessionFromUserFunc = func(user string) []registry.Session { return nil }
|
|
||||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
|
|
||||||
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
|
|
||||||
c.sessionRegistry = mockReg
|
c.sessionRegistry = mockReg
|
||||||
c.config = &MockConfig{}
|
mCfg := &MockConfig{}
|
||||||
c.config.(*MockConfig).On("Domain").Return("test.com")
|
c.config = mCfg
|
||||||
mockStream.sendFunc = func(n *proto.Node) error { return nil }
|
mCfg.On("Domain").Return("test.com").Maybe()
|
||||||
|
|
||||||
|
switch et {
|
||||||
|
case proto.EventType_SLUG_CHANGE:
|
||||||
|
mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once()
|
||||||
|
case proto.EventType_GET_SESSIONS:
|
||||||
|
mockReg.On("GetAllSessionFromUser", mock.Anything).Return(nil).Once()
|
||||||
|
case proto.EventType_TERMINATE_SESSION:
|
||||||
|
mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("fail")).Once()
|
||||||
|
}
|
||||||
|
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||||
|
|
||||||
err := c.processEventStream(mockStream)
|
err := c.processEventStream(mockStream)
|
||||||
if !errors.Is(err, io.EOF) {
|
assert.ErrorIs(t, err, io.EOF)
|
||||||
t.Errorf("expected EOF, got %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("HandlerError", func(t *testing.T) {
|
t.Run("HandlerError", func(t *testing.T) {
|
||||||
first := true
|
mockStream := &mockSubscribeClient{}
|
||||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
mockStream.On("Recv").Return(&proto.Events{
|
||||||
if first {
|
Type: proto.EventType_SLUG_CHANGE,
|
||||||
first = false
|
Payload: &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}},
|
||||||
return &proto.Events{Type: proto.EventType_SLUG_CHANGE, Payload: &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}}}, nil
|
}, nil).Once()
|
||||||
}
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
mockReg := &mockRegistry{}
|
mockReg := &mockRegistry{}
|
||||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
|
mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once()
|
||||||
c.sessionRegistry = mockReg
|
c.sessionRegistry = mockReg
|
||||||
mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Unavailable, "send fail") }
|
|
||||||
|
expectedErr := status.Error(codes.Unavailable, "send fail")
|
||||||
|
mockStream.On("Send", mock.Anything).Return(expectedErr).Once()
|
||||||
|
|
||||||
err := c.processEventStream(mockStream)
|
err := c.processEventStream(mockStream)
|
||||||
if !errors.Is(err, status.Error(codes.Unavailable, "send fail")) {
|
assert.Equal(t, expectedErr, err)
|
||||||
t.Errorf("expected send fail error, got %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSendNode(t *testing.T) {
|
func TestSendNode(t *testing.T) {
|
||||||
c := &client{}
|
c := &client{}
|
||||||
mockStream := &mockSubscribeClient{}
|
|
||||||
|
|
||||||
t.Run("Success", func(t *testing.T) {
|
t.Run("Success", func(t *testing.T) {
|
||||||
mockStream.sendFunc = func(n *proto.Node) error { return nil }
|
mockStream := &mockSubscribeClient{}
|
||||||
|
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||||
err := c.sendNode(mockStream, &proto.Node{}, "context")
|
err := c.sendNode(mockStream, &proto.Node{}, "context")
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("sendNode error = %v", err)
|
mockStream.AssertExpectations(t)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ConnectionError", func(t *testing.T) {
|
t.Run("ConnectionError", func(t *testing.T) {
|
||||||
mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Unavailable, "fail") }
|
mockStream := &mockSubscribeClient{}
|
||||||
|
expectedErr := status.Error(codes.Unavailable, "fail")
|
||||||
|
mockStream.On("Send", mock.Anything).Return(expectedErr).Once()
|
||||||
err := c.sendNode(mockStream, &proto.Node{}, "context")
|
err := c.sendNode(mockStream, &proto.Node{}, "context")
|
||||||
if err == nil {
|
assert.ErrorIs(t, err, expectedErr)
|
||||||
t.Errorf("expected error")
|
mockStream.AssertExpectations(t)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("OtherError", func(t *testing.T) {
|
t.Run("OtherError", func(t *testing.T) {
|
||||||
mockStream.sendFunc = func(n *proto.Node) error { return status.Error(codes.Internal, "fail") }
|
mockStream := &mockSubscribeClient{}
|
||||||
|
mockStream.On("Send", mock.Anything).Return(status.Error(codes.Internal, "fail")).Once()
|
||||||
err := c.sendNode(mockStream, &proto.Node{}, "context")
|
err := c.sendNode(mockStream, &proto.Node{}, "context")
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("expected nil error for non-connection error (logged only)")
|
mockStream.AssertExpectations(t)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -462,80 +448,47 @@ func TestHandleSlugChange(t *testing.T) {
|
|||||||
t.Run("Success", func(t *testing.T) {
|
t.Run("Success", func(t *testing.T) {
|
||||||
mockSess := &mockSession{}
|
mockSess := &mockSession{}
|
||||||
mockInter := &mockInteraction{}
|
mockInter := &mockInteraction{}
|
||||||
mockSess.interactionFunc = func() interaction.Interaction { return mockInter }
|
mockSess.On("Interaction").Return(mockInter).Once()
|
||||||
|
mockInter.On("Redraw").Return().Once()
|
||||||
|
|
||||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) {
|
mockReg.On("Get", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once()
|
||||||
if key.Id != "old-slug" {
|
mockReg.On("Update", "mas-fuad", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}, types.SessionKey{Id: "new-slug", Type: types.TunnelTypeHTTP}).Return(nil).Once()
|
||||||
t.Errorf("expected old-slug, got %s", key.Id)
|
|
||||||
}
|
|
||||||
return mockSess, nil
|
|
||||||
}
|
|
||||||
mockReg.updateFunc = func(user string, oldKey, newKey registry.Key) error {
|
|
||||||
if user != "mas-fuad" || oldKey.Id != "old-slug" || newKey.Id != "new-slug" {
|
|
||||||
t.Errorf("unexpected update args")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sent := false
|
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||||
mockStream.sendFunc = func(n *proto.Node) error {
|
return n.Type == proto.EventType_SLUG_CHANGE_RESPONSE && n.GetSlugEventResponse().Success
|
||||||
sent = true
|
})).Return(nil).Once()
|
||||||
if n.Type != proto.EventType_SLUG_CHANGE_RESPONSE {
|
|
||||||
t.Errorf("expected slug change response")
|
|
||||||
}
|
|
||||||
resp := n.GetSlugEventResponse()
|
|
||||||
if !resp.Success {
|
|
||||||
t.Errorf("expected success")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.handleSlugChange(mockStream, evt)
|
err := c.handleSlugChange(mockStream, evt)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("handleSlugChange error = %v", err)
|
mockReg.AssertExpectations(t)
|
||||||
}
|
mockStream.AssertExpectations(t)
|
||||||
if !mockInter.redrawCalled {
|
mockInter.AssertExpectations(t)
|
||||||
t.Errorf("redraw was not called")
|
|
||||||
}
|
|
||||||
if !sent {
|
|
||||||
t.Errorf("response not sent")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("SessionNotFound", func(t *testing.T) {
|
t.Run("SessionNotFound", func(t *testing.T) {
|
||||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) {
|
mockReg.On("Get", mock.Anything).Return(nil, errors.New("not found")).Once()
|
||||||
return nil, errors.New("not found")
|
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||||
}
|
return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "not found"
|
||||||
mockStream.sendFunc = func(n *proto.Node) error {
|
})).Return(nil).Once()
|
||||||
resp := n.GetSlugEventResponse()
|
|
||||||
if resp.Success || resp.Message != "not found" {
|
|
||||||
t.Errorf("unexpected failure response: %v", resp)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err := c.handleSlugChange(mockStream, evt)
|
err := c.handleSlugChange(mockStream, evt)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("handleSlugChange should return nil if error is handled via response, but it currently returns whatever sendNode returns")
|
mockReg.AssertExpectations(t)
|
||||||
}
|
mockStream.AssertExpectations(t)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("UpdateError", func(t *testing.T) {
|
t.Run("UpdateError", func(t *testing.T) {
|
||||||
mockSess := &mockSession{}
|
mockSess := &mockSession{}
|
||||||
mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return mockSess, nil }
|
mockReg.On("Get", mock.Anything).Return(mockSess, nil).Once()
|
||||||
mockReg.updateFunc = func(user string, oldKey, newKey registry.Key) error {
|
mockReg.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("update fail")).Once()
|
||||||
return errors.New("update fail")
|
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||||
}
|
return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "update fail"
|
||||||
mockStream.sendFunc = func(n *proto.Node) error {
|
})).Return(nil).Once()
|
||||||
resp := n.GetSlugEventResponse()
|
|
||||||
if resp.Success || resp.Message != "update fail" {
|
|
||||||
t.Errorf("unexpected failure response: %v", resp)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err := c.handleSlugChange(mockStream, evt)
|
err := c.handleSlugChange(mockStream, evt)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("handleSlugChange error = %v", err)
|
mockReg.AssertExpectations(t)
|
||||||
}
|
mockStream.AssertExpectations(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -543,7 +496,6 @@ func TestHandleGetSessions(t *testing.T) {
|
|||||||
mockReg := &mockRegistry{}
|
mockReg := &mockRegistry{}
|
||||||
mockStream := &mockSubscribeClient{}
|
mockStream := &mockSubscribeClient{}
|
||||||
mockCfg := &MockConfig{}
|
mockCfg := &MockConfig{}
|
||||||
mockCfg.On("Domain").Return("test.com")
|
|
||||||
c := &client{sessionRegistry: mockReg, config: mockCfg}
|
c := &client{sessionRegistry: mockReg, config: mockCfg}
|
||||||
|
|
||||||
evt := &proto.Events{
|
evt := &proto.Events{
|
||||||
@@ -557,43 +509,30 @@ func TestHandleGetSessions(t *testing.T) {
|
|||||||
t.Run("Success", func(t *testing.T) {
|
t.Run("Success", func(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
mockSess := &mockSession{}
|
mockSess := &mockSession{}
|
||||||
mockSess.detailFunc = func() *types.Detail {
|
mockSess.On("Detail").Return(&types.Detail{
|
||||||
return &types.Detail{
|
|
||||||
ForwardingType: "http",
|
ForwardingType: "http",
|
||||||
Slug: "myslug",
|
Slug: "myslug",
|
||||||
UserID: "mas-fuad",
|
UserID: "mas-fuad",
|
||||||
Active: true,
|
Active: true,
|
||||||
StartedAt: now,
|
StartedAt: now,
|
||||||
}
|
}).Once()
|
||||||
}
|
|
||||||
|
|
||||||
mockReg.getAllSessionFromUserFunc = func(user string) []registry.Session {
|
mockReg.On("GetAllSessionFromUser", "mas-fuad").Return([]registry.Session{mockSess}).Once()
|
||||||
if user != "mas-fuad" {
|
mockCfg.On("Domain").Return("test.com").Once()
|
||||||
t.Errorf("expected mas-fuad, got %s", user)
|
|
||||||
}
|
|
||||||
return []registry.Session{mockSess}
|
|
||||||
}
|
|
||||||
|
|
||||||
sent := false
|
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||||
mockStream.sendFunc = func(n *proto.Node) error {
|
|
||||||
sent = true
|
|
||||||
if n.Type != proto.EventType_GET_SESSIONS {
|
if n.Type != proto.EventType_GET_SESSIONS {
|
||||||
t.Errorf("expected get sessions response type")
|
return false
|
||||||
}
|
|
||||||
resp := n.GetGetSessionsEvent()
|
|
||||||
if len(resp.Details) != 1 || resp.Details[0].Slug != "myslug" {
|
|
||||||
t.Errorf("unexpected details: %v", resp.Details)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
details := n.GetGetSessionsEvent().Details
|
||||||
|
return len(details) == 1 && details[0].Slug == "myslug"
|
||||||
|
})).Return(nil).Once()
|
||||||
|
|
||||||
err := c.handleGetSessions(mockStream, evt)
|
err := c.handleGetSessions(mockStream, evt)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("handleGetSessions error = %v", err)
|
mockReg.AssertExpectations(t)
|
||||||
}
|
mockStream.AssertExpectations(t)
|
||||||
if !sent {
|
mockCfg.AssertExpectations(t)
|
||||||
t.Errorf("response not sent")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -615,41 +554,20 @@ func TestHandleTerminateSession(t *testing.T) {
|
|||||||
t.Run("Success", func(t *testing.T) {
|
t.Run("Success", func(t *testing.T) {
|
||||||
mockSess := &mockSession{}
|
mockSess := &mockSession{}
|
||||||
mockLife := &mockLifecycle{}
|
mockLife := &mockLifecycle{}
|
||||||
mockSess.lifecycleFunc = func() lifecycle.Lifecycle { return mockLife }
|
mockSess.On("Lifecycle").Return(mockLife).Once()
|
||||||
|
mockLife.On("Close").Return(nil).Once()
|
||||||
|
|
||||||
closed := false
|
mockReg.On("GetWithUser", "mas-fuad", types.SessionKey{Id: "myslug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once()
|
||||||
mockLife.closeFunc = func() error {
|
|
||||||
closed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) {
|
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||||
if user != "mas-fuad" || key.Id != "myslug" || key.Type != types.TunnelTypeHTTP {
|
return n.GetTerminateSessionEventResponse().Success
|
||||||
t.Errorf("unexpected get args")
|
})).Return(nil).Once()
|
||||||
}
|
|
||||||
return mockSess, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
sent := false
|
|
||||||
mockStream.sendFunc = func(n *proto.Node) error {
|
|
||||||
sent = true
|
|
||||||
resp := n.GetTerminateSessionEventResponse()
|
|
||||||
if !resp.Success {
|
|
||||||
t.Errorf("expected success")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.handleTerminateSession(mockStream, evt)
|
err := c.handleTerminateSession(mockStream, evt)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("handleTerminateSession error = %v", err)
|
mockReg.AssertExpectations(t)
|
||||||
}
|
mockStream.AssertExpectations(t)
|
||||||
if !closed {
|
mockLife.AssertExpectations(t)
|
||||||
t.Errorf("close was not called")
|
|
||||||
}
|
|
||||||
if !sent {
|
|
||||||
t.Errorf("response not sent")
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("TunnelTypeUnknown", func(t *testing.T) {
|
t.Run("TunnelTypeUnknown", func(t *testing.T) {
|
||||||
@@ -660,54 +578,46 @@ func TestHandleTerminateSession(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
mockStream.sendFunc = func(n *proto.Node) error {
|
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||||
resp := n.GetTerminateSessionEventResponse()
|
resp := n.GetTerminateSessionEventResponse()
|
||||||
if resp.Success || resp.Message == "" {
|
return !resp.Success && resp.Message != ""
|
||||||
t.Errorf("expected failure response")
|
})).Return(nil).Once()
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err := c.handleTerminateSession(mockStream, badEvt)
|
err := c.handleTerminateSession(mockStream, badEvt)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("handleTerminateSession error = %v", err)
|
mockStream.AssertExpectations(t)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("SessionNotFound", func(t *testing.T) {
|
t.Run("SessionNotFound", func(t *testing.T) {
|
||||||
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) {
|
mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("not found")).Once()
|
||||||
return nil, errors.New("not found")
|
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||||
}
|
|
||||||
mockStream.sendFunc = func(n *proto.Node) error {
|
|
||||||
resp := n.GetTerminateSessionEventResponse()
|
resp := n.GetTerminateSessionEventResponse()
|
||||||
if resp.Success || resp.Message != "not found" {
|
return !resp.Success && resp.Message == "not found"
|
||||||
t.Errorf("unexpected failure response: %v", resp)
|
})).Return(nil).Once()
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err := c.handleTerminateSession(mockStream, evt)
|
err := c.handleTerminateSession(mockStream, evt)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("handleTerminateSession error = %v", err)
|
mockReg.AssertExpectations(t)
|
||||||
}
|
mockStream.AssertExpectations(t)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("CloseError", func(t *testing.T) {
|
t.Run("CloseError", func(t *testing.T) {
|
||||||
mockSess := &mockSession{}
|
mockSess := &mockSession{}
|
||||||
mockLife := &mockLifecycle{}
|
mockLife := &mockLifecycle{}
|
||||||
mockSess.lifecycleFunc = func() lifecycle.Lifecycle { return mockLife }
|
mockSess.On("Lifecycle").Return(mockLife).Once()
|
||||||
mockLife.closeFunc = func() error { return errors.New("close fail") }
|
mockLife.On("Close").Return(errors.New("close fail")).Once()
|
||||||
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return mockSess, nil }
|
mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(mockSess, nil).Once()
|
||||||
|
|
||||||
mockStream.sendFunc = func(n *proto.Node) error {
|
mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
|
||||||
resp := n.GetTerminateSessionEventResponse()
|
resp := n.GetTerminateSessionEventResponse()
|
||||||
if resp.Success || resp.Message != "close fail" {
|
return !resp.Success && resp.Message == "close fail"
|
||||||
t.Errorf("expected failure response: %v", resp)
|
})).Return(nil).Once()
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err := c.handleTerminateSession(mockStream, evt)
|
err := c.handleTerminateSession(mockStream, evt)
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("handleTerminateSession error = %v", err)
|
mockReg.AssertExpectations(t)
|
||||||
}
|
mockStream.AssertExpectations(t)
|
||||||
|
mockLife.AssertExpectations(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -718,42 +628,29 @@ func TestSubscribeAndProcess(t *testing.T) {
|
|||||||
backoff := time.Second
|
backoff := time.Second
|
||||||
|
|
||||||
t.Run("SubscribeError", func(t *testing.T) {
|
t.Run("SubscribeError", func(t *testing.T) {
|
||||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
expectedErr := status.Error(codes.Unauthenticated, "unauth")
|
||||||
return nil, status.Error(codes.Unauthenticated, "unauth")
|
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once()
|
||||||
}
|
|
||||||
err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
|
err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
|
||||||
if !errors.Is(err, status.Error(codes.Unauthenticated, "unauth")) {
|
assert.ErrorIs(t, err, expectedErr)
|
||||||
t.Errorf("expected unauth error, got %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("AuthSendError", func(t *testing.T) {
|
t.Run("AuthSendError", func(t *testing.T) {
|
||||||
mockStream := &mockSubscribeClient{}
|
mockStream := &mockSubscribeClient{}
|
||||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once()
|
||||||
return mockStream, nil
|
expectedErr := status.Error(codes.Internal, "send fail")
|
||||||
}
|
mockStream.On("Send", mock.Anything).Return(expectedErr).Once()
|
||||||
mockStream.sendFunc = func(n *proto.Node) error {
|
|
||||||
return status.Error(codes.Internal, "send fail")
|
|
||||||
}
|
|
||||||
err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
|
err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
|
||||||
if !errors.Is(err, status.Error(codes.Internal, "send fail")) {
|
assert.ErrorIs(t, err, expectedErr)
|
||||||
t.Errorf("expected send fail, got %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("StreamError", func(t *testing.T) {
|
t.Run("StreamError", func(t *testing.T) {
|
||||||
mockStream := &mockSubscribeClient{}
|
mockStream := &mockSubscribeClient{}
|
||||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once()
|
||||||
return mockStream, nil
|
mockStream.On("Send", mock.Anything).Return(nil).Once()
|
||||||
}
|
expectedErr := status.Error(codes.Internal, "stream fail")
|
||||||
mockStream.sendFunc = func(n *proto.Node) error { return nil }
|
mockStream.On("Recv").Return(nil, expectedErr).Once()
|
||||||
mockStream.recvFunc = func() (*proto.Events, error) {
|
|
||||||
return nil, status.Error(codes.Internal, "stream fail")
|
|
||||||
}
|
|
||||||
err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
|
err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
|
||||||
if !errors.Is(err, status.Error(codes.Internal, "stream fail")) {
|
assert.ErrorIs(t, err, expectedErr)
|
||||||
t.Errorf("expected stream fail, got %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -762,13 +659,10 @@ func TestSubscribeEvents(t *testing.T) {
|
|||||||
c := &client{eventService: mockEventSvc}
|
c := &client{eventService: mockEventSvc}
|
||||||
|
|
||||||
t.Run("ReturnsOnError", func(t *testing.T) {
|
t.Run("ReturnsOnError", func(t *testing.T) {
|
||||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
expectedErr := errors.New("fatal error")
|
||||||
return nil, errors.New("fatal error")
|
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once()
|
||||||
}
|
|
||||||
err := c.SubscribeEvents(context.Background(), "id", "token")
|
err := c.SubscribeEvents(context.Background(), "id", "token")
|
||||||
if err == nil || err.Error() != "fatal error" {
|
assert.ErrorIs(t, err, expectedErr)
|
||||||
t.Errorf("expected fatal error, got %v", err)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("RetryLoop", func(t *testing.T) {
|
t.Run("RetryLoop", func(t *testing.T) {
|
||||||
@@ -779,19 +673,11 @@ func TestSubscribeEvents(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
callCount := 0
|
mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, status.Error(codes.Unavailable, "unavailable"))
|
||||||
mockEventSvc.subscribeFunc = func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
|
||||||
callCount++
|
|
||||||
return nil, status.Error(codes.Unavailable, "unavailable")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.SubscribeEvents(ctx, "id", "token")
|
err := c.SubscribeEvents(ctx, "id", "token")
|
||||||
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
|
assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled))
|
||||||
t.Errorf("expected timeout/canceled error, got %v", err)
|
mockEventSvc.AssertExpectations(t)
|
||||||
}
|
|
||||||
if callCount <= 1 {
|
|
||||||
t.Errorf("expected multiple calls due to retry, got %d", callCount)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -806,33 +692,24 @@ func TestCheckServerHealth(t *testing.T) {
|
|||||||
c := &client{}
|
c := &client{}
|
||||||
|
|
||||||
t.Run("Success", func(t *testing.T) {
|
t.Run("Success", func(t *testing.T) {
|
||||||
mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
|
mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil).Once()
|
||||||
return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil
|
|
||||||
}
|
|
||||||
err := c.CheckServerHealth(context.Background())
|
err := c.CheckServerHealth(context.Background())
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Errorf("expected nil error, got %v", err)
|
mockHealth.AssertExpectations(t)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Error", func(t *testing.T) {
|
t.Run("Error", func(t *testing.T) {
|
||||||
mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
|
mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("health fail")).Once()
|
||||||
return nil, errors.New("health fail")
|
|
||||||
}
|
|
||||||
err := c.CheckServerHealth(context.Background())
|
err := c.CheckServerHealth(context.Background())
|
||||||
if err == nil || err.Error() != "health check failed: health fail" {
|
assert.ErrorContains(t, err, "health check failed: health fail")
|
||||||
t.Errorf("expected health fail error, got %v", err)
|
mockHealth.AssertExpectations(t)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("NotServing", func(t *testing.T) {
|
t.Run("NotServing", func(t *testing.T) {
|
||||||
mockHealth.checkFunc = func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
|
mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil).Once()
|
||||||
return &grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil
|
|
||||||
}
|
|
||||||
err := c.CheckServerHealth(context.Background())
|
err := c.CheckServerHealth(context.Background())
|
||||||
if err == nil || err.Error() != "server not serving: NOT_SERVING" {
|
assert.ErrorContains(t, err, "server not serving: NOT_SERVING")
|
||||||
t.Errorf("expected not serving error, got %v", err)
|
mockHealth.AssertExpectations(t)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -897,88 +774,192 @@ func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
|||||||
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
||||||
|
|
||||||
type mockRegistry struct {
|
type mockRegistry struct {
|
||||||
registry.Registry
|
mock.Mock
|
||||||
getFunc func(key registry.Key) (registry.Session, error)
|
|
||||||
getWithUserFunc func(user string, key registry.Key) (registry.Session, error)
|
|
||||||
updateFunc func(user string, oldKey, newKey registry.Key) error
|
|
||||||
getAllSessionFromUserFunc func(user string) []registry.Session
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRegistry) Get(key registry.Key) (registry.Session, error) {
|
func (m *mockRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||||
return m.getFunc(key)
|
args := m.Called(key)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(registry.Session), args.Error(1)
|
||||||
}
|
}
|
||||||
func (m *mockRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
func (m *mockRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||||
return m.getWithUserFunc(user, key)
|
args := m.Called(user, key)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(registry.Session), args.Error(1)
|
||||||
}
|
}
|
||||||
func (m *mockRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
func (m *mockRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||||
return m.updateFunc(user, oldKey, newKey)
|
return m.Called(user, oldKey, newKey).Error(0)
|
||||||
}
|
}
|
||||||
func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||||
return m.getAllSessionFromUserFunc(user)
|
args := m.Called(user)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).([]registry.Session)
|
||||||
|
}
|
||||||
|
func (m *mockRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||||
|
return m.Called(key, session).Bool(0)
|
||||||
|
}
|
||||||
|
func (m *mockRegistry) Remove(key registry.Key) {
|
||||||
|
m.Called(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockSession struct {
|
type mockSession struct {
|
||||||
registry.Session
|
mock.Mock
|
||||||
lifecycleFunc func() lifecycle.Lifecycle
|
|
||||||
interactionFunc func() interaction.Interaction
|
|
||||||
detailFunc func() *types.Detail
|
|
||||||
slugFunc func() slug.Slug
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSession) Lifecycle() lifecycle.Lifecycle { return m.lifecycleFunc() }
|
func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
|
||||||
func (m *mockSession) Interaction() interaction.Interaction { return m.interactionFunc() }
|
args := m.Called()
|
||||||
func (m *mockSession) Detail() *types.Detail { return m.detailFunc() }
|
if args.Get(0) == nil {
|
||||||
func (m *mockSession) Slug() slug.Slug { return m.slugFunc() }
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(lifecycle.Lifecycle)
|
||||||
|
}
|
||||||
|
func (m *mockSession) Interaction() interaction.Interaction {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(interaction.Interaction)
|
||||||
|
}
|
||||||
|
func (m *mockSession) Detail() *types.Detail {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(*types.Detail)
|
||||||
|
}
|
||||||
|
func (m *mockSession) Slug() slug.Slug {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(slug.Slug)
|
||||||
|
}
|
||||||
|
func (m *mockSession) Forwarder() forwarder.Forwarder {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(forwarder.Forwarder)
|
||||||
|
}
|
||||||
|
|
||||||
type mockInteraction struct {
|
type mockInteraction struct {
|
||||||
interaction.Interaction
|
mock.Mock
|
||||||
redrawCalled bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockInteraction) Redraw() { m.redrawCalled = true }
|
func (m *mockInteraction) Start() { m.Called() }
|
||||||
|
func (m *mockInteraction) Stop() { m.Called() }
|
||||||
|
func (m *mockInteraction) Redraw() { m.Called() }
|
||||||
|
func (m *mockInteraction) SetWH(w, h int) { m.Called(w, h) }
|
||||||
|
func (m *mockInteraction) SetChannel(channel ssh.Channel) { m.Called(channel) }
|
||||||
|
func (m *mockInteraction) SetMode(mode types.InteractiveMode) { m.Called(mode) }
|
||||||
|
func (m *mockInteraction) Mode() types.InteractiveMode {
|
||||||
|
return m.Called().Get(0).(types.InteractiveMode)
|
||||||
|
}
|
||||||
|
func (m *mockInteraction) Send(message string) error { return m.Called(message).Error(0) }
|
||||||
|
|
||||||
type mockLifecycle struct {
|
type mockLifecycle struct {
|
||||||
lifecycle.Lifecycle
|
mock.Mock
|
||||||
closeFunc func() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockLifecycle) Close() error { return m.closeFunc() }
|
func (m *mockLifecycle) Close() error { return m.Called().Error(0) }
|
||||||
|
func (m *mockLifecycle) Channel() ssh.Channel {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(ssh.Channel)
|
||||||
|
}
|
||||||
|
func (m *mockLifecycle) Connection() ssh.Conn {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(ssh.Conn)
|
||||||
|
}
|
||||||
|
func (m *mockLifecycle) User() string { return m.Called().String(0) }
|
||||||
|
func (m *mockLifecycle) SetChannel(channel ssh.Channel) { m.Called(channel) }
|
||||||
|
func (m *mockLifecycle) SetStatus(status types.SessionStatus) { m.Called(status) }
|
||||||
|
func (m *mockLifecycle) IsActive() bool { return m.Called().Bool(0) }
|
||||||
|
func (m *mockLifecycle) StartedAt() time.Time { return m.Called().Get(0).(time.Time) }
|
||||||
|
func (m *mockLifecycle) PortRegistry() port.Port {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(port.Port)
|
||||||
|
}
|
||||||
|
|
||||||
type mockEventServiceClient struct {
|
type mockEventServiceClient struct {
|
||||||
proto.EventServiceClient
|
mock.Mock
|
||||||
subscribeFunc func(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockEventServiceClient) Subscribe(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
func (m *mockEventServiceClient) Subscribe(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
|
||||||
return m.subscribeFunc(ctx, opts...)
|
args := m.Called(ctx, opts)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(proto.EventService_SubscribeClient), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockSubscribeClient struct {
|
type mockSubscribeClient struct {
|
||||||
|
mock.Mock
|
||||||
grpc.ClientStream
|
grpc.ClientStream
|
||||||
sendFunc func(*proto.Node) error
|
|
||||||
recvFunc func() (*proto.Events, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSubscribeClient) Send(n *proto.Node) error { return m.sendFunc(n) }
|
func (m *mockSubscribeClient) Send(n *proto.Node) error { return m.Called(n).Error(0) }
|
||||||
func (m *mockSubscribeClient) Recv() (*proto.Events, error) { return m.recvFunc() }
|
func (m *mockSubscribeClient) Recv() (*proto.Events, error) {
|
||||||
func (m *mockSubscribeClient) Context() context.Context { return context.Background() }
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(*proto.Events), args.Error(1)
|
||||||
|
}
|
||||||
|
func (m *mockSubscribeClient) Context() context.Context { return m.Called().Get(0).(context.Context) }
|
||||||
|
|
||||||
type mockUserServiceClient struct {
|
type mockUserServiceClient struct {
|
||||||
proto.UserServiceClient
|
mock.Mock
|
||||||
checkFunc func(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockUserServiceClient) Check(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) {
|
func (m *mockUserServiceClient) Check(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) {
|
||||||
return m.checkFunc(ctx, in, opts...)
|
args := m.Called(ctx, in, opts)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(*proto.CheckResponse), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockHealthClient struct {
|
type mockHealthClient struct {
|
||||||
grpc_health_v1.HealthClient
|
mock.Mock
|
||||||
checkFunc func(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockHealthClient) Check(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
|
func (m *mockHealthClient) Check(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
|
||||||
return m.checkFunc(ctx, in, opts...)
|
args := m.Called(ctx, in, opts)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(*grpc_health_v1.HealthCheckResponse), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHealthClient) Watch(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (grpc_health_v1.Health_WatchClient, error) {
|
||||||
|
args := m.Called(ctx, in, opts)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(grpc_health_v1.Health_WatchClient), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockHealthClient) List(ctx context.Context, in *grpc_health_v1.HealthListRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthListResponse, error) {
|
||||||
|
args := m.Called(ctx, in, opts)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Error(1)
|
||||||
|
}
|
||||||
|
return args.Get(0).(*grpc_health_v1.HealthListResponse), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProtoToTunnelType(t *testing.T) {
|
func TestProtoToTunnelType(t *testing.T) {
|
||||||
|
|||||||
@@ -1,40 +1,42 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockRequestHeader struct {
|
type mockRequestHeader struct {
|
||||||
headers map[string]string
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRequestHeader) Value(key string) string {
|
func (m *mockRequestHeader) Value(key string) string {
|
||||||
return m.headers[key]
|
return m.Called(key).String(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRequestHeader) Set(key string, value string) {
|
func (m *mockRequestHeader) Set(key string, value string) {
|
||||||
m.headers[key] = value
|
m.Called(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRequestHeader) Remove(key string) {
|
func (m *mockRequestHeader) Remove(key string) {
|
||||||
delete(m.headers, key)
|
m.Called(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRequestHeader) Finalize() []byte {
|
func (m *mockRequestHeader) Finalize() []byte {
|
||||||
return []byte{}
|
return m.Called().Get(0).([]byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRequestHeader) Method() string {
|
func (m *mockRequestHeader) Method() string {
|
||||||
return ""
|
return m.Called().String(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRequestHeader) Path() string {
|
func (m *mockRequestHeader) Path() string {
|
||||||
return ""
|
return m.Called().String(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRequestHeader) Version() string {
|
func (m *mockRequestHeader) Version() string {
|
||||||
return ""
|
return m.Called().String(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestForwardedFor_HandleRequest(t *testing.T) {
|
func TestForwardedFor_HandleRequest(t *testing.T) {
|
||||||
@@ -73,23 +75,19 @@ func TestForwardedFor_HandleRequest(t *testing.T) {
|
|||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
ff := NewForwardedFor(tc.addr)
|
ff := NewForwardedFor(tc.addr)
|
||||||
reqHeader := &mockRequestHeader{headers: make(map[string]string)}
|
reqHeader := new(mockRequestHeader)
|
||||||
|
|
||||||
|
if !tc.expectError {
|
||||||
|
reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return()
|
||||||
|
}
|
||||||
|
|
||||||
err := ff.HandleRequest(reqHeader)
|
err := ff.HandleRequest(reqHeader)
|
||||||
|
|
||||||
if tc.expectError {
|
if tc.expectError {
|
||||||
if err == nil {
|
assert.Error(t, err)
|
||||||
t.Fatalf("expected error but got none")
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if err != nil {
|
assert.NoError(t, err)
|
||||||
t.Fatalf("unexpected error: %v", err)
|
reqHeader.AssertExpectations(t)
|
||||||
}
|
|
||||||
|
|
||||||
host := reqHeader.Value("X-Forwarded-For")
|
|
||||||
if host != tc.expectedHost {
|
|
||||||
t.Errorf("expected X-Forwarded-For header to be '%s', got '%s'", tc.expectedHost, host)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -121,10 +119,7 @@ func TestNewForwardedFor(t *testing.T) {
|
|||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
ff := NewForwardedFor(tc.addr)
|
ff := NewForwardedFor(tc.addr)
|
||||||
|
assert.Equal(t, tc.expectAddr.String(), ff.addr.String())
|
||||||
if ff.addr.String() != tc.expectAddr.String() {
|
|
||||||
t.Errorf("expected addr to be '%v', got '%v'", tc.expectAddr, ff.addr)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,48 +1,46 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockResponseHeader struct {
|
type mockResponseHeader struct {
|
||||||
headers map[string]string
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockResponseHeader) Value(key string) string {
|
func (m *mockResponseHeader) Value(key string) string {
|
||||||
return m.headers[key]
|
return m.Called(key).String(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockResponseHeader) Set(key string, value string) {
|
func (m *mockResponseHeader) Set(key string, value string) {
|
||||||
m.headers[key] = value
|
m.Called(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockResponseHeader) Remove(key string) {
|
func (m *mockResponseHeader) Remove(key string) {
|
||||||
delete(m.headers, key)
|
m.Called(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockResponseHeader) Finalize() []byte {
|
func (m *mockResponseHeader) Finalize() []byte {
|
||||||
return nil
|
return m.Called().Get(0).([]byte)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTunnelFingerprintHandleResponse(t *testing.T) {
|
func TestTunnelFingerprintHandleResponse(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
initialState map[string]string
|
|
||||||
expected map[string]string
|
expected map[string]string
|
||||||
body []byte
|
body []byte
|
||||||
wantErr error
|
wantErr error
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Sets Server Header",
|
name: "Sets Server Header",
|
||||||
initialState: map[string]string{},
|
|
||||||
expected: map[string]string{"Server": "Tunnel Please"},
|
expected: map[string]string{"Server": "Tunnel Please"},
|
||||||
body: []byte("Sample body"),
|
body: []byte("Sample body"),
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Overwrites Server Header",
|
name: "Overwrites Server Header",
|
||||||
initialState: map[string]string{"Server": "Old Value"},
|
|
||||||
expected: map[string]string{"Server": "Tunnel Please"},
|
expected: map[string]string{"Server": "Tunnel Please"},
|
||||||
body: nil,
|
body: nil,
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
@@ -51,26 +49,21 @@ func TestTunnelFingerprintHandleResponse(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
mockHeader := &mockResponseHeader{headers: tt.initialState}
|
mockHeader := new(mockResponseHeader)
|
||||||
|
for k, v := range tt.expected {
|
||||||
|
mockHeader.On("Set", k, v).Return()
|
||||||
|
}
|
||||||
|
|
||||||
tunnelFingerprint := NewTunnelFingerprint()
|
tunnelFingerprint := NewTunnelFingerprint()
|
||||||
|
|
||||||
err := tunnelFingerprint.HandleResponse(mockHeader, tt.body)
|
err := tunnelFingerprint.HandleResponse(mockHeader, tt.body)
|
||||||
if !errors.Is(err, tt.wantErr) {
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
t.Fatalf("unexpected error, got: %v, want: %v", err, tt.wantErr)
|
mockHeader.AssertExpectations(t)
|
||||||
}
|
|
||||||
|
|
||||||
for key, expectedValue := range tt.expected {
|
|
||||||
if val := mockHeader.Value(key); val != expectedValue {
|
|
||||||
t.Errorf("header[%q] = %q; want %q", key, val, expectedValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewTunnelFingerprint(t *testing.T) {
|
func TestNewTunnelFingerprint(t *testing.T) {
|
||||||
instance := NewTunnelFingerprint()
|
instance := NewTunnelFingerprint()
|
||||||
if instance == nil {
|
assert.NotNil(t, instance)
|
||||||
t.Errorf("NewTunnelFingerprint() = nil; want non-nil instance")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
+16
-17
@@ -1,6 +1,7 @@
|
|||||||
package port
|
package port
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,8 +21,10 @@ func TestAddRange(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
pm := New()
|
pm := New()
|
||||||
err := pm.AddRange(tt.startPort, tt.endPort)
|
err := pm.AddRange(tt.startPort, tt.endPort)
|
||||||
if (err != nil) != tt.wantErr {
|
if tt.wantErr {
|
||||||
t.Errorf("AddRange() error = %v, wantErr %v", err, tt.wantErr)
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -48,9 +51,8 @@ func TestUnassigned(t *testing.T) {
|
|||||||
_ = pm.SetStatus(k, v)
|
_ = pm.SetStatus(k, v)
|
||||||
}
|
}
|
||||||
got, gotOk := pm.Unassigned()
|
got, gotOk := pm.Unassigned()
|
||||||
if got != tt.want || gotOk != tt.wantOk {
|
assert.Equal(t, tt.want, got)
|
||||||
t.Errorf("Unassigned() got = %v, want %v, gotOk = %v, wantOk %v", got, tt.want, gotOk, tt.wantOk)
|
assert.Equal(t, tt.wantOk, gotOk)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -70,12 +72,12 @@ func TestSetStatus(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := pm.SetStatus(tt.port, tt.assigned); err != nil {
|
err := pm.SetStatus(tt.port, tt.assigned)
|
||||||
t.Errorf("SetStatus() error = %v", err)
|
assert.NoError(t, err)
|
||||||
}
|
|
||||||
if status, _ := pm.(*port).ports[tt.port]; status != tt.assigned {
|
status, ok := pm.(*port).ports[tt.port]
|
||||||
t.Errorf("SetStatus() failed, port %v has status %v, want %v", tt.port, status, tt.assigned)
|
assert.True(t, ok)
|
||||||
}
|
assert.Equal(t, tt.assigned, status)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -102,13 +104,10 @@ func TestClaim(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
got := pm.Claim(tt.port)
|
got := pm.Claim(tt.port)
|
||||||
if got != tt.want {
|
assert.Equal(t, tt.want, got)
|
||||||
t.Errorf("Claim() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
|
|
||||||
if finalState := pm.(*port).ports[tt.port]; finalState != true {
|
finalState := pm.(*port).ports[tt.port]
|
||||||
t.Errorf("Claim() did not update port %v status to 'assigned'", tt.port)
|
assert.True(t, finalState)
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,12 @@
|
|||||||
package random
|
package random
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type brainrotReader struct {
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *brainrotReader) Read(p []byte) (int, error) {
|
|
||||||
return 0, f.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRandom_String(t *testing.T) {
|
func TestRandom_String(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -32,20 +24,18 @@ func TestRandom_String(t *testing.T) {
|
|||||||
randomizer := New()
|
randomizer := New()
|
||||||
|
|
||||||
result, err := randomizer.String(tt.length)
|
result, err := randomizer.String(tt.length)
|
||||||
if (err != nil) != tt.wantErr {
|
if tt.wantErr {
|
||||||
t.Errorf("String() error = %v, wantErr %v", err, tt.wantErr)
|
assert.Error(t, err)
|
||||||
return
|
} else {
|
||||||
}
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, result, tt.length)
|
||||||
if !tt.wantErr && len(result) != tt.length {
|
|
||||||
t.Errorf("String() length = %v, want %v", len(result), tt.length)
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRandomWithFailingReader_String(t *testing.T) {
|
func TestRandomWithFailingReader_String(t *testing.T) {
|
||||||
errBrainrot := fmt.Errorf("you are not sigma enough")
|
errBrainrot := assert.AnError
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -54,7 +44,9 @@ func TestRandomWithFailingReader_String(t *testing.T) {
|
|||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "failing reader",
|
name: "failing reader",
|
||||||
reader: &brainrotReader{err: errBrainrot},
|
reader: func() io.Reader {
|
||||||
|
return &failingReader{err: errBrainrot}
|
||||||
|
}(),
|
||||||
expectErr: errBrainrot,
|
expectErr: errBrainrot,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -63,14 +55,16 @@ func TestRandomWithFailingReader_String(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
randomizer := &random{reader: tt.reader}
|
randomizer := &random{reader: tt.reader}
|
||||||
result, err := randomizer.String(20)
|
result, err := randomizer.String(20)
|
||||||
if !errors.Is(err, tt.expectErr) {
|
assert.ErrorIs(t, err, tt.expectErr)
|
||||||
t.Errorf("String() error = %v, wantErr %v", err, tt.expectErr)
|
assert.Empty(t, result)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != "" {
|
|
||||||
t.Errorf("String() result = %v, want an empty string due to error", result)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type failingReader struct {
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *failingReader) Read(p []byte) (int, error) {
|
||||||
|
return 0, f.err
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -15,47 +17,109 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockSession struct{ user string }
|
type mockSession struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockSession) Lifecycle() lifecycle.Lifecycle { return &mockLifecycle{user: m.user} }
|
func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
|
||||||
func (m *mockSession) Interaction() interaction.Interaction {
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return args.Get(0).(lifecycle.Lifecycle)
|
||||||
|
}
|
||||||
|
func (m *mockSession) Interaction() interaction.Interaction {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(interaction.Interaction)
|
||||||
|
}
|
||||||
func (m *mockSession) Forwarder() forwarder.Forwarder {
|
func (m *mockSession) Forwarder() forwarder.Forwarder {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return args.Get(0).(forwarder.Forwarder)
|
||||||
|
}
|
||||||
func (m *mockSession) Slug() slug.Slug {
|
func (m *mockSession) Slug() slug.Slug {
|
||||||
return &mockSlug{}
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(slug.Slug)
|
||||||
}
|
}
|
||||||
func (m *mockSession) Detail() *types.Detail {
|
func (m *mockSession) Detail() *types.Detail {
|
||||||
|
args := m.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return args.Get(0).(*types.Detail)
|
||||||
|
}
|
||||||
|
|
||||||
type mockLifecycle struct{ user string }
|
type mockLifecycle struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
func (ml *mockLifecycle) Channel() ssh.Channel {
|
func (ml *mockLifecycle) Channel() ssh.Channel {
|
||||||
|
args := ml.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return args.Get(0).(ssh.Channel)
|
||||||
|
}
|
||||||
|
|
||||||
func (ml *mockLifecycle) Connection() ssh.Conn { return nil }
|
func (ml *mockLifecycle) Connection() ssh.Conn {
|
||||||
func (ml *mockLifecycle) PortRegistry() port.Port { return nil }
|
args := ml.Called()
|
||||||
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { _ = channel }
|
if args.Get(0) == nil {
|
||||||
func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { _ = status }
|
return nil
|
||||||
func (ml *mockLifecycle) IsActive() bool { return false }
|
}
|
||||||
func (ml *mockLifecycle) StartedAt() time.Time { return time.Time{} }
|
return args.Get(0).(ssh.Conn)
|
||||||
func (ml *mockLifecycle) Close() error { return nil }
|
}
|
||||||
func (ml *mockLifecycle) User() string { return ml.user }
|
|
||||||
|
|
||||||
type mockSlug struct{}
|
func (ml *mockLifecycle) PortRegistry() port.Port {
|
||||||
|
args := ml.Called()
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return args.Get(0).(port.Port)
|
||||||
|
}
|
||||||
|
|
||||||
func (ms *mockSlug) Set(slug string) { _ = slug }
|
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { ml.Called(channel) }
|
||||||
func (ms *mockSlug) String() string { return "" }
|
func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { ml.Called(status) }
|
||||||
|
func (ml *mockLifecycle) IsActive() bool { return ml.Called().Bool(0) }
|
||||||
|
func (ml *mockLifecycle) StartedAt() time.Time { return ml.Called().Get(0).(time.Time) }
|
||||||
|
func (ml *mockLifecycle) Close() error { return ml.Called().Error(0) }
|
||||||
|
func (ml *mockLifecycle) User() string { return ml.Called().String(0) }
|
||||||
|
|
||||||
|
type mockSlug struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ms *mockSlug) Set(slug string) { ms.Called(slug) }
|
||||||
|
func (ms *mockSlug) String() string { return ms.Called().String(0) }
|
||||||
|
|
||||||
|
func createMockSession(user ...string) *mockSession {
|
||||||
|
u := "user1"
|
||||||
|
if len(user) > 0 {
|
||||||
|
u = user[0]
|
||||||
|
}
|
||||||
|
m := new(mockSession)
|
||||||
|
ml := new(mockLifecycle)
|
||||||
|
ml.On("User").Return(u).Maybe()
|
||||||
|
m.On("Lifecycle").Return(ml).Maybe()
|
||||||
|
ms := new(mockSlug)
|
||||||
|
ms.On("Set", mock.Anything).Maybe()
|
||||||
|
m.On("Slug").Return(ms).Maybe()
|
||||||
|
m.On("Interaction").Return(nil).Maybe()
|
||||||
|
m.On("Forwarder").Return(nil).Maybe()
|
||||||
|
m.On("Detail").Return(nil).Maybe()
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewRegistry(t *testing.T) {
|
func TestNewRegistry(t *testing.T) {
|
||||||
r := NewRegistry()
|
r := NewRegistry()
|
||||||
if r == nil {
|
require.NotNil(t, r)
|
||||||
t.Fatal("NewRegistry returned nil")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRegistry_Get(t *testing.T) {
|
func TestRegistry_Get(t *testing.T) {
|
||||||
@@ -71,7 +135,7 @@ func TestRegistry_Get(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) {
|
setupFunc: func(r *registry) {
|
||||||
user := "user1"
|
user := "user1"
|
||||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: user}
|
session := createMockSession(user)
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -113,13 +177,8 @@ func TestRegistry_Get(t *testing.T) {
|
|||||||
|
|
||||||
session, err := r.Get(tt.key)
|
session, err := r.Get(tt.key)
|
||||||
|
|
||||||
if !errors.Is(err, tt.wantErr) {
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
t.Fatalf("expected error %v, got %v", tt.wantErr, err)
|
assert.Equal(t, tt.wantResult, session != nil)
|
||||||
}
|
|
||||||
|
|
||||||
if (session != nil) != tt.wantResult {
|
|
||||||
t.Fatalf("expected session existence to be %v, got %v", tt.wantResult, session != nil)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -138,7 +197,7 @@ func TestRegistry_GetWithUser(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) {
|
setupFunc: func(r *registry) {
|
||||||
user := "user1"
|
user := "user1"
|
||||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: user}
|
session := createMockSession()
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -183,13 +242,8 @@ func TestRegistry_GetWithUser(t *testing.T) {
|
|||||||
|
|
||||||
session, err := r.GetWithUser(tt.user, tt.key)
|
session, err := r.GetWithUser(tt.user, tt.key)
|
||||||
|
|
||||||
if !errors.Is(err, tt.wantErr) {
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
t.Fatalf("expected error %v, got %v", tt.wantErr, err)
|
assert.Equal(t, tt.wantResult, session != nil)
|
||||||
}
|
|
||||||
|
|
||||||
if (session != nil) != tt.wantResult {
|
|
||||||
t.Fatalf("expected session existence to be %v, got %v", tt.wantResult, session != nil)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -207,7 +261,7 @@ func TestRegistry_Update(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession("user1")
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -226,7 +280,7 @@ func TestRegistry_Update(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession()
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -247,7 +301,7 @@ func TestRegistry_Update(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP}
|
newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession()
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -266,7 +320,7 @@ func TestRegistry_Update(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP}
|
newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession()
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -285,7 +339,7 @@ func TestRegistry_Update(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||||
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession()
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -304,7 +358,7 @@ func TestRegistry_Update(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||||
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession()
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -323,7 +377,7 @@ func TestRegistry_Update(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession()
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -342,7 +396,7 @@ func TestRegistry_Update(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
||||||
newKey := oldKey
|
newKey := oldKey
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession()
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@@ -371,19 +425,15 @@ func TestRegistry_Update(t *testing.T) {
|
|||||||
oldKey, newKey := tt.setupFunc(r)
|
oldKey, newKey := tt.setupFunc(r)
|
||||||
|
|
||||||
err := r.Update(tt.user, oldKey, newKey)
|
err := r.Update(tt.user, oldKey, newKey)
|
||||||
if !errors.Is(err, tt.wantErr) {
|
assert.ErrorIs(t, err, tt.wantErr)
|
||||||
t.Fatalf("expected error %v, got %v", tt.wantErr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
defer r.mu.RUnlock()
|
defer r.mu.RUnlock()
|
||||||
if _, ok := r.byUser[tt.user][newKey]; !ok {
|
_, ok := r.byUser[tt.user][newKey]
|
||||||
t.Errorf("newKey not found in registry")
|
assert.True(t, ok, "newKey not found in registry")
|
||||||
}
|
_, ok = r.byUser[tt.user][oldKey]
|
||||||
if _, ok := r.byUser[tt.user][oldKey]; ok {
|
assert.False(t, ok, "oldKey still exists in registry")
|
||||||
t.Errorf("oldKey still exists in registry")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -410,7 +460,7 @@ func TestRegistry_Register(t *testing.T) {
|
|||||||
user: "user1",
|
user: "user1",
|
||||||
setupFunc: func(r *registry) Key {
|
setupFunc: func(r *registry) Key {
|
||||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession()
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
r.byUser["user1"] = map[Key]Session{key: session}
|
r.byUser["user1"] = map[Key]Session{key: session}
|
||||||
@@ -426,7 +476,7 @@ func TestRegistry_Register(t *testing.T) {
|
|||||||
user: "user1",
|
user: "user1",
|
||||||
setupFunc: func(r *registry) Key {
|
setupFunc: func(r *registry) Key {
|
||||||
firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP}
|
firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: "user1"}
|
session := createMockSession()
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
r.byUser["user1"] = map[Key]Session{firstKey: session}
|
r.byUser["user1"] = map[Key]Session{firstKey: session}
|
||||||
r.slugIndex[firstKey] = "user1"
|
r.slugIndex[firstKey] = "user1"
|
||||||
@@ -450,22 +500,16 @@ func TestRegistry_Register(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
key := tt.setupFunc(r)
|
key := tt.setupFunc(r)
|
||||||
session := &mockSession{user: tt.user}
|
session := createMockSession()
|
||||||
|
|
||||||
ok := r.Register(key, session)
|
ok := r.Register(key, session)
|
||||||
if ok != tt.wantOK {
|
assert.Equal(t, tt.wantOK, ok)
|
||||||
t.Fatalf("expected success %v, got %v", tt.wantOK, ok)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
defer r.mu.RUnlock()
|
defer r.mu.RUnlock()
|
||||||
if r.byUser[tt.user][key] != session {
|
assert.Equal(t, session, r.byUser[tt.user][key], "session not stored in byUser")
|
||||||
t.Errorf("session not stored in byUser")
|
assert.Equal(t, tt.user, r.slugIndex[key], "slugIndex not updated")
|
||||||
}
|
|
||||||
if r.slugIndex[key] != tt.user {
|
|
||||||
t.Errorf("slugIndex not updated")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -492,8 +536,8 @@ func TestRegistry_GetAllSessionFromUser(t *testing.T) {
|
|||||||
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
|
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
r.byUser[user] = map[Key]Session{
|
r.byUser[user] = map[Key]Session{
|
||||||
key1: &mockSession{user: user},
|
key1: createMockSession(),
|
||||||
key2: &mockSession{user: user},
|
key2: createMockSession(),
|
||||||
}
|
}
|
||||||
r.mu.Unlock()
|
r.mu.Unlock()
|
||||||
return user
|
return user
|
||||||
@@ -511,9 +555,7 @@ func TestRegistry_GetAllSessionFromUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
user := tt.setupFunc(r)
|
user := tt.setupFunc(r)
|
||||||
sessions := r.GetAllSessionFromUser(user)
|
sessions := r.GetAllSessionFromUser(user)
|
||||||
if len(sessions) != tt.expectN {
|
assert.Len(t, sessions, tt.expectN)
|
||||||
t.Errorf("expected %d sessions, got %d", tt.expectN, len(sessions))
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -530,7 +572,7 @@ func TestRegistry_Remove(t *testing.T) {
|
|||||||
setupFunc: func(r *registry) (string, types.SessionKey) {
|
setupFunc: func(r *registry) (string, types.SessionKey) {
|
||||||
user := "user1"
|
user := "user1"
|
||||||
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||||
session := &mockSession{user: user}
|
session := createMockSession()
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
r.byUser[user] = map[Key]Session{key: session}
|
r.byUser[user] = map[Key]Session{key: session}
|
||||||
r.slugIndex[key] = user
|
r.slugIndex[key] = user
|
||||||
@@ -538,15 +580,12 @@ func TestRegistry_Remove(t *testing.T) {
|
|||||||
return user, key
|
return user, key
|
||||||
},
|
},
|
||||||
verify: func(t *testing.T, r *registry, user string, key types.SessionKey) {
|
verify: func(t *testing.T, r *registry, user string, key types.SessionKey) {
|
||||||
if _, ok := r.byUser[user][key]; ok {
|
_, ok := r.byUser[user][key]
|
||||||
t.Errorf("expected key to be removed from byUser")
|
assert.False(t, ok, "expected key to be removed from byUser")
|
||||||
}
|
_, ok = r.slugIndex[key]
|
||||||
if _, ok := r.slugIndex[key]; ok {
|
assert.False(t, ok, "expected key to be removed from slugIndex")
|
||||||
t.Errorf("expected key to be removed from slugIndex")
|
_, ok = r.byUser[user]
|
||||||
}
|
assert.False(t, ok, "expected user to be removed from byUser map")
|
||||||
if _, ok := r.byUser[user]; ok {
|
|
||||||
t.Errorf("expected user to be removed from byUser map")
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -144,19 +144,6 @@ func (m *MockLifecycle) Close() error {
|
|||||||
return args.Error(0)
|
return args.Error(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockSSHConn struct {
|
|
||||||
ssh.Conn
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
|
||||||
args := m.Called(name, data)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
||||||
}
|
|
||||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockSSHChannel struct {
|
type MockSSHChannel struct {
|
||||||
ssh.Channel
|
ssh.Channel
|
||||||
mock.Mock
|
mock.Mock
|
||||||
|
|||||||
@@ -73,11 +73,7 @@ func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (
|
|||||||
case resultChan <- channelResult{channel, reqs, err}:
|
case resultChan <- channelResult{channel, reqs, err}:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
if channel != nil {
|
if channel != nil {
|
||||||
err = channel.Close()
|
_ = channel.Close()
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to close unused channel: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
go ssh.DiscardRequests(reqs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -116,10 +112,7 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string)
|
|||||||
|
|
||||||
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||||
defer func() {
|
defer func() {
|
||||||
_, err := io.Copy(io.Discard, src)
|
_, _ = io.Copy(io.Discard, src)
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to discard connection: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package lifecycle
|
package lifecycle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@@ -65,10 +66,10 @@ func (m *MockForwarder) Listener() net.Listener {
|
|||||||
return args.Get(0).(net.Listener)
|
return args.Get(0).(net.Listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
args := m.Called(payload)
|
args := m.Called(ctx, origin)
|
||||||
if args.Get(0) == nil {
|
if args.Get(0) == nil {
|
||||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
return nil, nil, args.Error(2)
|
||||||
}
|
}
|
||||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||||
}
|
}
|
||||||
@@ -208,7 +209,8 @@ func TestLifecycle_SetStatus(t *testing.T) {
|
|||||||
|
|
||||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||||
|
|
||||||
assert.NotNil(t, mockLifecycle.StartedAt())
|
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||||
|
assert.True(t, mockLifecycle.IsActive())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLifecycle_IsActive(t *testing.T) {
|
func TestLifecycle_IsActive(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user