refactor(config): centralize env loading and enforce typed access

- Centralize environment variable loading in config.MustLoad
- Parse and validate all env vars once at initialization
- Make config fields private and read-only
- Remove public Getenv usage in favor of typed accessors
- Improve validation and initialization order
- Normalize enum naming to be idiomatic and avoid constant collisions
This commit is contained in:
2026-01-21 19:43:19 +07:00
parent 1e12373359
commit 2bc20dd991
19 changed files with 414 additions and 257 deletions
+53 -23
View File
@@ -1,33 +1,63 @@
package config
import (
"os"
"strconv"
import "tunnel_pls/types"
"github.com/joho/godotenv"
)
type Config interface {
Domain() string
SSHPort() string
func Load() error {
if _, err := os.Stat(".env"); err == nil {
return godotenv.Load(".env")
}
return nil
HTTPPort() string
HTTPSPort() string
TLSEnabled() bool
TLSRedirect() bool
ACMEEmail() string
CFAPIToken() string
ACMEStaging() bool
AllowedPortsStart() uint16
AllowedPortsEnd() uint16
BufferSize() int
PprofEnabled() bool
PprofPort() string
Mode() types.ServerMode
GRPCAddress() string
GRPCPort() string
NodeToken() string
}
func Getenv(key, defaultValue string) string {
val := os.Getenv(key)
if val == "" {
val = defaultValue
func MustLoad() (Config, error) {
if err := loadEnvFile(); err != nil {
return nil, err
}
return val
cfg, err := parse()
if err != nil {
return nil, err
}
return cfg, nil
}
func GetBufferSize() int {
sizeStr := Getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(sizeStr)
if err != nil || size < 4096 || size > 1048576 {
return 32768
}
return size
}
func (c *config) Domain() string { return c.domain }
func (c *config) SSHPort() string { return c.sshPort }
func (c *config) HTTPPort() string { return c.httpPort }
func (c *config) HTTPSPort() string { return c.httpsPort }
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
func (c *config) ACMEEmail() string { return c.acmeEmail }
func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging }
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
func (c *config) BufferSize() int { return c.bufferSize }
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
func (c *config) PprofPort() string { return c.pprofPort }
func (c *config) Mode() types.ServerMode { return c.mode }
func (c *config) GRPCAddress() string { return c.grpcAddress }
func (c *config) GRPCPort() string { return c.grpcPort }
func (c *config) NodeToken() string { return c.nodeToken }
+170
View File
@@ -0,0 +1,170 @@
package config
import (
"fmt"
"log"
"os"
"strconv"
"strings"
"tunnel_pls/types"
"github.com/joho/godotenv"
)
type config struct {
domain string
sshPort string
httpPort string
httpsPort string
tlsEnabled bool
tlsRedirect bool
acmeEmail string
cfAPIToken string
acmeStaging bool
allowedPortsStart uint16
allowedPortsEnd uint16
bufferSize int
pprofEnabled bool
pprofPort string
mode types.ServerMode
grpcAddress string
grpcPort string
nodeToken string
}
func parse() (*config, error) {
mode, err := parseMode()
if err != nil {
return nil, err
}
domain := getenv("DOMAIN", "localhost")
sshPort := getenv("PORT", "2200")
httpPort := getenv("HTTP_PORT", "8080")
httpsPort := getenv("HTTPS_PORT", "8443")
tlsEnabled := getenvBool("TLS_ENABLED", false)
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
acmeStaging := getenvBool("ACME_STAGING", false)
cfToken := getenv("CF_API_TOKEN", "")
if tlsEnabled && cfToken == "" {
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
}
start, end, err := parseAllowedPorts()
if err != nil {
return nil, err
}
bufferSize := parseBufferSize()
pprofEnabled := getenvBool("PPROF_ENABLED", false)
pprofPort := getenv("PPROF_PORT", "6060")
grpcHost := getenv("GRPC_ADDRESS", "localhost")
grpcPort := getenv("GRPC_PORT", "8080")
nodeToken := getenv("NODE_TOKEN", "")
if mode == types.ServerModeNODE && nodeToken == "" {
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
}
return &config{
domain: domain,
sshPort: sshPort,
httpPort: httpPort,
httpsPort: httpsPort,
tlsEnabled: tlsEnabled,
tlsRedirect: tlsRedirect,
acmeEmail: acmeEmail,
cfAPIToken: cfToken,
acmeStaging: acmeStaging,
allowedPortsStart: start,
allowedPortsEnd: end,
bufferSize: bufferSize,
pprofEnabled: pprofEnabled,
pprofPort: pprofPort,
mode: mode,
grpcAddress: grpcHost,
grpcPort: grpcPort,
nodeToken: nodeToken,
}, nil
}
func loadEnvFile() error {
if _, err := os.Stat(".env"); err == nil {
return godotenv.Load(".env")
}
return nil
}
func parseMode() (types.ServerMode, error) {
switch strings.ToLower(getenv("MODE", "standalone")) {
case "standalone":
return types.ServerModeSTANDALONE, nil
case "node":
return types.ServerModeNODE, nil
default:
return 0, fmt.Errorf("invalid MODE value")
}
}
func parseAllowedPorts() (uint16, uint16, error) {
raw := getenv("ALLOWED_PORTS", "")
if raw == "" {
return 0, 0, nil
}
parts := strings.Split(raw, "-")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
}
start, err := strconv.ParseUint(parts[0], 10, 16)
if err != nil {
return 0, 0, err
}
end, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return 0, 0, err
}
return uint16(start), uint16(end), nil
}
func parseBufferSize() int {
raw := getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(raw)
if err != nil || size < 4096 || size > 1048576 {
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
return 4096
}
return size
}
func getenv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}
func getenvBool(key string, def bool) bool {
val := os.Getenv(key)
if val == "" {
return def
}
return val == "true"
}
+9 -7
View File
@@ -29,6 +29,7 @@ type Client interface {
CheckServerHealth(ctx context.Context) error
}
type client struct {
config config.Config
conn *grpc.ClientConn
address string
sessionRegistry registry.Registry
@@ -37,7 +38,7 @@ type client struct {
closing bool
}
func New(address string, sessionRegistry registry.Registry) (Client, error) {
func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) {
var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
@@ -66,6 +67,7 @@ func New(address string, sessionRegistry registry.Registry) (Client, error) {
authorizeConnectionService := proto.NewUserServiceClient(conn)
return &client{
config: config,
conn: conn,
address: address,
sessionRegistry: sessionRegistry,
@@ -192,7 +194,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
oldSlug := slugEvent.GetOld()
newSlug := slugEvent.GetNew()
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP})
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP})
if err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
@@ -202,7 +204,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
}, "slug change failure response")
}
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil {
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
@@ -227,7 +229,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
for _, ses := range sessions {
detail := ses.Detail()
details = append(details, &proto.Detail{
Node: config.Getenv("DOMAIN", "localhost"),
Node: c.config.Domain(),
ForwardingType: detail.ForwardingType,
Slug: detail.Slug,
UserId: detail.UserID,
@@ -299,11 +301,11 @@ func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
switch t {
case proto.TunnelType_HTTP:
return types.HTTP, nil
return types.TunnelTypeHTTP, nil
case proto.TunnelType_TCP:
return types.TCP, nil
return types.TunnelTypeTCP, nil
default:
return types.UNKNOWN, fmt.Errorf("unknown tunnel type received")
return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
}
}
+3 -3
View File
@@ -17,9 +17,9 @@ type RequestHeader interface {
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
Method() string
Path() string
Version() string
}
type requestHeader struct {
method string
+3 -3
View File
@@ -32,15 +32,15 @@ func (req *requestHeader) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeader) GetMethod() string {
func (req *requestHeader) Method() string {
return req.method
}
func (req *requestHeader) GetPath() string {
func (req *requestHeader) Path() string {
return req.path
}
func (req *requestHeader) GetVersion() string {
func (req *requestHeader) Version() string {
return req.version
}
+5 -5
View File
@@ -47,12 +47,12 @@ func (r *registry) Get(key Key) (session Session, err error) {
userID, ok := r.slugIndex[key]
if !ok {
return nil, fmt.Errorf("Session not found")
return nil, fmt.Errorf("session not found")
}
client, ok := r.byUser[userID][key]
if !ok {
return nil, fmt.Errorf("Session not found")
return nil, fmt.Errorf("session not found")
}
return client, nil
}
@@ -63,7 +63,7 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error
client, ok := r.byUser[user][key]
if !ok {
return nil, fmt.Errorf("Session not found")
return nil, fmt.Errorf("session not found")
}
return client, nil
}
@@ -73,7 +73,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
return fmt.Errorf("tunnel type cannot change")
}
if newKey.Type != types.HTTP {
if newKey.Type != types.TunnelTypeHTTP {
return fmt.Errorf("non http tunnel cannot change slug")
}
@@ -93,7 +93,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
}
client, ok := r.byUser[user][oldKey]
if !ok {
return fmt.Errorf("Session not found")
return fmt.Errorf("session not found")
}
delete(r.byUser[user], oldKey)
+2 -2
View File
@@ -12,9 +12,9 @@ type httpServer struct {
port string
}
func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &httpServer{
handler: newHTTPHandler(sessionRegistry, redirectTLS),
handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
port: port,
}
}
+18 -14
View File
@@ -4,12 +4,12 @@ import (
"bufio"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware"
@@ -20,19 +20,21 @@ import (
)
type httpHandler struct {
domain string
sessionRegistry registry.Registry
redirectTLS bool
}
func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
return &httpHandler{
domain: domain,
sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
}
}
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
_, err := conn.Write([]byte(fmt.Sprintf("TunnelTypeHTTP/1.1 %d Moved Permanently\r\n", status) +
fmt.Sprintf("Location: %s", location) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
@@ -44,7 +46,7 @@ func (hh *httpHandler) redirect(conn net.Conn, status int, location string) erro
}
func (hh *httpHandler) badRequest(conn net.Conn) error {
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
if _, err := conn.Write([]byte("TunnelTypeHTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
return err
}
return nil
@@ -67,7 +69,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
}
if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
return
}
@@ -85,7 +87,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
defer func(hw stream.HTTP) {
err = hw.Close()
if err != nil {
log.Printf("Error closing HTTP stream: %v", err)
log.Printf("Error closing TunnelTypeHTTP stream: %v", err)
}
}(hw)
hh.forwardRequest(hw, reqhf, sshSession)
@@ -116,7 +118,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
}
_, err := conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"TunnelTypeHTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
@@ -133,7 +135,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
Type: types.TunnelTypeHTTP,
})
if err != nil {
return nil, err
@@ -143,17 +145,19 @@ func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession)
defer func() {
err = channel.Close()
if err != nil {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return
}
defer func() {
err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
hh.setupMiddlewares(hw)
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
+6 -9
View File
@@ -9,28 +9,25 @@ import (
)
type https struct {
tlsConfig *tls.Config
httpHandler *httpHandler
domain string
port string
}
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
return &https{
httpHandler: newHTTPHandler(sessionRegistry, redirectTLS),
tlsConfig: tlsConfig,
httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
domain: domain,
port: port,
}
}
func (ht *https) Listen() (net.Listener, error) {
tlsConfig, err := NewTLSConfig(ht.domain)
if err != nil {
return nil, err
}
return tls.Listen("tcp", ":"+ht.port, tlsConfig)
return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.port)
for {
+12 -33
View File
@@ -26,7 +26,8 @@ type TLSManager interface {
}
type tlsManager struct {
domain string
config config.Config
certPath string
keyPath string
storagePath string
@@ -42,7 +43,7 @@ type tlsManager struct {
var globalTLSManager TLSManager
var tlsManagerOnce sync.Once
func NewTLSConfig(domain string) (*tls.Config, error) {
func NewTLSConfig(config config.Config) (*tls.Config, error) {
var initErr error
tlsManagerOnce.Do(func() {
@@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
storagePath := "certs/tls/certmagic"
tm := &tlsManager{
domain: domain,
config: config,
certPath: certPath,
keyPath: keyPath,
storagePath: storagePath,
@@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
tm.useCertMagic = false
tm.startCertWatcher()
} else {
if !isACMEConfigComplete() {
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
return
}
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain, config.Domain)
if err := tm.initCertMagic(); err != nil {
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
return
@@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
return globalTLSManager.getTLSConfig(), nil
}
func isACMEConfigComplete() bool {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
return cfAPIToken != ""
}
func (tm *tlsManager) userCertsExistAndValid() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath)
@@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
return false
}
return ValidateCertDomains(tm.certPath, tm.domain)
return ValidateCertDomains(tm.certPath, tm.config.Domain())
}
func ValidateCertDomains(certPath, domain string) bool {
@@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() {
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
log.Printf("Certificate files changed, reloading...")
if !ValidateCertDomains(tm.certPath, tm.domain) {
if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
log.Printf("New certificates don't cover required domains")
if !isACMEConfigComplete() {
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
continue
}
log.Printf("Switching to CertMagic for automatic certificate management")
if err := tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err)
continue
@@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error {
return fmt.Errorf("failed to create cert storage directory: %w", err)
}
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain)
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
if cfAPIToken == "" {
if tm.config.CFAPIToken() == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
}
cfProvider := &cloudflare.Provider{
APIToken: cfAPIToken,
APIToken: tm.config.CFAPIToken(),
}
storage := &certmagic.FileStorage{Path: tm.storagePath}
@@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error {
})
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: acmeEmail,
Email: tm.config.ACMEEmail(),
Agreed: true,
DNS01Solver: &certmagic.DNS01Solver{
DNSManager: certmagic.DNSManager{
@@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error {
},
})
if acmeStaging {
if tm.config.ACMEStaging() {
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
log.Printf("Using Let's Encrypt staging server")
} else {
@@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error {
magic.Issuers = []certmagic.Issuer{acmeIssuer}
tm.magic = magic
domains := []string{tm.domain, "*." + tm.domain}
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
log.Printf("Requesting certificates for: %v", domains)
ctx := context.Background()