test: check and handle error for testing

This commit is contained in:
2026-01-26 18:55:59 +07:00
parent 6def82a095
commit a3f6baa6ae
17 changed files with 505 additions and 516 deletions
+184 -254
View File
@@ -2,13 +2,7 @@ package bootstrap
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/http"
_ "net/http/pprof"
@@ -20,15 +14,11 @@ import (
"tunnel_pls/internal/config"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
@@ -76,35 +66,6 @@ func (m *MockSessionRegistry) Slug() slug.Slug {
return args.Get(0).(slug.Slug)
}
type MockSession struct {
mock.Mock
}
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
args := m.Called()
return args.Get(0).(lifecycle.Lifecycle)
}
func (m *MockSession) Interaction() interaction.Interaction {
args := m.Called()
return args.Get(0).(interaction.Interaction)
}
func (m *MockSession) Forwarder() forwarder.Forwarder {
args := m.Called()
return args.Get(0).(forwarder.Forwarder)
}
func (m *MockSession) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
func (m *MockSession) Detail() *types.Detail {
args := m.Called()
return args.Get(0).(*types.Detail)
}
type MockRandom struct {
mock.Mock
}
@@ -162,24 +123,24 @@ func (m *MockPort) AddRange(startPort, endPort uint16) error {
}
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
var port uint16
var mPort uint16
if args.Get(0) != nil {
switch v := args.Get(0).(type) {
case int:
port = uint16(v)
mPort = uint16(v)
case uint16:
port = v
mPort = v
case uint32:
port = uint16(v)
mPort = uint16(v)
case int32:
port = uint16(v)
mPort = uint16(v)
case float64:
port = uint16(v)
mPort = uint16(v)
default:
port = uint16(args.Int(0))
mPort = uint16(args.Int(0))
}
}
return port, args.Bool(1)
return mPort, args.Bool(1)
}
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
@@ -281,49 +242,17 @@ func TestNew(t *testing.T) {
}
}
func generateTestCert(t *testing.T) (certPEM, keyPEM []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Co"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
return certPEM, keyPEM
}
func randomAvailablePort() (string, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", err
}
defer listener.Close()
defer func(listener net.Listener) {
_ = listener.Close()
}(listener)
port := listener.Addr().(*net.TCPAddr).Port
return strconv.Itoa(port), nil
mPort := listener.Addr().(*net.TCPAddr).Port
return strconv.Itoa(mPort), nil
}
func TestRun(t *testing.T) {
@@ -346,81 +275,81 @@ func TestRun(t *testing.T) {
{
name: "successful run and termination",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: false,
},
{
name: "error from SSH server invalid port",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("invalid")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("invalid")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: true,
},
{
name: "error from HTTP server invalid port",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("invalid")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("invalid")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: true,
},
@@ -428,55 +357,55 @@ func TestRun(t *testing.T) {
name: "error from HTTPS server invalid port",
setupConfig: func() *MockConfig {
tempDir := os.TempDir()
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("invalid")
mock.On("TLSEnabled").Return(true)
mock.On("TLSRedirect").Return(false)
mock.On("TLSStoragePath").Return(tempDir)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("invalid")
mockConfig.On("TLSEnabled").Return(true)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("TLSStoragePath").Return(tempDir)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: true,
},
{
name: "grpc health check failed",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeNODE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("invalid")
mock.On("NodeToken").Return("fake-node-token")
return mock
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("invalid")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
@@ -488,54 +417,54 @@ func TestRun(t *testing.T) {
{
name: "successful run with pprof enabled",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mockConfig := &MockConfig{}
pprofPort, _ := randomAvailablePort()
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(true)
mock.On("PprofPort").Return(pprofPort)
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(true)
mockConfig.On("PprofPort").Return(pprofPort)
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: false,
}, {
name: "successful run in NODE mode with signal",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeNODE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
@@ -548,27 +477,27 @@ func TestRun(t *testing.T) {
}, {
name: "successful run in NODE mode with signal buf error when closing",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeNODE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
@@ -613,7 +542,8 @@ func TestRun(t *testing.T) {
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
resp.Body.Close()
err = resp.Body.Close()
assert.NoError(t, err)
mockSignalChan <- os.Interrupt
err = <-done
assert.NoError(t, err)
+20 -8
View File
@@ -37,7 +37,8 @@ func TestGetenv(t *testing.T) {
if tt.val != "" {
t.Setenv(tt.key, tt.val)
} else {
os.Unsetenv(tt.key)
err := os.Unsetenv(tt.key)
assert.NoError(t, err)
}
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
})
@@ -87,7 +88,8 @@ func TestGetenvBool(t *testing.T) {
if tt.val != "" {
t.Setenv(tt.key, tt.val)
} else {
os.Unsetenv(tt.key)
err := os.Unsetenv(tt.key)
assert.NoError(t, err)
}
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
})
@@ -113,7 +115,8 @@ func TestParseMode(t *testing.T) {
if tt.mode != "" {
t.Setenv("MODE", tt.mode)
} else {
os.Unsetenv("MODE")
err := os.Unsetenv("MODE")
assert.NoError(t, err)
}
mode, err := parseMode()
if tt.expectErr {
@@ -148,7 +151,8 @@ func TestParseAllowedPorts(t *testing.T) {
if tt.val != "" {
t.Setenv("ALLOWED_PORTS", tt.val)
} else {
os.Unsetenv("ALLOWED_PORTS")
err := os.Unsetenv("ALLOWED_PORTS")
assert.NoError(t, err)
}
start, end, err := parseAllowedPorts()
if tt.expectErr {
@@ -180,7 +184,8 @@ func TestParseBufferSize(t *testing.T) {
if tt.val != "" {
t.Setenv("BUFFER_SIZE", tt.val)
} else {
os.Unsetenv("BUFFER_SIZE")
err := os.Unsetenv("BUFFER_SIZE")
assert.NoError(t, err)
}
size := parseBufferSize()
assert.Equal(t, tt.expect, size)
@@ -206,7 +211,8 @@ func TestParseHeaderSize(t *testing.T) {
if tt.val != "" {
t.Setenv("MAX_HEADER_SIZE", tt.val)
} else {
os.Unsetenv("MAX_HEADER_SIZE")
err := os.Unsetenv("MAX_HEADER_SIZE")
assert.NoError(t, err)
}
size := parseHeaderSize()
assert.Equal(t, tt.expect, size)
@@ -358,7 +364,10 @@ func TestMustLoad(t *testing.T) {
t.Run("loadEnvFile error", func(t *testing.T) {
err := os.Mkdir(".env", 0755)
assert.NoError(t, err)
defer os.Remove(".env")
defer func() {
err = os.Remove(".env")
assert.NoError(t, err)
}()
cfg, err := MustLoad()
assert.Error(t, err)
@@ -378,7 +387,10 @@ func TestLoadEnvFile(t *testing.T) {
t.Run("file exists", func(t *testing.T) {
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
assert.NoError(t, err)
defer os.Remove(".env")
defer func() {
err = os.Remove(".env")
assert.NoError(t, err)
}()
err = loadEnvFile()
assert.NoError(t, err)
+3 -1
View File
@@ -744,7 +744,9 @@ func TestNew(t *testing.T) {
if cli == nil {
t.Fatal("New() returned nil client")
}
defer cli.Close()
defer func(cli Client) {
_ = cli.Close()
}(cli)
}
type MockConfig struct {
+3 -2
View File
@@ -1,10 +1,11 @@
package middleware
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type mockRequestHeader struct {
@@ -1,9 +1,10 @@
package middleware
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
)
type mockResponseHeader struct {
+2 -1
View File
@@ -1,8 +1,9 @@
package port
import (
"github.com/stretchr/testify/assert"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddRange(t *testing.T) {
+4 -3
View File
@@ -1,9 +1,6 @@
package registry
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"sync"
"testing"
"time"
@@ -14,6 +11,10 @@ import (
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
+10 -4
View File
@@ -38,7 +38,8 @@ func TestHTTPServer_Listen(t *testing.T) {
listener, err := srv.Listen()
assert.NoError(t, err)
assert.NotNil(t, listener)
listener.Close()
err = listener.Close()
assert.NoError(t, err)
}
func TestHTTPServer_Serve(t *testing.T) {
@@ -54,7 +55,8 @@ func TestHTTPServer_Serve(t *testing.T) {
go func() {
time.Sleep(100 * time.Millisecond)
listener.Close()
err = listener.Close()
assert.NoError(t, err)
}()
err = srv.Serve(listener)
@@ -102,8 +104,12 @@ func TestHTTPServer_Serve_Success(t *testing.T) {
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
time.Sleep(100 * time.Millisecond)
conn.Close()
listener.Close()
err = conn.Close()
assert.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
}
type mockListener struct {
+4 -49
View File
@@ -10,7 +10,6 @@ import (
"sync"
"testing"
"time"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
@@ -97,53 +96,6 @@ func (m *MockSession) Detail() *types.Detail {
return args.Get(0).(*types.Detail)
}
type MockLifecycle struct {
mock.Mock
}
func (m *MockLifecycle) Channel() ssh.Channel {
args := m.Called()
return args.Get(0).(ssh.Channel)
}
func (m *MockLifecycle) Connection() ssh.Conn {
args := m.Called()
return args.Get(0).(ssh.Conn)
}
func (m *MockLifecycle) PortRegistry() port.Port {
args := m.Called()
return args.Get(0).(port.Port)
}
func (m *MockLifecycle) User() string {
args := m.Called()
return args.String(0)
}
func (m *MockLifecycle) SetChannel(channel ssh.Channel) {
m.Called(channel)
}
func (m *MockLifecycle) SetStatus(status types.SessionStatus) {
m.Called(status)
}
func (m *MockLifecycle) IsActive() bool {
args := m.Called()
return args.Bool(0)
}
func (m *MockLifecycle) StartedAt() time.Time {
args := m.Called()
return args.Get(0).(time.Time)
}
func (m *MockLifecycle) Close() error {
args := m.Called()
return args.Error(0)
}
type MockSSHChannel struct {
ssh.Channel
mock.Mock
@@ -678,7 +630,10 @@ func TestHandler(t *testing.T) {
}
if clientConn != nil {
defer clientConn.Close()
defer func(clientConn net.Conn) {
err := clientConn.Close()
assert.NoError(t, err)
}(clientConn)
}
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
+8 -4
View File
@@ -46,7 +46,8 @@ func TestHTTPSServer_Listen(t *testing.T) {
return
}
assert.NotNil(t, listener)
listener.Close()
err = listener.Close()
assert.NoError(t, err)
}
func TestHTTPSServer_Serve(t *testing.T) {
@@ -62,7 +63,8 @@ func TestHTTPSServer_Serve(t *testing.T) {
go func() {
time.Sleep(100 * time.Millisecond)
listener.Close()
err = listener.Close()
assert.NoError(t, err)
}()
err = srv.Serve(listener)
@@ -111,6 +113,8 @@ func TestHTTPSServer_Serve_Success(t *testing.T) {
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
time.Sleep(100 * time.Millisecond)
conn.Close()
listener.Close()
err = conn.Close()
assert.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
}
+16 -7
View File
@@ -32,7 +32,8 @@ func TestTCPServer_Listen(t *testing.T) {
listener, err := srv.Listen()
assert.NoError(t, err)
assert.NotNil(t, listener)
listener.Close()
err = listener.Close()
assert.NoError(t, err)
}
func TestTCPServer_Serve(t *testing.T) {
@@ -44,7 +45,8 @@ func TestTCPServer_Serve(t *testing.T) {
go func() {
time.Sleep(100 * time.Millisecond)
listener.Close()
err = listener.Close()
assert.NoError(t, err)
}()
err = srv.Serve(listener)
@@ -84,9 +86,10 @@ func TestTCPServer_Serve_Success(t *testing.T) {
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
conn.Close()
listener.Close()
err = conn.Close()
assert.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
mf.AssertExpectations(t)
}
@@ -95,7 +98,10 @@ func TestTCPServer_handleTcp_Success(t *testing.T) {
srv := NewTCPServer(0, mf).(*tcp)
serverConn, clientConn := net.Pipe()
defer clientConn.Close()
defer func(clientConn net.Conn) {
err := clientConn.Close()
assert.NoError(t, err)
}(clientConn)
reqs := make(chan *ssh.Request)
mockChannel := new(MockSSHChannel)
@@ -127,7 +133,10 @@ func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
srv := NewTCPServer(0, mf).(*tcp)
serverConn, clientConn := net.Pipe()
defer clientConn.Close()
defer func(clientConn net.Conn) {
err := clientConn.Close()
assert.NoError(t, err)
}(clientConn)
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
+124 -61
View File
@@ -80,13 +80,15 @@ func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, so
assert.NoError(t, err)
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
assert.NoError(t, err)
certOut.Close()
err = certOut.Close()
assert.NoError(t, err)
keyOut, err := os.CreateTemp("", "key*.pem")
assert.NoError(t, err)
err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
assert.NoError(t, err)
keyOut.Close()
err = keyOut.Close()
assert.NoError(t, err)
return certOut.Name(), keyOut.Name()
}
@@ -98,7 +100,8 @@ func setupTestDir(t *testing.T) string {
assert.NoError(t, err)
t.Cleanup(func() {
os.RemoveAll(tmpDir)
err = os.RemoveAll(tmpDir)
assert.NoError(t, err)
})
return tmpDir
@@ -126,8 +129,11 @@ func TestValidateCertDomains(t *testing.T) {
assert.NoError(t, err)
_, err = tmpFile.WriteString("not a pem")
assert.NoError(t, err)
tmpFile.Close()
return tmpFile.Name(), func() { os.Remove(tmpFile.Name()) }
err = tmpFile.Close()
assert.NoError(t, err)
return tmpFile.Name(), func() {
_ = os.Remove(tmpFile.Name())
}
},
domain: "example.com",
expected: false,
@@ -137,8 +143,8 @@ func TestValidateCertDomains(t *testing.T) {
setup: func(t *testing.T) (string, func()) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
return certPath, func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
}
},
domain: "example.com",
@@ -149,8 +155,8 @@ func TestValidateCertDomains(t *testing.T) {
setup: func(t *testing.T) (string, func()) {
certPath, keyPath := createTestCert(t, "example.com", true, true, false)
return certPath, func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
}
},
domain: "example.com",
@@ -161,8 +167,8 @@ func TestValidateCertDomains(t *testing.T) {
setup: func(t *testing.T) (string, func()) {
certPath, keyPath := createTestCert(t, "example.com", true, false, true)
return certPath, func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
}
},
domain: "example.com",
@@ -173,8 +179,8 @@ func TestValidateCertDomains(t *testing.T) {
setup: func(t *testing.T) (string, func()) {
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
return certPath, func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
}
},
domain: "example.com",
@@ -205,8 +211,8 @@ func TestLoadAndParseCertificate(t *testing.T) {
setup: func(t *testing.T) (string, func()) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
return certPath, func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
}
},
wantError: false,
@@ -275,8 +281,14 @@ func TestIsCertificateValid(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, tt.expired, tt.soon)
defer os.Remove(certPath)
defer os.Remove(keyPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(certPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(keyPath)
cert, err := loadAndParseCertificate(certPath)
assert.NoError(t, err)
@@ -289,8 +301,14 @@ func TestIsCertificateValid(t *testing.T) {
func TestExtractCertDomains(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(certPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(keyPath)
cert, err := loadAndParseCertificate(certPath)
assert.NoError(t, err)
@@ -381,8 +399,8 @@ func TestTLSManager_getCertificate(t *testing.T) {
setup: func(t *testing.T) *tlsManager {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
@@ -447,8 +465,9 @@ func TestTLSManager_userCertsExistAndValid(t *testing.T) {
mockCfg.On("Domain").Return("example.com")
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() { os.Remove(certPath) })
os.Remove(keyPath)
t.Cleanup(func() { _ = os.Remove(certPath) })
err := os.Remove(keyPath)
assert.NoError(t, err)
return &tlsManager{
config: mockCfg,
@@ -471,8 +490,14 @@ func TestTLSManager_userCertsExistAndValid(t *testing.T) {
func TestTLSManager_certFilesExist(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(certPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(keyPath)
tm := &tlsManager{
certPath: certPath,
@@ -494,8 +519,8 @@ func TestTLSManager_loadUserCerts(t *testing.T) {
setup: func(t *testing.T) *tlsManager {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
return &tlsManager{
@@ -550,8 +575,14 @@ func TestCreateTLSManager(t *testing.T) {
func TestNewCertWatcher(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(certPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(keyPath)
mockCfg := &MockConfig{}
@@ -571,8 +602,14 @@ func TestNewCertWatcher(t *testing.T) {
func TestCertWatcher_filesModified(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(certPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(keyPath)
mockCfg := &MockConfig{}
@@ -600,8 +637,14 @@ func TestCertWatcher_filesModified(t *testing.T) {
func TestCertWatcher_updateModTimes(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(certPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(keyPath)
mockCfg := &MockConfig{}
@@ -637,8 +680,8 @@ func TestCertWatcher_getFileInfo(t *testing.T) {
setup: func(t *testing.T) *tlsManager {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
return &tlsManager{
@@ -657,8 +700,9 @@ func TestCertWatcher_getFileInfo(t *testing.T) {
name: "missing cert file",
setup: func(t *testing.T) *tlsManager {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
os.Remove(certPath)
t.Cleanup(func() { os.Remove(keyPath) })
err := os.Remove(certPath)
assert.NoError(t, err)
t.Cleanup(func() { _ = os.Remove(keyPath) })
return &tlsManager{
config: &MockConfig{},
@@ -672,8 +716,9 @@ func TestCertWatcher_getFileInfo(t *testing.T) {
name: "missing key file",
setup: func(t *testing.T) *tlsManager {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
os.Remove(keyPath)
t.Cleanup(func() { os.Remove(certPath) })
err := os.Remove(keyPath)
assert.NoError(t, err)
t.Cleanup(func() { _ = os.Remove(certPath) })
return &tlsManager{
config: &MockConfig{},
@@ -729,8 +774,8 @@ func TestCertWatcher_checkAndReloadCerts(t *testing.T) {
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
tm := &tlsManager{
@@ -747,8 +792,8 @@ func TestCertWatcher_checkAndReloadCerts(t *testing.T) {
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
mockCfg := &MockConfig{}
@@ -792,8 +837,8 @@ func TestCertWatcher_handleCertificateChange(t *testing.T) {
setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
mockCfg := &MockConfig{}
@@ -819,8 +864,8 @@ func TestCertWatcher_handleCertificateChange(t *testing.T) {
setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
tmpDir := setupTestDir(t)
@@ -850,8 +895,8 @@ func TestCertWatcher_handleCertificateChange(t *testing.T) {
setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
mockCfg := &MockConfig{}
@@ -946,8 +991,8 @@ func TestCertWatcher_watch(t *testing.T) {
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
tmpDir := setupTestDir(t)
@@ -976,8 +1021,8 @@ func TestCertWatcher_watch(t *testing.T) {
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
mockCfg := &MockConfig{}
@@ -1006,8 +1051,14 @@ func TestCertWatcher_watch(t *testing.T) {
func TestCertWatcher_watch_Integration(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(certPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(keyPath)
mockCfg := &MockConfig{}
mockCfg.On("Domain").Return("example.com")
@@ -1029,8 +1080,14 @@ func TestCertWatcher_watch_Integration(t *testing.T) {
time.Sleep(50 * time.Millisecond)
newCertPath, newKeyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(newCertPath)
defer os.Remove(newKeyPath)
defer func(name string) {
err = os.Remove(name)
assert.NoError(t, err)
}(newCertPath)
defer func(name string) {
err = os.Remove(name)
assert.NoError(t, err)
}(newKeyPath)
newCertData, err := os.ReadFile(newCertPath)
assert.NoError(t, err)
@@ -1066,8 +1123,8 @@ func TestNewTLSConfig(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
t.Cleanup(func() {
os.Remove(certPath)
os.Remove(keyPath)
_ = os.Remove(certPath)
_ = os.Remove(keyPath)
})
certData, err := os.ReadFile(certPath)
@@ -1140,8 +1197,14 @@ func TestNewTLSConfig_Singleton(t *testing.T) {
tmpDir := setupTestDir(t)
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(certPath)
defer func(name string) {
err := os.Remove(name)
assert.NoError(t, err)
}(keyPath)
certData, err := os.ReadFile(certPath)
assert.NoError(t, err)