304 lines
7.2 KiB
Go
304 lines
7.2 KiB
Go
package lifecycle
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"testing"
|
|
"tunnel_pls/types"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type MockSessionRegistry struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockSessionRegistry) Remove(key types.SessionKey) {
|
|
m.Called(key)
|
|
}
|
|
|
|
type MockForwarder struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
|
args := m.Called(origin)
|
|
return args.Get(0).([]byte)
|
|
}
|
|
|
|
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
|
m.Called(dst, src)
|
|
}
|
|
|
|
func (m *MockForwarder) Close() error {
|
|
args := m.Called()
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockForwarder) TunnelType() types.TunnelType {
|
|
args := m.Called()
|
|
return args.Get(0).(types.TunnelType)
|
|
}
|
|
|
|
func (m *MockForwarder) ForwardedPort() uint16 {
|
|
args := m.Called()
|
|
return args.Get(0).(uint16)
|
|
}
|
|
|
|
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
|
m.Called(tunnelType)
|
|
}
|
|
|
|
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
|
m.Called(port)
|
|
}
|
|
|
|
func (m *MockForwarder) SetListener(listener net.Listener) {
|
|
m.Called(listener)
|
|
}
|
|
|
|
func (m *MockForwarder) Listener() net.Listener {
|
|
args := m.Called()
|
|
return args.Get(0).(net.Listener)
|
|
}
|
|
|
|
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
|
args := m.Called(ctx, origin)
|
|
if args.Get(0) == nil {
|
|
return nil, nil, args.Error(2)
|
|
}
|
|
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
}
|
|
|
|
type MockPort struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
|
return m.Called(startPort, endPort).Error(0)
|
|
}
|
|
func (m *MockPort) Unassigned() (uint16, bool) {
|
|
args := m.Called()
|
|
var port uint16
|
|
if args.Get(0) != nil {
|
|
switch v := args.Get(0).(type) {
|
|
case int:
|
|
port = uint16(v)
|
|
case uint16:
|
|
port = v
|
|
case uint32:
|
|
port = uint16(v)
|
|
case int32:
|
|
port = uint16(v)
|
|
case float64:
|
|
port = uint16(v)
|
|
default:
|
|
port = uint16(args.Int(0))
|
|
}
|
|
}
|
|
return port, args.Bool(1)
|
|
}
|
|
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
|
return m.Called(port, assigned).Error(0)
|
|
}
|
|
func (m *MockPort) Claim(port uint16) bool {
|
|
return m.Called(port).Bool(0)
|
|
}
|
|
|
|
type MockSlug struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (ms *MockSlug) Set(slug string) {
|
|
ms.Called(slug)
|
|
}
|
|
func (ms *MockSlug) String() string {
|
|
return ms.Called().String(0)
|
|
}
|
|
|
|
type MockSSHConn struct {
|
|
ssh.Conn
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockSSHConn) Close() error {
|
|
args := m.Called()
|
|
return args.Error(0)
|
|
}
|
|
|
|
type MockSSHChannel struct {
|
|
ssh.Channel
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockSSHChannel) Close() error {
|
|
return m.Called().Error(0)
|
|
}
|
|
|
|
func TestNew(t *testing.T) {
|
|
mockSSHConn := new(MockSSHConn)
|
|
mockForwarder := &MockForwarder{}
|
|
mockSlug := &MockSlug{}
|
|
mockPort := &MockPort{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
|
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
|
|
assert.NotNil(t, mockLifecycle.Connection())
|
|
assert.NotNil(t, mockLifecycle.User())
|
|
assert.NotNil(t, mockLifecycle.PortRegistry())
|
|
assert.NotNil(t, mockLifecycle.StartedAt())
|
|
}
|
|
|
|
func TestLifecycle_User(t *testing.T) {
|
|
mockSSHConn := new(MockSSHConn)
|
|
mockForwarder := &MockForwarder{}
|
|
mockSlug := &MockSlug{}
|
|
mockPort := &MockPort{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
|
|
user := "mas-fuad"
|
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user)
|
|
assert.Equal(t, user, mockLifecycle.User())
|
|
}
|
|
|
|
func TestLifecycle_SetChannel(t *testing.T) {
|
|
mockSSHConn := new(MockSSHConn)
|
|
mockForwarder := &MockForwarder{}
|
|
mockSlug := &MockSlug{}
|
|
mockPort := &MockPort{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
|
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
|
|
mockSSHChannel := &MockSSHChannel{}
|
|
|
|
mockLifecycle.SetChannel(mockSSHChannel)
|
|
|
|
assert.Equal(t, mockSSHChannel, mockLifecycle.Channel())
|
|
}
|
|
|
|
func TestLifecycle_SetStatus(t *testing.T) {
|
|
mockSSHConn := new(MockSSHConn)
|
|
mockForwarder := &MockForwarder{}
|
|
mockSlug := &MockSlug{}
|
|
mockPort := &MockPort{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
|
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
|
|
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
|
assert.True(t, mockLifecycle.IsActive())
|
|
}
|
|
|
|
func TestLifecycle_IsActive(t *testing.T) {
|
|
mockSSHConn := new(MockSSHConn)
|
|
mockForwarder := &MockForwarder{}
|
|
mockSlug := &MockSlug{}
|
|
mockPort := &MockPort{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
|
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
|
|
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
|
assert.True(t, mockLifecycle.IsActive())
|
|
}
|
|
|
|
func TestLifecycle_Close(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tunnelType types.TunnelType
|
|
connCloseErr error
|
|
channelCloseErr error
|
|
expectErr bool
|
|
alreadyClosed bool
|
|
}{
|
|
{
|
|
name: "Close HTTP forwarding success",
|
|
tunnelType: types.TunnelTypeHTTP,
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "Close TCP forwarding success",
|
|
tunnelType: types.TunnelTypeTCP,
|
|
expectErr: false,
|
|
},
|
|
{
|
|
name: "Close with conn close error",
|
|
tunnelType: types.TunnelTypeHTTP,
|
|
connCloseErr: errors.New("conn close error"),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "Close with channel close error",
|
|
tunnelType: types.TunnelTypeHTTP,
|
|
channelCloseErr: errors.New("channel close error"),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "Close when already closed",
|
|
tunnelType: types.TunnelTypeHTTP,
|
|
alreadyClosed: true,
|
|
expectErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
mockSSHConn := &MockSSHConn{}
|
|
mockSSHConn.On("Close").Return(tt.connCloseErr)
|
|
|
|
mockForwarder := &MockForwarder{}
|
|
mockForwarder.On("TunnelType").Return(tt.tunnelType)
|
|
if tt.tunnelType == types.TunnelTypeTCP {
|
|
mockForwarder.On("ForwardedPort").Return(uint16(8080))
|
|
mockForwarder.On("Close").Return(nil)
|
|
}
|
|
|
|
mockSlug := &MockSlug{}
|
|
mockSlug.On("String").Return("test-slug")
|
|
|
|
mockPort := &MockPort{}
|
|
if tt.tunnelType == types.TunnelTypeTCP {
|
|
mockPort.On("SetStatus", uint16(8080), false).Return(nil)
|
|
}
|
|
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockSessionRegistry.On("Remove", mock.Anything).Return()
|
|
|
|
mockSSHChannel := &MockSSHChannel{}
|
|
mockSSHChannel.On("Close").Return(tt.channelCloseErr)
|
|
|
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
|
|
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
|
mockLifecycle.SetChannel(mockSSHChannel)
|
|
|
|
if tt.alreadyClosed {
|
|
err := mockLifecycle.Close()
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
err := mockLifecycle.Close()
|
|
|
|
if tt.expectErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
assert.False(t, mockLifecycle.IsActive())
|
|
|
|
mockSSHConn.AssertExpectations(t)
|
|
mockForwarder.AssertExpectations(t)
|
|
mockSlug.AssertExpectations(t)
|
|
mockPort.AssertExpectations(t)
|
|
mockSessionRegistry.AssertExpectations(t)
|
|
mockSSHChannel.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|