From 04c9ddbc136d8b072c27886cf7b699b4342d0e19 Mon Sep 17 00:00:00 2001 From: bagas Date: Sat, 24 Jan 2026 20:28:30 +0700 Subject: [PATCH] test(lifecycle): add unit tests for lifecycle behavior --- internal/registry/registry_test.go | 4 + internal/transport/httphandler_test.go | 5 + session/lifecycle/lifecycle.go | 10 +- session/lifecycle/lifecycle_test.go | 322 +++++++++++++++++++++++++ 4 files changed, 338 insertions(+), 3 deletions(-) create mode 100644 session/lifecycle/lifecycle_test.go diff --git a/internal/registry/registry_test.go b/internal/registry/registry_test.go index 9b3d026..ce3c14b 100644 --- a/internal/registry/registry_test.go +++ b/internal/registry/registry_test.go @@ -33,6 +33,10 @@ func (m *mockSession) Detail() *types.Detail { type mockLifecycle struct{ user string } +func (ml *mockLifecycle) Channel() ssh.Channel { + return nil +} + func (ml *mockLifecycle) Connection() ssh.Conn { return nil } func (ml *mockLifecycle) PortRegistry() port.Port { return nil } func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { _ = channel } diff --git a/internal/transport/httphandler_test.go b/internal/transport/httphandler_test.go index 3159366..984f9b1 100644 --- a/internal/transport/httphandler_test.go +++ b/internal/transport/httphandler_test.go @@ -100,6 +100,11 @@ type MockLifecycle struct { mock.Mock } +func (m *MockLifecycle) Channel() ssh.Channel { + args := m.Called() + return args.Get(0).(ssh.Channel) +} + func (m *MockLifecycle) Connection() ssh.Conn { args := m.Called() return args.Get(0).(ssh.Conn) diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 234bff8..2d9dd3e 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -54,6 +54,7 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti type Lifecycle interface { Connection() ssh.Conn + Channel() ssh.Channel PortRegistry() portUtil.Port User() string SetChannel(channel ssh.Channel) @@ -74,16 +75,19 @@ func (l *lifecycle) User() string { func (l *lifecycle) SetChannel(channel ssh.Channel) { l.channel = channel } + +func (l *lifecycle) Channel() ssh.Channel { + return l.channel +} + func (l *lifecycle) Connection() ssh.Conn { return l.conn } + func (l *lifecycle) SetStatus(status types.SessionStatus) { l.mu.Lock() defer l.mu.Unlock() l.status = status - if status == types.SessionStatusRUNNING && l.startedAt.IsZero() { - l.startedAt = time.Now() - } } func (l *lifecycle) IsActive() bool { diff --git a/session/lifecycle/lifecycle_test.go b/session/lifecycle/lifecycle_test.go new file mode 100644 index 0000000..5a8deb0 --- /dev/null +++ b/session/lifecycle/lifecycle_test.go @@ -0,0 +1,322 @@ +package lifecycle + +import ( + "errors" + "io" + "net" + "testing" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/ssh" +) + +type MockSessionRegistry struct { + mock.Mock +} + +func (m *MockSessionRegistry) Remove(key types.SessionKey) { + m.Called(key) +} + +type MockForwarder struct { + mock.Mock +} + +func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { + args := m.Called(origin) + return args.Get(0).([]byte) +} + +func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) { + m.Called(dst) +} + +func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { + m.Called(dst, src) +} + +func (m *MockForwarder) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockForwarder) TunnelType() types.TunnelType { + args := m.Called() + return args.Get(0).(types.TunnelType) +} + +func (m *MockForwarder) ForwardedPort() uint16 { + args := m.Called() + return args.Get(0).(uint16) +} + +func (m *MockForwarder) SetType(tunnelType types.TunnelType) { + m.Called(tunnelType) +} + +func (m *MockForwarder) SetForwardedPort(port uint16) { + m.Called(port) +} + +func (m *MockForwarder) SetListener(listener net.Listener) { + m.Called(listener) +} + +func (m *MockForwarder) Listener() net.Listener { + args := m.Called() + return args.Get(0).(net.Listener) +} + +func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(payload) + if args.Get(0) == nil { + return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) + } + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +type MockPort struct { + mock.Mock +} + +func (m *MockPort) AddRange(startPort, endPort uint16) error { + return m.Called(startPort, endPort).Error(0) +} +func (m *MockPort) Unassigned() (uint16, bool) { + args := m.Called() + var port uint16 + if args.Get(0) != nil { + switch v := args.Get(0).(type) { + case int: + port = uint16(v) + case uint16: + port = v + case uint32: + port = uint16(v) + case int32: + port = uint16(v) + case float64: + port = uint16(v) + default: + port = uint16(args.Int(0)) + } + } + return port, args.Bool(1) +} +func (m *MockPort) SetStatus(port uint16, assigned bool) error { + return m.Called(port, assigned).Error(0) +} +func (m *MockPort) Claim(port uint16) bool { + return m.Called(port).Bool(0) +} + +type MockSlug struct { + mock.Mock +} + +func (ms *MockSlug) Set(slug string) { + ms.Called(slug) +} +func (ms *MockSlug) String() string { + return ms.Called().String(0) +} + +type MockSSHConn struct { + ssh.Conn + mock.Mock +} + +func (m *MockSSHConn) Close() error { + args := m.Called() + return args.Error(0) +} + +type MockSSHChannel struct { + ssh.Channel + mock.Mock +} + +func (m *MockSSHChannel) Close() error { + return m.Called().Error(0) +} + +type mockNewChannel struct { + ssh.NewChannel + mock.Mock +} + +func (m *mockNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called() + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(name, data) + if args.Get(0) == nil { + return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) + } + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +func TestNew(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + assert.NotNil(t, mockLifecycle.Connection()) + assert.NotNil(t, mockLifecycle.User()) + assert.NotNil(t, mockLifecycle.PortRegistry()) + assert.NotNil(t, mockLifecycle.StartedAt()) +} + +func TestLifecycle_User(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + user := "mas-fuad" + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user) + assert.Equal(t, user, mockLifecycle.User()) +} + +func TestLifecycle_SetChannel(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + mockSSHChannel := &MockSSHChannel{} + + mockLifecycle.SetChannel(mockSSHChannel) + + assert.Equal(t, mockSSHChannel, mockLifecycle.Channel()) +} + +func TestLifecycle_SetStatus(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + assert.NotNil(t, mockLifecycle.StartedAt()) +} + +func TestLifecycle_IsActive(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + mockLifecycle.SetStatus(types.SessionStatusRUNNING) + assert.True(t, mockLifecycle.IsActive()) +} + +func TestLifecycle_Close(t *testing.T) { + tests := []struct { + name string + tunnelType types.TunnelType + connCloseErr error + channelCloseErr error + expectErr bool + alreadyClosed bool + }{ + { + name: "Close HTTP forwarding success", + tunnelType: types.TunnelTypeHTTP, + expectErr: false, + }, + { + name: "Close TCP forwarding success", + tunnelType: types.TunnelTypeTCP, + expectErr: false, + }, + { + name: "Close with conn close error", + tunnelType: types.TunnelTypeHTTP, + connCloseErr: errors.New("conn close error"), + expectErr: true, + }, + { + name: "Close with channel close error", + tunnelType: types.TunnelTypeHTTP, + channelCloseErr: errors.New("channel close error"), + expectErr: true, + }, + { + name: "Close when already closed", + tunnelType: types.TunnelTypeHTTP, + alreadyClosed: true, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSSHConn := &MockSSHConn{} + mockSSHConn.On("Close").Return(tt.connCloseErr) + + mockForwarder := &MockForwarder{} + mockForwarder.On("TunnelType").Return(tt.tunnelType) + if tt.tunnelType == types.TunnelTypeTCP { + mockForwarder.On("ForwardedPort").Return(uint16(8080)) + mockForwarder.On("Close").Return(nil) + } + + mockSlug := &MockSlug{} + mockSlug.On("String").Return("test-slug") + + mockPort := &MockPort{} + if tt.tunnelType == types.TunnelTypeTCP { + mockPort.On("SetStatus", uint16(8080), false).Return(nil) + } + + mockSessionRegistry := &MockSessionRegistry{} + mockSessionRegistry.On("Remove", mock.Anything).Return() + + mockSSHChannel := &MockSSHChannel{} + mockSSHChannel.On("Close").Return(tt.channelCloseErr) + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + mockLifecycle.SetStatus(types.SessionStatusRUNNING) + mockLifecycle.SetChannel(mockSSHChannel) + + if tt.alreadyClosed { + mockLifecycle.Close() + } + + err := mockLifecycle.Close() + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.False(t, mockLifecycle.IsActive()) + + mockSSHConn.AssertExpectations(t) + mockForwarder.AssertExpectations(t) + mockSlug.AssertExpectations(t) + mockPort.AssertExpectations(t) + mockSessionRegistry.AssertExpectations(t) + mockSSHChannel.AssertExpectations(t) + }) + } +}