diff --git a/server/server_test.go b/server/server_test.go index c35073a..572fa1f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" "tunnel_pls/internal/registry" + "tunnel_pls/session/slug" "tunnel_pls/types" "github.com/stretchr/testify/assert" @@ -18,11 +19,11 @@ import ( "google.golang.org/grpc" ) -type mockRandom struct { +type MockRandom struct { mock.Mock } -func (m *mockRandom) String(length int) (string, error) { +func (m *MockRandom) String(length int) (string, error) { args := m.Called(length) return args.String(0), args.Error(1) } @@ -46,17 +47,30 @@ func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0 func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) } -func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } -func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } -func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } -func (m *MockConfig) NodeToken() string { return m.Called().String(0) } -func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { + args := m.Called() + if args.Get(0) == nil { + return 0 + } + switch v := args.Get(0).(type) { + case types.ServerMode: + return v + case int: + return types.ServerMode(v) + default: + return types.ServerMode(args.Int(0)) + } +} +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } -type mockRegistry struct { +type MockSessionRegistry struct { mock.Mock } -func (m *mockRegistry) Get(key types.SessionKey) (registry.Session, error) { +func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) { args := m.Called(key) if args.Get(0) == nil { return nil, args.Error(1) @@ -64,7 +78,7 @@ func (m *mockRegistry) Get(key types.SessionKey) (registry.Session, error) { return args.Get(0).(registry.Session), args.Error(1) } -func (m *mockRegistry) GetWithUser(user string, key types.SessionKey) (registry.Session, error) { +func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { args := m.Called(user, key) if args.Get(0) == nil { return nil, args.Error(1) @@ -72,77 +86,85 @@ func (m *mockRegistry) GetWithUser(user string, key types.SessionKey) (registry. return args.Get(0).(registry.Session), args.Error(1) } -func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool { - return m.Called(key, session).Bool(0) +func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error { + args := m.Called(user, oldKey, newKey) + return args.Error(0) } -func (m *mockRegistry) Update(user string, oldKey types.SessionKey, newKey types.SessionKey) error { - return m.Called(user, oldKey, newKey).Error(0) +func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool { + args := m.Called(key, session) + return args.Bool(0) } -func (m *mockRegistry) Remove(key types.SessionKey) { +func (m *MockSessionRegistry) Remove(key registry.Key) { m.Called(key) } -func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session { - return m.Called(user).Get(0).([]registry.Session) +func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session { + args := m.Called(user) + return args.Get(0).([]registry.Session) } -type mockGrpcClient struct { +func (m *MockSessionRegistry) Slug() slug.Slug { + args := m.Called() + return args.Get(0).(slug.Slug) +} + +type MockGRPCClient struct { mock.Mock } -func (m *mockGrpcClient) SubscribeEvents(ctx context.Context, identity string, authToken string) error { - return m.Called(ctx, identity, authToken).Error(0) -} - -func (m *mockGrpcClient) ClientConn() *grpc.ClientConn { +func (m *MockGRPCClient) ClientConn() *grpc.ClientConn { args := m.Called() - if args.Get(0) == nil { - return nil - } return args.Get(0).(*grpc.ClientConn) } -func (m *mockGrpcClient) AuthorizeConn(ctx context.Context, token string) (bool, string, error) { +func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { args := m.Called(ctx, token) return args.Bool(0), args.String(1), args.Error(2) } -func (m *mockGrpcClient) CheckServerHealth(ctx context.Context) error { - return m.Called(ctx).Error(0) +func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) } -func (m *mockGrpcClient) Close() error { - return m.Called().Error(0) +func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error { + args := m.Called(ctx, domain, token) + return args.Error(0) } -type mockPort struct { +func (m *MockGRPCClient) Close() error { + args := m.Called() + return args.Error(0) +} + +type MockPort struct { mock.Mock } -func (m *mockPort) AddRange(startPort, endPort uint16) error { +func (m *MockPort) AddRange(startPort, endPort uint16) error { return m.Called(startPort, endPort).Error(0) } -func (m *mockPort) Unassigned() (uint16, bool) { +func (m *MockPort) Unassigned() (uint16, bool) { args := m.Called() return uint16(args.Int(0)), args.Bool(1) } -func (m *mockPort) SetStatus(port uint16, assigned bool) error { +func (m *MockPort) SetStatus(port uint16, assigned bool) error { return m.Called(port, assigned).Error(0) } -func (m *mockPort) Claim(port uint16) bool { +func (m *MockPort) Claim(port uint16) bool { return m.Called(port).Bool(0) } -type mockListener struct { +type MockListener struct { mock.Mock } -func (m *mockListener) Accept() (net.Conn, error) { +func (m *MockListener) Accept() (net.Conn, error) { args := m.Called() if args.Get(0) == nil { return nil, args.Error(1) @@ -150,14 +172,22 @@ func (m *mockListener) Accept() (net.Conn, error) { return args.Get(0).(net.Conn), args.Error(1) } -func (m *mockListener) Close() error { +func (m *MockListener) Close() error { return m.Called().Error(0) } -func (m *mockListener) Addr() net.Addr { +func (m *MockListener) Addr() net.Addr { return m.Called().Get(0).(net.Addr) } +type MockSession struct { + mock.Mock +} + +func (m *MockSession) Start() error { + return m.Called().Error(0) +} + func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) { key, _ := rsa.GenerateKey(rand.Reader, 2048) signer, _ := ssh.NewSignerFromKey(key) @@ -169,11 +199,11 @@ func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) { } func TestNew(t *testing.T) { - mr := new(mockRandom) + mr := new(MockRandom) mc := new(MockConfig) - mreg := new(mockRegistry) - mg := new(mockGrpcClient) - mp := new(mockPort) + mreg := new(MockSessionRegistry) + mg := new(MockGRPCClient) + mp := new(MockPort) sc, _ := getTestSSHConfig() tests := []struct { @@ -222,27 +252,46 @@ func TestNew(t *testing.T) { } func TestClose(t *testing.T) { - mr := new(mockRandom) + mr := new(MockRandom) mc := new(MockConfig) - mreg := new(mockRegistry) - mg := new(mockGrpcClient) - mp := new(mockPort) + mreg := new(MockSessionRegistry) + mg := new(MockGRPCClient) + mp := new(MockPort) sc, _ := getTestSSHConfig() - s, _ := New(mr, mc, sc, mreg, mg, mp, "0") - err := s.Close() - assert.NoError(t, err) + t.Run("successful close", func(t *testing.T) { + s, _ := New(mr, mc, sc, mreg, mg, mp, "0") + err := s.Close() + assert.NoError(t, err) + }) - err = s.Close() - assert.Error(t, err) + t.Run("close already closed listener", func(t *testing.T) { + s, _ := New(mr, mc, sc, mreg, mg, mp, "0") + _ = s.Close() + err := s.Close() + assert.Error(t, err) + }) + + t.Run("close with nil listener", func(t *testing.T) { + s := &server{ + sshListener: nil, + } + defer func() { + if r := recover(); r != nil { + assert.NotNil(t, r) + } + }() + _ = s.Close() + t.Fatal("expected panic for nil listener") + }) } func TestStart(t *testing.T) { - mr := new(mockRandom) + mr := new(MockRandom) mc := new(MockConfig) - mreg := new(mockRegistry) - mg := new(mockGrpcClient) - mp := new(mockPort) + mreg := new(MockSessionRegistry) + mg := new(MockGRPCClient) + mp := new(MockPort) sc, _ := getTestSSHConfig() t.Run("normal stop", func(t *testing.T) { @@ -254,8 +303,8 @@ func TestStart(t *testing.T) { s.Start() }) - t.Run("accept error", func(t *testing.T) { - ml := new(mockListener) + t.Run("accept error - temporary error continues loop", func(t *testing.T) { + ml := new(MockListener) s := &server{ sshListener: ml, sshPort: "0", @@ -267,4 +316,544 @@ func TestStart(t *testing.T) { s.Start() ml.AssertExpectations(t) }) + + t.Run("accept error - immediate close", func(t *testing.T) { + ml := new(MockListener) + s := &server{ + sshListener: ml, + sshPort: "0", + } + + ml.On("Accept").Return(nil, net.ErrClosed).Once() + + s.Start() + ml.AssertExpectations(t) + }) + + t.Run("accept success - connection fails SSH handshake", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockGrpcClient := &MockGRPCClient{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + + mockListener := &MockListener{} + mockListener.On("Accept").Return(serverConn, nil).Once() + mockListener.On("Accept").Return(nil, net.ErrClosed).Once() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshListener: mockListener, + sshConfig: sshConfig, + grpcClient: mockGrpcClient, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + go s.Start() + + time.Sleep(50 * time.Millisecond) + clientConn.Close() + time.Sleep(100 * time.Millisecond) + + mockListener.AssertExpectations(t) + }) + + t.Run("accept success - valid SSH connection without auth", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + + mockListener := &MockListener{} + mockListener.On("Accept").Return(serverConn, nil).Once() + mockListener.On("Accept").Return(nil, net.ErrClosed).Once() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshListener: mockListener, + sshConfig: sshConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + go s.Start() + + time.Sleep(50 * time.Millisecond) + clientConn.Close() + time.Sleep(100 * time.Millisecond) + + mockListener.AssertExpectations(t) + }) +} + +func TestHandleConnection(t *testing.T) { + t.Run("SSH handshake fails - connection closed", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockGrpcClient := &MockGRPCClient{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: sshConfig, + grpcClient: mockGrpcClient, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + clientConn.Close() + s.handleConnection(serverConn) + }) + + // SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS + // GONNA IMPLEMENT THIS UNIT TEST LATER + + //t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) { + // mockRandom := &MockRandom{} + // mockConfig := &MockConfig{} + // mockSessionRegistry := &MockSessionRegistry{} + // mockGrpcClient := &MockGRPCClient{} + // mockPort := &MockPort{} + // + // sshConfig, _ := getTestSSHConfig() + // + // serverConn, clientConn := net.Pipe() + // + // s := &server{ + // randomizer: mockRandom, + // config: mockConfig, + // sshPort: "0", + // sshConfig: sshConfig, + // grpcClient: mockGrpcClient, + // sessionRegistry: mockSessionRegistry, + // portRegistry: mockPort, + // } + // + // done := make(chan bool, 1) + // + // go func() { + // s.handleConnection(serverConn) + // done <- true + // }() + // + // go func() { + // clientConn.Write([]byte("invalid ssh protocol\n")) + // clientConn.Close() + // }() + // + // select { + // case <-done: + // case <-time.After(1 * time.Second): + // t.Fatal("handleConnection did not complete in time") + // } + //}) + + t.Run("SSH connection established without gRPC client", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + serverConfig, _ := getTestSSHConfig() + + mockConfig.On("Domain").Return("test.com") + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("SSHPort").Return("2200") + mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil) + mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + mockSessionRegistry.On("Remove", mock.Anything).Return(nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + serverAddr := listener.Addr().String() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: serverConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + done := make(chan bool, 1) + + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + s.handleConnection(conn) + done <- true + }() + + time.Sleep(50 * time.Millisecond) + + clientConfig := &ssh.ClientConfig{ + User: "testuser", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + go func() { + client, err := ssh.Dial("tcp", serverAddr, clientConfig) + if err != nil { + t.Logf("Client dial failed: %v", err) + return + } + defer client.Close() + + type forwardPayload struct { + BindAddr string + BindPort uint32 + } + + payload := ssh.Marshal(forwardPayload{ + BindAddr: "localhost", + BindPort: 80, + }) + + _, _, err = client.SendRequest("tcpip-forward", true, payload) + if err != nil { + t.Logf("Forward request failed: %v", err) + } + + time.Sleep(500 * time.Millisecond) + }() + + select { + case <-done: + t.Log("handleConnection completed") + case <-time.After(5 * time.Second): + t.Fatal("handleConnection did not complete in time") + } + }) + + t.Run("SSH connection established with gRPC authorization", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockGrpcClient := &MockGRPCClient{} + mockPort := &MockPort{} + + serverConfig, _ := getTestSSHConfig() + + mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil) + mockConfig.On("Domain").Return("test.com") + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("SSHPort").Return("2200") + mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil) + mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + mockSessionRegistry.On("Remove", mock.Anything).Return(nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + serverAddr := listener.Addr().String() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: serverConfig, + grpcClient: mockGrpcClient, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + done := make(chan bool, 1) + + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + s.handleConnection(conn) + done <- true + }() + + time.Sleep(50 * time.Millisecond) + + clientConfig := &ssh.ClientConfig{ + User: "testuser", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + go func() { + client, err := ssh.Dial("tcp", serverAddr, clientConfig) + if err != nil { + t.Logf("Client dial failed: %v", err) + return + } + defer client.Close() + + type forwardPayload struct { + BindAddr string + BindPort uint32 + } + + payload := ssh.Marshal(forwardPayload{ + BindAddr: "localhost", + BindPort: 80, + }) + + _, _, err = client.SendRequest("tcpip-forward", true, payload) + if err != nil { + t.Logf("Forward request failed: %v", err) + } + + time.Sleep(500 * time.Millisecond) + }() + + select { + case <-done: + mockGrpcClient.AssertExpectations(t) + case <-time.After(5 * time.Second): + t.Fatal("handleConnection did not complete in time") + } + }) + + t.Run("SSH connection with gRPC authorization error", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockGrpcClient := &MockGRPCClient{} + mockPort := &MockPort{} + + serverConfig, _ := getTestSSHConfig() + + mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil) + mockConfig.On("Domain").Return("test.com") + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("SSHPort").Return("2200") + mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil) + mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + mockSessionRegistry.On("Remove", mock.Anything).Return(nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + serverAddr := listener.Addr().String() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: serverConfig, + grpcClient: mockGrpcClient, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + done := make(chan bool, 1) + + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + s.handleConnection(conn) + done <- true + }() + + time.Sleep(50 * time.Millisecond) + + clientConfig := &ssh.ClientConfig{ + User: "testuser", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + go func() { + client, err := ssh.Dial("tcp", serverAddr, clientConfig) + if err != nil { + t.Logf("Client dial failed: %v", err) + return + } + defer client.Close() + + type forwardPayload struct { + BindAddr string + BindPort uint32 + } + + payload := ssh.Marshal(forwardPayload{ + BindAddr: "localhost", + BindPort: 8080, + }) + + _, _, err = client.SendRequest("tcpip-forward", true, payload) + if err != nil { + t.Logf("Forward request failed: %v", err) + } + + time.Sleep(500 * time.Millisecond) + }() + + select { + case <-done: + mockGrpcClient.AssertExpectations(t) + case <-time.After(5 * time.Second): + t.Fatal("handleConnection did not complete in time") + } + }) + + t.Run("connection cleanup on close", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + serverConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: serverConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + done := make(chan bool, 1) + + go func() { + s.handleConnection(serverConn) + done <- true + }() + + clientConn.Close() + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("handleConnection did not complete in time") + } + }) +} + +func TestIntegration(t *testing.T) { + t.Run("full server lifecycle", func(t *testing.T) { + mr := new(MockRandom) + mc := new(MockConfig) + mreg := new(MockSessionRegistry) + mg := new(MockGRPCClient) + mp := new(MockPort) + sc, _ := getTestSSHConfig() + + s, err := New(mr, mc, sc, mreg, mg, mp, "0") + assert.NoError(t, err) + assert.NotNil(t, s) + + go func() { + time.Sleep(100 * time.Millisecond) + err := s.Close() + assert.NoError(t, err) + }() + + s.Start() + }) + + t.Run("multiple connections", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + conn1Server, conn1Client := net.Pipe() + conn2Server, conn2Client := net.Pipe() + + mockListener := &MockListener{} + mockListener.On("Accept").Return(conn1Server, nil).Once() + mockListener.On("Accept").Return(conn2Server, nil).Once() + mockListener.On("Accept").Return(nil, net.ErrClosed).Once() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshListener: mockListener, + sshConfig: sshConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + go s.Start() + + time.Sleep(50 * time.Millisecond) + conn1Client.Close() + time.Sleep(50 * time.Millisecond) + conn2Client.Close() + time.Sleep(100 * time.Millisecond) + + mockListener.AssertExpectations(t) + }) +} + +func TestErrorHandling(t *testing.T) { + t.Run("write error during SSH handshake", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + clientConn.Close() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: sshConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + s.handleConnection(serverConn) + }) }