refactor(transport): reduce cognitive complexity and clean up public API
This commit is contained in:
+275
-153
@@ -17,13 +17,22 @@ import (
|
||||
"github.com/libdns/cloudflare"
|
||||
)
|
||||
|
||||
type TLSManager interface {
|
||||
userCertsExistAndValid() bool
|
||||
loadUserCerts() error
|
||||
startCertWatcher()
|
||||
initCertMagic() error
|
||||
getTLSConfig() *tls.Config
|
||||
getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
func NewTLSConfig(config config.Config) (*tls.Config, error) {
|
||||
var initErr error
|
||||
|
||||
tlsManagerOnce.Do(func() {
|
||||
tm := createTLSManager(config)
|
||||
initErr = tm.initialize()
|
||||
if initErr == nil {
|
||||
globalTLSManager = tm
|
||||
}
|
||||
})
|
||||
|
||||
if initErr != nil {
|
||||
return nil, initErr
|
||||
}
|
||||
|
||||
return globalTLSManager.getTLSConfig(), nil
|
||||
}
|
||||
|
||||
type tlsManager struct {
|
||||
@@ -41,55 +50,60 @@ type tlsManager struct {
|
||||
useCertMagic bool
|
||||
}
|
||||
|
||||
var globalTLSManager TLSManager
|
||||
var globalTLSManager *tlsManager
|
||||
var tlsManagerOnce sync.Once
|
||||
|
||||
func NewTLSConfig(config config.Config) (*tls.Config, error) {
|
||||
var initErr error
|
||||
func createTLSManager(cfg config.Config) *tlsManager {
|
||||
storagePath := cfg.TLSStoragePath()
|
||||
cleanBase := filepath.Clean(storagePath)
|
||||
|
||||
tlsManagerOnce.Do(func() {
|
||||
storagePath := config.TLSStoragePath()
|
||||
cleanBase := filepath.Clean(storagePath)
|
||||
return &tlsManager{
|
||||
config: cfg,
|
||||
certPath: filepath.Join(cleanBase, "cert.pem"),
|
||||
keyPath: filepath.Join(cleanBase, "privkey.pem"),
|
||||
storagePath: filepath.Join(cleanBase, "certmagic"),
|
||||
}
|
||||
}
|
||||
|
||||
certPath := filepath.Join(cleanBase, "cert.pem")
|
||||
keyPath := filepath.Join(cleanBase, "privkey.pem")
|
||||
storagePathCertMagic := filepath.Join(cleanBase, "certmagic")
|
||||
func (tm *tlsManager) initialize() error {
|
||||
if tm.userCertsExistAndValid() {
|
||||
return tm.initializeWithUserCerts()
|
||||
}
|
||||
return tm.initializeWithCertMagic()
|
||||
}
|
||||
|
||||
tm := &tlsManager{
|
||||
config: config,
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
storagePath: storagePathCertMagic,
|
||||
}
|
||||
func (tm *tlsManager) initializeWithUserCerts() error {
|
||||
log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath)
|
||||
|
||||
if tm.userCertsExistAndValid() {
|
||||
log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath)
|
||||
if err := tm.loadUserCerts(); err != nil {
|
||||
initErr = fmt.Errorf("failed to load user certificates: %w", err)
|
||||
return
|
||||
}
|
||||
tm.useCertMagic = false
|
||||
tm.startCertWatcher()
|
||||
} else {
|
||||
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain())
|
||||
if err := tm.initCertMagic(); err != nil {
|
||||
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
|
||||
return
|
||||
}
|
||||
tm.useCertMagic = true
|
||||
}
|
||||
|
||||
globalTLSManager = tm
|
||||
})
|
||||
|
||||
if initErr != nil {
|
||||
return nil, initErr
|
||||
if err := tm.loadUserCerts(); err != nil {
|
||||
return fmt.Errorf("failed to load user certificates: %w", err)
|
||||
}
|
||||
|
||||
return globalTLSManager.getTLSConfig(), nil
|
||||
tm.useCertMagic = false
|
||||
tm.startCertWatcher()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initializeWithCertMagic() error {
|
||||
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic",
|
||||
tm.config.Domain(), tm.config.Domain())
|
||||
|
||||
if err := tm.initCertMagic(); err != nil {
|
||||
return fmt.Errorf("failed to initialize CertMagic: %w", err)
|
||||
}
|
||||
|
||||
tm.useCertMagic = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||
if !tm.certFilesExist() {
|
||||
return false
|
||||
}
|
||||
return validateCertDomains(tm.certPath, tm.config.Domain())
|
||||
}
|
||||
|
||||
func (tm *tlsManager) certFilesExist() bool {
|
||||
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
||||
log.Printf("Certificate file not found: %s", tm.certPath)
|
||||
return false
|
||||
@@ -98,66 +112,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||
log.Printf("Key file not found: %s", tm.keyPath)
|
||||
return false
|
||||
}
|
||||
|
||||
return ValidateCertDomains(tm.certPath, tm.config.Domain())
|
||||
}
|
||||
|
||||
func ValidateCertDomains(certPath, domain string) bool {
|
||||
certPEM, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read certificate: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
log.Printf("Failed to decode PEM block from certificate")
|
||||
return false
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse certificate: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().After(cert.NotAfter) {
|
||||
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) {
|
||||
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
var certDomains []string
|
||||
if cert.Subject.CommonName != "" {
|
||||
certDomains = append(certDomains, cert.Subject.CommonName)
|
||||
}
|
||||
certDomains = append(certDomains, cert.DNSNames...)
|
||||
|
||||
hasBase := false
|
||||
hasWildcard := false
|
||||
wildcardDomain := "*." + domain
|
||||
|
||||
for _, d := range certDomains {
|
||||
if d == domain {
|
||||
hasBase = true
|
||||
}
|
||||
if d == wildcardDomain {
|
||||
hasWildcard = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasBase {
|
||||
log.Printf("Certificate does not cover base domain: %s", domain)
|
||||
}
|
||||
if !hasWildcard {
|
||||
log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain)
|
||||
}
|
||||
|
||||
return hasBase && hasWildcard
|
||||
return true
|
||||
}
|
||||
|
||||
func (tm *tlsManager) loadUserCerts() error {
|
||||
@@ -176,62 +131,34 @@ func (tm *tlsManager) loadUserCerts() error {
|
||||
|
||||
func (tm *tlsManager) startCertWatcher() {
|
||||
go func() {
|
||||
var lastCertMod, lastKeyMod time.Time
|
||||
|
||||
if info, err := os.Stat(tm.certPath); err == nil {
|
||||
lastCertMod = info.ModTime()
|
||||
}
|
||||
if info, err := os.Stat(tm.keyPath); err == nil {
|
||||
lastKeyMod = info.ModTime()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
certInfo, certErr := os.Stat(tm.certPath)
|
||||
keyInfo, keyErr := os.Stat(tm.keyPath)
|
||||
|
||||
if certErr != nil || keyErr != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
|
||||
log.Printf("Certificate files changed, reloading...")
|
||||
|
||||
if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
|
||||
log.Printf("New certificates don't cover required domains")
|
||||
|
||||
if err := tm.initCertMagic(); err != nil {
|
||||
log.Printf("Failed to initialize CertMagic: %v", err)
|
||||
continue
|
||||
}
|
||||
tm.useCertMagic = true
|
||||
return
|
||||
}
|
||||
|
||||
if err := tm.loadUserCerts(); err != nil {
|
||||
log.Printf("Failed to reload certificates: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
lastCertMod = certInfo.ModTime()
|
||||
lastKeyMod = keyInfo.ModTime()
|
||||
log.Printf("Certificates reloaded successfully")
|
||||
}
|
||||
}
|
||||
watcher := newCertWatcher(tm)
|
||||
watcher.watch()
|
||||
}()
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initCertMagic() error {
|
||||
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||
if err := tm.createStorageDirectory(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tm.config.CFAPIToken() == "" {
|
||||
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
||||
}
|
||||
|
||||
magic := tm.createCertMagicConfig()
|
||||
tm.magic = magic
|
||||
|
||||
return tm.obtainCertificates(magic)
|
||||
}
|
||||
|
||||
func (tm *tlsManager) createStorageDirectory() error {
|
||||
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
|
||||
cfProvider := &cloudflare.Provider{
|
||||
APIToken: tm.config.CFAPIToken(),
|
||||
}
|
||||
@@ -248,6 +175,13 @@ func (tm *tlsManager) initCertMagic() error {
|
||||
Storage: storage,
|
||||
})
|
||||
|
||||
acmeIssuer := tm.createACMEIssuer(magic, cfProvider)
|
||||
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
||||
|
||||
return magic
|
||||
}
|
||||
|
||||
func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer {
|
||||
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
||||
Email: tm.config.ACMEEmail(),
|
||||
Agreed: true,
|
||||
@@ -266,9 +200,10 @@ func (tm *tlsManager) initCertMagic() error {
|
||||
log.Printf("Using Let's Encrypt production server")
|
||||
}
|
||||
|
||||
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
||||
tm.magic = magic
|
||||
return acmeIssuer
|
||||
}
|
||||
|
||||
func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error {
|
||||
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
|
||||
log.Printf("Requesting certificates for: %v", domains)
|
||||
|
||||
@@ -311,3 +246,190 @@ func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certifica
|
||||
|
||||
return tm.userCert, nil
|
||||
}
|
||||
|
||||
func validateCertDomains(certPath, domain string) bool {
|
||||
cert, err := loadAndParseCertificate(certPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !isCertificateValid(cert) {
|
||||
return false
|
||||
}
|
||||
|
||||
return certCoversRequiredDomains(cert, domain)
|
||||
}
|
||||
|
||||
func loadAndParseCertificate(certPath string) (*x509.Certificate, error) {
|
||||
certPEM, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read certificate: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
log.Printf("Failed to decode PEM block from certificate")
|
||||
return nil, fmt.Errorf("failed to decode PEM block")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse certificate: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func isCertificateValid(cert *x509.Certificate) bool {
|
||||
now := time.Now()
|
||||
|
||||
if now.After(cert.NotAfter) {
|
||||
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
thirtyDaysFromNow := now.Add(30 * 24 * time.Hour)
|
||||
if thirtyDaysFromNow.After(cert.NotAfter) {
|
||||
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool {
|
||||
certDomains := extractCertDomains(cert)
|
||||
hasBase, hasWildcard := checkDomainCoverage(certDomains, domain)
|
||||
|
||||
logDomainCoverage(hasBase, hasWildcard, domain)
|
||||
return hasBase && hasWildcard
|
||||
}
|
||||
|
||||
func extractCertDomains(cert *x509.Certificate) []string {
|
||||
var domains []string
|
||||
if cert.Subject.CommonName != "" {
|
||||
domains = append(domains, cert.Subject.CommonName)
|
||||
}
|
||||
domains = append(domains, cert.DNSNames...)
|
||||
return domains
|
||||
}
|
||||
|
||||
func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) {
|
||||
wildcardDomain := "*." + domain
|
||||
|
||||
for _, d := range certDomains {
|
||||
if d == domain {
|
||||
hasBase = true
|
||||
}
|
||||
if d == wildcardDomain {
|
||||
hasWildcard = true
|
||||
}
|
||||
}
|
||||
|
||||
return hasBase, hasWildcard
|
||||
}
|
||||
|
||||
func logDomainCoverage(hasBase, hasWildcard bool, domain string) {
|
||||
if !hasBase {
|
||||
log.Printf("Certificate does not cover base domain: %s", domain)
|
||||
}
|
||||
if !hasWildcard {
|
||||
log.Printf("Certificate does not cover wildcard domain: *.%s", domain)
|
||||
}
|
||||
}
|
||||
|
||||
type certWatcher struct {
|
||||
tm *tlsManager
|
||||
lastCertMod time.Time
|
||||
lastKeyMod time.Time
|
||||
}
|
||||
|
||||
func newCertWatcher(tm *tlsManager) *certWatcher {
|
||||
watcher := &certWatcher{tm: tm}
|
||||
watcher.initializeModTimes()
|
||||
return watcher
|
||||
}
|
||||
|
||||
func (cw *certWatcher) initializeModTimes() {
|
||||
if info, err := os.Stat(cw.tm.certPath); err == nil {
|
||||
cw.lastCertMod = info.ModTime()
|
||||
}
|
||||
if info, err := os.Stat(cw.tm.keyPath); err == nil {
|
||||
cw.lastKeyMod = info.ModTime()
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *certWatcher) watch() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if cw.checkAndReloadCerts() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *certWatcher) checkAndReloadCerts() bool {
|
||||
certInfo, keyInfo, err := cw.getFileInfo()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cw.filesModified(certInfo, keyInfo) {
|
||||
return false
|
||||
}
|
||||
|
||||
return cw.handleCertificateChange(certInfo, keyInfo)
|
||||
}
|
||||
|
||||
func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) {
|
||||
certInfo, certErr := os.Stat(cw.tm.certPath)
|
||||
keyInfo, keyErr := os.Stat(cw.tm.keyPath)
|
||||
|
||||
if certErr != nil || keyErr != nil {
|
||||
return nil, nil, fmt.Errorf("file stat error")
|
||||
}
|
||||
|
||||
return certInfo, keyInfo, nil
|
||||
}
|
||||
|
||||
func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool {
|
||||
return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod)
|
||||
}
|
||||
|
||||
func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool {
|
||||
log.Printf("Certificate files changed, reloading...")
|
||||
|
||||
if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) {
|
||||
return cw.switchToCertMagic()
|
||||
}
|
||||
|
||||
if err := cw.tm.loadUserCerts(); err != nil {
|
||||
log.Printf("Failed to reload certificates: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
cw.updateModTimes(certInfo, keyInfo)
|
||||
log.Printf("Certificates reloaded successfully")
|
||||
return false
|
||||
}
|
||||
|
||||
func (cw *certWatcher) switchToCertMagic() bool {
|
||||
log.Printf("New certificates don't cover required domains")
|
||||
|
||||
if err := cw.tm.initCertMagic(); err != nil {
|
||||
log.Printf("Failed to initialize CertMagic: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
cw.tm.useCertMagic = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) {
|
||||
cw.lastCertMod = certInfo.ModTime()
|
||||
cw.lastKeyMod = keyInfo.ModTime()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user