From 2bc20dd99154500e369eea94345e6e1d131c029d Mon Sep 17 00:00:00 2001 From: bagas Date: Wed, 21 Jan 2026 19:43:19 +0700 Subject: [PATCH] refactor(config): centralize env loading and enforce typed access - Centralize environment variable loading in config.MustLoad - Parse and validate all env vars once at initialization - Make config fields private and read-only - Remove public Getenv usage in favor of typed accessors - Improve validation and initialization order - Normalize enum naming to be idiomatic and avoid constant collisions --- internal/config/config.go | 76 +++++++++---- internal/config/loader.go | 170 +++++++++++++++++++++++++++++ internal/grpc/client/client.go | 16 +-- internal/http/header/header.go | 6 +- internal/http/header/request.go | 6 +- internal/registry/registry.go | 10 +- internal/transport/http.go | 4 +- internal/transport/httphandler.go | 32 +++--- internal/transport/https.go | 15 +-- internal/transport/tls.go | 45 ++------ main.go | 103 ++++++----------- server/server.go | 13 ++- session/forwarder/forwarder.go | 63 ++++++----- session/interaction/interaction.go | 21 ++-- session/interaction/model.go | 2 +- session/interaction/slug.go | 12 +- session/lifecycle/lifecycle.go | 14 +-- session/session.go | 36 +++--- types/types.go | 27 +++-- 19 files changed, 414 insertions(+), 257 deletions(-) create mode 100644 internal/config/loader.go diff --git a/internal/config/config.go b/internal/config/config.go index 45f1cc5..62e1aca 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,33 +1,63 @@ package config -import ( - "os" - "strconv" +import "tunnel_pls/types" - "github.com/joho/godotenv" -) +type Config interface { + Domain() string + SSHPort() string -func Load() error { - if _, err := os.Stat(".env"); err == nil { - return godotenv.Load(".env") - } - return nil + HTTPPort() string + HTTPSPort() string + + TLSEnabled() bool + TLSRedirect() bool + + ACMEEmail() string + CFAPIToken() string + ACMEStaging() bool + + AllowedPortsStart() uint16 + AllowedPortsEnd() uint16 + + BufferSize() int + + PprofEnabled() bool + PprofPort() string + + Mode() types.ServerMode + GRPCAddress() string + GRPCPort() string + NodeToken() string } -func Getenv(key, defaultValue string) string { - val := os.Getenv(key) - if val == "" { - val = defaultValue +func MustLoad() (Config, error) { + if err := loadEnvFile(); err != nil { + return nil, err } - return val + cfg, err := parse() + if err != nil { + return nil, err + } + + return cfg, nil } -func GetBufferSize() int { - sizeStr := Getenv("BUFFER_SIZE", "32768") - size, err := strconv.Atoi(sizeStr) - if err != nil || size < 4096 || size > 1048576 { - return 32768 - } - return size -} +func (c *config) Domain() string { return c.domain } +func (c *config) SSHPort() string { return c.sshPort } +func (c *config) HTTPPort() string { return c.httpPort } +func (c *config) HTTPSPort() string { return c.httpsPort } +func (c *config) TLSEnabled() bool { return c.tlsEnabled } +func (c *config) TLSRedirect() bool { return c.tlsRedirect } +func (c *config) ACMEEmail() string { return c.acmeEmail } +func (c *config) CFAPIToken() string { return c.cfAPIToken } +func (c *config) ACMEStaging() bool { return c.acmeStaging } +func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart } +func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd } +func (c *config) BufferSize() int { return c.bufferSize } +func (c *config) PprofEnabled() bool { return c.pprofEnabled } +func (c *config) PprofPort() string { return c.pprofPort } +func (c *config) Mode() types.ServerMode { return c.mode } +func (c *config) GRPCAddress() string { return c.grpcAddress } +func (c *config) GRPCPort() string { return c.grpcPort } +func (c *config) NodeToken() string { return c.nodeToken } diff --git a/internal/config/loader.go b/internal/config/loader.go new file mode 100644 index 0000000..cde9fd0 --- /dev/null +++ b/internal/config/loader.go @@ -0,0 +1,170 @@ +package config + +import ( + "fmt" + "log" + "os" + "strconv" + "strings" + "tunnel_pls/types" + + "github.com/joho/godotenv" +) + +type config struct { + domain string + sshPort string + + httpPort string + httpsPort string + + tlsEnabled bool + tlsRedirect bool + + acmeEmail string + cfAPIToken string + acmeStaging bool + + allowedPortsStart uint16 + allowedPortsEnd uint16 + + bufferSize int + + pprofEnabled bool + pprofPort string + + mode types.ServerMode + grpcAddress string + grpcPort string + nodeToken string +} + +func parse() (*config, error) { + mode, err := parseMode() + if err != nil { + return nil, err + } + + domain := getenv("DOMAIN", "localhost") + sshPort := getenv("PORT", "2200") + + httpPort := getenv("HTTP_PORT", "8080") + httpsPort := getenv("HTTPS_PORT", "8443") + + tlsEnabled := getenvBool("TLS_ENABLED", false) + tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false) + + acmeEmail := getenv("ACME_EMAIL", "admin@"+domain) + acmeStaging := getenvBool("ACME_STAGING", false) + + cfToken := getenv("CF_API_TOKEN", "") + if tlsEnabled && cfToken == "" { + return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled") + } + + start, end, err := parseAllowedPorts() + if err != nil { + return nil, err + } + + bufferSize := parseBufferSize() + + pprofEnabled := getenvBool("PPROF_ENABLED", false) + pprofPort := getenv("PPROF_PORT", "6060") + + grpcHost := getenv("GRPC_ADDRESS", "localhost") + grpcPort := getenv("GRPC_PORT", "8080") + + nodeToken := getenv("NODE_TOKEN", "") + if mode == types.ServerModeNODE && nodeToken == "" { + return nil, fmt.Errorf("NODE_TOKEN is required in node mode") + } + + return &config{ + domain: domain, + sshPort: sshPort, + httpPort: httpPort, + httpsPort: httpsPort, + tlsEnabled: tlsEnabled, + tlsRedirect: tlsRedirect, + acmeEmail: acmeEmail, + cfAPIToken: cfToken, + acmeStaging: acmeStaging, + allowedPortsStart: start, + allowedPortsEnd: end, + bufferSize: bufferSize, + pprofEnabled: pprofEnabled, + pprofPort: pprofPort, + mode: mode, + grpcAddress: grpcHost, + grpcPort: grpcPort, + nodeToken: nodeToken, + }, nil +} + +func loadEnvFile() error { + if _, err := os.Stat(".env"); err == nil { + return godotenv.Load(".env") + } + return nil +} + +func parseMode() (types.ServerMode, error) { + switch strings.ToLower(getenv("MODE", "standalone")) { + case "standalone": + return types.ServerModeSTANDALONE, nil + case "node": + return types.ServerModeNODE, nil + default: + return 0, fmt.Errorf("invalid MODE value") + } +} + +func parseAllowedPorts() (uint16, uint16, error) { + raw := getenv("ALLOWED_PORTS", "") + if raw == "" { + return 0, 0, nil + } + + parts := strings.Split(raw, "-") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format") + } + + start, err := strconv.ParseUint(parts[0], 10, 16) + if err != nil { + return 0, 0, err + } + + end, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return 0, 0, err + } + + return uint16(start), uint16(end), nil +} + +func parseBufferSize() int { + raw := getenv("BUFFER_SIZE", "32768") + size, err := strconv.Atoi(raw) + if err != nil || size < 4096 || size > 1048576 { + log.Println("Invalid BUFFER_SIZE, falling back to 4096") + return 4096 + } + return size +} + +func getenv(key, def string) string { + if v := os.Getenv(key); v != "" { + return v + } + return def +} + +func getenvBool(key string, def bool) bool { + val := os.Getenv(key) + if val == "" { + return def + } + return val == "true" +} diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index 0874afe..f2e0a1e 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -29,6 +29,7 @@ type Client interface { CheckServerHealth(ctx context.Context) error } type client struct { + config config.Config conn *grpc.ClientConn address string sessionRegistry registry.Registry @@ -37,7 +38,7 @@ type client struct { closing bool } -func New(address string, sessionRegistry registry.Registry) (Client, error) { +func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -66,6 +67,7 @@ func New(address string, sessionRegistry registry.Registry) (Client, error) { authorizeConnectionService := proto.NewUserServiceClient(conn) return &client{ + config: config, conn: conn, address: address, sessionRegistry: sessionRegistry, @@ -192,7 +194,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, oldSlug := slugEvent.GetOld() newSlug := slugEvent.GetNew() - userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP}) + userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}) if err != nil { return c.sendNode(subscribe, &proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, @@ -202,7 +204,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, }, "slug change failure response") } - if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil { + if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil { return c.sendNode(subscribe, &proto.Node{ Type: proto.EventType_SLUG_CHANGE_RESPONSE, Payload: &proto.Node_SlugEventResponse{ @@ -227,7 +229,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node for _, ses := range sessions { detail := ses.Detail() details = append(details, &proto.Detail{ - Node: config.Getenv("DOMAIN", "localhost"), + Node: c.config.Domain(), ForwardingType: detail.ForwardingType, Slug: detail.Slug, UserId: detail.UserID, @@ -299,11 +301,11 @@ func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) { switch t { case proto.TunnelType_HTTP: - return types.HTTP, nil + return types.TunnelTypeHTTP, nil case proto.TunnelType_TCP: - return types.TCP, nil + return types.TunnelTypeTCP, nil default: - return types.UNKNOWN, fmt.Errorf("unknown tunnel type received") + return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received") } } diff --git a/internal/http/header/header.go b/internal/http/header/header.go index a5e52b3..605f9ec 100644 --- a/internal/http/header/header.go +++ b/internal/http/header/header.go @@ -17,9 +17,9 @@ type RequestHeader interface { Set(key string, value string) Remove(key string) Finalize() []byte - GetMethod() string - GetPath() string - GetVersion() string + Method() string + Path() string + Version() string } type requestHeader struct { method string diff --git a/internal/http/header/request.go b/internal/http/header/request.go index b05f699..1fbe57a 100644 --- a/internal/http/header/request.go +++ b/internal/http/header/request.go @@ -32,15 +32,15 @@ func (req *requestHeader) Remove(key string) { delete(req.headers, key) } -func (req *requestHeader) GetMethod() string { +func (req *requestHeader) Method() string { return req.method } -func (req *requestHeader) GetPath() string { +func (req *requestHeader) Path() string { return req.path } -func (req *requestHeader) GetVersion() string { +func (req *requestHeader) Version() string { return req.version } diff --git a/internal/registry/registry.go b/internal/registry/registry.go index 22e590a..86898b0 100644 --- a/internal/registry/registry.go +++ b/internal/registry/registry.go @@ -47,12 +47,12 @@ func (r *registry) Get(key Key) (session Session, err error) { userID, ok := r.slugIndex[key] if !ok { - return nil, fmt.Errorf("Session not found") + return nil, fmt.Errorf("session not found") } client, ok := r.byUser[userID][key] if !ok { - return nil, fmt.Errorf("Session not found") + return nil, fmt.Errorf("session not found") } return client, nil } @@ -63,7 +63,7 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error client, ok := r.byUser[user][key] if !ok { - return nil, fmt.Errorf("Session not found") + return nil, fmt.Errorf("session not found") } return client, nil } @@ -73,7 +73,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { return fmt.Errorf("tunnel type cannot change") } - if newKey.Type != types.HTTP { + if newKey.Type != types.TunnelTypeHTTP { return fmt.Errorf("non http tunnel cannot change slug") } @@ -93,7 +93,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { } client, ok := r.byUser[user][oldKey] if !ok { - return fmt.Errorf("Session not found") + return fmt.Errorf("session not found") } delete(r.byUser[user], oldKey) diff --git a/internal/transport/http.go b/internal/transport/http.go index bf698ab..dd091c3 100644 --- a/internal/transport/http.go +++ b/internal/transport/http.go @@ -12,9 +12,9 @@ type httpServer struct { port string } -func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { +func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { return &httpServer{ - handler: newHTTPHandler(sessionRegistry, redirectTLS), + handler: newHTTPHandler(domain, sessionRegistry, redirectTLS), port: port, } } diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go index 0b22e48..b6f128d 100644 --- a/internal/transport/httphandler.go +++ b/internal/transport/httphandler.go @@ -4,12 +4,12 @@ import ( "bufio" "errors" "fmt" + "io" "log" "net" "net/http" "strings" "time" - "tunnel_pls/internal/config" "tunnel_pls/internal/http/header" "tunnel_pls/internal/http/stream" "tunnel_pls/internal/middleware" @@ -20,19 +20,21 @@ import ( ) type httpHandler struct { + domain string sessionRegistry registry.Registry redirectTLS bool } -func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { +func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { return &httpHandler{ + domain: domain, sessionRegistry: sessionRegistry, redirectTLS: redirectTLS, } } func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error { - _, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) + + _, err := conn.Write([]byte(fmt.Sprintf("TunnelTypeHTTP/1.1 %d Moved Permanently\r\n", status) + fmt.Sprintf("Location: %s", location) + "Content-Length: 0\r\n" + "Connection: close\r\n" + @@ -44,7 +46,7 @@ func (hh *httpHandler) redirect(conn net.Conn, status int, location string) erro } func (hh *httpHandler) badRequest(conn net.Conn) error { - if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { + if _, err := conn.Write([]byte("TunnelTypeHTTP/1.1 400 Bad Request\r\n\r\n")); err != nil { return err } return nil @@ -67,7 +69,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { } if hh.shouldRedirectToTLS(isTLS) { - _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost"))) + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain)) return } @@ -85,7 +87,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { defer func(hw stream.HTTP) { err = hw.Close() if err != nil { - log.Printf("Error closing HTTP stream: %v", err) + log.Printf("Error closing TunnelTypeHTTP stream: %v", err) } }(hw) hh.forwardRequest(hw, reqhf, sshSession) @@ -116,7 +118,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { } _, err := conn.Write([]byte( - "HTTP/1.1 200 OK\r\n" + + "TunnelTypeHTTP/1.1 200 OK\r\n" + "Content-Length: 0\r\n" + "Connection: close\r\n" + "Access-Control-Allow-Origin: *\r\n" + @@ -133,7 +135,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { func (hh *httpHandler) getSession(slug string) (registry.Session, error) { sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ Id: slug, - Type: types.HTTP, + Type: types.TunnelTypeHTTP, }) if err != nil { return nil, err @@ -143,17 +145,19 @@ func (hh *httpHandler) getSession(slug string) (registry.Session, error) { func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { channel, err := hh.openForwardedChannel(hw, sshSession) - defer func() { - err = channel.Close() - if err != nil { - log.Printf("Error closing forwarded channel: %v", err) - } - }() if err != nil { log.Printf("Failed to establish channel: %v", err) sshSession.Forwarder().WriteBadGatewayResponse(hw) return } + + defer func() { + err = channel.Close() + if err != nil && !errors.Is(err, io.EOF) { + log.Printf("Error closing forwarded channel: %v", err) + } + }() + hh.setupMiddlewares(hw) if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil { diff --git a/internal/transport/https.go b/internal/transport/https.go index 104aa15..88ffe27 100644 --- a/internal/transport/https.go +++ b/internal/transport/https.go @@ -9,28 +9,25 @@ import ( ) type https struct { + tlsConfig *tls.Config httpHandler *httpHandler domain string port string } -func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { +func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport { return &https{ - httpHandler: newHTTPHandler(sessionRegistry, redirectTLS), + tlsConfig: tlsConfig, + httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS), domain: domain, port: port, } } func (ht *https) Listen() (net.Listener, error) { - tlsConfig, err := NewTLSConfig(ht.domain) - if err != nil { - return nil, err - } - - return tls.Listen("tcp", ":"+ht.port, tlsConfig) - + return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig) } + func (ht *https) Serve(listener net.Listener) error { log.Printf("HTTPS server is starting on port %s", ht.port) for { diff --git a/internal/transport/tls.go b/internal/transport/tls.go index 0893b85..6824a54 100644 --- a/internal/transport/tls.go +++ b/internal/transport/tls.go @@ -26,7 +26,8 @@ type TLSManager interface { } type tlsManager struct { - domain string + config config.Config + certPath string keyPath string storagePath string @@ -42,7 +43,7 @@ type tlsManager struct { var globalTLSManager TLSManager var tlsManagerOnce sync.Once -func NewTLSConfig(domain string) (*tls.Config, error) { +func NewTLSConfig(config config.Config) (*tls.Config, error) { var initErr error tlsManagerOnce.Do(func() { @@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) { storagePath := "certs/tls/certmagic" tm := &tlsManager{ - domain: domain, + config: config, certPath: certPath, keyPath: keyPath, storagePath: storagePath, @@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) { tm.useCertMagic = false tm.startCertWatcher() } else { - if !isACMEConfigComplete() { - log.Printf("User certificates missing or invalid, and ACME configuration is incomplete") - log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable") - initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)") - return - } - - log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain) + log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain, config.Domain) if err := tm.initCertMagic(); err != nil { initErr = fmt.Errorf("failed to initialize CertMagic: %w", err) return @@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) { return globalTLSManager.getTLSConfig(), nil } -func isACMEConfigComplete() bool { - cfAPIToken := config.Getenv("CF_API_TOKEN", "") - return cfAPIToken != "" -} - func (tm *tlsManager) userCertsExistAndValid() bool { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { log.Printf("Certificate file not found: %s", tm.certPath) @@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool { return false } - return ValidateCertDomains(tm.certPath, tm.domain) + return ValidateCertDomains(tm.certPath, tm.config.Domain()) } func ValidateCertDomains(certPath, domain string) bool { @@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() { if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) { log.Printf("Certificate files changed, reloading...") - if !ValidateCertDomains(tm.certPath, tm.domain) { + if !ValidateCertDomains(tm.certPath, tm.config.Domain()) { log.Printf("New certificates don't cover required domains") - if !isACMEConfigComplete() { - log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)") - continue - } - - log.Printf("Switching to CertMagic for automatic certificate management") if err := tm.initCertMagic(); err != nil { log.Printf("Failed to initialize CertMagic: %v", err) continue @@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error { return fmt.Errorf("failed to create cert storage directory: %w", err) } - acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain) - cfAPIToken := config.Getenv("CF_API_TOKEN", "") - acmeStaging := config.Getenv("ACME_STAGING", "false") == "true" - - if cfAPIToken == "" { + if tm.config.CFAPIToken() == "" { return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") } cfProvider := &cloudflare.Provider{ - APIToken: cfAPIToken, + APIToken: tm.config.CFAPIToken(), } storage := &certmagic.FileStorage{Path: tm.storagePath} @@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error { }) acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ - Email: acmeEmail, + Email: tm.config.ACMEEmail(), Agreed: true, DNS01Solver: &certmagic.DNS01Solver{ DNSManager: certmagic.DNSManager{ @@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error { }, }) - if acmeStaging { + if tm.config.ACMEStaging() { acmeIssuer.CA = certmagic.LetsEncryptStagingCA log.Printf("Using Let's Encrypt staging server") } else { @@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error { magic.Issuers = []certmagic.Issuer{acmeIssuer} tm.magic = magic - domains := []string{tm.domain, "*." + tm.domain} + domains := []string{tm.config.Domain(), "*." + tm.config.Domain()} log.Printf("Requesting certificates for: %v", domains) ctx := context.Background() diff --git a/main.go b/main.go index 6510932..f897b46 100644 --- a/main.go +++ b/main.go @@ -9,8 +9,6 @@ import ( _ "net/http/pprof" "os" "os/signal" - "strconv" - "strings" "syscall" "time" "tunnel_pls/internal/config" @@ -21,6 +19,7 @@ import ( "tunnel_pls/internal/transport" "tunnel_pls/internal/version" "tunnel_pls/server" + "tunnel_pls/types" "golang.org/x/crypto/ssh" ) @@ -36,27 +35,12 @@ func main() { log.Printf("Starting %s", version.GetVersion()) - err := config.Load() + conf, err := config.MustLoad() if err != nil { log.Fatalf("Failed to load configuration: %s", err) return } - mode := strings.ToLower(config.Getenv("MODE", "standalone")) - isNodeMode := mode == "node" - - pprofEnabled := config.Getenv("PPROF_ENABLED", "false") - if pprofEnabled == "true" { - pprofPort := config.Getenv("PPROF_PORT", "6060") - go func() { - pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) - log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) - if err = http.ListenAndServe(pprofAddr, nil); err != nil { - log.Printf("pprof server error: %v", err) - } - }() - } - sshConfig := &ssh.ServerConfig{ NoClientAuth: true, ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), @@ -88,16 +72,11 @@ func main() { signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) var grpcClient client.Client - if isNodeMode { - grpcHost := config.Getenv("GRPC_ADDRESS", "localhost") - grpcPort := config.Getenv("GRPC_PORT", "8080") - grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort) - nodeToken := config.Getenv("NODE_TOKEN", "") - if nodeToken == "" { - log.Fatalf("NODE_TOKEN is required in node mode") - } - grpcClient, err = client.New(grpcAddr, sessionRegistry) + if conf.Mode() == types.ServerModeNODE { + grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort()) + + grpcClient, err = client.New(conf, grpcAddr, sessionRegistry) if err != nil { log.Fatalf("failed to create grpc client: %v", err) } @@ -110,46 +89,15 @@ func main() { healthCancel() go func() { - identity := config.Getenv("DOMAIN", "localhost") - if err = grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil { + if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { errChan <- fmt.Errorf("failed to subscribe to events: %w", err) } }() } - portManager := port.New() - rawRange := config.Getenv("ALLOWED_PORTS", "") - if rawRange != "" { - splitRange := strings.Split(rawRange, "-") - if len(splitRange) == 2 { - var start, end uint64 - start, err = strconv.ParseUint(splitRange[0], 10, 16) - if err != nil { - log.Fatalf("Failed to parse start port: %s", err) - } - - end, err = strconv.ParseUint(splitRange[1], 10, 16) - if err != nil { - log.Fatalf("Failed to parse end port: %s", err) - } - - if err = portManager.AddRange(uint16(start), uint16(end)); err != nil { - log.Fatalf("Failed to add port range: %s", err) - } - log.Printf("PortRegistry range configured: %d-%d", start, end) - } else { - log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange) - } - } - - tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true" - redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" - go func() { - httpPort := config.Getenv("HTTP_PORT", "8080") - var httpListener net.Listener - httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS) + httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect()) httpListener, err = httpserver.Listen() if err != nil { errChan <- fmt.Errorf("failed to start http server: %w", err) @@ -162,37 +110,52 @@ func main() { } }() - if tlsEnabled { + if conf.TLSEnabled() { go func() { - httpsPort := config.Getenv("HTTPS_PORT", "8443") - domain := config.Getenv("DOMAIN", "localhost") - - var httpListener net.Listener - httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS) - httpListener, err = httpserver.Listen() + var httpsListener net.Listener + tlsConfig, _ := transport.NewTLSConfig(conf) + httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig) + httpsListener, err = httpsServer.Listen() if err != nil { errChan <- fmt.Errorf("failed to start http server: %w", err) return } - err = httpserver.Serve(httpListener) + err = httpsServer.Serve(httpsListener) if err != nil { errChan <- fmt.Errorf("error when serving http server: %w", err) return } }() } + + portManager := port.New() + err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()) + if err != nil { + log.Fatalf("Failed to initialize port manager: %s", err) + return + } var app server.Server go func() { - sshPort := config.Getenv("PORT", "2200") - app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort) + app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort()) if err != nil { errChan <- fmt.Errorf("failed to start server: %s", err) return } app.Start() + }() + if conf.PprofEnabled() { + go func() { + pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort()) + log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) + if err = http.ListenAndServe(pprofAddr, nil); err != nil { + log.Printf("pprof server error: %v", err) + } + }() + } + select { case err = <-errChan: log.Printf("error happen : %s", err) diff --git a/server/server.go b/server/server.go index 185d051..f47c579 100644 --- a/server/server.go +++ b/server/server.go @@ -7,6 +7,7 @@ import ( "log" "net" "time" + "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/port" "tunnel_pls/internal/registry" @@ -20,24 +21,26 @@ type Server interface { Close() error } type server struct { + config config.Config sshPort string sshListener net.Listener - config *ssh.ServerConfig + sshConfig *ssh.ServerConfig grpcClient client.Client sessionRegistry registry.Registry portRegistry port.Port } -func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { +func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort)) if err != nil { return nil, err } return &server{ + config: config, sshPort: sshPort, sshListener: listener, - config: sshConfig, + sshConfig: sshConfig, grpcClient: grpcClient, sessionRegistry: sessionRegistry, portRegistry: portRegistry, @@ -66,7 +69,7 @@ func (s *server) Close() error { } func (s *server) handleConnection(conn net.Conn) { - sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) + sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig) if err != nil { log.Printf("failed to establish SSH connection: %v", err) err = conn.Close() @@ -92,7 +95,7 @@ func (s *server) handleConnection(conn net.Conn) { cancel() } log.Println("SSH connection established:", sshConn.User()) - sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) + sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) err = sshSession.Start() if err != nil { log.Printf("SSH session ended with error: %v", err) diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index ff2abde..c602565 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -18,37 +18,6 @@ import ( "golang.org/x/crypto/ssh" ) -var bufferPool = sync.Pool{ - New: func() interface{} { - bufSize := config.GetBufferSize() - return make([]byte, bufSize) - }, -} - -func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { - buf := bufferPool.Get().([]byte) - defer bufferPool.Put(buf) - return io.CopyBuffer(dst, src, buf) -} - -type forwarder struct { - listener net.Listener - tunnelType types.TunnelType - forwardedPort uint16 - slug slug.Slug - conn ssh.Conn -} - -func New(slug slug.Slug, conn ssh.Conn) Forwarder { - return &forwarder{ - listener: nil, - tunnelType: types.UNKNOWN, - forwardedPort: 0, - slug: slug, - conn: conn, - } -} - type Forwarder interface { SetType(tunnelType types.TunnelType) SetForwardedPort(port uint16) @@ -62,6 +31,36 @@ type Forwarder interface { WriteBadGatewayResponse(dst io.Writer) Close() error } +type forwarder struct { + listener net.Listener + tunnelType types.TunnelType + forwardedPort uint16 + slug slug.Slug + conn ssh.Conn + bufferPool sync.Pool +} + +func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder { + return &forwarder{ + listener: nil, + tunnelType: types.TunnelTypeUNKNOWN, + forwardedPort: 0, + slug: slug, + conn: conn, + bufferPool: sync.Pool{ + New: func() interface{} { + bufSize := config.BufferSize() + return make([]byte, bufSize) + }, + }, + } +} + +func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { + buf := f.bufferPool.Get().([]byte) + defer f.bufferPool.Put(buf) + return io.CopyBuffer(dst, src, buf) +} func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { type channelResult struct { @@ -107,7 +106,7 @@ func closeWriter(w io.Writer) error { func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error { var errs []error - _, err := copyWithBuffer(dst, src) + _, err := f.copyWithBuffer(dst, src) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err)) } diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 3c02dae..5f68102 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -18,9 +18,9 @@ import ( ) type Interaction interface { - Mode() types.Mode + Mode() types.InteractiveMode SetChannel(channel ssh.Channel) - SetMode(m types.Mode) + SetMode(m types.InteractiveMode) SetWH(w, h int) Start() Redraw() @@ -39,6 +39,7 @@ type Forwarder interface { type CloseFunc func() error type interaction struct { + config config.Config channel ssh.Channel slug slug.Slug forwarder Forwarder @@ -48,14 +49,14 @@ type interaction struct { program *tea.Program ctx context.Context cancel context.CancelFunc - mode types.Mode + mode types.InteractiveMode } -func (i *interaction) SetMode(m types.Mode) { +func (i *interaction) SetMode(m types.InteractiveMode) { i.mode = m } -func (i *interaction) Mode() types.Mode { +func (i *interaction) Mode() types.InteractiveMode { return i.mode } @@ -75,9 +76,10 @@ func (i *interaction) SetWH(w, h int) { } } -func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { +func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { ctx, cancel := context.WithCancel(context.Background()) return &interaction{ + config: config, channel: nil, slug: slug, forwarder: forwarder, @@ -174,14 +176,13 @@ func (m *model) View() string { } func (i *interaction) Start() { - if i.mode == types.HEADLESS { + if i.mode == types.InteractiveModeHEADLESS { return } lipgloss.SetColorProfile(termenv.TrueColor) - domain := config.Getenv("DOMAIN", "localhost") protocol := "http" - if config.Getenv("TLS_ENABLED", "false") == "true" { + if i.config.TLSEnabled() { protocol = "https" } @@ -209,7 +210,7 @@ func (i *interaction) Start() { ti.Width = 50 m := &model{ - domain: domain, + domain: i.config.Domain(), protocol: protocol, tunnelType: tunnelType, port: port, diff --git a/session/interaction/model.go b/session/interaction/model.go index 24b4d26..189b0a1 100644 --- a/session/interaction/model.go +++ b/session/interaction/model.go @@ -41,7 +41,7 @@ type model struct { } func (m *model) getTunnelURL() string { - if m.tunnelType == types.HTTP { + if m.tunnelType == types.TunnelTypeHTTP { return buildURL(m.protocol, m.interaction.slug.String(), m.domain) } return fmt.Sprintf("tcp://%s:%d", m.domain, m.port) diff --git a/session/interaction/slug.go b/session/interaction/slug.go index 6c6a97b..08c7c7d 100644 --- a/session/interaction/slug.go +++ b/session/interaction/slug.go @@ -15,7 +15,7 @@ import ( func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { var cmd tea.Cmd - if m.tunnelType != types.HTTP { + if m.tunnelType != types.TunnelTypeHTTP { m.editingSlug = false m.slugError = "" return m, tea.Batch(tea.ClearScreen, textinput.Blink) @@ -30,10 +30,10 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { inputValue := m.slugInput.Value() if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{ Id: m.interaction.slug.String(), - Type: types.HTTP, + Type: types.TunnelTypeHTTP, }, types.SessionKey{ Id: inputValue, - Type: types.HTTP, + Type: types.TunnelTypeHTTP, }); err != nil { m.slugError = err.Error() return m, nil @@ -130,7 +130,7 @@ func (m *model) slugView() string { b.WriteString(titleStyle.Render(title)) b.WriteString("\n\n") - if m.tunnelType != types.HTTP { + if m.tunnelType != types.TunnelTypeHTTP { warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60) warningBoxStyle := lipgloss.NewStyle(). Foreground(lipgloss.Color("#FFA500")). @@ -145,9 +145,9 @@ func (m *model) slugView() string { var warningText string if isVeryCompact { - warningText = "⚠️ TCP tunnels don't support custom subdomains." + warningText = "⚠️ TunnelTypeTCP tunnels don't support custom subdomains." } else { - warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization." + warningText = "⚠️ TunnelTypeTCP tunnels cannot have custom subdomains. Only TunnelTypeHTTP/HTTPS tunnels support subdomain customization." } b.WriteString(warningBoxStyle.Render(warningText)) b.WriteString("\n\n") diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 7a2fcaf..e4ce44f 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -24,7 +24,7 @@ type SessionRegistry interface { } type lifecycle struct { - status types.Status + status types.SessionStatus conn ssh.Conn channel ssh.Channel forwarder Forwarder @@ -37,7 +37,7 @@ type lifecycle struct { func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle { return &lifecycle{ - status: types.INITIALIZING, + status: types.SessionStatusINITIALIZING, conn: conn, channel: nil, forwarder: forwarder, @@ -54,7 +54,7 @@ type Lifecycle interface { PortRegistry() portUtil.Port User() string SetChannel(channel ssh.Channel) - SetStatus(status types.Status) + SetStatus(status types.SessionStatus) IsActive() bool StartedAt() time.Time Close() error @@ -74,9 +74,9 @@ func (l *lifecycle) SetChannel(channel ssh.Channel) { func (l *lifecycle) Connection() ssh.Conn { return l.conn } -func (l *lifecycle) SetStatus(status types.Status) { +func (l *lifecycle) SetStatus(status types.SessionStatus) { l.status = status - if status == types.RUNNING && l.startedAt.IsZero() { + if status == types.SessionStatusRUNNING && l.startedAt.IsZero() { l.startedAt = time.Now() } } @@ -112,7 +112,7 @@ func (l *lifecycle) Close() error { } l.sessionRegistry.Remove(key) - if tunnelType == types.TCP { + if tunnelType == types.TunnelTypeTCP { if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { firstErr = err } @@ -122,7 +122,7 @@ func (l *lifecycle) Close() error { } func (l *lifecycle) IsActive() bool { - return l.status == types.RUNNING + return l.status == types.SessionStatusRUNNING } func (l *lifecycle) StartedAt() time.Time { diff --git a/session/session.go b/session/session.go index d113084..65bbc54 100644 --- a/session/session.go +++ b/session/session.go @@ -37,6 +37,7 @@ type Session interface { } type session struct { + config config.Config initialReq <-chan *ssh.Request sshChan <-chan ssh.NewChannel lifecycle lifecycle.Lifecycle @@ -48,13 +49,14 @@ type session struct { var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { +func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { slugManager := slug.New() - forwarderManager := forwarder.New(slugManager, conn) + forwarderManager := forwarder.New(config, slugManager, conn) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) - interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) + interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) return &session{ + config: config, initialReq: initialReq, sshChan: sshChan, lifecycle: lifecycleManager, @@ -83,12 +85,12 @@ func (s *session) Slug() slug.Slug { func (s *session) Detail() *types.Detail { tunnelTypeMap := map[types.TunnelType]string{ - types.HTTP: "HTTP", - types.TCP: "TCP", + types.TunnelTypeHTTP: "TunnelTypeHTTP", + types.TunnelTypeTCP: "TunnelTypeTCP", } tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()] if !ok { - tunnelType = "UNKNOWN" + tunnelType = "TunnelTypeUNKNOWN" } return &types.Detail{ @@ -131,7 +133,7 @@ func (s *session) setupSessionMode() error { } return s.setupInteractiveMode(channel) case <-time.After(500 * time.Millisecond): - s.interaction.SetMode(types.HEADLESS) + s.interaction.SetMode(types.InteractiveModeHEADLESS) return nil } } @@ -152,13 +154,13 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error { s.lifecycle.SetChannel(ch) s.interaction.SetChannel(ch) - s.interaction.SetMode(types.INTERACTIVE) + s.interaction.SetMode(types.InteractiveModeINTERACTIVE) return nil } func (s *session) handleMissingForwardRequest() error { - err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) + err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain, s.config.SSHPort)) if err != nil { return err } @@ -169,8 +171,8 @@ func (s *session) handleMissingForwardRequest() error { } func (s *session) shouldRejectUnauthorized() bool { - return s.interaction.Mode() == types.HEADLESS && - config.Getenv("MODE", "standalone") == "standalone" && + return s.interaction.Mode() == types.InteractiveModeHEADLESS && + s.config.Mode() == types.ServerModeSTANDALONE && s.lifecycle.User() == "UNAUTHORIZED" } @@ -318,7 +320,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen s.forwarder.SetType(tunnelType) s.forwarder.SetForwardedPort(portToBind) s.slug.Set(slug) - s.lifecycle.SetStatus(types.RUNNING) + s.lifecycle.SetStatus(types.SessionStatusRUNNING) if listener != nil { s.forwarder.SetListener(listener) @@ -348,12 +350,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { if err != nil { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err)) } - key := types.SessionKey{Id: randomString, Type: types.HTTP} + key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP} if !s.registry.Register(key, s) { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString)) } - err = s.finalizeForwarding(req, portToBind, nil, types.HTTP, key.Id) + err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id) if err != nil { return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } @@ -371,12 +373,12 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) } - key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} + key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP} if !s.registry.Register(key, s) { - return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TCP client with id: %s", key.Id)) + return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id)) } - err = s.finalizeForwarding(req, portToBind, listener, types.TCP, key.Id) + err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id) if err != nil { return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err)) } diff --git a/types/types.go b/types/types.go index b91dffb..8e7d1e5 100644 --- a/types/types.go +++ b/types/types.go @@ -2,26 +2,33 @@ package types import "time" -type Status int +type SessionStatus int const ( - INITIALIZING Status = iota - RUNNING + SessionStatusINITIALIZING SessionStatus = iota + SessionStatusRUNNING ) -type Mode int +type InteractiveMode int const ( - INTERACTIVE Mode = iota - HEADLESS + InteractiveModeINTERACTIVE InteractiveMode = iota + 1 + InteractiveModeHEADLESS ) type TunnelType int const ( - UNKNOWN TunnelType = iota - HTTP - TCP + TunnelTypeUNKNOWN TunnelType = iota + TunnelTypeHTTP + TunnelTypeTCP +) + +type ServerMode int + +const ( + ServerModeSTANDALONE = iota + 1 + ServerModeNODE ) type SessionKey struct { @@ -37,7 +44,7 @@ type Detail struct { StartedAt time.Time `json:"started_at,omitempty"` } -var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + +var BadGatewayResponse = []byte("TunnelTypeHTTP/1.1 502 Bad Gateway\r\n" + "Content-Length: 11\r\n" + "Content-Type: text/plain\r\n\r\n" + "Bad Gateway")