436 lines
9.9 KiB
Go
436 lines
9.9 KiB
Go
package transport
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"sync"
|
|
"time"
|
|
"tunnel_pls/internal/config"
|
|
|
|
"github.com/caddyserver/certmagic"
|
|
"github.com/libdns/cloudflare"
|
|
)
|
|
|
|
func NewTLSConfig(config config.Config) (*tls.Config, error) {
|
|
var initErr error
|
|
|
|
tlsManagerOnce.Do(func() {
|
|
tm := createTLSManager(config)
|
|
initErr = tm.initialize()
|
|
if initErr == nil {
|
|
globalTLSManager = tm
|
|
}
|
|
})
|
|
|
|
if initErr != nil {
|
|
return nil, initErr
|
|
}
|
|
|
|
return globalTLSManager.getTLSConfig(), nil
|
|
}
|
|
|
|
type tlsManager struct {
|
|
config config.Config
|
|
|
|
certPath string
|
|
keyPath string
|
|
storagePath string
|
|
|
|
userCert *tls.Certificate
|
|
userCertMu sync.RWMutex
|
|
|
|
magic *certmagic.Config
|
|
|
|
useCertMagic bool
|
|
}
|
|
|
|
var globalTLSManager *tlsManager
|
|
var tlsManagerOnce sync.Once
|
|
|
|
func createTLSManager(cfg config.Config) *tlsManager {
|
|
storagePath := cfg.TLSStoragePath()
|
|
cleanBase := filepath.Clean(storagePath)
|
|
|
|
return &tlsManager{
|
|
config: cfg,
|
|
certPath: filepath.Join(cleanBase, "cert.pem"),
|
|
keyPath: filepath.Join(cleanBase, "privkey.pem"),
|
|
storagePath: filepath.Join(cleanBase, "certmagic"),
|
|
}
|
|
}
|
|
|
|
func (tm *tlsManager) initialize() error {
|
|
if tm.userCertsExistAndValid() {
|
|
return tm.initializeWithUserCerts()
|
|
}
|
|
return tm.initializeWithCertMagic()
|
|
}
|
|
|
|
func (tm *tlsManager) initializeWithUserCerts() error {
|
|
log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath)
|
|
|
|
if err := tm.loadUserCerts(); err != nil {
|
|
return fmt.Errorf("failed to load user certificates: %w", err)
|
|
}
|
|
|
|
tm.useCertMagic = false
|
|
tm.startCertWatcher()
|
|
return nil
|
|
}
|
|
|
|
func (tm *tlsManager) initializeWithCertMagic() error {
|
|
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic",
|
|
tm.config.Domain(), tm.config.Domain())
|
|
|
|
if err := tm.initCertMagic(); err != nil {
|
|
return fmt.Errorf("failed to initialize CertMagic: %w", err)
|
|
}
|
|
|
|
tm.useCertMagic = true
|
|
return nil
|
|
}
|
|
|
|
func (tm *tlsManager) userCertsExistAndValid() bool {
|
|
if !tm.certFilesExist() {
|
|
return false
|
|
}
|
|
return validateCertDomains(tm.certPath, tm.config.Domain())
|
|
}
|
|
|
|
func (tm *tlsManager) certFilesExist() bool {
|
|
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
|
log.Printf("Certificate file not found: %s", tm.certPath)
|
|
return false
|
|
}
|
|
if _, err := os.Stat(tm.keyPath); os.IsNotExist(err) {
|
|
log.Printf("Key file not found: %s", tm.keyPath)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (tm *tlsManager) loadUserCerts() error {
|
|
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tm.userCertMu.Lock()
|
|
tm.userCert = &cert
|
|
tm.userCertMu.Unlock()
|
|
|
|
log.Printf("Loaded user certificates successfully")
|
|
return nil
|
|
}
|
|
|
|
func (tm *tlsManager) startCertWatcher() {
|
|
go func() {
|
|
watcher := newCertWatcher(tm)
|
|
watcher.watch()
|
|
}()
|
|
}
|
|
|
|
func (tm *tlsManager) initCertMagic() error {
|
|
if err := tm.createStorageDirectory(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if tm.config.CFAPIToken() == "" {
|
|
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
|
}
|
|
|
|
magic := tm.createCertMagicConfig()
|
|
tm.magic = magic
|
|
|
|
return tm.obtainCertificates(magic)
|
|
}
|
|
|
|
func (tm *tlsManager) createStorageDirectory() error {
|
|
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
|
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
|
|
cfProvider := &cloudflare.Provider{
|
|
APIToken: tm.config.CFAPIToken(),
|
|
}
|
|
|
|
storage := &certmagic.FileStorage{Path: tm.storagePath}
|
|
|
|
cache := certmagic.NewCache(certmagic.CacheOptions{
|
|
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
|
|
return tm.magic, nil
|
|
},
|
|
})
|
|
|
|
magic := certmagic.New(cache, certmagic.Config{
|
|
Storage: storage,
|
|
})
|
|
|
|
acmeIssuer := tm.createACMEIssuer(magic, cfProvider)
|
|
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
|
|
|
return magic
|
|
}
|
|
|
|
func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer {
|
|
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
|
Email: tm.config.ACMEEmail(),
|
|
Agreed: true,
|
|
DNS01Solver: &certmagic.DNS01Solver{
|
|
DNSManager: certmagic.DNSManager{
|
|
DNSProvider: cfProvider,
|
|
},
|
|
},
|
|
})
|
|
|
|
if tm.config.ACMEStaging() {
|
|
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
|
|
log.Printf("Using Let's Encrypt staging server")
|
|
} else {
|
|
acmeIssuer.CA = certmagic.LetsEncryptProductionCA
|
|
log.Printf("Using Let's Encrypt production server")
|
|
}
|
|
|
|
return acmeIssuer
|
|
}
|
|
|
|
func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error {
|
|
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
|
|
log.Printf("Requesting certificates for: %v", domains)
|
|
|
|
ctx := context.Background()
|
|
if err := magic.ManageSync(ctx, domains); err != nil {
|
|
return fmt.Errorf("failed to obtain certificates: %w", err)
|
|
}
|
|
|
|
log.Printf("Certificates obtained successfully for %v", domains)
|
|
return nil
|
|
}
|
|
|
|
func (tm *tlsManager) getTLSConfig() *tls.Config {
|
|
return &tls.Config{
|
|
GetCertificate: tm.getCertificate,
|
|
|
|
MinVersion: tls.VersionTLS13,
|
|
MaxVersion: tls.VersionTLS13,
|
|
|
|
CurvePreferences: []tls.CurveID{
|
|
tls.X25519,
|
|
},
|
|
|
|
SessionTicketsDisabled: false,
|
|
ClientAuth: tls.NoClientCert,
|
|
}
|
|
}
|
|
|
|
func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
if tm.useCertMagic {
|
|
return tm.magic.GetCertificate(hello)
|
|
}
|
|
|
|
tm.userCertMu.RLock()
|
|
defer tm.userCertMu.RUnlock()
|
|
|
|
if tm.userCert == nil {
|
|
return nil, fmt.Errorf("no certificate available")
|
|
}
|
|
|
|
return tm.userCert, nil
|
|
}
|
|
|
|
func validateCertDomains(certPath, domain string) bool {
|
|
cert, err := loadAndParseCertificate(certPath)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
if !isCertificateValid(cert) {
|
|
return false
|
|
}
|
|
|
|
return certCoversRequiredDomains(cert, domain)
|
|
}
|
|
|
|
func loadAndParseCertificate(certPath string) (*x509.Certificate, error) {
|
|
certPEM, err := os.ReadFile(certPath)
|
|
if err != nil {
|
|
log.Printf("Failed to read certificate: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
block, _ := pem.Decode(certPEM)
|
|
if block == nil {
|
|
log.Printf("Failed to decode PEM block from certificate")
|
|
return nil, fmt.Errorf("failed to decode PEM block")
|
|
}
|
|
|
|
cert, err := x509.ParseCertificate(block.Bytes)
|
|
if err != nil {
|
|
log.Printf("Failed to parse certificate: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
return cert, nil
|
|
}
|
|
|
|
func isCertificateValid(cert *x509.Certificate) bool {
|
|
now := time.Now()
|
|
|
|
if now.After(cert.NotAfter) {
|
|
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
|
return false
|
|
}
|
|
|
|
thirtyDaysFromNow := now.Add(30 * 24 * time.Hour)
|
|
if thirtyDaysFromNow.After(cert.NotAfter) {
|
|
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool {
|
|
certDomains := extractCertDomains(cert)
|
|
hasBase, hasWildcard := checkDomainCoverage(certDomains, domain)
|
|
|
|
logDomainCoverage(hasBase, hasWildcard, domain)
|
|
return hasBase && hasWildcard
|
|
}
|
|
|
|
func extractCertDomains(cert *x509.Certificate) []string {
|
|
var domains []string
|
|
if cert.Subject.CommonName != "" {
|
|
domains = append(domains, cert.Subject.CommonName)
|
|
}
|
|
domains = append(domains, cert.DNSNames...)
|
|
return domains
|
|
}
|
|
|
|
func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) {
|
|
wildcardDomain := "*." + domain
|
|
|
|
for _, d := range certDomains {
|
|
if d == domain {
|
|
hasBase = true
|
|
}
|
|
if d == wildcardDomain {
|
|
hasWildcard = true
|
|
}
|
|
}
|
|
|
|
return hasBase, hasWildcard
|
|
}
|
|
|
|
func logDomainCoverage(hasBase, hasWildcard bool, domain string) {
|
|
if !hasBase {
|
|
log.Printf("Certificate does not cover base domain: %s", domain)
|
|
}
|
|
if !hasWildcard {
|
|
log.Printf("Certificate does not cover wildcard domain: *.%s", domain)
|
|
}
|
|
}
|
|
|
|
type certWatcher struct {
|
|
tm *tlsManager
|
|
lastCertMod time.Time
|
|
lastKeyMod time.Time
|
|
}
|
|
|
|
func newCertWatcher(tm *tlsManager) *certWatcher {
|
|
watcher := &certWatcher{tm: tm}
|
|
watcher.initializeModTimes()
|
|
return watcher
|
|
}
|
|
|
|
func (cw *certWatcher) initializeModTimes() {
|
|
if info, err := os.Stat(cw.tm.certPath); err == nil {
|
|
cw.lastCertMod = info.ModTime()
|
|
}
|
|
if info, err := os.Stat(cw.tm.keyPath); err == nil {
|
|
cw.lastKeyMod = info.ModTime()
|
|
}
|
|
}
|
|
|
|
func (cw *certWatcher) watch() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
if cw.checkAndReloadCerts() {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (cw *certWatcher) checkAndReloadCerts() bool {
|
|
certInfo, keyInfo, err := cw.getFileInfo()
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
if !cw.filesModified(certInfo, keyInfo) {
|
|
return false
|
|
}
|
|
|
|
return cw.handleCertificateChange(certInfo, keyInfo)
|
|
}
|
|
|
|
func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) {
|
|
certInfo, certErr := os.Stat(cw.tm.certPath)
|
|
keyInfo, keyErr := os.Stat(cw.tm.keyPath)
|
|
|
|
if certErr != nil || keyErr != nil {
|
|
return nil, nil, fmt.Errorf("file stat error")
|
|
}
|
|
|
|
return certInfo, keyInfo, nil
|
|
}
|
|
|
|
func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool {
|
|
return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod)
|
|
}
|
|
|
|
func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool {
|
|
log.Printf("Certificate files changed, reloading...")
|
|
|
|
if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) {
|
|
return cw.switchToCertMagic()
|
|
}
|
|
|
|
if err := cw.tm.loadUserCerts(); err != nil {
|
|
log.Printf("Failed to reload certificates: %v", err)
|
|
return false
|
|
}
|
|
|
|
cw.updateModTimes(certInfo, keyInfo)
|
|
log.Printf("Certificates reloaded successfully")
|
|
return false
|
|
}
|
|
|
|
func (cw *certWatcher) switchToCertMagic() bool {
|
|
log.Printf("New certificates don't cover required domains")
|
|
|
|
if err := cw.tm.initCertMagic(); err != nil {
|
|
log.Printf("Failed to initialize CertMagic: %v", err)
|
|
return false
|
|
}
|
|
|
|
cw.tm.useCertMagic = true
|
|
return true
|
|
}
|
|
|
|
func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) {
|
|
cw.lastCertMod = certInfo.ModTime()
|
|
cw.lastKeyMod = keyInfo.ModTime()
|
|
}
|