diff --git a/.gitea/workflows/build.yml b/.gitea/workflows/build.yml index 3cf1d5e..8db73b5 100644 --- a/.gitea/workflows/build.yml +++ b/.gitea/workflows/build.yml @@ -5,6 +5,14 @@ on: branches: - main - staging + paths: + - '**.go' + - 'go.mod' + - 'go.sum' + - 'Dockerfile' + - 'Dockerfile.*' + - '.dockerignore' + - '.gitea/workflows/build.yml' jobs: build-and-push: diff --git a/.gitea/workflows/renovate.yml b/.gitea/workflows/renovate.yml index e5f26b0..49008b4 100644 --- a/.gitea/workflows/renovate.yml +++ b/.gitea/workflows/renovate.yml @@ -17,4 +17,5 @@ jobs: env: RENOVATE_CONFIG_FILE: ${{ gitea.workspace }}/renovate-config.js LOG_LEVEL: "debug" - RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }} \ No newline at end of file + RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }} + GITHUB_COM_TOKEN: ${{ secrets.COM_TOKEN }} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 83d9b84..9884c52 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ COPY . . RUN --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \ + CGO_ENABLED=0 GOOS=linux \ go build -trimpath \ -ldflags="-w -s" \ -o /app/tunnel_pls \ diff --git a/README.md b/README.md index c7ddf33..dbcf475 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,6 @@ A lightweight SSH-based tunnel server written in Go that enables secure TCP and - SSH interactive session with real-time command handling - Custom subdomain management for HTTP tunnels -- Active connection control with drop functionality - Dual protocol support: HTTP and TCP tunnels - Real-time connection monitoring ## Requirements @@ -116,6 +115,110 @@ go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30 go tool pprof http://localhost:6060/debug/pprof/heap ``` +## Docker Deployment + +Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`. + +### Configuration Options + +#### 1. Root with Host Networking (RECOMMENDED) + +**File:** `docker-compose.root.yml` + +**Advantages:** +- Full TCP port forwarding support (ports 40000-41000) +- Direct binding to privileged ports (80, 443, 2200) +- Best performance with no NAT overhead +- Maximum flexibility for all tunnel types +- No port mapping limitations + +**Use Case:** Production deployments where you need unrestricted TCP forwarding and maximum performance. + +**Deploy:** +```bash +docker-compose -f docker-compose.root.yml up -d +``` + +#### 2. Standard (HTTP/HTTPS Only) + +**File:** `docker-compose.standard.yml` + +**Advantages:** +- Runs with unprivileged user (more secure) +- Standard port mappings (2200, 80, 443) +- Simple and predictable networking +- TCP port forwarding disabled (`ALLOWED_PORTS=none`) + +**Use Case:** Deployments where you only need HTTP/HTTPS tunneling without custom TCP port forwarding. + +**Deploy:** +```bash +docker-compose -f docker-compose.standard.yml up -d +``` + +#### 3. Limited TCP Forwarding + +**File:** `docker-compose.tcp.yml` + +**Advantages:** +- Runs with unprivileged user (more secure) +- Standard port mappings (2200, 80, 443) +- Limited TCP forwarding (ports 30000-31000) +- Controlled port range exposure + +**Use Case:** Deployments where you need both HTTP/HTTPS tunneling and limited TCP forwarding within a specific port range. + +**Deploy:** +```bash +docker-compose -f docker-compose.tcp.yml up -d +``` + +### Quick Start + +1. **Choose your configuration** based on your requirements +2. **Edit the environment variables** in the chosen compose file: + - `DOMAIN`: Your domain name (e.g., `example.com`) + - `ACME_EMAIL`: Your email for Let's Encrypt + - `CF_API_TOKEN`: Your Cloudflare API token (if using automatic TLS) +3. **Deploy:** + ```bash + docker-compose -f docker-compose.root.yml up -d + ``` +4. **Check logs:** + ```bash + docker-compose -f docker-compose.root.yml logs -f + ``` +5. **Stop the service:** + ```bash + docker-compose -f docker-compose.root.yml down + ``` + +### Volume Management + +All configurations use a named volume `certs` for persistent storage: +- SSH keys: `/app/certs/ssh/` +- TLS certificates: `/app/certs/tls/` + +To backup certificates: +```bash +docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar czf /backup/certs-backup.tar.gz -C /data . +``` + +To restore certificates: +```bash +docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar xzf /backup/certs-backup.tar.gz -C /data +``` + +### Recommendation + +**Use `docker-compose.root.yml`** for production deployments if you need: +- Full TCP port forwarding capabilities +- Any port range configuration +- Direct port binding without mapping overhead +- Maximum performance and flexibility + +This is the recommended configuration for most use cases as it provides the complete feature set without limitations. + ## Contributing Contributions are welcome! diff --git a/docker-compose.root.yml b/docker-compose.root.yml new file mode 100644 index 0000000..273aaec --- /dev/null +++ b/docker-compose.root.yml @@ -0,0 +1,37 @@ +version: '3.8' + +services: + tunnel-please: + image: git.fossy.my.id/bagas/tunnel-please:latest + container_name: tunnel-please-root + user: root + network_mode: host + restart: unless-stopped + volumes: + - certs:/app/certs + environment: + DOMAIN: example.com + PORT: 2200 + HTTP_PORT: 8080 + HTTPS_PORT: 8443 + TLS_ENABLED: "true" + TLS_REDIRECT: "true" + ACME_EMAIL: admin@example.com + CF_API_TOKEN: your_cloudflare_api_token_here + ACME_STAGING: "false" + CORS_LIST: http://localhost:3000,https://example.com + ALLOWED_PORTS: 40000-41000 + BUFFER_SIZE: 32768 + PPROF_ENABLED: "false" + PPROF_PORT: 6060 + healthcheck: + test: ["CMD", "/bin/sh", "-c", "netstat -tln | grep -q :2200"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s + +volumes: + certs: + driver: local + diff --git a/docker-compose.standard.yml b/docker-compose.standard.yml new file mode 100644 index 0000000..13ce33c --- /dev/null +++ b/docker-compose.standard.yml @@ -0,0 +1,39 @@ +version: '3.8' + +services: + tunnel-please: + image: git.fossy.my.id/bagas/tunnel-please:latest + container_name: tunnel-please-standard + restart: unless-stopped + ports: + - "2200:2200" + - "80:8080" + - "443:8443" + volumes: + - certs:/app/certs + environment: + DOMAIN: example.com + PORT: 2200 + HTTP_PORT: 8080 + HTTPS_PORT: 8443 + TLS_ENABLED: "true" + TLS_REDIRECT: "true" + ACME_EMAIL: admin@example.com + CF_API_TOKEN: your_cloudflare_api_token_here + ACME_STAGING: "false" + CORS_LIST: http://localhost:3000,https://example.com + ALLOWED_PORTS: none + BUFFER_SIZE: 32768 + PPROF_ENABLED: "false" + PPROF_PORT: 6060 + healthcheck: + test: ["CMD", "/bin/sh", "-c", "netstat -tln | grep -q :2200"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s + +volumes: + certs: + driver: local + diff --git a/docker-compose.tcp.yml b/docker-compose.tcp.yml new file mode 100644 index 0000000..dfbf8fb --- /dev/null +++ b/docker-compose.tcp.yml @@ -0,0 +1,40 @@ +version: '3.8' + +services: + tunnel-please: + image: git.fossy.my.id/bagas/tunnel-please:latest + container_name: tunnel-please-tcp + restart: unless-stopped + ports: + - "2200:2200" + - "80:8080" + - "443:8443" + - "30000-31000:30000-31000" + volumes: + - certs:/app/certs + environment: + DOMAIN: example.com + PORT: 2200 + HTTP_PORT: 8080 + HTTPS_PORT: 8443 + TLS_ENABLED: "true" + TLS_REDIRECT: "true" + ACME_EMAIL: admin@example.com + CF_API_TOKEN: your_cloudflare_api_token_here + ACME_STAGING: "false" + CORS_LIST: http://localhost:3000,https://example.com + ALLOWED_PORTS: 30000-31000 + BUFFER_SIZE: 32768 + PPROF_ENABLED: "false" + PPROF_PORT: 6060 + healthcheck: + test: ["CMD", "/bin/sh", "-c", "netstat -tln | grep -q :2200"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 10s + +volumes: + certs: + driver: local + diff --git a/renovate-config.js b/renovate-config.js index 883f95e..212cb36 100644 --- a/renovate-config.js +++ b/renovate-config.js @@ -1,6 +1,6 @@ module.exports = { "endpoint": "https://git.fossy.my.id/api/v1", - "gitAuthor": "Renovate Bot ", + "gitAuthor": "Renovate-Clanker ", "platform": "gitea", "onboardingConfigFileName": "renovate.json", "autodiscover": true, diff --git a/renovate.json b/renovate.json index 426bd59..a8d3f45 100644 --- a/renovate.json +++ b/renovate.json @@ -10,7 +10,10 @@ "pin", "digest" ], - "automerge": true + "automerge": true, + "baseBranches": [ + "staging" + ] } ] } diff --git a/server/handler.go b/server/handler.go index 68b92dd..494e7f5 100644 --- a/server/handler.go +++ b/server/handler.go @@ -9,7 +9,7 @@ import ( ) func (s *Server) handleConnection(conn net.Conn) { - sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config) + sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) if err != nil { log.Printf("failed to establish SSH connection: %v", err) err := conn.Close() diff --git a/server/header.go b/server/header.go index 0b36a2c..ec0c224 100644 --- a/server/header.go +++ b/server/header.go @@ -14,21 +14,38 @@ type HeaderManager interface { Finalize() []byte } -type ResponseHeaderFactory struct { +type ResponseHeaderManager interface { + Get(key string) string + Set(key string, value string) + Remove(key string) + Finalize() []byte +} + +type RequestHeaderManager interface { + Get(key string) string + Set(key string, value string) + Remove(key string) + Finalize() []byte + GetMethod() string + GetPath() string + GetVersion() string +} + +type responseHeaderFactory struct { startLine []byte headers map[string]string } -type RequestHeaderFactory struct { - Method string - Path string - Version string +type requestHeaderFactory struct { + method string + path string + version string startLine []byte headers map[string]string } -func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) { - header := &RequestHeaderFactory{ +func NewRequestHeaderFactory(br *bufio.Reader) (RequestHeaderManager, error) { + header := &requestHeaderFactory{ headers: make(map[string]string), } @@ -44,9 +61,9 @@ func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) { return nil, fmt.Errorf("invalid request line") } - header.Method = parts[0] - header.Path = parts[1] - header.Version = parts[2] + header.method = parts[0] + header.path = parts[1] + header.version = parts[2] for { line, err := br.ReadString('\n') @@ -69,8 +86,8 @@ func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) { return header, nil } -func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory { - header := &ResponseHeaderFactory{ +func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager { + header := &responseHeaderFactory{ startLine: nil, headers: make(map[string]string), } @@ -96,19 +113,19 @@ func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory { return header } -func (resp *ResponseHeaderFactory) Get(key string) string { +func (resp *responseHeaderFactory) Get(key string) string { return resp.headers[key] } -func (resp *ResponseHeaderFactory) Set(key string, value string) { +func (resp *responseHeaderFactory) Set(key string, value string) { resp.headers[key] = value } -func (resp *ResponseHeaderFactory) Remove(key string) { +func (resp *responseHeaderFactory) Remove(key string) { delete(resp.headers, key) } -func (resp *ResponseHeaderFactory) Finalize() []byte { +func (resp *responseHeaderFactory) Finalize() []byte { var buf bytes.Buffer buf.Write(resp.startLine) @@ -125,7 +142,7 @@ func (resp *ResponseHeaderFactory) Finalize() []byte { return buf.Bytes() } -func (req *RequestHeaderFactory) Get(key string) string { +func (req *requestHeaderFactory) Get(key string) string { val, ok := req.headers[key] if !ok { return "" @@ -133,15 +150,27 @@ func (req *RequestHeaderFactory) Get(key string) string { return val } -func (req *RequestHeaderFactory) Set(key string, value string) { +func (req *requestHeaderFactory) Set(key string, value string) { req.headers[key] = value } -func (req *RequestHeaderFactory) Remove(key string) { +func (req *requestHeaderFactory) Remove(key string) { delete(req.headers, key) } -func (req *RequestHeaderFactory) Finalize() []byte { +func (req *requestHeaderFactory) GetMethod() string { + return req.method +} + +func (req *requestHeaderFactory) GetPath() string { + return req.path +} + +func (req *requestHeaderFactory) GetVersion() string { + return req.version +} + +func (req *requestHeaderFactory) Finalize() []byte { var buf bytes.Buffer buf.Write(req.startLine) diff --git a/server/http.go b/server/http.go index 0ca4e23..1f08123 100644 --- a/server/http.go +++ b/server/http.go @@ -20,25 +20,63 @@ import ( type Interaction interface { SendMessage(message string) } -type CustomWriter struct { - RemoteAddr net.Addr + +type HTTPWriter interface { + io.Reader + io.Writer + SetInteraction(interaction Interaction) + AddInteraction(interaction Interaction) + GetRemoteAddr() net.Addr + GetWriter() io.Writer + AddResponseMiddleware(mw ResponseMiddleware) + AddRequestStartMiddleware(mw RequestMiddleware) + SetRequestHeader(header RequestHeaderManager) + GetRequestStartMiddleware() []RequestMiddleware +} + +type customWriter struct { + remoteAddr net.Addr writer io.Writer reader io.Reader headerBuf []byte buf []byte - respHeader *ResponseHeaderFactory - reqHeader *RequestHeaderFactory + respHeader ResponseHeaderManager + reqHeader RequestHeaderManager interaction Interaction respMW []ResponseMiddleware reqStartMW []RequestMiddleware reqEndMW []RequestMiddleware } -func (cw *CustomWriter) SetInteraction(interaction Interaction) { +func (cw *customWriter) SetInteraction(interaction Interaction) { cw.interaction = interaction } -func (cw *CustomWriter) Read(p []byte) (int, error) { +func (cw *customWriter) GetRemoteAddr() net.Addr { + return cw.remoteAddr +} + +func (cw *customWriter) GetWriter() io.Writer { + return cw.writer +} + +func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) { + cw.respMW = append(cw.respMW, mw) +} + +func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) { + cw.reqStartMW = append(cw.reqStartMW, mw) +} + +func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) { + cw.reqHeader = header +} + +func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware { + return cw.reqStartMW +} + +func (cw *customWriter) Read(p []byte) (int, error) { tmp := make([]byte, len(p)) read, err := cw.reader.Read(tmp) if read == 0 && err != nil { @@ -95,9 +133,9 @@ func (cw *CustomWriter) Read(p []byte) (int, error) { return n, nil } -func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter { - return &CustomWriter{ - RemoteAddr: remoteAddr, +func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter { + return &customWriter{ + remoteAddr: remoteAddr, writer: writer, reader: reader, buf: make([]byte, 0, 4096), @@ -129,7 +167,7 @@ func isHTTPHeader(buf []byte) bool { return true } -func (cw *CustomWriter) Write(p []byte) (int, error) { +func (cw *customWriter) Write(p []byte) (int, error) { if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" { cw.respHeader = nil } @@ -186,7 +224,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { return len(p), nil } -func (cw *CustomWriter) AddInteraction(interaction Interaction) { +func (cw *customWriter) AddInteraction(interaction Interaction) { cw.interaction = interaction } @@ -292,13 +330,13 @@ func Handler(conn net.Conn) { return } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - cw.SetInteraction(sshSession.Interaction) + cw.SetInteraction(sshSession.GetInteraction()) forwardRequest(cw, reqhf, sshSession) return } -func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { - payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr) +func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) { + payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr()) type channelResult struct { channel ssh.Channel @@ -308,7 +346,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS resultChan := make(chan channelResult, 1) go func() { - channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) + channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload) resultChan <- channelResult{channel, reqs, err} }() @@ -319,29 +357,28 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS case result := <-resultChan: if result.err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) - sshSession.Forwarder.WriteBadGatewayResponse(cw.writer) + sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter()) return } channel = result.channel reqs = result.reqs case <-time.After(5 * time.Second): log.Printf("Timeout opening forwarded-tcpip channel") - sshSession.Forwarder.WriteBadGatewayResponse(cw.writer) + sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter()) return } go ssh.DiscardRequests(reqs) fingerprintMiddleware := NewTunnelFingerprint() - forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr) + forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr()) - cw.respMW = append(cw.respMW, fingerprintMiddleware) - cw.reqStartMW = append(cw.reqStartMW, forwardedForMiddleware) - cw.reqEndMW = nil - cw.reqHeader = initialRequest + cw.AddResponseMiddleware(fingerprintMiddleware) + cw.AddRequestStartMiddleware(forwardedForMiddleware) + cw.SetRequestHeader(initialRequest) - for _, m := range cw.reqStartMW { - if err := m.HandleRequest(cw.reqHeader); err != nil { + for _, m := range cw.GetRequestStartMiddleware() { + if err := m.HandleRequest(initialRequest); err != nil { log.Printf("Error handling request: %v", err) return } @@ -353,6 +390,6 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS return } - sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) + sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr()) return } diff --git a/server/https.go b/server/https.go index fc08424..2a09c91 100644 --- a/server/https.go +++ b/server/https.go @@ -104,7 +104,7 @@ func HandlerTLS(conn net.Conn) { return } cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) - cw.SetInteraction(sshSession.Interaction) + cw.SetInteraction(sshSession.GetInteraction()) forwardRequest(cw, reqhf, sshSession) return } diff --git a/server/middleware.go b/server/middleware.go index acbd8bd..ee6ca1a 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -5,11 +5,11 @@ import ( ) type RequestMiddleware interface { - HandleRequest(header *RequestHeaderFactory) error + HandleRequest(header RequestHeaderManager) error } type ResponseMiddleware interface { - HandleResponse(header *ResponseHeaderFactory, body []byte) error + HandleResponse(header ResponseHeaderManager, body []byte) error } type TunnelFingerprint struct{} @@ -18,16 +18,11 @@ func NewTunnelFingerprint() *TunnelFingerprint { return &TunnelFingerprint{} } -func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error { +func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error { header.Set("Server", "Tunnel Please") return nil } -type RequestLogger struct { - interaction Interaction - remoteAddr net.Addr -} - type ForwardedFor struct { addr net.Addr } @@ -36,7 +31,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor { return &ForwardedFor{addr: addr} } -func (ff *ForwardedFor) HandleRequest(header *RequestHeaderFactory) error { +func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error { host, _, err := net.SplitHostPort(ff.addr.String()) if err != nil { return err diff --git a/server/server.go b/server/server.go index 8fb85b0..7f03f7c 100644 --- a/server/server.go +++ b/server/server.go @@ -11,9 +11,21 @@ import ( ) type Server struct { - Conn *net.Listener - Config *ssh.ServerConfig - HttpServer *http.Server + conn *net.Listener + config *ssh.ServerConfig + httpServer *http.Server +} + +func (s *Server) GetConn() *net.Listener { + return s.conn +} + +func (s *Server) GetConfig() *ssh.ServerConfig { + return s.config +} + +func (s *Server) GetHttpServer() *http.Server { + return s.httpServer } func NewServer(config *ssh.ServerConfig) *Server { @@ -33,15 +45,15 @@ func NewServer(config *ssh.ServerConfig) *Server { log.Fatalf("failed to start http server: %v", err) } return &Server{ - Conn: &listener, - Config: config, + conn: &listener, + config: config, } } func (s *Server) Start() { log.Println("SSH server is starting on port 2200...") for { - conn, err := (*s.Conn).Accept() + conn, err := (*s.conn).Accept() if err != nil { log.Printf("failed to accept connection: %v", err) continue diff --git a/server/tls.go b/server/tls.go index 1eb6ac8..bc69150 100644 --- a/server/tls.go +++ b/server/tls.go @@ -16,7 +16,16 @@ import ( "github.com/libdns/cloudflare" ) -type TLSManager struct { +type TLSManager interface { + userCertsExistAndValid() bool + loadUserCerts() error + startCertWatcher() + initCertMagic() error + getTLSConfig() *tls.Config + getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) +} + +type tlsManager struct { domain string certPath string keyPath string @@ -30,7 +39,7 @@ type TLSManager struct { useCertMagic bool } -var tlsManager *TLSManager +var globalTLSManager TLSManager var tlsManagerOnce sync.Once func NewTLSConfig(domain string) (*tls.Config, error) { @@ -41,7 +50,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) { keyPath := "certs/tls/privkey.pem" storagePath := "certs/tls/certmagic" - tm := &TLSManager{ + tm := &tlsManager{ domain: domain, certPath: certPath, keyPath: keyPath, @@ -72,14 +81,14 @@ func NewTLSConfig(domain string) (*tls.Config, error) { tm.useCertMagic = true } - tlsManager = tm + globalTLSManager = tm }) if initErr != nil { return nil, initErr } - return tlsManager.getTLSConfig(), nil + return globalTLSManager.getTLSConfig(), nil } func isACMEConfigComplete() bool { @@ -87,7 +96,7 @@ func isACMEConfigComplete() bool { return cfAPIToken != "" } -func (tm *TLSManager) userCertsExistAndValid() bool { +func (tm *tlsManager) userCertsExistAndValid() bool { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { log.Printf("Certificate file not found: %s", tm.certPath) return false @@ -158,7 +167,7 @@ func ValidateCertDomains(certPath, domain string) bool { return hasBase && hasWildcard } -func (tm *TLSManager) loadUserCerts() error { +func (tm *tlsManager) loadUserCerts() error { cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath) if err != nil { return err @@ -172,7 +181,7 @@ func (tm *TLSManager) loadUserCerts() error { return nil } -func (tm *TLSManager) startCertWatcher() { +func (tm *tlsManager) startCertWatcher() { go func() { var lastCertMod, lastKeyMod time.Time @@ -227,7 +236,7 @@ func (tm *TLSManager) startCertWatcher() { }() } -func (tm *TLSManager) initCertMagic() error { +func (tm *tlsManager) initCertMagic() error { if err := os.MkdirAll(tm.storagePath, 0700); err != nil { return fmt.Errorf("failed to create cert storage directory: %w", err) } @@ -289,14 +298,14 @@ func (tm *TLSManager) initCertMagic() error { return nil } -func (tm *TLSManager) getTLSConfig() *tls.Config { +func (tm *tlsManager) getTLSConfig() *tls.Config { return &tls.Config{ GetCertificate: tm.getCertificate, MinVersion: tls.VersionTLS12, } } -func (tm *TLSManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { if tm.useCertMagic { return tm.magic.GetCertificate(hello) } diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index c993183..250a005 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -31,11 +31,21 @@ func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { } type Forwarder struct { - Listener net.Listener - TunnelType types.TunnelType - ForwardedPort uint16 - SlugManager slug.Manager - Lifecycle Lifecycle + listener net.Listener + tunnelType types.TunnelType + forwardedPort uint16 + slugManager slug.Manager + lifecycle Lifecycle +} + +func NewForwarder(slugManager slug.Manager) *Forwarder { + return &Forwarder{ + listener: nil, + tunnelType: "", + forwardedPort: 0, + slugManager: slugManager, + lifecycle: nil, + } } type Lifecycle interface { @@ -58,7 +68,7 @@ type ForwardingController interface { } func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { - f.Lifecycle = lifecycle + f.lifecycle = lifecycle } func (f *Forwarder) AcceptTCPConnections() { @@ -90,7 +100,7 @@ func (f *Forwarder) AcceptTCPConnections() { resultChan := make(chan channelResult, 1) go func() { - channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) + channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) resultChan <- channelResult{channel, reqs, err} }() @@ -164,27 +174,27 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA } func (f *Forwarder) SetType(tunnelType types.TunnelType) { - f.TunnelType = tunnelType + f.tunnelType = tunnelType } func (f *Forwarder) GetTunnelType() types.TunnelType { - return f.TunnelType + return f.tunnelType } func (f *Forwarder) GetForwardedPort() uint16 { - return f.ForwardedPort + return f.forwardedPort } func (f *Forwarder) SetForwardedPort(port uint16) { - f.ForwardedPort = port + f.forwardedPort = port } func (f *Forwarder) SetListener(listener net.Listener) { - f.Listener = listener + f.listener = listener } func (f *Forwarder) GetListener() net.Listener { - return f.Listener + return f.listener } func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { @@ -197,7 +207,7 @@ func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { func (f *Forwarder) Close() error { if f.GetListener() != nil { - return f.Listener.Close() + return f.listener.Close() } return nil } diff --git a/session/handler.go b/session/handler.go index d536b51..04b1c87 100644 --- a/session/handler.go +++ b/session/handler.go @@ -49,7 +49,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { log.Println("Failed to reply to request:", err) return } - err = s.Lifecycle.Close() + err = s.lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -59,13 +59,13 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { var rawPortToBind uint32 if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { log.Println("Failed to read port from payload:", err) - s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) + s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) return } - err = s.Lifecycle.Close() + err = s.lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -73,13 +73,13 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { } if rawPortToBind > 65535 { - s.Interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) + s.interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) return } - err = s.Lifecycle.Close() + err = s.lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -89,13 +89,13 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { portToBind := uint16(rawPortToBind) if isBlockedPort(portToBind) { - s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) + s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) return } - err = s.Lifecycle.Close() + err = s.lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -110,26 +110,26 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) { unassign, success := portUtil.Default.GetUnassignedPort() portToBind = unassign if !success { - s.Interaction.SendMessage("No available port\r\n") + s.interaction.SendMessage("No available port\r\n") err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) return } - err = s.Lifecycle.Close() + err = s.lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } return } } else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse { - s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) + s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) err := req.Reply(false, nil) if err != nil { log.Println("Failed to reply to request:", err) return } - err = s.Lifecycle.Close() + err = s.lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -193,21 +193,21 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) { return } - s.Forwarder.SetType(types.HTTP) - s.Forwarder.SetForwardedPort(portToBind) - s.SlugManager.Set(slug) - s.Interaction.SendMessage("\033[H\033[2J") - s.Interaction.ShowWelcomeMessage() - s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) - s.Lifecycle.SetStatus(types.RUNNING) - s.Interaction.HandleUserInput() + s.forwarder.SetType(types.HTTP) + s.forwarder.SetForwardedPort(portToBind) + s.slugManager.Set(slug) + s.interaction.SendMessage("\033[H\033[2J") + s.interaction.ShowWelcomeMessage() + s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) + s.lifecycle.SetStatus(types.RUNNING) + s.interaction.HandleUserInput() } func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { log.Printf("Requested forwarding on %s:%d", addr, portToBind) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) if err != nil { - s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) + s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { log.Printf("Failed to reset port status: %v", setErr) } @@ -216,7 +216,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind log.Println("Failed to reply to request:", err) return } - err = s.Lifecycle.Close() + err = s.lifecycle.Close() if err != nil { log.Printf("failed to close session: %v", err) } @@ -253,15 +253,15 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind return } - s.Forwarder.SetType(types.TCP) - s.Forwarder.SetListener(listener) - s.Forwarder.SetForwardedPort(portToBind) - s.Interaction.SendMessage("\033[H\033[2J") - s.Interaction.ShowWelcomeMessage() - s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("DOMAIN", "localhost"), s.Forwarder.GetForwardedPort())) - s.Lifecycle.SetStatus(types.RUNNING) - go s.Forwarder.AcceptTCPConnections() - s.Interaction.HandleUserInput() + s.forwarder.SetType(types.TCP) + s.forwarder.SetListener(listener) + s.forwarder.SetForwardedPort(portToBind) + s.interaction.SendMessage("\033[H\033[2J") + s.interaction.ShowWelcomeMessage() + s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("DOMAIN", "localhost"), s.forwarder.GetForwardedPort())) + s.lifecycle.SetStatus(types.RUNNING) + go s.forwarder.AcceptTCPConnections() + s.interaction.HandleUserInput() } func generateUniqueSlug() string { diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index d8cf4c7..e20577c 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -42,21 +42,37 @@ type Forwarder interface { } type Interaction struct { - InputLength int - CommandBuffer *bytes.Buffer - InteractiveMode bool - InteractionType types.InteractionType - EditSlug string + inputLength int + commandBuffer *bytes.Buffer + interactiveMode bool + interactionType types.InteractionType + editSlug string channel ssh.Channel - SlugManager slug.Manager - Forwarder Forwarder - Lifecycle Lifecycle + slugManager slug.Manager + forwarder Forwarder + lifecycle Lifecycle pendingExit bool updateClientSlug func(oldSlug, newSlug string) bool } +func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction { + return &Interaction{ + inputLength: 0, + commandBuffer: bytes.NewBuffer(make([]byte, 0, 20)), + interactiveMode: false, + interactionType: "", + editSlug: "", + channel: nil, + slugManager: slugManager, + forwarder: forwarder, + lifecycle: nil, + pendingExit: false, + updateClientSlug: nil, + } +} + func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { - i.Lifecycle = lifecycle + i.lifecycle = lifecycle } func (i *Interaction) SetChannel(channel ssh.Channel) { @@ -77,7 +93,7 @@ func (i *Interaction) SendMessage(message string) { func (i *Interaction) HandleUserInput() { buf := make([]byte, 1) - i.InteractiveMode = false + i.interactiveMode = false for { n, err := i.channel.Read(buf) @@ -99,7 +115,7 @@ func (i *Interaction) handleReadError(err error) { } func (i *Interaction) processCharacter(char byte) { - if i.InteractiveMode { + if i.interactiveMode { i.handleInteractiveMode(char) return } @@ -113,7 +129,7 @@ func (i *Interaction) processCharacter(char byte) { } func (i *Interaction) handleInteractiveMode(char byte) { - switch i.InteractionType { + switch i.interactionType { case types.Slug: i.HandleSlugEditMode(char) } @@ -123,7 +139,7 @@ func (i *Interaction) handleExitSequence(char byte) bool { if char == ctrlC { if i.pendingExit { i.SendMessage("Closing connection...\r\n") - if err := i.Lifecycle.Close(); err != nil { + if err := i.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) } return true @@ -147,37 +163,37 @@ func (i *Interaction) handleNonInteractiveInput(char byte) { i.handleBackspace() case char == forwardSlash: i.handleCommandStart() - case i.CommandBuffer.Len() > 0: + case i.commandBuffer.Len() > 0: i.handleCommandInput(char) case char == enterChar: i.SendMessage(clearLine) default: - i.InputLength++ + i.inputLength++ } } func (i *Interaction) handleBackspace() { - if i.InputLength > 0 { + if i.inputLength > 0 { i.SendMessage(backspaceSeq) } - if i.CommandBuffer.Len() > 0 { - i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) + if i.commandBuffer.Len() > 0 { + i.commandBuffer.Truncate(i.commandBuffer.Len() - 1) } } func (i *Interaction) handleCommandStart() { - i.CommandBuffer.Reset() - i.CommandBuffer.WriteByte(forwardSlash) + i.commandBuffer.Reset() + i.commandBuffer.WriteByte(forwardSlash) } func (i *Interaction) handleCommandInput(char byte) { if char == enterChar { i.SendMessage(clearLine) - i.HandleCommand(i.CommandBuffer.String()) + i.HandleCommand(i.commandBuffer.String()) return } - i.CommandBuffer.WriteByte(char) - i.InputLength++ + i.commandBuffer.WriteByte(char) + i.inputLength++ } func (i *Interaction) HandleSlugEditMode(char byte) { @@ -194,15 +210,15 @@ func (i *Interaction) HandleSlugEditMode(char byte) { } func (i *Interaction) handleSlugBackspace() { - if len(i.EditSlug) > 0 { - i.EditSlug = i.EditSlug[:len(i.EditSlug)-1] + if len(i.editSlug) > 0 { + i.editSlug = i.editSlug[:len(i.editSlug)-1] i.refreshSlugDisplay() } } func (i *Interaction) appendToSlug(char byte) { - if isValidSlugChar(char) { - i.EditSlug += string(char) + if len(i.editSlug) < maxSlugLength { + i.editSlug += string(char) i.refreshSlugDisplay() } } @@ -210,16 +226,16 @@ func (i *Interaction) appendToSlug(char byte) { func (i *Interaction) refreshSlugDisplay() { domain := utils.Getenv("DOMAIN", "localhost") i.SendMessage(clearToLineEnd) - i.SendMessage("➤ " + i.EditSlug + "." + domain) + i.SendMessage("➤ " + i.editSlug + "." + domain) } func (i *Interaction) HandleSlugSave() { i.SendMessage(clearScreen) switch { - case isForbiddenSlug(i.EditSlug): + case isForbiddenSlug(i.editSlug): i.showForbiddenSlugMessage() - case !isValidSlug(i.EditSlug): + case !isValidSlug(i.editSlug): i.showInvalidSlugMessage() default: i.updateSlug() @@ -230,8 +246,8 @@ func (i *Interaction) HandleSlugSave() { } func (i *Interaction) updateSlug() { - oldSlug := i.SlugManager.Get() - newSlug := i.EditSlug + oldSlug := i.slugManager.Get() + newSlug := i.editSlug if !i.updateClientSlug(oldSlug, newSlug) { i.HandleSlugUpdateError() @@ -262,8 +278,8 @@ func (i *Interaction) returnToMainScreen() { i.SendMessage(clearScreen) i.ShowWelcomeMessage() i.ShowForwardingMessage() - i.InteractiveMode = false - i.CommandBuffer.Reset() + i.interactiveMode = false + i.commandBuffer.Reset() } func (i *Interaction) HandleSlugCancel() { @@ -271,8 +287,8 @@ func (i *Interaction) HandleSlugCancel() { i.SendMessage("\r\n\r\n⚠️ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n") i.SendMessage("Press any key to continue...\r\n") - i.InteractiveMode = false - i.InteractionType = "" + i.interactiveMode = false + i.interactionType = "" i.WaitForKeyPress() i.SendMessage(clearScreen) @@ -289,7 +305,7 @@ func (i *Interaction) HandleSlugUpdateError() { time.Sleep(1 * time.Second) } - if err := i.Lifecycle.Close(); err != nil { + if err := i.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) } } @@ -308,12 +324,12 @@ func (i *Interaction) HandleCommand(command string) { i.SendMessage("Unknown command\r\n") } - i.CommandBuffer.Reset() + i.commandBuffer.Reset() } func (i *Interaction) handleByeCommand() { i.SendMessage("Closing connection...\r\n") - if err := i.Lifecycle.Close(); err != nil { + if err := i.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) } } @@ -329,32 +345,32 @@ func (i *Interaction) handleClearCommand() { } func (i *Interaction) handleSlugCommand() { - if i.Forwarder.GetTunnelType() != types.HTTP { - i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains\r\n", i.Forwarder.GetTunnelType())) + if i.forwarder.GetTunnelType() != types.HTTP { + i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains\r\n", i.forwarder.GetTunnelType())) return } - i.InteractiveMode = true - i.InteractionType = types.Slug - i.EditSlug = i.SlugManager.Get() + i.interactiveMode = true + i.interactionType = types.Slug + i.editSlug = i.slugManager.Get() i.SendMessage(clearScreen) i.DisplaySlugEditor() domain := utils.Getenv("DOMAIN", "localhost") - i.SendMessage("➤ " + i.EditSlug + "." + domain) + i.SendMessage("➤ " + i.editSlug + "." + domain) } func (i *Interaction) ShowForwardingMessage() { domain := utils.Getenv("DOMAIN", "localhost") - if i.Forwarder.GetTunnelType() == types.HTTP { + if i.forwarder.GetTunnelType() == types.HTTP { protocol := "http" if utils.Getenv("TLS_ENABLED", "false") == "true" { protocol = "https" } - i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.slugManager.Get(), domain)) } else { - i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort())) + i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.forwarder.GetForwardedPort())) } } @@ -385,7 +401,7 @@ func (i *Interaction) ShowWelcomeMessage() { func (i *Interaction) DisplaySlugEditor() { domain := utils.Getenv("DOMAIN", "localhost") - fullDomain := i.SlugManager.Get() + "." + domain + fullDomain := i.slugManager.Get() + "." + domain contentLine := " ║ Current: " + fullDomain boxWidth := calculateBoxWidth(contentLine) diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index 11106f8..8ba1ac6 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -22,16 +22,27 @@ type Forwarder interface { } type Lifecycle struct { - Status types.Status - Conn ssh.Conn - Channel ssh.Channel - - Interaction Interaction - Forwarder Forwarder - SlugManager slug.Manager + status types.Status + conn ssh.Conn + channel ssh.Channel + interaction Interaction + forwarder Forwarder + slugManager slug.Manager unregisterClient func(slug string) } +func NewLifecycle(conn ssh.Conn, interaction Interaction, forwarder Forwarder, slugManager slug.Manager) *Lifecycle { + return &Lifecycle{ + status: "", + conn: conn, + channel: nil, + interaction: interaction, + forwarder: forwarder, + slugManager: slugManager, + unregisterClient: nil, + } +} + func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) { l.unregisterClient = unregisterClient } @@ -46,46 +57,46 @@ type SessionLifecycle interface { } func (l *Lifecycle) GetChannel() ssh.Channel { - return l.Channel + return l.channel } func (l *Lifecycle) SetChannel(channel ssh.Channel) { - l.Channel = channel + l.channel = channel } func (l *Lifecycle) GetConnection() ssh.Conn { - return l.Conn + return l.conn } func (l *Lifecycle) SetStatus(status types.Status) { - l.Status = status + l.status = status } func (l *Lifecycle) Close() error { - err := l.Forwarder.Close() + err := l.forwarder.Close() if err != nil && !errors.Is(err, net.ErrClosed) { return err } - if l.Channel != nil { - err := l.Channel.Close() + if l.channel != nil { + err := l.channel.Close() if err != nil && !errors.Is(err, io.EOF) { return err } } - if l.Conn != nil { - err := l.Conn.Close() + if l.conn != nil { + err := l.conn.Close() if err != nil && !errors.Is(err, net.ErrClosed) { return err } } - clientSlug := l.SlugManager.Get() + clientSlug := l.slugManager.Get() if clientSlug != "" { l.unregisterClient(clientSlug) } - if l.Forwarder.GetTunnelType() == types.TCP { - err := portUtil.Default.SetPortStatus(l.Forwarder.GetForwardedPort(), false) + if l.forwarder.GetTunnelType() == types.TCP { + err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false) if err != nil { return err } diff --git a/session/session.go b/session/session.go index 1d23994..7f50e2d 100644 --- a/session/session.go +++ b/session/session.go @@ -1,7 +1,6 @@ package session import ( - "bytes" "fmt" "log" "sync" @@ -28,36 +27,33 @@ type Session interface { } type SSHSession struct { - Lifecycle lifecycle.SessionLifecycle - Interaction interaction.Controller - Forwarder forwarder.ForwardingController - SlugManager slug.Manager + lifecycle lifecycle.SessionLifecycle + interaction interaction.Controller + forwarder forwarder.ForwardingController + slugManager slug.Manager +} + +func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle { + return s.lifecycle +} + +func (s *SSHSession) GetInteraction() interaction.Controller { + return s.interaction +} + +func (s *SSHSession) GetForwarder() forwarder.ForwardingController { + return s.forwarder +} + +func (s *SSHSession) GetSlugManager() slug.Manager { + return s.slugManager } func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { slugManager := slug.NewManager() - forwarderManager := &forwarder.Forwarder{ - Listener: nil, - TunnelType: "", - ForwardedPort: 0, - SlugManager: slugManager, - } - interactionManager := &interaction.Interaction{ - CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)), - InteractiveMode: false, - EditSlug: "", - SlugManager: slugManager, - Forwarder: forwarderManager, - Lifecycle: nil, - } - lifecycleManager := &lifecycle.Lifecycle{ - Status: "", - Conn: conn, - Channel: nil, - Interaction: interactionManager, - Forwarder: forwarderManager, - SlugManager: slugManager, - } + forwarderManager := forwarder.NewForwarder(slugManager) + interactionManager := interaction.NewInteraction(slugManager, forwarderManager) + lifecycleManager := lifecycle.NewLifecycle(conn, interactionManager, forwarderManager, slugManager) interactionManager.SetLifecycle(lifecycleManager) interactionManager.SetSlugModificator(updateClientSlug) @@ -65,10 +61,10 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan lifecycleManager.SetUnregisterClient(unregisterClient) session := &SSHSession{ - Lifecycle: lifecycleManager, - Interaction: interactionManager, - Forwarder: forwarderManager, - SlugManager: slugManager, + lifecycle: lifecycleManager, + interaction: interactionManager, + forwarder: forwarderManager, + slugManager: slugManager, } var once sync.Once @@ -79,13 +75,13 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan continue } once.Do(func() { - session.Lifecycle.SetChannel(ch) - session.Interaction.SetChannel(ch) + session.lifecycle.SetChannel(ch) + session.interaction.SetChannel(ch) tcpipReq := session.waitForTCPIPForward(forwardingReq) if tcpipReq == nil { - session.Interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200"))) - if err := session.Lifecycle.Close(); err != nil { + session.interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200"))) + if err := session.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) } return @@ -94,7 +90,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan }) go session.HandleGlobalRequest(reqs) } - if err := session.Lifecycle.Close(); err != nil { + if err := session.lifecycle.Close(); err != nil { log.Printf("failed to close session: %v", err) } } @@ -134,7 +130,7 @@ func updateClientSlug(oldSlug, newSlug string) bool { } delete(Clients, oldSlug) - client.SlugManager.Set(newSlug) + client.slugManager.Set(newSlug) Clients[newSlug] = client return true }