package bootstrap import ( "context" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "math/big" "net" "net/http" _ "net/http/pprof" "os" "path/filepath" "strconv" "testing" "time" "tunnel_pls/internal/config" "tunnel_pls/internal/port" "tunnel_pls/internal/registry" "tunnel_pls/session/forwarder" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" "tunnel_pls/session/slug" "tunnel_pls/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "google.golang.org/grpc" ) type MockSessionRegistry struct { mock.Mock } func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) { args := m.Called(key) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(registry.Session), args.Error(1) } func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { args := m.Called(user, key) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(registry.Session), args.Error(1) } func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error { args := m.Called(user, oldKey, newKey) return args.Error(0) } func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool { args := m.Called(key, session) return args.Bool(0) } func (m *MockSessionRegistry) Remove(key registry.Key) { m.Called(key) } func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session { args := m.Called(user) return args.Get(0).([]registry.Session) } func (m *MockSessionRegistry) Slug() slug.Slug { args := m.Called() return args.Get(0).(slug.Slug) } type MockSession struct { mock.Mock } func (m *MockSession) Lifecycle() lifecycle.Lifecycle { args := m.Called() return args.Get(0).(lifecycle.Lifecycle) } func (m *MockSession) Interaction() interaction.Interaction { args := m.Called() return args.Get(0).(interaction.Interaction) } func (m *MockSession) Forwarder() forwarder.Forwarder { args := m.Called() return args.Get(0).(forwarder.Forwarder) } func (m *MockSession) Slug() slug.Slug { args := m.Called() return args.Get(0).(slug.Slug) } func (m *MockSession) Detail() *types.Detail { args := m.Called() return args.Get(0).(*types.Detail) } type MockRandom struct { mock.Mock } func (m *MockRandom) String(length int) (string, error) { args := m.Called(length) return args.String(0), args.Error(1) } type MockConfig struct { mock.Mock } func (m *MockConfig) Domain() string { return m.Called().String(0) } func (m *MockConfig) SSHPort() string { return m.Called().String(0) } func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) Mode() types.ServerMode { args := m.Called() if args.Get(0) == nil { return 0 } switch v := args.Get(0).(type) { case types.ServerMode: return v case int: return types.ServerMode(v) default: return types.ServerMode(args.Int(0)) } } func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } func (m *MockConfig) NodeToken() string { return m.Called().String(0) } func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } type MockPort struct { mock.Mock } func (m *MockPort) AddRange(startPort, endPort uint16) error { return m.Called(startPort, endPort).Error(0) } func (m *MockPort) Unassigned() (uint16, bool) { args := m.Called() var port uint16 if args.Get(0) != nil { switch v := args.Get(0).(type) { case int: port = uint16(v) case uint16: port = v case uint32: port = uint16(v) case int32: port = uint16(v) case float64: port = uint16(v) default: port = uint16(args.Int(0)) } } return port, args.Bool(1) } func (m *MockPort) SetStatus(port uint16, assigned bool) error { return m.Called(port, assigned).Error(0) } func (m *MockPort) Claim(port uint16) bool { return m.Called(port).Bool(0) } type MockGRPCClient struct { mock.Mock } func (m *MockGRPCClient) ClientConn() *grpc.ClientConn { args := m.Called() return args.Get(0).(*grpc.ClientConn) } func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { m.Called() return } func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error { args := m.Called(ctx) return args.Error(0) } func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error { args := m.Called(ctx, domain, token) return args.Error(0) } func (m *MockGRPCClient) Close() error { args := m.Called() return args.Error(0) } func TestNew(t *testing.T) { tests := []struct { name string setupConfig func() config.Config setupPort func() port.Port wantErr bool errContains string }{ { name: "Success New with default value", wantErr: false, }, { name: "Error when AddRange fails", setupPort: func() port.Port { mockPort := &MockPort{} mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range")) return mockPort }, wantErr: true, errContains: "invalid port range", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var mockPort port.Port if tt.setupPort != nil { mockPort = tt.setupPort() } else { mockPort = port.New() } var mockConfig config.Config if tt.setupConfig != nil { mockConfig = tt.setupConfig() } else { var err error mockConfig, err = config.MustLoad() assert.NoError(t, err) } bootstrap, err := New(mockConfig, mockPort) if tt.wantErr { assert.Error(t, err) if tt.errContains != "" { assert.Contains(t, err.Error(), tt.errContains) } assert.Nil(t, bootstrap) } else { assert.NoError(t, err) assert.NotNil(t, bootstrap) assert.NotNil(t, bootstrap.Randomizer) assert.NotNil(t, bootstrap.SessionRegistry) assert.NotNil(t, bootstrap.Config) assert.NotNil(t, bootstrap.Port) assert.NotNil(t, bootstrap.ErrChan) assert.NotNil(t, bootstrap.SignalChan) } }) } } func generateTestCert(t *testing.T) (certPEM, keyPEM []byte) { privateKey, err := rsa.GenerateKey(rand.Reader, 2048) require.NoError(t, err) template := x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{ Organization: []string{"Test Co"}, }, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, DNSNames: []string{"localhost"}, IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, } certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) require.NoError(t, err) certPEM = pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: certDER, }) keyPEM = pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey), }) return certPEM, keyPEM } func randomAvailablePort() (string, error) { listener, err := net.Listen("tcp", "localhost:0") if err != nil { return "", err } defer listener.Close() port := listener.Addr().(*net.TCPAddr).Port return strconv.Itoa(port), nil } func TestRun(t *testing.T) { mockRandom := &MockRandom{} mockErrChan := make(chan error, 1) mockSignalChan := make(chan os.Signal, 1) mockSessionRegistry := &MockSessionRegistry{} mockPort := &MockPort{} tmpDir := t.TempDir() keyLoc := filepath.Join(tmpDir, "key.key") tests := []struct { name string setupConfig func() *MockConfig setupGrpcClient func() *MockGRPCClient needCerts bool expectError bool }{ { name: "successful run and termination", setupConfig: func() *MockConfig { mock := &MockConfig{} mock.On("KeyLoc").Return(keyLoc) mock.On("Mode").Return(types.ServerModeSTANDALONE) mock.On("Domain").Return("example.com") mock.On("SSHPort").Return("0") mock.On("HTTPPort").Return("0") mock.On("HTTPSPort").Return("0") mock.On("TLSEnabled").Return(false) mock.On("TLSRedirect").Return(false) mock.On("ACMEEmail").Return("test@example.com") mock.On("CFAPIToken").Return("fake-token") mock.On("ACMEStaging").Return(true) mock.On("AllowedPortsStart").Return(uint16(1024)) mock.On("AllowedPortsEnd").Return(uint16(65535)) mock.On("BufferSize").Return(4096) mock.On("PprofEnabled").Return(false) mock.On("PprofPort").Return("0") mock.On("GRPCAddress").Return("localhost") mock.On("GRPCPort").Return("0") mock.On("NodeToken").Return("fake-node-token") return mock }, expectError: false, }, { name: "error from SSH server invalid port", setupConfig: func() *MockConfig { mock := &MockConfig{} mock.On("KeyLoc").Return(keyLoc) mock.On("Mode").Return(types.ServerModeSTANDALONE) mock.On("Domain").Return("example.com") mock.On("SSHPort").Return("invalid") mock.On("HTTPPort").Return("0") mock.On("HTTPSPort").Return("0") mock.On("TLSEnabled").Return(false) mock.On("TLSRedirect").Return(false) mock.On("ACMEEmail").Return("test@example.com") mock.On("CFAPIToken").Return("fake-token") mock.On("ACMEStaging").Return(true) mock.On("AllowedPortsStart").Return(uint16(1024)) mock.On("AllowedPortsEnd").Return(uint16(65535)) mock.On("BufferSize").Return(4096) mock.On("PprofEnabled").Return(false) mock.On("PprofPort").Return("0") mock.On("GRPCAddress").Return("localhost") mock.On("GRPCPort").Return("0") mock.On("NodeToken").Return("fake-node-token") return mock }, expectError: true, }, { name: "error from HTTP server invalid port", setupConfig: func() *MockConfig { mock := &MockConfig{} mock.On("KeyLoc").Return(keyLoc) mock.On("Mode").Return(types.ServerModeSTANDALONE) mock.On("Domain").Return("example.com") mock.On("SSHPort").Return("0") mock.On("HTTPPort").Return("invalid") mock.On("HTTPSPort").Return("0") mock.On("TLSEnabled").Return(false) mock.On("TLSRedirect").Return(false) mock.On("ACMEEmail").Return("test@example.com") mock.On("CFAPIToken").Return("fake-token") mock.On("ACMEStaging").Return(true) mock.On("AllowedPortsStart").Return(uint16(1024)) mock.On("AllowedPortsEnd").Return(uint16(65535)) mock.On("BufferSize").Return(4096) mock.On("PprofEnabled").Return(false) mock.On("PprofPort").Return("0") mock.On("GRPCAddress").Return("localhost") mock.On("GRPCPort").Return("0") mock.On("NodeToken").Return("fake-node-token") return mock }, expectError: true, }, { name: "error from HTTPS server invalid port", setupConfig: func() *MockConfig { tempDir := os.TempDir() mock := &MockConfig{} mock.On("KeyLoc").Return(keyLoc) mock.On("Mode").Return(types.ServerModeSTANDALONE) mock.On("Domain").Return("example.com") mock.On("SSHPort").Return("0") mock.On("HTTPPort").Return("0") mock.On("HTTPSPort").Return("invalid") mock.On("TLSEnabled").Return(true) mock.On("TLSRedirect").Return(false) mock.On("TLSStoragePath").Return(tempDir) mock.On("ACMEEmail").Return("test@example.com") mock.On("CFAPIToken").Return("fake-token") mock.On("ACMEStaging").Return(true) mock.On("AllowedPortsStart").Return(uint16(1024)) mock.On("AllowedPortsEnd").Return(uint16(65535)) mock.On("BufferSize").Return(4096) mock.On("PprofEnabled").Return(false) mock.On("PprofPort").Return("0") mock.On("GRPCAddress").Return("localhost") mock.On("GRPCPort").Return("0") mock.On("NodeToken").Return("fake-node-token") return mock }, expectError: true, }, { name: "grpc health check failed", setupConfig: func() *MockConfig { mock := &MockConfig{} mock.On("KeyLoc").Return(keyLoc) mock.On("Mode").Return(types.ServerModeNODE) mock.On("Domain").Return("example.com") mock.On("SSHPort").Return("0") mock.On("HTTPPort").Return("0") mock.On("HTTPSPort").Return("0") mock.On("TLSEnabled").Return(false) mock.On("TLSRedirect").Return(false) mock.On("ACMEEmail").Return("test@example.com") mock.On("CFAPIToken").Return("fake-token") mock.On("ACMEStaging").Return(true) mock.On("AllowedPortsStart").Return(uint16(1024)) mock.On("AllowedPortsEnd").Return(uint16(65535)) mock.On("BufferSize").Return(4096) mock.On("PprofEnabled").Return(false) mock.On("PprofPort").Return("0") mock.On("GRPCAddress").Return("localhost") mock.On("GRPCPort").Return("invalid") mock.On("NodeToken").Return("fake-node-token") return mock }, setupGrpcClient: func() *MockGRPCClient { mockGRPCClient := &MockGRPCClient{} mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed")) return mockGRPCClient }, expectError: true, }, { name: "successful run with pprof enabled", setupConfig: func() *MockConfig { mock := &MockConfig{} pprofPort, _ := randomAvailablePort() mock.On("KeyLoc").Return(keyLoc) mock.On("Mode").Return(types.ServerModeSTANDALONE) mock.On("Domain").Return("example.com") mock.On("SSHPort").Return("0") mock.On("HTTPPort").Return("0") mock.On("HTTPSPort").Return("0") mock.On("TLSEnabled").Return(false) mock.On("TLSRedirect").Return(false) mock.On("ACMEEmail").Return("test@example.com") mock.On("CFAPIToken").Return("fake-token") mock.On("ACMEStaging").Return(true) mock.On("AllowedPortsStart").Return(uint16(1024)) mock.On("AllowedPortsEnd").Return(uint16(65535)) mock.On("BufferSize").Return(4096) mock.On("PprofEnabled").Return(true) mock.On("PprofPort").Return(pprofPort) mock.On("GRPCAddress").Return("localhost") mock.On("GRPCPort").Return("0") mock.On("NodeToken").Return("fake-node-token") return mock }, expectError: false, }, { name: "successful run in NODE mode with signal", setupConfig: func() *MockConfig { mock := &MockConfig{} mock.On("KeyLoc").Return(keyLoc) mock.On("Mode").Return(types.ServerModeNODE) mock.On("Domain").Return("example.com") mock.On("SSHPort").Return("0") mock.On("HTTPPort").Return("0") mock.On("HTTPSPort").Return("0") mock.On("TLSEnabled").Return(false) mock.On("TLSRedirect").Return(false) mock.On("ACMEEmail").Return("test@example.com") mock.On("CFAPIToken").Return("fake-token") mock.On("ACMEStaging").Return(true) mock.On("AllowedPortsStart").Return(uint16(1024)) mock.On("AllowedPortsEnd").Return(uint16(65535)) mock.On("BufferSize").Return(4096) mock.On("PprofEnabled").Return(false) mock.On("PprofPort").Return("0") mock.On("GRPCAddress").Return("localhost") mock.On("GRPCPort").Return("0") mock.On("NodeToken").Return("fake-node-token") return mock }, setupGrpcClient: func() *MockGRPCClient { mockGRPCClient := &MockGRPCClient{} mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil) mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil) mockGRPCClient.On("Close").Return(nil) return mockGRPCClient }, expectError: false, }, { name: "successful run in NODE mode with signal buf error when closing", setupConfig: func() *MockConfig { mock := &MockConfig{} mock.On("KeyLoc").Return(keyLoc) mock.On("Mode").Return(types.ServerModeNODE) mock.On("Domain").Return("example.com") mock.On("SSHPort").Return("0") mock.On("HTTPPort").Return("0") mock.On("HTTPSPort").Return("0") mock.On("TLSEnabled").Return(false) mock.On("TLSRedirect").Return(false) mock.On("ACMEEmail").Return("test@example.com") mock.On("CFAPIToken").Return("fake-token") mock.On("ACMEStaging").Return(true) mock.On("AllowedPortsStart").Return(uint16(1024)) mock.On("AllowedPortsEnd").Return(uint16(65535)) mock.On("BufferSize").Return(4096) mock.On("PprofEnabled").Return(false) mock.On("PprofPort").Return("0") mock.On("GRPCAddress").Return("localhost") mock.On("GRPCPort").Return("0") mock.On("NodeToken").Return("fake-node-token") return mock }, setupGrpcClient: func() *MockGRPCClient { mockGRPCClient := &MockGRPCClient{} mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil) mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil) mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy")) return mockGRPCClient }, expectError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockConfig := tt.setupConfig() mockGRPCClient := &MockGRPCClient{} bootstrap := &Bootstrap{ Randomizer: mockRandom, Config: mockConfig, SessionRegistry: mockSessionRegistry, Port: mockPort, ErrChan: mockErrChan, SignalChan: mockSignalChan, GrpcClient: mockGRPCClient, } if tt.setupGrpcClient != nil { bootstrap.GrpcClient = tt.setupGrpcClient() } done := make(chan error, 1) go func() { done <- bootstrap.Run() }() if tt.expectError { err := <-done assert.Error(t, err) } else if tt.name == "successful run with pprof enabled" { time.Sleep(200 * time.Millisecond) fmt.Println(mockConfig.PprofPort()) resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort())) assert.NoError(t, err) assert.Equal(t, 200, resp.StatusCode) resp.Body.Close() mockSignalChan <- os.Interrupt err = <-done assert.NoError(t, err) } else { time.Sleep(time.Second) mockSignalChan <- os.Interrupt err := <-done assert.NoError(t, err) } }) } }