diff --git a/internal/transport/tls.go b/internal/transport/tls.go index 584dec4..4d62e60 100644 --- a/internal/transport/tls.go +++ b/internal/transport/tls.go @@ -17,13 +17,22 @@ import ( "github.com/libdns/cloudflare" ) -type TLSManager interface { - userCertsExistAndValid() bool - loadUserCerts() error - startCertWatcher() - initCertMagic() error - getTLSConfig() *tls.Config - getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) +func NewTLSConfig(config config.Config) (*tls.Config, error) { + var initErr error + + tlsManagerOnce.Do(func() { + tm := createTLSManager(config) + initErr = tm.initialize() + if initErr == nil { + globalTLSManager = tm + } + }) + + if initErr != nil { + return nil, initErr + } + + return globalTLSManager.getTLSConfig(), nil } type tlsManager struct { @@ -41,55 +50,60 @@ type tlsManager struct { useCertMagic bool } -var globalTLSManager TLSManager +var globalTLSManager *tlsManager var tlsManagerOnce sync.Once -func NewTLSConfig(config config.Config) (*tls.Config, error) { - var initErr error +func createTLSManager(cfg config.Config) *tlsManager { + storagePath := cfg.TLSStoragePath() + cleanBase := filepath.Clean(storagePath) - tlsManagerOnce.Do(func() { - storagePath := config.TLSStoragePath() - cleanBase := filepath.Clean(storagePath) + return &tlsManager{ + config: cfg, + certPath: filepath.Join(cleanBase, "cert.pem"), + keyPath: filepath.Join(cleanBase, "privkey.pem"), + storagePath: filepath.Join(cleanBase, "certmagic"), + } +} - certPath := filepath.Join(cleanBase, "cert.pem") - keyPath := filepath.Join(cleanBase, "privkey.pem") - storagePathCertMagic := filepath.Join(cleanBase, "certmagic") +func (tm *tlsManager) initialize() error { + if tm.userCertsExistAndValid() { + return tm.initializeWithUserCerts() + } + return tm.initializeWithCertMagic() +} - tm := &tlsManager{ - config: config, - certPath: certPath, - keyPath: keyPath, - storagePath: storagePathCertMagic, - } +func (tm *tlsManager) initializeWithUserCerts() error { + log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath) - if tm.userCertsExistAndValid() { - log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath) - if err := tm.loadUserCerts(); err != nil { - initErr = fmt.Errorf("failed to load user certificates: %w", err) - return - } - tm.useCertMagic = false - tm.startCertWatcher() - } else { - log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain()) - if err := tm.initCertMagic(); err != nil { - initErr = fmt.Errorf("failed to initialize CertMagic: %w", err) - return - } - tm.useCertMagic = true - } - - globalTLSManager = tm - }) - - if initErr != nil { - return nil, initErr + if err := tm.loadUserCerts(); err != nil { + return fmt.Errorf("failed to load user certificates: %w", err) } - return globalTLSManager.getTLSConfig(), nil + tm.useCertMagic = false + tm.startCertWatcher() + return nil +} + +func (tm *tlsManager) initializeWithCertMagic() error { + log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", + tm.config.Domain(), tm.config.Domain()) + + if err := tm.initCertMagic(); err != nil { + return fmt.Errorf("failed to initialize CertMagic: %w", err) + } + + tm.useCertMagic = true + return nil } func (tm *tlsManager) userCertsExistAndValid() bool { + if !tm.certFilesExist() { + return false + } + return validateCertDomains(tm.certPath, tm.config.Domain()) +} + +func (tm *tlsManager) certFilesExist() bool { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { log.Printf("Certificate file not found: %s", tm.certPath) return false @@ -98,66 +112,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool { log.Printf("Key file not found: %s", tm.keyPath) return false } - - return ValidateCertDomains(tm.certPath, tm.config.Domain()) -} - -func ValidateCertDomains(certPath, domain string) bool { - certPEM, err := os.ReadFile(certPath) - if err != nil { - log.Printf("Failed to read certificate: %v", err) - return false - } - - block, _ := pem.Decode(certPEM) - if block == nil { - log.Printf("Failed to decode PEM block from certificate") - return false - } - - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - log.Printf("Failed to parse certificate: %v", err) - return false - } - - if time.Now().After(cert.NotAfter) { - log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter) - return false - } - - if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) { - log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter) - return false - } - - var certDomains []string - if cert.Subject.CommonName != "" { - certDomains = append(certDomains, cert.Subject.CommonName) - } - certDomains = append(certDomains, cert.DNSNames...) - - hasBase := false - hasWildcard := false - wildcardDomain := "*." + domain - - for _, d := range certDomains { - if d == domain { - hasBase = true - } - if d == wildcardDomain { - hasWildcard = true - } - } - - if !hasBase { - log.Printf("Certificate does not cover base domain: %s", domain) - } - if !hasWildcard { - log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain) - } - - return hasBase && hasWildcard + return true } func (tm *tlsManager) loadUserCerts() error { @@ -176,62 +131,34 @@ func (tm *tlsManager) loadUserCerts() error { func (tm *tlsManager) startCertWatcher() { go func() { - var lastCertMod, lastKeyMod time.Time - - if info, err := os.Stat(tm.certPath); err == nil { - lastCertMod = info.ModTime() - } - if info, err := os.Stat(tm.keyPath); err == nil { - lastKeyMod = info.ModTime() - } - - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for range ticker.C { - certInfo, certErr := os.Stat(tm.certPath) - keyInfo, keyErr := os.Stat(tm.keyPath) - - if certErr != nil || keyErr != nil { - continue - } - - if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) { - log.Printf("Certificate files changed, reloading...") - - if !ValidateCertDomains(tm.certPath, tm.config.Domain()) { - log.Printf("New certificates don't cover required domains") - - if err := tm.initCertMagic(); err != nil { - log.Printf("Failed to initialize CertMagic: %v", err) - continue - } - tm.useCertMagic = true - return - } - - if err := tm.loadUserCerts(); err != nil { - log.Printf("Failed to reload certificates: %v", err) - continue - } - - lastCertMod = certInfo.ModTime() - lastKeyMod = keyInfo.ModTime() - log.Printf("Certificates reloaded successfully") - } - } + watcher := newCertWatcher(tm) + watcher.watch() }() } func (tm *tlsManager) initCertMagic() error { - if err := os.MkdirAll(tm.storagePath, 0700); err != nil { - return fmt.Errorf("failed to create cert storage directory: %w", err) + if err := tm.createStorageDirectory(); err != nil { + return err } if tm.config.CFAPIToken() == "" { return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") } + magic := tm.createCertMagicConfig() + tm.magic = magic + + return tm.obtainCertificates(magic) +} + +func (tm *tlsManager) createStorageDirectory() error { + if err := os.MkdirAll(tm.storagePath, 0700); err != nil { + return fmt.Errorf("failed to create cert storage directory: %w", err) + } + return nil +} + +func (tm *tlsManager) createCertMagicConfig() *certmagic.Config { cfProvider := &cloudflare.Provider{ APIToken: tm.config.CFAPIToken(), } @@ -248,6 +175,13 @@ func (tm *tlsManager) initCertMagic() error { Storage: storage, }) + acmeIssuer := tm.createACMEIssuer(magic, cfProvider) + magic.Issuers = []certmagic.Issuer{acmeIssuer} + + return magic +} + +func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer { acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ Email: tm.config.ACMEEmail(), Agreed: true, @@ -266,9 +200,10 @@ func (tm *tlsManager) initCertMagic() error { log.Printf("Using Let's Encrypt production server") } - magic.Issuers = []certmagic.Issuer{acmeIssuer} - tm.magic = magic + return acmeIssuer +} +func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error { domains := []string{tm.config.Domain(), "*." + tm.config.Domain()} log.Printf("Requesting certificates for: %v", domains) @@ -311,3 +246,190 @@ func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certifica return tm.userCert, nil } + +func validateCertDomains(certPath, domain string) bool { + cert, err := loadAndParseCertificate(certPath) + if err != nil { + return false + } + + if !isCertificateValid(cert) { + return false + } + + return certCoversRequiredDomains(cert, domain) +} + +func loadAndParseCertificate(certPath string) (*x509.Certificate, error) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + log.Printf("Failed to read certificate: %v", err) + return nil, err + } + + block, _ := pem.Decode(certPEM) + if block == nil { + log.Printf("Failed to decode PEM block from certificate") + return nil, fmt.Errorf("failed to decode PEM block") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + log.Printf("Failed to parse certificate: %v", err) + return nil, err + } + + return cert, nil +} + +func isCertificateValid(cert *x509.Certificate) bool { + now := time.Now() + + if now.After(cert.NotAfter) { + log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter) + return false + } + + thirtyDaysFromNow := now.Add(30 * 24 * time.Hour) + if thirtyDaysFromNow.After(cert.NotAfter) { + log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter) + return false + } + + return true +} + +func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool { + certDomains := extractCertDomains(cert) + hasBase, hasWildcard := checkDomainCoverage(certDomains, domain) + + logDomainCoverage(hasBase, hasWildcard, domain) + return hasBase && hasWildcard +} + +func extractCertDomains(cert *x509.Certificate) []string { + var domains []string + if cert.Subject.CommonName != "" { + domains = append(domains, cert.Subject.CommonName) + } + domains = append(domains, cert.DNSNames...) + return domains +} + +func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) { + wildcardDomain := "*." + domain + + for _, d := range certDomains { + if d == domain { + hasBase = true + } + if d == wildcardDomain { + hasWildcard = true + } + } + + return hasBase, hasWildcard +} + +func logDomainCoverage(hasBase, hasWildcard bool, domain string) { + if !hasBase { + log.Printf("Certificate does not cover base domain: %s", domain) + } + if !hasWildcard { + log.Printf("Certificate does not cover wildcard domain: *.%s", domain) + } +} + +type certWatcher struct { + tm *tlsManager + lastCertMod time.Time + lastKeyMod time.Time +} + +func newCertWatcher(tm *tlsManager) *certWatcher { + watcher := &certWatcher{tm: tm} + watcher.initializeModTimes() + return watcher +} + +func (cw *certWatcher) initializeModTimes() { + if info, err := os.Stat(cw.tm.certPath); err == nil { + cw.lastCertMod = info.ModTime() + } + if info, err := os.Stat(cw.tm.keyPath); err == nil { + cw.lastKeyMod = info.ModTime() + } +} + +func (cw *certWatcher) watch() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if cw.checkAndReloadCerts() { + return + } + } +} + +func (cw *certWatcher) checkAndReloadCerts() bool { + certInfo, keyInfo, err := cw.getFileInfo() + if err != nil { + return false + } + + if !cw.filesModified(certInfo, keyInfo) { + return false + } + + return cw.handleCertificateChange(certInfo, keyInfo) +} + +func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) { + certInfo, certErr := os.Stat(cw.tm.certPath) + keyInfo, keyErr := os.Stat(cw.tm.keyPath) + + if certErr != nil || keyErr != nil { + return nil, nil, fmt.Errorf("file stat error") + } + + return certInfo, keyInfo, nil +} + +func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool { + return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod) +} + +func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool { + log.Printf("Certificate files changed, reloading...") + + if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) { + return cw.switchToCertMagic() + } + + if err := cw.tm.loadUserCerts(); err != nil { + log.Printf("Failed to reload certificates: %v", err) + return false + } + + cw.updateModTimes(certInfo, keyInfo) + log.Printf("Certificates reloaded successfully") + return false +} + +func (cw *certWatcher) switchToCertMagic() bool { + log.Printf("New certificates don't cover required domains") + + if err := cw.tm.initCertMagic(); err != nil { + log.Printf("Failed to initialize CertMagic: %v", err) + return false + } + + cw.tm.useCertMagic = true + return true +} + +func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) { + cw.lastCertMod = certInfo.ModTime() + cw.lastKeyMod = keyInfo.ModTime() +} diff --git a/internal/transport/tls_test.go b/internal/transport/tls_test.go index 17e7214..1518469 100644 --- a/internal/transport/tls_test.go +++ b/internal/transport/tls_test.go @@ -9,8 +9,11 @@ import ( "encoding/pem" "math/big" "os" + "path/filepath" + "sync" "testing" "time" + "tunnel_pls/internal/config" "tunnel_pls/types" "github.com/stretchr/testify/assert" @@ -39,58 +42,14 @@ func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(type func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } func (m *MockConfig) NodeToken() string { return m.Called().String(0) } - -func TestValidateCertDomains_NotFound(t *testing.T) { - result := ValidateCertDomains("nonexistent.pem", "example.com") - assert.False(t, result) -} - -func TestValidateCertDomains_InvalidPEM(t *testing.T) { - tmpFile, err := os.CreateTemp("", "invalid*.pem") - assert.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - _, _ = tmpFile.WriteString("not a pem") - tmpFile.Close() - - result := ValidateCertDomains(tmpFile.Name(), "example.com") - assert.False(t, result) -} - -func TestTLSManager_getTLSConfig(t *testing.T) { - tm := &tlsManager{ - useCertMagic: false, - } - cfg := tm.getTLSConfig() - assert.NotNil(t, cfg) - assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion) -} - -func TestTLSManager_getCertificate_Magic(t *testing.T) { - tm := &tlsManager{ - useCertMagic: true, - } - hello := &tls.ClientHelloInfo{} - assert.Panics(t, func() { - _, _ = tm.getCertificate(hello) - }) -} - -func TestTLSManager_userCertsExistAndValid(t *testing.T) { - tm := &tlsManager{ - certPath: "nonexistent.pem", - keyPath: "nonexistent.key", - } - assert.False(t, tm.userCertsExistAndValid()) - - keyFile, _ := os.CreateTemp("", "key*.pem") - defer os.Remove(keyFile.Name()) - tm.keyPath = keyFile.Name() - assert.False(t, tm.userCertsExistAndValid()) -} +func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, soon bool) (string, string) { - priv, _ := rsa.GenerateKey(rand.Reader, 2048) + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) notAfter := time.Now().Add(365 * 24 * time.Hour) if expired { @@ -116,54 +75,400 @@ func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, so derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) assert.NoError(t, err) - certOut, _ := os.CreateTemp("", "cert*.pem") - _ = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + certOut, err := os.CreateTemp("", "cert*.pem") + assert.NoError(t, err) + err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + assert.NoError(t, err) certOut.Close() - keyOut, _ := os.CreateTemp("", "key*.pem") - _ = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + keyOut, err := os.CreateTemp("", "key*.pem") + assert.NoError(t, err) + err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + assert.NoError(t, err) keyOut.Close() return certOut.Name(), keyOut.Name() } -func TestValidateCertDomains_Success(t *testing.T) { +func setupTestDir(t *testing.T) string { + t.Helper() + + tmpDir, err := os.MkdirTemp("", "tls-test-*") + assert.NoError(t, err) + + t.Cleanup(func() { + os.RemoveAll(tmpDir) + }) + + return tmpDir +} + +func TestValidateCertDomains(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (certPath string, cleanup func()) + domain string + expected bool + }{ + { + name: "file not found", + setup: func(t *testing.T) (string, func()) { + return "nonexistent.pem", func() {} + }, + domain: "example.com", + expected: false, + }, + { + name: "invalid PEM", + setup: func(t *testing.T) (string, func()) { + tmpFile, err := os.CreateTemp("", "invalid*.pem") + assert.NoError(t, err) + _, err = tmpFile.WriteString("not a pem") + assert.NoError(t, err) + tmpFile.Close() + return tmpFile.Name(), func() { os.Remove(tmpFile.Name()) } + }, + domain: "example.com", + expected: false, + }, + { + name: "valid cert with wildcard", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + return certPath, func() { + os.Remove(certPath) + os.Remove(keyPath) + } + }, + domain: "example.com", + expected: true, + }, + { + name: "expired cert", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", true, true, false) + return certPath, func() { + os.Remove(certPath) + os.Remove(keyPath) + } + }, + domain: "example.com", + expected: false, + }, + { + name: "cert expiring soon", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", true, false, true) + return certPath, func() { + os.Remove(certPath) + os.Remove(keyPath) + } + }, + domain: "example.com", + expected: false, + }, + { + name: "missing wildcard", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", false, false, false) + return certPath, func() { + os.Remove(certPath) + os.Remove(keyPath) + } + }, + domain: "example.com", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + certPath, cleanup := tt.setup(t) + defer cleanup() + + result := validateCertDomains(certPath, tt.domain) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestLoadAndParseCertificate(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (certPath string, cleanup func()) + wantError bool + validate func(t *testing.T, cert *x509.Certificate) + }{ + { + name: "success", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + return certPath, func() { + os.Remove(certPath) + os.Remove(keyPath) + } + }, + wantError: false, + validate: func(t *testing.T, cert *x509.Certificate) { + assert.Equal(t, "example.com", cert.Subject.CommonName) + }, + }, + { + name: "file not found", + setup: func(t *testing.T) (string, func()) { + return "nonexistent.pem", func() {} + }, + wantError: true, + validate: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + certPath, cleanup := tt.setup(t) + defer cleanup() + + cert, err := loadAndParseCertificate(certPath) + + if tt.wantError { + assert.Error(t, err) + assert.Nil(t, cert) + } else { + assert.NoError(t, err) + assert.NotNil(t, cert) + if tt.validate != nil { + tt.validate(t, cert) + } + } + }) + } +} + +func TestIsCertificateValid(t *testing.T) { + tests := []struct { + name string + expired bool + soon bool + expected bool + }{ + { + name: "valid certificate", + expired: false, + soon: false, + expected: true, + }, + { + name: "expired certificate", + expired: true, + soon: false, + expected: false, + }, + { + name: "expiring soon", + expired: false, + soon: true, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, tt.expired, tt.soon) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + cert, err := loadAndParseCertificate(certPath) + assert.NoError(t, err) + + result := isCertificateValid(cert) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestExtractCertDomains(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) defer os.Remove(certPath) defer os.Remove(keyPath) - result := ValidateCertDomains(certPath, "example.com") - assert.True(t, result) + cert, err := loadAndParseCertificate(certPath) + assert.NoError(t, err) + + domains := extractCertDomains(cert) + assert.Contains(t, domains, "example.com") + assert.Contains(t, domains, "*.example.com") } -func TestValidateCertDomains_Expired(t *testing.T) { - certPath, keyPath := createTestCert(t, "example.com", true, true, false) - defer os.Remove(certPath) - defer os.Remove(keyPath) +func TestCheckDomainCoverage(t *testing.T) { + tests := []struct { + name string + certDomains []string + domain string + wantBase bool + wantWildcard bool + }{ + { + name: "both covered", + certDomains: []string{"example.com", "*.example.com"}, + domain: "example.com", + wantBase: true, + wantWildcard: true, + }, + { + name: "only base", + certDomains: []string{"example.com"}, + domain: "example.com", + wantBase: true, + wantWildcard: false, + }, + { + name: "only wildcard", + certDomains: []string{"*.example.com"}, + domain: "example.com", + wantBase: false, + wantWildcard: true, + }, + { + name: "neither", + certDomains: []string{"other.com"}, + domain: "example.com", + wantBase: false, + wantWildcard: false, + }, + } - result := ValidateCertDomains(certPath, "example.com") - assert.False(t, result) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hasBase, hasWildcard := checkDomainCoverage(tt.certDomains, tt.domain) + assert.Equal(t, tt.wantBase, hasBase) + assert.Equal(t, tt.wantWildcard, hasWildcard) + }) + } } -func TestValidateCertDomains_ExpiringSoon(t *testing.T) { - certPath, keyPath := createTestCert(t, "example.com", true, false, true) - defer os.Remove(certPath) - defer os.Remove(keyPath) - - result := ValidateCertDomains(certPath, "example.com") - assert.False(t, result) +func TestTLSManager_getTLSConfig(t *testing.T) { + tm := &tlsManager{ + useCertMagic: false, + } + cfg := tm.getTLSConfig() + assert.NotNil(t, cfg) + assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion) + assert.Equal(t, uint16(tls.VersionTLS13), cfg.MaxVersion) + assert.NotNil(t, cfg.GetCertificate) } -func TestValidateCertDomains_MissingWildcard(t *testing.T) { - certPath, keyPath := createTestCert(t, "example.com", false, false, false) - defer os.Remove(certPath) - defer os.Remove(keyPath) +func TestTLSManager_getCertificate(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + wantError bool + errorContains string + }{ + { + name: "no certificate available", + setup: func(t *testing.T) *tlsManager { + return &tlsManager{ + useCertMagic: false, + userCert: nil, + } + }, + wantError: true, + errorContains: "no certificate available", + }, + { + name: "with user certificate", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) - result := ValidateCertDomains(certPath, "example.com") - assert.False(t, result) + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + assert.NoError(t, err) + + return &tlsManager{ + useCertMagic: false, + userCert: &cert, + } + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + hello := &tls.ClientHelloInfo{ + ServerName: "example.com", + } + + cert, err := tm.getCertificate(hello) + + if tt.wantError { + assert.Error(t, err) + assert.Nil(t, cert) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, cert) + } + }) + } } -func TestTLSManager_loadUserCerts_Success(t *testing.T) { +func TestTLSManager_userCertsExistAndValid(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + expected bool + }{ + { + name: "no files", + setup: func(t *testing.T) *tlsManager { + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + return &tlsManager{ + config: mockCfg, + certPath: "nonexistent.pem", + keyPath: "nonexistent.key", + } + }, + expected: false, + }, + { + name: "missing key file", + setup: func(t *testing.T) *tlsManager { + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { os.Remove(certPath) }) + os.Remove(keyPath) + + return &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + result := tm.userCertsExistAndValid() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTLSManager_certFilesExist(t *testing.T) { certPath, keyPath := createTestCert(t, "example.com", true, false, false) defer os.Remove(certPath) defer os.Remove(keyPath) @@ -172,7 +477,706 @@ func TestTLSManager_loadUserCerts_Success(t *testing.T) { certPath: certPath, keyPath: keyPath, } + + result := tm.certFilesExist() + assert.True(t, result) +} + +func TestTLSManager_loadUserCerts(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + wantError bool + }{ + { + name: "success", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + return &tlsManager{ + certPath: certPath, + keyPath: keyPath, + } + }, + wantError: false, + }, + { + name: "invalid path", + setup: func(t *testing.T) *tlsManager { + return &tlsManager{ + certPath: "nonexistent.pem", + keyPath: "nonexistent.key", + } + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + err := tm.loadUserCerts() + + if tt.wantError { + assert.Error(t, err) + assert.Nil(t, tm.userCert) + } else { + assert.NoError(t, err) + assert.NotNil(t, tm.userCert) + } + }) + } +} + +func TestCreateTLSManager(t *testing.T) { + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("TLSStoragePath").Return(tmpDir) + + tm := createTLSManager(mockCfg) + + assert.NotNil(t, tm) + assert.Equal(t, mockCfg, tm.config) + assert.Equal(t, filepath.Join(tmpDir, "cert.pem"), tm.certPath) + assert.Equal(t, filepath.Join(tmpDir, "privkey.pem"), tm.keyPath) + assert.Equal(t, filepath.Join(tmpDir, "certmagic"), tm.storagePath) +} + +func TestNewCertWatcher(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + mockCfg := &MockConfig{} + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + watcher := newCertWatcher(tm) + + assert.NotNil(t, watcher) + assert.Equal(t, tm, watcher.tm) + assert.False(t, watcher.lastCertMod.IsZero()) + assert.False(t, watcher.lastKeyMod.IsZero()) +} + +func TestCertWatcher_filesModified(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + mockCfg := &MockConfig{} + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + watcher := newCertWatcher(tm) + + certInfo, err := os.Stat(certPath) + assert.NoError(t, err) + keyInfo, err := os.Stat(keyPath) + assert.NoError(t, err) + + result := watcher.filesModified(certInfo, keyInfo) + assert.False(t, result) + + watcher.lastCertMod = time.Now().Add(-1 * time.Hour) + + result = watcher.filesModified(certInfo, keyInfo) + assert.True(t, result) +} + +func TestCertWatcher_updateModTimes(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + mockCfg := &MockConfig{} + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + watcher := newCertWatcher(tm) + + certInfo, err := os.Stat(certPath) + assert.NoError(t, err) + keyInfo, err := os.Stat(keyPath) + assert.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + watcher.updateModTimes(certInfo, keyInfo) + + assert.Equal(t, certInfo.ModTime(), watcher.lastCertMod) + assert.Equal(t, keyInfo.ModTime(), watcher.lastKeyMod) +} + +func TestCertWatcher_getFileInfo(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + wantError bool + validate func(t *testing.T, certInfo, keyInfo os.FileInfo) + }{ + { + name: "success", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + return &tlsManager{ + config: &MockConfig{}, + certPath: certPath, + keyPath: keyPath, + } + }, + wantError: false, + validate: func(t *testing.T, certInfo, keyInfo os.FileInfo) { + assert.NotNil(t, certInfo) + assert.NotNil(t, keyInfo) + }, + }, + { + name: "missing cert file", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + os.Remove(certPath) + t.Cleanup(func() { os.Remove(keyPath) }) + + return &tlsManager{ + config: &MockConfig{}, + certPath: certPath, + keyPath: keyPath, + } + }, + wantError: true, + }, + { + name: "missing key file", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + os.Remove(keyPath) + t.Cleanup(func() { os.Remove(certPath) }) + + return &tlsManager{ + config: &MockConfig{}, + certPath: certPath, + keyPath: keyPath, + } + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + watcher := newCertWatcher(tm) + + certInfo, keyInfo, err := watcher.getFileInfo() + + if tt.wantError { + assert.Error(t, err) + assert.Nil(t, certInfo) + assert.Nil(t, keyInfo) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, certInfo, keyInfo) + } + } + }) + } +} + +func TestCertWatcher_checkAndReloadCerts(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (*tlsManager, *certWatcher) + expected bool + }{ + { + name: "file error", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + tm := &tlsManager{ + config: &MockConfig{}, + certPath: "nonexistent.pem", + keyPath: "nonexistent.key", + } + return tm, newCertWatcher(tm) + }, + expected: false, + }, + { + name: "no modification", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + tm := &tlsManager{ + config: &MockConfig{}, + certPath: certPath, + keyPath: keyPath, + } + return tm, newCertWatcher(tm) + }, + expected: false, + }, + { + name: "with modification", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + err := tm.loadUserCerts() + assert.NoError(t, err) + + watcher := newCertWatcher(tm) + watcher.lastCertMod = time.Now().Add(-1 * time.Hour) + + return tm, watcher + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, watcher := tt.setup(t) + result := watcher.checkAndReloadCerts() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCertWatcher_handleCertificateChange(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) + expected bool + }{ + { + name: "successful reload", + setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + watcher := newCertWatcher(tm) + + certInfo, _ := os.Stat(certPath) + keyInfo, _ := os.Stat(keyPath) + + return tm, watcher, certInfo, keyInfo + }, + expected: false, + }, + { + name: "invalid cert triggers certmagic", + setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) { + certPath, keyPath := createTestCert(t, "example.com", false, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + storagePath: tmpDir, + } + + watcher := newCertWatcher(tm) + + certInfo, _ := os.Stat(certPath) + keyInfo, _ := os.Stat(keyPath) + + return tm, watcher, certInfo, keyInfo + }, + expected: false, + }, + { + name: "load error", + setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: "nonexistent.key", + } + + watcher := newCertWatcher(tm) + + certInfo, _ := os.Stat(certPath) + keyInfo, _ := os.Stat(keyPath) + + return tm, watcher, certInfo, keyInfo + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, watcher, certInfo, keyInfo := tt.setup(t) + result := watcher.handleCertificateChange(certInfo, keyInfo) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCertWatcher_switchToCertMagic(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + expected bool + }{ + { + name: "with staging token", + setup: func(t *testing.T) *tlsManager { + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("test-token") + mockCfg.On("ACMEEmail").Return("test@example.com") + mockCfg.On("ACMEStaging").Return(true) + + return &tlsManager{ + config: mockCfg, + storagePath: tmpDir, + } + }, + expected: false, + }, + { + name: "missing token", + setup: func(t *testing.T) *tlsManager { + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("") + + return &tlsManager{ + config: mockCfg, + storagePath: tmpDir, + } + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + watcher := newCertWatcher(tm) + result := watcher.switchToCertMagic() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCertWatcher_watch(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (*tlsManager, *certWatcher) + expected bool + }{ + { + name: "exits on certmagic switch attempt", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + certPath, keyPath := createTestCert(t, "example.com", false, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + storagePath: tmpDir, + } + + watcher := newCertWatcher(tm) + watcher.lastCertMod = time.Now().Add(-1 * time.Hour) + watcher.lastKeyMod = time.Now().Add(-1 * time.Hour) + + return tm, watcher + }, + expected: false, + }, + { + name: "continues on no modification", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + return tm, newCertWatcher(tm) + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, watcher := tt.setup(t) + result := watcher.checkAndReloadCerts() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCertWatcher_watch_Integration(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + err := tm.loadUserCerts() assert.NoError(t, err) + initialCert := tm.userCert + + watcher := newCertWatcher(tm) + + go watcher.watch() + + time.Sleep(50 * time.Millisecond) + + newCertPath, newKeyPath := createTestCert(t, "example.com", true, false, false) + defer os.Remove(newCertPath) + defer os.Remove(newKeyPath) + + newCertData, err := os.ReadFile(newCertPath) + assert.NoError(t, err) + newKeyData, err := os.ReadFile(newKeyPath) + assert.NoError(t, err) + + err = os.WriteFile(certPath, newCertData, 0644) + assert.NoError(t, err) + err = os.WriteFile(keyPath, newKeyData, 0644) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + assert.NotNil(t, tm.userCert) + assert.Equal(t, initialCert, tm.userCert) +} + +func TestNewTLSConfig(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) config.Config + wantError bool + errorMsg string + validate func(t *testing.T, cfg *tls.Config) + }{ + { + name: "with valid user certs", + setup: func(t *testing.T) config.Config { + globalTLSManager = nil + tlsManagerOnce = sync.Once{} + + tmpDir := setupTestDir(t) + + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + os.Remove(certPath) + os.Remove(keyPath) + }) + + certData, err := os.ReadFile(certPath) + assert.NoError(t, err) + keyData, err := os.ReadFile(keyPath) + assert.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpDir, "cert.pem"), certData, 0644) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, "privkey.pem"), keyData, 0644) + assert.NoError(t, err) + + mockCfg := &MockConfig{} + mockCfg.On("TLSStoragePath").Return(tmpDir) + mockCfg.On("Domain").Return("example.com") + + return mockCfg + }, + wantError: false, + validate: func(t *testing.T, cfg *tls.Config) { + assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion) + assert.NotNil(t, cfg.GetCertificate) + }, + }, + { + name: "missing certs requires certmagic", + setup: func(t *testing.T) config.Config { + globalTLSManager = nil + tlsManagerOnce = sync.Once{} + + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("TLSStoragePath").Return(tmpDir) + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("") + + return mockCfg + }, + wantError: true, + errorMsg: "CF_API_TOKEN", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.setup(t) + tlsConfig, err := NewTLSConfig(cfg) + + if tt.wantError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, tlsConfig) + if tt.validate != nil { + tt.validate(t, tlsConfig) + } + } + }) + } +} + +func TestNewTLSConfig_Singleton(t *testing.T) { + globalTLSManager = nil + tlsManagerOnce = sync.Once{} + + tmpDir := setupTestDir(t) + + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + certData, err := os.ReadFile(certPath) + assert.NoError(t, err) + keyData, err := os.ReadFile(keyPath) + assert.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpDir, "cert.pem"), certData, 0644) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, "privkey.pem"), keyData, 0644) + assert.NoError(t, err) + + mockCfg := &MockConfig{} + mockCfg.On("TLSStoragePath").Return(tmpDir) + mockCfg.On("Domain").Return("example.com") + + tlsConfig1, err1 := NewTLSConfig(mockCfg) + tlsConfig2, err2 := NewTLSConfig(mockCfg) + + assert.NoError(t, err1) + assert.NoError(t, err2) + assert.NotNil(t, tlsConfig1) + assert.NotNil(t, tlsConfig2) + + assert.Equal(t, tlsConfig1.MinVersion, tlsConfig2.MinVersion) + assert.Equal(t, tlsConfig1.MaxVersion, tlsConfig2.MaxVersion) + assert.Equal(t, tlsConfig1.SessionTicketsDisabled, tlsConfig2.SessionTicketsDisabled) + assert.Equal(t, tlsConfig1.ClientAuth, tlsConfig2.ClientAuth) + + hello := &tls.ClientHelloInfo{ServerName: "example.com"} + cert1, err1 := tlsConfig1.GetCertificate(hello) + cert2, err2 := tlsConfig2.GetCertificate(hello) + + assert.NoError(t, err1) + assert.NoError(t, err2) + assert.NotNil(t, cert1) + assert.NotNil(t, cert2) + + assert.Equal(t, cert1, cert2) }