feat(testing): add comprehensive test coverage and code quality improvements #76

Merged
bagas merged 47 commits from feat/testing into staging 2026-01-27 16:36:40 +07:00
9 changed files with 754 additions and 89 deletions
Showing only changes of commit 2f5c44ff01 - Show all commits
+41 -43
View File
@@ -28,27 +28,35 @@ type Bootstrap struct {
Config config.Config Config config.Config
SessionRegistry registry.Registry SessionRegistry registry.Registry
Port port.Port Port port.Port
GrpcClient client.Client
ErrChan chan error
SignalChan chan os.Signal
} }
func New() (*Bootstrap, error) { func New(config config.Config, port port.Port) (*Bootstrap, error) {
conf, err := config.MustLoad() randomizer := random.New()
sessionRegistry := registry.NewRegistry()
if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil {
return nil, err
}
grpcClient, err := client.New(config, sessionRegistry)
if err != nil { if err != nil {
return nil, err return nil, err
} }
randomizer := random.New() errChan := make(chan error, 5)
sessionRegistry := registry.NewRegistry() signalChan := make(chan os.Signal, 1)
portManager := port.New()
if err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()); err != nil {
return nil, err
}
return &Bootstrap{ return &Bootstrap{
Randomizer: randomizer, Randomizer: randomizer,
Config: conf, Config: config,
SessionRegistry: sessionRegistry, SessionRegistry: sessionRegistry,
Port: portManager, Port: port,
GrpcClient: grpcClient,
ErrChan: errChan,
SignalChan: signalChan,
}, nil }, nil
} }
@@ -73,25 +81,20 @@ func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) {
return sshCfg, nil return sshCfg, nil
} }
func startGRPCClient(ctx context.Context, conf config.Config, registry registry.Registry, errChan chan<- error) (client.Client, error) { func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error {
grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort())
grpcClient, err := client.New(conf, grpcAddr, registry)
if err != nil {
return nil, err
}
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second) healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
defer healthCancel() defer healthCancel()
if err = grpcClient.CheckServerHealth(healthCtx); err != nil { if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil {
return nil, fmt.Errorf("gRPC health check failed: %w", err) return fmt.Errorf("gRPC health check failed: %w", err)
} }
go func() { go func() {
if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
errChan <- fmt.Errorf("failed to subscribe to events: %w", err) errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
} }
}() }()
return grpcClient, nil return nil
} }
func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) { func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
@@ -115,7 +118,7 @@ func startHTTPSServer(conf config.Config, registry registry.Registry, errChan ch
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), registry, conf.TLSRedirect(), tlsCfg) httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), registry, conf.TLSRedirect(), tlsCfg)
ln, err := httpsServer.Listen() ln, err := httpsServer.Listen()
if err != nil { if err != nil {
errChan <- fmt.Errorf("failed to start https server: %w", err) errChan <- fmt.Errorf("failed to create TLS config: %w", err)
return return
} }
if err = httpsServer.Serve(ln); err != nil { if err = httpsServer.Serve(ln); err != nil {
@@ -123,25 +126,25 @@ func startHTTPSServer(conf config.Config, registry registry.Registry, errChan ch
} }
} }
func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, sshPort string) error { func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) {
sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, sshPort) sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort())
if err != nil { if err != nil {
return err errChan <- err
return
} }
sshServer.Start() sshServer.Start()
return sshServer.Close() errChan <- sshServer.Close()
} }
func startPprof(pprofPort string) { func startPprof(pprofPort string, errChan chan<- error) {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil { if err := http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err) errChan <- fmt.Errorf("pprof server error: %v", err)
} }
} }
func (b *Bootstrap) Run() error { func (b *Bootstrap) Run() error {
sshConfig, err := newSSHConfig(b.Config.KeyLoc()) sshConfig, err := newSSHConfig(b.Config.KeyLoc())
if err != nil { if err != nil {
@@ -151,13 +154,10 @@ func (b *Bootstrap) Run() error {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
errChan := make(chan error, 5) signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM)
shutdownChan := make(chan os.Signal, 1)
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
var grpcClient client.Client
if b.Config.Mode() == types.ServerModeNODE { if b.Config.Mode() == types.ServerModeNODE {
grpcClient, err = startGRPCClient(ctx, b.Config, b.SessionRegistry, errChan) err = b.startGRPCClient(ctx, b.Config, b.ErrChan)
if err != nil { if err != nil {
return fmt.Errorf("failed to start gRPC client: %w", err) return fmt.Errorf("failed to start gRPC client: %w", err)
} }
@@ -166,31 +166,29 @@ func (b *Bootstrap) Run() error {
if err != nil { if err != nil {
log.Printf("failed to close gRPC client") log.Printf("failed to close gRPC client")
} }
}(grpcClient) }(b.GrpcClient)
} }
go startHTTPServer(b.Config, b.SessionRegistry, errChan) go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan)
if b.Config.TLSEnabled() { if b.Config.TLSEnabled() {
go startHTTPSServer(b.Config, b.SessionRegistry, errChan) go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan)
} }
go func() { go func() {
if err = startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, grpcClient, b.Port, b.Config.SSHPort()); err != nil { startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan)
errChan <- fmt.Errorf("SSH server error: %w", err)
}
}() }()
if b.Config.PprofEnabled() { if b.Config.PprofEnabled() {
go startPprof(b.Config.PprofPort()) go startPprof(b.Config.PprofPort(), b.ErrChan)
} }
log.Println("All services started successfully") log.Println("All services started successfully")
select { select {
case err = <-errChan: case err = <-b.ErrChan:
return fmt.Errorf("service error: %w", err) return fmt.Errorf("service error: %w", err)
case sig := <-shutdownChan: case sig := <-b.SignalChan:
log.Printf("Received signal %s, initiating graceful shutdown", sig) log.Printf("Received signal %s, initiating graceful shutdown", sig)
cancel() cancel()
return nil return nil
+627
View File
@@ -0,0 +1,627 @@
package bootstrap
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"net/http"
_ "net/http/pprof"
"os"
"path/filepath"
"strconv"
"testing"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockSession struct {
mock.Mock
}
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
args := m.Called()
return args.Get(0).(lifecycle.Lifecycle)
}
func (m *MockSession) Interaction() interaction.Interaction {
args := m.Called()
return args.Get(0).(interaction.Interaction)
}
func (m *MockSession) Forwarder() forwarder.Forwarder {
args := m.Called()
return args.Get(0).(forwarder.Forwarder)
}
func (m *MockSession) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
func (m *MockSession) Detail() *types.Detail {
args := m.Called()
return args.Get(0).(*types.Detail)
}
type MockRandom struct {
mock.Mock
}
func (m *MockRandom) String(length int) (string, error) {
args := m.Called(length)
return args.String(0), args.Error(1)
}
type MockConfig struct {
mock.Mock
}
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode {
args := m.Called()
if args.Get(0) == nil {
return 0
}
switch v := args.Get(0).(type) {
case types.ServerMode:
return v
case int:
return types.ServerMode(v)
default:
return types.ServerMode(args.Int(0))
}
}
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type MockPort struct {
mock.Mock
}
func (m *MockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
var port uint16
if args.Get(0) != nil {
switch v := args.Get(0).(type) {
case int:
port = uint16(v)
case uint16:
port = v
case uint32:
port = uint16(v)
case int32:
port = uint16(v)
case float64:
port = uint16(v)
default:
port = uint16(args.Int(0))
}
}
return port, args.Bool(1)
}
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *MockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type MockGRPCClient struct {
mock.Mock
}
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
args := m.Called()
return args.Get(0).(*grpc.ClientConn)
}
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
m.Called()
return
}
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
args := m.Called(ctx, domain, token)
return args.Error(0)
}
func (m *MockGRPCClient) Close() error {
args := m.Called()
return args.Error(0)
}
func TestNew(t *testing.T) {
tests := []struct {
name string
setupConfig func() config.Config
setupPort func() port.Port
wantErr bool
errContains string
}{
{
name: "Success New with default value",
wantErr: false,
},
{
name: "Error when AddRange fails",
setupPort: func() port.Port {
mockPort := &MockPort{}
mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range"))
return mockPort
},
wantErr: true,
errContains: "invalid port range",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var mockPort port.Port
if tt.setupPort != nil {
mockPort = tt.setupPort()
} else {
mockPort = port.New()
}
var mockConfig config.Config
if tt.setupConfig != nil {
mockConfig = tt.setupConfig()
} else {
var err error
mockConfig, err = config.MustLoad()
assert.NoError(t, err)
}
bootstrap, err := New(mockConfig, mockPort)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, bootstrap)
} else {
assert.NoError(t, err)
assert.NotNil(t, bootstrap)
assert.NotNil(t, bootstrap.Randomizer)
assert.NotNil(t, bootstrap.SessionRegistry)
assert.NotNil(t, bootstrap.Config)
assert.NotNil(t, bootstrap.Port)
assert.NotNil(t, bootstrap.ErrChan)
assert.NotNil(t, bootstrap.SignalChan)
}
})
}
}
func generateTestCert(t *testing.T) (certPEM, keyPEM []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test Co"},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDER,
})
keyPEM = pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
})
return certPEM, keyPEM
}
func randomAvailablePort() (string, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", err
}
defer listener.Close()
port := listener.Addr().(*net.TCPAddr).Port
return strconv.Itoa(port), nil
}
func TestRun(t *testing.T) {
mockRandom := &MockRandom{}
mockErrChan := make(chan error, 1)
mockSignalChan := make(chan os.Signal, 1)
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
tmpDir := t.TempDir()
keyLoc := filepath.Join(tmpDir, "key.key")
tests := []struct {
name string
setupConfig func() *MockConfig
setupGrpcClient func() *MockGRPCClient
needCerts bool
expectError bool
}{
{
name: "successful run and termination",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: false,
},
{
name: "error from SSH server invalid port",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("invalid")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: true,
},
{
name: "error from HTTP server invalid port",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("invalid")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: true,
},
{
name: "error from HTTPS server invalid port",
setupConfig: func() *MockConfig {
tempDir := os.TempDir()
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("invalid")
mock.On("TLSEnabled").Return(true)
mock.On("TLSRedirect").Return(false)
mock.On("TLSStoragePath").Return(tempDir)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: true,
},
{
name: "grpc health check failed",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeNODE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("invalid")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed"))
return mockGRPCClient
},
expectError: true,
},
{
name: "successful run with pprof enabled",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
pprofPort, _ := randomAvailablePort()
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeSTANDALONE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(true)
mock.On("PprofPort").Return(pprofPort)
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
expectError: false,
}, {
name: "successful run in NODE mode with signal",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeNODE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockGRPCClient.On("Close").Return(nil)
return mockGRPCClient
},
expectError: false,
}, {
name: "successful run in NODE mode with signal buf error when closing",
setupConfig: func() *MockConfig {
mock := &MockConfig{}
mock.On("KeyLoc").Return(keyLoc)
mock.On("Mode").Return(types.ServerModeNODE)
mock.On("Domain").Return("example.com")
mock.On("SSHPort").Return("0")
mock.On("HTTPPort").Return("0")
mock.On("HTTPSPort").Return("0")
mock.On("TLSEnabled").Return(false)
mock.On("TLSRedirect").Return(false)
mock.On("ACMEEmail").Return("test@example.com")
mock.On("CFAPIToken").Return("fake-token")
mock.On("ACMEStaging").Return(true)
mock.On("AllowedPortsStart").Return(uint16(1024))
mock.On("AllowedPortsEnd").Return(uint16(65535))
mock.On("BufferSize").Return(4096)
mock.On("PprofEnabled").Return(false)
mock.On("PprofPort").Return("0")
mock.On("GRPCAddress").Return("localhost")
mock.On("GRPCPort").Return("0")
mock.On("NodeToken").Return("fake-node-token")
return mock
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy"))
return mockGRPCClient
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConfig := tt.setupConfig()
mockGRPCClient := &MockGRPCClient{}
bootstrap := &Bootstrap{
Randomizer: mockRandom,
Config: mockConfig,
SessionRegistry: mockSessionRegistry,
Port: mockPort,
ErrChan: mockErrChan,
SignalChan: mockSignalChan,
GrpcClient: mockGRPCClient,
}
if tt.setupGrpcClient != nil {
bootstrap.GrpcClient = tt.setupGrpcClient()
}
done := make(chan error, 1)
go func() {
done <- bootstrap.Run()
}()
if tt.expectError {
err := <-done
assert.Error(t, err)
} else if tt.name == "successful run with pprof enabled" {
time.Sleep(200 * time.Millisecond)
fmt.Println(mockConfig.PprofPort())
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
resp.Body.Close()
mockSignalChan <- os.Interrupt
err = <-done
assert.NoError(t, err)
} else {
time.Sleep(time.Second)
mockSignalChan <- os.Interrupt
err := <-done
assert.NoError(t, err)
}
})
}
}
+2
View File
@@ -13,6 +13,7 @@ type Config interface {
TLSEnabled() bool TLSEnabled() bool
TLSRedirect() bool TLSRedirect() bool
TLSStoragePath() string
ACMEEmail() string ACMEEmail() string
CFAPIToken() string CFAPIToken() string
@@ -52,6 +53,7 @@ func (c *config) HTTPSPort() string { return c.httpsPort }
func (c *config) KeyLoc() string { return c.keyLoc } func (c *config) KeyLoc() string { return c.keyLoc }
func (c *config) TLSEnabled() bool { return c.tlsEnabled } func (c *config) TLSEnabled() bool { return c.tlsEnabled }
func (c *config) TLSRedirect() bool { return c.tlsRedirect } func (c *config) TLSRedirect() bool { return c.tlsRedirect }
func (c *config) TLSStoragePath() string { return c.tlsStoragePath }
func (c *config) ACMEEmail() string { return c.acmeEmail } func (c *config) ACMEEmail() string { return c.acmeEmail }
func (c *config) CFAPIToken() string { return c.cfAPIToken } func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging } func (c *config) ACMEStaging() bool { return c.acmeStaging }
+6 -6
View File
@@ -20,12 +20,12 @@ type config struct {
keyLoc string keyLoc string
tlsEnabled bool tlsEnabled bool
tlsRedirect bool tlsRedirect bool
tlsStoragePath string
acmeEmail string acmeEmail string
cfAPIToken string cfAPIToken string
acmeStaging bool acmeStaging bool
allowedPortsStart uint16 allowedPortsStart uint16
allowedPortsEnd uint16 allowedPortsEnd uint16
+3 -1
View File
@@ -44,7 +44,9 @@ var (
initialBackoff = time.Second initialBackoff = time.Second
) )
func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) { func New(config config.Config, sessionRegistry registry.Registry) (Client, error) {
address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort())
var opts []grpc.DialOption var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
+35 -11
View File
@@ -8,7 +8,6 @@ import (
"testing" "testing"
"time" "time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry" "tunnel_pls/internal/registry"
"tunnel_pls/session/interaction" "tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle" "tunnel_pls/session/lifecycle"
@@ -16,6 +15,7 @@ import (
"tunnel_pls/types" "tunnel_pls/types"
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/health/grpc_health_v1"
@@ -382,7 +382,8 @@ func TestProcessEventStream(t *testing.T) {
mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return nil, errors.New("fail") }
c.sessionRegistry = mockReg c.sessionRegistry = mockReg
c.config = &mockConfig{domain: "test.com"} c.config = &MockConfig{}
c.config.(*MockConfig).On("Domain").Return("test.com")
mockStream.sendFunc = func(n *proto.Node) error { return nil } mockStream.sendFunc = func(n *proto.Node) error { return nil }
err := c.processEventStream(mockStream) err := c.processEventStream(mockStream)
@@ -541,7 +542,8 @@ func TestHandleSlugChange(t *testing.T) {
func TestHandleGetSessions(t *testing.T) { func TestHandleGetSessions(t *testing.T) {
mockReg := &mockRegistry{} mockReg := &mockRegistry{}
mockStream := &mockSubscribeClient{} mockStream := &mockSubscribeClient{}
mockCfg := &mockConfig{domain: "test.com"} mockCfg := &MockConfig{}
mockCfg.On("Domain").Return("test.com")
c := &client{sessionRegistry: mockReg, config: mockCfg} c := &client{sessionRegistry: mockReg, config: mockCfg}
evt := &proto.Events{ evt := &proto.Events{
@@ -840,8 +842,11 @@ func TestNew_Error(t *testing.T) {
return nil, errors.New("dial fail") return nil, errors.New("dial fail")
} }
defer func() { grpcNewClient = old }() defer func() { grpcNewClient = old }()
mockConfig := &MockConfig{}
cli, err := New(&mockConfig{}, "localhost:1234", &mockRegistry{}) mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("1234")
cli, err := New(mockConfig, &mockRegistry{})
if err == nil || err.Error() != "failed to connect to gRPC server at localhost:1234: dial fail" { if err == nil || err.Error() != "failed to connect to gRPC server at localhost:1234: dial fail" {
t.Errorf("expected dial fail error, got %v", err) t.Errorf("expected dial fail error, got %v", err)
} }
@@ -851,10 +856,11 @@ func TestNew_Error(t *testing.T) {
} }
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
mockCfg := &mockConfig{} mockConfig := &MockConfig{}
mockReg := &mockRegistry{} mockReg := &mockRegistry{}
mockConfig.On("GRPCAddress").Return("localhost")
cli, err := New(mockCfg, "localhost:1234", mockReg) mockConfig.On("GRPCPort").Return("1234")
cli, err := New(mockConfig, mockReg)
if err != nil { if err != nil {
t.Errorf("New() error = %v", err) t.Errorf("New() error = %v", err)
} }
@@ -864,12 +870,30 @@ func TestNew(t *testing.T) {
defer cli.Close() defer cli.Close()
} }
type mockConfig struct { type MockConfig struct {
config.Config mock.Mock
domain string
} }
func (m *mockConfig) Domain() string { return m.domain } func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type mockRegistry struct { type mockRegistry struct {
registry.Registry registry.Registry
+8 -4
View File
@@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"path/filepath"
"sync" "sync"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
@@ -47,15 +48,18 @@ func NewTLSConfig(config config.Config) (*tls.Config, error) {
var initErr error var initErr error
tlsManagerOnce.Do(func() { tlsManagerOnce.Do(func() {
certPath := "certs/tls/cert.pem" storagePath := config.TLSStoragePath()
keyPath := "certs/tls/privkey.pem" cleanBase := filepath.Clean(storagePath)
storagePath := "certs/tls/certmagic"
certPath := filepath.Join(cleanBase, "cert.pem")
keyPath := filepath.Join(cleanBase, "privkey.pem")
storagePathCertMagic := filepath.Join(cleanBase, "certmagic")
tm := &tlsManager{ tm := &tlsManager{
config: config, config: config,
certPath: certPath, certPath: certPath,
keyPath: keyPath, keyPath: keyPath,
storagePath: storagePath, storagePath: storagePathCertMagic,
} }
if tm.userCertsExistAndValid() { if tm.userCertsExistAndValid() {
+8 -1
View File
@@ -5,6 +5,8 @@ import (
"log" "log"
"os" "os"
"tunnel_pls/internal/bootstrap" "tunnel_pls/internal/bootstrap"
"tunnel_pls/internal/config"
"tunnel_pls/internal/port"
"tunnel_pls/internal/version" "tunnel_pls/internal/version"
) )
@@ -18,7 +20,12 @@ func main() {
log.SetFlags(log.LstdFlags | log.Lshortfile) log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Printf("Starting %s", version.GetVersion()) log.Printf("Starting %s", version.GetVersion())
boot, err := bootstrap.New() conf, err := config.MustLoad()
if err != nil {
log.Fatalf("Config load error: %v", err)
}
boot, err := bootstrap.New(conf, port.New())
if err != nil { if err != nil {
log.Fatalf("Startup error: %v", err) log.Fatalf("Startup error: %v", err)
} }
+24 -23
View File
@@ -27,29 +27,30 @@ func (m *mockRandom) String(length int) (string, error) {
return args.String(0), args.Error(1) return args.String(0), args.Error(1)
} }
type mockConfig struct { type MockConfig struct {
mock.Mock mock.Mock
} }
func (m *mockConfig) Domain() string { return m.Called().String(0) } func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *mockConfig) SSHPort() string { return m.Called().String(0) } func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *mockConfig) HTTPPort() string { return m.Called().String(0) } func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *mockConfig) HTTPSPort() string { return m.Called().String(0) } func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *mockConfig) TLSRedirect() bool { return m.Called().Bool(0) } func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *mockConfig) ACMEEmail() string { return m.Called().String(0) } func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *mockConfig) CFAPIToken() string { return m.Called().String(0) } func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *mockConfig) ACMEStaging() bool { return m.Called().Bool(0) } func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *mockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *mockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *mockConfig) BufferSize() int { return m.Called().Int(0) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *mockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *mockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) } func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *mockConfig) GRPCPort() string { return m.Called().String(0) } func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *mockConfig) NodeToken() string { return m.Called().String(0) } func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *mockConfig) KeyLoc() string { return m.Called().String(0) } func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type mockRegistry struct { type mockRegistry struct {
mock.Mock mock.Mock
@@ -169,7 +170,7 @@ func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
mr := new(mockRandom) mr := new(mockRandom)
mc := new(mockConfig) mc := new(MockConfig)
mreg := new(mockRegistry) mreg := new(mockRegistry)
mg := new(mockGrpcClient) mg := new(mockGrpcClient)
mp := new(mockPort) mp := new(mockPort)
@@ -222,7 +223,7 @@ func TestNew(t *testing.T) {
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
mr := new(mockRandom) mr := new(mockRandom)
mc := new(mockConfig) mc := new(MockConfig)
mreg := new(mockRegistry) mreg := new(mockRegistry)
mg := new(mockGrpcClient) mg := new(mockGrpcClient)
mp := new(mockPort) mp := new(mockPort)
@@ -238,7 +239,7 @@ func TestClose(t *testing.T) {
func TestStart(t *testing.T) { func TestStart(t *testing.T) {
mr := new(mockRandom) mr := new(mockRandom)
mc := new(mockConfig) mc := new(MockConfig)
mreg := new(mockRegistry) mreg := new(mockRegistry)
mg := new(mockGrpcClient) mg := new(mockGrpcClient)
mp := new(mockPort) mp := new(mockPort)