test(bootstrap): add unit tests for initial bootstrap behavior

This commit is contained in:
2026-01-24 15:42:30 +07:00
parent d0e052524c
commit 2f5c44ff01
9 changed files with 754 additions and 89 deletions
+41 -43
View File
@@ -28,27 +28,35 @@ type Bootstrap struct {
Config config.Config
SessionRegistry registry.Registry
Port port.Port
GrpcClient client.Client
ErrChan chan error
SignalChan chan os.Signal
}
func New() (*Bootstrap, error) {
conf, err := config.MustLoad()
func New(config config.Config, port port.Port) (*Bootstrap, error) {
randomizer := random.New()
sessionRegistry := registry.NewRegistry()
if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil {
return nil, err
}
grpcClient, err := client.New(config, sessionRegistry)
if err != nil {
return nil, err
}
randomizer := random.New()
sessionRegistry := registry.NewRegistry()
portManager := port.New()
if err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()); err != nil {
return nil, err
}
errChan := make(chan error, 5)
signalChan := make(chan os.Signal, 1)
return &Bootstrap{
Randomizer: randomizer,
Config: conf,
Config: config,
SessionRegistry: sessionRegistry,
Port: portManager,
Port: port,
GrpcClient: grpcClient,
ErrChan: errChan,
SignalChan: signalChan,
}, nil
}
@@ -73,25 +81,20 @@ func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) {
return sshCfg, nil
}
func startGRPCClient(ctx context.Context, conf config.Config, registry registry.Registry, errChan chan<- error) (client.Client, error) {
grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort())
grpcClient, err := client.New(conf, grpcAddr, registry)
if err != nil {
return nil, err
}
func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error {
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
defer healthCancel()
if err = grpcClient.CheckServerHealth(healthCtx); err != nil {
return nil, fmt.Errorf("gRPC health check failed: %w", err)
if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil {
return fmt.Errorf("gRPC health check failed: %w", err)
}
go func() {
if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
}
}()
return grpcClient, nil
return nil
}
func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
@@ -115,7 +118,7 @@ func startHTTPSServer(conf config.Config, registry registry.Registry, errChan ch
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), registry, conf.TLSRedirect(), tlsCfg)
ln, err := httpsServer.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start https server: %w", err)
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
return
}
if err = httpsServer.Serve(ln); err != nil {
@@ -123,25 +126,25 @@ func startHTTPSServer(conf config.Config, registry registry.Registry, errChan ch
}
}
func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, sshPort string) error {
sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, sshPort)
func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) {
sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort())
if err != nil {
return err
errChan <- err
return
}
sshServer.Start()
return sshServer.Close()
errChan <- sshServer.Close()
}
func startPprof(pprofPort string) {
func startPprof(pprofPort string, errChan chan<- error) {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
errChan <- fmt.Errorf("pprof server error: %v", err)
}
}
func (b *Bootstrap) Run() error {
sshConfig, err := newSSHConfig(b.Config.KeyLoc())
if err != nil {
@@ -151,13 +154,10 @@ func (b *Bootstrap) Run() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errChan := make(chan error, 5)
shutdownChan := make(chan os.Signal, 1)
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM)
var grpcClient client.Client
if b.Config.Mode() == types.ServerModeNODE {
grpcClient, err = startGRPCClient(ctx, b.Config, b.SessionRegistry, errChan)
err = b.startGRPCClient(ctx, b.Config, b.ErrChan)
if err != nil {
return fmt.Errorf("failed to start gRPC client: %w", err)
}
@@ -166,31 +166,29 @@ func (b *Bootstrap) Run() error {
if err != nil {
log.Printf("failed to close gRPC client")
}
}(grpcClient)
}(b.GrpcClient)
}
go startHTTPServer(b.Config, b.SessionRegistry, errChan)
go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan)
if b.Config.TLSEnabled() {
go startHTTPSServer(b.Config, b.SessionRegistry, errChan)
go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan)
}
go func() {
if err = startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, grpcClient, b.Port, b.Config.SSHPort()); err != nil {
errChan <- fmt.Errorf("SSH server error: %w", err)
}
startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan)
}()
if b.Config.PprofEnabled() {
go startPprof(b.Config.PprofPort())
go startPprof(b.Config.PprofPort(), b.ErrChan)
}
log.Println("All services started successfully")
select {
case err = <-errChan:
case err = <-b.ErrChan:
return fmt.Errorf("service error: %w", err)
case sig := <-shutdownChan:
case sig := <-b.SignalChan:
log.Printf("Received signal %s, initiating graceful shutdown", sig)
cancel()
return nil
+627
View File
@@ -0,0 +1,627 @@
package bootstrap
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/http"
_ "net/http/pprof"
"os"
"path/filepath"
"strconv"
"testing"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockSession struct {
mock.Mock
}
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
args := m.Called()
return args.Get(0).(lifecycle.Lifecycle)
}
func (m *MockSession) Interaction() interaction.Interaction {
args := m.Called()
return args.Get(0).(interaction.Interaction)
}
func (m *MockSession) Forwarder() forwarder.Forwarder {
args := m.Called()
return args.Get(0).(forwarder.Forwarder)
}
func (m *MockSession) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
func (m *MockSession) Detail() *types.Detail {
args := m.Called()
return args.Get(0).(*types.Detail)
}
type MockRandom struct {
mock.Mock
}
func (m *MockRandom) String(length int) (string, error) {
args := m.Called(length)
return args.String(0), args.Error(1)
}
type MockConfig struct {
mock.Mock
}
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode {
args := m.Called()
if args.Get(0) == nil {
return 0
}
switch v := args.Get(0).(type) {
case types.ServerMode:
return v
case int:
return types.ServerMode(v)
default:
return types.ServerMode(args.Int(0))
}
}
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type MockPort struct {
mock.Mock
}
func (m *MockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
var port uint16
if args.Get(0) != nil {
switch v := args.Get(0).(type) {
case int:
port = uint16(v)
case uint16:
port = v
case uint32:
port = uint16(v)
case int32:
port = uint16(v)
case float64:
port = uint16(v)
default:
port = uint16(args.Int(0))
}
}
return port, args.Bool(1)
}
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *MockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type MockGRPCClient struct {
mock.Mock
}
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
args := m.Called()
return args.Get(0).(*grpc.ClientConn)
}
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
m.Called()
return
}
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
args := m.Called(ctx, domain, token)
return args.Error(0)
}
func (m *MockGRPCClient) Close() error {
args := m.Called()
return args.Error(0)
}
func TestNew(t *testing.T) {
tests := []struct {
name string
setupConfig func() config.Config
setupPort func() port.Port
wantErr bool
errContains string
}{
{
name: "Success New with default value",
wantErr: false,
},
{
name: "Error when AddRange fails",
setupPort: func() port.Port {
mockPort := &MockPort{}
mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range"))
return mockPort
},
wantErr: true,
errContains: "invalid port range",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var mockPort port.Port
if tt.setupPort != nil {
mockPort = tt.setupPort()
} else {
mockPort = port.New()
}
var mockConfig config.Config
if tt.setupConfig != nil {
mockConfig = tt.setupConfig()
} else {
var err error
mockConfig, err = config.MustLoad()
assert.NoError(t, err)
}
bootstrap, err := New(mockConfig, mockPort)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, bootstrap)
} else {
assert.NoError(t, err)
assert.NotNil(t, bootstrap)
assert.NotNil(t, bootstrap.Randomizer)
assert.NotNil(t, bootstrap.SessionRegistry)
assert.NotNil(t, bootstrap.Config)
assert.NotNil(t, bootstrap.Port)
assert.NotNil(t, bootstrap.ErrChan)
assert.NotNil(t, bootstrap.SignalChan)
}
})
}
}
func generateTestCert(t *testing.T) (certPEM, keyPEM []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Co"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
return certPEM, keyPEM
}
func randomAvailablePort() (string, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", err
}
defer listener.Close()
port := listener.Addr().(*net.TCPAddr).Port
return strconv.Itoa(port), nil
}
func TestRun(t *testing.T) {
mockRandom := &MockRandom{}
mockErrChan := make(chan error, 1)
mockSignalChan := make(chan os.Signal, 1)
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
tmpDir := t.TempDir()
keyLoc := filepath.Join(tmpDir, "key.key")
tests := []struct {
name string
setupConfig func() *MockConfig
setupGrpcClient func() *MockGRPCClient
needCerts bool
expectError bool
}{
{
name: "successful run and termination",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: false,
},
{
name: "error from SSH server invalid port",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("invalid")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: true,
},
{
name: "error from HTTP server invalid port",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("invalid")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: true,
},
{
name: "error from HTTPS server invalid port",
setupConfig: func() *MockConfig {
tempDir := os.TempDir()
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("invalid")
mock.On("TLSEnabled").Return(true)
mock.On("TLSRedirect").Return(false)
mock.On("TLSStoragePath").Return(tempDir)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: true,
},
{
name: "grpc health check failed",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeNODE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("invalid")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed"))
return mockGRPCClient
},
expectError: true,
},
{
name: "successful run with pprof enabled",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
pprofPort, _ := randomAvailablePort()
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(true)
mock.On("PprofPort").Return(pprofPort)
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: false,
}, {
name: "successful run in NODE mode with signal",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeNODE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockGRPCClient.On("Close").Return(nil)
return mockGRPCClient
},
expectError: false,
}, {
name: "successful run in NODE mode with signal buf error when closing",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeNODE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy"))
return mockGRPCClient
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConfig := tt.setupConfig()
mockGRPCClient := &MockGRPCClient{}
bootstrap := &Bootstrap{
Randomizer: mockRandom,
Config: mockConfig,
SessionRegistry: mockSessionRegistry,
Port: mockPort,
ErrChan: mockErrChan,
SignalChan: mockSignalChan,
GrpcClient: mockGRPCClient,
}
if tt.setupGrpcClient != nil {
bootstrap.GrpcClient = tt.setupGrpcClient()
}
done := make(chan error, 1)
go func() {
done <- bootstrap.Run()
}()
if tt.expectError {
err := <-done
assert.Error(t, err)
} else if tt.name == "successful run with pprof enabled" {
time.Sleep(200 * time.Millisecond)
fmt.Println(mockConfig.PprofPort())
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
resp.Body.Close()
mockSignalChan <- os.Interrupt
err = <-done
assert.NoError(t, err)
} else {
time.Sleep(time.Second)
mockSignalChan <- os.Interrupt
err := <-done
assert.NoError(t, err)
}
})
}
}
+2
View File
@@ -13,6 +13,7 @@ type Config interface {
TLSEnabled() bool
TLSRedirect() bool
TLSStoragePath() string
ACMEEmail() string
CFAPIToken() string
@@ -52,6 +53,7 @@ func (c *config) HTTPSPort() string { return c.httpsPort }
func (c *config) KeyLoc() string { return c.keyLoc }
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
func (c *config) TLSStoragePath() string { return c.tlsStoragePath }
func (c *config) ACMEEmail() string { return c.acmeEmail }
func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging }
+6 -6
View File
@@ -20,12 +20,12 @@ type config struct {
keyLoc string
tlsEnabled bool
tlsRedirect bool
acmeEmail string
cfAPIToken string
acmeStaging bool
tlsEnabled bool
tlsRedirect bool
tlsStoragePath string
acmeEmail string
cfAPIToken string
acmeStaging bool
allowedPortsStart uint16
allowedPortsEnd uint16
+3 -1
View File
@@ -44,7 +44,9 @@ var (
initialBackoff = time.Second
)
func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) {
func New(config config.Config, sessionRegistry registry.Registry) (Client, error) {
address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort())
var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
+35 -11
View File
@@ -8,7 +8,6 @@ import (
"testing"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
@@ -16,6 +15,7 @@ import (
"tunnel_pls/types"
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/health/grpc_health_v1"
@@ -382,7 +382,8 @@ func TestProcessEventStream(t *testing.T) {
mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
c.sessionRegistry = mockReg
c.config = &mockConfig{domain: "test.com"}
c.config = &MockConfig{}
c.config.(*MockConfig).On("Domain").Return("test.com")
mockStream.sendFunc = func(n *proto.Node) error { return nil }
err := c.processEventStream(mockStream)
@@ -541,7 +542,8 @@ func TestHandleSlugChange(t *testing.T) {
func TestHandleGetSessions(t *testing.T) {
mockReg := &mockRegistry{}
mockStream := &mockSubscribeClient{}
mockCfg := &mockConfig{domain: "test.com"}
mockCfg := &MockConfig{}
mockCfg.On("Domain").Return("test.com")
c := &client{sessionRegistry: mockReg, config: mockCfg}
evt := &proto.Events{
@@ -840,8 +842,11 @@ func TestNew_Error(t *testing.T) {
return nil, errors.New("dial fail")
}
defer func() { grpcNewClient = old }()
mockConfig := &MockConfig{}
cli, err := New(&mockConfig{}, "localhost:1234", &mockRegistry{})
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("1234")
cli, err := New(mockConfig, &mockRegistry{})
if err == nil || err.Error() != "failed to connect to gRPC server at localhost:1234: dial fail" {
t.Errorf("expected dial fail error, got %v", err)
}
@@ -851,10 +856,11 @@ func TestNew_Error(t *testing.T) {
}
func TestNew(t *testing.T) {
mockCfg := &mockConfig{}
mockConfig := &MockConfig{}
mockReg := &mockRegistry{}
cli, err := New(mockCfg, "localhost:1234", mockReg)
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("1234")
cli, err := New(mockConfig, mockReg)
if err != nil {
t.Errorf("New() error = %v", err)
}
@@ -864,12 +870,30 @@ func TestNew(t *testing.T) {
defer cli.Close()
}
type mockConfig struct {
config.Config
domain string
type MockConfig struct {
mock.Mock
}
func (m *mockConfig) Domain() string { return m.domain }
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type mockRegistry struct {
registry.Registry
+8 -4
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"log"
"os"
"path/filepath"
"sync"
"time"
"tunnel_pls/internal/config"
@@ -47,15 +48,18 @@ func NewTLSConfig(config config.Config) (*tls.Config, error) {
var initErr error
tlsManagerOnce.Do(func() {
certPath := "certs/tls/cert.pem"
keyPath := "certs/tls/privkey.pem"
storagePath := "certs/tls/certmagic"
storagePath := config.TLSStoragePath()
cleanBase := filepath.Clean(storagePath)
certPath := filepath.Join(cleanBase, "cert.pem")
keyPath := filepath.Join(cleanBase, "privkey.pem")
storagePathCertMagic := filepath.Join(cleanBase, "certmagic")
tm := &tlsManager{
config: config,
certPath: certPath,
keyPath: keyPath,
storagePath: storagePath,
storagePath: storagePathCertMagic,
}
if tm.userCertsExistAndValid() {
+8 -1
View File
@@ -5,6 +5,8 @@ import (
"log"
"os"
"tunnel_pls/internal/bootstrap"
"tunnel_pls/internal/config"
"tunnel_pls/internal/port"
"tunnel_pls/internal/version"
)
@@ -18,7 +20,12 @@ func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Printf("Starting %s", version.GetVersion())
boot, err := bootstrap.New()
conf, err := config.MustLoad()
if err != nil {
log.Fatalf("Config load error: %v", err)
}
boot, err := bootstrap.New(conf, port.New())
if err != nil {
log.Fatalf("Startup error: %v", err)
}
+24 -23
View File
@@ -27,29 +27,30 @@ func (m *mockRandom) String(length int) (string, error) {
return args.String(0), args.Error(1)
}
type mockConfig struct {
type MockConfig struct {
mock.Mock
}
func (m *mockConfig) Domain() string { return m.Called().String(0) }
func (m *mockConfig) SSHPort() string { return m.Called().String(0) }
func (m *mockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *mockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *mockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *mockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *mockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *mockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *mockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *mockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *mockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *mockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *mockConfig) PprofPort() string { return m.Called().String(0) }
func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *mockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *mockConfig) NodeToken() string { return m.Called().String(0) }
func (m *mockConfig) KeyLoc() string { return m.Called().String(0) }
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type mockRegistry struct {
mock.Mock
@@ -169,7 +170,7 @@ func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
func TestNew(t *testing.T) {
mr := new(mockRandom)
mc := new(mockConfig)
mc := new(MockConfig)
mreg := new(mockRegistry)
mg := new(mockGrpcClient)
mp := new(mockPort)
@@ -222,7 +223,7 @@ func TestNew(t *testing.T) {
func TestClose(t *testing.T) {
mr := new(mockRandom)
mc := new(mockConfig)
mc := new(MockConfig)
mreg := new(mockRegistry)
mg := new(mockGrpcClient)
mp := new(mockPort)
@@ -238,7 +239,7 @@ func TestClose(t *testing.T) {
func TestStart(t *testing.T) {
mr := new(mockRandom)
mc := new(mockConfig)
mc := new(MockConfig)
mreg := new(mockRegistry)
mg := new(mockGrpcClient)
mp := new(mockPort)