diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..21cb4fb --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,35 @@ +package config + +import ( + "log" + "os" + "strconv" + + "github.com/joho/godotenv" +) + +func init() { + if _, err := os.Stat(".env"); err == nil { + if err := godotenv.Load(".env"); err != nil { + log.Printf("Warning: Failed to load .env file: %s", err) + } + } +} + +func Getenv(key, defaultValue string) string { + val := os.Getenv(key) + if val == "" { + val = defaultValue + } + + return val +} + +func GetBufferSize() int { + sizeStr := Getenv("BUFFER_SIZE", "32768") + size, err := strconv.Atoi(sizeStr) + if err != nil || size < 4096 || size > 1048576 { + return 32768 + } + return size +} diff --git a/utils/utils.go b/internal/key/key.go similarity index 58% rename from utils/utils.go rename to internal/key/key.go index 52637be..659abe3 100644 --- a/utils/utils.go +++ b/internal/key/key.go @@ -1,4 +1,4 @@ -package utils +package key import ( "crypto/rand" @@ -6,54 +6,12 @@ import ( "crypto/x509" "encoding/pem" "log" - mathrand "math/rand" "os" "path/filepath" - "strconv" - "strings" - "time" - "github.com/joho/godotenv" "golang.org/x/crypto/ssh" ) -func init() { - if _, err := os.Stat(".env"); err == nil { - if err := godotenv.Load(".env"); err != nil { - log.Printf("Warning: Failed to load .env file: %s", err) - } - } -} - -func GenerateRandomString(length int) string { - const charset = "abcdefghijklmnopqrstuvwxyz" - seededRand := mathrand.New(mathrand.NewSource(time.Now().UnixNano() + int64(mathrand.Intn(9999)))) - var result strings.Builder - for i := 0; i < length; i++ { - randomIndex := seededRand.Intn(len(charset)) - result.WriteString(string(charset[randomIndex])) - } - return result.String() -} - -func Getenv(key, defaultValue string) string { - val := os.Getenv(key) - if val == "" { - val = defaultValue - } - - return val -} - -func GetBufferSize() int { - sizeStr := Getenv("BUFFER_SIZE", "32768") - size, err := strconv.Atoi(sizeStr) - if err != nil || size < 4096 || size > 1048576 { - return 32768 - } - return size -} - func GenerateSSHKeyIfNotExist(keyPath string) error { if _, err := os.Stat(keyPath); err == nil { log.Printf("SSH key already exists at %s", keyPath) diff --git a/internal/port/port.go b/internal/port/port.go index 68e185a..8eb17b9 100644 --- a/internal/port/port.go +++ b/internal/port/port.go @@ -6,7 +6,7 @@ import ( "strconv" "strings" "sync" - "tunnel_pls/utils" + "tunnel_pls/internal/config" ) type Manager interface { @@ -28,7 +28,7 @@ var Default Manager = &manager{ } func init() { - rawRange := utils.Getenv("ALLOWED_PORTS", "") + rawRange := config.Getenv("ALLOWED_PORTS", "") if rawRange == "" { return } diff --git a/internal/random/random.go b/internal/random/random.go new file mode 100644 index 0000000..a67c9bf --- /dev/null +++ b/internal/random/random.go @@ -0,0 +1,18 @@ +package random + +import ( + mathrand "math/rand" + "strings" + "time" +) + +func GenerateRandomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyz" + seededRand := mathrand.New(mathrand.NewSource(time.Now().UnixNano() + int64(mathrand.Intn(9999)))) + var result strings.Builder + for i := 0; i < length; i++ { + randomIndex := seededRand.Intn(len(charset)) + result.WriteString(string(charset[randomIndex])) + } + return result.String() +} diff --git a/main.go b/main.go index 4b5c496..8198f92 100644 --- a/main.go +++ b/main.go @@ -6,8 +6,9 @@ import ( "net/http" _ "net/http/pprof" "os" + "tunnel_pls/internal/config" + "tunnel_pls/internal/key" "tunnel_pls/server" - "tunnel_pls/utils" "tunnel_pls/version" "golang.org/x/crypto/ssh" @@ -24,9 +25,9 @@ func main() { log.Printf("Starting %s", version.GetVersion()) - pprofEnabled := utils.Getenv("PPROF_ENABLED", "false") + pprofEnabled := config.Getenv("PPROF_ENABLED", "false") if pprofEnabled == "true" { - pprofPort := utils.Getenv("PPROF_PORT", "6060") + pprofPort := config.Getenv("PPROF_PORT", "6060") go func() { pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) @@ -42,7 +43,7 @@ func main() { } sshKeyPath := "certs/ssh/id_rsa" - if err := utils.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { + if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { log.Fatalf("Failed to generate SSH key: %s", err) } diff --git a/server/http.go b/server/http.go index 4cdaaf5..9c2e506 100644 --- a/server/http.go +++ b/server/http.go @@ -11,8 +11,8 @@ import ( "regexp" "strings" "time" + "tunnel_pls/internal/config" "tunnel_pls/session" - "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) @@ -231,12 +231,12 @@ func (cw *customWriter) AddInteraction(interaction Interaction) { var redirectTLS = false func NewHTTPServer() error { - httpPort := utils.Getenv("HTTP_PORT", "8080") + httpPort := config.Getenv("HTTP_PORT", "8080") listener, err := net.Listen("tcp", ":"+httpPort) if err != nil { return errors.New("Error listening: " + err.Error()) } - if utils.Getenv("TLS_ENABLED", "false") == "true" && utils.Getenv("TLS_REDIRECT", "false") == "true" { + if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" { redirectTLS = true } go func() { @@ -288,7 +288,7 @@ func Handler(conn net.Conn) { if redirectTLS { _, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" + - fmt.Sprintf("Location: https://%s.%s/\r\n", slug, utils.Getenv("DOMAIN", "localhost")) + + fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) + "Content-Length: 0\r\n" + "Connection: close\r\n" + "\r\n")) diff --git a/server/https.go b/server/https.go index 50342d2..55849cf 100644 --- a/server/https.go +++ b/server/https.go @@ -8,13 +8,13 @@ import ( "log" "net" "strings" + "tunnel_pls/internal/config" "tunnel_pls/session" - "tunnel_pls/utils" ) func NewHTTPSServer() error { - domain := utils.Getenv("DOMAIN", "localhost") - httpsPort := utils.Getenv("HTTPS_PORT", "8443") + domain := config.Getenv("DOMAIN", "localhost") + httpsPort := config.Getenv("HTTPS_PORT", "8443") tlsConfig, err := NewTLSConfig(domain) if err != nil { diff --git a/server/server.go b/server/server.go index 7f03f7c..531b3d7 100644 --- a/server/server.go +++ b/server/server.go @@ -5,7 +5,7 @@ import ( "log" "net" "net/http" - "tunnel_pls/utils" + "tunnel_pls/internal/config" "golang.org/x/crypto/ssh" ) @@ -28,13 +28,13 @@ func (s *Server) GetHttpServer() *http.Server { return s.httpServer } -func NewServer(config *ssh.ServerConfig) *Server { - listener, err := net.Listen("tcp", fmt.Sprintf(":%s", utils.Getenv("PORT", "2200"))) +func NewServer(sshConfig *ssh.ServerConfig) *Server { + listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) if err != nil { log.Fatalf("failed to listen on port 2200: %v", err) return nil } - if utils.Getenv("TLS_ENABLED", "false") == "true" { + if config.Getenv("TLS_ENABLED", "false") == "true" { err = NewHTTPSServer() if err != nil { log.Fatalf("failed to start https server: %v", err) @@ -46,7 +46,7 @@ func NewServer(config *ssh.ServerConfig) *Server { } return &Server{ conn: &listener, - config: config, + config: sshConfig, } } diff --git a/server/tls.go b/server/tls.go index bc69150..5933026 100644 --- a/server/tls.go +++ b/server/tls.go @@ -10,7 +10,7 @@ import ( "os" "sync" "time" - "tunnel_pls/utils" + "tunnel_pls/internal/config" "github.com/caddyserver/certmagic" "github.com/libdns/cloudflare" @@ -92,7 +92,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) { } func isACMEConfigComplete() bool { - cfAPIToken := utils.Getenv("CF_API_TOKEN", "") + cfAPIToken := config.Getenv("CF_API_TOKEN", "") return cfAPIToken != "" } @@ -241,9 +241,9 @@ func (tm *tlsManager) initCertMagic() error { return fmt.Errorf("failed to create cert storage directory: %w", err) } - acmeEmail := utils.Getenv("ACME_EMAIL", "admin@"+tm.domain) - cfAPIToken := utils.Getenv("CF_API_TOKEN", "") - acmeStaging := utils.Getenv("ACME_STAGING", "false") == "true" + acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain) + cfAPIToken := config.Getenv("CF_API_TOKEN", "") + acmeStaging := config.Getenv("ACME_STAGING", "false") == "true" if cfAPIToken == "" { return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 250a005..4558533 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -10,16 +10,16 @@ import ( "strconv" "sync" "time" + "tunnel_pls/internal/config" "tunnel_pls/session/slug" "tunnel_pls/types" - "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) var bufferPool = sync.Pool{ New: func() interface{} { - bufSize := utils.GetBufferSize() + bufSize := config.GetBufferSize() return make([]byte, bufSize) }, } diff --git a/session/handler.go b/session/handler.go index e9b7fce..d4c808c 100644 --- a/session/handler.go +++ b/session/handler.go @@ -7,10 +7,9 @@ import ( "log" "net" portUtil "tunnel_pls/internal/port" + "tunnel_pls/internal/random" "tunnel_pls/types" - "tunnel_pls/utils" - "golang.org/x/crypto/ssh" ) @@ -276,7 +275,7 @@ func generateUniqueSlug() string { maxAttempts := 5 for i := 0; i < maxAttempts; i++ { - slug := utils.GenerateRandomString(20) + slug := random.GenerateRandomString(20) clientsMutex.RLock() _, exists := Clients[slug] diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 9dcac7d..93d6060 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -6,9 +6,10 @@ import ( "log" "strings" "time" + "tunnel_pls/internal/config" + "tunnel_pls/internal/random" "tunnel_pls/session/slug" "tunnel_pls/types" - "tunnel_pls/utils" "github.com/charmbracelet/bubbles/help" "github.com/charmbracelet/bubbles/key" @@ -722,9 +723,9 @@ func (m model) View() string { func (i *Interaction) Start() { lipgloss.SetColorProfile(termenv.TrueColor) - domain := utils.Getenv("DOMAIN", "localhost") + domain := config.Getenv("DOMAIN", "localhost") protocol := "http" - if utils.Getenv("TLS_ENABLED", "false") == "true" { + if config.Getenv("TLS_ENABLED", "false") == "true" { protocol = "https" } @@ -811,7 +812,7 @@ func buildURL(protocol, subdomain, domain string) string { } func generateRandomSubdomain() string { - return utils.GenerateRandomString(20) + return random.GenerateRandomString(20) } func isValidSlug(slug string) bool { diff --git a/session/session.go b/session/session.go index 9c515db..db5fc27 100644 --- a/session/session.go +++ b/session/session.go @@ -4,11 +4,11 @@ import ( "log" "sync" "time" + "tunnel_pls/internal/config" "tunnel_pls/session/forwarder" "tunnel_pls/session/interaction" "tunnel_pls/session/lifecycle" "tunnel_pls/session/slug" - "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) @@ -79,7 +79,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan tcpipReq := session.waitForTCPIPForward(forwardingReq) if tcpipReq == nil { - log.Printf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200")) + log.Printf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")) if err := session.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) }