refactor(config): centralize env loading and enforce typed access

- Centralize environment variable loading in config.MustLoad
- Parse and validate all env vars once at initialization
- Make config fields private and read-only
- Remove public Getenv usage in favor of typed accessors
- Improve validation and initialization order
- Normalize enum naming to be idiomatic and avoid constant collisions
This commit is contained in:
2026-01-21 19:43:19 +07:00
parent 1e12373359
commit 2bc20dd991
19 changed files with 414 additions and 257 deletions
+2 -2
View File
@@ -12,9 +12,9 @@ type httpServer struct {
port string
}
func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &httpServer{
handler: newHTTPHandler(sessionRegistry, redirectTLS),
handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
port: port,
}
}
+18 -14
View File
@@ -4,12 +4,12 @@ import (
"bufio"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware"
@@ -20,19 +20,21 @@ import (
)
type httpHandler struct {
domain string
sessionRegistry registry.Registry
redirectTLS bool
}
func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
return &httpHandler{
domain: domain,
sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
}
}
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
_, err := conn.Write([]byte(fmt.Sprintf("TunnelTypeHTTP/1.1 %d Moved Permanently\r\n", status) +
fmt.Sprintf("Location: %s", location) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
@@ -44,7 +46,7 @@ func (hh *httpHandler) redirect(conn net.Conn, status int, location string) erro
}
func (hh *httpHandler) badRequest(conn net.Conn) error {
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
if _, err := conn.Write([]byte("TunnelTypeHTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
return err
}
return nil
@@ -67,7 +69,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
}
if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
return
}
@@ -85,7 +87,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
defer func(hw stream.HTTP) {
err = hw.Close()
if err != nil {
log.Printf("Error closing HTTP stream: %v", err)
log.Printf("Error closing TunnelTypeHTTP stream: %v", err)
}
}(hw)
hh.forwardRequest(hw, reqhf, sshSession)
@@ -116,7 +118,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
}
_, err := conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"TunnelTypeHTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
@@ -133,7 +135,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
Type: types.TunnelTypeHTTP,
})
if err != nil {
return nil, err
@@ -143,17 +145,19 @@ func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession)
defer func() {
err = channel.Close()
if err != nil {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return
}
defer func() {
err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
hh.setupMiddlewares(hw)
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
+6 -9
View File
@@ -9,28 +9,25 @@ import (
)
type https struct {
tlsConfig *tls.Config
httpHandler *httpHandler
domain string
port string
}
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
return &https{
httpHandler: newHTTPHandler(sessionRegistry, redirectTLS),
tlsConfig: tlsConfig,
httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
domain: domain,
port: port,
}
}
func (ht *https) Listen() (net.Listener, error) {
tlsConfig, err := NewTLSConfig(ht.domain)
if err != nil {
return nil, err
}
return tls.Listen("tcp", ":"+ht.port, tlsConfig)
return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.port)
for {
+12 -33
View File
@@ -26,7 +26,8 @@ type TLSManager interface {
}
type tlsManager struct {
domain string
config config.Config
certPath string
keyPath string
storagePath string
@@ -42,7 +43,7 @@ type tlsManager struct {
var globalTLSManager TLSManager
var tlsManagerOnce sync.Once
func NewTLSConfig(domain string) (*tls.Config, error) {
func NewTLSConfig(config config.Config) (*tls.Config, error) {
var initErr error
tlsManagerOnce.Do(func() {
@@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
storagePath := "certs/tls/certmagic"
tm := &tlsManager{
domain: domain,
config: config,
certPath: certPath,
keyPath: keyPath,
storagePath: storagePath,
@@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
tm.useCertMagic = false
tm.startCertWatcher()
} else {
if !isACMEConfigComplete() {
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
return
}
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain, config.Domain)
if err := tm.initCertMagic(); err != nil {
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
return
@@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
return globalTLSManager.getTLSConfig(), nil
}
func isACMEConfigComplete() bool {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
return cfAPIToken != ""
}
func (tm *tlsManager) userCertsExistAndValid() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath)
@@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
return false
}
return ValidateCertDomains(tm.certPath, tm.domain)
return ValidateCertDomains(tm.certPath, tm.config.Domain())
}
func ValidateCertDomains(certPath, domain string) bool {
@@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() {
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
log.Printf("Certificate files changed, reloading...")
if !ValidateCertDomains(tm.certPath, tm.domain) {
if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
log.Printf("New certificates don't cover required domains")
if !isACMEConfigComplete() {
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
continue
}
log.Printf("Switching to CertMagic for automatic certificate management")
if err := tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err)
continue
@@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error {
return fmt.Errorf("failed to create cert storage directory: %w", err)
}
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain)
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
if cfAPIToken == "" {
if tm.config.CFAPIToken() == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
}
cfProvider := &cloudflare.Provider{
APIToken: cfAPIToken,
APIToken: tm.config.CFAPIToken(),
}
storage := &certmagic.FileStorage{Path: tm.storagePath}
@@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error {
})
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: acmeEmail,
Email: tm.config.ACMEEmail(),
Agreed: true,
DNS01Solver: &certmagic.DNS01Solver{
DNSManager: certmagic.DNSManager{
@@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error {
},
})
if acmeStaging {
if tm.config.ACMEStaging() {
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
log.Printf("Using Let's Encrypt staging server")
} else {
@@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error {
magic.Issuers = []certmagic.Issuer{acmeIssuer}
tm.magic = magic
domains := []string{tm.domain, "*." + tm.domain}
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
log.Printf("Requesting certificates for: %v", domains)
ctx := context.Background()