feat(testing): add comprehensive test coverage and code quality improvements #76

Merged
bagas merged 47 commits from feat/testing into staging 2026-01-27 16:36:40 +07:00
Showing only changes of commit 8fee8bf92e - Show all commits
+640 -51
View File
@@ -10,6 +10,7 @@ import (
"testing"
"time"
"tunnel_pls/internal/registry"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
@@ -18,11 +19,11 @@ import (
"google.golang.org/grpc"
)
type mockRandom struct {
type MockRandom struct {
mock.Mock
}
func (m *mockRandom) String(length int) (string, error) {
func (m *MockRandom) String(length int) (string, error) {
args := m.Called(length)
return args.String(0), args.Error(1)
}
@@ -46,17 +47,30 @@ func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *MockConfig) Mode() types.ServerMode {
args := m.Called()
if args.Get(0) == nil {
return 0
}
switch v := args.Get(0).(type) {
case types.ServerMode:
return v
case int:
return types.ServerMode(v)
default:
return types.ServerMode(args.Int(0))
}
}
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type mockRegistry struct {
type MockSessionRegistry struct {
mock.Mock
}
func (m *mockRegistry) Get(key types.SessionKey) (registry.Session, error) {
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
@@ -64,7 +78,7 @@ func (m *mockRegistry) Get(key types.SessionKey) (registry.Session, error) {
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *mockRegistry) GetWithUser(user string, key types.SessionKey) (registry.Session, error) {
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
@@ -72,77 +86,85 @@ func (m *mockRegistry) GetWithUser(user string, key types.SessionKey) (registry.
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool {
return m.Called(key, session).Bool(0)
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *mockRegistry) Update(user string, oldKey types.SessionKey, newKey types.SessionKey) error {
return m.Called(user, oldKey, newKey).Error(0)
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *mockRegistry) Remove(key types.SessionKey) {
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session {
return m.Called(user).Get(0).([]registry.Session)
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
type mockGrpcClient struct {
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockGRPCClient struct {
mock.Mock
}
func (m *mockGrpcClient) SubscribeEvents(ctx context.Context, identity string, authToken string) error {
return m.Called(ctx, identity, authToken).Error(0)
}
func (m *mockGrpcClient) ClientConn() *grpc.ClientConn {
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(*grpc.ClientConn)
}
func (m *mockGrpcClient) AuthorizeConn(ctx context.Context, token string) (bool, string, error) {
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
args := m.Called(ctx, token)
return args.Bool(0), args.String(1), args.Error(2)
}
func (m *mockGrpcClient) CheckServerHealth(ctx context.Context) error {
return m.Called(ctx).Error(0)
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *mockGrpcClient) Close() error {
return m.Called().Error(0)
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
args := m.Called(ctx, domain, token)
return args.Error(0)
}
type mockPort struct {
func (m *MockGRPCClient) Close() error {
args := m.Called()
return args.Error(0)
}
type MockPort struct {
mock.Mock
}
func (m *mockPort) AddRange(startPort, endPort uint16) error {
func (m *MockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *mockPort) Unassigned() (uint16, bool) {
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
return uint16(args.Int(0)), args.Bool(1)
}
func (m *mockPort) SetStatus(port uint16, assigned bool) error {
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *mockPort) Claim(port uint16) bool {
func (m *MockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type mockListener struct {
type MockListener struct {
mock.Mock
}
func (m *mockListener) Accept() (net.Conn, error) {
func (m *MockListener) Accept() (net.Conn, error) {
args := m.Called()
if args.Get(0) == nil {
return nil, args.Error(1)
@@ -150,14 +172,22 @@ func (m *mockListener) Accept() (net.Conn, error) {
return args.Get(0).(net.Conn), args.Error(1)
}
func (m *mockListener) Close() error {
func (m *MockListener) Close() error {
return m.Called().Error(0)
}
func (m *mockListener) Addr() net.Addr {
func (m *MockListener) Addr() net.Addr {
return m.Called().Get(0).(net.Addr)
}
type MockSession struct {
mock.Mock
}
func (m *MockSession) Start() error {
return m.Called().Error(0)
}
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
key, _ := rsa.GenerateKey(rand.Reader, 2048)
signer, _ := ssh.NewSignerFromKey(key)
@@ -169,11 +199,11 @@ func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
}
func TestNew(t *testing.T) {
mr := new(mockRandom)
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(mockRegistry)
mg := new(mockGrpcClient)
mp := new(mockPort)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
tests := []struct {
@@ -222,27 +252,46 @@ func TestNew(t *testing.T) {
}
func TestClose(t *testing.T) {
mr := new(mockRandom)
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(mockRegistry)
mg := new(mockGrpcClient)
mp := new(mockPort)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
t.Run("successful close", func(t *testing.T) {
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
err := s.Close()
assert.NoError(t, err)
})
err = s.Close()
t.Run("close already closed listener", func(t *testing.T) {
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
_ = s.Close()
err := s.Close()
assert.Error(t, err)
})
t.Run("close with nil listener", func(t *testing.T) {
s := &server{
sshListener: nil,
}
defer func() {
if r := recover(); r != nil {
assert.NotNil(t, r)
}
}()
_ = s.Close()
t.Fatal("expected panic for nil listener")
})
}
func TestStart(t *testing.T) {
mr := new(mockRandom)
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(mockRegistry)
mg := new(mockGrpcClient)
mp := new(mockPort)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
t.Run("normal stop", func(t *testing.T) {
@@ -254,8 +303,8 @@ func TestStart(t *testing.T) {
s.Start()
})
t.Run("accept error", func(t *testing.T) {
ml := new(mockListener)
t.Run("accept error - temporary error continues loop", func(t *testing.T) {
ml := new(MockListener)
s := &server{
sshListener: ml,
sshPort: "0",
@@ -267,4 +316,544 @@ func TestStart(t *testing.T) {
s.Start()
ml.AssertExpectations(t)
})
t.Run("accept error - immediate close", func(t *testing.T) {
ml := new(MockListener)
s := &server{
sshListener: ml,
sshPort: "0",
}
ml.On("Accept").Return(nil, net.ErrClosed).Once()
s.Start()
ml.AssertExpectations(t)
})
t.Run("accept success - connection fails SSH handshake", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
mockListener := &MockListener{}
mockListener.On("Accept").Return(serverConn, nil).Once()
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshListener: mockListener,
sshConfig: sshConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
go s.Start()
time.Sleep(50 * time.Millisecond)
clientConn.Close()
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
})
t.Run("accept success - valid SSH connection without auth", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
mockListener := &MockListener{}
mockListener.On("Accept").Return(serverConn, nil).Once()
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshListener: mockListener,
sshConfig: sshConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
go s.Start()
time.Sleep(50 * time.Millisecond)
clientConn.Close()
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
})
}
func TestHandleConnection(t *testing.T) {
t.Run("SSH handshake fails - connection closed", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: sshConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
clientConn.Close()
s.handleConnection(serverConn)
})
// SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS
// GONNA IMPLEMENT THIS UNIT TEST LATER
//t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) {
// mockRandom := &MockRandom{}
// mockConfig := &MockConfig{}
// mockSessionRegistry := &MockSessionRegistry{}
// mockGrpcClient := &MockGRPCClient{}
// mockPort := &MockPort{}
//
// sshConfig, _ := getTestSSHConfig()
//
// serverConn, clientConn := net.Pipe()
//
// s := &server{
// randomizer: mockRandom,
// config: mockConfig,
// sshPort: "0",
// sshConfig: sshConfig,
// grpcClient: mockGrpcClient,
// sessionRegistry: mockSessionRegistry,
// portRegistry: mockPort,
// }
//
// done := make(chan bool, 1)
//
// go func() {
// s.handleConnection(serverConn)
// done <- true
// }()
//
// go func() {
// clientConn.Write([]byte("invalid ssh protocol\n"))
// clientConn.Close()
// }()
//
// select {
// case <-done:
// case <-time.After(1 * time.Second):
// t.Fatal("handleConnection did not complete in time")
// }
//})
t.Run("SSH connection established without gRPC client", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
mockConfig.On("Domain").Return("test.com")
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("SSHPort").Return("2200")
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
serverAddr := listener.Addr().String()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
s.handleConnection(conn)
done <- true
}()
time.Sleep(50 * time.Millisecond)
clientConfig := &ssh.ClientConfig{
User: "testuser",
Auth: []ssh.AuthMethod{ssh.Password("password")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
go func() {
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
if err != nil {
t.Logf("Client dial failed: %v", err)
return
}
defer client.Close()
type forwardPayload struct {
BindAddr string
BindPort uint32
}
payload := ssh.Marshal(forwardPayload{
BindAddr: "localhost",
BindPort: 80,
})
_, _, err = client.SendRequest("tcpip-forward", true, payload)
if err != nil {
t.Logf("Forward request failed: %v", err)
}
time.Sleep(500 * time.Millisecond)
}()
select {
case <-done:
t.Log("handleConnection completed")
case <-time.After(5 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
t.Run("SSH connection established with gRPC authorization", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
mockConfig.On("Domain").Return("test.com")
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("SSHPort").Return("2200")
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
serverAddr := listener.Addr().String()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
s.handleConnection(conn)
done <- true
}()
time.Sleep(50 * time.Millisecond)
clientConfig := &ssh.ClientConfig{
User: "testuser",
Auth: []ssh.AuthMethod{ssh.Password("password")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
go func() {
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
if err != nil {
t.Logf("Client dial failed: %v", err)
return
}
defer client.Close()
type forwardPayload struct {
BindAddr string
BindPort uint32
}
payload := ssh.Marshal(forwardPayload{
BindAddr: "localhost",
BindPort: 80,
})
_, _, err = client.SendRequest("tcpip-forward", true, payload)
if err != nil {
t.Logf("Forward request failed: %v", err)
}
time.Sleep(500 * time.Millisecond)
}()
select {
case <-done:
mockGrpcClient.AssertExpectations(t)
case <-time.After(5 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
t.Run("SSH connection with gRPC authorization error", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
mockConfig.On("Domain").Return("test.com")
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("SSHPort").Return("2200")
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer listener.Close()
serverAddr := listener.Addr().String()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
s.handleConnection(conn)
done <- true
}()
time.Sleep(50 * time.Millisecond)
clientConfig := &ssh.ClientConfig{
User: "testuser",
Auth: []ssh.AuthMethod{ssh.Password("password")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
go func() {
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
if err != nil {
t.Logf("Client dial failed: %v", err)
return
}
defer client.Close()
type forwardPayload struct {
BindAddr string
BindPort uint32
}
payload := ssh.Marshal(forwardPayload{
BindAddr: "localhost",
BindPort: 8080,
})
_, _, err = client.SendRequest("tcpip-forward", true, payload)
if err != nil {
t.Logf("Forward request failed: %v", err)
}
time.Sleep(500 * time.Millisecond)
}()
select {
case <-done:
mockGrpcClient.AssertExpectations(t)
case <-time.After(5 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
t.Run("connection cleanup on close", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
s.handleConnection(serverConn)
done <- true
}()
clientConn.Close()
select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
}
func TestIntegration(t *testing.T) {
t.Run("full server lifecycle", func(t *testing.T) {
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
s, err := New(mr, mc, sc, mreg, mg, mp, "0")
assert.NoError(t, err)
assert.NotNil(t, s)
go func() {
time.Sleep(100 * time.Millisecond)
err := s.Close()
assert.NoError(t, err)
}()
s.Start()
})
t.Run("multiple connections", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
conn1Server, conn1Client := net.Pipe()
conn2Server, conn2Client := net.Pipe()
mockListener := &MockListener{}
mockListener.On("Accept").Return(conn1Server, nil).Once()
mockListener.On("Accept").Return(conn2Server, nil).Once()
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshListener: mockListener,
sshConfig: sshConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
go s.Start()
time.Sleep(50 * time.Millisecond)
conn1Client.Close()
time.Sleep(50 * time.Millisecond)
conn2Client.Close()
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
})
}
func TestErrorHandling(t *testing.T) {
t.Run("write error during SSH handshake", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
clientConn.Close()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: sshConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
s.handleConnection(serverConn)
})
}