fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 - autoclosed #63
+275
-153
@@ -17,13 +17,22 @@ import (
|
|||||||
"github.com/libdns/cloudflare"
|
"github.com/libdns/cloudflare"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TLSManager interface {
|
func NewTLSConfig(config config.Config) (*tls.Config, error) {
|
||||||
userCertsExistAndValid() bool
|
var initErr error
|
||||||
loadUserCerts() error
|
|
||||||
startCertWatcher()
|
tlsManagerOnce.Do(func() {
|
||||||
initCertMagic() error
|
tm := createTLSManager(config)
|
||||||
getTLSConfig() *tls.Config
|
initErr = tm.initialize()
|
||||||
getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
if initErr == nil {
|
||||||
|
globalTLSManager = tm
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if initErr != nil {
|
||||||
|
return nil, initErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return globalTLSManager.getTLSConfig(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type tlsManager struct {
|
type tlsManager struct {
|
||||||
@@ -41,55 +50,60 @@ type tlsManager struct {
|
|||||||
useCertMagic bool
|
useCertMagic bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var globalTLSManager TLSManager
|
var globalTLSManager *tlsManager
|
||||||
var tlsManagerOnce sync.Once
|
var tlsManagerOnce sync.Once
|
||||||
|
|
||||||
func NewTLSConfig(config config.Config) (*tls.Config, error) {
|
func createTLSManager(cfg config.Config) *tlsManager {
|
||||||
var initErr error
|
storagePath := cfg.TLSStoragePath()
|
||||||
|
cleanBase := filepath.Clean(storagePath)
|
||||||
|
|
||||||
tlsManagerOnce.Do(func() {
|
return &tlsManager{
|
||||||
storagePath := config.TLSStoragePath()
|
config: cfg,
|
||||||
cleanBase := filepath.Clean(storagePath)
|
certPath: filepath.Join(cleanBase, "cert.pem"),
|
||||||
|
keyPath: filepath.Join(cleanBase, "privkey.pem"),
|
||||||
|
storagePath: filepath.Join(cleanBase, "certmagic"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
certPath := filepath.Join(cleanBase, "cert.pem")
|
func (tm *tlsManager) initialize() error {
|
||||||
keyPath := filepath.Join(cleanBase, "privkey.pem")
|
if tm.userCertsExistAndValid() {
|
||||||
storagePathCertMagic := filepath.Join(cleanBase, "certmagic")
|
return tm.initializeWithUserCerts()
|
||||||
|
}
|
||||||
|
return tm.initializeWithCertMagic()
|
||||||
|
}
|
||||||
|
|
||||||
tm := &tlsManager{
|
func (tm *tlsManager) initializeWithUserCerts() error {
|
||||||
config: config,
|
log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath)
|
||||||
certPath: certPath,
|
|
||||||
keyPath: keyPath,
|
|
||||||
storagePath: storagePathCertMagic,
|
|
||||||
}
|
|
||||||
|
|
||||||
if tm.userCertsExistAndValid() {
|
if err := tm.loadUserCerts(); err != nil {
|
||||||
log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath)
|
return fmt.Errorf("failed to load user certificates: %w", err)
|
||||||
if err := tm.loadUserCerts(); err != nil {
|
|
||||||
initErr = fmt.Errorf("failed to load user certificates: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tm.useCertMagic = false
|
|
||||||
tm.startCertWatcher()
|
|
||||||
} else {
|
|
||||||
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain())
|
|
||||||
if err := tm.initCertMagic(); err != nil {
|
|
||||||
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
tm.useCertMagic = true
|
|
||||||
}
|
|
||||||
|
|
||||||
globalTLSManager = tm
|
|
||||||
})
|
|
||||||
|
|
||||||
if initErr != nil {
|
|
||||||
return nil, initErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return globalTLSManager.getTLSConfig(), nil
|
tm.useCertMagic = false
|
||||||
|
tm.startCertWatcher()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) initializeWithCertMagic() error {
|
||||||
|
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic",
|
||||||
|
tm.config.Domain(), tm.config.Domain())
|
||||||
|
|
||||||
|
if err := tm.initCertMagic(); err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize CertMagic: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tm.useCertMagic = true
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *tlsManager) userCertsExistAndValid() bool {
|
func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||||
|
if !tm.certFilesExist() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return validateCertDomains(tm.certPath, tm.config.Domain())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) certFilesExist() bool {
|
||||||
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
||||||
log.Printf("Certificate file not found: %s", tm.certPath)
|
log.Printf("Certificate file not found: %s", tm.certPath)
|
||||||
return false
|
return false
|
||||||
@@ -98,66 +112,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
|
|||||||
log.Printf("Key file not found: %s", tm.keyPath)
|
log.Printf("Key file not found: %s", tm.keyPath)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
return ValidateCertDomains(tm.certPath, tm.config.Domain())
|
|
||||||
}
|
|
||||||
|
|
||||||
func ValidateCertDomains(certPath, domain string) bool {
|
|
||||||
certPEM, err := os.ReadFile(certPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to read certificate: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
block, _ := pem.Decode(certPEM)
|
|
||||||
if block == nil {
|
|
||||||
log.Printf("Failed to decode PEM block from certificate")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := x509.ParseCertificate(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to parse certificate: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().After(cert.NotAfter) {
|
|
||||||
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) {
|
|
||||||
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
var certDomains []string
|
|
||||||
if cert.Subject.CommonName != "" {
|
|
||||||
certDomains = append(certDomains, cert.Subject.CommonName)
|
|
||||||
}
|
|
||||||
certDomains = append(certDomains, cert.DNSNames...)
|
|
||||||
|
|
||||||
hasBase := false
|
|
||||||
hasWildcard := false
|
|
||||||
wildcardDomain := "*." + domain
|
|
||||||
|
|
||||||
for _, d := range certDomains {
|
|
||||||
if d == domain {
|
|
||||||
hasBase = true
|
|
||||||
}
|
|
||||||
if d == wildcardDomain {
|
|
||||||
hasWildcard = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasBase {
|
|
||||||
log.Printf("Certificate does not cover base domain: %s", domain)
|
|
||||||
}
|
|
||||||
if !hasWildcard {
|
|
||||||
log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hasBase && hasWildcard
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *tlsManager) loadUserCerts() error {
|
func (tm *tlsManager) loadUserCerts() error {
|
||||||
@@ -176,62 +131,34 @@ func (tm *tlsManager) loadUserCerts() error {
|
|||||||
|
|
||||||
func (tm *tlsManager) startCertWatcher() {
|
func (tm *tlsManager) startCertWatcher() {
|
||||||
go func() {
|
go func() {
|
||||||
var lastCertMod, lastKeyMod time.Time
|
watcher := newCertWatcher(tm)
|
||||||
|
watcher.watch()
|
||||||
if info, err := os.Stat(tm.certPath); err == nil {
|
|
||||||
lastCertMod = info.ModTime()
|
|
||||||
}
|
|
||||||
if info, err := os.Stat(tm.keyPath); err == nil {
|
|
||||||
lastKeyMod = info.ModTime()
|
|
||||||
}
|
|
||||||
|
|
||||||
ticker := time.NewTicker(30 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for range ticker.C {
|
|
||||||
certInfo, certErr := os.Stat(tm.certPath)
|
|
||||||
keyInfo, keyErr := os.Stat(tm.keyPath)
|
|
||||||
|
|
||||||
if certErr != nil || keyErr != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
|
|
||||||
log.Printf("Certificate files changed, reloading...")
|
|
||||||
|
|
||||||
if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
|
|
||||||
log.Printf("New certificates don't cover required domains")
|
|
||||||
|
|
||||||
if err := tm.initCertMagic(); err != nil {
|
|
||||||
log.Printf("Failed to initialize CertMagic: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
tm.useCertMagic = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tm.loadUserCerts(); err != nil {
|
|
||||||
log.Printf("Failed to reload certificates: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
lastCertMod = certInfo.ModTime()
|
|
||||||
lastKeyMod = keyInfo.ModTime()
|
|
||||||
log.Printf("Certificates reloaded successfully")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *tlsManager) initCertMagic() error {
|
func (tm *tlsManager) initCertMagic() error {
|
||||||
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
if err := tm.createStorageDirectory(); err != nil {
|
||||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if tm.config.CFAPIToken() == "" {
|
if tm.config.CFAPIToken() == "" {
|
||||||
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
magic := tm.createCertMagicConfig()
|
||||||
|
tm.magic = magic
|
||||||
|
|
||||||
|
return tm.obtainCertificates(magic)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) createStorageDirectory() error {
|
||||||
|
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
|
||||||
cfProvider := &cloudflare.Provider{
|
cfProvider := &cloudflare.Provider{
|
||||||
APIToken: tm.config.CFAPIToken(),
|
APIToken: tm.config.CFAPIToken(),
|
||||||
}
|
}
|
||||||
@@ -248,6 +175,13 @@ func (tm *tlsManager) initCertMagic() error {
|
|||||||
Storage: storage,
|
Storage: storage,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
acmeIssuer := tm.createACMEIssuer(magic, cfProvider)
|
||||||
|
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
||||||
|
|
||||||
|
return magic
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer {
|
||||||
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
||||||
Email: tm.config.ACMEEmail(),
|
Email: tm.config.ACMEEmail(),
|
||||||
Agreed: true,
|
Agreed: true,
|
||||||
@@ -266,9 +200,10 @@ func (tm *tlsManager) initCertMagic() error {
|
|||||||
log.Printf("Using Let's Encrypt production server")
|
log.Printf("Using Let's Encrypt production server")
|
||||||
}
|
}
|
||||||
|
|
||||||
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
return acmeIssuer
|
||||||
tm.magic = magic
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error {
|
||||||
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
|
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
|
||||||
log.Printf("Requesting certificates for: %v", domains)
|
log.Printf("Requesting certificates for: %v", domains)
|
||||||
|
|
||||||
@@ -311,3 +246,190 @@ func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certifica
|
|||||||
|
|
||||||
return tm.userCert, nil
|
return tm.userCert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateCertDomains(certPath, domain string) bool {
|
||||||
|
cert, err := loadAndParseCertificate(certPath)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isCertificateValid(cert) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return certCoversRequiredDomains(cert, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadAndParseCertificate(certPath string) (*x509.Certificate, error) {
|
||||||
|
certPEM, err := os.ReadFile(certPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to read certificate: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(certPEM)
|
||||||
|
if block == nil {
|
||||||
|
log.Printf("Failed to decode PEM block from certificate")
|
||||||
|
return nil, fmt.Errorf("failed to decode PEM block")
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to parse certificate: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCertificateValid(cert *x509.Certificate) bool {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if now.After(cert.NotAfter) {
|
||||||
|
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
thirtyDaysFromNow := now.Add(30 * 24 * time.Hour)
|
||||||
|
if thirtyDaysFromNow.After(cert.NotAfter) {
|
||||||
|
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool {
|
||||||
|
certDomains := extractCertDomains(cert)
|
||||||
|
hasBase, hasWildcard := checkDomainCoverage(certDomains, domain)
|
||||||
|
|
||||||
|
logDomainCoverage(hasBase, hasWildcard, domain)
|
||||||
|
return hasBase && hasWildcard
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractCertDomains(cert *x509.Certificate) []string {
|
||||||
|
var domains []string
|
||||||
|
if cert.Subject.CommonName != "" {
|
||||||
|
domains = append(domains, cert.Subject.CommonName)
|
||||||
|
}
|
||||||
|
domains = append(domains, cert.DNSNames...)
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) {
|
||||||
|
wildcardDomain := "*." + domain
|
||||||
|
|
||||||
|
for _, d := range certDomains {
|
||||||
|
if d == domain {
|
||||||
|
hasBase = true
|
||||||
|
}
|
||||||
|
if d == wildcardDomain {
|
||||||
|
hasWildcard = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return hasBase, hasWildcard
|
||||||
|
}
|
||||||
|
|
||||||
|
func logDomainCoverage(hasBase, hasWildcard bool, domain string) {
|
||||||
|
if !hasBase {
|
||||||
|
log.Printf("Certificate does not cover base domain: %s", domain)
|
||||||
|
}
|
||||||
|
if !hasWildcard {
|
||||||
|
log.Printf("Certificate does not cover wildcard domain: *.%s", domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type certWatcher struct {
|
||||||
|
tm *tlsManager
|
||||||
|
lastCertMod time.Time
|
||||||
|
lastKeyMod time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCertWatcher(tm *tlsManager) *certWatcher {
|
||||||
|
watcher := &certWatcher{tm: tm}
|
||||||
|
watcher.initializeModTimes()
|
||||||
|
return watcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *certWatcher) initializeModTimes() {
|
||||||
|
if info, err := os.Stat(cw.tm.certPath); err == nil {
|
||||||
|
cw.lastCertMod = info.ModTime()
|
||||||
|
}
|
||||||
|
if info, err := os.Stat(cw.tm.keyPath); err == nil {
|
||||||
|
cw.lastKeyMod = info.ModTime()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *certWatcher) watch() {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
if cw.checkAndReloadCerts() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *certWatcher) checkAndReloadCerts() bool {
|
||||||
|
certInfo, keyInfo, err := cw.getFileInfo()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cw.filesModified(certInfo, keyInfo) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return cw.handleCertificateChange(certInfo, keyInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) {
|
||||||
|
certInfo, certErr := os.Stat(cw.tm.certPath)
|
||||||
|
keyInfo, keyErr := os.Stat(cw.tm.keyPath)
|
||||||
|
|
||||||
|
if certErr != nil || keyErr != nil {
|
||||||
|
return nil, nil, fmt.Errorf("file stat error")
|
||||||
|
}
|
||||||
|
|
||||||
|
return certInfo, keyInfo, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool {
|
||||||
|
return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool {
|
||||||
|
log.Printf("Certificate files changed, reloading...")
|
||||||
|
|
||||||
|
if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) {
|
||||||
|
return cw.switchToCertMagic()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cw.tm.loadUserCerts(); err != nil {
|
||||||
|
log.Printf("Failed to reload certificates: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.updateModTimes(certInfo, keyInfo)
|
||||||
|
log.Printf("Certificates reloaded successfully")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *certWatcher) switchToCertMagic() bool {
|
||||||
|
log.Printf("New certificates don't cover required domains")
|
||||||
|
|
||||||
|
if err := cw.tm.initCertMagic(); err != nil {
|
||||||
|
log.Printf("Failed to initialize CertMagic: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.tm.useCertMagic = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) {
|
||||||
|
cw.lastCertMod = certInfo.ModTime()
|
||||||
|
cw.lastKeyMod = keyInfo.ModTime()
|
||||||
|
}
|
||||||
|
|||||||
+1081
-77
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user