test: check and handle error for testing
SonarQube Scan / SonarQube Trigger (push) Successful in 3m35s
SonarQube Scan / SonarQube Trigger (push) Successful in 3m35s
This commit is contained in:
@@ -2,13 +2,7 @@ package bootstrap
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
@@ -20,15 +14,11 @@ import (
|
|||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/internal/port"
|
"tunnel_pls/internal/port"
|
||||||
"tunnel_pls/internal/registry"
|
"tunnel_pls/internal/registry"
|
||||||
"tunnel_pls/session/forwarder"
|
|
||||||
"tunnel_pls/session/interaction"
|
|
||||||
"tunnel_pls/session/lifecycle"
|
|
||||||
"tunnel_pls/session/slug"
|
"tunnel_pls/session/slug"
|
||||||
"tunnel_pls/types"
|
"tunnel_pls/types"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -76,35 +66,6 @@ func (m *MockSessionRegistry) Slug() slug.Slug {
|
|||||||
return args.Get(0).(slug.Slug)
|
return args.Get(0).(slug.Slug)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockSession struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(lifecycle.Lifecycle)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Interaction() interaction.Interaction {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(interaction.Interaction)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Forwarder() forwarder.Forwarder {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(forwarder.Forwarder)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Slug() slug.Slug {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(slug.Slug)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Detail() *types.Detail {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(*types.Detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockRandom struct {
|
type MockRandom struct {
|
||||||
mock.Mock
|
mock.Mock
|
||||||
}
|
}
|
||||||
@@ -162,24 +123,24 @@ func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
|||||||
}
|
}
|
||||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
func (m *MockPort) Unassigned() (uint16, bool) {
|
||||||
args := m.Called()
|
args := m.Called()
|
||||||
var port uint16
|
var mPort uint16
|
||||||
if args.Get(0) != nil {
|
if args.Get(0) != nil {
|
||||||
switch v := args.Get(0).(type) {
|
switch v := args.Get(0).(type) {
|
||||||
case int:
|
case int:
|
||||||
port = uint16(v)
|
mPort = uint16(v)
|
||||||
case uint16:
|
case uint16:
|
||||||
port = v
|
mPort = v
|
||||||
case uint32:
|
case uint32:
|
||||||
port = uint16(v)
|
mPort = uint16(v)
|
||||||
case int32:
|
case int32:
|
||||||
port = uint16(v)
|
mPort = uint16(v)
|
||||||
case float64:
|
case float64:
|
||||||
port = uint16(v)
|
mPort = uint16(v)
|
||||||
default:
|
default:
|
||||||
port = uint16(args.Int(0))
|
mPort = uint16(args.Int(0))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return port, args.Bool(1)
|
return mPort, args.Bool(1)
|
||||||
}
|
}
|
||||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
||||||
return m.Called(port, assigned).Error(0)
|
return m.Called(port, assigned).Error(0)
|
||||||
@@ -281,49 +242,17 @@ func TestNew(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateTestCert(t *testing.T) (certPEM, keyPEM []byte) {
|
|
||||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
template := x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
Subject: pkix.Name{
|
|
||||||
Organization: []string{"Test Co"},
|
|
||||||
},
|
|
||||||
NotBefore: time.Now(),
|
|
||||||
NotAfter: time.Now().Add(time.Hour),
|
|
||||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
DNSNames: []string{"localhost"},
|
|
||||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
|
||||||
}
|
|
||||||
|
|
||||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
certPEM = pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "CERTIFICATE",
|
|
||||||
Bytes: certDER,
|
|
||||||
})
|
|
||||||
|
|
||||||
keyPEM = pem.EncodeToMemory(&pem.Block{
|
|
||||||
Type: "RSA PRIVATE KEY",
|
|
||||||
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
|
||||||
})
|
|
||||||
|
|
||||||
return certPEM, keyPEM
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomAvailablePort() (string, error) {
|
func randomAvailablePort() (string, error) {
|
||||||
listener, err := net.Listen("tcp", "localhost:0")
|
listener, err := net.Listen("tcp", "localhost:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer func(listener net.Listener) {
|
||||||
|
_ = listener.Close()
|
||||||
|
}(listener)
|
||||||
|
|
||||||
port := listener.Addr().(*net.TCPAddr).Port
|
mPort := listener.Addr().(*net.TCPAddr).Port
|
||||||
return strconv.Itoa(port), nil
|
return strconv.Itoa(mPort), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRun(t *testing.T) {
|
func TestRun(t *testing.T) {
|
||||||
@@ -346,81 +275,81 @@ func TestRun(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful run and termination",
|
name: "successful run and termination",
|
||||||
setupConfig: func() *MockConfig {
|
setupConfig: func() *MockConfig {
|
||||||
mock := &MockConfig{}
|
mockConfig := &MockConfig{}
|
||||||
mock.On("KeyLoc").Return(keyLoc)
|
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||||
mock.On("Mode").Return(types.ServerModeSTANDALONE)
|
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||||
mock.On("Domain").Return("example.com")
|
mockConfig.On("Domain").Return("example.com")
|
||||||
mock.On("SSHPort").Return("0")
|
mockConfig.On("SSHPort").Return("0")
|
||||||
mock.On("HTTPPort").Return("0")
|
mockConfig.On("HTTPPort").Return("0")
|
||||||
mock.On("HTTPSPort").Return("0")
|
mockConfig.On("HTTPSPort").Return("0")
|
||||||
mock.On("TLSEnabled").Return(false)
|
mockConfig.On("TLSEnabled").Return(false)
|
||||||
mock.On("TLSRedirect").Return(false)
|
mockConfig.On("TLSRedirect").Return(false)
|
||||||
mock.On("ACMEEmail").Return("test@example.com")
|
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||||
mock.On("CFAPIToken").Return("fake-token")
|
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||||
mock.On("ACMEStaging").Return(true)
|
mockConfig.On("ACMEStaging").Return(true)
|
||||||
mock.On("AllowedPortsStart").Return(uint16(1024))
|
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||||
mock.On("AllowedPortsEnd").Return(uint16(65535))
|
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||||
mock.On("BufferSize").Return(4096)
|
mockConfig.On("BufferSize").Return(4096)
|
||||||
mock.On("PprofEnabled").Return(false)
|
mockConfig.On("PprofEnabled").Return(false)
|
||||||
mock.On("PprofPort").Return("0")
|
mockConfig.On("PprofPort").Return("0")
|
||||||
mock.On("GRPCAddress").Return("localhost")
|
mockConfig.On("GRPCAddress").Return("localhost")
|
||||||
mock.On("GRPCPort").Return("0")
|
mockConfig.On("GRPCPort").Return("0")
|
||||||
mock.On("NodeToken").Return("fake-node-token")
|
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||||
return mock
|
return mockConfig
|
||||||
},
|
},
|
||||||
expectError: false,
|
expectError: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error from SSH server invalid port",
|
name: "error from SSH server invalid port",
|
||||||
setupConfig: func() *MockConfig {
|
setupConfig: func() *MockConfig {
|
||||||
mock := &MockConfig{}
|
mockConfig := &MockConfig{}
|
||||||
mock.On("KeyLoc").Return(keyLoc)
|
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||||
mock.On("Mode").Return(types.ServerModeSTANDALONE)
|
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||||
mock.On("Domain").Return("example.com")
|
mockConfig.On("Domain").Return("example.com")
|
||||||
mock.On("SSHPort").Return("invalid")
|
mockConfig.On("SSHPort").Return("invalid")
|
||||||
mock.On("HTTPPort").Return("0")
|
mockConfig.On("HTTPPort").Return("0")
|
||||||
mock.On("HTTPSPort").Return("0")
|
mockConfig.On("HTTPSPort").Return("0")
|
||||||
mock.On("TLSEnabled").Return(false)
|
mockConfig.On("TLSEnabled").Return(false)
|
||||||
mock.On("TLSRedirect").Return(false)
|
mockConfig.On("TLSRedirect").Return(false)
|
||||||
mock.On("ACMEEmail").Return("test@example.com")
|
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||||
mock.On("CFAPIToken").Return("fake-token")
|
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||||
mock.On("ACMEStaging").Return(true)
|
mockConfig.On("ACMEStaging").Return(true)
|
||||||
mock.On("AllowedPortsStart").Return(uint16(1024))
|
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||||
mock.On("AllowedPortsEnd").Return(uint16(65535))
|
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||||
mock.On("BufferSize").Return(4096)
|
mockConfig.On("BufferSize").Return(4096)
|
||||||
mock.On("PprofEnabled").Return(false)
|
mockConfig.On("PprofEnabled").Return(false)
|
||||||
mock.On("PprofPort").Return("0")
|
mockConfig.On("PprofPort").Return("0")
|
||||||
mock.On("GRPCAddress").Return("localhost")
|
mockConfig.On("GRPCAddress").Return("localhost")
|
||||||
mock.On("GRPCPort").Return("0")
|
mockConfig.On("GRPCPort").Return("0")
|
||||||
mock.On("NodeToken").Return("fake-node-token")
|
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||||
return mock
|
return mockConfig
|
||||||
},
|
},
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "error from HTTP server invalid port",
|
name: "error from HTTP server invalid port",
|
||||||
setupConfig: func() *MockConfig {
|
setupConfig: func() *MockConfig {
|
||||||
mock := &MockConfig{}
|
mockConfig := &MockConfig{}
|
||||||
mock.On("KeyLoc").Return(keyLoc)
|
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||||
mock.On("Mode").Return(types.ServerModeSTANDALONE)
|
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||||
mock.On("Domain").Return("example.com")
|
mockConfig.On("Domain").Return("example.com")
|
||||||
mock.On("SSHPort").Return("0")
|
mockConfig.On("SSHPort").Return("0")
|
||||||
mock.On("HTTPPort").Return("invalid")
|
mockConfig.On("HTTPPort").Return("invalid")
|
||||||
mock.On("HTTPSPort").Return("0")
|
mockConfig.On("HTTPSPort").Return("0")
|
||||||
mock.On("TLSEnabled").Return(false)
|
mockConfig.On("TLSEnabled").Return(false)
|
||||||
mock.On("TLSRedirect").Return(false)
|
mockConfig.On("TLSRedirect").Return(false)
|
||||||
mock.On("ACMEEmail").Return("test@example.com")
|
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||||
mock.On("CFAPIToken").Return("fake-token")
|
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||||
mock.On("ACMEStaging").Return(true)
|
mockConfig.On("ACMEStaging").Return(true)
|
||||||
mock.On("AllowedPortsStart").Return(uint16(1024))
|
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||||
mock.On("AllowedPortsEnd").Return(uint16(65535))
|
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||||
mock.On("BufferSize").Return(4096)
|
mockConfig.On("BufferSize").Return(4096)
|
||||||
mock.On("PprofEnabled").Return(false)
|
mockConfig.On("PprofEnabled").Return(false)
|
||||||
mock.On("PprofPort").Return("0")
|
mockConfig.On("PprofPort").Return("0")
|
||||||
mock.On("GRPCAddress").Return("localhost")
|
mockConfig.On("GRPCAddress").Return("localhost")
|
||||||
mock.On("GRPCPort").Return("0")
|
mockConfig.On("GRPCPort").Return("0")
|
||||||
mock.On("NodeToken").Return("fake-node-token")
|
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||||
return mock
|
return mockConfig
|
||||||
},
|
},
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
@@ -428,55 +357,55 @@ func TestRun(t *testing.T) {
|
|||||||
name: "error from HTTPS server invalid port",
|
name: "error from HTTPS server invalid port",
|
||||||
setupConfig: func() *MockConfig {
|
setupConfig: func() *MockConfig {
|
||||||
tempDir := os.TempDir()
|
tempDir := os.TempDir()
|
||||||
mock := &MockConfig{}
|
mockConfig := &MockConfig{}
|
||||||
mock.On("KeyLoc").Return(keyLoc)
|
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||||
mock.On("Mode").Return(types.ServerModeSTANDALONE)
|
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||||
mock.On("Domain").Return("example.com")
|
mockConfig.On("Domain").Return("example.com")
|
||||||
mock.On("SSHPort").Return("0")
|
mockConfig.On("SSHPort").Return("0")
|
||||||
mock.On("HTTPPort").Return("0")
|
mockConfig.On("HTTPPort").Return("0")
|
||||||
mock.On("HTTPSPort").Return("invalid")
|
mockConfig.On("HTTPSPort").Return("invalid")
|
||||||
mock.On("TLSEnabled").Return(true)
|
mockConfig.On("TLSEnabled").Return(true)
|
||||||
mock.On("TLSRedirect").Return(false)
|
mockConfig.On("TLSRedirect").Return(false)
|
||||||
mock.On("TLSStoragePath").Return(tempDir)
|
mockConfig.On("TLSStoragePath").Return(tempDir)
|
||||||
mock.On("ACMEEmail").Return("test@example.com")
|
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||||
mock.On("CFAPIToken").Return("fake-token")
|
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||||
mock.On("ACMEStaging").Return(true)
|
mockConfig.On("ACMEStaging").Return(true)
|
||||||
mock.On("AllowedPortsStart").Return(uint16(1024))
|
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||||
mock.On("AllowedPortsEnd").Return(uint16(65535))
|
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||||
mock.On("BufferSize").Return(4096)
|
mockConfig.On("BufferSize").Return(4096)
|
||||||
mock.On("PprofEnabled").Return(false)
|
mockConfig.On("PprofEnabled").Return(false)
|
||||||
mock.On("PprofPort").Return("0")
|
mockConfig.On("PprofPort").Return("0")
|
||||||
mock.On("GRPCAddress").Return("localhost")
|
mockConfig.On("GRPCAddress").Return("localhost")
|
||||||
mock.On("GRPCPort").Return("0")
|
mockConfig.On("GRPCPort").Return("0")
|
||||||
mock.On("NodeToken").Return("fake-node-token")
|
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||||
return mock
|
return mockConfig
|
||||||
},
|
},
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "grpc health check failed",
|
name: "grpc health check failed",
|
||||||
setupConfig: func() *MockConfig {
|
setupConfig: func() *MockConfig {
|
||||||
mock := &MockConfig{}
|
mockConfig := &MockConfig{}
|
||||||
mock.On("KeyLoc").Return(keyLoc)
|
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||||
mock.On("Mode").Return(types.ServerModeNODE)
|
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||||
mock.On("Domain").Return("example.com")
|
mockConfig.On("Domain").Return("example.com")
|
||||||
mock.On("SSHPort").Return("0")
|
mockConfig.On("SSHPort").Return("0")
|
||||||
mock.On("HTTPPort").Return("0")
|
mockConfig.On("HTTPPort").Return("0")
|
||||||
mock.On("HTTPSPort").Return("0")
|
mockConfig.On("HTTPSPort").Return("0")
|
||||||
mock.On("TLSEnabled").Return(false)
|
mockConfig.On("TLSEnabled").Return(false)
|
||||||
mock.On("TLSRedirect").Return(false)
|
mockConfig.On("TLSRedirect").Return(false)
|
||||||
mock.On("ACMEEmail").Return("test@example.com")
|
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||||
mock.On("CFAPIToken").Return("fake-token")
|
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||||
mock.On("ACMEStaging").Return(true)
|
mockConfig.On("ACMEStaging").Return(true)
|
||||||
mock.On("AllowedPortsStart").Return(uint16(1024))
|
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||||
mock.On("AllowedPortsEnd").Return(uint16(65535))
|
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||||
mock.On("BufferSize").Return(4096)
|
mockConfig.On("BufferSize").Return(4096)
|
||||||
mock.On("PprofEnabled").Return(false)
|
mockConfig.On("PprofEnabled").Return(false)
|
||||||
mock.On("PprofPort").Return("0")
|
mockConfig.On("PprofPort").Return("0")
|
||||||
mock.On("GRPCAddress").Return("localhost")
|
mockConfig.On("GRPCAddress").Return("localhost")
|
||||||
mock.On("GRPCPort").Return("invalid")
|
mockConfig.On("GRPCPort").Return("invalid")
|
||||||
mock.On("NodeToken").Return("fake-node-token")
|
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||||
return mock
|
return mockConfig
|
||||||
},
|
},
|
||||||
setupGrpcClient: func() *MockGRPCClient {
|
setupGrpcClient: func() *MockGRPCClient {
|
||||||
mockGRPCClient := &MockGRPCClient{}
|
mockGRPCClient := &MockGRPCClient{}
|
||||||
@@ -488,54 +417,54 @@ func TestRun(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "successful run with pprof enabled",
|
name: "successful run with pprof enabled",
|
||||||
setupConfig: func() *MockConfig {
|
setupConfig: func() *MockConfig {
|
||||||
mock := &MockConfig{}
|
mockConfig := &MockConfig{}
|
||||||
pprofPort, _ := randomAvailablePort()
|
pprofPort, _ := randomAvailablePort()
|
||||||
mock.On("KeyLoc").Return(keyLoc)
|
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||||
mock.On("Mode").Return(types.ServerModeSTANDALONE)
|
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||||
mock.On("Domain").Return("example.com")
|
mockConfig.On("Domain").Return("example.com")
|
||||||
mock.On("SSHPort").Return("0")
|
mockConfig.On("SSHPort").Return("0")
|
||||||
mock.On("HTTPPort").Return("0")
|
mockConfig.On("HTTPPort").Return("0")
|
||||||
mock.On("HTTPSPort").Return("0")
|
mockConfig.On("HTTPSPort").Return("0")
|
||||||
mock.On("TLSEnabled").Return(false)
|
mockConfig.On("TLSEnabled").Return(false)
|
||||||
mock.On("TLSRedirect").Return(false)
|
mockConfig.On("TLSRedirect").Return(false)
|
||||||
mock.On("ACMEEmail").Return("test@example.com")
|
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||||
mock.On("CFAPIToken").Return("fake-token")
|
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||||
mock.On("ACMEStaging").Return(true)
|
mockConfig.On("ACMEStaging").Return(true)
|
||||||
mock.On("AllowedPortsStart").Return(uint16(1024))
|
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||||
mock.On("AllowedPortsEnd").Return(uint16(65535))
|
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||||
mock.On("BufferSize").Return(4096)
|
mockConfig.On("BufferSize").Return(4096)
|
||||||
mock.On("PprofEnabled").Return(true)
|
mockConfig.On("PprofEnabled").Return(true)
|
||||||
mock.On("PprofPort").Return(pprofPort)
|
mockConfig.On("PprofPort").Return(pprofPort)
|
||||||
mock.On("GRPCAddress").Return("localhost")
|
mockConfig.On("GRPCAddress").Return("localhost")
|
||||||
mock.On("GRPCPort").Return("0")
|
mockConfig.On("GRPCPort").Return("0")
|
||||||
mock.On("NodeToken").Return("fake-node-token")
|
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||||
return mock
|
return mockConfig
|
||||||
},
|
},
|
||||||
expectError: false,
|
expectError: false,
|
||||||
}, {
|
}, {
|
||||||
name: "successful run in NODE mode with signal",
|
name: "successful run in NODE mode with signal",
|
||||||
setupConfig: func() *MockConfig {
|
setupConfig: func() *MockConfig {
|
||||||
mock := &MockConfig{}
|
mockConfig := &MockConfig{}
|
||||||
mock.On("KeyLoc").Return(keyLoc)
|
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||||
mock.On("Mode").Return(types.ServerModeNODE)
|
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||||
mock.On("Domain").Return("example.com")
|
mockConfig.On("Domain").Return("example.com")
|
||||||
mock.On("SSHPort").Return("0")
|
mockConfig.On("SSHPort").Return("0")
|
||||||
mock.On("HTTPPort").Return("0")
|
mockConfig.On("HTTPPort").Return("0")
|
||||||
mock.On("HTTPSPort").Return("0")
|
mockConfig.On("HTTPSPort").Return("0")
|
||||||
mock.On("TLSEnabled").Return(false)
|
mockConfig.On("TLSEnabled").Return(false)
|
||||||
mock.On("TLSRedirect").Return(false)
|
mockConfig.On("TLSRedirect").Return(false)
|
||||||
mock.On("ACMEEmail").Return("test@example.com")
|
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||||
mock.On("CFAPIToken").Return("fake-token")
|
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||||
mock.On("ACMEStaging").Return(true)
|
mockConfig.On("ACMEStaging").Return(true)
|
||||||
mock.On("AllowedPortsStart").Return(uint16(1024))
|
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||||
mock.On("AllowedPortsEnd").Return(uint16(65535))
|
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||||
mock.On("BufferSize").Return(4096)
|
mockConfig.On("BufferSize").Return(4096)
|
||||||
mock.On("PprofEnabled").Return(false)
|
mockConfig.On("PprofEnabled").Return(false)
|
||||||
mock.On("PprofPort").Return("0")
|
mockConfig.On("PprofPort").Return("0")
|
||||||
mock.On("GRPCAddress").Return("localhost")
|
mockConfig.On("GRPCAddress").Return("localhost")
|
||||||
mock.On("GRPCPort").Return("0")
|
mockConfig.On("GRPCPort").Return("0")
|
||||||
mock.On("NodeToken").Return("fake-node-token")
|
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||||
return mock
|
return mockConfig
|
||||||
},
|
},
|
||||||
setupGrpcClient: func() *MockGRPCClient {
|
setupGrpcClient: func() *MockGRPCClient {
|
||||||
mockGRPCClient := &MockGRPCClient{}
|
mockGRPCClient := &MockGRPCClient{}
|
||||||
@@ -548,27 +477,27 @@ func TestRun(t *testing.T) {
|
|||||||
}, {
|
}, {
|
||||||
name: "successful run in NODE mode with signal buf error when closing",
|
name: "successful run in NODE mode with signal buf error when closing",
|
||||||
setupConfig: func() *MockConfig {
|
setupConfig: func() *MockConfig {
|
||||||
mock := &MockConfig{}
|
mockConfig := &MockConfig{}
|
||||||
mock.On("KeyLoc").Return(keyLoc)
|
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||||
mock.On("Mode").Return(types.ServerModeNODE)
|
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||||
mock.On("Domain").Return("example.com")
|
mockConfig.On("Domain").Return("example.com")
|
||||||
mock.On("SSHPort").Return("0")
|
mockConfig.On("SSHPort").Return("0")
|
||||||
mock.On("HTTPPort").Return("0")
|
mockConfig.On("HTTPPort").Return("0")
|
||||||
mock.On("HTTPSPort").Return("0")
|
mockConfig.On("HTTPSPort").Return("0")
|
||||||
mock.On("TLSEnabled").Return(false)
|
mockConfig.On("TLSEnabled").Return(false)
|
||||||
mock.On("TLSRedirect").Return(false)
|
mockConfig.On("TLSRedirect").Return(false)
|
||||||
mock.On("ACMEEmail").Return("test@example.com")
|
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||||
mock.On("CFAPIToken").Return("fake-token")
|
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||||
mock.On("ACMEStaging").Return(true)
|
mockConfig.On("ACMEStaging").Return(true)
|
||||||
mock.On("AllowedPortsStart").Return(uint16(1024))
|
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||||
mock.On("AllowedPortsEnd").Return(uint16(65535))
|
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||||
mock.On("BufferSize").Return(4096)
|
mockConfig.On("BufferSize").Return(4096)
|
||||||
mock.On("PprofEnabled").Return(false)
|
mockConfig.On("PprofEnabled").Return(false)
|
||||||
mock.On("PprofPort").Return("0")
|
mockConfig.On("PprofPort").Return("0")
|
||||||
mock.On("GRPCAddress").Return("localhost")
|
mockConfig.On("GRPCAddress").Return("localhost")
|
||||||
mock.On("GRPCPort").Return("0")
|
mockConfig.On("GRPCPort").Return("0")
|
||||||
mock.On("NodeToken").Return("fake-node-token")
|
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||||
return mock
|
return mockConfig
|
||||||
},
|
},
|
||||||
setupGrpcClient: func() *MockGRPCClient {
|
setupGrpcClient: func() *MockGRPCClient {
|
||||||
mockGRPCClient := &MockGRPCClient{}
|
mockGRPCClient := &MockGRPCClient{}
|
||||||
@@ -613,7 +542,8 @@ func TestRun(t *testing.T) {
|
|||||||
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
|
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 200, resp.StatusCode)
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
mockSignalChan <- os.Interrupt
|
mockSignalChan <- os.Interrupt
|
||||||
err = <-done
|
err = <-done
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@@ -37,7 +37,8 @@ func TestGetenv(t *testing.T) {
|
|||||||
if tt.val != "" {
|
if tt.val != "" {
|
||||||
t.Setenv(tt.key, tt.val)
|
t.Setenv(tt.key, tt.val)
|
||||||
} else {
|
} else {
|
||||||
os.Unsetenv(tt.key)
|
err := os.Unsetenv(tt.key)
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
|
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
|
||||||
})
|
})
|
||||||
@@ -87,7 +88,8 @@ func TestGetenvBool(t *testing.T) {
|
|||||||
if tt.val != "" {
|
if tt.val != "" {
|
||||||
t.Setenv(tt.key, tt.val)
|
t.Setenv(tt.key, tt.val)
|
||||||
} else {
|
} else {
|
||||||
os.Unsetenv(tt.key)
|
err := os.Unsetenv(tt.key)
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
|
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
|
||||||
})
|
})
|
||||||
@@ -113,7 +115,8 @@ func TestParseMode(t *testing.T) {
|
|||||||
if tt.mode != "" {
|
if tt.mode != "" {
|
||||||
t.Setenv("MODE", tt.mode)
|
t.Setenv("MODE", tt.mode)
|
||||||
} else {
|
} else {
|
||||||
os.Unsetenv("MODE")
|
err := os.Unsetenv("MODE")
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
mode, err := parseMode()
|
mode, err := parseMode()
|
||||||
if tt.expectErr {
|
if tt.expectErr {
|
||||||
@@ -148,7 +151,8 @@ func TestParseAllowedPorts(t *testing.T) {
|
|||||||
if tt.val != "" {
|
if tt.val != "" {
|
||||||
t.Setenv("ALLOWED_PORTS", tt.val)
|
t.Setenv("ALLOWED_PORTS", tt.val)
|
||||||
} else {
|
} else {
|
||||||
os.Unsetenv("ALLOWED_PORTS")
|
err := os.Unsetenv("ALLOWED_PORTS")
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
start, end, err := parseAllowedPorts()
|
start, end, err := parseAllowedPorts()
|
||||||
if tt.expectErr {
|
if tt.expectErr {
|
||||||
@@ -180,7 +184,8 @@ func TestParseBufferSize(t *testing.T) {
|
|||||||
if tt.val != "" {
|
if tt.val != "" {
|
||||||
t.Setenv("BUFFER_SIZE", tt.val)
|
t.Setenv("BUFFER_SIZE", tt.val)
|
||||||
} else {
|
} else {
|
||||||
os.Unsetenv("BUFFER_SIZE")
|
err := os.Unsetenv("BUFFER_SIZE")
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
size := parseBufferSize()
|
size := parseBufferSize()
|
||||||
assert.Equal(t, tt.expect, size)
|
assert.Equal(t, tt.expect, size)
|
||||||
@@ -206,7 +211,8 @@ func TestParseHeaderSize(t *testing.T) {
|
|||||||
if tt.val != "" {
|
if tt.val != "" {
|
||||||
t.Setenv("MAX_HEADER_SIZE", tt.val)
|
t.Setenv("MAX_HEADER_SIZE", tt.val)
|
||||||
} else {
|
} else {
|
||||||
os.Unsetenv("MAX_HEADER_SIZE")
|
err := os.Unsetenv("MAX_HEADER_SIZE")
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
size := parseHeaderSize()
|
size := parseHeaderSize()
|
||||||
assert.Equal(t, tt.expect, size)
|
assert.Equal(t, tt.expect, size)
|
||||||
@@ -358,7 +364,10 @@ func TestMustLoad(t *testing.T) {
|
|||||||
t.Run("loadEnvFile error", func(t *testing.T) {
|
t.Run("loadEnvFile error", func(t *testing.T) {
|
||||||
err := os.Mkdir(".env", 0755)
|
err := os.Mkdir(".env", 0755)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer os.Remove(".env")
|
defer func() {
|
||||||
|
err = os.Remove(".env")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
cfg, err := MustLoad()
|
cfg, err := MustLoad()
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -378,7 +387,10 @@ func TestLoadEnvFile(t *testing.T) {
|
|||||||
t.Run("file exists", func(t *testing.T) {
|
t.Run("file exists", func(t *testing.T) {
|
||||||
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
|
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
defer os.Remove(".env")
|
defer func() {
|
||||||
|
err = os.Remove(".env")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
err = loadEnvFile()
|
err = loadEnvFile()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
@@ -744,7 +744,9 @@ func TestNew(t *testing.T) {
|
|||||||
if cli == nil {
|
if cli == nil {
|
||||||
t.Fatal("New() returned nil client")
|
t.Fatal("New() returned nil client")
|
||||||
}
|
}
|
||||||
defer cli.Close()
|
defer func(cli Client) {
|
||||||
|
_ = cli.Close()
|
||||||
|
}(cli)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockConfig struct {
|
type MockConfig struct {
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockRequestHeader struct {
|
type mockRequestHeader struct {
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"testing"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockResponseHeader struct {
|
type mockResponseHeader struct {
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
package port
|
package port
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestAddRange(t *testing.T) {
|
func TestAddRange(t *testing.T) {
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -14,6 +11,10 @@ import (
|
|||||||
"tunnel_pls/session/slug"
|
"tunnel_pls/session/slug"
|
||||||
"tunnel_pls/types"
|
"tunnel_pls/types"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -38,7 +38,8 @@ func TestHTTPServer_Listen(t *testing.T) {
|
|||||||
listener, err := srv.Listen()
|
listener, err := srv.Listen()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, listener)
|
assert.NotNil(t, listener)
|
||||||
listener.Close()
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPServer_Serve(t *testing.T) {
|
func TestHTTPServer_Serve(t *testing.T) {
|
||||||
@@ -54,7 +55,8 @@ func TestHTTPServer_Serve(t *testing.T) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
listener.Close()
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = srv.Serve(listener)
|
err = srv.Serve(listener)
|
||||||
@@ -102,8 +104,12 @@ func TestHTTPServer_Serve_Success(t *testing.T) {
|
|||||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
conn.Close()
|
err = conn.Close()
|
||||||
listener.Close()
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockListener struct {
|
type mockListener struct {
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
"tunnel_pls/internal/port"
|
|
||||||
"tunnel_pls/internal/registry"
|
"tunnel_pls/internal/registry"
|
||||||
"tunnel_pls/session/forwarder"
|
"tunnel_pls/session/forwarder"
|
||||||
"tunnel_pls/session/interaction"
|
"tunnel_pls/session/interaction"
|
||||||
@@ -97,53 +96,6 @@ func (m *MockSession) Detail() *types.Detail {
|
|||||||
return args.Get(0).(*types.Detail)
|
return args.Get(0).(*types.Detail)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockLifecycle struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockLifecycle) Channel() ssh.Channel {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(ssh.Channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockLifecycle) Connection() ssh.Conn {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(ssh.Conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockLifecycle) PortRegistry() port.Port {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(port.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockLifecycle) User() string {
|
|
||||||
args := m.Called()
|
|
||||||
return args.String(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockLifecycle) SetChannel(channel ssh.Channel) {
|
|
||||||
m.Called(channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockLifecycle) SetStatus(status types.SessionStatus) {
|
|
||||||
m.Called(status)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockLifecycle) IsActive() bool {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Bool(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockLifecycle) StartedAt() time.Time {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(time.Time)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockLifecycle) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockSSHChannel struct {
|
type MockSSHChannel struct {
|
||||||
ssh.Channel
|
ssh.Channel
|
||||||
mock.Mock
|
mock.Mock
|
||||||
@@ -678,7 +630,10 @@ func TestHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if clientConn != nil {
|
if clientConn != nil {
|
||||||
defer clientConn.Close()
|
defer func(clientConn net.Conn) {
|
||||||
|
err := clientConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(clientConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||||
|
|||||||
@@ -46,7 +46,8 @@ func TestHTTPSServer_Listen(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
assert.NotNil(t, listener)
|
assert.NotNil(t, listener)
|
||||||
listener.Close()
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPSServer_Serve(t *testing.T) {
|
func TestHTTPSServer_Serve(t *testing.T) {
|
||||||
@@ -62,7 +63,8 @@ func TestHTTPSServer_Serve(t *testing.T) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
listener.Close()
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = srv.Serve(listener)
|
err = srv.Serve(listener)
|
||||||
@@ -111,6 +113,8 @@ func TestHTTPSServer_Serve_Success(t *testing.T) {
|
|||||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
conn.Close()
|
err = conn.Close()
|
||||||
listener.Close()
|
assert.NoError(t, err)
|
||||||
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ func TestTCPServer_Listen(t *testing.T) {
|
|||||||
listener, err := srv.Listen()
|
listener, err := srv.Listen()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, listener)
|
assert.NotNil(t, listener)
|
||||||
listener.Close()
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTCPServer_Serve(t *testing.T) {
|
func TestTCPServer_Serve(t *testing.T) {
|
||||||
@@ -44,7 +45,8 @@ func TestTCPServer_Serve(t *testing.T) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
listener.Close()
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = srv.Serve(listener)
|
err = srv.Serve(listener)
|
||||||
@@ -84,9 +86,10 @@ func TestTCPServer_Serve_Success(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
conn.Close()
|
err = conn.Close()
|
||||||
listener.Close()
|
assert.NoError(t, err)
|
||||||
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
mf.AssertExpectations(t)
|
mf.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -95,7 +98,10 @@ func TestTCPServer_handleTcp_Success(t *testing.T) {
|
|||||||
srv := NewTCPServer(0, mf).(*tcp)
|
srv := NewTCPServer(0, mf).(*tcp)
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
serverConn, clientConn := net.Pipe()
|
||||||
defer clientConn.Close()
|
defer func(clientConn net.Conn) {
|
||||||
|
err := clientConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(clientConn)
|
||||||
|
|
||||||
reqs := make(chan *ssh.Request)
|
reqs := make(chan *ssh.Request)
|
||||||
mockChannel := new(MockSSHChannel)
|
mockChannel := new(MockSSHChannel)
|
||||||
@@ -127,7 +133,10 @@ func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
|
|||||||
srv := NewTCPServer(0, mf).(*tcp)
|
srv := NewTCPServer(0, mf).(*tcp)
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
serverConn, clientConn := net.Pipe()
|
||||||
defer clientConn.Close()
|
defer func(clientConn net.Conn) {
|
||||||
|
err := clientConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(clientConn)
|
||||||
|
|
||||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||||
|
|
||||||
|
|||||||
+124
-61
@@ -80,13 +80,15 @@ func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, so
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
certOut.Close()
|
err = certOut.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
keyOut, err := os.CreateTemp("", "key*.pem")
|
keyOut, err := os.CreateTemp("", "key*.pem")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
keyOut.Close()
|
err = keyOut.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
return certOut.Name(), keyOut.Name()
|
return certOut.Name(), keyOut.Name()
|
||||||
}
|
}
|
||||||
@@ -98,7 +100,8 @@ func setupTestDir(t *testing.T) string {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.RemoveAll(tmpDir)
|
err = os.RemoveAll(tmpDir)
|
||||||
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
return tmpDir
|
return tmpDir
|
||||||
@@ -126,8 +129,11 @@ func TestValidateCertDomains(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
_, err = tmpFile.WriteString("not a pem")
|
_, err = tmpFile.WriteString("not a pem")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
tmpFile.Close()
|
err = tmpFile.Close()
|
||||||
return tmpFile.Name(), func() { os.Remove(tmpFile.Name()) }
|
assert.NoError(t, err)
|
||||||
|
return tmpFile.Name(), func() {
|
||||||
|
_ = os.Remove(tmpFile.Name())
|
||||||
|
}
|
||||||
},
|
},
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
expected: false,
|
expected: false,
|
||||||
@@ -137,8 +143,8 @@ func TestValidateCertDomains(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (string, func()) {
|
setup: func(t *testing.T) (string, func()) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
return certPath, func() {
|
return certPath, func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
@@ -149,8 +155,8 @@ func TestValidateCertDomains(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (string, func()) {
|
setup: func(t *testing.T) (string, func()) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, true, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, true, false)
|
||||||
return certPath, func() {
|
return certPath, func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
@@ -161,8 +167,8 @@ func TestValidateCertDomains(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (string, func()) {
|
setup: func(t *testing.T) (string, func()) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, true)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, true)
|
||||||
return certPath, func() {
|
return certPath, func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
@@ -173,8 +179,8 @@ func TestValidateCertDomains(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (string, func()) {
|
setup: func(t *testing.T) (string, func()) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
|
||||||
return certPath, func() {
|
return certPath, func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
domain: "example.com",
|
domain: "example.com",
|
||||||
@@ -205,8 +211,8 @@ func TestLoadAndParseCertificate(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (string, func()) {
|
setup: func(t *testing.T) (string, func()) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
return certPath, func() {
|
return certPath, func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
wantError: false,
|
wantError: false,
|
||||||
@@ -275,8 +281,14 @@ func TestIsCertificateValid(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, tt.expired, tt.soon)
|
certPath, keyPath := createTestCert(t, "example.com", true, tt.expired, tt.soon)
|
||||||
defer os.Remove(certPath)
|
defer func(name string) {
|
||||||
defer os.Remove(keyPath)
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(certPath)
|
||||||
|
defer func(name string) {
|
||||||
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(keyPath)
|
||||||
|
|
||||||
cert, err := loadAndParseCertificate(certPath)
|
cert, err := loadAndParseCertificate(certPath)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@@ -289,8 +301,14 @@ func TestIsCertificateValid(t *testing.T) {
|
|||||||
|
|
||||||
func TestExtractCertDomains(t *testing.T) {
|
func TestExtractCertDomains(t *testing.T) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
defer os.Remove(certPath)
|
defer func(name string) {
|
||||||
defer os.Remove(keyPath)
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(certPath)
|
||||||
|
defer func(name string) {
|
||||||
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(keyPath)
|
||||||
|
|
||||||
cert, err := loadAndParseCertificate(certPath)
|
cert, err := loadAndParseCertificate(certPath)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@@ -381,8 +399,8 @@ func TestTLSManager_getCertificate(t *testing.T) {
|
|||||||
setup: func(t *testing.T) *tlsManager {
|
setup: func(t *testing.T) *tlsManager {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||||
@@ -447,8 +465,9 @@ func TestTLSManager_userCertsExistAndValid(t *testing.T) {
|
|||||||
mockCfg.On("Domain").Return("example.com")
|
mockCfg.On("Domain").Return("example.com")
|
||||||
|
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() { os.Remove(certPath) })
|
t.Cleanup(func() { _ = os.Remove(certPath) })
|
||||||
os.Remove(keyPath)
|
err := os.Remove(keyPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
return &tlsManager{
|
return &tlsManager{
|
||||||
config: mockCfg,
|
config: mockCfg,
|
||||||
@@ -471,8 +490,14 @@ func TestTLSManager_userCertsExistAndValid(t *testing.T) {
|
|||||||
|
|
||||||
func TestTLSManager_certFilesExist(t *testing.T) {
|
func TestTLSManager_certFilesExist(t *testing.T) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
defer os.Remove(certPath)
|
defer func(name string) {
|
||||||
defer os.Remove(keyPath)
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(certPath)
|
||||||
|
defer func(name string) {
|
||||||
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(keyPath)
|
||||||
|
|
||||||
tm := &tlsManager{
|
tm := &tlsManager{
|
||||||
certPath: certPath,
|
certPath: certPath,
|
||||||
@@ -494,8 +519,8 @@ func TestTLSManager_loadUserCerts(t *testing.T) {
|
|||||||
setup: func(t *testing.T) *tlsManager {
|
setup: func(t *testing.T) *tlsManager {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
return &tlsManager{
|
return &tlsManager{
|
||||||
@@ -550,8 +575,14 @@ func TestCreateTLSManager(t *testing.T) {
|
|||||||
|
|
||||||
func TestNewCertWatcher(t *testing.T) {
|
func TestNewCertWatcher(t *testing.T) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
defer os.Remove(certPath)
|
defer func(name string) {
|
||||||
defer os.Remove(keyPath)
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(certPath)
|
||||||
|
defer func(name string) {
|
||||||
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(keyPath)
|
||||||
|
|
||||||
mockCfg := &MockConfig{}
|
mockCfg := &MockConfig{}
|
||||||
|
|
||||||
@@ -571,8 +602,14 @@ func TestNewCertWatcher(t *testing.T) {
|
|||||||
|
|
||||||
func TestCertWatcher_filesModified(t *testing.T) {
|
func TestCertWatcher_filesModified(t *testing.T) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
defer os.Remove(certPath)
|
defer func(name string) {
|
||||||
defer os.Remove(keyPath)
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(certPath)
|
||||||
|
defer func(name string) {
|
||||||
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(keyPath)
|
||||||
|
|
||||||
mockCfg := &MockConfig{}
|
mockCfg := &MockConfig{}
|
||||||
|
|
||||||
@@ -600,8 +637,14 @@ func TestCertWatcher_filesModified(t *testing.T) {
|
|||||||
|
|
||||||
func TestCertWatcher_updateModTimes(t *testing.T) {
|
func TestCertWatcher_updateModTimes(t *testing.T) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
defer os.Remove(certPath)
|
defer func(name string) {
|
||||||
defer os.Remove(keyPath)
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(certPath)
|
||||||
|
defer func(name string) {
|
||||||
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(keyPath)
|
||||||
|
|
||||||
mockCfg := &MockConfig{}
|
mockCfg := &MockConfig{}
|
||||||
|
|
||||||
@@ -637,8 +680,8 @@ func TestCertWatcher_getFileInfo(t *testing.T) {
|
|||||||
setup: func(t *testing.T) *tlsManager {
|
setup: func(t *testing.T) *tlsManager {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
return &tlsManager{
|
return &tlsManager{
|
||||||
@@ -657,8 +700,9 @@ func TestCertWatcher_getFileInfo(t *testing.T) {
|
|||||||
name: "missing cert file",
|
name: "missing cert file",
|
||||||
setup: func(t *testing.T) *tlsManager {
|
setup: func(t *testing.T) *tlsManager {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
os.Remove(certPath)
|
err := os.Remove(certPath)
|
||||||
t.Cleanup(func() { os.Remove(keyPath) })
|
assert.NoError(t, err)
|
||||||
|
t.Cleanup(func() { _ = os.Remove(keyPath) })
|
||||||
|
|
||||||
return &tlsManager{
|
return &tlsManager{
|
||||||
config: &MockConfig{},
|
config: &MockConfig{},
|
||||||
@@ -672,8 +716,9 @@ func TestCertWatcher_getFileInfo(t *testing.T) {
|
|||||||
name: "missing key file",
|
name: "missing key file",
|
||||||
setup: func(t *testing.T) *tlsManager {
|
setup: func(t *testing.T) *tlsManager {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
os.Remove(keyPath)
|
err := os.Remove(keyPath)
|
||||||
t.Cleanup(func() { os.Remove(certPath) })
|
assert.NoError(t, err)
|
||||||
|
t.Cleanup(func() { _ = os.Remove(certPath) })
|
||||||
|
|
||||||
return &tlsManager{
|
return &tlsManager{
|
||||||
config: &MockConfig{},
|
config: &MockConfig{},
|
||||||
@@ -729,8 +774,8 @@ func TestCertWatcher_checkAndReloadCerts(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
|
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
tm := &tlsManager{
|
tm := &tlsManager{
|
||||||
@@ -747,8 +792,8 @@ func TestCertWatcher_checkAndReloadCerts(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
|
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
mockCfg := &MockConfig{}
|
mockCfg := &MockConfig{}
|
||||||
@@ -792,8 +837,8 @@ func TestCertWatcher_handleCertificateChange(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
|
setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
mockCfg := &MockConfig{}
|
mockCfg := &MockConfig{}
|
||||||
@@ -819,8 +864,8 @@ func TestCertWatcher_handleCertificateChange(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
|
setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
tmpDir := setupTestDir(t)
|
tmpDir := setupTestDir(t)
|
||||||
@@ -850,8 +895,8 @@ func TestCertWatcher_handleCertificateChange(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
|
setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
mockCfg := &MockConfig{}
|
mockCfg := &MockConfig{}
|
||||||
@@ -946,8 +991,8 @@ func TestCertWatcher_watch(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
|
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
tmpDir := setupTestDir(t)
|
tmpDir := setupTestDir(t)
|
||||||
@@ -976,8 +1021,8 @@ func TestCertWatcher_watch(t *testing.T) {
|
|||||||
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
|
setup: func(t *testing.T) (*tlsManager, *certWatcher) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
mockCfg := &MockConfig{}
|
mockCfg := &MockConfig{}
|
||||||
@@ -1006,8 +1051,14 @@ func TestCertWatcher_watch(t *testing.T) {
|
|||||||
|
|
||||||
func TestCertWatcher_watch_Integration(t *testing.T) {
|
func TestCertWatcher_watch_Integration(t *testing.T) {
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
defer os.Remove(certPath)
|
defer func(name string) {
|
||||||
defer os.Remove(keyPath)
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(certPath)
|
||||||
|
defer func(name string) {
|
||||||
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(keyPath)
|
||||||
|
|
||||||
mockCfg := &MockConfig{}
|
mockCfg := &MockConfig{}
|
||||||
mockCfg.On("Domain").Return("example.com")
|
mockCfg.On("Domain").Return("example.com")
|
||||||
@@ -1029,8 +1080,14 @@ func TestCertWatcher_watch_Integration(t *testing.T) {
|
|||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
newCertPath, newKeyPath := createTestCert(t, "example.com", true, false, false)
|
newCertPath, newKeyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
defer os.Remove(newCertPath)
|
defer func(name string) {
|
||||||
defer os.Remove(newKeyPath)
|
err = os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(newCertPath)
|
||||||
|
defer func(name string) {
|
||||||
|
err = os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(newKeyPath)
|
||||||
|
|
||||||
newCertData, err := os.ReadFile(newCertPath)
|
newCertData, err := os.ReadFile(newCertPath)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@@ -1066,8 +1123,8 @@ func TestNewTLSConfig(t *testing.T) {
|
|||||||
|
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
os.Remove(certPath)
|
_ = os.Remove(certPath)
|
||||||
os.Remove(keyPath)
|
_ = os.Remove(keyPath)
|
||||||
})
|
})
|
||||||
|
|
||||||
certData, err := os.ReadFile(certPath)
|
certData, err := os.ReadFile(certPath)
|
||||||
@@ -1140,8 +1197,14 @@ func TestNewTLSConfig_Singleton(t *testing.T) {
|
|||||||
tmpDir := setupTestDir(t)
|
tmpDir := setupTestDir(t)
|
||||||
|
|
||||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||||
defer os.Remove(certPath)
|
defer func(name string) {
|
||||||
defer os.Remove(keyPath)
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(certPath)
|
||||||
|
defer func(name string) {
|
||||||
|
err := os.Remove(name)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(keyPath)
|
||||||
|
|
||||||
certData, err := os.ReadFile(certPath)
|
certData, err := os.ReadFile(certPath)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
+42
-22
@@ -181,14 +181,6 @@ func (m *MockListener) Addr() net.Addr {
|
|||||||
return m.Called().Get(0).(net.Addr)
|
return m.Called().Get(0).(net.Addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockSession struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Start() error {
|
|
||||||
return m.Called().Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
|
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
|
||||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
signer, _ := ssh.NewSignerFromKey(key)
|
signer, _ := ssh.NewSignerFromKey(key)
|
||||||
@@ -244,7 +236,10 @@ func TestNew(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
port := l.Addr().(*net.TCPAddr).Port
|
port := l.Addr().(*net.TCPAddr).Port
|
||||||
defer l.Close()
|
defer func(l net.Listener) {
|
||||||
|
err = l.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(l)
|
||||||
|
|
||||||
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
|
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -360,7 +355,9 @@ func TestStart(t *testing.T) {
|
|||||||
go s.Start()
|
go s.Start()
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
clientConn.Close()
|
err := clientConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
mockListener.AssertExpectations(t)
|
mockListener.AssertExpectations(t)
|
||||||
@@ -394,7 +391,9 @@ func TestStart(t *testing.T) {
|
|||||||
go s.Start()
|
go s.Start()
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
clientConn.Close()
|
err := clientConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
mockListener.AssertExpectations(t)
|
mockListener.AssertExpectations(t)
|
||||||
@@ -423,7 +422,9 @@ func TestHandleConnection(t *testing.T) {
|
|||||||
portRegistry: mockPort,
|
portRegistry: mockPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
clientConn.Close()
|
err := clientConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
s.handleConnection(serverConn)
|
s.handleConnection(serverConn)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -489,7 +490,10 @@ func TestHandleConnection(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer func(listener net.Listener) {
|
||||||
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(listener)
|
||||||
|
|
||||||
serverAddr := listener.Addr().String()
|
serverAddr := listener.Addr().String()
|
||||||
|
|
||||||
@@ -529,7 +533,10 @@ func TestHandleConnection(t *testing.T) {
|
|||||||
t.Logf("Client dial failed: %v", err)
|
t.Logf("Client dial failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer func(client *ssh.Client) {
|
||||||
|
err = client.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(client)
|
||||||
|
|
||||||
type forwardPayload struct {
|
type forwardPayload struct {
|
||||||
BindAddr string
|
BindAddr string
|
||||||
@@ -578,7 +585,10 @@ func TestHandleConnection(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer func(listener net.Listener) {
|
||||||
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(listener)
|
||||||
|
|
||||||
serverAddr := listener.Addr().String()
|
serverAddr := listener.Addr().String()
|
||||||
|
|
||||||
@@ -618,7 +628,10 @@ func TestHandleConnection(t *testing.T) {
|
|||||||
t.Logf("Client dial failed: %v", err)
|
t.Logf("Client dial failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer func(client *ssh.Client) {
|
||||||
|
err = client.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(client)
|
||||||
|
|
||||||
type forwardPayload struct {
|
type forwardPayload struct {
|
||||||
BindAddr string
|
BindAddr string
|
||||||
@@ -667,7 +680,10 @@ func TestHandleConnection(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer func(listener net.Listener) {
|
||||||
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(listener)
|
||||||
|
|
||||||
serverAddr := listener.Addr().String()
|
serverAddr := listener.Addr().String()
|
||||||
|
|
||||||
@@ -707,7 +723,9 @@ func TestHandleConnection(t *testing.T) {
|
|||||||
t.Logf("Client dial failed: %v", err)
|
t.Logf("Client dial failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer client.Close()
|
defer func(client *ssh.Client) {
|
||||||
|
_ = client.Close()
|
||||||
|
}(client)
|
||||||
|
|
||||||
type forwardPayload struct {
|
type forwardPayload struct {
|
||||||
BindAddr string
|
BindAddr string
|
||||||
@@ -762,7 +780,8 @@ func TestHandleConnection(t *testing.T) {
|
|||||||
done <- true
|
done <- true
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientConn.Close()
|
err := clientConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
@@ -824,9 +843,9 @@ func TestIntegration(t *testing.T) {
|
|||||||
go s.Start()
|
go s.Start()
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
conn1Client.Close()
|
_ = conn1Client.Close()
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
conn2Client.Close()
|
_ = conn2Client.Close()
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
mockListener.AssertExpectations(t)
|
mockListener.AssertExpectations(t)
|
||||||
@@ -843,7 +862,8 @@ func TestErrorHandling(t *testing.T) {
|
|||||||
sshConfig, _ := getTestSSHConfig()
|
sshConfig, _ := getTestSSHConfig()
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
serverConn, clientConn := net.Pipe()
|
||||||
clientConn.Close()
|
err := clientConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
s := &server{
|
s := &server{
|
||||||
randomizer: mockRandom,
|
randomizer: mockRandom,
|
||||||
|
|||||||
@@ -219,8 +219,14 @@ func (p *pipeConn) Write(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *pipeConn) Close() error {
|
func (p *pipeConn) Close() error {
|
||||||
p.reader.Close()
|
err := p.reader.Close()
|
||||||
p.writer.Close()
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = p.writer.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -499,7 +505,8 @@ func TestOpenForwardedChannel(t *testing.T) {
|
|||||||
OriginAddr string
|
OriginAddr string
|
||||||
OriginPort uint32
|
OriginPort uint32
|
||||||
}
|
}
|
||||||
ssh.Unmarshal(capturedData, &payload)
|
err = ssh.Unmarshal(capturedData, &payload)
|
||||||
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, tt.wantDestAddr, payload.DestAddr)
|
assert.Equal(t, tt.wantDestAddr, payload.DestAddr)
|
||||||
assert.Equal(t, tt.wantDestPort, payload.DestPort)
|
assert.Equal(t, tt.wantDestPort, payload.DestPort)
|
||||||
assert.Equal(t, tt.wantOrigAddr, payload.OriginAddr)
|
assert.Equal(t, tt.wantOrigAddr, payload.OriginAddr)
|
||||||
@@ -662,8 +669,8 @@ func TestCreateForwardedTCPIPPayload(t *testing.T) {
|
|||||||
OriginPort uint32
|
OriginPort uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
ssh.Unmarshal(payload, &decoded)
|
err := ssh.Unmarshal(payload, &decoded)
|
||||||
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, tt.wantDestAddr, decoded.DestAddr)
|
assert.Equal(t, tt.wantDestAddr, decoded.DestAddr)
|
||||||
assert.Equal(t, tt.wantDestPort, decoded.DestPort)
|
assert.Equal(t, tt.wantDestPort, decoded.DestPort)
|
||||||
assert.Equal(t, tt.wantOriginAddr, decoded.OriginAddr)
|
assert.Equal(t, tt.wantOriginAddr, decoded.OriginAddr)
|
||||||
@@ -1056,10 +1063,8 @@ func TestCopyWithBuffer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if buf, ok := dst.(*bytes.Buffer); ok && !tt.wantErr {
|
if buf, ok := dst.(*bytes.Buffer); ok && !tt.wantErr {
|
||||||
if _, ok := src.(io.Reader); ok {
|
|
||||||
assert.Equal(t, tt.wantBytesCount, int64(buf.Len()))
|
assert.Equal(t, tt.wantBytesCount, int64(buf.Len()))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if mr, ok := src.(*mockReader); ok {
|
if mr, ok := src.(*mockReader); ok {
|
||||||
mr.AssertExpectations(t)
|
mr.AssertExpectations(t)
|
||||||
@@ -1276,7 +1281,10 @@ func TestSetListener(t *testing.T) {
|
|||||||
|
|
||||||
listener := tt.setupListener()
|
listener := tt.setupListener()
|
||||||
if listener != nil {
|
if listener != nil {
|
||||||
defer listener.Close()
|
defer func(listener net.Listener) {
|
||||||
|
err := listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Nil(t, forwarder.Listener())
|
assert.Nil(t, forwarder.Listener())
|
||||||
@@ -1318,7 +1326,10 @@ func TestListener(t *testing.T) {
|
|||||||
|
|
||||||
listener := tt.setupListener()
|
listener := tt.setupListener()
|
||||||
if listener != nil {
|
if listener != nil {
|
||||||
defer listener.Close()
|
defer func(listener net.Listener) {
|
||||||
|
err := listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(listener)
|
||||||
forwarder.SetListener(listener)
|
forwarder.SetListener(listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1356,7 +1367,8 @@ func TestClose(t *testing.T) {
|
|||||||
setupListener: func() net.Listener {
|
setupListener: func() net.Listener {
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
listener.Close()
|
err = listener.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
return listener
|
return listener
|
||||||
},
|
},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
@@ -1477,8 +1489,10 @@ func TestHandleConnectionWithErrors(t *testing.T) {
|
|||||||
return newPipePair()
|
return newPipePair()
|
||||||
},
|
},
|
||||||
simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
|
simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
|
||||||
channel.CloseWrite()
|
err := channel.CloseWrite()
|
||||||
dst.CloseWrite()
|
assert.NoError(t, err)
|
||||||
|
err = dst.CloseWrite()
|
||||||
|
assert.NoError(t, err)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -1491,10 +1505,14 @@ func TestHandleConnectionWithErrors(t *testing.T) {
|
|||||||
return newPipePair()
|
return newPipePair()
|
||||||
},
|
},
|
||||||
simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
|
simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
|
||||||
dst.Close()
|
err := dst.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
time.Sleep(10 * time.Millisecond)
|
time.Sleep(10 * time.Millisecond)
|
||||||
channel.Write([]byte("test"))
|
write, err := channel.Write([]byte("test"))
|
||||||
channel.CloseWrite()
|
assert.NotZero(t, write)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = channel.CloseWrite()
|
||||||
|
assert.NoError(t, err)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -139,24 +139,6 @@ func (m *MockSSHChannel) Close() error {
|
|||||||
return m.Called().Error(0)
|
return m.Called().Error(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockNewChannel struct {
|
|
||||||
ssh.NewChannel
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
|
||||||
args := m.Called(name, data)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
||||||
}
|
|
||||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
func TestNew(t *testing.T) {
|
||||||
mockSSHConn := new(MockSSHConn)
|
mockSSHConn := new(MockSSHConn)
|
||||||
mockForwarder := &MockForwarder{}
|
mockForwarder := &MockForwarder{}
|
||||||
@@ -297,7 +279,8 @@ func TestLifecycle_Close(t *testing.T) {
|
|||||||
mockLifecycle.SetChannel(mockSSHChannel)
|
mockLifecycle.SetChannel(mockSSHChannel)
|
||||||
|
|
||||||
if tt.alreadyClosed {
|
if tt.alreadyClosed {
|
||||||
mockLifecycle.Close()
|
err := mockLifecycle.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := mockLifecycle.Close()
|
err := mockLifecycle.Close()
|
||||||
|
|||||||
+47
-59
@@ -7,13 +7,13 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
portUtil "tunnel_pls/internal/port"
|
|
||||||
"tunnel_pls/internal/registry"
|
"tunnel_pls/internal/registry"
|
||||||
"tunnel_pls/session/lifecycle"
|
"tunnel_pls/session/lifecycle"
|
||||||
"tunnel_pls/types"
|
"tunnel_pls/types"
|
||||||
@@ -122,25 +122,6 @@ func (m *mockSSHConn) User() string {
|
|||||||
return m.Called().String(0)
|
return m.Called().String(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockSSHChannel struct {
|
|
||||||
ssh.Channel
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockSSHChannel) Close() error {
|
|
||||||
return m.Called().Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockNewChannel struct {
|
|
||||||
ssh.NewChannel
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, sChans <-chan ssh.NewChannel, cConn ssh.Conn, cleanup func()) {
|
func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, sChans <-chan ssh.NewChannel, cConn ssh.Conn, cleanup func()) {
|
||||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -193,7 +174,8 @@ func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, s
|
|||||||
if newChan.ChannelType() == "session" {
|
if newChan.ChannelType() == "session" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
newChan.Reject(ssh.Prohibited, "")
|
err = newChan.Reject(ssh.Prohibited, "")
|
||||||
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -205,9 +187,9 @@ func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
return sConnObj, sReqsChan, sChansChan, cConnObj, func() {
|
return sConnObj, sReqsChan, sChansChan, cConnObj, func() {
|
||||||
cConnObj.Close()
|
_ = cConnObj.Close()
|
||||||
sConnObj.Close()
|
_ = sConnObj.Close()
|
||||||
l.Close()
|
_ = l.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,7 +312,9 @@ func TestHandleGlobalRequest(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
cConn.Close()
|
err := cConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
@@ -505,11 +489,13 @@ func TestStart_Table(t *testing.T) {
|
|||||||
time.Sleep(200 * time.Millisecond)
|
time.Sleep(200 * time.Millisecond)
|
||||||
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
||||||
time.Sleep(200 * time.Millisecond)
|
time.Sleep(200 * time.Millisecond)
|
||||||
ch.Write([]byte("q"))
|
write, err := ch.Write([]byte("q"))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotZero(t, write)
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
ch.Close()
|
_ = ch.Close()
|
||||||
}
|
}
|
||||||
cConn.Close()
|
_ = cConn.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := s.Start()
|
err := s.Start()
|
||||||
@@ -530,9 +516,13 @@ func TestStart_Table(t *testing.T) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(600 * time.Millisecond)
|
time.Sleep(600 * time.Millisecond)
|
||||||
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
_, _, err := cConn.SendRequest("tcpip-forward", true, payload)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
cConn.Close()
|
err = cConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := s.Start()
|
err := s.Start()
|
||||||
@@ -545,7 +535,7 @@ func TestStart_Table(t *testing.T) {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
time.Sleep(1200 * time.Millisecond)
|
time.Sleep(1200 * time.Millisecond)
|
||||||
cConn.Close()
|
_ = cConn.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := s.Start()
|
err := s.Start()
|
||||||
@@ -554,11 +544,11 @@ func TestStart_Table(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Unauthorized Headless", func(t *testing.T) {
|
t.Run("Unauthorized Headless", func(t *testing.T) {
|
||||||
s, conf, cConn, cleanup := setup(t)
|
_, conf, cConn, cleanup := setup(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
conf.User = "UNAUTHORIZED"
|
conf.User = "UNAUTHORIZED"
|
||||||
s = New(conf).(*session)
|
s := New(conf).(*session)
|
||||||
|
|
||||||
payload := make([]byte, 4+9+4)
|
payload := make([]byte, 4+9+4)
|
||||||
binary.BigEndian.PutUint32(payload[0:4], 9)
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
||||||
@@ -738,14 +728,17 @@ func TestForwardingFailures(t *testing.T) {
|
|||||||
binary.BigEndian.PutUint32(payload[13:17], 80)
|
binary.BigEndian.PutUint32(payload[13:17], 80)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
_, _, err := cConn.SendRequest("tcpip-forward", true, payload)
|
||||||
|
assert.Error(t, err, io.EOF)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
req := <-sReqs
|
req := <-sReqs
|
||||||
cConn.Close()
|
err := cConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
err := s.HandleTCPIPForward(req)
|
err = s.HandleTCPIPForward(req)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -759,7 +752,10 @@ func TestForwardingFailures(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer func(l net.Listener) {
|
||||||
|
err = l.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(l)
|
||||||
_, portStr, _ := net.SplitHostPort(l.Addr().String())
|
_, portStr, _ := net.SplitHostPort(l.Addr().String())
|
||||||
port, _ := strconv.Atoi(portStr)
|
port, _ := strconv.Atoi(portStr)
|
||||||
|
|
||||||
@@ -1120,10 +1116,12 @@ func TestDenyForwardingRequest_Full(t *testing.T) {
|
|||||||
s, _, _, sReqs, cConn, cleanup := setup(t)
|
s, _, _, sReqs, cConn, cleanup := setup(t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
req := getReq(t, cConn, sReqs)
|
req := getReq(t, cConn, sReqs)
|
||||||
cConn.Close()
|
err := cConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
err := s.denyForwardingRequest(req, nil, nil, assert.AnError.Error())
|
err = s.denyForwardingRequest(req, nil, nil, assert.AnError.Error())
|
||||||
assert.Error(t, err, assert.AnError)
|
assert.Error(t, err, assert.AnError)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -1183,7 +1181,10 @@ func TestHandleTCPForward_Failures(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer l.Close()
|
defer func(l net.Listener) {
|
||||||
|
err = l.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}(l)
|
||||||
port := uint16(l.Addr().(*net.TCPAddr).Port)
|
port := uint16(l.Addr().(*net.TCPAddr).Port)
|
||||||
|
|
||||||
err = s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", port)
|
err = s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", port)
|
||||||
@@ -1213,10 +1214,11 @@ func TestHandleTCPForward_Failures(t *testing.T) {
|
|||||||
mPort.On("Claim", mock.Anything).Return(true)
|
mPort.On("Claim", mock.Anything).Return(true)
|
||||||
mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||||
req := getReq(t, cConn, sReqs)
|
req := getReq(t, cConn, sReqs)
|
||||||
cConn.Close()
|
err := cConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
err := s.HandleTCPForward(req, "localhost", 0)
|
err = s.HandleTCPForward(req, "localhost", 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("expected error, got nil")
|
t.Error("expected error, got nil")
|
||||||
} else if !strings.Contains(err.Error(), "Failed to finalize forwarding") {
|
} else if !strings.Contains(err.Error(), "Failed to finalize forwarding") {
|
||||||
@@ -1318,7 +1320,9 @@ func TestHandleGlobalRequest_Failures(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
cConn.Close()
|
err := cConn.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case <-done:
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
@@ -1354,19 +1358,3 @@ type mockCloser struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockCloser) Close() error { return m.err }
|
func (m *mockCloser) Close() error { return m.err }
|
||||||
|
|
||||||
type mockLifecycle struct {
|
|
||||||
lifecycle.Lifecycle
|
|
||||||
closeErr error
|
|
||||||
conn ssh.Conn
|
|
||||||
user string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockLifecycle) Close() error { return m.closeErr }
|
|
||||||
func (m *mockLifecycle) Connection() ssh.Conn { return m.conn }
|
|
||||||
func (m *mockLifecycle) User() string { return m.user }
|
|
||||||
func (m *mockLifecycle) IsActive() bool { return false }
|
|
||||||
func (m *mockLifecycle) PortRegistry() portUtil.Port { return nil }
|
|
||||||
func (m *mockLifecycle) SetChannel(ch ssh.Channel) {}
|
|
||||||
func (m *mockLifecycle) SetStatus(status types.SessionStatus) {}
|
|
||||||
func (m *mockLifecycle) StartedAt() time.Time { return time.Time{} }
|
|
||||||
|
|||||||
@@ -83,11 +83,6 @@ func (suite *SlugTestSuite) TestMultipleSet() {
|
|||||||
assert.Equal(suite.T(), "", suite.slug.String())
|
assert.Equal(suite.T(), "", suite.slug.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSlugInterface(t *testing.T) {
|
|
||||||
var _ Slug = (*slug)(nil)
|
|
||||||
var _ Slug = New()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSlugIsolation(t *testing.T) {
|
func TestSlugIsolation(t *testing.T) {
|
||||||
slug1 := New()
|
slug1 := New()
|
||||||
slug2 := New()
|
slug2 := New()
|
||||||
|
|||||||
Reference in New Issue
Block a user