fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 - autoclosed #63
@@ -33,6 +33,10 @@ func (m *mockSession) Detail() *types.Detail {
|
|||||||
|
|
||||||
type mockLifecycle struct{ user string }
|
type mockLifecycle struct{ user string }
|
||||||
|
|
||||||
|
func (ml *mockLifecycle) Channel() ssh.Channel {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (ml *mockLifecycle) Connection() ssh.Conn { return nil }
|
func (ml *mockLifecycle) Connection() ssh.Conn { return nil }
|
||||||
func (ml *mockLifecycle) PortRegistry() port.Port { return nil }
|
func (ml *mockLifecycle) PortRegistry() port.Port { return nil }
|
||||||
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { _ = channel }
|
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { _ = channel }
|
||||||
|
|||||||
@@ -100,6 +100,11 @@ type MockLifecycle struct {
|
|||||||
mock.Mock
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockLifecycle) Channel() ssh.Channel {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).(ssh.Channel)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *MockLifecycle) Connection() ssh.Conn {
|
func (m *MockLifecycle) Connection() ssh.Conn {
|
||||||
args := m.Called()
|
args := m.Called()
|
||||||
return args.Get(0).(ssh.Conn)
|
return args.Get(0).(ssh.Conn)
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti
|
|||||||
|
|
||||||
type Lifecycle interface {
|
type Lifecycle interface {
|
||||||
Connection() ssh.Conn
|
Connection() ssh.Conn
|
||||||
|
Channel() ssh.Channel
|
||||||
PortRegistry() portUtil.Port
|
PortRegistry() portUtil.Port
|
||||||
User() string
|
User() string
|
||||||
SetChannel(channel ssh.Channel)
|
SetChannel(channel ssh.Channel)
|
||||||
@@ -74,16 +75,19 @@ func (l *lifecycle) User() string {
|
|||||||
func (l *lifecycle) SetChannel(channel ssh.Channel) {
|
func (l *lifecycle) SetChannel(channel ssh.Channel) {
|
||||||
l.channel = channel
|
l.channel = channel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *lifecycle) Channel() ssh.Channel {
|
||||||
|
return l.channel
|
||||||
|
}
|
||||||
|
|
||||||
func (l *lifecycle) Connection() ssh.Conn {
|
func (l *lifecycle) Connection() ssh.Conn {
|
||||||
return l.conn
|
return l.conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lifecycle) SetStatus(status types.SessionStatus) {
|
func (l *lifecycle) SetStatus(status types.SessionStatus) {
|
||||||
l.mu.Lock()
|
l.mu.Lock()
|
||||||
defer l.mu.Unlock()
|
defer l.mu.Unlock()
|
||||||
l.status = status
|
l.status = status
|
||||||
if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
|
|
||||||
l.startedAt = time.Now()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lifecycle) IsActive() bool {
|
func (l *lifecycle) IsActive() bool {
|
||||||
|
|||||||
@@ -0,0 +1,322 @@
|
|||||||
|
package lifecycle
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"tunnel_pls/types"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockSessionRegistry struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSessionRegistry) Remove(key types.SessionKey) {
|
||||||
|
m.Called(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockForwarder struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||||
|
args := m.Called(origin)
|
||||||
|
return args.Get(0).([]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||||
|
m.Called(dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||||
|
m.Called(dst, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) Close() error {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) TunnelType() types.TunnelType {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).(types.TunnelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) ForwardedPort() uint16 {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).(uint16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
||||||
|
m.Called(tunnelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
||||||
|
m.Called(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) SetListener(listener net.Listener) {
|
||||||
|
m.Called(listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) Listener() net.Listener {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).(net.Listener)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
|
args := m.Called(payload)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||||
|
}
|
||||||
|
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockPort struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
||||||
|
return m.Called(startPort, endPort).Error(0)
|
||||||
|
}
|
||||||
|
func (m *MockPort) Unassigned() (uint16, bool) {
|
||||||
|
args := m.Called()
|
||||||
|
var port uint16
|
||||||
|
if args.Get(0) != nil {
|
||||||
|
switch v := args.Get(0).(type) {
|
||||||
|
case int:
|
||||||
|
port = uint16(v)
|
||||||
|
case uint16:
|
||||||
|
port = v
|
||||||
|
case uint32:
|
||||||
|
port = uint16(v)
|
||||||
|
case int32:
|
||||||
|
port = uint16(v)
|
||||||
|
case float64:
|
||||||
|
port = uint16(v)
|
||||||
|
default:
|
||||||
|
port = uint16(args.Int(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return port, args.Bool(1)
|
||||||
|
}
|
||||||
|
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
||||||
|
return m.Called(port, assigned).Error(0)
|
||||||
|
}
|
||||||
|
func (m *MockPort) Claim(port uint16) bool {
|
||||||
|
return m.Called(port).Bool(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockSlug struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ms *MockSlug) Set(slug string) {
|
||||||
|
ms.Called(slug)
|
||||||
|
}
|
||||||
|
func (ms *MockSlug) String() string {
|
||||||
|
return ms.Called().String(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockSSHConn struct {
|
||||||
|
ssh.Conn
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSSHConn) Close() error {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockSSHChannel struct {
|
||||||
|
ssh.Channel
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSSHChannel) Close() error {
|
||||||
|
return m.Called().Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockNewChannel struct {
|
||||||
|
ssh.NewChannel
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||||
|
args := m.Called(name, data)
|
||||||
|
if args.Get(0) == nil {
|
||||||
|
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||||
|
}
|
||||||
|
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNew(t *testing.T) {
|
||||||
|
mockSSHConn := new(MockSSHConn)
|
||||||
|
mockForwarder := &MockForwarder{}
|
||||||
|
mockSlug := &MockSlug{}
|
||||||
|
mockPort := &MockPort{}
|
||||||
|
mockSessionRegistry := &MockSessionRegistry{}
|
||||||
|
|
||||||
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||||
|
|
||||||
|
assert.NotNil(t, mockLifecycle.Connection())
|
||||||
|
assert.NotNil(t, mockLifecycle.User())
|
||||||
|
assert.NotNil(t, mockLifecycle.PortRegistry())
|
||||||
|
assert.NotNil(t, mockLifecycle.StartedAt())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLifecycle_User(t *testing.T) {
|
||||||
|
mockSSHConn := new(MockSSHConn)
|
||||||
|
mockForwarder := &MockForwarder{}
|
||||||
|
mockSlug := &MockSlug{}
|
||||||
|
mockPort := &MockPort{}
|
||||||
|
mockSessionRegistry := &MockSessionRegistry{}
|
||||||
|
|
||||||
|
user := "mas-fuad"
|
||||||
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user)
|
||||||
|
assert.Equal(t, user, mockLifecycle.User())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLifecycle_SetChannel(t *testing.T) {
|
||||||
|
mockSSHConn := new(MockSSHConn)
|
||||||
|
mockForwarder := &MockForwarder{}
|
||||||
|
mockSlug := &MockSlug{}
|
||||||
|
mockPort := &MockPort{}
|
||||||
|
mockSessionRegistry := &MockSessionRegistry{}
|
||||||
|
|
||||||
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||||
|
|
||||||
|
mockSSHChannel := &MockSSHChannel{}
|
||||||
|
|
||||||
|
mockLifecycle.SetChannel(mockSSHChannel)
|
||||||
|
|
||||||
|
assert.Equal(t, mockSSHChannel, mockLifecycle.Channel())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLifecycle_SetStatus(t *testing.T) {
|
||||||
|
mockSSHConn := new(MockSSHConn)
|
||||||
|
mockForwarder := &MockForwarder{}
|
||||||
|
mockSlug := &MockSlug{}
|
||||||
|
mockPort := &MockPort{}
|
||||||
|
mockSessionRegistry := &MockSessionRegistry{}
|
||||||
|
|
||||||
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||||
|
|
||||||
|
assert.NotNil(t, mockLifecycle.StartedAt())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLifecycle_IsActive(t *testing.T) {
|
||||||
|
mockSSHConn := new(MockSSHConn)
|
||||||
|
mockForwarder := &MockForwarder{}
|
||||||
|
mockSlug := &MockSlug{}
|
||||||
|
mockPort := &MockPort{}
|
||||||
|
mockSessionRegistry := &MockSessionRegistry{}
|
||||||
|
|
||||||
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||||
|
|
||||||
|
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||||
|
assert.True(t, mockLifecycle.IsActive())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLifecycle_Close(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tunnelType types.TunnelType
|
||||||
|
connCloseErr error
|
||||||
|
channelCloseErr error
|
||||||
|
expectErr bool
|
||||||
|
alreadyClosed bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Close HTTP forwarding success",
|
||||||
|
tunnelType: types.TunnelTypeHTTP,
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Close TCP forwarding success",
|
||||||
|
tunnelType: types.TunnelTypeTCP,
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Close with conn close error",
|
||||||
|
tunnelType: types.TunnelTypeHTTP,
|
||||||
|
connCloseErr: errors.New("conn close error"),
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Close with channel close error",
|
||||||
|
tunnelType: types.TunnelTypeHTTP,
|
||||||
|
channelCloseErr: errors.New("channel close error"),
|
||||||
|
expectErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Close when already closed",
|
||||||
|
tunnelType: types.TunnelTypeHTTP,
|
||||||
|
alreadyClosed: true,
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mockSSHConn := &MockSSHConn{}
|
||||||
|
mockSSHConn.On("Close").Return(tt.connCloseErr)
|
||||||
|
|
||||||
|
mockForwarder := &MockForwarder{}
|
||||||
|
mockForwarder.On("TunnelType").Return(tt.tunnelType)
|
||||||
|
if tt.tunnelType == types.TunnelTypeTCP {
|
||||||
|
mockForwarder.On("ForwardedPort").Return(uint16(8080))
|
||||||
|
mockForwarder.On("Close").Return(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSlug := &MockSlug{}
|
||||||
|
mockSlug.On("String").Return("test-slug")
|
||||||
|
|
||||||
|
mockPort := &MockPort{}
|
||||||
|
if tt.tunnelType == types.TunnelTypeTCP {
|
||||||
|
mockPort.On("SetStatus", uint16(8080), false).Return(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSessionRegistry := &MockSessionRegistry{}
|
||||||
|
mockSessionRegistry.On("Remove", mock.Anything).Return()
|
||||||
|
|
||||||
|
mockSSHChannel := &MockSSHChannel{}
|
||||||
|
mockSSHChannel.On("Close").Return(tt.channelCloseErr)
|
||||||
|
|
||||||
|
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||||
|
|
||||||
|
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||||
|
mockLifecycle.SetChannel(mockSSHChannel)
|
||||||
|
|
||||||
|
if tt.alreadyClosed {
|
||||||
|
mockLifecycle.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := mockLifecycle.Close()
|
||||||
|
|
||||||
|
if tt.expectErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.False(t, mockLifecycle.IsActive())
|
||||||
|
|
||||||
|
mockSSHConn.AssertExpectations(t)
|
||||||
|
mockForwarder.AssertExpectations(t)
|
||||||
|
mockSlug.AssertExpectations(t)
|
||||||
|
mockPort.AssertExpectations(t)
|
||||||
|
mockSessionRegistry.AssertExpectations(t)
|
||||||
|
mockSSHChannel.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user