From a3f6baa6aea8d54e77a40b27d183ec6bc3f43b45 Mon Sep 17 00:00:00 2001 From: bagas Date: Mon, 26 Jan 2026 18:55:59 +0700 Subject: [PATCH] test: check and handle error for testing --- internal/bootstrap/bootstrap_test.go | 438 ++++++++---------- internal/config/config_test.go | 28 +- internal/grpc/client/client_test.go | 4 +- internal/middleware/forwardedfor_test.go | 5 +- internal/middleware/tunnelfingerprint_test.go | 3 +- internal/port/port_test.go | 3 +- internal/registry/registry_test.go | 7 +- internal/transport/http_test.go | 14 +- internal/transport/httphandler_test.go | 53 +-- internal/transport/https_test.go | 12 +- internal/transport/tcp_test.go | 23 +- internal/transport/tls_test.go | 185 +++++--- server/server_test.go | 64 ++- session/forwarder/forwarder_test.go | 50 +- session/lifecycle/lifecycle_test.go | 21 +- session/session_test.go | 106 ++--- session/slug/slug_test.go | 5 - 17 files changed, 505 insertions(+), 516 deletions(-) diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go index f5b5c0c..2453cde 100644 --- a/internal/bootstrap/bootstrap_test.go +++ b/internal/bootstrap/bootstrap_test.go @@ -2,13 +2,7 @@ package bootstrap import ( "context" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" "fmt" - "math/big" "net" "net/http" _ "net/http/pprof" @@ -20,15 +14,11 @@ import ( "tunnel_pls/internal/config" "tunnel_pls/internal/port" "tunnel_pls/internal/registry" - "tunnel_pls/session/forwarder" - "tunnel_pls/session/interaction" - "tunnel_pls/session/lifecycle" "tunnel_pls/session/slug" "tunnel_pls/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" "google.golang.org/grpc" ) @@ -76,35 +66,6 @@ func (m *MockSessionRegistry) Slug() slug.Slug { return args.Get(0).(slug.Slug) } -type MockSession struct { - mock.Mock -} - -func (m *MockSession) Lifecycle() lifecycle.Lifecycle { - args := m.Called() - return args.Get(0).(lifecycle.Lifecycle) -} - -func (m *MockSession) Interaction() interaction.Interaction { - args := m.Called() - return args.Get(0).(interaction.Interaction) -} - -func (m *MockSession) Forwarder() forwarder.Forwarder { - args := m.Called() - return args.Get(0).(forwarder.Forwarder) -} - -func (m *MockSession) Slug() slug.Slug { - args := m.Called() - return args.Get(0).(slug.Slug) -} - -func (m *MockSession) Detail() *types.Detail { - args := m.Called() - return args.Get(0).(*types.Detail) -} - type MockRandom struct { mock.Mock } @@ -162,24 +123,24 @@ func (m *MockPort) AddRange(startPort, endPort uint16) error { } func (m *MockPort) Unassigned() (uint16, bool) { args := m.Called() - var port uint16 + var mPort uint16 if args.Get(0) != nil { switch v := args.Get(0).(type) { case int: - port = uint16(v) + mPort = uint16(v) case uint16: - port = v + mPort = v case uint32: - port = uint16(v) + mPort = uint16(v) case int32: - port = uint16(v) + mPort = uint16(v) case float64: - port = uint16(v) + mPort = uint16(v) default: - port = uint16(args.Int(0)) + mPort = uint16(args.Int(0)) } } - return port, args.Bool(1) + return mPort, args.Bool(1) } func (m *MockPort) SetStatus(port uint16, assigned bool) error { return m.Called(port, assigned).Error(0) @@ -281,49 +242,17 @@ func TestNew(t *testing.T) { } } -func generateTestCert(t *testing.T) (certPEM, keyPEM []byte) { - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - require.NoError(t, err) - - template := x509.Certificate{ - SerialNumber: big.NewInt(1), - Subject: pkix.Name{ - Organization: []string{"Test Co"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(time.Hour), - KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, - BasicConstraintsValid: true, - DNSNames: []string{"localhost"}, - IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, - } - - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) - require.NoError(t, err) - - certPEM = pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE", - Bytes: certDER, - }) - - keyPEM = pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(privateKey), - }) - - return certPEM, keyPEM -} - func randomAvailablePort() (string, error) { listener, err := net.Listen("tcp", "localhost:0") if err != nil { return "", err } - defer listener.Close() + defer func(listener net.Listener) { + _ = listener.Close() + }(listener) - port := listener.Addr().(*net.TCPAddr).Port - return strconv.Itoa(port), nil + mPort := listener.Addr().(*net.TCPAddr).Port + return strconv.Itoa(mPort), nil } func TestRun(t *testing.T) { @@ -346,81 +275,81 @@ func TestRun(t *testing.T) { { name: "successful run and termination", setupConfig: func() *MockConfig { - mock := &MockConfig{} - mock.On("KeyLoc").Return(keyLoc) - mock.On("Mode").Return(types.ServerModeSTANDALONE) - mock.On("Domain").Return("example.com") - mock.On("SSHPort").Return("0") - mock.On("HTTPPort").Return("0") - mock.On("HTTPSPort").Return("0") - mock.On("TLSEnabled").Return(false) - mock.On("TLSRedirect").Return(false) - mock.On("ACMEEmail").Return("test@example.com") - mock.On("CFAPIToken").Return("fake-token") - mock.On("ACMEStaging").Return(true) - mock.On("AllowedPortsStart").Return(uint16(1024)) - mock.On("AllowedPortsEnd").Return(uint16(65535)) - mock.On("BufferSize").Return(4096) - mock.On("PprofEnabled").Return(false) - mock.On("PprofPort").Return("0") - mock.On("GRPCAddress").Return("localhost") - mock.On("GRPCPort").Return("0") - mock.On("NodeToken").Return("fake-node-token") - return mock + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig }, expectError: false, }, { name: "error from SSH server invalid port", setupConfig: func() *MockConfig { - mock := &MockConfig{} - mock.On("KeyLoc").Return(keyLoc) - mock.On("Mode").Return(types.ServerModeSTANDALONE) - mock.On("Domain").Return("example.com") - mock.On("SSHPort").Return("invalid") - mock.On("HTTPPort").Return("0") - mock.On("HTTPSPort").Return("0") - mock.On("TLSEnabled").Return(false) - mock.On("TLSRedirect").Return(false) - mock.On("ACMEEmail").Return("test@example.com") - mock.On("CFAPIToken").Return("fake-token") - mock.On("ACMEStaging").Return(true) - mock.On("AllowedPortsStart").Return(uint16(1024)) - mock.On("AllowedPortsEnd").Return(uint16(65535)) - mock.On("BufferSize").Return(4096) - mock.On("PprofEnabled").Return(false) - mock.On("PprofPort").Return("0") - mock.On("GRPCAddress").Return("localhost") - mock.On("GRPCPort").Return("0") - mock.On("NodeToken").Return("fake-node-token") - return mock + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("invalid") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig }, expectError: true, }, { name: "error from HTTP server invalid port", setupConfig: func() *MockConfig { - mock := &MockConfig{} - mock.On("KeyLoc").Return(keyLoc) - mock.On("Mode").Return(types.ServerModeSTANDALONE) - mock.On("Domain").Return("example.com") - mock.On("SSHPort").Return("0") - mock.On("HTTPPort").Return("invalid") - mock.On("HTTPSPort").Return("0") - mock.On("TLSEnabled").Return(false) - mock.On("TLSRedirect").Return(false) - mock.On("ACMEEmail").Return("test@example.com") - mock.On("CFAPIToken").Return("fake-token") - mock.On("ACMEStaging").Return(true) - mock.On("AllowedPortsStart").Return(uint16(1024)) - mock.On("AllowedPortsEnd").Return(uint16(65535)) - mock.On("BufferSize").Return(4096) - mock.On("PprofEnabled").Return(false) - mock.On("PprofPort").Return("0") - mock.On("GRPCAddress").Return("localhost") - mock.On("GRPCPort").Return("0") - mock.On("NodeToken").Return("fake-node-token") - return mock + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("invalid") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig }, expectError: true, }, @@ -428,55 +357,55 @@ func TestRun(t *testing.T) { name: "error from HTTPS server invalid port", setupConfig: func() *MockConfig { tempDir := os.TempDir() - mock := &MockConfig{} - mock.On("KeyLoc").Return(keyLoc) - mock.On("Mode").Return(types.ServerModeSTANDALONE) - mock.On("Domain").Return("example.com") - mock.On("SSHPort").Return("0") - mock.On("HTTPPort").Return("0") - mock.On("HTTPSPort").Return("invalid") - mock.On("TLSEnabled").Return(true) - mock.On("TLSRedirect").Return(false) - mock.On("TLSStoragePath").Return(tempDir) - mock.On("ACMEEmail").Return("test@example.com") - mock.On("CFAPIToken").Return("fake-token") - mock.On("ACMEStaging").Return(true) - mock.On("AllowedPortsStart").Return(uint16(1024)) - mock.On("AllowedPortsEnd").Return(uint16(65535)) - mock.On("BufferSize").Return(4096) - mock.On("PprofEnabled").Return(false) - mock.On("PprofPort").Return("0") - mock.On("GRPCAddress").Return("localhost") - mock.On("GRPCPort").Return("0") - mock.On("NodeToken").Return("fake-node-token") - return mock + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("invalid") + mockConfig.On("TLSEnabled").Return(true) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("TLSStoragePath").Return(tempDir) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig }, expectError: true, }, { name: "grpc health check failed", setupConfig: func() *MockConfig { - mock := &MockConfig{} - mock.On("KeyLoc").Return(keyLoc) - mock.On("Mode").Return(types.ServerModeNODE) - mock.On("Domain").Return("example.com") - mock.On("SSHPort").Return("0") - mock.On("HTTPPort").Return("0") - mock.On("HTTPSPort").Return("0") - mock.On("TLSEnabled").Return(false) - mock.On("TLSRedirect").Return(false) - mock.On("ACMEEmail").Return("test@example.com") - mock.On("CFAPIToken").Return("fake-token") - mock.On("ACMEStaging").Return(true) - mock.On("AllowedPortsStart").Return(uint16(1024)) - mock.On("AllowedPortsEnd").Return(uint16(65535)) - mock.On("BufferSize").Return(4096) - mock.On("PprofEnabled").Return(false) - mock.On("PprofPort").Return("0") - mock.On("GRPCAddress").Return("localhost") - mock.On("GRPCPort").Return("invalid") - mock.On("NodeToken").Return("fake-node-token") - return mock + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("invalid") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig }, setupGrpcClient: func() *MockGRPCClient { mockGRPCClient := &MockGRPCClient{} @@ -488,54 +417,54 @@ func TestRun(t *testing.T) { { name: "successful run with pprof enabled", setupConfig: func() *MockConfig { - mock := &MockConfig{} + mockConfig := &MockConfig{} pprofPort, _ := randomAvailablePort() - mock.On("KeyLoc").Return(keyLoc) - mock.On("Mode").Return(types.ServerModeSTANDALONE) - mock.On("Domain").Return("example.com") - mock.On("SSHPort").Return("0") - mock.On("HTTPPort").Return("0") - mock.On("HTTPSPort").Return("0") - mock.On("TLSEnabled").Return(false) - mock.On("TLSRedirect").Return(false) - mock.On("ACMEEmail").Return("test@example.com") - mock.On("CFAPIToken").Return("fake-token") - mock.On("ACMEStaging").Return(true) - mock.On("AllowedPortsStart").Return(uint16(1024)) - mock.On("AllowedPortsEnd").Return(uint16(65535)) - mock.On("BufferSize").Return(4096) - mock.On("PprofEnabled").Return(true) - mock.On("PprofPort").Return(pprofPort) - mock.On("GRPCAddress").Return("localhost") - mock.On("GRPCPort").Return("0") - mock.On("NodeToken").Return("fake-node-token") - return mock + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(true) + mockConfig.On("PprofPort").Return(pprofPort) + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig }, expectError: false, }, { name: "successful run in NODE mode with signal", setupConfig: func() *MockConfig { - mock := &MockConfig{} - mock.On("KeyLoc").Return(keyLoc) - mock.On("Mode").Return(types.ServerModeNODE) - mock.On("Domain").Return("example.com") - mock.On("SSHPort").Return("0") - mock.On("HTTPPort").Return("0") - mock.On("HTTPSPort").Return("0") - mock.On("TLSEnabled").Return(false) - mock.On("TLSRedirect").Return(false) - mock.On("ACMEEmail").Return("test@example.com") - mock.On("CFAPIToken").Return("fake-token") - mock.On("ACMEStaging").Return(true) - mock.On("AllowedPortsStart").Return(uint16(1024)) - mock.On("AllowedPortsEnd").Return(uint16(65535)) - mock.On("BufferSize").Return(4096) - mock.On("PprofEnabled").Return(false) - mock.On("PprofPort").Return("0") - mock.On("GRPCAddress").Return("localhost") - mock.On("GRPCPort").Return("0") - mock.On("NodeToken").Return("fake-node-token") - return mock + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig }, setupGrpcClient: func() *MockGRPCClient { mockGRPCClient := &MockGRPCClient{} @@ -548,27 +477,27 @@ func TestRun(t *testing.T) { }, { name: "successful run in NODE mode with signal buf error when closing", setupConfig: func() *MockConfig { - mock := &MockConfig{} - mock.On("KeyLoc").Return(keyLoc) - mock.On("Mode").Return(types.ServerModeNODE) - mock.On("Domain").Return("example.com") - mock.On("SSHPort").Return("0") - mock.On("HTTPPort").Return("0") - mock.On("HTTPSPort").Return("0") - mock.On("TLSEnabled").Return(false) - mock.On("TLSRedirect").Return(false) - mock.On("ACMEEmail").Return("test@example.com") - mock.On("CFAPIToken").Return("fake-token") - mock.On("ACMEStaging").Return(true) - mock.On("AllowedPortsStart").Return(uint16(1024)) - mock.On("AllowedPortsEnd").Return(uint16(65535)) - mock.On("BufferSize").Return(4096) - mock.On("PprofEnabled").Return(false) - mock.On("PprofPort").Return("0") - mock.On("GRPCAddress").Return("localhost") - mock.On("GRPCPort").Return("0") - mock.On("NodeToken").Return("fake-node-token") - return mock + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig }, setupGrpcClient: func() *MockGRPCClient { mockGRPCClient := &MockGRPCClient{} @@ -613,7 +542,8 @@ func TestRun(t *testing.T) { resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort())) assert.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) - resp.Body.Close() + err = resp.Body.Close() + assert.NoError(t, err) mockSignalChan <- os.Interrupt err = <-done assert.NoError(t, err) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index b102094..85f93f3 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -37,7 +37,8 @@ func TestGetenv(t *testing.T) { if tt.val != "" { t.Setenv(tt.key, tt.val) } else { - os.Unsetenv(tt.key) + err := os.Unsetenv(tt.key) + assert.NoError(t, err) } assert.Equal(t, tt.expected, getenv(tt.key, tt.def)) }) @@ -87,7 +88,8 @@ func TestGetenvBool(t *testing.T) { if tt.val != "" { t.Setenv(tt.key, tt.val) } else { - os.Unsetenv(tt.key) + err := os.Unsetenv(tt.key) + assert.NoError(t, err) } assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def)) }) @@ -113,7 +115,8 @@ func TestParseMode(t *testing.T) { if tt.mode != "" { t.Setenv("MODE", tt.mode) } else { - os.Unsetenv("MODE") + err := os.Unsetenv("MODE") + assert.NoError(t, err) } mode, err := parseMode() if tt.expectErr { @@ -148,7 +151,8 @@ func TestParseAllowedPorts(t *testing.T) { if tt.val != "" { t.Setenv("ALLOWED_PORTS", tt.val) } else { - os.Unsetenv("ALLOWED_PORTS") + err := os.Unsetenv("ALLOWED_PORTS") + assert.NoError(t, err) } start, end, err := parseAllowedPorts() if tt.expectErr { @@ -180,7 +184,8 @@ func TestParseBufferSize(t *testing.T) { if tt.val != "" { t.Setenv("BUFFER_SIZE", tt.val) } else { - os.Unsetenv("BUFFER_SIZE") + err := os.Unsetenv("BUFFER_SIZE") + assert.NoError(t, err) } size := parseBufferSize() assert.Equal(t, tt.expect, size) @@ -206,7 +211,8 @@ func TestParseHeaderSize(t *testing.T) { if tt.val != "" { t.Setenv("MAX_HEADER_SIZE", tt.val) } else { - os.Unsetenv("MAX_HEADER_SIZE") + err := os.Unsetenv("MAX_HEADER_SIZE") + assert.NoError(t, err) } size := parseHeaderSize() assert.Equal(t, tt.expect, size) @@ -358,7 +364,10 @@ func TestMustLoad(t *testing.T) { t.Run("loadEnvFile error", func(t *testing.T) { err := os.Mkdir(".env", 0755) assert.NoError(t, err) - defer os.Remove(".env") + defer func() { + err = os.Remove(".env") + assert.NoError(t, err) + }() cfg, err := MustLoad() assert.Error(t, err) @@ -378,7 +387,10 @@ func TestLoadEnvFile(t *testing.T) { t.Run("file exists", func(t *testing.T) { err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644) assert.NoError(t, err) - defer os.Remove(".env") + defer func() { + err = os.Remove(".env") + assert.NoError(t, err) + }() err = loadEnvFile() assert.NoError(t, err) diff --git a/internal/grpc/client/client_test.go b/internal/grpc/client/client_test.go index fb2147e..e69065d 100644 --- a/internal/grpc/client/client_test.go +++ b/internal/grpc/client/client_test.go @@ -744,7 +744,9 @@ func TestNew(t *testing.T) { if cli == nil { t.Fatal("New() returned nil client") } - defer cli.Close() + defer func(cli Client) { + _ = cli.Close() + }(cli) } type MockConfig struct { diff --git a/internal/middleware/forwardedfor_test.go b/internal/middleware/forwardedfor_test.go index 5a45dc0..49f9980 100644 --- a/internal/middleware/forwardedfor_test.go +++ b/internal/middleware/forwardedfor_test.go @@ -1,10 +1,11 @@ package middleware import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "net" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) type mockRequestHeader struct { diff --git a/internal/middleware/tunnelfingerprint_test.go b/internal/middleware/tunnelfingerprint_test.go index 21e8b15..0054d1e 100644 --- a/internal/middleware/tunnelfingerprint_test.go +++ b/internal/middleware/tunnelfingerprint_test.go @@ -1,9 +1,10 @@ package middleware import ( + "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "testing" ) type mockResponseHeader struct { diff --git a/internal/port/port_test.go b/internal/port/port_test.go index 56526b3..fcc64d3 100644 --- a/internal/port/port_test.go +++ b/internal/port/port_test.go @@ -1,8 +1,9 @@ package port import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func TestAddRange(t *testing.T) { diff --git a/internal/registry/registry_test.go b/internal/registry/registry_test.go index 6e93ceb..484b4e9 100644 --- a/internal/registry/registry_test.go +++ b/internal/registry/registry_test.go @@ -1,9 +1,6 @@ package registry import ( - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" "sync" "testing" "time" @@ -14,6 +11,10 @@ import ( "tunnel_pls/session/slug" "tunnel_pls/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" ) diff --git a/internal/transport/http_test.go b/internal/transport/http_test.go index 3922eb0..cd3cf68 100644 --- a/internal/transport/http_test.go +++ b/internal/transport/http_test.go @@ -38,7 +38,8 @@ func TestHTTPServer_Listen(t *testing.T) { listener, err := srv.Listen() assert.NoError(t, err) assert.NotNil(t, listener) - listener.Close() + err = listener.Close() + assert.NoError(t, err) } func TestHTTPServer_Serve(t *testing.T) { @@ -54,7 +55,8 @@ func TestHTTPServer_Serve(t *testing.T) { go func() { time.Sleep(100 * time.Millisecond) - listener.Close() + err = listener.Close() + assert.NoError(t, err) }() err = srv.Serve(listener) @@ -102,8 +104,12 @@ func TestHTTPServer_Serve_Success(t *testing.T) { _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) time.Sleep(100 * time.Millisecond) - conn.Close() - listener.Close() + err = conn.Close() + assert.NoError(t, err) + + err = listener.Close() + assert.NoError(t, err) + } type mockListener struct { diff --git a/internal/transport/httphandler_test.go b/internal/transport/httphandler_test.go index b30d9d5..6801b22 100644 --- a/internal/transport/httphandler_test.go +++ b/internal/transport/httphandler_test.go @@ -10,7 +10,6 @@ import ( "sync" "testing" "time" - "tunnel_pls/internal/port" "tunnel_pls/internal/registry" "tunnel_pls/session/forwarder" "tunnel_pls/session/interaction" @@ -97,53 +96,6 @@ func (m *MockSession) Detail() *types.Detail { return args.Get(0).(*types.Detail) } -type MockLifecycle struct { - mock.Mock -} - -func (m *MockLifecycle) Channel() ssh.Channel { - args := m.Called() - return args.Get(0).(ssh.Channel) -} - -func (m *MockLifecycle) Connection() ssh.Conn { - args := m.Called() - return args.Get(0).(ssh.Conn) -} - -func (m *MockLifecycle) PortRegistry() port.Port { - args := m.Called() - return args.Get(0).(port.Port) -} - -func (m *MockLifecycle) User() string { - args := m.Called() - return args.String(0) -} - -func (m *MockLifecycle) SetChannel(channel ssh.Channel) { - m.Called(channel) -} - -func (m *MockLifecycle) SetStatus(status types.SessionStatus) { - m.Called(status) -} - -func (m *MockLifecycle) IsActive() bool { - args := m.Called() - return args.Bool(0) -} - -func (m *MockLifecycle) StartedAt() time.Time { - args := m.Called() - return args.Get(0).(time.Time) -} - -func (m *MockLifecycle) Close() error { - args := m.Called() - return args.Error(0) -} - type MockSSHChannel struct { ssh.Channel mock.Mock @@ -678,7 +630,10 @@ func TestHandler(t *testing.T) { } if clientConn != nil { - defer clientConn.Close() + defer func(clientConn net.Conn) { + err := clientConn.Close() + assert.NoError(t, err) + }(clientConn) } remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") diff --git a/internal/transport/https_test.go b/internal/transport/https_test.go index cf09592..6081d97 100644 --- a/internal/transport/https_test.go +++ b/internal/transport/https_test.go @@ -46,7 +46,8 @@ func TestHTTPSServer_Listen(t *testing.T) { return } assert.NotNil(t, listener) - listener.Close() + err = listener.Close() + assert.NoError(t, err) } func TestHTTPSServer_Serve(t *testing.T) { @@ -62,7 +63,8 @@ func TestHTTPSServer_Serve(t *testing.T) { go func() { time.Sleep(100 * time.Millisecond) - listener.Close() + err = listener.Close() + assert.NoError(t, err) }() err = srv.Serve(listener) @@ -111,6 +113,8 @@ func TestHTTPSServer_Serve_Success(t *testing.T) { _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) time.Sleep(100 * time.Millisecond) - conn.Close() - listener.Close() + err = conn.Close() + assert.NoError(t, err) + err = listener.Close() + assert.NoError(t, err) } diff --git a/internal/transport/tcp_test.go b/internal/transport/tcp_test.go index 409e6f1..c4c4963 100644 --- a/internal/transport/tcp_test.go +++ b/internal/transport/tcp_test.go @@ -32,7 +32,8 @@ func TestTCPServer_Listen(t *testing.T) { listener, err := srv.Listen() assert.NoError(t, err) assert.NotNil(t, listener) - listener.Close() + err = listener.Close() + assert.NoError(t, err) } func TestTCPServer_Serve(t *testing.T) { @@ -44,7 +45,8 @@ func TestTCPServer_Serve(t *testing.T) { go func() { time.Sleep(100 * time.Millisecond) - listener.Close() + err = listener.Close() + assert.NoError(t, err) }() err = srv.Serve(listener) @@ -84,9 +86,10 @@ func TestTCPServer_Serve_Success(t *testing.T) { assert.NoError(t, err) time.Sleep(100 * time.Millisecond) - conn.Close() - listener.Close() - + err = conn.Close() + assert.NoError(t, err) + err = listener.Close() + assert.NoError(t, err) mf.AssertExpectations(t) } @@ -95,7 +98,10 @@ func TestTCPServer_handleTcp_Success(t *testing.T) { srv := NewTCPServer(0, mf).(*tcp) serverConn, clientConn := net.Pipe() - defer clientConn.Close() + defer func(clientConn net.Conn) { + err := clientConn.Close() + assert.NoError(t, err) + }(clientConn) reqs := make(chan *ssh.Request) mockChannel := new(MockSSHChannel) @@ -127,7 +133,10 @@ func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) { srv := NewTCPServer(0, mf).(*tcp) serverConn, clientConn := net.Pipe() - defer clientConn.Close() + defer func(clientConn net.Conn) { + err := clientConn.Close() + assert.NoError(t, err) + }(clientConn) mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) diff --git a/internal/transport/tls_test.go b/internal/transport/tls_test.go index 12e656d..0c5510c 100644 --- a/internal/transport/tls_test.go +++ b/internal/transport/tls_test.go @@ -80,13 +80,15 @@ func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, so assert.NoError(t, err) err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) assert.NoError(t, err) - certOut.Close() + err = certOut.Close() + assert.NoError(t, err) keyOut, err := os.CreateTemp("", "key*.pem") assert.NoError(t, err) err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) assert.NoError(t, err) - keyOut.Close() + err = keyOut.Close() + assert.NoError(t, err) return certOut.Name(), keyOut.Name() } @@ -98,7 +100,8 @@ func setupTestDir(t *testing.T) string { assert.NoError(t, err) t.Cleanup(func() { - os.RemoveAll(tmpDir) + err = os.RemoveAll(tmpDir) + assert.NoError(t, err) }) return tmpDir @@ -126,8 +129,11 @@ func TestValidateCertDomains(t *testing.T) { assert.NoError(t, err) _, err = tmpFile.WriteString("not a pem") assert.NoError(t, err) - tmpFile.Close() - return tmpFile.Name(), func() { os.Remove(tmpFile.Name()) } + err = tmpFile.Close() + assert.NoError(t, err) + return tmpFile.Name(), func() { + _ = os.Remove(tmpFile.Name()) + } }, domain: "example.com", expected: false, @@ -137,8 +143,8 @@ func TestValidateCertDomains(t *testing.T) { setup: func(t *testing.T) (string, func()) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) return certPath, func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) } }, domain: "example.com", @@ -149,8 +155,8 @@ func TestValidateCertDomains(t *testing.T) { setup: func(t *testing.T) (string, func()) { certPath, keyPath := createTestCert(t, "example.com", true, true, false) return certPath, func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) } }, domain: "example.com", @@ -161,8 +167,8 @@ func TestValidateCertDomains(t *testing.T) { setup: func(t *testing.T) (string, func()) { certPath, keyPath := createTestCert(t, "example.com", true, false, true) return certPath, func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) } }, domain: "example.com", @@ -173,8 +179,8 @@ func TestValidateCertDomains(t *testing.T) { setup: func(t *testing.T) (string, func()) { certPath, keyPath := createTestCert(t, "example.com", false, false, false) return certPath, func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) } }, domain: "example.com", @@ -205,8 +211,8 @@ func TestLoadAndParseCertificate(t *testing.T) { setup: func(t *testing.T) (string, func()) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) return certPath, func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) } }, wantError: false, @@ -275,8 +281,14 @@ func TestIsCertificateValid(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, tt.expired, tt.soon) - defer os.Remove(certPath) - defer os.Remove(keyPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) cert, err := loadAndParseCertificate(certPath) assert.NoError(t, err) @@ -289,8 +301,14 @@ func TestIsCertificateValid(t *testing.T) { func TestExtractCertDomains(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) - defer os.Remove(certPath) - defer os.Remove(keyPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) cert, err := loadAndParseCertificate(certPath) assert.NoError(t, err) @@ -381,8 +399,8 @@ func TestTLSManager_getCertificate(t *testing.T) { setup: func(t *testing.T) *tlsManager { certPath, keyPath := createTestCert(t, "example.com", true, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) cert, err := tls.LoadX509KeyPair(certPath, keyPath) @@ -447,8 +465,9 @@ func TestTLSManager_userCertsExistAndValid(t *testing.T) { mockCfg.On("Domain").Return("example.com") certPath, keyPath := createTestCert(t, "example.com", true, false, false) - t.Cleanup(func() { os.Remove(certPath) }) - os.Remove(keyPath) + t.Cleanup(func() { _ = os.Remove(certPath) }) + err := os.Remove(keyPath) + assert.NoError(t, err) return &tlsManager{ config: mockCfg, @@ -471,8 +490,14 @@ func TestTLSManager_userCertsExistAndValid(t *testing.T) { func TestTLSManager_certFilesExist(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) - defer os.Remove(certPath) - defer os.Remove(keyPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) tm := &tlsManager{ certPath: certPath, @@ -494,8 +519,8 @@ func TestTLSManager_loadUserCerts(t *testing.T) { setup: func(t *testing.T) *tlsManager { certPath, keyPath := createTestCert(t, "example.com", true, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) return &tlsManager{ @@ -550,8 +575,14 @@ func TestCreateTLSManager(t *testing.T) { func TestNewCertWatcher(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) - defer os.Remove(certPath) - defer os.Remove(keyPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) mockCfg := &MockConfig{} @@ -571,8 +602,14 @@ func TestNewCertWatcher(t *testing.T) { func TestCertWatcher_filesModified(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) - defer os.Remove(certPath) - defer os.Remove(keyPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) mockCfg := &MockConfig{} @@ -600,8 +637,14 @@ func TestCertWatcher_filesModified(t *testing.T) { func TestCertWatcher_updateModTimes(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) - defer os.Remove(certPath) - defer os.Remove(keyPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) mockCfg := &MockConfig{} @@ -637,8 +680,8 @@ func TestCertWatcher_getFileInfo(t *testing.T) { setup: func(t *testing.T) *tlsManager { certPath, keyPath := createTestCert(t, "example.com", true, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) return &tlsManager{ @@ -657,8 +700,9 @@ func TestCertWatcher_getFileInfo(t *testing.T) { name: "missing cert file", setup: func(t *testing.T) *tlsManager { certPath, keyPath := createTestCert(t, "example.com", true, false, false) - os.Remove(certPath) - t.Cleanup(func() { os.Remove(keyPath) }) + err := os.Remove(certPath) + assert.NoError(t, err) + t.Cleanup(func() { _ = os.Remove(keyPath) }) return &tlsManager{ config: &MockConfig{}, @@ -672,8 +716,9 @@ func TestCertWatcher_getFileInfo(t *testing.T) { name: "missing key file", setup: func(t *testing.T) *tlsManager { certPath, keyPath := createTestCert(t, "example.com", true, false, false) - os.Remove(keyPath) - t.Cleanup(func() { os.Remove(certPath) }) + err := os.Remove(keyPath) + assert.NoError(t, err) + t.Cleanup(func() { _ = os.Remove(certPath) }) return &tlsManager{ config: &MockConfig{}, @@ -729,8 +774,8 @@ func TestCertWatcher_checkAndReloadCerts(t *testing.T) { setup: func(t *testing.T) (*tlsManager, *certWatcher) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) tm := &tlsManager{ @@ -747,8 +792,8 @@ func TestCertWatcher_checkAndReloadCerts(t *testing.T) { setup: func(t *testing.T) (*tlsManager, *certWatcher) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) mockCfg := &MockConfig{} @@ -792,8 +837,8 @@ func TestCertWatcher_handleCertificateChange(t *testing.T) { setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) mockCfg := &MockConfig{} @@ -819,8 +864,8 @@ func TestCertWatcher_handleCertificateChange(t *testing.T) { setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) { certPath, keyPath := createTestCert(t, "example.com", false, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) tmpDir := setupTestDir(t) @@ -850,8 +895,8 @@ func TestCertWatcher_handleCertificateChange(t *testing.T) { setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) mockCfg := &MockConfig{} @@ -946,8 +991,8 @@ func TestCertWatcher_watch(t *testing.T) { setup: func(t *testing.T) (*tlsManager, *certWatcher) { certPath, keyPath := createTestCert(t, "example.com", false, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) tmpDir := setupTestDir(t) @@ -976,8 +1021,8 @@ func TestCertWatcher_watch(t *testing.T) { setup: func(t *testing.T) (*tlsManager, *certWatcher) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) mockCfg := &MockConfig{} @@ -1006,8 +1051,14 @@ func TestCertWatcher_watch(t *testing.T) { func TestCertWatcher_watch_Integration(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) - defer os.Remove(certPath) - defer os.Remove(keyPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) mockCfg := &MockConfig{} mockCfg.On("Domain").Return("example.com") @@ -1029,8 +1080,14 @@ func TestCertWatcher_watch_Integration(t *testing.T) { time.Sleep(50 * time.Millisecond) newCertPath, newKeyPath := createTestCert(t, "example.com", true, false, false) - defer os.Remove(newCertPath) - defer os.Remove(newKeyPath) + defer func(name string) { + err = os.Remove(name) + assert.NoError(t, err) + }(newCertPath) + defer func(name string) { + err = os.Remove(name) + assert.NoError(t, err) + }(newKeyPath) newCertData, err := os.ReadFile(newCertPath) assert.NoError(t, err) @@ -1066,8 +1123,8 @@ func TestNewTLSConfig(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) t.Cleanup(func() { - os.Remove(certPath) - os.Remove(keyPath) + _ = os.Remove(certPath) + _ = os.Remove(keyPath) }) certData, err := os.ReadFile(certPath) @@ -1140,8 +1197,14 @@ func TestNewTLSConfig_Singleton(t *testing.T) { tmpDir := setupTestDir(t) certPath, keyPath := createTestCert(t, "example.com", true, false, false) - defer os.Remove(certPath) - defer os.Remove(keyPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) certData, err := os.ReadFile(certPath) assert.NoError(t, err) diff --git a/server/server_test.go b/server/server_test.go index 9b7dd88..a4d5c74 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -181,14 +181,6 @@ func (m *MockListener) Addr() net.Addr { return m.Called().Get(0).(net.Addr) } -type MockSession struct { - mock.Mock -} - -func (m *MockSession) Start() error { - return m.Called().Error(0) -} - func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) { key, _ := rsa.GenerateKey(rand.Reader, 2048) signer, _ := ssh.NewSignerFromKey(key) @@ -244,7 +236,10 @@ func TestNew(t *testing.T) { t.Fatal(err) } port := l.Addr().(*net.TCPAddr).Port - defer l.Close() + defer func(l net.Listener) { + err = l.Close() + assert.NoError(t, err) + }(l) s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port)) assert.Error(t, err) @@ -360,7 +355,9 @@ func TestStart(t *testing.T) { go s.Start() time.Sleep(50 * time.Millisecond) - clientConn.Close() + err := clientConn.Close() + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) mockListener.AssertExpectations(t) @@ -394,7 +391,9 @@ func TestStart(t *testing.T) { go s.Start() time.Sleep(50 * time.Millisecond) - clientConn.Close() + err := clientConn.Close() + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) mockListener.AssertExpectations(t) @@ -423,7 +422,9 @@ func TestHandleConnection(t *testing.T) { portRegistry: mockPort, } - clientConn.Close() + err := clientConn.Close() + assert.NoError(t, err) + s.handleConnection(serverConn) }) @@ -489,7 +490,10 @@ func TestHandleConnection(t *testing.T) { if err != nil { t.Fatal(err) } - defer listener.Close() + defer func(listener net.Listener) { + err = listener.Close() + assert.NoError(t, err) + }(listener) serverAddr := listener.Addr().String() @@ -529,7 +533,10 @@ func TestHandleConnection(t *testing.T) { t.Logf("Client dial failed: %v", err) return } - defer client.Close() + defer func(client *ssh.Client) { + err = client.Close() + assert.NoError(t, err) + }(client) type forwardPayload struct { BindAddr string @@ -578,7 +585,10 @@ func TestHandleConnection(t *testing.T) { if err != nil { t.Fatal(err) } - defer listener.Close() + defer func(listener net.Listener) { + err = listener.Close() + assert.NoError(t, err) + }(listener) serverAddr := listener.Addr().String() @@ -618,7 +628,10 @@ func TestHandleConnection(t *testing.T) { t.Logf("Client dial failed: %v", err) return } - defer client.Close() + defer func(client *ssh.Client) { + err = client.Close() + assert.NoError(t, err) + }(client) type forwardPayload struct { BindAddr string @@ -667,7 +680,10 @@ func TestHandleConnection(t *testing.T) { if err != nil { t.Fatal(err) } - defer listener.Close() + defer func(listener net.Listener) { + err = listener.Close() + assert.NoError(t, err) + }(listener) serverAddr := listener.Addr().String() @@ -707,7 +723,9 @@ func TestHandleConnection(t *testing.T) { t.Logf("Client dial failed: %v", err) return } - defer client.Close() + defer func(client *ssh.Client) { + _ = client.Close() + }(client) type forwardPayload struct { BindAddr string @@ -762,7 +780,8 @@ func TestHandleConnection(t *testing.T) { done <- true }() - clientConn.Close() + err := clientConn.Close() + assert.NoError(t, err) select { case <-done: @@ -824,9 +843,9 @@ func TestIntegration(t *testing.T) { go s.Start() time.Sleep(50 * time.Millisecond) - conn1Client.Close() + _ = conn1Client.Close() time.Sleep(50 * time.Millisecond) - conn2Client.Close() + _ = conn2Client.Close() time.Sleep(100 * time.Millisecond) mockListener.AssertExpectations(t) @@ -843,7 +862,8 @@ func TestErrorHandling(t *testing.T) { sshConfig, _ := getTestSSHConfig() serverConn, clientConn := net.Pipe() - clientConn.Close() + err := clientConn.Close() + assert.NoError(t, err) s := &server{ randomizer: mockRandom, diff --git a/session/forwarder/forwarder_test.go b/session/forwarder/forwarder_test.go index a06506e..c3e1284 100644 --- a/session/forwarder/forwarder_test.go +++ b/session/forwarder/forwarder_test.go @@ -219,8 +219,14 @@ func (p *pipeConn) Write(b []byte) (int, error) { } func (p *pipeConn) Close() error { - p.reader.Close() - p.writer.Close() + err := p.reader.Close() + if err != nil { + return err + } + err = p.writer.Close() + if err != nil { + return err + } return nil } @@ -499,7 +505,8 @@ func TestOpenForwardedChannel(t *testing.T) { OriginAddr string OriginPort uint32 } - ssh.Unmarshal(capturedData, &payload) + err = ssh.Unmarshal(capturedData, &payload) + assert.NoError(t, err) assert.Equal(t, tt.wantDestAddr, payload.DestAddr) assert.Equal(t, tt.wantDestPort, payload.DestPort) assert.Equal(t, tt.wantOrigAddr, payload.OriginAddr) @@ -662,8 +669,8 @@ func TestCreateForwardedTCPIPPayload(t *testing.T) { OriginPort uint32 } - ssh.Unmarshal(payload, &decoded) - + err := ssh.Unmarshal(payload, &decoded) + assert.NoError(t, err) assert.Equal(t, tt.wantDestAddr, decoded.DestAddr) assert.Equal(t, tt.wantDestPort, decoded.DestPort) assert.Equal(t, tt.wantOriginAddr, decoded.OriginAddr) @@ -1056,9 +1063,7 @@ func TestCopyWithBuffer(t *testing.T) { } if buf, ok := dst.(*bytes.Buffer); ok && !tt.wantErr { - if _, ok := src.(io.Reader); ok { - assert.Equal(t, tt.wantBytesCount, int64(buf.Len())) - } + assert.Equal(t, tt.wantBytesCount, int64(buf.Len())) } if mr, ok := src.(*mockReader); ok { @@ -1276,7 +1281,10 @@ func TestSetListener(t *testing.T) { listener := tt.setupListener() if listener != nil { - defer listener.Close() + defer func(listener net.Listener) { + err := listener.Close() + assert.NoError(t, err) + }(listener) } assert.Nil(t, forwarder.Listener()) @@ -1318,7 +1326,10 @@ func TestListener(t *testing.T) { listener := tt.setupListener() if listener != nil { - defer listener.Close() + defer func(listener net.Listener) { + err := listener.Close() + assert.NoError(t, err) + }(listener) forwarder.SetListener(listener) } @@ -1356,7 +1367,8 @@ func TestClose(t *testing.T) { setupListener: func() net.Listener { listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) - listener.Close() + err = listener.Close() + assert.NoError(t, err) return listener }, wantErr: true, @@ -1477,8 +1489,10 @@ func TestHandleConnectionWithErrors(t *testing.T) { return newPipePair() }, simulateErr: func(channel *testChannelPeer, dst *pipeConn) { - channel.CloseWrite() - dst.CloseWrite() + err := channel.CloseWrite() + assert.NoError(t, err) + err = dst.CloseWrite() + assert.NoError(t, err) }, }, { @@ -1491,10 +1505,14 @@ func TestHandleConnectionWithErrors(t *testing.T) { return newPipePair() }, simulateErr: func(channel *testChannelPeer, dst *pipeConn) { - dst.Close() + err := dst.Close() + assert.NoError(t, err) time.Sleep(10 * time.Millisecond) - channel.Write([]byte("test")) - channel.CloseWrite() + write, err := channel.Write([]byte("test")) + assert.NotZero(t, write) + assert.NoError(t, err) + err = channel.CloseWrite() + assert.NoError(t, err) }, }, } diff --git a/session/lifecycle/lifecycle_test.go b/session/lifecycle/lifecycle_test.go index 4f4335b..608b3a8 100644 --- a/session/lifecycle/lifecycle_test.go +++ b/session/lifecycle/lifecycle_test.go @@ -139,24 +139,6 @@ func (m *MockSSHChannel) Close() error { return m.Called().Error(0) } -type mockNewChannel struct { - ssh.NewChannel - mock.Mock -} - -func (m *mockNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) { - args := m.Called() - return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) -} - -func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { - args := m.Called(name, data) - if args.Get(0) == nil { - return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) - } - return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) -} - func TestNew(t *testing.T) { mockSSHConn := new(MockSSHConn) mockForwarder := &MockForwarder{} @@ -297,7 +279,8 @@ func TestLifecycle_Close(t *testing.T) { mockLifecycle.SetChannel(mockSSHChannel) if tt.alreadyClosed { - mockLifecycle.Close() + err := mockLifecycle.Close() + assert.NoError(t, err) } err := mockLifecycle.Close() diff --git a/session/session_test.go b/session/session_test.go index 0e87834..8a89125 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -7,13 +7,13 @@ import ( "encoding/binary" "encoding/pem" "fmt" + "io" "net" "strconv" "strings" "testing" "time" "tunnel_pls/internal/config" - portUtil "tunnel_pls/internal/port" "tunnel_pls/internal/registry" "tunnel_pls/session/lifecycle" "tunnel_pls/types" @@ -122,25 +122,6 @@ func (m *mockSSHConn) User() string { return m.Called().String(0) } -type mockSSHChannel struct { - ssh.Channel - mock.Mock -} - -func (m *mockSSHChannel) Close() error { - return m.Called().Error(0) -} - -type mockNewChannel struct { - ssh.NewChannel - mock.Mock -} - -func (m *mockNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) { - args := m.Called() - return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) -} - func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, sChans <-chan ssh.NewChannel, cConn ssh.Conn, cleanup func()) { l, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -193,7 +174,8 @@ func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, s if newChan.ChannelType() == "session" { continue } - newChan.Reject(ssh.Prohibited, "") + err = newChan.Reject(ssh.Prohibited, "") + assert.NoError(t, err) } }() @@ -205,9 +187,9 @@ func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, s } return sConnObj, sReqsChan, sChansChan, cConnObj, func() { - cConnObj.Close() - sConnObj.Close() - l.Close() + _ = cConnObj.Close() + _ = sConnObj.Close() + _ = l.Close() } } @@ -330,7 +312,9 @@ func TestHandleGlobalRequest(t *testing.T) { }) } - cConn.Close() + err := cConn.Close() + assert.NoError(t, err) + select { case <-done: case <-time.After(2 * time.Second): @@ -505,11 +489,13 @@ func TestStart_Table(t *testing.T) { time.Sleep(200 * time.Millisecond) _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) time.Sleep(200 * time.Millisecond) - ch.Write([]byte("q")) + write, err := ch.Write([]byte("q")) + assert.NoError(t, err) + assert.NotZero(t, write) time.Sleep(100 * time.Millisecond) - ch.Close() + _ = ch.Close() } - cConn.Close() + _ = cConn.Close() }() err := s.Start() @@ -530,9 +516,13 @@ func TestStart_Table(t *testing.T) { go func() { time.Sleep(600 * time.Millisecond) - _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + _, _, err := cConn.SendRequest("tcpip-forward", true, payload) + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) - cConn.Close() + err = cConn.Close() + assert.NoError(t, err) + }() err := s.Start() @@ -545,7 +535,7 @@ func TestStart_Table(t *testing.T) { go func() { time.Sleep(1200 * time.Millisecond) - cConn.Close() + _ = cConn.Close() }() err := s.Start() @@ -554,11 +544,11 @@ func TestStart_Table(t *testing.T) { }) t.Run("Unauthorized Headless", func(t *testing.T) { - s, conf, cConn, cleanup := setup(t) + _, conf, cConn, cleanup := setup(t) defer cleanup() conf.User = "UNAUTHORIZED" - s = New(conf).(*session) + s := New(conf).(*session) payload := make([]byte, 4+9+4) binary.BigEndian.PutUint32(payload[0:4], 9) @@ -738,14 +728,17 @@ func TestForwardingFailures(t *testing.T) { binary.BigEndian.PutUint32(payload[13:17], 80) go func() { - _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + _, _, err := cConn.SendRequest("tcpip-forward", true, payload) + assert.Error(t, err, io.EOF) }() req := <-sReqs - cConn.Close() + err := cConn.Close() + assert.NoError(t, err) + time.Sleep(50 * time.Millisecond) - err := s.HandleTCPIPForward(req) + err = s.HandleTCPIPForward(req) assert.Error(t, err) }) @@ -759,7 +752,10 @@ func TestForwardingFailures(t *testing.T) { if err != nil { t.Fatal(err) } - defer l.Close() + defer func(l net.Listener) { + err = l.Close() + assert.NoError(t, err) + }(l) _, portStr, _ := net.SplitHostPort(l.Addr().String()) port, _ := strconv.Atoi(portStr) @@ -1120,10 +1116,12 @@ func TestDenyForwardingRequest_Full(t *testing.T) { s, _, _, sReqs, cConn, cleanup := setup(t) defer cleanup() req := getReq(t, cConn, sReqs) - cConn.Close() + err := cConn.Close() + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) - err := s.denyForwardingRequest(req, nil, nil, assert.AnError.Error()) + err = s.denyForwardingRequest(req, nil, nil, assert.AnError.Error()) assert.Error(t, err, assert.AnError) }) } @@ -1183,7 +1181,10 @@ func TestHandleTCPForward_Failures(t *testing.T) { if err != nil { t.Fatal(err) } - defer l.Close() + defer func(l net.Listener) { + err = l.Close() + assert.NoError(t, err) + }(l) port := uint16(l.Addr().(*net.TCPAddr).Port) err = s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", port) @@ -1213,10 +1214,11 @@ func TestHandleTCPForward_Failures(t *testing.T) { mPort.On("Claim", mock.Anything).Return(true) mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) req := getReq(t, cConn, sReqs) - cConn.Close() + err := cConn.Close() + assert.NoError(t, err) time.Sleep(100 * time.Millisecond) - err := s.HandleTCPForward(req, "localhost", 0) + err = s.HandleTCPForward(req, "localhost", 0) if err == nil { t.Error("expected error, got nil") } else if !strings.Contains(err.Error(), "Failed to finalize forwarding") { @@ -1318,7 +1320,9 @@ func TestHandleGlobalRequest_Failures(t *testing.T) { }) } - cConn.Close() + err := cConn.Close() + assert.NoError(t, err) + select { case <-done: case <-time.After(2 * time.Second): @@ -1354,19 +1358,3 @@ type mockCloser struct { } func (m *mockCloser) Close() error { return m.err } - -type mockLifecycle struct { - lifecycle.Lifecycle - closeErr error - conn ssh.Conn - user string -} - -func (m *mockLifecycle) Close() error { return m.closeErr } -func (m *mockLifecycle) Connection() ssh.Conn { return m.conn } -func (m *mockLifecycle) User() string { return m.user } -func (m *mockLifecycle) IsActive() bool { return false } -func (m *mockLifecycle) PortRegistry() portUtil.Port { return nil } -func (m *mockLifecycle) SetChannel(ch ssh.Channel) {} -func (m *mockLifecycle) SetStatus(status types.SessionStatus) {} -func (m *mockLifecycle) StartedAt() time.Time { return time.Time{} } diff --git a/session/slug/slug_test.go b/session/slug/slug_test.go index 3e192af..c7af138 100644 --- a/session/slug/slug_test.go +++ b/session/slug/slug_test.go @@ -83,11 +83,6 @@ func (suite *SlugTestSuite) TestMultipleSet() { assert.Equal(suite.T(), "", suite.slug.String()) } -func TestSlugInterface(t *testing.T) { - var _ Slug = (*slug)(nil) - var _ Slug = New() -} - func TestSlugIsolation(t *testing.T) { slug1 := New() slug2 := New()