diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index 77ed6e8..bd8645f 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -28,27 +28,35 @@ type Bootstrap struct { Config config.Config SessionRegistry registry.Registry Port port.Port + GrpcClient client.Client + ErrChan chan error + SignalChan chan os.Signal } -func New() (*Bootstrap, error) { - conf, err := config.MustLoad() +func New(config config.Config, port port.Port) (*Bootstrap, error) { + randomizer := random.New() + sessionRegistry := registry.NewRegistry() + + if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil { + return nil, err + } + + grpcClient, err := client.New(config, sessionRegistry) if err != nil { return nil, err } - randomizer := random.New() - sessionRegistry := registry.NewRegistry() - - portManager := port.New() - if err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()); err != nil { - return nil, err - } + errChan := make(chan error, 5) + signalChan := make(chan os.Signal, 1) return &Bootstrap{ Randomizer: randomizer, - Config: conf, + Config: config, SessionRegistry: sessionRegistry, - Port: portManager, + Port: port, + GrpcClient: grpcClient, + ErrChan: errChan, + SignalChan: signalChan, }, nil } @@ -73,25 +81,20 @@ func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) { return sshCfg, nil } -func startGRPCClient(ctx context.Context, conf config.Config, registry registry.Registry, errChan chan<- error) (client.Client, error) { - grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort()) - grpcClient, err := client.New(conf, grpcAddr, registry) - if err != nil { - return nil, err - } +func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error { healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second) defer healthCancel() - if err = grpcClient.CheckServerHealth(healthCtx); err != nil { - return nil, fmt.Errorf("gRPC health check failed: %w", err) + if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil { + return fmt.Errorf("gRPC health check failed: %w", err) } go func() { - if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { + if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { errChan <- fmt.Errorf("failed to subscribe to events: %w", err) } }() - return grpcClient, nil + return nil } func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) { @@ -115,7 +118,7 @@ func startHTTPSServer(conf config.Config, registry registry.Registry, errChan ch httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), registry, conf.TLSRedirect(), tlsCfg) ln, err := httpsServer.Listen() if err != nil { - errChan <- fmt.Errorf("failed to start https server: %w", err) + errChan <- fmt.Errorf("failed to create TLS config: %w", err) return } if err = httpsServer.Serve(ln); err != nil { @@ -123,25 +126,25 @@ func startHTTPSServer(conf config.Config, registry registry.Registry, errChan ch } } -func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, sshPort string) error { - sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, sshPort) +func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) { + sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort()) if err != nil { - return err + errChan <- err + return } sshServer.Start() - return sshServer.Close() + errChan <- sshServer.Close() } -func startPprof(pprofPort string) { +func startPprof(pprofPort string, errChan chan<- error) { pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) if err := http.ListenAndServe(pprofAddr, nil); err != nil { - log.Printf("pprof server error: %v", err) + errChan <- fmt.Errorf("pprof server error: %v", err) } } - func (b *Bootstrap) Run() error { sshConfig, err := newSSHConfig(b.Config.KeyLoc()) if err != nil { @@ -151,13 +154,10 @@ func (b *Bootstrap) Run() error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - errChan := make(chan error, 5) - shutdownChan := make(chan os.Signal, 1) - signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) + signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM) - var grpcClient client.Client if b.Config.Mode() == types.ServerModeNODE { - grpcClient, err = startGRPCClient(ctx, b.Config, b.SessionRegistry, errChan) + err = b.startGRPCClient(ctx, b.Config, b.ErrChan) if err != nil { return fmt.Errorf("failed to start gRPC client: %w", err) } @@ -166,31 +166,29 @@ func (b *Bootstrap) Run() error { if err != nil { log.Printf("failed to close gRPC client") } - }(grpcClient) + }(b.GrpcClient) } - go startHTTPServer(b.Config, b.SessionRegistry, errChan) + go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan) if b.Config.TLSEnabled() { - go startHTTPSServer(b.Config, b.SessionRegistry, errChan) + go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan) } go func() { - if err = startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, grpcClient, b.Port, b.Config.SSHPort()); err != nil { - errChan <- fmt.Errorf("SSH server error: %w", err) - } + startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan) }() if b.Config.PprofEnabled() { - go startPprof(b.Config.PprofPort()) + go startPprof(b.Config.PprofPort(), b.ErrChan) } log.Println("All services started successfully") select { - case err = <-errChan: + case err = <-b.ErrChan: return fmt.Errorf("service error: %w", err) - case sig := <-shutdownChan: + case sig := <-b.SignalChan: log.Printf("Received signal %s, initiating graceful shutdown", sig) cancel() return nil diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go new file mode 100644 index 0000000..13778c7 --- /dev/null +++ b/internal/bootstrap/bootstrap_test.go @@ -0,0 +1,627 @@ +package bootstrap + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http" + _ "net/http/pprof" + "os" + "path/filepath" + "strconv" + "testing" + "time" + "tunnel_pls/internal/config" + "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +type MockSessionRegistry struct { + mock.Mock +} + +func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) { + args := m.Called(key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { + args := m.Called(user, key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error { + args := m.Called(user, oldKey, newKey) + return args.Error(0) +} + +func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool { + args := m.Called(key, session) + return args.Bool(0) +} + +func (m *MockSessionRegistry) Remove(key registry.Key) { + m.Called(key) +} + +func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session { + args := m.Called(user) + return args.Get(0).([]registry.Session) +} + +func (m *MockSessionRegistry) Slug() slug.Slug { + args := m.Called() + return args.Get(0).(slug.Slug) +} + +type MockSession struct { + mock.Mock +} + +func (m *MockSession) Lifecycle() lifecycle.Lifecycle { + args := m.Called() + return args.Get(0).(lifecycle.Lifecycle) +} + +func (m *MockSession) Interaction() interaction.Interaction { + args := m.Called() + return args.Get(0).(interaction.Interaction) +} + +func (m *MockSession) Forwarder() forwarder.Forwarder { + args := m.Called() + return args.Get(0).(forwarder.Forwarder) +} + +func (m *MockSession) Slug() slug.Slug { + args := m.Called() + return args.Get(0).(slug.Slug) +} + +func (m *MockSession) Detail() *types.Detail { + args := m.Called() + return args.Get(0).(*types.Detail) +} + +type MockRandom struct { + mock.Mock +} + +func (m *MockRandom) String(length int) (string, error) { + args := m.Called(length) + return args.String(0), args.Error(1) +} + +type MockConfig struct { + mock.Mock +} + +func (m *MockConfig) Domain() string { return m.Called().String(0) } +func (m *MockConfig) SSHPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) PprofPort() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { + args := m.Called() + if args.Get(0) == nil { + return 0 + } + switch v := args.Get(0).(type) { + case types.ServerMode: + return v + case int: + return types.ServerMode(v) + default: + return types.ServerMode(args.Int(0)) + } +} +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } + +type MockPort struct { + mock.Mock +} + +func (m *MockPort) AddRange(startPort, endPort uint16) error { + return m.Called(startPort, endPort).Error(0) +} +func (m *MockPort) Unassigned() (uint16, bool) { + args := m.Called() + var port uint16 + if args.Get(0) != nil { + switch v := args.Get(0).(type) { + case int: + port = uint16(v) + case uint16: + port = v + case uint32: + port = uint16(v) + case int32: + port = uint16(v) + case float64: + port = uint16(v) + default: + port = uint16(args.Int(0)) + } + } + return port, args.Bool(1) +} +func (m *MockPort) SetStatus(port uint16, assigned bool) error { + return m.Called(port, assigned).Error(0) +} +func (m *MockPort) Claim(port uint16) bool { + return m.Called(port).Bool(0) +} + +type MockGRPCClient struct { + mock.Mock +} + +func (m *MockGRPCClient) ClientConn() *grpc.ClientConn { + args := m.Called() + return args.Get(0).(*grpc.ClientConn) +} + +func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { + m.Called() + return +} + +func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error { + args := m.Called(ctx, domain, token) + return args.Error(0) +} + +func (m *MockGRPCClient) Close() error { + args := m.Called() + return args.Error(0) +} + +func TestNew(t *testing.T) { + tests := []struct { + name string + setupConfig func() config.Config + setupPort func() port.Port + wantErr bool + errContains string + }{ + { + name: "Success New with default value", + wantErr: false, + }, + { + name: "Error when AddRange fails", + setupPort: func() port.Port { + mockPort := &MockPort{} + mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range")) + return mockPort + }, + wantErr: true, + errContains: "invalid port range", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var mockPort port.Port + if tt.setupPort != nil { + mockPort = tt.setupPort() + } else { + mockPort = port.New() + } + + var mockConfig config.Config + if tt.setupConfig != nil { + mockConfig = tt.setupConfig() + } else { + var err error + mockConfig, err = config.MustLoad() + assert.NoError(t, err) + } + + bootstrap, err := New(mockConfig, mockPort) + + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + assert.Nil(t, bootstrap) + } else { + assert.NoError(t, err) + assert.NotNil(t, bootstrap) + assert.NotNil(t, bootstrap.Randomizer) + assert.NotNil(t, bootstrap.SessionRegistry) + assert.NotNil(t, bootstrap.Config) + assert.NotNil(t, bootstrap.Port) + assert.NotNil(t, bootstrap.ErrChan) + assert.NotNil(t, bootstrap.SignalChan) + } + }) + } +} + +func generateTestCert(t *testing.T) (certPEM, keyPEM []byte) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test Co"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + require.NoError(t, err) + + certPEM = pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: certDER, + }) + + keyPEM = pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + return certPEM, keyPEM +} + +func randomAvailablePort() (string, error) { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + defer listener.Close() + + port := listener.Addr().(*net.TCPAddr).Port + return strconv.Itoa(port), nil +} + +func TestRun(t *testing.T) { + mockRandom := &MockRandom{} + mockErrChan := make(chan error, 1) + mockSignalChan := make(chan os.Signal, 1) + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + tmpDir := t.TempDir() + keyLoc := filepath.Join(tmpDir, "key.key") + + tests := []struct { + name string + setupConfig func() *MockConfig + setupGrpcClient func() *MockGRPCClient + needCerts bool + expectError bool + }{ + { + name: "successful run and termination", + setupConfig: func() *MockConfig { + mock := &MockConfig{} + mock.On("KeyLoc").Return(keyLoc) + mock.On("Mode").Return(types.ServerModeSTANDALONE) + mock.On("Domain").Return("example.com") + mock.On("SSHPort").Return("0") + mock.On("HTTPPort").Return("0") + mock.On("HTTPSPort").Return("0") + mock.On("TLSEnabled").Return(false) + mock.On("TLSRedirect").Return(false) + mock.On("ACMEEmail").Return("test@example.com") + mock.On("CFAPIToken").Return("fake-token") + mock.On("ACMEStaging").Return(true) + mock.On("AllowedPortsStart").Return(uint16(1024)) + mock.On("AllowedPortsEnd").Return(uint16(65535)) + mock.On("BufferSize").Return(4096) + mock.On("PprofEnabled").Return(false) + mock.On("PprofPort").Return("0") + mock.On("GRPCAddress").Return("localhost") + mock.On("GRPCPort").Return("0") + mock.On("NodeToken").Return("fake-node-token") + return mock + }, + expectError: false, + }, + { + name: "error from SSH server invalid port", + setupConfig: func() *MockConfig { + mock := &MockConfig{} + mock.On("KeyLoc").Return(keyLoc) + mock.On("Mode").Return(types.ServerModeSTANDALONE) + mock.On("Domain").Return("example.com") + mock.On("SSHPort").Return("invalid") + mock.On("HTTPPort").Return("0") + mock.On("HTTPSPort").Return("0") + mock.On("TLSEnabled").Return(false) + mock.On("TLSRedirect").Return(false) + mock.On("ACMEEmail").Return("test@example.com") + mock.On("CFAPIToken").Return("fake-token") + mock.On("ACMEStaging").Return(true) + mock.On("AllowedPortsStart").Return(uint16(1024)) + mock.On("AllowedPortsEnd").Return(uint16(65535)) + mock.On("BufferSize").Return(4096) + mock.On("PprofEnabled").Return(false) + mock.On("PprofPort").Return("0") + mock.On("GRPCAddress").Return("localhost") + mock.On("GRPCPort").Return("0") + mock.On("NodeToken").Return("fake-node-token") + return mock + }, + expectError: true, + }, + { + name: "error from HTTP server invalid port", + setupConfig: func() *MockConfig { + mock := &MockConfig{} + mock.On("KeyLoc").Return(keyLoc) + mock.On("Mode").Return(types.ServerModeSTANDALONE) + mock.On("Domain").Return("example.com") + mock.On("SSHPort").Return("0") + mock.On("HTTPPort").Return("invalid") + mock.On("HTTPSPort").Return("0") + mock.On("TLSEnabled").Return(false) + mock.On("TLSRedirect").Return(false) + mock.On("ACMEEmail").Return("test@example.com") + mock.On("CFAPIToken").Return("fake-token") + mock.On("ACMEStaging").Return(true) + mock.On("AllowedPortsStart").Return(uint16(1024)) + mock.On("AllowedPortsEnd").Return(uint16(65535)) + mock.On("BufferSize").Return(4096) + mock.On("PprofEnabled").Return(false) + mock.On("PprofPort").Return("0") + mock.On("GRPCAddress").Return("localhost") + mock.On("GRPCPort").Return("0") + mock.On("NodeToken").Return("fake-node-token") + return mock + }, + expectError: true, + }, + { + name: "error from HTTPS server invalid port", + setupConfig: func() *MockConfig { + tempDir := os.TempDir() + mock := &MockConfig{} + mock.On("KeyLoc").Return(keyLoc) + mock.On("Mode").Return(types.ServerModeSTANDALONE) + mock.On("Domain").Return("example.com") + mock.On("SSHPort").Return("0") + mock.On("HTTPPort").Return("0") + mock.On("HTTPSPort").Return("invalid") + mock.On("TLSEnabled").Return(true) + mock.On("TLSRedirect").Return(false) + mock.On("TLSStoragePath").Return(tempDir) + mock.On("ACMEEmail").Return("test@example.com") + mock.On("CFAPIToken").Return("fake-token") + mock.On("ACMEStaging").Return(true) + mock.On("AllowedPortsStart").Return(uint16(1024)) + mock.On("AllowedPortsEnd").Return(uint16(65535)) + mock.On("BufferSize").Return(4096) + mock.On("PprofEnabled").Return(false) + mock.On("PprofPort").Return("0") + mock.On("GRPCAddress").Return("localhost") + mock.On("GRPCPort").Return("0") + mock.On("NodeToken").Return("fake-node-token") + return mock + }, + expectError: true, + }, + { + name: "grpc health check failed", + setupConfig: func() *MockConfig { + mock := &MockConfig{} + mock.On("KeyLoc").Return(keyLoc) + mock.On("Mode").Return(types.ServerModeNODE) + mock.On("Domain").Return("example.com") + mock.On("SSHPort").Return("0") + mock.On("HTTPPort").Return("0") + mock.On("HTTPSPort").Return("0") + mock.On("TLSEnabled").Return(false) + mock.On("TLSRedirect").Return(false) + mock.On("ACMEEmail").Return("test@example.com") + mock.On("CFAPIToken").Return("fake-token") + mock.On("ACMEStaging").Return(true) + mock.On("AllowedPortsStart").Return(uint16(1024)) + mock.On("AllowedPortsEnd").Return(uint16(65535)) + mock.On("BufferSize").Return(4096) + mock.On("PprofEnabled").Return(false) + mock.On("PprofPort").Return("0") + mock.On("GRPCAddress").Return("localhost") + mock.On("GRPCPort").Return("invalid") + mock.On("NodeToken").Return("fake-node-token") + return mock + }, + setupGrpcClient: func() *MockGRPCClient { + mockGRPCClient := &MockGRPCClient{} + mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed")) + return mockGRPCClient + }, + expectError: true, + }, + { + name: "successful run with pprof enabled", + setupConfig: func() *MockConfig { + mock := &MockConfig{} + pprofPort, _ := randomAvailablePort() + mock.On("KeyLoc").Return(keyLoc) + mock.On("Mode").Return(types.ServerModeSTANDALONE) + mock.On("Domain").Return("example.com") + mock.On("SSHPort").Return("0") + mock.On("HTTPPort").Return("0") + mock.On("HTTPSPort").Return("0") + mock.On("TLSEnabled").Return(false) + mock.On("TLSRedirect").Return(false) + mock.On("ACMEEmail").Return("test@example.com") + mock.On("CFAPIToken").Return("fake-token") + mock.On("ACMEStaging").Return(true) + mock.On("AllowedPortsStart").Return(uint16(1024)) + mock.On("AllowedPortsEnd").Return(uint16(65535)) + mock.On("BufferSize").Return(4096) + mock.On("PprofEnabled").Return(true) + mock.On("PprofPort").Return(pprofPort) + mock.On("GRPCAddress").Return("localhost") + mock.On("GRPCPort").Return("0") + mock.On("NodeToken").Return("fake-node-token") + return mock + }, + expectError: false, + }, { + name: "successful run in NODE mode with signal", + setupConfig: func() *MockConfig { + mock := &MockConfig{} + mock.On("KeyLoc").Return(keyLoc) + mock.On("Mode").Return(types.ServerModeNODE) + mock.On("Domain").Return("example.com") + mock.On("SSHPort").Return("0") + mock.On("HTTPPort").Return("0") + mock.On("HTTPSPort").Return("0") + mock.On("TLSEnabled").Return(false) + mock.On("TLSRedirect").Return(false) + mock.On("ACMEEmail").Return("test@example.com") + mock.On("CFAPIToken").Return("fake-token") + mock.On("ACMEStaging").Return(true) + mock.On("AllowedPortsStart").Return(uint16(1024)) + mock.On("AllowedPortsEnd").Return(uint16(65535)) + mock.On("BufferSize").Return(4096) + mock.On("PprofEnabled").Return(false) + mock.On("PprofPort").Return("0") + mock.On("GRPCAddress").Return("localhost") + mock.On("GRPCPort").Return("0") + mock.On("NodeToken").Return("fake-node-token") + return mock + }, + setupGrpcClient: func() *MockGRPCClient { + mockGRPCClient := &MockGRPCClient{} + mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil) + mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockGRPCClient.On("Close").Return(nil) + return mockGRPCClient + }, + expectError: false, + }, { + name: "successful run in NODE mode with signal buf error when closing", + setupConfig: func() *MockConfig { + mock := &MockConfig{} + mock.On("KeyLoc").Return(keyLoc) + mock.On("Mode").Return(types.ServerModeNODE) + mock.On("Domain").Return("example.com") + mock.On("SSHPort").Return("0") + mock.On("HTTPPort").Return("0") + mock.On("HTTPSPort").Return("0") + mock.On("TLSEnabled").Return(false) + mock.On("TLSRedirect").Return(false) + mock.On("ACMEEmail").Return("test@example.com") + mock.On("CFAPIToken").Return("fake-token") + mock.On("ACMEStaging").Return(true) + mock.On("AllowedPortsStart").Return(uint16(1024)) + mock.On("AllowedPortsEnd").Return(uint16(65535)) + mock.On("BufferSize").Return(4096) + mock.On("PprofEnabled").Return(false) + mock.On("PprofPort").Return("0") + mock.On("GRPCAddress").Return("localhost") + mock.On("GRPCPort").Return("0") + mock.On("NodeToken").Return("fake-node-token") + return mock + }, + setupGrpcClient: func() *MockGRPCClient { + mockGRPCClient := &MockGRPCClient{} + mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil) + mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy")) + return mockGRPCClient + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockConfig := tt.setupConfig() + mockGRPCClient := &MockGRPCClient{} + bootstrap := &Bootstrap{ + Randomizer: mockRandom, + Config: mockConfig, + SessionRegistry: mockSessionRegistry, + Port: mockPort, + ErrChan: mockErrChan, + SignalChan: mockSignalChan, + GrpcClient: mockGRPCClient, + } + + if tt.setupGrpcClient != nil { + bootstrap.GrpcClient = tt.setupGrpcClient() + } + + done := make(chan error, 1) + go func() { + done <- bootstrap.Run() + }() + + if tt.expectError { + err := <-done + assert.Error(t, err) + } else if tt.name == "successful run with pprof enabled" { + time.Sleep(200 * time.Millisecond) + fmt.Println(mockConfig.PprofPort()) + resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort())) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + resp.Body.Close() + mockSignalChan <- os.Interrupt + err = <-done + assert.NoError(t, err) + } else { + time.Sleep(time.Second) + mockSignalChan <- os.Interrupt + err := <-done + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 5c21abf..19bbc49 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -13,6 +13,7 @@ type Config interface { TLSEnabled() bool TLSRedirect() bool + TLSStoragePath() string ACMEEmail() string CFAPIToken() string @@ -52,6 +53,7 @@ func (c *config) HTTPSPort() string { return c.httpsPort } func (c *config) KeyLoc() string { return c.keyLoc } func (c *config) TLSEnabled() bool { return c.tlsEnabled } func (c *config) TLSRedirect() bool { return c.tlsRedirect } +func (c *config) TLSStoragePath() string { return c.tlsStoragePath } func (c *config) ACMEEmail() string { return c.acmeEmail } func (c *config) CFAPIToken() string { return c.cfAPIToken } func (c *config) ACMEStaging() bool { return c.acmeStaging } diff --git a/internal/config/loader.go b/internal/config/loader.go index ebccf3f..aeb32bb 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -20,12 +20,12 @@ type config struct { keyLoc string - tlsEnabled bool - tlsRedirect bool - - acmeEmail string - cfAPIToken string - acmeStaging bool + tlsEnabled bool + tlsRedirect bool + tlsStoragePath string + acmeEmail string + cfAPIToken string + acmeStaging bool allowedPortsStart uint16 allowedPortsEnd uint16 diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index ade89a0..3adcc57 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -44,7 +44,9 @@ var ( initialBackoff = time.Second ) -func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) { +func New(config config.Config, sessionRegistry registry.Registry) (Client, error) { + address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort()) + var opts []grpc.DialOption opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) diff --git a/internal/grpc/client/client_test.go b/internal/grpc/client/client_test.go index 1d3b315..3964cf0 100644 --- a/internal/grpc/client/client_test.go +++ b/internal/grpc/client/client_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "tunnel_pls/internal/config" "tunnel_pls/internal/registry" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" @@ -16,6 +15,7 @@ import ( "tunnel_pls/types" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" + "github.com/stretchr/testify/mock" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health/grpc_health_v1" @@ -382,7 +382,8 @@ func TestProcessEventStream(t *testing.T) { mockReg.getFunc = func(key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } mockReg.getWithUserFunc = func(user string, key registry.Key) (registry.Session, error) { return nil, errors.New("fail") } c.sessionRegistry = mockReg - c.config = &mockConfig{domain: "test.com"} + c.config = &MockConfig{} + c.config.(*MockConfig).On("Domain").Return("test.com") mockStream.sendFunc = func(n *proto.Node) error { return nil } err := c.processEventStream(mockStream) @@ -541,7 +542,8 @@ func TestHandleSlugChange(t *testing.T) { func TestHandleGetSessions(t *testing.T) { mockReg := &mockRegistry{} mockStream := &mockSubscribeClient{} - mockCfg := &mockConfig{domain: "test.com"} + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("test.com") c := &client{sessionRegistry: mockReg, config: mockCfg} evt := &proto.Events{ @@ -840,8 +842,11 @@ func TestNew_Error(t *testing.T) { return nil, errors.New("dial fail") } defer func() { grpcNewClient = old }() + mockConfig := &MockConfig{} - cli, err := New(&mockConfig{}, "localhost:1234", &mockRegistry{}) + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("1234") + cli, err := New(mockConfig, &mockRegistry{}) if err == nil || err.Error() != "failed to connect to gRPC server at localhost:1234: dial fail" { t.Errorf("expected dial fail error, got %v", err) } @@ -851,10 +856,11 @@ func TestNew_Error(t *testing.T) { } func TestNew(t *testing.T) { - mockCfg := &mockConfig{} + mockConfig := &MockConfig{} mockReg := &mockRegistry{} - - cli, err := New(mockCfg, "localhost:1234", mockReg) + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("1234") + cli, err := New(mockConfig, mockReg) if err != nil { t.Errorf("New() error = %v", err) } @@ -864,12 +870,30 @@ func TestNew(t *testing.T) { defer cli.Close() } -type mockConfig struct { - config.Config - domain string +type MockConfig struct { + mock.Mock } -func (m *mockConfig) Domain() string { return m.domain } +func (m *MockConfig) Domain() string { return m.Called().String(0) } +func (m *MockConfig) SSHPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) PprofPort() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } type mockRegistry struct { registry.Registry diff --git a/internal/transport/tls.go b/internal/transport/tls.go index 877afb4..584dec4 100644 --- a/internal/transport/tls.go +++ b/internal/transport/tls.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "os" + "path/filepath" "sync" "time" "tunnel_pls/internal/config" @@ -47,15 +48,18 @@ func NewTLSConfig(config config.Config) (*tls.Config, error) { var initErr error tlsManagerOnce.Do(func() { - certPath := "certs/tls/cert.pem" - keyPath := "certs/tls/privkey.pem" - storagePath := "certs/tls/certmagic" + storagePath := config.TLSStoragePath() + cleanBase := filepath.Clean(storagePath) + + certPath := filepath.Join(cleanBase, "cert.pem") + keyPath := filepath.Join(cleanBase, "privkey.pem") + storagePathCertMagic := filepath.Join(cleanBase, "certmagic") tm := &tlsManager{ config: config, certPath: certPath, keyPath: keyPath, - storagePath: storagePath, + storagePath: storagePathCertMagic, } if tm.userCertsExistAndValid() { diff --git a/main.go b/main.go index be8b510..a908903 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,8 @@ import ( "log" "os" "tunnel_pls/internal/bootstrap" + "tunnel_pls/internal/config" + "tunnel_pls/internal/port" "tunnel_pls/internal/version" ) @@ -18,7 +20,12 @@ func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) log.Printf("Starting %s", version.GetVersion()) - boot, err := bootstrap.New() + conf, err := config.MustLoad() + if err != nil { + log.Fatalf("Config load error: %v", err) + } + + boot, err := bootstrap.New(conf, port.New()) if err != nil { log.Fatalf("Startup error: %v", err) } diff --git a/server/server_test.go b/server/server_test.go index de54f18..c35073a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -27,29 +27,30 @@ func (m *mockRandom) String(length int) (string, error) { return args.String(0), args.Error(1) } -type mockConfig struct { +type MockConfig struct { mock.Mock } -func (m *mockConfig) Domain() string { return m.Called().String(0) } -func (m *mockConfig) SSHPort() string { return m.Called().String(0) } -func (m *mockConfig) HTTPPort() string { return m.Called().String(0) } -func (m *mockConfig) HTTPSPort() string { return m.Called().String(0) } -func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) } -func (m *mockConfig) TLSRedirect() bool { return m.Called().Bool(0) } -func (m *mockConfig) ACMEEmail() string { return m.Called().String(0) } -func (m *mockConfig) CFAPIToken() string { return m.Called().String(0) } -func (m *mockConfig) ACMEStaging() bool { return m.Called().Bool(0) } -func (m *mockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } -func (m *mockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } -func (m *mockConfig) BufferSize() int { return m.Called().Int(0) } -func (m *mockConfig) PprofEnabled() bool { return m.Called().Bool(0) } -func (m *mockConfig) PprofPort() string { return m.Called().String(0) } -func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } -func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) } -func (m *mockConfig) GRPCPort() string { return m.Called().String(0) } -func (m *mockConfig) NodeToken() string { return m.Called().String(0) } -func (m *mockConfig) KeyLoc() string { return m.Called().String(0) } +func (m *MockConfig) Domain() string { return m.Called().String(0) } +func (m *MockConfig) SSHPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) PprofPort() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } type mockRegistry struct { mock.Mock @@ -169,7 +170,7 @@ func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) { func TestNew(t *testing.T) { mr := new(mockRandom) - mc := new(mockConfig) + mc := new(MockConfig) mreg := new(mockRegistry) mg := new(mockGrpcClient) mp := new(mockPort) @@ -222,7 +223,7 @@ func TestNew(t *testing.T) { func TestClose(t *testing.T) { mr := new(mockRandom) - mc := new(mockConfig) + mc := new(MockConfig) mreg := new(mockRegistry) mg := new(mockGrpcClient) mp := new(mockPort) @@ -238,7 +239,7 @@ func TestClose(t *testing.T) { func TestStart(t *testing.T) { mr := new(mockRandom) - mc := new(mockConfig) + mc := new(MockConfig) mreg := new(mockRegistry) mg := new(mockGrpcClient) mp := new(mockPort)