7 Commits

Author SHA1 Message Date
78b7b894d9 chore(deps): update actions/checkout action to v6
SonarQube Scan / SonarQube Trigger (push) Successful in 49s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 48s
2026-01-22 16:01:02 +00:00
d91eecb2a0 chore: Refactor and optimize project architecture
Docker Build and Push / build-and-push-tags (push) Has been skipped
SonarQube Scan / SonarQube Trigger (push) Successful in 54s
Docker Build and Push / build-and-push-branches (push) Successful in 12m17s
- Fix: Resolve goroutine deadlock on early connection close
- Refactor: Simplify Start() method, unify forwarding logic, and enhance HTTP handler modularity
- Improve: Connection handling, header parsing, and resource management
- Refactor: Centralize environment loading, enforce typed access, and cleanup config structure
- Enhance: SonarQube scan integration for CI
- Chore: Reorganize project layout and simplify lifecycle management
- Define reusable constants for registry errors

Reviewed-on: #74
2026-01-22 22:16:33 +07:00
961a905542 chore(restructure): refactor architecture, config, and lifecycle management
Docker Build and Push / build-and-push-tags (push) Has been skipped
SonarQube Scan / SonarQube Trigger (push) Successful in 44s
Docker Build and Push / build-and-push-branches (push) Successful in 11m16s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 46s
- Reorganized internal packages and overall project structure
- Moved HTTP/HTTPS/TCP servers into the transport layer
- Decoupled server initialization from HTTP/HTTPS/TCP startup logic
- Separated HTTP parsing, streaming, middleware, and session registry concerns
- Refactored session and forwarder responsibilities for clearer ownership
- Centralized environment loading with validated, typed config access
- Made config immutable after initialization and normalized enum naming
- Improved resource lifecycle handling and error aggregation on shutdown
- Introduced reusable, package-level registry errors
- Added SonarQube scanning to CI pipeline

Reviewed-on: #73
2026-01-22 00:48:40 +07:00
634c8321ef refactor(registry): define reusable constant errors
SonarQube Scan / SonarQube Trigger (push) Successful in 52s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 46s
- Introduced package-level error variables in registry to replace repeated fmt.Errorf calls
- Added errors like ErrSessionNotFound, ErrSlugInUse, ErrInvalidSlug, ErrForbiddenSlug, ErrSlugChangeNotAllowed, and ErrSlugUnchanged
2026-01-22 00:39:28 +07:00
9f4c24a3f3 refactor(lifecycle): reorder resource closing and simplify Close()
SonarQube Scan / SonarQube Trigger (push) Successful in 53s
- Close channel and connection first, then remove session
- Close forwarded port and forwarder at the end for TCP tunnels
- Aggregate all errors using errors.Join instead of failing early
2026-01-21 21:59:59 +07:00
1408b80917 ci: add sonarqube scan
SonarQube Scan / SonarQube Trigger (push) Successful in 48s
2026-01-21 21:24:57 +07:00
2bc20dd991 refactor(config): centralize env loading and enforce typed access
- Centralize environment variable loading in config.MustLoad
- Parse and validate all env vars once at initialization
- Make config fields private and read-only
- Remove public Getenv usage in favor of typed accessors
- Improve validation and initialization order
- Normalize enum naming to be idiomatic and avoid constant collisions
2026-01-21 19:43:19 +07:00
21 changed files with 460 additions and 277 deletions
+20
View File
@@ -0,0 +1,20 @@
on:
push:
pull_request:
types: [opened, synchronize, reopened]
name: SonarQube Scan
jobs:
sonarqube:
name: SonarQube Trigger
runs-on: ubuntu-latest
steps:
- name: Checking out
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: SonarQube Scan
uses: SonarSource/sonarqube-scan-action@v7.0.0
env:
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
+52 -22
View File
@@ -1,33 +1,63 @@
package config package config
import ( import "tunnel_pls/types"
"os"
"strconv"
"github.com/joho/godotenv" type Config interface {
) Domain() string
SSHPort() string
func Load() error { HTTPPort() string
if _, err := os.Stat(".env"); err == nil { HTTPSPort() string
return godotenv.Load(".env")
} TLSEnabled() bool
return nil TLSRedirect() bool
ACMEEmail() string
CFAPIToken() string
ACMEStaging() bool
AllowedPortsStart() uint16
AllowedPortsEnd() uint16
BufferSize() int
PprofEnabled() bool
PprofPort() string
Mode() types.ServerMode
GRPCAddress() string
GRPCPort() string
NodeToken() string
} }
func Getenv(key, defaultValue string) string { func MustLoad() (Config, error) {
val := os.Getenv(key) if err := loadEnvFile(); err != nil {
if val == "" { return nil, err
val = defaultValue
} }
return val cfg, err := parse()
if err != nil {
return nil, err
} }
func GetBufferSize() int { return cfg, nil
sizeStr := Getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(sizeStr)
if err != nil || size < 4096 || size > 1048576 {
return 32768
}
return size
} }
func (c *config) Domain() string { return c.domain }
func (c *config) SSHPort() string { return c.sshPort }
func (c *config) HTTPPort() string { return c.httpPort }
func (c *config) HTTPSPort() string { return c.httpsPort }
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
func (c *config) ACMEEmail() string { return c.acmeEmail }
func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging }
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
func (c *config) BufferSize() int { return c.bufferSize }
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
func (c *config) PprofPort() string { return c.pprofPort }
func (c *config) Mode() types.ServerMode { return c.mode }
func (c *config) GRPCAddress() string { return c.grpcAddress }
func (c *config) GRPCPort() string { return c.grpcPort }
func (c *config) NodeToken() string { return c.nodeToken }
+170
View File
@@ -0,0 +1,170 @@
package config
import (
"fmt"
"log"
"os"
"strconv"
"strings"
"tunnel_pls/types"
"github.com/joho/godotenv"
)
type config struct {
domain string
sshPort string
httpPort string
httpsPort string
tlsEnabled bool
tlsRedirect bool
acmeEmail string
cfAPIToken string
acmeStaging bool
allowedPortsStart uint16
allowedPortsEnd uint16
bufferSize int
pprofEnabled bool
pprofPort string
mode types.ServerMode
grpcAddress string
grpcPort string
nodeToken string
}
func parse() (*config, error) {
mode, err := parseMode()
if err != nil {
return nil, err
}
domain := getenv("DOMAIN", "localhost")
sshPort := getenv("PORT", "2200")
httpPort := getenv("HTTP_PORT", "8080")
httpsPort := getenv("HTTPS_PORT", "8443")
tlsEnabled := getenvBool("TLS_ENABLED", false)
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
acmeStaging := getenvBool("ACME_STAGING", false)
cfToken := getenv("CF_API_TOKEN", "")
if tlsEnabled && cfToken == "" {
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
}
start, end, err := parseAllowedPorts()
if err != nil {
return nil, err
}
bufferSize := parseBufferSize()
pprofEnabled := getenvBool("PPROF_ENABLED", false)
pprofPort := getenv("PPROF_PORT", "6060")
grpcHost := getenv("GRPC_ADDRESS", "localhost")
grpcPort := getenv("GRPC_PORT", "8080")
nodeToken := getenv("NODE_TOKEN", "")
if mode == types.ServerModeNODE && nodeToken == "" {
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
}
return &config{
domain: domain,
sshPort: sshPort,
httpPort: httpPort,
httpsPort: httpsPort,
tlsEnabled: tlsEnabled,
tlsRedirect: tlsRedirect,
acmeEmail: acmeEmail,
cfAPIToken: cfToken,
acmeStaging: acmeStaging,
allowedPortsStart: start,
allowedPortsEnd: end,
bufferSize: bufferSize,
pprofEnabled: pprofEnabled,
pprofPort: pprofPort,
mode: mode,
grpcAddress: grpcHost,
grpcPort: grpcPort,
nodeToken: nodeToken,
}, nil
}
func loadEnvFile() error {
if _, err := os.Stat(".env"); err == nil {
return godotenv.Load(".env")
}
return nil
}
func parseMode() (types.ServerMode, error) {
switch strings.ToLower(getenv("MODE", "standalone")) {
case "standalone":
return types.ServerModeSTANDALONE, nil
case "node":
return types.ServerModeNODE, nil
default:
return 0, fmt.Errorf("invalid MODE value")
}
}
func parseAllowedPorts() (uint16, uint16, error) {
raw := getenv("ALLOWED_PORTS", "")
if raw == "" {
return 0, 0, nil
}
parts := strings.Split(raw, "-")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
}
start, err := strconv.ParseUint(parts[0], 10, 16)
if err != nil {
return 0, 0, err
}
end, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return 0, 0, err
}
return uint16(start), uint16(end), nil
}
func parseBufferSize() int {
raw := getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(raw)
if err != nil || size < 4096 || size > 1048576 {
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
return 4096
}
return size
}
func getenv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}
func getenvBool(key string, def bool) bool {
val := os.Getenv(key)
if val == "" {
return def
}
return val == "true"
}
+9 -7
View File
@@ -29,6 +29,7 @@ type Client interface {
CheckServerHealth(ctx context.Context) error CheckServerHealth(ctx context.Context) error
} }
type client struct { type client struct {
config config.Config
conn *grpc.ClientConn conn *grpc.ClientConn
address string address string
sessionRegistry registry.Registry sessionRegistry registry.Registry
@@ -37,7 +38,7 @@ type client struct {
closing bool closing bool
} }
func New(address string, sessionRegistry registry.Registry) (Client, error) { func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) {
var opts []grpc.DialOption var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
@@ -66,6 +67,7 @@ func New(address string, sessionRegistry registry.Registry) (Client, error) {
authorizeConnectionService := proto.NewUserServiceClient(conn) authorizeConnectionService := proto.NewUserServiceClient(conn)
return &client{ return &client{
config: config,
conn: conn, conn: conn,
address: address, address: address,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
@@ -192,7 +194,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
oldSlug := slugEvent.GetOld() oldSlug := slugEvent.GetOld()
newSlug := slugEvent.GetNew() newSlug := slugEvent.GetNew()
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP}) userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP})
if err != nil { if err != nil {
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
@@ -202,7 +204,7 @@ func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
}, "slug change failure response") }, "slug change failure response")
} }
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil { if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil {
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{ Payload: &proto.Node_SlugEventResponse{
@@ -227,7 +229,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
for _, ses := range sessions { for _, ses := range sessions {
detail := ses.Detail() detail := ses.Detail()
details = append(details, &proto.Detail{ details = append(details, &proto.Detail{
Node: config.Getenv("DOMAIN", "localhost"), Node: c.config.Domain(),
ForwardingType: detail.ForwardingType, ForwardingType: detail.ForwardingType,
Slug: detail.Slug, Slug: detail.Slug,
UserId: detail.UserID, UserId: detail.UserID,
@@ -299,11 +301,11 @@ func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) { func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
switch t { switch t {
case proto.TunnelType_HTTP: case proto.TunnelType_HTTP:
return types.HTTP, nil return types.TunnelTypeHTTP, nil
case proto.TunnelType_TCP: case proto.TunnelType_TCP:
return types.TCP, nil return types.TunnelTypeTCP, nil
default: default:
return types.UNKNOWN, fmt.Errorf("unknown tunnel type received") return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
} }
} }
+3 -3
View File
@@ -17,9 +17,9 @@ type RequestHeader interface {
Set(key string, value string) Set(key string, value string)
Remove(key string) Remove(key string)
Finalize() []byte Finalize() []byte
GetMethod() string Method() string
GetPath() string Path() string
GetVersion() string Version() string
} }
type requestHeader struct { type requestHeader struct {
method string method string
+3 -3
View File
@@ -32,15 +32,15 @@ func (req *requestHeader) Remove(key string) {
delete(req.headers, key) delete(req.headers, key)
} }
func (req *requestHeader) GetMethod() string { func (req *requestHeader) Method() string {
return req.method return req.method
} }
func (req *requestHeader) GetPath() string { func (req *requestHeader) Path() string {
return req.path return req.path
} }
func (req *requestHeader) GetVersion() string { func (req *requestHeader) Version() string {
return req.version return req.version
} }
+19 -10
View File
@@ -34,6 +34,15 @@ type registry struct {
slugIndex map[Key]string slugIndex map[Key]string
} }
var (
ErrSessionNotFound = fmt.Errorf("session not found")
ErrSlugInUse = fmt.Errorf("slug already in use")
ErrInvalidSlug = fmt.Errorf("invalid slug")
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
)
func NewRegistry() Registry { func NewRegistry() Registry {
return &registry{ return &registry{
byUser: make(map[string]map[Key]Session), byUser: make(map[string]map[Key]Session),
@@ -47,12 +56,12 @@ func (r *registry) Get(key Key) (session Session, err error) {
userID, ok := r.slugIndex[key] userID, ok := r.slugIndex[key]
if !ok { if !ok {
return nil, fmt.Errorf("Session not found") return nil, ErrSessionNotFound
} }
client, ok := r.byUser[userID][key] client, ok := r.byUser[userID][key]
if !ok { if !ok {
return nil, fmt.Errorf("Session not found") return nil, ErrSessionNotFound
} }
return client, nil return client, nil
} }
@@ -63,37 +72,37 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error
client, ok := r.byUser[user][key] client, ok := r.byUser[user][key]
if !ok { if !ok {
return nil, fmt.Errorf("Session not found") return nil, ErrSessionNotFound
} }
return client, nil return client, nil
} }
func (r *registry) Update(user string, oldKey, newKey Key) error { func (r *registry) Update(user string, oldKey, newKey Key) error {
if oldKey.Type != newKey.Type { if oldKey.Type != newKey.Type {
return fmt.Errorf("tunnel type cannot change") return ErrSlugUnchanged
} }
if newKey.Type != types.HTTP { if newKey.Type != types.TunnelTypeHTTP {
return fmt.Errorf("non http tunnel cannot change slug") return ErrSlugChangeNotAllowed
} }
if isForbiddenSlug(newKey.Id) { if isForbiddenSlug(newKey.Id) {
return fmt.Errorf("this subdomain is reserved. Please choose a different one") return ErrForbiddenSlug
} }
if !isValidSlug(newKey.Id) { if !isValidSlug(newKey.Id) {
return fmt.Errorf("invalid subdomain. Follow the rules") return ErrInvalidSlug
} }
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return fmt.Errorf("someone already uses this subdomain") return ErrSlugInUse
} }
client, ok := r.byUser[user][oldKey] client, ok := r.byUser[user][oldKey]
if !ok { if !ok {
return fmt.Errorf("Session not found") return ErrSessionNotFound
} }
delete(r.byUser[user], oldKey) delete(r.byUser[user], oldKey)
+2 -2
View File
@@ -12,9 +12,9 @@ type httpServer struct {
port string port string
} }
func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &httpServer{ return &httpServer{
handler: newHTTPHandler(sessionRegistry, redirectTLS), handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
port: port, port: port,
} }
} }
+14 -10
View File
@@ -4,12 +4,12 @@ import (
"bufio" "bufio"
"errors" "errors"
"fmt" "fmt"
"io"
"log" "log"
"net" "net"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/http/header" "tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream" "tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware" "tunnel_pls/internal/middleware"
@@ -20,12 +20,14 @@ import (
) )
type httpHandler struct { type httpHandler struct {
domain string
sessionRegistry registry.Registry sessionRegistry registry.Registry
redirectTLS bool redirectTLS bool
} }
func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
return &httpHandler{ return &httpHandler{
domain: domain,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS, redirectTLS: redirectTLS,
} }
@@ -67,7 +69,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
} }
if hh.shouldRedirectToTLS(isTLS) { if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost"))) _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
return return
} }
@@ -133,7 +135,7 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
func (hh *httpHandler) getSession(slug string) (registry.Session, error) { func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug, Id: slug,
Type: types.HTTP, Type: types.TunnelTypeHTTP,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -143,17 +145,19 @@ func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession) channel, err := hh.openForwardedChannel(hw, sshSession)
defer func() {
err = channel.Close()
if err != nil {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
if err != nil { if err != nil {
log.Printf("Failed to establish channel: %v", err) log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw) sshSession.Forwarder().WriteBadGatewayResponse(hw)
return return
} }
defer func() {
err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
hh.setupMiddlewares(hw) hh.setupMiddlewares(hw)
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil { if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
+5 -8
View File
@@ -9,28 +9,25 @@ import (
) )
type https struct { type https struct {
tlsConfig *tls.Config
httpHandler *httpHandler httpHandler *httpHandler
domain string domain string
port string port string
} }
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
return &https{ return &https{
httpHandler: newHTTPHandler(sessionRegistry, redirectTLS), tlsConfig: tlsConfig,
httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
domain: domain, domain: domain,
port: port, port: port,
} }
} }
func (ht *https) Listen() (net.Listener, error) { func (ht *https) Listen() (net.Listener, error) {
tlsConfig, err := NewTLSConfig(ht.domain) return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
if err != nil {
return nil, err
} }
return tls.Listen("tcp", ":"+ht.port, tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error { func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.port) log.Printf("HTTPS server is starting on port %s", ht.port)
for { for {
+12 -33
View File
@@ -26,7 +26,8 @@ type TLSManager interface {
} }
type tlsManager struct { type tlsManager struct {
domain string config config.Config
certPath string certPath string
keyPath string keyPath string
storagePath string storagePath string
@@ -42,7 +43,7 @@ type tlsManager struct {
var globalTLSManager TLSManager var globalTLSManager TLSManager
var tlsManagerOnce sync.Once var tlsManagerOnce sync.Once
func NewTLSConfig(domain string) (*tls.Config, error) { func NewTLSConfig(config config.Config) (*tls.Config, error) {
var initErr error var initErr error
tlsManagerOnce.Do(func() { tlsManagerOnce.Do(func() {
@@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
storagePath := "certs/tls/certmagic" storagePath := "certs/tls/certmagic"
tm := &tlsManager{ tm := &tlsManager{
domain: domain, config: config,
certPath: certPath, certPath: certPath,
keyPath: keyPath, keyPath: keyPath,
storagePath: storagePath, storagePath: storagePath,
@@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
tm.useCertMagic = false tm.useCertMagic = false
tm.startCertWatcher() tm.startCertWatcher()
} else { } else {
if !isACMEConfigComplete() { log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain())
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
return
}
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
if err := tm.initCertMagic(); err != nil { if err := tm.initCertMagic(); err != nil {
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err) initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
return return
@@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
return globalTLSManager.getTLSConfig(), nil return globalTLSManager.getTLSConfig(), nil
} }
func isACMEConfigComplete() bool {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
return cfAPIToken != ""
}
func (tm *tlsManager) userCertsExistAndValid() bool { func (tm *tlsManager) userCertsExistAndValid() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath) log.Printf("Certificate file not found: %s", tm.certPath)
@@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
return false return false
} }
return ValidateCertDomains(tm.certPath, tm.domain) return ValidateCertDomains(tm.certPath, tm.config.Domain())
} }
func ValidateCertDomains(certPath, domain string) bool { func ValidateCertDomains(certPath, domain string) bool {
@@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() {
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) { if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
log.Printf("Certificate files changed, reloading...") log.Printf("Certificate files changed, reloading...")
if !ValidateCertDomains(tm.certPath, tm.domain) { if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
log.Printf("New certificates don't cover required domains") log.Printf("New certificates don't cover required domains")
if !isACMEConfigComplete() {
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
continue
}
log.Printf("Switching to CertMagic for automatic certificate management")
if err := tm.initCertMagic(); err != nil { if err := tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err) log.Printf("Failed to initialize CertMagic: %v", err)
continue continue
@@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error {
return fmt.Errorf("failed to create cert storage directory: %w", err) return fmt.Errorf("failed to create cert storage directory: %w", err)
} }
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain) if tm.config.CFAPIToken() == "" {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
if cfAPIToken == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
} }
cfProvider := &cloudflare.Provider{ cfProvider := &cloudflare.Provider{
APIToken: cfAPIToken, APIToken: tm.config.CFAPIToken(),
} }
storage := &certmagic.FileStorage{Path: tm.storagePath} storage := &certmagic.FileStorage{Path: tm.storagePath}
@@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error {
}) })
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: acmeEmail, Email: tm.config.ACMEEmail(),
Agreed: true, Agreed: true,
DNS01Solver: &certmagic.DNS01Solver{ DNS01Solver: &certmagic.DNS01Solver{
DNSManager: certmagic.DNSManager{ DNSManager: certmagic.DNSManager{
@@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error {
}, },
}) })
if acmeStaging { if tm.config.ACMEStaging() {
acmeIssuer.CA = certmagic.LetsEncryptStagingCA acmeIssuer.CA = certmagic.LetsEncryptStagingCA
log.Printf("Using Let's Encrypt staging server") log.Printf("Using Let's Encrypt staging server")
} else { } else {
@@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error {
magic.Issuers = []certmagic.Issuer{acmeIssuer} magic.Issuers = []certmagic.Issuer{acmeIssuer}
tm.magic = magic tm.magic = magic
domains := []string{tm.domain, "*." + tm.domain} domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
log.Printf("Requesting certificates for: %v", domains) log.Printf("Requesting certificates for: %v", domains)
ctx := context.Background() ctx := context.Background()
+54 -91
View File
@@ -9,8 +9,6 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
"strconv"
"strings"
"syscall" "syscall"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
@@ -21,6 +19,7 @@ import (
"tunnel_pls/internal/transport" "tunnel_pls/internal/transport"
"tunnel_pls/internal/version" "tunnel_pls/internal/version"
"tunnel_pls/server" "tunnel_pls/server"
"tunnel_pls/types"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -36,27 +35,12 @@ func main() {
log.Printf("Starting %s", version.GetVersion()) log.Printf("Starting %s", version.GetVersion())
err := config.Load() conf, err := config.MustLoad()
if err != nil { if err != nil {
log.Fatalf("Failed to load configuration: %s", err) log.Fatalf("Failed to load configuration: %s", err)
return return
} }
mode := strings.ToLower(config.Getenv("MODE", "standalone"))
isNodeMode := mode == "node"
pprofEnabled := config.Getenv("PPROF_ENABLED", "false")
if pprofEnabled == "true" {
pprofPort := config.Getenv("PPROF_PORT", "6060")
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
}
sshConfig := &ssh.ServerConfig{ sshConfig := &ssh.ServerConfig{
NoClientAuth: true, NoClientAuth: true,
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
@@ -88,16 +72,11 @@ func main() {
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
var grpcClient client.Client var grpcClient client.Client
if isNodeMode {
grpcHost := config.Getenv("GRPC_ADDRESS", "localhost")
grpcPort := config.Getenv("GRPC_PORT", "8080")
grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort)
nodeToken := config.Getenv("NODE_TOKEN", "")
if nodeToken == "" {
log.Fatalf("NODE_TOKEN is required in node mode")
}
grpcClient, err = client.New(grpcAddr, sessionRegistry) if conf.Mode() == types.ServerModeNODE {
grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort())
grpcClient, err = client.New(conf, grpcAddr, sessionRegistry)
if err != nil { if err != nil {
log.Fatalf("failed to create grpc client: %v", err) log.Fatalf("failed to create grpc client: %v", err)
} }
@@ -110,89 +89,73 @@ func main() {
healthCancel() healthCancel()
go func() { go func() {
identity := config.Getenv("DOMAIN", "localhost") if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
if err = grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil {
errChan <- fmt.Errorf("failed to subscribe to events: %w", err) errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
} }
}() }()
} }
go func() {
var httpListener net.Listener
httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect())
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
if conf.TLSEnabled() {
go func() {
var httpsListener net.Listener
tlsConfig, _ := transport.NewTLSConfig(conf)
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig)
httpsListener, err = httpsServer.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpsServer.Serve(httpsListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
}
portManager := port.New() portManager := port.New()
rawRange := config.Getenv("ALLOWED_PORTS", "") err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd())
if rawRange != "" {
splitRange := strings.Split(rawRange, "-")
if len(splitRange) == 2 {
var start, end uint64
start, err = strconv.ParseUint(splitRange[0], 10, 16)
if err != nil { if err != nil {
log.Fatalf("Failed to parse start port: %s", err) log.Fatalf("Failed to initialize port manager: %s", err)
}
end, err = strconv.ParseUint(splitRange[1], 10, 16)
if err != nil {
log.Fatalf("Failed to parse end port: %s", err)
}
if err = portManager.AddRange(uint16(start), uint16(end)); err != nil {
log.Fatalf("Failed to add port range: %s", err)
}
log.Printf("PortRegistry range configured: %d-%d", start, end)
} else {
log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange)
}
}
tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true"
redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true"
go func() {
httpPort := config.Getenv("HTTP_PORT", "8080")
var httpListener net.Listener
httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS)
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return return
} }
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
if tlsEnabled {
go func() {
httpsPort := config.Getenv("HTTPS_PORT", "8443")
domain := config.Getenv("DOMAIN", "localhost")
var httpListener net.Listener
httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS)
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
}
var app server.Server var app server.Server
go func() { go func() {
sshPort := config.Getenv("PORT", "2200") app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort())
app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort)
if err != nil { if err != nil {
errChan <- fmt.Errorf("failed to start server: %s", err) errChan <- fmt.Errorf("failed to start server: %s", err)
return return
} }
app.Start() app.Start()
}() }()
if conf.PprofEnabled() {
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort())
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
}
select { select {
case err = <-errChan: case err = <-errChan:
log.Printf("error happen : %s", err) log.Printf("error happen : %s", err)
+8 -5
View File
@@ -7,6 +7,7 @@ import (
"log" "log"
"net" "net"
"time" "time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port" "tunnel_pls/internal/port"
"tunnel_pls/internal/registry" "tunnel_pls/internal/registry"
@@ -20,24 +21,26 @@ type Server interface {
Close() error Close() error
} }
type server struct { type server struct {
config config.Config
sshPort string sshPort string
sshListener net.Listener sshListener net.Listener
config *ssh.ServerConfig sshConfig *ssh.ServerConfig
grpcClient client.Client grpcClient client.Client
sessionRegistry registry.Registry sessionRegistry registry.Registry
portRegistry port.Port portRegistry port.Port
} }
func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort)) listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &server{ return &server{
config: config,
sshPort: sshPort, sshPort: sshPort,
sshListener: listener, sshListener: listener,
config: sshConfig, sshConfig: sshConfig,
grpcClient: grpcClient, grpcClient: grpcClient,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
portRegistry: portRegistry, portRegistry: portRegistry,
@@ -66,7 +69,7 @@ func (s *server) Close() error {
} }
func (s *server) handleConnection(conn net.Conn) { func (s *server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil { if err != nil {
log.Printf("failed to establish SSH connection: %v", err) log.Printf("failed to establish SSH connection: %v", err)
err = conn.Close() err = conn.Close()
@@ -92,7 +95,7 @@ func (s *server) handleConnection(conn net.Conn) {
cancel() cancel()
} }
log.Println("SSH connection established:", sshConn.User()) log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
err = sshSession.Start() err = sshSession.Start()
if err != nil { if err != nil {
log.Printf("SSH session ended with error: %v", err) log.Printf("SSH session ended with error: %v", err)
+31 -32
View File
@@ -18,37 +18,6 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
var bufferPool = sync.Pool{
New: func() interface{} {
bufSize := config.GetBufferSize()
return make([]byte, bufSize)
},
}
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
conn ssh.Conn
}
func New(slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: types.UNKNOWN,
forwardedPort: 0,
slug: slug,
conn: conn,
}
}
type Forwarder interface { type Forwarder interface {
SetType(tunnelType types.TunnelType) SetType(tunnelType types.TunnelType)
SetForwardedPort(port uint16) SetForwardedPort(port uint16)
@@ -62,6 +31,36 @@ type Forwarder interface {
WriteBadGatewayResponse(dst io.Writer) WriteBadGatewayResponse(dst io.Writer)
Close() error Close() error
} }
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
conn ssh.Conn
bufferPool sync.Pool
}
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: types.TunnelTypeUNKNOWN,
forwardedPort: 0,
slug: slug,
conn: conn,
bufferPool: sync.Pool{
New: func() interface{} {
bufSize := config.BufferSize()
return make([]byte, bufSize)
},
},
}
}
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := f.bufferPool.Get().([]byte)
defer f.bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
type channelResult struct { type channelResult struct {
@@ -107,7 +106,7 @@ func closeWriter(w io.Writer) error {
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error { func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
var errs []error var errs []error
_, err := copyWithBuffer(dst, src) _, err := f.copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err)) errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
} }
+11 -10
View File
@@ -18,9 +18,9 @@ import (
) )
type Interaction interface { type Interaction interface {
Mode() types.Mode Mode() types.InteractiveMode
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetMode(m types.Mode) SetMode(m types.InteractiveMode)
SetWH(w, h int) SetWH(w, h int)
Start() Start()
Redraw() Redraw()
@@ -39,6 +39,7 @@ type Forwarder interface {
type CloseFunc func() error type CloseFunc func() error
type interaction struct { type interaction struct {
config config.Config
channel ssh.Channel channel ssh.Channel
slug slug.Slug slug slug.Slug
forwarder Forwarder forwarder Forwarder
@@ -48,14 +49,14 @@ type interaction struct {
program *tea.Program program *tea.Program
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
mode types.Mode mode types.InteractiveMode
} }
func (i *interaction) SetMode(m types.Mode) { func (i *interaction) SetMode(m types.InteractiveMode) {
i.mode = m i.mode = m
} }
func (i *interaction) Mode() types.Mode { func (i *interaction) Mode() types.InteractiveMode {
return i.mode return i.mode
} }
@@ -75,9 +76,10 @@ func (i *interaction) SetWH(w, h int) {
} }
} }
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &interaction{ return &interaction{
config: config,
channel: nil, channel: nil,
slug: slug, slug: slug,
forwarder: forwarder, forwarder: forwarder,
@@ -174,14 +176,13 @@ func (m *model) View() string {
} }
func (i *interaction) Start() { func (i *interaction) Start() {
if i.mode == types.HEADLESS { if i.mode == types.InteractiveModeHEADLESS {
return return
} }
lipgloss.SetColorProfile(termenv.TrueColor) lipgloss.SetColorProfile(termenv.TrueColor)
domain := config.Getenv("DOMAIN", "localhost")
protocol := "http" protocol := "http"
if config.Getenv("TLS_ENABLED", "false") == "true" { if i.config.TLSEnabled() {
protocol = "https" protocol = "https"
} }
@@ -209,7 +210,7 @@ func (i *interaction) Start() {
ti.Width = 50 ti.Width = 50
m := &model{ m := &model{
domain: domain, domain: i.config.Domain(),
protocol: protocol, protocol: protocol,
tunnelType: tunnelType, tunnelType: tunnelType,
port: port, port: port,
+1 -1
View File
@@ -41,7 +41,7 @@ type model struct {
} }
func (m *model) getTunnelURL() string { func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP { if m.tunnelType == types.TunnelTypeHTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain) return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
} }
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port) return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
+4 -4
View File
@@ -15,7 +15,7 @@ import (
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd var cmd tea.Cmd
if m.tunnelType != types.HTTP { if m.tunnelType != types.TunnelTypeHTTP {
m.editingSlug = false m.editingSlug = false
m.slugError = "" m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
@@ -30,10 +30,10 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
inputValue := m.slugInput.Value() inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{ if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
Id: m.interaction.slug.String(), Id: m.interaction.slug.String(),
Type: types.HTTP, Type: types.TunnelTypeHTTP,
}, types.SessionKey{ }, types.SessionKey{
Id: inputValue, Id: inputValue,
Type: types.HTTP, Type: types.TunnelTypeHTTP,
}); err != nil { }); err != nil {
m.slugError = err.Error() m.slugError = err.Error()
return m, nil return m, nil
@@ -130,7 +130,7 @@ func (m *model) slugView() string {
b.WriteString(titleStyle.Render(title)) b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n") b.WriteString("\n\n")
if m.tunnelType != types.HTTP { if m.tunnelType != types.TunnelTypeHTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60) warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle(). warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")). Foreground(lipgloss.Color("#FFA500")).
+25 -29
View File
@@ -2,8 +2,6 @@ package lifecycle
import ( import (
"errors" "errors"
"io"
"net"
"time" "time"
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
@@ -24,7 +22,7 @@ type SessionRegistry interface {
} }
type lifecycle struct { type lifecycle struct {
status types.Status status types.SessionStatus
conn ssh.Conn conn ssh.Conn
channel ssh.Channel channel ssh.Channel
forwarder Forwarder forwarder Forwarder
@@ -37,7 +35,7 @@ type lifecycle struct {
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle { func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
return &lifecycle{ return &lifecycle{
status: types.INITIALIZING, status: types.SessionStatusINITIALIZING,
conn: conn, conn: conn,
channel: nil, channel: nil,
forwarder: forwarder, forwarder: forwarder,
@@ -54,7 +52,7 @@ type Lifecycle interface {
PortRegistry() portUtil.Port PortRegistry() portUtil.Port
User() string User() string
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetStatus(status types.Status) SetStatus(status types.SessionStatus)
IsActive() bool IsActive() bool
StartedAt() time.Time StartedAt() time.Time
Close() error Close() error
@@ -74,35 +72,30 @@ func (l *lifecycle) SetChannel(channel ssh.Channel) {
func (l *lifecycle) Connection() ssh.Conn { func (l *lifecycle) Connection() ssh.Conn {
return l.conn return l.conn
} }
func (l *lifecycle) SetStatus(status types.Status) { func (l *lifecycle) SetStatus(status types.SessionStatus) {
l.status = status l.status = status
if status == types.RUNNING && l.startedAt.IsZero() { if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now() l.startedAt = time.Now()
} }
} }
func closeIfNotNil(c interface{ Close() error }) error {
if c != nil {
return c.Close()
}
return nil
}
func (l *lifecycle) Close() error { func (l *lifecycle) Close() error {
var firstErr error var errs []error
tunnelType := l.forwarder.TunnelType() tunnelType := l.forwarder.TunnelType()
if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := closeIfNotNil(l.channel); err != nil {
firstErr = err errs = append(errs, err)
} }
if l.channel != nil { if err := closeIfNotNil(l.conn); err != nil {
if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { errs = append(errs, err)
if firstErr == nil {
firstErr = err
}
}
}
if l.conn != nil {
if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
if firstErr == nil {
firstErr = err
}
}
} }
clientSlug := l.slug.String() clientSlug := l.slug.String()
@@ -112,17 +105,20 @@ func (l *lifecycle) Close() error {
} }
l.sessionRegistry.Remove(key) l.sessionRegistry.Remove(key)
if tunnelType == types.TCP { if tunnelType == types.TunnelTypeTCP {
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil { if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil {
firstErr = err errs = append(errs, err)
}
if err := l.forwarder.Close(); err != nil {
errs = append(errs, err)
} }
} }
return firstErr return errors.Join(errs...)
} }
func (l *lifecycle) IsActive() bool { func (l *lifecycle) IsActive() bool {
return l.status == types.RUNNING return l.status == types.SessionStatusRUNNING
} }
func (l *lifecycle) StartedAt() time.Time { func (l *lifecycle) StartedAt() time.Time {
+19 -17
View File
@@ -37,6 +37,7 @@ type Session interface {
} }
type session struct { type session struct {
config config.Config
initialReq <-chan *ssh.Request initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel sshChan <-chan ssh.NewChannel
lifecycle lifecycle.Lifecycle lifecycle lifecycle.Lifecycle
@@ -48,13 +49,14 @@ type session struct {
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
slugManager := slug.New() slugManager := slug.New()
forwarderManager := forwarder.New(slugManager, conn) forwarderManager := forwarder.New(config, slugManager, conn)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
return &session{ return &session{
config: config,
initialReq: initialReq, initialReq: initialReq,
sshChan: sshChan, sshChan: sshChan,
lifecycle: lifecycleManager, lifecycle: lifecycleManager,
@@ -83,12 +85,12 @@ func (s *session) Slug() slug.Slug {
func (s *session) Detail() *types.Detail { func (s *session) Detail() *types.Detail {
tunnelTypeMap := map[types.TunnelType]string{ tunnelTypeMap := map[types.TunnelType]string{
types.HTTP: "HTTP", types.TunnelTypeHTTP: "TunnelTypeHTTP",
types.TCP: "TCP", types.TunnelTypeTCP: "TunnelTypeTCP",
} }
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()] tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
if !ok { if !ok {
tunnelType = "UNKNOWN" tunnelType = "TunnelTypeUNKNOWN"
} }
return &types.Detail{ return &types.Detail{
@@ -131,7 +133,7 @@ func (s *session) setupSessionMode() error {
} }
return s.setupInteractiveMode(channel) return s.setupInteractiveMode(channel)
case <-time.After(500 * time.Millisecond): case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.HEADLESS) s.interaction.SetMode(types.InteractiveModeHEADLESS)
return nil return nil
} }
} }
@@ -152,13 +154,13 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
s.lifecycle.SetChannel(ch) s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch) s.interaction.SetChannel(ch)
s.interaction.SetMode(types.INTERACTIVE) s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
return nil return nil
} }
func (s *session) handleMissingForwardRequest() error { func (s *session) handleMissingForwardRequest() error {
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))) err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
if err != nil { if err != nil {
return err return err
} }
@@ -169,8 +171,8 @@ func (s *session) handleMissingForwardRequest() error {
} }
func (s *session) shouldRejectUnauthorized() bool { func (s *session) shouldRejectUnauthorized() bool {
return s.interaction.Mode() == types.HEADLESS && return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
config.Getenv("MODE", "standalone") == "standalone" && s.config.Mode() == types.ServerModeSTANDALONE &&
s.lifecycle.User() == "UNAUTHORIZED" s.lifecycle.User() == "UNAUTHORIZED"
} }
@@ -318,7 +320,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
s.forwarder.SetType(tunnelType) s.forwarder.SetType(tunnelType)
s.forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(slug) s.slug.Set(slug)
s.lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.SessionStatusRUNNING)
if listener != nil { if listener != nil {
s.forwarder.SetListener(listener) s.forwarder.SetListener(listener)
@@ -348,12 +350,12 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
if err != nil { if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err)) return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
} }
key := types.SessionKey{Id: randomString, Type: types.HTTP} key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
if !s.registry.Register(key, s) { if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString)) return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
} }
err = s.finalizeForwarding(req, portToBind, nil, types.HTTP, key.Id) err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
if err != nil { if err != nil {
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err)) return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
} }
@@ -371,12 +373,12 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
} }
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP} key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
if !s.registry.Register(key, s) { if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TCP client with id: %s", key.Id)) return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
} }
err = s.finalizeForwarding(req, portToBind, listener, types.TCP, key.Id) err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
if err != nil { if err != nil {
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err)) return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
} }
+1
View File
@@ -0,0 +1 @@
sonar.projectKey=tunnel-please
+16 -9
View File
@@ -2,26 +2,33 @@ package types
import "time" import "time"
type Status int type SessionStatus int
const ( const (
INITIALIZING Status = iota SessionStatusINITIALIZING SessionStatus = iota
RUNNING SessionStatusRUNNING
) )
type Mode int type InteractiveMode int
const ( const (
INTERACTIVE Mode = iota InteractiveModeINTERACTIVE InteractiveMode = iota + 1
HEADLESS InteractiveModeHEADLESS
) )
type TunnelType int type TunnelType int
const ( const (
UNKNOWN TunnelType = iota TunnelTypeUNKNOWN TunnelType = iota
HTTP TunnelTypeHTTP
TCP TunnelTypeTCP
)
type ServerMode int
const (
ServerModeSTANDALONE = iota + 1
ServerModeNODE
) )
type SessionKey struct { type SessionKey struct {