fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 - autoclosed #63
+640
-51
@@ -10,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -18,11 +19,11 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type mockRandom struct {
|
||||
type MockRandom struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockRandom) String(length int) (string, error) {
|
||||
func (m *MockRandom) String(length int) (string, error) {
|
||||
args := m.Called(length)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
@@ -46,17 +47,30 @@ func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
|
||||
func (m *MockConfig) Mode() types.ServerMode {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return 0
|
||||
}
|
||||
switch v := args.Get(0).(type) {
|
||||
case types.ServerMode:
|
||||
return v
|
||||
case int:
|
||||
return types.ServerMode(v)
|
||||
default:
|
||||
return types.ServerMode(args.Int(0))
|
||||
}
|
||||
}
|
||||
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
||||
|
||||
type mockRegistry struct {
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Get(key types.SessionKey) (registry.Session, error) {
|
||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
@@ -64,7 +78,7 @@ func (m *mockRegistry) Get(key types.SessionKey) (registry.Session, error) {
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetWithUser(user string, key types.SessionKey) (registry.Session, error) {
|
||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
@@ -72,77 +86,85 @@ func (m *mockRegistry) GetWithUser(user string, key types.SessionKey) (registry.
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool {
|
||||
return m.Called(key, session).Bool(0)
|
||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
args := m.Called(user, oldKey, newKey)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Update(user string, oldKey types.SessionKey, newKey types.SessionKey) error {
|
||||
return m.Called(user, oldKey, newKey).Error(0)
|
||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
args := m.Called(key, session)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Remove(key types.SessionKey) {
|
||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
return m.Called(user).Get(0).([]registry.Session)
|
||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
args := m.Called(user)
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
type mockGrpcClient struct {
|
||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
type MockGRPCClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) SubscribeEvents(ctx context.Context, identity string, authToken string) error {
|
||||
return m.Called(ctx, identity, authToken).Error(0)
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) ClientConn() *grpc.ClientConn {
|
||||
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*grpc.ClientConn)
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) AuthorizeConn(ctx context.Context, token string) (bool, string, error) {
|
||||
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
||||
args := m.Called(ctx, token)
|
||||
return args.Bool(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) CheckServerHealth(ctx context.Context) error {
|
||||
return m.Called(ctx).Error(0)
|
||||
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) Close() error {
|
||||
return m.Called().Error(0)
|
||||
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
|
||||
args := m.Called(ctx, domain, token)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type mockPort struct {
|
||||
func (m *MockGRPCClient) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockPort struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockPort) AddRange(startPort, endPort uint16) error {
|
||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
||||
return m.Called(startPort, endPort).Error(0)
|
||||
}
|
||||
|
||||
func (m *mockPort) Unassigned() (uint16, bool) {
|
||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
||||
args := m.Called()
|
||||
return uint16(args.Int(0)), args.Bool(1)
|
||||
}
|
||||
|
||||
func (m *mockPort) SetStatus(port uint16, assigned bool) error {
|
||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
||||
return m.Called(port, assigned).Error(0)
|
||||
}
|
||||
|
||||
func (m *mockPort) Claim(port uint16) bool {
|
||||
func (m *MockPort) Claim(port uint16) bool {
|
||||
return m.Called(port).Bool(0)
|
||||
}
|
||||
|
||||
type mockListener struct {
|
||||
type MockListener struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockListener) Accept() (net.Conn, error) {
|
||||
func (m *MockListener) Accept() (net.Conn, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
@@ -150,14 +172,22 @@ func (m *mockListener) Accept() (net.Conn, error) {
|
||||
return args.Get(0).(net.Conn), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockListener) Close() error {
|
||||
func (m *MockListener) Close() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func (m *mockListener) Addr() net.Addr {
|
||||
func (m *MockListener) Addr() net.Addr {
|
||||
return m.Called().Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSession) Start() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
signer, _ := ssh.NewSignerFromKey(key)
|
||||
@@ -169,11 +199,11 @@ func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
mr := new(mockRandom)
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(mockRegistry)
|
||||
mg := new(mockGrpcClient)
|
||||
mp := new(mockPort)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
tests := []struct {
|
||||
@@ -222,27 +252,46 @@ func TestNew(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
mr := new(mockRandom)
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(mockRegistry)
|
||||
mg := new(mockGrpcClient)
|
||||
mp := new(mockPort)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
t.Run("successful close", func(t *testing.T) {
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
err := s.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
err = s.Close()
|
||||
t.Run("close already closed listener", func(t *testing.T) {
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
_ = s.Close()
|
||||
err := s.Close()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("close with nil listener", func(t *testing.T) {
|
||||
s := &server{
|
||||
sshListener: nil,
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
assert.NotNil(t, r)
|
||||
}
|
||||
}()
|
||||
_ = s.Close()
|
||||
t.Fatal("expected panic for nil listener")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
mr := new(mockRandom)
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(mockRegistry)
|
||||
mg := new(mockGrpcClient)
|
||||
mp := new(mockPort)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
t.Run("normal stop", func(t *testing.T) {
|
||||
@@ -254,8 +303,8 @@ func TestStart(t *testing.T) {
|
||||
s.Start()
|
||||
})
|
||||
|
||||
t.Run("accept error", func(t *testing.T) {
|
||||
ml := new(mockListener)
|
||||
t.Run("accept error - temporary error continues loop", func(t *testing.T) {
|
||||
ml := new(MockListener)
|
||||
s := &server{
|
||||
sshListener: ml,
|
||||
sshPort: "0",
|
||||
@@ -267,4 +316,544 @@ func TestStart(t *testing.T) {
|
||||
s.Start()
|
||||
ml.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("accept error - immediate close", func(t *testing.T) {
|
||||
ml := new(MockListener)
|
||||
s := &server{
|
||||
sshListener: ml,
|
||||
sshPort: "0",
|
||||
}
|
||||
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s.Start()
|
||||
ml.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("accept success - connection fails SSH handshake", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
mockListener := &MockListener{}
|
||||
mockListener.On("Accept").Return(serverConn, nil).Once()
|
||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshListener: mockListener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
clientConn.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("accept success - valid SSH connection without auth", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
mockListener := &MockListener{}
|
||||
mockListener.On("Accept").Return(serverConn, nil).Once()
|
||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshListener: mockListener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
clientConn.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
t.Run("SSH handshake fails - connection closed", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
clientConn.Close()
|
||||
s.handleConnection(serverConn)
|
||||
})
|
||||
|
||||
// SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS
|
||||
// GONNA IMPLEMENT THIS UNIT TEST LATER
|
||||
|
||||
//t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) {
|
||||
// mockRandom := &MockRandom{}
|
||||
// mockConfig := &MockConfig{}
|
||||
// mockSessionRegistry := &MockSessionRegistry{}
|
||||
// mockGrpcClient := &MockGRPCClient{}
|
||||
// mockPort := &MockPort{}
|
||||
//
|
||||
// sshConfig, _ := getTestSSHConfig()
|
||||
//
|
||||
// serverConn, clientConn := net.Pipe()
|
||||
//
|
||||
// s := &server{
|
||||
// randomizer: mockRandom,
|
||||
// config: mockConfig,
|
||||
// sshPort: "0",
|
||||
// sshConfig: sshConfig,
|
||||
// grpcClient: mockGrpcClient,
|
||||
// sessionRegistry: mockSessionRegistry,
|
||||
// portRegistry: mockPort,
|
||||
// }
|
||||
//
|
||||
// done := make(chan bool, 1)
|
||||
//
|
||||
// go func() {
|
||||
// s.handleConnection(serverConn)
|
||||
// done <- true
|
||||
// }()
|
||||
//
|
||||
// go func() {
|
||||
// clientConn.Write([]byte("invalid ssh protocol\n"))
|
||||
// clientConn.Close()
|
||||
// }()
|
||||
//
|
||||
// select {
|
||||
// case <-done:
|
||||
// case <-time.After(1 * time.Second):
|
||||
// t.Fatal("handleConnection did not complete in time")
|
||||
// }
|
||||
//})
|
||||
|
||||
t.Run("SSH connection established without gRPC client", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
mockConfig.On("Domain").Return("test.com")
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("SSHPort").Return("2200")
|
||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConnection(conn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
||||
if err != nil {
|
||||
t.Logf("Client dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
type forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
payload := ssh.Marshal(forwardPayload{
|
||||
BindAddr: "localhost",
|
||||
BindPort: 80,
|
||||
})
|
||||
|
||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
||||
if err != nil {
|
||||
t.Logf("Forward request failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("handleConnection completed")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SSH connection established with gRPC authorization", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
||||
mockConfig.On("Domain").Return("test.com")
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("SSHPort").Return("2200")
|
||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConnection(conn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
||||
if err != nil {
|
||||
t.Logf("Client dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
type forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
payload := ssh.Marshal(forwardPayload{
|
||||
BindAddr: "localhost",
|
||||
BindPort: 80,
|
||||
})
|
||||
|
||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
||||
if err != nil {
|
||||
t.Logf("Forward request failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
mockGrpcClient.AssertExpectations(t)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SSH connection with gRPC authorization error", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
||||
mockConfig.On("Domain").Return("test.com")
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("SSHPort").Return("2200")
|
||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConnection(conn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
||||
if err != nil {
|
||||
t.Logf("Client dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
type forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
payload := ssh.Marshal(forwardPayload{
|
||||
BindAddr: "localhost",
|
||||
BindPort: 8080,
|
||||
})
|
||||
|
||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
||||
if err != nil {
|
||||
t.Logf("Forward request failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
mockGrpcClient.AssertExpectations(t)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("connection cleanup on close", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
s.handleConnection(serverConn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
clientConn.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
t.Run("full server lifecycle", func(t *testing.T) {
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
s, err := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err := s.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
s.Start()
|
||||
})
|
||||
|
||||
t.Run("multiple connections", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
conn1Server, conn1Client := net.Pipe()
|
||||
conn2Server, conn2Client := net.Pipe()
|
||||
|
||||
mockListener := &MockListener{}
|
||||
mockListener.On("Accept").Return(conn1Server, nil).Once()
|
||||
mockListener.On("Accept").Return(conn2Server, nil).Once()
|
||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshListener: mockListener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn1Client.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn2Client.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorHandling(t *testing.T) {
|
||||
t.Run("write error during SSH handshake", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
clientConn.Close()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
s.handleConnection(serverConn)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user