diff --git a/.gitea/workflows/sonarqube.yml b/.gitea/workflows/sonarqube.yml
index 9c672ac..57625a2 100644
--- a/.gitea/workflows/sonarqube.yml
+++ b/.gitea/workflows/sonarqube.yml
@@ -13,8 +13,58 @@ jobs:
uses: actions/checkout@v4
with:
fetch-depth: 0
+
+ - name: Set up Go
+ uses: actions/setup-go@v6
+ with:
+ go-version: '1.25.5'
+ cache: false
+
+ - name: Install dependencies
+ run: go mod tidy
+
+ - name: Run go vet
+ run: go vet ./... 2>&1 | tee vet-results.txt
+
+ - name: Run tests with coverage
+ run: |
+ go test ./... -v -coverprofile=coverage
+
+ - name: Run GolangCI-Lint Analysis
+ uses: golangci/golangci-lint-action@v9
+ with:
+ skip-cache: true
+ version: v2.6
+ args: >
+ --issues-exit-code=0
+ --output.text.path=stdout
+ --output.checkstyle.path=golangci-lint-report.xml
+
+ - name: Set SonarQube project key
+ run: |
+ BRANCH_NAME=${GITHUB_REF#refs/heads/}
+ if [ "$BRANCH_NAME" = "main" ]; then
+ SONAR_PROJECT_KEY="tunnel-please"
+ else
+ BRANCH_KEY=${BRANCH_NAME//\//-}
+ SONAR_PROJECT_KEY="tunnel-please-$BRANCH_KEY"
+ fi
+ echo "SONAR_PROJECT_KEY=tunnel-please-$BRANCH_KEY" >> $GITHUB_ENV
+ echo "Using SonarQube Project Key: $SONAR_PROJECT_KEY"
+
- name: SonarQube Scan
uses: SonarSource/sonarqube-scan-action@v7.0.0
env:
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
+ with:
+ args: >
+ -Dsonar.projectKey=${{ env.SONAR_PROJECT_KEY }}
+ -Dsonar.go.coverage.reportPaths=coverage
+ -Dsonar.test.inclusions=**/*_test.go
+ -Dsonar.test.exclusions=**/vendor/**
+ -Dsonar.exclusions=**/*_test.go,**/vendor/**,**/golangci-lint-report.xml
+ -Dsonar.go.govet.reportPaths=vet-results.txt
+ -Dsonar.go.golangci-lint.reportPaths=golangci-lint-report.xml
+ -Dsonar.sources=./
+ -Dsonar.tests=./
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index dc40a4f..fd6e5af 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,4 +4,6 @@ id_rsa*
.env
tmp
certs
-app
\ No newline at end of file
+app
+coverage
+test-results.json
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
index c1c452d..6b8b765 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -22,7 +22,10 @@ RUN --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \
CGO_ENABLED=0 GOOS=linux \
go build -trimpath \
- -ldflags="-w -s -X tunnel_pls/version.Version=${VERSION} -X tunnel_pls/version.BuildDate=${BUILD_DATE} -X tunnel_pls/version.Commit=${COMMIT}" \
+ -ldflags="-w -s \
+ -X tunnel_pls/internal/version.Version=${VERSION} \
+ -X tunnel_pls/internal/version.BuildDate=${BUILD_DATE} \
+ -X tunnel_pls/internal/version.Commit=${COMMIT}" \
-o /app/tunnel_pls \
.
diff --git a/README.md b/README.md
index 474f430..628efbe 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,22 @@
+
+
+

+
# Tunnel Please
-A lightweight SSH-based tunnel server written in Go that enables secure TCP and HTTP forwarding with an interactive terminal interface for managing connections and custom subdomains.
+A lightweight SSH-based tunnel server
+
+
+
+[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
+[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
+[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
+[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
+[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
+[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
+[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
+
+
## Features
@@ -17,108 +33,32 @@ A lightweight SSH-based tunnel server written in Go that enables secure TCP and
The following environment variables can be configured in the `.env` file:
-| Variable | Description | Default | Required |
-|----------|-------------|---------|----------|
-| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
-| `PORT` | SSH server port | `2200` | No |
-| `HTTP_PORT` | HTTP server port | `8080` | No |
-| `HTTPS_PORT` | HTTPS server port | `8443` | No |
-| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
-| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
-| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@` | No |
-| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | - | Yes (if auto-cert) |
-| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
-| `CORS_LIST` | Comma-separated list of allowed CORS origins | - | No |
-| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
-| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
-| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
-| `PPROF_PORT` | Port for pprof server | `6060` | No |
-| `MODE` | Runtime mode: `standalone` (default, no gRPC/auth) or `node` (enable gRPC + auth) | `standalone` | No |
-| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
-| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
-| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | - (required in `node`) | Yes (node mode) |
+| Variable | Description | Default | Required |
+|---------------------|-----------------------------------------------------------------------------|-------------------------|---------------------|
+| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
+| `PORT` | SSH server port | `2200` | No |
+| `HTTP_PORT` | HTTP server port | `8080` | No |
+| `HTTPS_PORT` | HTTPS server port | `8443` | No |
+| `KEY_LOC` | Path to the private key file | `certs/privkey.pem` | No |
+| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
+| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
+| `TLS_STORAGE_PATH` | Path to store TLS certificates | `certs/tls/` | No |
+| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@` | No |
+| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | `-` | Yes (if auto-cert) |
+| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
+| `CORS_LIST` | Comma-separated list of allowed CORS origins | `-` | No |
+| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
+| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
+| `MAX_HEADER_SIZE` | Maximum size of HTTP headers in bytes (4096-131072) | `4096` | No |
+| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
+| `PPROF_PORT` | Port for pprof server | `6060` | No |
+| `MODE` | Runtime mode: `standalone` or `node` | `standalone` | No |
+| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
+| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
+| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | `-` | Yes (node mode) |
**Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality.
-### Automatic TLS Certificate Management
-
-The server supports automatic TLS certificate generation and renewal using [CertMagic](https://github.com/caddyserver/certmagic) with Cloudflare DNS-01 challenge. This is required for wildcard certificate support (`*.yourdomain.com`).
-
-**Certificate Storage:**
-- TLS certificates are stored in `certs/tls/` (relative to application directory)
-- User-provided certificates: `certs/tls/cert.pem` and `certs/tls/privkey.pem`
-- CertMagic automatic certificates: `certs/tls/certmagic/`
-- SSH keys are stored separately in `certs/ssh/`
-
-**How it works:**
-1. If user-provided certificates exist at `certs/tls/cert.pem` and `certs/tls/privkey.pem` and cover both `DOMAIN` and `*.DOMAIN`, they will be used
-2. If certificates are missing, expired, expiring within 30 days, or don't cover the required domains, CertMagic will automatically obtain new certificates from Let's Encrypt
-3. Certificates are automatically renewed before expiration
-4. User-provided certificates support hot-reload (changes detected every 30 seconds)
-
-**Cloudflare API Token Setup:**
-
-To use automatic certificate generation, you need a Cloudflare API token with the following permissions:
-
-1. Go to [Cloudflare Dashboard](https://dash.cloudflare.com/profile/api-tokens)
-2. Click "Create Token"
-3. Use "Create Custom Token" with these permissions:
- - **Zone → Zone → Read** (for all zones or specific zone)
- - **Zone → DNS → Edit** (for all zones or specific zone)
-4. Copy the token and set it as `CF_API_TOKEN` environment variable
-
-**Example configuration for automatic certificates:**
-```env
-DOMAIN=example.com
-TLS_ENABLED=true
-CF_API_TOKEN=your_cloudflare_api_token_here
-ACME_EMAIL=admin@example.com
-# ACME_STAGING=true # Uncomment for testing to avoid rate limits
-```
-
-### SSH Key Auto-Generation
-
-The application will automatically generate a new 4096-bit RSA key pair at `certs/ssh/id_rsa` if it doesn't exist. This makes it easier to get started without manually creating SSH keys. SSH keys are stored separately from TLS certificates.
-
-### Memory Optimization
-
-The application uses a buffer pool with controlled buffer sizes to prevent excessive memory usage under high concurrent loads. The `BUFFER_SIZE` environment variable controls the size of buffers used for io.Copy operations:
-
-- **Default:** 32768 bytes (32 KB) - Good balance for most scenarios
-- **Minimum:** 4096 bytes (4 KB) - Lower memory usage, more CPU overhead
-- **Maximum:** 1048576 bytes (1 MB) - Higher throughput, more memory usage
-
-**Recommended settings based on load:**
-- **Low traffic (<100 concurrent):** `BUFFER_SIZE=32768` (default)
-- **High traffic (>100 concurrent):** `BUFFER_SIZE=16384` or `BUFFER_SIZE=8192`
-- **Very high traffic (>1000 concurrent):** `BUFFER_SIZE=8192` or `BUFFER_SIZE=4096`
-
-The buffer pool reuses buffers across connections, preventing memory fragmentation and reducing garbage collection pressure.
-
-### Profiling with pprof
-
-To enable profiling for performance analysis:
-
-1. Set `PPROF_ENABLED=true` in your `.env` file
-2. Optionally set `PPROF_PORT` to your desired port (default: 6060)
-3. Access profiling data at `http://localhost:6060/debug/pprof/`
-
-Common pprof endpoints:
-- `/debug/pprof/` - Index page with available profiles
-- `/debug/pprof/heap` - Memory allocation profile
-- `/debug/pprof/goroutine` - Stack traces of all current goroutines
-- `/debug/pprof/profile` - CPU profile (30-second sample by default)
-- `/debug/pprof/trace` - Execution trace
-
-Example usage with `go tool pprof`:
-```bash
-# Analyze CPU profile
-go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
-
-# Analyze memory heap
-go tool pprof http://localhost:6060/debug/pprof/heap
-```
-
## Docker Deployment
Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`.
@@ -197,22 +137,6 @@ docker-compose -f docker-compose.tcp.yml up -d
docker-compose -f docker-compose.root.yml down
```
-### Volume Management
-
-All configurations use a named volume `certs` for persistent storage:
-- SSH keys: `/app/certs/ssh/`
-- TLS certificates: `/app/certs/tls/`
-
-To backup certificates:
-```bash
-docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar czf /backup/certs-backup.tar.gz -C /data .
-```
-
-To restore certificates:
-```bash
-docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar xzf /backup/certs-backup.tar.gz -C /data
-```
-
### Recommendation
**Use `docker-compose.root.yml`** for production deployments if you need:
diff --git a/docs/images/gopher.png b/docs/images/gopher.png
new file mode 100644
index 0000000..bdc8fec
Binary files /dev/null and b/docs/images/gopher.png differ
diff --git a/go.mod b/go.mod
index 958657e..214f1c5 100644
--- a/go.mod
+++ b/go.mod
@@ -11,6 +11,7 @@ require (
github.com/joho/godotenv v1.5.1
github.com/libdns/cloudflare v0.2.2
github.com/muesli/termenv v0.16.0
+ github.com/stretchr/testify v1.8.1
golang.org/x/crypto v0.47.0
google.golang.org/grpc v1.78.0
google.golang.org/protobuf v1.36.11
@@ -27,6 +28,7 @@ require (
github.com/clipperhouse/displaywidth v0.6.2 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
+ github.com/davecgh/go-spew v1.1.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/libdns/libdns v1.1.1 // indirect
@@ -38,8 +40,10 @@ require (
github.com/miekg/dns v1.1.69 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
+ github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/sahilm/fuzzy v0.1.1 // indirect
+ github.com/stretchr/objx v0.5.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/zeebo/blake3 v0.2.4 // indirect
go.uber.org/multierr v1.11.0 // indirect
@@ -52,4 +56,5 @@ require (
golang.org/x/text v0.33.0 // indirect
golang.org/x/tools v0.40.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
+ gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index 11912af..4356e9d 100644
--- a/go.sum
+++ b/go.sum
@@ -32,6 +32,7 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
@@ -80,6 +81,12 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
+github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
@@ -138,5 +145,8 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go
new file mode 100644
index 0000000..6c3a1f9
--- /dev/null
+++ b/internal/bootstrap/bootstrap.go
@@ -0,0 +1,196 @@
+package bootstrap
+
+import (
+ "context"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+ "tunnel_pls/internal/config"
+ "tunnel_pls/internal/grpc/client"
+ "tunnel_pls/internal/key"
+ "tunnel_pls/internal/port"
+ "tunnel_pls/internal/random"
+ "tunnel_pls/internal/registry"
+ "tunnel_pls/internal/transport"
+ "tunnel_pls/internal/version"
+ "tunnel_pls/server"
+ "tunnel_pls/types"
+
+ "golang.org/x/crypto/ssh"
+)
+
+type Bootstrap struct {
+ Randomizer random.Random
+ Config config.Config
+ SessionRegistry registry.Registry
+ Port port.Port
+ GrpcClient client.Client
+ ErrChan chan error
+ SignalChan chan os.Signal
+}
+
+func New(config config.Config, port port.Port) (*Bootstrap, error) {
+ randomizer := random.New()
+ sessionRegistry := registry.NewRegistry()
+
+ if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil {
+ return nil, err
+ }
+
+ grpcClient, err := client.New(config, sessionRegistry)
+ if err != nil {
+ return nil, err
+ }
+
+ errChan := make(chan error, 5)
+ signalChan := make(chan os.Signal, 1)
+
+ return &Bootstrap{
+ Randomizer: randomizer,
+ Config: config,
+ SessionRegistry: sessionRegistry,
+ Port: port,
+ GrpcClient: grpcClient,
+ ErrChan: errChan,
+ SignalChan: signalChan,
+ }, nil
+}
+
+func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) {
+ sshCfg := &ssh.ServerConfig{
+ NoClientAuth: true,
+ ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
+ }
+
+ if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
+ return nil, fmt.Errorf("generate ssh key: %w", err)
+ }
+ privateBytes, err := os.ReadFile(sshKeyPath)
+ if err != nil {
+ return nil, fmt.Errorf("read private key: %w", err)
+ }
+ private, err := ssh.ParsePrivateKey(privateBytes)
+ if err != nil {
+ return nil, fmt.Errorf("parse private key: %w", err)
+ }
+ sshCfg.AddHostKey(private)
+ return sshCfg, nil
+}
+
+func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error {
+ healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
+ defer healthCancel()
+ if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil {
+ return fmt.Errorf("gRPC health check failed: %w", err)
+ }
+
+ go func() {
+ if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
+ errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
+ }
+ }()
+
+ return nil
+}
+
+func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
+ httpserver := transport.NewHTTPServer(conf, registry)
+ ln, err := httpserver.Listen()
+ if err != nil {
+ errChan <- fmt.Errorf("failed to start http server: %w", err)
+ return
+ }
+ if err = httpserver.Serve(ln); err != nil {
+ errChan <- fmt.Errorf("error when serving http server: %w", err)
+ }
+}
+
+func startHTTPSServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
+ tlsCfg, err := transport.NewTLSConfig(conf)
+ if err != nil {
+ errChan <- fmt.Errorf("failed to create TLS config: %w", err)
+ return
+ }
+ httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg)
+ ln, err := httpsServer.Listen()
+ if err != nil {
+ errChan <- fmt.Errorf("failed to create TLS config: %w", err)
+ return
+ }
+ if err = httpsServer.Serve(ln); err != nil {
+ errChan <- fmt.Errorf("error when serving https server: %w", err)
+ }
+}
+
+func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) {
+ sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort())
+ if err != nil {
+ errChan <- err
+ return
+ }
+
+ sshServer.Start()
+
+ errChan <- sshServer.Close()
+}
+
+func startPprof(pprofPort string, errChan chan<- error) {
+ pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
+ log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
+ if err := http.ListenAndServe(pprofAddr, nil); err != nil {
+ errChan <- fmt.Errorf("pprof server error: %v", err)
+ }
+}
+func (b *Bootstrap) Run() error {
+ sshConfig, err := newSSHConfig(b.Config.KeyLoc())
+ if err != nil {
+ return fmt.Errorf("failed to create SSH config: %w", err)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM)
+
+ if b.Config.Mode() == types.ServerModeNODE {
+ err = b.startGRPCClient(ctx, b.Config, b.ErrChan)
+ if err != nil {
+ return fmt.Errorf("failed to start gRPC client: %w", err)
+ }
+ defer func(grpcClient client.Client) {
+ err = grpcClient.Close()
+ if err != nil {
+ log.Printf("failed to close gRPC client")
+ }
+ }(b.GrpcClient)
+ }
+
+ go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan)
+
+ if b.Config.TLSEnabled() {
+ go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan)
+ }
+
+ go func() {
+ startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan)
+ }()
+
+ if b.Config.PprofEnabled() {
+ go startPprof(b.Config.PprofPort(), b.ErrChan)
+ }
+
+ log.Println("All services started successfully")
+
+ select {
+ case err = <-b.ErrChan:
+ return fmt.Errorf("service error: %w", err)
+ case sig := <-b.SignalChan:
+ log.Printf("Received signal %s, initiating graceful shutdown", sig)
+ cancel()
+ return nil
+ }
+}
diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go
new file mode 100644
index 0000000..2453cde
--- /dev/null
+++ b/internal/bootstrap/bootstrap_test.go
@@ -0,0 +1,558 @@
+package bootstrap
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ _ "net/http/pprof"
+ "os"
+ "path/filepath"
+ "strconv"
+ "testing"
+ "time"
+ "tunnel_pls/internal/config"
+ "tunnel_pls/internal/port"
+ "tunnel_pls/internal/registry"
+ "tunnel_pls/session/slug"
+ "tunnel_pls/types"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "google.golang.org/grpc"
+)
+
+type MockSessionRegistry struct {
+ mock.Mock
+}
+
+func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
+ args := m.Called(key)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(registry.Session), args.Error(1)
+}
+
+func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
+ args := m.Called(user, key)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(registry.Session), args.Error(1)
+}
+
+func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
+ args := m.Called(user, oldKey, newKey)
+ return args.Error(0)
+}
+
+func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
+ args := m.Called(key, session)
+ return args.Bool(0)
+}
+
+func (m *MockSessionRegistry) Remove(key registry.Key) {
+ m.Called(key)
+}
+
+func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
+ args := m.Called(user)
+ return args.Get(0).([]registry.Session)
+}
+
+func (m *MockSessionRegistry) Slug() slug.Slug {
+ args := m.Called()
+ return args.Get(0).(slug.Slug)
+}
+
+type MockRandom struct {
+ mock.Mock
+}
+
+func (m *MockRandom) String(length int) (string, error) {
+ args := m.Called(length)
+ return args.String(0), args.Error(1)
+}
+
+type MockConfig struct {
+ mock.Mock
+}
+
+func (m *MockConfig) Domain() string { return m.Called().String(0) }
+func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
+func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
+func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
+func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
+func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
+func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
+func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
+func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
+func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
+func (m *MockConfig) Mode() types.ServerMode {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return 0
+ }
+ switch v := args.Get(0).(type) {
+ case types.ServerMode:
+ return v
+ case int:
+ return types.ServerMode(v)
+ default:
+ return types.ServerMode(args.Int(0))
+ }
+}
+func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
+func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
+func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
+func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
+
+type MockPort struct {
+ mock.Mock
+}
+
+func (m *MockPort) AddRange(startPort, endPort uint16) error {
+ return m.Called(startPort, endPort).Error(0)
+}
+func (m *MockPort) Unassigned() (uint16, bool) {
+ args := m.Called()
+ var mPort uint16
+ if args.Get(0) != nil {
+ switch v := args.Get(0).(type) {
+ case int:
+ mPort = uint16(v)
+ case uint16:
+ mPort = v
+ case uint32:
+ mPort = uint16(v)
+ case int32:
+ mPort = uint16(v)
+ case float64:
+ mPort = uint16(v)
+ default:
+ mPort = uint16(args.Int(0))
+ }
+ }
+ return mPort, args.Bool(1)
+}
+func (m *MockPort) SetStatus(port uint16, assigned bool) error {
+ return m.Called(port, assigned).Error(0)
+}
+func (m *MockPort) Claim(port uint16) bool {
+ return m.Called(port).Bool(0)
+}
+
+type MockGRPCClient struct {
+ mock.Mock
+}
+
+func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
+ args := m.Called()
+ return args.Get(0).(*grpc.ClientConn)
+}
+
+func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
+ args := m.Called(ctx, token)
+ return args.Bool(0), args.String(1), args.Error(2)
+}
+
+func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
+ args := m.Called(ctx)
+ return args.Error(0)
+}
+
+func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
+ args := m.Called(ctx, domain, token)
+ return args.Error(0)
+}
+
+func (m *MockGRPCClient) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func TestNew(t *testing.T) {
+ tests := []struct {
+ name string
+ setupConfig func() config.Config
+ setupPort func() port.Port
+ wantErr bool
+ errContains string
+ }{
+ {
+ name: "Success New with default value",
+ wantErr: false,
+ },
+ {
+ name: "Error when AddRange fails",
+ setupPort: func() port.Port {
+ mockPort := &MockPort{}
+ mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range"))
+ return mockPort
+ },
+ wantErr: true,
+ errContains: "invalid port range",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var mockPort port.Port
+ if tt.setupPort != nil {
+ mockPort = tt.setupPort()
+ } else {
+ mockPort = port.New()
+ }
+
+ var mockConfig config.Config
+ if tt.setupConfig != nil {
+ mockConfig = tt.setupConfig()
+ } else {
+ var err error
+ mockConfig, err = config.MustLoad()
+ assert.NoError(t, err)
+ }
+
+ bootstrap, err := New(mockConfig, mockPort)
+
+ if tt.wantErr {
+ assert.Error(t, err)
+ if tt.errContains != "" {
+ assert.Contains(t, err.Error(), tt.errContains)
+ }
+ assert.Nil(t, bootstrap)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, bootstrap)
+ assert.NotNil(t, bootstrap.Randomizer)
+ assert.NotNil(t, bootstrap.SessionRegistry)
+ assert.NotNil(t, bootstrap.Config)
+ assert.NotNil(t, bootstrap.Port)
+ assert.NotNil(t, bootstrap.ErrChan)
+ assert.NotNil(t, bootstrap.SignalChan)
+ }
+ })
+ }
+}
+
+func randomAvailablePort() (string, error) {
+ listener, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return "", err
+ }
+ defer func(listener net.Listener) {
+ _ = listener.Close()
+ }(listener)
+
+ mPort := listener.Addr().(*net.TCPAddr).Port
+ return strconv.Itoa(mPort), nil
+}
+
+func TestRun(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockErrChan := make(chan error, 1)
+ mockSignalChan := make(chan os.Signal, 1)
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockPort := &MockPort{}
+
+ tmpDir := t.TempDir()
+ keyLoc := filepath.Join(tmpDir, "key.key")
+
+ tests := []struct {
+ name string
+ setupConfig func() *MockConfig
+ setupGrpcClient func() *MockGRPCClient
+ needCerts bool
+ expectError bool
+ }{
+ {
+ name: "successful run and termination",
+ setupConfig: func() *MockConfig {
+ mockConfig := &MockConfig{}
+ mockConfig.On("KeyLoc").Return(keyLoc)
+ mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("SSHPort").Return("0")
+ mockConfig.On("HTTPPort").Return("0")
+ mockConfig.On("HTTPSPort").Return("0")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockConfig.On("TLSRedirect").Return(false)
+ mockConfig.On("ACMEEmail").Return("test@example.com")
+ mockConfig.On("CFAPIToken").Return("fake-token")
+ mockConfig.On("ACMEStaging").Return(true)
+ mockConfig.On("AllowedPortsStart").Return(uint16(1024))
+ mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
+ mockConfig.On("BufferSize").Return(4096)
+ mockConfig.On("PprofEnabled").Return(false)
+ mockConfig.On("PprofPort").Return("0")
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("0")
+ mockConfig.On("NodeToken").Return("fake-node-token")
+ return mockConfig
+ },
+ expectError: false,
+ },
+ {
+ name: "error from SSH server invalid port",
+ setupConfig: func() *MockConfig {
+ mockConfig := &MockConfig{}
+ mockConfig.On("KeyLoc").Return(keyLoc)
+ mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("SSHPort").Return("invalid")
+ mockConfig.On("HTTPPort").Return("0")
+ mockConfig.On("HTTPSPort").Return("0")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockConfig.On("TLSRedirect").Return(false)
+ mockConfig.On("ACMEEmail").Return("test@example.com")
+ mockConfig.On("CFAPIToken").Return("fake-token")
+ mockConfig.On("ACMEStaging").Return(true)
+ mockConfig.On("AllowedPortsStart").Return(uint16(1024))
+ mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
+ mockConfig.On("BufferSize").Return(4096)
+ mockConfig.On("PprofEnabled").Return(false)
+ mockConfig.On("PprofPort").Return("0")
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("0")
+ mockConfig.On("NodeToken").Return("fake-node-token")
+ return mockConfig
+ },
+ expectError: true,
+ },
+ {
+ name: "error from HTTP server invalid port",
+ setupConfig: func() *MockConfig {
+ mockConfig := &MockConfig{}
+ mockConfig.On("KeyLoc").Return(keyLoc)
+ mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("SSHPort").Return("0")
+ mockConfig.On("HTTPPort").Return("invalid")
+ mockConfig.On("HTTPSPort").Return("0")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockConfig.On("TLSRedirect").Return(false)
+ mockConfig.On("ACMEEmail").Return("test@example.com")
+ mockConfig.On("CFAPIToken").Return("fake-token")
+ mockConfig.On("ACMEStaging").Return(true)
+ mockConfig.On("AllowedPortsStart").Return(uint16(1024))
+ mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
+ mockConfig.On("BufferSize").Return(4096)
+ mockConfig.On("PprofEnabled").Return(false)
+ mockConfig.On("PprofPort").Return("0")
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("0")
+ mockConfig.On("NodeToken").Return("fake-node-token")
+ return mockConfig
+ },
+ expectError: true,
+ },
+ {
+ name: "error from HTTPS server invalid port",
+ setupConfig: func() *MockConfig {
+ tempDir := os.TempDir()
+ mockConfig := &MockConfig{}
+ mockConfig.On("KeyLoc").Return(keyLoc)
+ mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("SSHPort").Return("0")
+ mockConfig.On("HTTPPort").Return("0")
+ mockConfig.On("HTTPSPort").Return("invalid")
+ mockConfig.On("TLSEnabled").Return(true)
+ mockConfig.On("TLSRedirect").Return(false)
+ mockConfig.On("TLSStoragePath").Return(tempDir)
+ mockConfig.On("ACMEEmail").Return("test@example.com")
+ mockConfig.On("CFAPIToken").Return("fake-token")
+ mockConfig.On("ACMEStaging").Return(true)
+ mockConfig.On("AllowedPortsStart").Return(uint16(1024))
+ mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
+ mockConfig.On("BufferSize").Return(4096)
+ mockConfig.On("PprofEnabled").Return(false)
+ mockConfig.On("PprofPort").Return("0")
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("0")
+ mockConfig.On("NodeToken").Return("fake-node-token")
+ return mockConfig
+ },
+ expectError: true,
+ },
+ {
+ name: "grpc health check failed",
+ setupConfig: func() *MockConfig {
+ mockConfig := &MockConfig{}
+ mockConfig.On("KeyLoc").Return(keyLoc)
+ mockConfig.On("Mode").Return(types.ServerModeNODE)
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("SSHPort").Return("0")
+ mockConfig.On("HTTPPort").Return("0")
+ mockConfig.On("HTTPSPort").Return("0")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockConfig.On("TLSRedirect").Return(false)
+ mockConfig.On("ACMEEmail").Return("test@example.com")
+ mockConfig.On("CFAPIToken").Return("fake-token")
+ mockConfig.On("ACMEStaging").Return(true)
+ mockConfig.On("AllowedPortsStart").Return(uint16(1024))
+ mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
+ mockConfig.On("BufferSize").Return(4096)
+ mockConfig.On("PprofEnabled").Return(false)
+ mockConfig.On("PprofPort").Return("0")
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("invalid")
+ mockConfig.On("NodeToken").Return("fake-node-token")
+ return mockConfig
+ },
+ setupGrpcClient: func() *MockGRPCClient {
+ mockGRPCClient := &MockGRPCClient{}
+ mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed"))
+ return mockGRPCClient
+ },
+ expectError: true,
+ },
+ {
+ name: "successful run with pprof enabled",
+ setupConfig: func() *MockConfig {
+ mockConfig := &MockConfig{}
+ pprofPort, _ := randomAvailablePort()
+ mockConfig.On("KeyLoc").Return(keyLoc)
+ mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("SSHPort").Return("0")
+ mockConfig.On("HTTPPort").Return("0")
+ mockConfig.On("HTTPSPort").Return("0")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockConfig.On("TLSRedirect").Return(false)
+ mockConfig.On("ACMEEmail").Return("test@example.com")
+ mockConfig.On("CFAPIToken").Return("fake-token")
+ mockConfig.On("ACMEStaging").Return(true)
+ mockConfig.On("AllowedPortsStart").Return(uint16(1024))
+ mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
+ mockConfig.On("BufferSize").Return(4096)
+ mockConfig.On("PprofEnabled").Return(true)
+ mockConfig.On("PprofPort").Return(pprofPort)
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("0")
+ mockConfig.On("NodeToken").Return("fake-node-token")
+ return mockConfig
+ },
+ expectError: false,
+ }, {
+ name: "successful run in NODE mode with signal",
+ setupConfig: func() *MockConfig {
+ mockConfig := &MockConfig{}
+ mockConfig.On("KeyLoc").Return(keyLoc)
+ mockConfig.On("Mode").Return(types.ServerModeNODE)
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("SSHPort").Return("0")
+ mockConfig.On("HTTPPort").Return("0")
+ mockConfig.On("HTTPSPort").Return("0")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockConfig.On("TLSRedirect").Return(false)
+ mockConfig.On("ACMEEmail").Return("test@example.com")
+ mockConfig.On("CFAPIToken").Return("fake-token")
+ mockConfig.On("ACMEStaging").Return(true)
+ mockConfig.On("AllowedPortsStart").Return(uint16(1024))
+ mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
+ mockConfig.On("BufferSize").Return(4096)
+ mockConfig.On("PprofEnabled").Return(false)
+ mockConfig.On("PprofPort").Return("0")
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("0")
+ mockConfig.On("NodeToken").Return("fake-node-token")
+ return mockConfig
+ },
+ setupGrpcClient: func() *MockGRPCClient {
+ mockGRPCClient := &MockGRPCClient{}
+ mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
+ mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
+ mockGRPCClient.On("Close").Return(nil)
+ return mockGRPCClient
+ },
+ expectError: false,
+ }, {
+ name: "successful run in NODE mode with signal buf error when closing",
+ setupConfig: func() *MockConfig {
+ mockConfig := &MockConfig{}
+ mockConfig.On("KeyLoc").Return(keyLoc)
+ mockConfig.On("Mode").Return(types.ServerModeNODE)
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("SSHPort").Return("0")
+ mockConfig.On("HTTPPort").Return("0")
+ mockConfig.On("HTTPSPort").Return("0")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockConfig.On("TLSRedirect").Return(false)
+ mockConfig.On("ACMEEmail").Return("test@example.com")
+ mockConfig.On("CFAPIToken").Return("fake-token")
+ mockConfig.On("ACMEStaging").Return(true)
+ mockConfig.On("AllowedPortsStart").Return(uint16(1024))
+ mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
+ mockConfig.On("BufferSize").Return(4096)
+ mockConfig.On("PprofEnabled").Return(false)
+ mockConfig.On("PprofPort").Return("0")
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("0")
+ mockConfig.On("NodeToken").Return("fake-node-token")
+ return mockConfig
+ },
+ setupGrpcClient: func() *MockGRPCClient {
+ mockGRPCClient := &MockGRPCClient{}
+ mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
+ mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
+ mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy"))
+ return mockGRPCClient
+ },
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockConfig := tt.setupConfig()
+ mockGRPCClient := &MockGRPCClient{}
+ bootstrap := &Bootstrap{
+ Randomizer: mockRandom,
+ Config: mockConfig,
+ SessionRegistry: mockSessionRegistry,
+ Port: mockPort,
+ ErrChan: mockErrChan,
+ SignalChan: mockSignalChan,
+ GrpcClient: mockGRPCClient,
+ }
+
+ if tt.setupGrpcClient != nil {
+ bootstrap.GrpcClient = tt.setupGrpcClient()
+ }
+
+ done := make(chan error, 1)
+ go func() {
+ done <- bootstrap.Run()
+ }()
+
+ if tt.expectError {
+ err := <-done
+ assert.Error(t, err)
+ } else if tt.name == "successful run with pprof enabled" {
+ time.Sleep(200 * time.Millisecond)
+ fmt.Println(mockConfig.PprofPort())
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
+ assert.NoError(t, err)
+ assert.Equal(t, 200, resp.StatusCode)
+ err = resp.Body.Close()
+ assert.NoError(t, err)
+ mockSignalChan <- os.Interrupt
+ err = <-done
+ assert.NoError(t, err)
+ } else {
+ time.Sleep(time.Second)
+ mockSignalChan <- os.Interrupt
+ err := <-done
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index 62e1aca..3e6c9e1 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -9,8 +9,11 @@ type Config interface {
HTTPPort() string
HTTPSPort() string
+ KeyLoc() string
+
TLSEnabled() bool
TLSRedirect() bool
+ TLSStoragePath() string
ACMEEmail() string
CFAPIToken() string
@@ -20,6 +23,7 @@ type Config interface {
AllowedPortsEnd() uint16
BufferSize() int
+ HeaderSize() int
PprofEnabled() bool
PprofPort() string
@@ -47,14 +51,17 @@ func (c *config) Domain() string { return c.domain }
func (c *config) SSHPort() string { return c.sshPort }
func (c *config) HTTPPort() string { return c.httpPort }
func (c *config) HTTPSPort() string { return c.httpsPort }
+func (c *config) KeyLoc() string { return c.keyLoc }
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
+func (c *config) TLSStoragePath() string { return c.tlsStoragePath }
func (c *config) ACMEEmail() string { return c.acmeEmail }
func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging }
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
func (c *config) BufferSize() int { return c.bufferSize }
+func (c *config) HeaderSize() int { return c.headerSize }
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
func (c *config) PprofPort() string { return c.pprofPort }
func (c *config) Mode() types.ServerMode { return c.mode }
diff --git a/internal/config/config_test.go b/internal/config/config_test.go
new file mode 100644
index 0000000..85f93f3
--- /dev/null
+++ b/internal/config/config_test.go
@@ -0,0 +1,405 @@
+package config
+
+import (
+ "os"
+ "testing"
+ "tunnel_pls/types"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestGetenv(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ val string
+ def string
+ expected string
+ }{
+ {
+ name: "returns existing env",
+ key: "TEST_ENV_EXIST",
+ val: "value",
+ def: "default",
+ expected: "value",
+ },
+ {
+ name: "returns default when env missing",
+ key: "TEST_ENV_MISSING",
+ val: "",
+ def: "default",
+ expected: "default",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.val != "" {
+ t.Setenv(tt.key, tt.val)
+ } else {
+ err := os.Unsetenv(tt.key)
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
+ })
+ }
+}
+
+func TestGetenvBool(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ val string
+ def bool
+ expected bool
+ }{
+ {
+ name: "returns true when env is true",
+ key: "TEST_BOOL_TRUE",
+ val: "true",
+ def: false,
+ expected: true,
+ },
+ {
+ name: "returns false when env is false",
+ key: "TEST_BOOL_FALSE",
+ val: "false",
+ def: true,
+ expected: false,
+ },
+ {
+ name: "returns default when env missing",
+ key: "TEST_BOOL_MISSING",
+ val: "",
+ def: true,
+ expected: true,
+ },
+ {
+ name: "returns false when env is not true",
+ key: "TEST_BOOL_INVALID",
+ val: "yes",
+ def: true,
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.val != "" {
+ t.Setenv(tt.key, tt.val)
+ } else {
+ err := os.Unsetenv(tt.key)
+ assert.NoError(t, err)
+ }
+ assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
+ })
+ }
+}
+
+func TestParseMode(t *testing.T) {
+ tests := []struct {
+ name string
+ mode string
+ expect types.ServerMode
+ expectErr bool
+ }{
+ {"standalone", "standalone", types.ServerModeSTANDALONE, false},
+ {"node", "node", types.ServerModeNODE, false},
+ {"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false},
+ {"invalid", "invalid", 0, true},
+ {"empty (default)", "", types.ServerModeSTANDALONE, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.mode != "" {
+ t.Setenv("MODE", tt.mode)
+ } else {
+ err := os.Unsetenv("MODE")
+ assert.NoError(t, err)
+ }
+ mode, err := parseMode()
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expect, mode)
+ }
+ })
+ }
+}
+
+func TestParseAllowedPorts(t *testing.T) {
+ tests := []struct {
+ name string
+ val string
+ start uint16
+ end uint16
+ expectErr bool
+ }{
+ {"valid range", "1000-2000", 1000, 2000, false},
+ {"empty", "", 0, 0, false},
+ {"invalid format - no dash", "1000", 0, 0, true},
+ {"invalid format - too many dashes", "1000-2000-3000", 0, 0, true},
+ {"invalid start port", "abc-2000", 0, 0, true},
+ {"invalid end port", "1000-abc", 0, 0, true},
+ {"out of range start", "70000-80000", 0, 0, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.val != "" {
+ t.Setenv("ALLOWED_PORTS", tt.val)
+ } else {
+ err := os.Unsetenv("ALLOWED_PORTS")
+ assert.NoError(t, err)
+ }
+ start, end, err := parseAllowedPorts()
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.start, start)
+ assert.Equal(t, tt.end, end)
+ }
+ })
+ }
+}
+
+func TestParseBufferSize(t *testing.T) {
+ tests := []struct {
+ name string
+ val string
+ expect int
+ }{
+ {"valid size", "8192", 8192},
+ {"default size", "", 32768},
+ {"too small", "1024", 4096},
+ {"too large", "2000000", 4096},
+ {"invalid format", "abc", 4096},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.val != "" {
+ t.Setenv("BUFFER_SIZE", tt.val)
+ } else {
+ err := os.Unsetenv("BUFFER_SIZE")
+ assert.NoError(t, err)
+ }
+ size := parseBufferSize()
+ assert.Equal(t, tt.expect, size)
+ })
+ }
+}
+
+func TestParseHeaderSize(t *testing.T) {
+ tests := []struct {
+ name string
+ val string
+ expect int
+ }{
+ {"valid size", "8192", 8192},
+ {"default size", "", 4096},
+ {"too small", "1024", 4096},
+ {"too large", "2000000", 4096},
+ {"invalid format", "abc", 4096},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.val != "" {
+ t.Setenv("MAX_HEADER_SIZE", tt.val)
+ } else {
+ err := os.Unsetenv("MAX_HEADER_SIZE")
+ assert.NoError(t, err)
+ }
+ size := parseHeaderSize()
+ assert.Equal(t, tt.expect, size)
+ })
+ }
+}
+
+func TestParse(t *testing.T) {
+ tests := []struct {
+ name string
+ envs map[string]string
+ expectErr bool
+ }{
+ {
+ name: "minimal valid config",
+ envs: map[string]string{
+ "DOMAIN": "example.com",
+ },
+ expectErr: false,
+ },
+ {
+ name: "TLS enabled without token",
+ envs: map[string]string{
+ "TLS_ENABLED": "true",
+ },
+ expectErr: true,
+ },
+ {
+ name: "TLS enabled with token",
+ envs: map[string]string{
+ "TLS_ENABLED": "true",
+ "CF_API_TOKEN": "secret",
+ },
+ expectErr: false,
+ },
+ {
+ name: "Node mode without token",
+ envs: map[string]string{
+ "MODE": "node",
+ },
+ expectErr: true,
+ },
+ {
+ name: "Node mode with token",
+ envs: map[string]string{
+ "MODE": "node",
+ "NODE_TOKEN": "token",
+ },
+ expectErr: false,
+ },
+ {
+ name: "invalid mode",
+ envs: map[string]string{
+ "MODE": "invalid",
+ },
+ expectErr: true,
+ },
+ {
+ name: "invalid allowed ports",
+ envs: map[string]string{
+ "ALLOWED_PORTS": "1000",
+ },
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ os.Clearenv()
+ for k, v := range tt.envs {
+ t.Setenv(k, v)
+ }
+ cfg, err := parse()
+ if tt.expectErr {
+ assert.Error(t, err)
+ assert.Nil(t, cfg)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, cfg)
+ }
+ })
+ }
+}
+
+func TestGetters(t *testing.T) {
+ envs := map[string]string{
+ "DOMAIN": "example.com",
+ "PORT": "2222",
+ "HTTP_PORT": "80",
+ "HTTPS_PORT": "443",
+ "KEY_LOC": "certs/ssh/id_rsa",
+ "TLS_ENABLED": "true",
+ "TLS_REDIRECT": "true",
+ "TLS_STORAGE_PATH": "certs/tls/",
+ "ACME_EMAIL": "test@example.com",
+ "CF_API_TOKEN": "token",
+ "ACME_STAGING": "true",
+ "ALLOWED_PORTS": "1000-2000",
+ "BUFFER_SIZE": "16384",
+ "MAX_HEADER_SIZE": "4096",
+ "PPROF_ENABLED": "true",
+ "PPROF_PORT": "7070",
+ "MODE": "standalone",
+ "GRPC_ADDRESS": "127.0.0.1",
+ "GRPC_PORT": "9090",
+ "NODE_TOKEN": "ntoken",
+ }
+
+ os.Clearenv()
+ for k, v := range envs {
+ t.Setenv(k, v)
+ }
+
+ cfg, err := parse()
+ assert.NoError(t, err)
+
+ assert.Equal(t, "example.com", cfg.Domain())
+ assert.Equal(t, "2222", cfg.SSHPort())
+ assert.Equal(t, "80", cfg.HTTPPort())
+ assert.Equal(t, "443", cfg.HTTPSPort())
+ assert.Equal(t, "certs/ssh/id_rsa", cfg.KeyLoc())
+ assert.Equal(t, true, cfg.TLSEnabled())
+ assert.Equal(t, true, cfg.TLSRedirect())
+ assert.Equal(t, "certs/tls/", cfg.TLSStoragePath())
+ assert.Equal(t, "test@example.com", cfg.ACMEEmail())
+ assert.Equal(t, "token", cfg.CFAPIToken())
+ assert.Equal(t, true, cfg.ACMEStaging())
+ assert.Equal(t, uint16(1000), cfg.AllowedPortsStart())
+ assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd())
+ assert.Equal(t, 16384, cfg.BufferSize())
+ assert.Equal(t, 4096, cfg.HeaderSize())
+ assert.Equal(t, true, cfg.PprofEnabled())
+ assert.Equal(t, "7070", cfg.PprofPort())
+ assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode())
+ assert.Equal(t, "127.0.0.1", cfg.GRPCAddress())
+ assert.Equal(t, "9090", cfg.GRPCPort())
+ assert.Equal(t, "ntoken", cfg.NodeToken())
+}
+
+func TestMustLoad(t *testing.T) {
+ t.Run("success", func(t *testing.T) {
+ os.Clearenv()
+ t.Setenv("DOMAIN", "example.com")
+ cfg, err := MustLoad()
+ assert.NoError(t, err)
+ assert.NotNil(t, cfg)
+ })
+
+ t.Run("loadEnvFile error", func(t *testing.T) {
+ err := os.Mkdir(".env", 0755)
+ assert.NoError(t, err)
+ defer func() {
+ err = os.Remove(".env")
+ assert.NoError(t, err)
+ }()
+
+ cfg, err := MustLoad()
+ assert.Error(t, err)
+ assert.Nil(t, cfg)
+ })
+
+ t.Run("parse error", func(t *testing.T) {
+ os.Clearenv()
+ t.Setenv("MODE", "invalid")
+ cfg, err := MustLoad()
+ assert.Error(t, err)
+ assert.Nil(t, cfg)
+ })
+}
+
+func TestLoadEnvFile(t *testing.T) {
+ t.Run("file exists", func(t *testing.T) {
+ err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
+ assert.NoError(t, err)
+ defer func() {
+ err = os.Remove(".env")
+ assert.NoError(t, err)
+ }()
+
+ err = loadEnvFile()
+ assert.NoError(t, err)
+ assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE"))
+ })
+
+ t.Run("file missing", func(t *testing.T) {
+ _ = os.Remove(".env")
+ err := loadEnvFile()
+ assert.NoError(t, err)
+ })
+}
diff --git a/internal/config/loader.go b/internal/config/loader.go
index cde9fd0..5cbfe1f 100644
--- a/internal/config/loader.go
+++ b/internal/config/loader.go
@@ -18,18 +18,21 @@ type config struct {
httpPort string
httpsPort string
- tlsEnabled bool
- tlsRedirect bool
+ keyLoc string
- acmeEmail string
- cfAPIToken string
- acmeStaging bool
+ tlsEnabled bool
+ tlsRedirect bool
+ tlsStoragePath string
+ acmeEmail string
+ cfAPIToken string
+ acmeStaging bool
allowedPortsStart uint16
allowedPortsEnd uint16
bufferSize int
-
+ headerSize int
+
pprofEnabled bool
pprofPort string
@@ -51,8 +54,11 @@ func parse() (*config, error) {
httpPort := getenv("HTTP_PORT", "8080")
httpsPort := getenv("HTTPS_PORT", "8443")
+ keyLoc := getenv("KEY_LOC", "certs/privkey.pem")
+
tlsEnabled := getenvBool("TLS_ENABLED", false)
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
+ tlsStoragePath := getenv("TLS_STORAGE_PATH", "certs/tls/")
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
acmeStaging := getenvBool("ACME_STAGING", false)
@@ -68,6 +74,7 @@ func parse() (*config, error) {
}
bufferSize := parseBufferSize()
+ headerSize := parseHeaderSize()
pprofEnabled := getenvBool("PPROF_ENABLED", false)
pprofPort := getenv("PPROF_PORT", "6060")
@@ -85,14 +92,17 @@ func parse() (*config, error) {
sshPort: sshPort,
httpPort: httpPort,
httpsPort: httpsPort,
+ keyLoc: keyLoc,
tlsEnabled: tlsEnabled,
tlsRedirect: tlsRedirect,
+ tlsStoragePath: tlsStoragePath,
acmeEmail: acmeEmail,
cfAPIToken: cfToken,
acmeStaging: acmeStaging,
allowedPortsStart: start,
allowedPortsEnd: end,
bufferSize: bufferSize,
+ headerSize: headerSize,
pprofEnabled: pprofEnabled,
pprofPort: pprofPort,
mode: mode,
@@ -154,6 +164,16 @@ func parseBufferSize() int {
return size
}
+func parseHeaderSize() int {
+ raw := getenv("MAX_HEADER_SIZE", "4096")
+ size, err := strconv.Atoi(raw)
+ if err != nil || size < 4096 || size > 131072 {
+ log.Println("Invalid BUFFER_SIZE, falling back to 4096")
+ return 4096
+ }
+ return size
+}
+
func getenv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go
index f2e0a1e..3adcc57 100644
--- a/internal/grpc/client/client.go
+++ b/internal/grpc/client/client.go
@@ -38,7 +38,15 @@ type client struct {
closing bool
}
-func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) {
+var (
+ grpcNewClient = grpc.NewClient
+ healthNewHealthClient = grpc_health_v1.NewHealthClient
+ initialBackoff = time.Second
+)
+
+func New(config config.Config, sessionRegistry registry.Registry) (Client, error) {
+ address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort())
+
var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
@@ -58,7 +66,7 @@ func New(config config.Config, address string, sessionRegistry registry.Registry
),
)
- conn, err := grpc.NewClient(address, opts...)
+ conn, err := grpcNewClient(address, opts...)
if err != nil {
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err)
}
@@ -77,85 +85,100 @@ func New(config config.Config, address string, sessionRegistry registry.Registry
}
func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error {
- const (
- baseBackoff = time.Second
- maxBackoff = 30 * time.Second
- )
-
- backoff := baseBackoff
- wait := func() error {
- if backoff <= 0 {
- return nil
- }
- select {
- case <-time.After(backoff):
- return nil
- case <-ctx.Done():
- return ctx.Err()
- }
- }
- growBackoff := func() {
- backoff *= 2
- if backoff > maxBackoff {
- backoff = maxBackoff
- }
- }
+ backoff := initialBackoff
for {
- subscribe, err := c.eventService.Subscribe(ctx)
- if err != nil {
- if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
- return err
- }
- if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
- return err
- }
- if err = wait(); err != nil {
- return err
- }
- growBackoff()
- log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
- continue
+ if err := c.subscribeAndProcess(ctx, identity, authToken, &backoff); err != nil {
+ return err
}
+ }
+}
- err = subscribe.Send(&proto.Node{
- Type: proto.EventType_AUTHENTICATION,
- Payload: &proto.Node_AuthEvent{
- AuthEvent: &proto.Authentication{
- Identity: identity,
- AuthToken: authToken,
- },
+func (c *client) subscribeAndProcess(ctx context.Context, identity, authToken string, backoff *time.Duration) error {
+ subscribe, err := c.eventService.Subscribe(ctx)
+ if err != nil {
+ return c.handleSubscribeError(ctx, err, backoff)
+ }
+
+ err = subscribe.Send(&proto.Node{
+ Type: proto.EventType_AUTHENTICATION,
+ Payload: &proto.Node_AuthEvent{
+ AuthEvent: &proto.Authentication{
+ Identity: identity,
+ AuthToken: authToken,
},
- })
+ },
+ })
- if err != nil {
- log.Println("Authentication failed to send to gRPC server:", err)
- if c.isConnectionError(err) {
- if err = wait(); err != nil {
- return err
- }
- growBackoff()
- continue
- }
- return err
- }
- log.Println("Authentication Successfully sent to gRPC server")
- backoff = baseBackoff
+ if err != nil {
+ return c.handleAuthError(ctx, err, backoff)
+ }
- if err = c.processEventStream(subscribe); err != nil {
- if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
- return err
- }
- if c.isConnectionError(err) {
- log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
- if err = wait(); err != nil {
- return err
- }
- growBackoff()
- continue
- }
- return err
- }
+ log.Println("Authentication Successfully sent to gRPC server")
+ *backoff = time.Second
+
+ return c.handleStreamError(ctx, c.processEventStream(subscribe), backoff)
+}
+
+func (c *client) handleSubscribeError(ctx context.Context, err error, backoff *time.Duration) error {
+ if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
+ return err
+ }
+ if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
+ return err
+ }
+ if err = c.wait(ctx, *backoff); err != nil {
+ return err
+ }
+ c.growBackoff(backoff)
+ log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
+ return nil
+}
+
+func (c *client) handleAuthError(ctx context.Context, err error, backoff *time.Duration) error {
+ log.Println("Authentication failed to send to gRPC server:", err)
+ if !c.isConnectionError(err) {
+ return err
+ }
+ if err := c.wait(ctx, *backoff); err != nil {
+ return err
+ }
+ c.growBackoff(backoff)
+ return nil
+}
+
+func (c *client) handleStreamError(ctx context.Context, err error, backoff *time.Duration) error {
+ if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
+ return err
+ }
+ if !c.isConnectionError(err) {
+ return err
+ }
+ log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
+ if err := c.wait(ctx, *backoff); err != nil {
+ return err
+ }
+ c.growBackoff(backoff)
+ return nil
+}
+
+func (c *client) wait(ctx context.Context, duration time.Duration) error {
+ if duration <= 0 {
+ return nil
+ }
+ select {
+ case <-time.After(duration):
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
+func (c *client) growBackoff(backoff *time.Duration) {
+ const maxBackoff = 30 * time.Second
+ *backoff *= 2
+ if *backoff > maxBackoff {
+ *backoff = maxBackoff
}
}
@@ -191,35 +214,20 @@ func (c *client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, pr
func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
slugEvent := evt.GetSlugEvent()
user := slugEvent.GetUser()
- oldSlug := slugEvent.GetOld()
- newSlug := slugEvent.GetNew()
+ oldKey := types.SessionKey{Id: slugEvent.GetOld(), Type: types.TunnelTypeHTTP}
+ newKey := types.SessionKey{Id: slugEvent.GetNew(), Type: types.TunnelTypeHTTP}
- userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP})
+ userSession, err := c.sessionRegistry.Get(oldKey)
if err != nil {
- return c.sendNode(subscribe, &proto.Node{
- Type: proto.EventType_SLUG_CHANGE_RESPONSE,
- Payload: &proto.Node_SlugEventResponse{
- SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
- },
- }, "slug change failure response")
+ return c.sendSlugChangeResponse(subscribe, false, err.Error())
}
- if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil {
- return c.sendNode(subscribe, &proto.Node{
- Type: proto.EventType_SLUG_CHANGE_RESPONSE,
- Payload: &proto.Node_SlugEventResponse{
- SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
- },
- }, "slug change failure response")
+ if err = c.sessionRegistry.Update(user, oldKey, newKey); err != nil {
+ return c.sendSlugChangeResponse(subscribe, false, err.Error())
}
userSession.Interaction().Redraw()
- return c.sendNode(subscribe, &proto.Node{
- Type: proto.EventType_SLUG_CHANGE_RESPONSE,
- Payload: &proto.Node_SlugEventResponse{
- SlugEventResponse: &proto.SlugChangeEventResponse{Success: true, Message: ""},
- },
- }, "slug change success response")
+ return c.sendSlugChangeResponse(subscribe, true, "")
}
func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
@@ -238,12 +246,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
})
}
- return c.sendNode(subscribe, &proto.Node{
- Type: proto.EventType_GET_SESSIONS,
- Payload: &proto.Node_GetSessionsEvent{
- GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
- },
- }, "send get sessions response")
+ return c.sendGetSessionsResponse(subscribe, details)
}
func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
@@ -253,39 +256,46 @@ func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto
tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType())
if err != nil {
- return c.sendNode(subscribe, &proto.Node{
- Type: proto.EventType_TERMINATE_SESSION,
- Payload: &proto.Node_TerminateSessionEventResponse{
- TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
- },
- }, "terminate session invalid tunnel type")
+ return c.sendTerminateSessionResponse(subscribe, false, err.Error())
}
userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType})
if err != nil {
- return c.sendNode(subscribe, &proto.Node{
- Type: proto.EventType_TERMINATE_SESSION,
- Payload: &proto.Node_TerminateSessionEventResponse{
- TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
- },
- }, "terminate session fetch failed")
+ return c.sendTerminateSessionResponse(subscribe, false, err.Error())
}
if err = userSession.Lifecycle().Close(); err != nil {
- return c.sendNode(subscribe, &proto.Node{
- Type: proto.EventType_TERMINATE_SESSION,
- Payload: &proto.Node_TerminateSessionEventResponse{
- TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
- },
- }, "terminate session close failed")
+ return c.sendTerminateSessionResponse(subscribe, false, err.Error())
}
+ return c.sendTerminateSessionResponse(subscribe, true, "")
+}
+
+func (c *client) sendSlugChangeResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
+ return c.sendNode(subscribe, &proto.Node{
+ Type: proto.EventType_SLUG_CHANGE_RESPONSE,
+ Payload: &proto.Node_SlugEventResponse{
+ SlugEventResponse: &proto.SlugChangeEventResponse{Success: success, Message: message},
+ },
+ }, "slug change response")
+}
+
+func (c *client) sendGetSessionsResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], details []*proto.Detail) error {
+ return c.sendNode(subscribe, &proto.Node{
+ Type: proto.EventType_GET_SESSIONS,
+ Payload: &proto.Node_GetSessionsEvent{
+ GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
+ },
+ }, "send get sessions response")
+}
+
+func (c *client) sendTerminateSessionResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
- TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: true, Message: ""},
+ TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: success, Message: message},
},
- }, "terminate session success response")
+ }, "terminate session response")
}
func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error {
@@ -326,7 +336,7 @@ func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bo
}
func (c *client) CheckServerHealth(ctx context.Context) error {
- healthClient := grpc_health_v1.NewHealthClient(c.ClientConn())
+ healthClient := healthNewHealthClient(c.ClientConn())
resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{
Service: "",
})
diff --git a/internal/grpc/client/client_test.go b/internal/grpc/client/client_test.go
new file mode 100644
index 0000000..e69065d
--- /dev/null
+++ b/internal/grpc/client/client_test.go
@@ -0,0 +1,1078 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "testing"
+ "time"
+
+ "tunnel_pls/internal/port"
+ "tunnel_pls/internal/registry"
+ "tunnel_pls/session/forwarder"
+ "tunnel_pls/session/interaction"
+ "tunnel_pls/session/lifecycle"
+ "tunnel_pls/session/slug"
+ "tunnel_pls/types"
+
+ proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "golang.org/x/crypto/ssh"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/health/grpc_health_v1"
+ "google.golang.org/grpc/status"
+)
+
+func TestClient_ClientConn(t *testing.T) {
+ conn := &grpc.ClientConn{}
+ c := &client{conn: conn}
+ if c.ClientConn() != conn {
+ t.Errorf("ClientConn() did not return expected connection")
+ }
+}
+
+func TestClient_Close(t *testing.T) {
+ c := &client{}
+ if err := c.Close(); err != nil {
+ t.Errorf("Close() on nil connection returned error: %v", err)
+ }
+}
+
+func TestAuthorizeConn(t *testing.T) {
+ mockUserSvc := &mockUserServiceClient{}
+ c := &client{authorizeConnectionService: mockUserSvc}
+
+ tests := []struct {
+ name string
+ token string
+ mockResp *proto.CheckResponse
+ mockErr error
+ wantAuth bool
+ wantUser string
+ wantErr bool
+ }{
+ {
+ name: "Success",
+ token: "valid",
+ mockResp: &proto.CheckResponse{Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED, User: "mas-fuad"},
+ wantAuth: true,
+ wantUser: "mas-fuad",
+ wantErr: false,
+ },
+ {
+ name: "Unauthorized",
+ token: "invalid",
+ mockResp: &proto.CheckResponse{Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED},
+ wantAuth: false,
+ wantUser: "UNAUTHORIZED",
+ wantErr: false,
+ },
+ {
+ name: "Error",
+ token: "error",
+ mockErr: errors.New("grpc error"),
+ wantAuth: false,
+ wantUser: "UNAUTHORIZED",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockUserSvc.On("Check", mock.Anything, &proto.CheckRequest{AuthToken: tt.token}, mock.Anything).Return(tt.mockResp, tt.mockErr).Once()
+
+ auth, user, err := c.AuthorizeConn(context.Background(), tt.token)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("AuthorizeConn() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ assert.Equal(t, tt.wantAuth, auth)
+ assert.Equal(t, tt.wantUser, user)
+ mockUserSvc.AssertExpectations(t)
+ })
+ }
+}
+
+func TestHandleSubscribeError(t *testing.T) {
+ c := &client{}
+ ctx := context.Background()
+ canceledCtx, cancel := context.WithCancel(ctx)
+ cancel()
+
+ tests := []struct {
+ name string
+ ctx context.Context
+ err error
+ backoff time.Duration
+ wantErr bool
+ wantB time.Duration
+ }{
+ {
+ name: "ContextCanceled",
+ ctx: canceledCtx,
+ err: context.Canceled,
+ backoff: time.Second,
+ wantErr: true,
+ },
+ {
+ name: "GrpcCanceled",
+ ctx: ctx,
+ err: status.Error(codes.Canceled, "canceled"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ {
+ name: "CtxErrSet",
+ ctx: canceledCtx,
+ err: errors.New("other error"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ {
+ name: "Unauthenticated",
+ ctx: ctx,
+ err: status.Error(codes.Unauthenticated, "unauth"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ {
+ name: "ConnectionError",
+ ctx: ctx,
+ err: status.Error(codes.Unavailable, "unavailable"),
+ backoff: time.Second,
+ wantErr: false,
+ wantB: 2 * time.Second,
+ },
+ {
+ name: "NonConnectionError",
+ ctx: ctx,
+ err: status.Error(codes.Internal, "internal"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ {
+ name: "WaitCanceled",
+ ctx: canceledCtx,
+ err: status.Error(codes.Unavailable, "unavailable"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ backoff := tt.backoff
+ err := c.handleSubscribeError(tt.ctx, tt.err, &backoff)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("handleSubscribeError() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if !tt.wantErr && backoff != tt.wantB {
+ t.Errorf("handleSubscribeError() backoff = %v, want %v", backoff, tt.wantB)
+ }
+ })
+ }
+
+ t.Run("WaitCanceledReal", func(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ backoff := 50 * time.Millisecond
+ go func() {
+ time.Sleep(10 * time.Millisecond)
+ cancel()
+ }()
+ err := c.handleSubscribeError(ctx, status.Error(codes.Unavailable, "unavailable"), &backoff)
+ if err == nil {
+ t.Errorf("expected error from wait")
+ }
+ })
+}
+
+func TestHandleStreamError(t *testing.T) {
+ c := &client{}
+ ctx := context.Background()
+ canceledCtx, cancel := context.WithCancel(ctx)
+ cancel()
+
+ tests := []struct {
+ name string
+ ctx context.Context
+ err error
+ backoff time.Duration
+ wantErr bool
+ wantB time.Duration
+ }{
+ {
+ name: "ContextCanceled",
+ ctx: canceledCtx,
+ err: context.Canceled,
+ backoff: time.Second,
+ wantErr: true,
+ },
+ {
+ name: "GrpcCanceled",
+ ctx: ctx,
+ err: status.Error(codes.Canceled, "canceled"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ {
+ name: "CtxErrSet",
+ ctx: canceledCtx,
+ err: errors.New("other error"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ {
+ name: "ConnectionError",
+ ctx: ctx,
+ err: status.Error(codes.Unavailable, "unavailable"),
+ backoff: time.Second,
+ wantErr: false,
+ wantB: 2 * time.Second,
+ },
+ {
+ name: "NonConnectionError",
+ ctx: ctx,
+ err: status.Error(codes.Internal, "internal"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ {
+ name: "WaitCanceled",
+ ctx: canceledCtx,
+ err: status.Error(codes.Unavailable, "unavailable"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ backoff := tt.backoff
+ err := c.handleStreamError(tt.ctx, tt.err, &backoff)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("handleStreamError() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if !tt.wantErr && backoff != tt.wantB {
+ t.Errorf("handleStreamError() backoff = %v, want %v", backoff, tt.wantB)
+ }
+ })
+ }
+
+ t.Run("WaitCanceledReal", func(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ backoff := 50 * time.Millisecond
+ go func() {
+ time.Sleep(10 * time.Millisecond)
+ cancel()
+ }()
+ err := c.handleStreamError(ctx, status.Error(codes.Unavailable, "unavailable"), &backoff)
+ if err == nil {
+ t.Errorf("expected error from wait")
+ }
+ })
+}
+
+func TestHandleAuthError(t *testing.T) {
+ c := &client{}
+ ctx := context.Background()
+
+ tests := []struct {
+ name string
+ err error
+ backoff time.Duration
+ wantErr bool
+ wantB time.Duration
+ }{
+ {
+ name: "ConnectionError",
+ err: status.Error(codes.Unavailable, "unavailable"),
+ backoff: time.Second,
+ wantErr: false,
+ wantB: 2 * time.Second,
+ },
+ {
+ name: "NonConnectionError",
+ err: status.Error(codes.Internal, "internal"),
+ backoff: time.Second,
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ backoff := tt.backoff
+ err := c.handleAuthError(ctx, tt.err, &backoff)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("handleAuthError() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if !tt.wantErr && backoff != tt.wantB {
+ t.Errorf("handleAuthError() backoff = %v, want %v", backoff, tt.wantB)
+ }
+ })
+ }
+}
+
+func TestHandleAuthError_WaitFail(t *testing.T) {
+ c := &client{}
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ backoff := time.Second
+ err := c.handleAuthError(ctx, status.Error(codes.Unavailable, "unavailable"), &backoff)
+ if err == nil {
+ t.Errorf("expected error when wait fails")
+ }
+}
+
+func TestProcessEventStream(t *testing.T) {
+ c := &client{}
+
+ t.Run("UnknownEventType", func(t *testing.T) {
+ mockStream := &mockSubscribeClient{}
+ mockStream.On("Recv").Return(&proto.Events{Type: proto.EventType(999)}, nil).Once()
+ mockStream.On("Recv").Return(nil, io.EOF).Once()
+
+ err := c.processEventStream(mockStream)
+ assert.ErrorIs(t, err, io.EOF)
+ })
+
+ t.Run("DispatchSuccess", func(t *testing.T) {
+ events := []proto.EventType{
+ proto.EventType_SLUG_CHANGE,
+ proto.EventType_GET_SESSIONS,
+ proto.EventType_TERMINATE_SESSION,
+ }
+
+ for _, et := range events {
+ t.Run(et.String(), func(t *testing.T) {
+ mockStream := &mockSubscribeClient{}
+ payload := &proto.Events{Type: et}
+ switch et {
+ case proto.EventType_SLUG_CHANGE:
+ payload.Payload = &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}}
+ case proto.EventType_GET_SESSIONS:
+ payload.Payload = &proto.Events_GetSessionsEvent{GetSessionsEvent: &proto.GetSessionsEvent{}}
+ case proto.EventType_TERMINATE_SESSION:
+ payload.Payload = &proto.Events_TerminateSessionEvent{TerminateSessionEvent: &proto.TerminateSessionEvent{}}
+ }
+
+ mockStream.On("Recv").Return(payload, nil).Once()
+ mockStream.On("Recv").Return(nil, io.EOF).Once()
+
+ mockReg := &mockRegistry{}
+ c.sessionRegistry = mockReg
+ mCfg := &MockConfig{}
+ c.config = mCfg
+ mCfg.On("Domain").Return("test.com").Maybe()
+
+ switch et {
+ case proto.EventType_SLUG_CHANGE:
+ mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once()
+ case proto.EventType_GET_SESSIONS:
+ mockReg.On("GetAllSessionFromUser", mock.Anything).Return(nil).Once()
+ case proto.EventType_TERMINATE_SESSION:
+ mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("fail")).Once()
+ }
+ mockStream.On("Send", mock.Anything).Return(nil).Once()
+
+ err := c.processEventStream(mockStream)
+ assert.ErrorIs(t, err, io.EOF)
+ })
+ }
+ })
+
+ t.Run("HandlerError", func(t *testing.T) {
+ mockStream := &mockSubscribeClient{}
+ mockStream.On("Recv").Return(&proto.Events{
+ Type: proto.EventType_SLUG_CHANGE,
+ Payload: &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}},
+ }, nil).Once()
+
+ mockReg := &mockRegistry{}
+ mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once()
+ c.sessionRegistry = mockReg
+
+ expectedErr := status.Error(codes.Unavailable, "send fail")
+ mockStream.On("Send", mock.Anything).Return(expectedErr).Once()
+
+ err := c.processEventStream(mockStream)
+ assert.Equal(t, expectedErr, err)
+ })
+}
+
+func TestSendNode(t *testing.T) {
+ c := &client{}
+
+ t.Run("Success", func(t *testing.T) {
+ mockStream := &mockSubscribeClient{}
+ mockStream.On("Send", mock.Anything).Return(nil).Once()
+ err := c.sendNode(mockStream, &proto.Node{}, "context")
+ assert.NoError(t, err)
+ mockStream.AssertExpectations(t)
+ })
+
+ t.Run("ConnectionError", func(t *testing.T) {
+ mockStream := &mockSubscribeClient{}
+ expectedErr := status.Error(codes.Unavailable, "fail")
+ mockStream.On("Send", mock.Anything).Return(expectedErr).Once()
+ err := c.sendNode(mockStream, &proto.Node{}, "context")
+ assert.ErrorIs(t, err, expectedErr)
+ mockStream.AssertExpectations(t)
+ })
+
+ t.Run("OtherError", func(t *testing.T) {
+ mockStream := &mockSubscribeClient{}
+ mockStream.On("Send", mock.Anything).Return(status.Error(codes.Internal, "fail")).Once()
+ err := c.sendNode(mockStream, &proto.Node{}, "context")
+ assert.NoError(t, err)
+ mockStream.AssertExpectations(t)
+ })
+}
+
+func TestHandleSlugChange(t *testing.T) {
+ mockReg := &mockRegistry{}
+ mockStream := &mockSubscribeClient{}
+ c := &client{sessionRegistry: mockReg}
+
+ evt := &proto.Events{
+ Payload: &proto.Events_SlugEvent{
+ SlugEvent: &proto.SlugChangeEvent{
+ User: "mas-fuad",
+ Old: "old-slug",
+ New: "new-slug",
+ },
+ },
+ }
+
+ t.Run("Success", func(t *testing.T) {
+ mockSess := &mockSession{}
+ mockInter := &mockInteraction{}
+ mockSess.On("Interaction").Return(mockInter).Once()
+ mockInter.On("Redraw").Return().Once()
+
+ mockReg.On("Get", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once()
+ mockReg.On("Update", "mas-fuad", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}, types.SessionKey{Id: "new-slug", Type: types.TunnelTypeHTTP}).Return(nil).Once()
+
+ mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
+ return n.Type == proto.EventType_SLUG_CHANGE_RESPONSE && n.GetSlugEventResponse().Success
+ })).Return(nil).Once()
+
+ err := c.handleSlugChange(mockStream, evt)
+ assert.NoError(t, err)
+ mockReg.AssertExpectations(t)
+ mockStream.AssertExpectations(t)
+ mockInter.AssertExpectations(t)
+ })
+
+ t.Run("SessionNotFound", func(t *testing.T) {
+ mockReg.On("Get", mock.Anything).Return(nil, errors.New("not found")).Once()
+ mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
+ return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "not found"
+ })).Return(nil).Once()
+
+ err := c.handleSlugChange(mockStream, evt)
+ assert.NoError(t, err)
+ mockReg.AssertExpectations(t)
+ mockStream.AssertExpectations(t)
+ })
+
+ t.Run("UpdateError", func(t *testing.T) {
+ mockSess := &mockSession{}
+ mockReg.On("Get", mock.Anything).Return(mockSess, nil).Once()
+ mockReg.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("update fail")).Once()
+ mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
+ return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "update fail"
+ })).Return(nil).Once()
+
+ err := c.handleSlugChange(mockStream, evt)
+ assert.NoError(t, err)
+ mockReg.AssertExpectations(t)
+ mockStream.AssertExpectations(t)
+ })
+}
+
+func TestHandleGetSessions(t *testing.T) {
+ mockReg := &mockRegistry{}
+ mockStream := &mockSubscribeClient{}
+ mockCfg := &MockConfig{}
+ c := &client{sessionRegistry: mockReg, config: mockCfg}
+
+ evt := &proto.Events{
+ Payload: &proto.Events_GetSessionsEvent{
+ GetSessionsEvent: &proto.GetSessionsEvent{
+ Identity: "mas-fuad",
+ },
+ },
+ }
+
+ t.Run("Success", func(t *testing.T) {
+ now := time.Now()
+ mockSess := &mockSession{}
+ mockSess.On("Detail").Return(&types.Detail{
+ ForwardingType: "http",
+ Slug: "myslug",
+ UserID: "mas-fuad",
+ Active: true,
+ StartedAt: now,
+ }).Once()
+
+ mockReg.On("GetAllSessionFromUser", "mas-fuad").Return([]registry.Session{mockSess}).Once()
+ mockCfg.On("Domain").Return("test.com").Once()
+
+ mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
+ if n.Type != proto.EventType_GET_SESSIONS {
+ return false
+ }
+ details := n.GetGetSessionsEvent().Details
+ return len(details) == 1 && details[0].Slug == "myslug"
+ })).Return(nil).Once()
+
+ err := c.handleGetSessions(mockStream, evt)
+ assert.NoError(t, err)
+ mockReg.AssertExpectations(t)
+ mockStream.AssertExpectations(t)
+ mockCfg.AssertExpectations(t)
+ })
+}
+
+func TestHandleTerminateSession(t *testing.T) {
+ mockReg := &mockRegistry{}
+ mockStream := &mockSubscribeClient{}
+ c := &client{sessionRegistry: mockReg}
+
+ evt := &proto.Events{
+ Payload: &proto.Events_TerminateSessionEvent{
+ TerminateSessionEvent: &proto.TerminateSessionEvent{
+ User: "mas-fuad",
+ Slug: "myslug",
+ TunnelType: proto.TunnelType_HTTP,
+ },
+ },
+ }
+
+ t.Run("Success", func(t *testing.T) {
+ mockSess := &mockSession{}
+ mockLife := &mockLifecycle{}
+ mockSess.On("Lifecycle").Return(mockLife).Once()
+ mockLife.On("Close").Return(nil).Once()
+
+ mockReg.On("GetWithUser", "mas-fuad", types.SessionKey{Id: "myslug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once()
+
+ mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
+ return n.GetTerminateSessionEventResponse().Success
+ })).Return(nil).Once()
+
+ err := c.handleTerminateSession(mockStream, evt)
+ assert.NoError(t, err)
+ mockReg.AssertExpectations(t)
+ mockStream.AssertExpectations(t)
+ mockLife.AssertExpectations(t)
+ })
+
+ t.Run("TunnelTypeUnknown", func(t *testing.T) {
+ badEvt := &proto.Events{
+ Payload: &proto.Events_TerminateSessionEvent{
+ TerminateSessionEvent: &proto.TerminateSessionEvent{
+ TunnelType: proto.TunnelType(999),
+ },
+ },
+ }
+ mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
+ resp := n.GetTerminateSessionEventResponse()
+ return !resp.Success && resp.Message != ""
+ })).Return(nil).Once()
+
+ err := c.handleTerminateSession(mockStream, badEvt)
+ assert.NoError(t, err)
+ mockStream.AssertExpectations(t)
+ })
+
+ t.Run("SessionNotFound", func(t *testing.T) {
+ mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("not found")).Once()
+ mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
+ resp := n.GetTerminateSessionEventResponse()
+ return !resp.Success && resp.Message == "not found"
+ })).Return(nil).Once()
+
+ err := c.handleTerminateSession(mockStream, evt)
+ assert.NoError(t, err)
+ mockReg.AssertExpectations(t)
+ mockStream.AssertExpectations(t)
+ })
+
+ t.Run("CloseError", func(t *testing.T) {
+ mockSess := &mockSession{}
+ mockLife := &mockLifecycle{}
+ mockSess.On("Lifecycle").Return(mockLife).Once()
+ mockLife.On("Close").Return(errors.New("close fail")).Once()
+ mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(mockSess, nil).Once()
+
+ mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool {
+ resp := n.GetTerminateSessionEventResponse()
+ return !resp.Success && resp.Message == "close fail"
+ })).Return(nil).Once()
+
+ err := c.handleTerminateSession(mockStream, evt)
+ assert.NoError(t, err)
+ mockReg.AssertExpectations(t)
+ mockStream.AssertExpectations(t)
+ mockLife.AssertExpectations(t)
+ })
+}
+
+func TestSubscribeAndProcess(t *testing.T) {
+ mockEventSvc := &mockEventServiceClient{}
+ c := &client{eventService: mockEventSvc}
+ ctx := context.Background()
+ backoff := time.Second
+
+ t.Run("SubscribeError", func(t *testing.T) {
+ expectedErr := status.Error(codes.Unauthenticated, "unauth")
+ mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once()
+ err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
+ assert.ErrorIs(t, err, expectedErr)
+ })
+
+ t.Run("AuthSendError", func(t *testing.T) {
+ mockStream := &mockSubscribeClient{}
+ mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once()
+ expectedErr := status.Error(codes.Internal, "send fail")
+ mockStream.On("Send", mock.Anything).Return(expectedErr).Once()
+ err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
+ assert.ErrorIs(t, err, expectedErr)
+ })
+
+ t.Run("StreamError", func(t *testing.T) {
+ mockStream := &mockSubscribeClient{}
+ mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once()
+ mockStream.On("Send", mock.Anything).Return(nil).Once()
+ expectedErr := status.Error(codes.Internal, "stream fail")
+ mockStream.On("Recv").Return(nil, expectedErr).Once()
+ err := c.subscribeAndProcess(ctx, "id", "token", &backoff)
+ assert.ErrorIs(t, err, expectedErr)
+ })
+}
+
+func TestSubscribeEvents(t *testing.T) {
+ mockEventSvc := &mockEventServiceClient{}
+ c := &client{eventService: mockEventSvc}
+
+ t.Run("ReturnsOnError", func(t *testing.T) {
+ expectedErr := errors.New("fatal error")
+ mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once()
+ err := c.SubscribeEvents(context.Background(), "id", "token")
+ assert.ErrorIs(t, err, expectedErr)
+ })
+
+ t.Run("RetryLoop", func(t *testing.T) {
+ oldB := initialBackoff
+ initialBackoff = 5 * time.Millisecond
+ defer func() { initialBackoff = oldB }()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ defer cancel()
+
+ mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, status.Error(codes.Unavailable, "unavailable"))
+
+ err := c.SubscribeEvents(ctx, "id", "token")
+ assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled))
+ mockEventSvc.AssertExpectations(t)
+ })
+}
+
+func TestCheckServerHealth(t *testing.T) {
+ mockHealth := &mockHealthClient{}
+ old := healthNewHealthClient
+ healthNewHealthClient = func(cc grpc.ClientConnInterface) grpc_health_v1.HealthClient {
+ return mockHealth
+ }
+ defer func() { healthNewHealthClient = old }()
+
+ c := &client{}
+
+ t.Run("Success", func(t *testing.T) {
+ mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil).Once()
+ err := c.CheckServerHealth(context.Background())
+ assert.NoError(t, err)
+ mockHealth.AssertExpectations(t)
+ })
+
+ t.Run("Error", func(t *testing.T) {
+ mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("health fail")).Once()
+ err := c.CheckServerHealth(context.Background())
+ assert.ErrorContains(t, err, "health check failed: health fail")
+ mockHealth.AssertExpectations(t)
+ })
+
+ t.Run("NotServing", func(t *testing.T) {
+ mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil).Once()
+ err := c.CheckServerHealth(context.Background())
+ assert.ErrorContains(t, err, "server not serving: NOT_SERVING")
+ mockHealth.AssertExpectations(t)
+ })
+}
+
+func TestNew_Error(t *testing.T) {
+ old := grpcNewClient
+ grpcNewClient = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
+ return nil, errors.New("dial fail")
+ }
+ defer func() { grpcNewClient = old }()
+ mockConfig := &MockConfig{}
+
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("1234")
+ cli, err := New(mockConfig, &mockRegistry{})
+ if err == nil || err.Error() != "failed to connect to gRPC server at localhost:1234: dial fail" {
+ t.Errorf("expected dial fail error, got %v", err)
+ }
+ if cli != nil {
+ t.Errorf("expected nil client")
+ }
+}
+
+func TestNew(t *testing.T) {
+ mockConfig := &MockConfig{}
+ mockReg := &mockRegistry{}
+ mockConfig.On("GRPCAddress").Return("localhost")
+ mockConfig.On("GRPCPort").Return("1234")
+ cli, err := New(mockConfig, mockReg)
+ if err != nil {
+ t.Errorf("New() error = %v", err)
+ }
+ if cli == nil {
+ t.Fatal("New() returned nil client")
+ }
+ defer func(cli Client) {
+ _ = cli.Close()
+ }(cli)
+}
+
+type MockConfig struct {
+ mock.Mock
+}
+
+func (m *MockConfig) Domain() string { return m.Called().String(0) }
+func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
+func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
+func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
+func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
+func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
+func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
+func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
+func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
+func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
+func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
+func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
+func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
+func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
+func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
+
+type mockRegistry struct {
+ mock.Mock
+}
+
+func (m *mockRegistry) Get(key registry.Key) (registry.Session, error) {
+ args := m.Called(key)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(registry.Session), args.Error(1)
+}
+func (m *mockRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
+ args := m.Called(user, key)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(registry.Session), args.Error(1)
+}
+func (m *mockRegistry) Update(user string, oldKey, newKey registry.Key) error {
+ return m.Called(user, oldKey, newKey).Error(0)
+}
+func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session {
+ args := m.Called(user)
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).([]registry.Session)
+}
+func (m *mockRegistry) Register(key registry.Key, session registry.Session) bool {
+ return m.Called(key, session).Bool(0)
+}
+func (m *mockRegistry) Remove(key registry.Key) {
+ m.Called(key)
+}
+
+type mockSession struct {
+ mock.Mock
+}
+
+func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(lifecycle.Lifecycle)
+}
+func (m *mockSession) Interaction() interaction.Interaction {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(interaction.Interaction)
+}
+func (m *mockSession) Detail() *types.Detail {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(*types.Detail)
+}
+func (m *mockSession) Slug() slug.Slug {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(slug.Slug)
+}
+func (m *mockSession) Forwarder() forwarder.Forwarder {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(forwarder.Forwarder)
+}
+
+type mockInteraction struct {
+ mock.Mock
+}
+
+func (m *mockInteraction) Start() { m.Called() }
+func (m *mockInteraction) Stop() { m.Called() }
+func (m *mockInteraction) Redraw() { m.Called() }
+func (m *mockInteraction) SetWH(w, h int) { m.Called(w, h) }
+func (m *mockInteraction) SetChannel(channel ssh.Channel) { m.Called(channel) }
+func (m *mockInteraction) SetMode(mode types.InteractiveMode) { m.Called(mode) }
+func (m *mockInteraction) Mode() types.InteractiveMode {
+ return m.Called().Get(0).(types.InteractiveMode)
+}
+func (m *mockInteraction) Send(message string) error { return m.Called(message).Error(0) }
+
+type mockLifecycle struct {
+ mock.Mock
+}
+
+func (m *mockLifecycle) Close() error { return m.Called().Error(0) }
+func (m *mockLifecycle) Channel() ssh.Channel {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(ssh.Channel)
+}
+func (m *mockLifecycle) Connection() ssh.Conn {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(ssh.Conn)
+}
+func (m *mockLifecycle) User() string { return m.Called().String(0) }
+func (m *mockLifecycle) SetChannel(channel ssh.Channel) { m.Called(channel) }
+func (m *mockLifecycle) SetStatus(status types.SessionStatus) { m.Called(status) }
+func (m *mockLifecycle) IsActive() bool { return m.Called().Bool(0) }
+func (m *mockLifecycle) StartedAt() time.Time { return m.Called().Get(0).(time.Time) }
+func (m *mockLifecycle) PortRegistry() port.Port {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(port.Port)
+}
+
+type mockEventServiceClient struct {
+ mock.Mock
+}
+
+func (m *mockEventServiceClient) Subscribe(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) {
+ args := m.Called(ctx, opts)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(proto.EventService_SubscribeClient), args.Error(1)
+}
+
+type mockSubscribeClient struct {
+ mock.Mock
+ grpc.ClientStream
+}
+
+func (m *mockSubscribeClient) Send(n *proto.Node) error { return m.Called(n).Error(0) }
+func (m *mockSubscribeClient) Recv() (*proto.Events, error) {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(*proto.Events), args.Error(1)
+}
+func (m *mockSubscribeClient) Context() context.Context { return m.Called().Get(0).(context.Context) }
+
+type mockUserServiceClient struct {
+ mock.Mock
+}
+
+func (m *mockUserServiceClient) Check(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) {
+ args := m.Called(ctx, in, opts)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(*proto.CheckResponse), args.Error(1)
+}
+
+type mockHealthClient struct {
+ mock.Mock
+}
+
+func (m *mockHealthClient) Check(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) {
+ args := m.Called(ctx, in, opts)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(*grpc_health_v1.HealthCheckResponse), args.Error(1)
+}
+
+func (m *mockHealthClient) Watch(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (grpc_health_v1.Health_WatchClient, error) {
+ args := m.Called(ctx, in, opts)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(grpc_health_v1.Health_WatchClient), args.Error(1)
+}
+
+func (m *mockHealthClient) List(ctx context.Context, in *grpc_health_v1.HealthListRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthListResponse, error) {
+ args := m.Called(ctx, in, opts)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(*grpc_health_v1.HealthListResponse), args.Error(1)
+}
+
+func TestProtoToTunnelType(t *testing.T) {
+ c := &client{}
+ tests := []struct {
+ name string
+ input proto.TunnelType
+ want types.TunnelType
+ wantErr bool
+ }{
+ {"HTTP", proto.TunnelType_HTTP, types.TunnelTypeHTTP, false},
+ {"TCP", proto.TunnelType_TCP, types.TunnelTypeTCP, false},
+ {"Unknown", proto.TunnelType(999), types.TunnelTypeUNKNOWN, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got, err := c.protoToTunnelType(tt.input)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("protoToTunnelType() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if got != tt.want {
+ t.Errorf("protoToTunnelType() got = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIsConnectionError(t *testing.T) {
+ c := &client{}
+ tests := []struct {
+ name string
+ closing bool
+ err error
+ want bool
+ }{
+ {"NilError", false, nil, false},
+ {"Closing", true, io.EOF, false},
+ {"EOF", false, io.EOF, true},
+ {"Unavailable", false, status.Error(codes.Unavailable, "unavailable"), true},
+ {"Canceled", false, status.Error(codes.Canceled, "canceled"), true},
+ {"DeadlineExceeded", false, status.Error(codes.DeadlineExceeded, "deadline"), true},
+ {"Internal", false, status.Error(codes.Internal, "internal"), false},
+ {"WrappedEOF", false, errors.New("wrapped: " + io.EOF.Error()), false},
+ }
+
+ tests[7].err = fmt.Errorf("wrapped: %w", io.EOF)
+ tests[7].want = true
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ c.closing = tt.closing
+ if got := c.isConnectionError(tt.err); got != tt.want {
+ t.Errorf("isConnectionError() = %v, want %v for error %v", got, tt.want, tt.err)
+ }
+ })
+ }
+}
+
+func TestGrowBackoff(t *testing.T) {
+ c := &client{}
+ tests := []struct {
+ name string
+ initial time.Duration
+ want time.Duration
+ }{
+ {"NormalGrow", time.Second, 2 * time.Second},
+ {"MaxLimit", 20 * time.Second, 30 * time.Second},
+ {"AlreadyAtMax", 30 * time.Second, 30 * time.Second},
+ {"OverMax", 40 * time.Second, 30 * time.Second},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ backoff := tt.initial
+ c.growBackoff(&backoff)
+ if backoff != tt.want {
+ t.Errorf("growBackoff() = %v, want %v", backoff, tt.want)
+ }
+ })
+ }
+}
+
+func TestWait(t *testing.T) {
+ c := &client{}
+
+ t.Run("ZeroDuration", func(t *testing.T) {
+ err := c.wait(context.Background(), 0)
+ if err != nil {
+ t.Errorf("wait() zero duration error = %v", err)
+ }
+ })
+
+ t.Run("ContextCanceled", func(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ err := c.wait(ctx, time.Minute)
+ if !errors.Is(err, context.Canceled) {
+ t.Errorf("wait() context canceled error = %v", err)
+ }
+ })
+
+ t.Run("Timeout", func(t *testing.T) {
+ start := time.Now()
+ err := c.wait(context.Background(), 10*time.Millisecond)
+ if err != nil {
+ t.Errorf("wait() timeout error = %v", err)
+ }
+ if time.Since(start) < 10*time.Millisecond {
+ t.Errorf("wait() returned too early")
+ }
+ })
+}
diff --git a/internal/http/header/header_test.go b/internal/http/header/header_test.go
new file mode 100644
index 0000000..b3f9228
--- /dev/null
+++ b/internal/http/header/header_test.go
@@ -0,0 +1,227 @@
+package header
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNewRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ data []byte
+ expectErr bool
+ errContains string
+ expectMethod string
+ expectPath string
+ expectVersion string
+ expectHeaders map[string]string
+ }{
+ {
+ name: "success",
+ data: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\nX-Custom: value\r\n\r\n"),
+ expectErr: false,
+ expectMethod: "GET",
+ expectPath: "/path",
+ expectVersion: "HTTP/1.1",
+ expectHeaders: map[string]string{
+ "Host": "example.com",
+ "X-Custom": "value",
+ },
+ },
+ {
+ name: "no CRLF in start line",
+ data: []byte("GET /path HTTP/1.1"),
+ expectErr: true,
+ errContains: "no CRLF found in start line",
+ },
+ {
+ name: "invalid start line - missing method",
+ data: []byte("INVALID\r\n\r\n"),
+ expectErr: true,
+ errContains: "invalid start line: missing method",
+ },
+ {
+ name: "invalid start line - missing version",
+ data: []byte("GET /path\r\n\r\n"),
+ expectErr: true,
+ errContains: "invalid start line: missing version",
+ },
+ {
+ name: "invalid start line - multiple spaces",
+ data: []byte("GET /path HTTP/1.1\r\n\r\n"),
+ expectErr: false,
+ expectMethod: "GET",
+ expectPath: "",
+ expectVersion: "/path HTTP/1.1",
+ expectHeaders: map[string]string{},
+ },
+ {
+ name: "start line with trailing space",
+ data: []byte("GET / HTTP/1.1 \r\n\r\n"),
+ expectErr: false,
+ expectMethod: "GET",
+ expectPath: "/",
+ expectVersion: "HTTP/1.1 ",
+ expectHeaders: map[string]string{},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req, err := NewRequest(tt.data)
+ if tt.expectErr {
+ assert.Error(t, err)
+ if tt.errContains != "" {
+ assert.Contains(t, err.Error(), tt.errContains)
+ }
+ assert.Nil(t, req)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, req)
+ assert.Equal(t, tt.expectMethod, req.Method())
+ assert.Equal(t, tt.expectPath, req.Path())
+ assert.Equal(t, tt.expectVersion, req.Version())
+ for k, v := range tt.expectHeaders {
+ assert.Equal(t, v, req.Value(k))
+ }
+ }
+ })
+ }
+}
+
+func TestRequestHeaderMethods(t *testing.T) {
+ data := []byte("GET / HTTP/1.1\r\nHost: original\r\n\r\n")
+ req, _ := NewRequest(data)
+
+ req.Set("Host", "updated")
+ req.Set("X-New", "new-value")
+ assert.Equal(t, "updated", req.Value("Host"))
+ assert.Equal(t, "new-value", req.Value("X-New"))
+
+ assert.Equal(t, "", req.Value("Non-Existent"))
+
+ req.Remove("X-New")
+ assert.Equal(t, "", req.Value("X-New"))
+
+ final := req.Finalize()
+ assert.Contains(t, string(final), "GET / HTTP/1.1\r\n")
+ assert.Contains(t, string(final), "Host: updated\r\n")
+ assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
+}
+
+func TestNewResponse(t *testing.T) {
+ tests := []struct {
+ name string
+ data []byte
+ expectErr bool
+ errContains string
+ expectHeaders map[string]string
+ }{
+ {
+ name: "success",
+ data: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"),
+ expectErr: false,
+ expectHeaders: map[string]string{
+ "Content-Length": "0",
+ },
+ },
+ {
+ name: "invalid response - no CRLF",
+ data: []byte("HTTP/1.1 200 OK"),
+ expectErr: true,
+ errContains: "no CRLF found in start line",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ resp, err := NewResponse(tt.data)
+ if tt.expectErr {
+ assert.Error(t, err)
+ if tt.errContains != "" {
+ assert.Contains(t, err.Error(), tt.errContains)
+ }
+ assert.Nil(t, resp)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, resp)
+ for k, v := range tt.expectHeaders {
+ assert.Equal(t, v, resp.Value(k))
+ }
+ }
+ })
+ }
+}
+
+func TestResponseHeaderMethods(t *testing.T) {
+ data := []byte("HTTP/1.1 200 OK\r\nServer: old\r\n\r\n")
+ resp, _ := NewResponse(data)
+
+ resp.Set("Server", "new")
+ resp.Set("X-Res", "val")
+ assert.Equal(t, "new", resp.Value("Server"))
+ assert.Equal(t, "val", resp.Value("X-Res"))
+
+ resp.Remove("X-Res")
+ assert.Equal(t, "", resp.Value("X-Res"))
+
+ final := resp.Finalize()
+ assert.Contains(t, string(final), "HTTP/1.1 200 OK\r\n")
+ assert.Contains(t, string(final), "Server: new\r\n")
+ assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
+}
+
+func TestSetRemainingHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ data []byte
+ initialHeaders map[string]string
+ expectHeaders map[string]string
+ }{
+ {
+ name: "various header formats",
+ data: []byte("K1: V1\r\nK2:V2\r\n K3 : V3 \r\nNoColon\r\n\r\n"),
+ expectHeaders: map[string]string{
+ "K1": "V1",
+ "K2": "V2",
+ "K3": "V3",
+ },
+ },
+ {
+ name: "no trailing CRLF",
+ data: []byte("K1: V1"),
+ expectHeaders: map[string]string{
+ "K1": "V1",
+ },
+ },
+ {
+ name: "empty lines",
+ data: []byte("\r\nK1: V1"),
+ expectHeaders: map[string]string{},
+ },
+ {
+ name: "headers with only colon",
+ data: []byte(": value\r\nkey:\r\n"),
+ expectHeaders: map[string]string{
+ "": "value",
+ "key": "",
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ req := &requestHeader{headers: make(map[string]string)}
+ if tt.initialHeaders != nil {
+ req.headers = tt.initialHeaders
+ }
+ setRemainingHeaders(tt.data, req)
+ assert.Equal(t, len(tt.expectHeaders), len(req.headers))
+ for k, v := range tt.expectHeaders {
+ assert.Equal(t, v, req.headers[k])
+ }
+ })
+ }
+}
diff --git a/internal/http/header/parser.go b/internal/http/header/parser.go
index 861c49e..9a58d59 100644
--- a/internal/http/header/parser.go
+++ b/internal/http/header/parser.go
@@ -1,7 +1,6 @@
package header
import (
- "bufio"
"bytes"
"fmt"
)
@@ -36,31 +35,6 @@ func setRemainingHeaders(remaining []byte, header interface {
}
}
-func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) {
- header := &requestHeader{
- headers: make(map[string]string, 16),
- }
-
- lineEnd := bytes.Index(headerData, []byte("\r\n"))
- if lineEnd == -1 {
- return nil, fmt.Errorf("invalid request: no CRLF found in start line")
- }
-
- startLine := headerData[:lineEnd]
- header.startLine = startLine
- var err error
- header.method, header.path, header.version, err = parseStartLine(startLine)
- if err != nil {
- return nil, err
- }
-
- remaining := headerData[lineEnd+2:]
-
- setRemainingHeaders(remaining, header)
-
- return header, nil
-}
-
func parseStartLine(startLine []byte) (method, path, version string, err error) {
firstSpace := bytes.IndexByte(startLine, ' ')
if firstSpace == -1 {
@@ -80,51 +54,6 @@ func parseStartLine(startLine []byte) (method, path, version string, err error)
return method, path, version, nil
}
-func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) {
- header := &requestHeader{
- headers: make(map[string]string, 16),
- }
-
- startLineBytes, err := br.ReadSlice('\n')
- if err != nil {
- return nil, err
- }
-
- startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
- header.startLine = make([]byte, len(startLineBytes))
- copy(header.startLine, startLineBytes)
-
- header.method, header.path, header.version, err = parseStartLine(header.startLine)
- if err != nil {
- return nil, err
- }
-
- for {
- lineBytes, err := br.ReadSlice('\n')
- if err != nil {
- return nil, err
- }
-
- lineBytes = bytes.TrimRight(lineBytes, "\r\n")
-
- if len(lineBytes) == 0 {
- break
- }
-
- colonIdx := bytes.IndexByte(lineBytes, ':')
- if colonIdx == -1 {
- continue
- }
-
- key := bytes.TrimSpace(lineBytes[:colonIdx])
- value := bytes.TrimSpace(lineBytes[colonIdx+1:])
-
- header.headers[string(key)] = string(value)
- }
-
- return header, nil
-}
-
func finalize(startLine []byte, headers map[string]string) []byte {
size := len(startLine) + 2
for key, val := range headers {
diff --git a/internal/http/header/request.go b/internal/http/header/request.go
index 1fbe57a..e35f169 100644
--- a/internal/http/header/request.go
+++ b/internal/http/header/request.go
@@ -1,19 +1,33 @@
package header
import (
- "bufio"
+ "bytes"
"fmt"
)
-func NewRequest(r interface{}) (RequestHeader, error) {
- switch v := r.(type) {
- case []byte:
- return parseHeadersFromBytes(v)
- case *bufio.Reader:
- return parseHeadersFromReader(v)
- default:
- return nil, fmt.Errorf("unsupported type: %T", r)
+func NewRequest(headerData []byte) (RequestHeader, error) {
+ header := &requestHeader{
+ headers: make(map[string]string, 16),
}
+
+ lineEnd := bytes.Index(headerData, []byte("\r\n"))
+ if lineEnd == -1 {
+ return nil, fmt.Errorf("invalid request: no CRLF found in start line")
+ }
+
+ startLine := headerData[:lineEnd]
+ header.startLine = startLine
+ var err error
+ header.method, header.path, header.version, err = parseStartLine(startLine)
+ if err != nil {
+ return nil, err
+ }
+
+ remaining := headerData[lineEnd+2:]
+
+ setRemainingHeaders(remaining, header)
+
+ return header, nil
}
func (req *requestHeader) Value(key string) string {
diff --git a/internal/http/stream/stream.go b/internal/http/stream/stream.go
index 97d2752..d339474 100644
--- a/internal/http/stream/stream.go
+++ b/internal/http/stream/stream.go
@@ -30,7 +30,6 @@ type http struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
- headerBuf []byte
buf []byte
respHeader header.ResponseHeader
reqHeader header.RequestHeader
@@ -72,7 +71,10 @@ func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
}
func (hs *http) Close() error {
- return hs.writer.(io.Closer).Close()
+ if closer, ok := hs.writer.(io.Closer); ok {
+ return closer.Close()
+ }
+ return nil
}
func (hs *http) CloseWrite() error {
diff --git a/internal/http/stream/stream_test.go b/internal/http/stream/stream_test.go
new file mode 100644
index 0000000..60c9d65
--- /dev/null
+++ b/internal/http/stream/stream_test.go
@@ -0,0 +1,765 @@
+package stream
+
+import (
+ "bytes"
+ "io"
+ "strings"
+ "testing"
+
+ "tunnel_pls/internal/http/header"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+type MockAddr struct {
+ mock.Mock
+}
+
+func (m *MockAddr) String() string {
+ args := m.Called()
+ return args.String(0)
+}
+
+func (m *MockAddr) Network() string {
+ args := m.Called()
+ return args.String(0)
+}
+
+type MockRequestMiddleware struct {
+ mock.Mock
+}
+
+func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
+ args := m.Called(h)
+ return args.Error(0)
+}
+
+type MockResponseMiddleware struct {
+ mock.Mock
+}
+
+func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
+ args := m.Called(h, body)
+ return args.Error(0)
+}
+
+type MockReadWriter struct {
+ mock.Mock
+ bytes.Buffer
+}
+
+func (m *MockReadWriter) Read(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockReadWriter) Write(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockReadWriter) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func (m *MockReadWriter) CloseWrite() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+type MockReadWriterOnlyCloser struct {
+ mock.Mock
+ bytes.Buffer
+}
+
+func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockReadWriterOnlyCloser) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+type MockWriterOnly struct {
+ mock.Mock
+}
+
+func (m *MockWriterOnly) Write(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockWriterOnly) Read(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+type MockReader struct {
+ mock.Mock
+}
+
+func (m *MockReader) Read(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+type MockWriter struct {
+ mock.Mock
+}
+
+func (m *MockWriter) Write(p []byte) (int, error) {
+ ret := m.Called(p)
+
+ var n int
+ var err error
+
+ switch v := ret.Get(0).(type) {
+ case func([]byte) int:
+ n = v(p)
+ case int:
+ n = v
+ default:
+ n = len(p)
+ }
+
+ switch v := ret.Get(1).(type) {
+ case func([]byte) error:
+ err = v(p)
+ case error:
+ err = v
+ default:
+ err = nil
+ }
+
+ return n, err
+}
+
+func (m *MockWriter) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func TestHTTPMethods(t *testing.T) {
+ addr := new(MockAddr)
+ addr.On("String").Return("1.2.3.4:1234")
+
+ rw := new(MockReadWriter)
+ hs := New(rw, rw, addr)
+
+ assert.Equal(t, addr, hs.RemoteAddr())
+
+ reqMW := new(MockRequestMiddleware)
+ hs.UseRequestMiddleware(reqMW)
+ assert.Equal(t, 1, len(hs.RequestMiddlewares()))
+ assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
+
+ respMW := new(MockResponseMiddleware)
+ hs.UseResponseMiddleware(respMW)
+ assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
+ assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
+
+ reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
+ hs.SetRequestHeader(reqH)
+}
+
+func TestApplyMiddlewares(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware)
+ apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
+ verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
+ expectErr bool
+ }{
+ {
+ name: "apply request middleware success",
+ setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
+ reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
+ h := args.Get(0).(header.RequestHeader)
+ h.Set("X-Middleware", "true")
+ }).Return(nil)
+ hs.UseRequestMiddleware(reqMW)
+ },
+ apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
+ return hs.ApplyRequestMiddlewares(reqH)
+ },
+ verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
+ assert.Equal(t, "true", reqH.Value("X-Middleware"))
+ },
+ },
+ {
+ name: "apply response middleware success",
+ setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
+ respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
+ h := args.Get(0).(header.ResponseHeader)
+ h.Set("X-Resp-Middleware", "true")
+ }).Return(nil)
+ hs.UseResponseMiddleware(respMW)
+ },
+ apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
+ return hs.ApplyResponseMiddlewares(respH, []byte("body"))
+ },
+ verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
+ assert.Equal(t, "true", respH.Value("X-Resp-Middleware"))
+ },
+ },
+ {
+ name: "apply request middleware error",
+ setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
+ reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError)
+ hs.UseRequestMiddleware(reqMW)
+ },
+ apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
+ return hs.ApplyRequestMiddlewares(reqH)
+ },
+ expectErr: true,
+ },
+ {
+ name: "apply response middleware error",
+ setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
+ respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError)
+ hs.UseResponseMiddleware(respMW)
+ },
+ apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
+ return hs.ApplyResponseMiddlewares(respH, []byte("body"))
+ },
+ expectErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
+ respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
+
+ addr := new(MockAddr)
+ addr.On("String").Return("1.2.3.4:1234")
+
+ rw := new(MockReadWriter)
+ hs := New(rw, rw, addr)
+
+ reqMW := new(MockRequestMiddleware)
+ respMW := new(MockResponseMiddleware)
+ tt.setup(hs, reqMW, respMW)
+
+ err := tt.apply(hs, reqH, respH)
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ if tt.verify != nil {
+ tt.verify(t, reqH, respH)
+ }
+ }
+
+ reqMW.AssertExpectations(t)
+ respMW.AssertExpectations(t)
+ })
+ }
+}
+
+func TestCloseMethods(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func() (io.Writer, io.Reader)
+ op func(HTTP) error
+ verify func(*testing.T, io.Writer)
+ }{
+ {
+ name: "Close success",
+ setup: func() (io.Writer, io.Reader) {
+ rw := new(MockReadWriter)
+ rw.On("Close").Return(nil)
+ return rw, rw
+ },
+ op: func(hs HTTP) error { return hs.Close() },
+ verify: func(t *testing.T, w io.Writer) {
+ w.(*MockReadWriter).AssertCalled(t, "Close")
+ },
+ },
+ {
+ name: "CloseWrite with CloseWrite implementation",
+ setup: func() (io.Writer, io.Reader) {
+ rw := new(MockReadWriter)
+ rw.On("CloseWrite").Return(nil)
+ return rw, rw
+ },
+ op: func(hs HTTP) error { return hs.CloseWrite() },
+ verify: func(t *testing.T, w io.Writer) {
+ w.(*MockReadWriter).AssertCalled(t, "CloseWrite")
+ },
+ },
+ {
+ name: "CloseWrite fallback to Close",
+ setup: func() (io.Writer, io.Reader) {
+ rw := new(MockReadWriterOnlyCloser)
+ rw.On("Close").Return(nil)
+ return rw, rw
+ },
+ op: func(hs HTTP) error { return hs.CloseWrite() },
+ verify: func(t *testing.T, w io.Writer) {
+ w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close")
+ },
+ },
+ {
+ name: "Close with No Closer",
+ setup: func() (io.Writer, io.Reader) {
+ w := new(MockWriterOnly)
+ r := new(MockReader)
+ return w, r
+ },
+ op: func(hs HTTP) error { return hs.Close() },
+ },
+ {
+ name: "CloseWrite with No CloseWrite and No Closer",
+ setup: func() (io.Writer, io.Reader) {
+ w := new(MockWriterOnly)
+ r := new(MockReader)
+ return w, r
+ },
+ op: func(hs HTTP) error { return hs.CloseWrite() },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := new(MockAddr)
+ addr.On("String").Return("1.2.3.4:1234")
+
+ w, r := tt.setup()
+ hs := New(w, r, addr)
+
+ assert.NotPanics(t, func() {
+ err := tt.op(hs)
+ assert.NoError(t, err)
+ })
+
+ if tt.verify != nil {
+ tt.verify(t, w)
+ }
+ })
+ }
+}
+
+func TestSplitHeaderAndBody(t *testing.T) {
+ tests := []struct {
+ name string
+ data []byte
+ delimiterIdx int
+ expectHeader []byte
+ expectBody []byte
+ }{
+ {
+ name: "standard",
+ data: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nBodyContent"),
+ delimiterIdx: 31,
+ expectHeader: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"),
+ expectBody: []byte("BodyContent"),
+ },
+ {
+ name: "empty body",
+ data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
+ delimiterIdx: 15,
+ expectHeader: []byte("HTTP/1.1 200 OK\r\n\r\n"),
+ expectBody: []byte(""),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ h, b := splitHeaderAndBody(tt.data, tt.delimiterIdx)
+ assert.Equal(t, tt.expectHeader, h)
+ assert.Equal(t, tt.expectBody, b)
+ })
+ }
+}
+
+func TestIsHTTPHeader(t *testing.T) {
+ tests := []struct {
+ name string
+ buf []byte
+ expect bool
+ }{
+ {
+ name: "valid request",
+ buf: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"),
+ expect: true,
+ },
+ {
+ name: "valid response",
+ buf: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"),
+ expect: true,
+ },
+ {
+ name: "invalid start line",
+ buf: []byte("NOT_HTTP /path\r\nHost: example.com\r\n\r\n"),
+ expect: false,
+ },
+ {
+ name: "invalid header line (no colon)",
+ buf: []byte("GET / HTTP/1.1\r\nInvalidHeaderLine\r\n\r\n"),
+ expect: false,
+ },
+ {
+ name: "invalid header line (colon at 0)",
+ buf: []byte("GET / HTTP/1.1\r\n: value\r\n\r\n"),
+ expect: false,
+ },
+ {
+ name: "empty header section",
+ buf: []byte("GET / HTTP/1.1\r\n\r\n"),
+ expect: true,
+ },
+ {
+ name: "multiple headers",
+ buf: []byte("GET / HTTP/1.1\r\nH1: V1\r\nH2: V2\r\n\r\n"),
+ expect: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := isHTTPHeader(tt.buf)
+ assert.Equal(t, tt.expect, result)
+ })
+ }
+}
+
+func TestRead(t *testing.T) {
+ tests := []struct {
+ name string
+ input []byte
+ readLen int
+ expectContent string
+ expectRead int
+ expectErr bool
+ middlewareErr error
+ isHTTP bool
+ }{
+ {
+ name: "valid http request",
+ input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\nBody"),
+ readLen: 100,
+ expectContent: "Body",
+ expectRead: 54,
+ isHTTP: true,
+ },
+ {
+ name: "non-http data",
+ input: []byte("Some random data\r\n\r\nMore data"),
+ readLen: 100,
+ expectContent: "Some random data\r\n\r\nMore data",
+ expectRead: 29,
+ isHTTP: false,
+ },
+ {
+ name: "no delimiter",
+ input: []byte("Partial data without delimiter"),
+ readLen: 100,
+ expectContent: "Partial data without delimiter",
+ expectRead: 30,
+ isHTTP: false,
+ },
+ {
+ name: "middleware error",
+ input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\n"),
+ readLen: 100,
+ middlewareErr: assert.AnError,
+ expectErr: true,
+ isHTTP: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := new(MockAddr)
+ addr.On("String").Return("1.2.3.4:1234")
+
+ reader := new(MockReader)
+ writer := new(MockWriterOnly)
+
+ if tt.expectErr || tt.name == "valid http request" {
+ reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
+ p := args.Get(0).([]byte)
+ copy(p, tt.input)
+ }).Return(len(tt.input), io.EOF).Once()
+ } else {
+ reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
+ p := args.Get(0).([]byte)
+ copy(p, tt.input)
+ }).Return(len(tt.input), nil).Once()
+ }
+
+ hs := New(writer, reader, addr)
+
+ reqMW := new(MockRequestMiddleware)
+ if tt.isHTTP {
+ if tt.middlewareErr != nil {
+ reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr)
+ } else {
+ reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
+ h := args.Get(0).(header.RequestHeader)
+ h.Set("X-Middleware", "true")
+ }).Return(nil)
+ }
+ }
+ hs.UseRequestMiddleware(reqMW)
+
+ p := make([]byte, tt.readLen)
+ n, err := hs.Read(p)
+
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expectRead, n)
+ if tt.name == "valid http request" {
+ content := string(p[:n])
+ assert.Contains(t, content, "GET / HTTP/1.1\r\n")
+ assert.Contains(t, content, "Host: test\r\n")
+ assert.Contains(t, content, "X-Middleware: true\r\n")
+ assert.True(t, bytes.HasSuffix(p[:n], []byte("\r\n\r\nBody")))
+ } else {
+ assert.Equal(t, tt.expectContent, string(p[:n]))
+ }
+ }
+
+ if tt.isHTTP {
+ reqMW.AssertExpectations(t)
+ }
+ reader.AssertExpectations(t)
+ })
+ }
+}
+
+func TestWrite(t *testing.T) {
+ tests := []struct {
+ name string
+ writes [][]byte
+ expectWritten string
+ expectErr bool
+ middlewareErr error
+ isHTTP bool
+ }{
+ {
+ name: "valid http response in one write",
+ writes: [][]byte{
+ []byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
+ },
+ expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
+ isHTTP: true,
+ },
+ {
+ name: "valid http response in multiple writes",
+ writes: [][]byte{
+ []byte("HTTP/1.1 200 OK\r\n"),
+ []byte("Content-Length: 4\r\n\r\n"),
+ []byte("Body"),
+ },
+ expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
+ isHTTP: true,
+ },
+ {
+ name: "non-http data",
+ writes: [][]byte{
+ []byte("Random data with delimiter\r\n\r\nFlush"),
+ },
+ expectWritten: "Random data with delimiter\r\n\r\nFlush",
+ isHTTP: false,
+ },
+ {
+ name: "bypass buffering",
+ writes: [][]byte{
+ []byte("HTTP/1.1 200 OK\r\n\r\n"),
+ []byte("HTTP/1.1 200 OK\r\n\r\n"),
+ },
+ expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
+ "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
+ isHTTP: true,
+ },
+ {
+ name: "middleware error",
+ writes: [][]byte{
+ []byte("HTTP/1.1 200 OK\r\n\r\n"),
+ },
+ middlewareErr: assert.AnError,
+ expectErr: true,
+ isHTTP: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := new(MockAddr)
+ addr.On("String").Return("1.2.3.4:1234")
+
+ var writtenData bytes.Buffer
+ writer := new(MockWriter)
+
+ writer.On("Write", mock.Anything).Run(func(args mock.Arguments) {
+ p := args.Get(0).([]byte)
+ writtenData.Write(p)
+ }).Return(func(p []byte) int {
+ return len(p)
+ }, nil)
+
+ reader := new(MockReader)
+ hs := New(writer, reader, addr)
+
+ respMW := new(MockResponseMiddleware)
+ if tt.isHTTP {
+ if tt.middlewareErr != nil {
+ respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr)
+ } else {
+ respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
+ h := args.Get(0).(header.ResponseHeader)
+ h.Set("X-Resp-Middleware", "true")
+ }).Return(nil)
+ }
+ }
+ hs.UseResponseMiddleware(respMW)
+
+ var totalN int
+ var err error
+ for _, w := range tt.writes {
+ var n int
+ n, err = hs.Write(w)
+ if err != nil {
+ break
+ }
+ totalN += n
+ }
+
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ written := writtenData.String()
+ if strings.HasPrefix(tt.expectWritten, "HTTP/") {
+ assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
+ assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
+ if strings.Contains(tt.expectWritten, "Content-Length: 4") {
+ assert.Contains(t, written, "Content-Length: 4\r\n")
+ }
+ assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
+ } else {
+ assert.Equal(t, tt.expectWritten, written)
+ }
+ }
+
+ if tt.isHTTP {
+ respMW.AssertExpectations(t)
+ }
+ if tt.middlewareErr == nil {
+ writer.AssertExpectations(t)
+ }
+ })
+ }
+}
+
+func TestWriteErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func() (io.Writer, io.Reader)
+ data []byte
+ }{
+ {
+ name: "write error in writeHeaderAndBody",
+ setup: func() (io.Writer, io.Reader) {
+ writer := new(MockWriter)
+ writer.On("Write", mock.Anything).Return(0, assert.AnError)
+ reader := new(MockReader)
+ return writer, reader
+ },
+ data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
+ },
+ {
+ name: "write error in writeHeaderAndBody second write",
+ setup: func() (io.Writer, io.Reader) {
+ writer := new(MockWriter)
+ writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once()
+ writer.On("Write", mock.Anything).Return(0, assert.AnError).Once()
+ reader := new(MockReader)
+ return writer, reader
+ },
+ data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
+ },
+ {
+ name: "write error in writeRawBuffer",
+ setup: func() (io.Writer, io.Reader) {
+ writer := new(MockWriter)
+ writer.On("Write", mock.Anything).Return(0, assert.AnError)
+ reader := new(MockReader)
+ return writer, reader
+ },
+ data: []byte("Not HTTP\r\n\r\nFlush"),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := new(MockAddr)
+ addr.On("String").Return("1.2.3.4:1234")
+
+ w, r := tt.setup()
+ hs := New(w, r, addr)
+
+ _, err := hs.Write(tt.data)
+ assert.Error(t, err)
+
+ w.(*MockWriter).AssertExpectations(t)
+ })
+ }
+}
+
+func TestReadEOF(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func() io.Reader
+ expectN int
+ expectErr error
+ expectContent string
+ }{
+ {
+ name: "read eof",
+ setup: func() io.Reader {
+ reader := new(MockReader)
+ reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
+ p := args.Get(0).([]byte)
+ copy(p, "data")
+ }).Return(4, io.EOF)
+ return reader
+ },
+ expectN: 4,
+ expectErr: io.EOF,
+ expectContent: "data",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ addr := new(MockAddr)
+ addr.On("String").Return("1.2.3.4:1234")
+
+ reader := tt.setup()
+ hs := New(nil, reader, addr)
+
+ p := make([]byte, 100)
+ n, err := hs.Read(p)
+
+ assert.Equal(t, tt.expectN, n)
+ assert.Equal(t, tt.expectErr, err)
+ assert.Equal(t, tt.expectContent, string(p[:n]))
+
+ reader.(*MockReader).AssertExpectations(t)
+ })
+ }
+}
diff --git a/internal/key/key.go b/internal/key/key.go
index 659abe3..b1c387a 100644
--- a/internal/key/key.go
+++ b/internal/key/key.go
@@ -5,6 +5,8 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
+ "errors"
+ "io"
"log"
"os"
"path/filepath"
@@ -12,7 +14,20 @@ import (
"golang.org/x/crypto/ssh"
)
+var (
+ rsaGenerateKey = rsa.GenerateKey
+ pemEncode = pem.Encode
+ sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
+ return ssh.NewPublicKey(key)
+ }
+ pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
+ return w.Write(data)
+ }
+ osOpenFile = os.OpenFile
+)
+
func GenerateSSHKeyIfNotExist(keyPath string) error {
+ var errGroup = make([]error, 0)
if _, err := os.Stat(keyPath); err == nil {
log.Printf("SSH key already exists at %s", keyPath)
return nil
@@ -20,7 +35,7 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
log.Printf("SSH key not found at %s, generating new key pair...", keyPath)
- privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
+ privateKey, err := rsaGenerateKey(rand.Reader, 4096)
if err != nil {
return err
}
@@ -35,33 +50,37 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
return err
}
- privateKeyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
+ privateKeyFile, err := osOpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
- defer privateKeyFile.Close()
+ defer func(privateKeyFile *os.File) {
+ errGroup = append(errGroup, privateKeyFile.Close())
+ }(privateKeyFile)
- if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil {
+ if err := pemEncode(privateKeyFile, privateKeyPEM); err != nil {
return err
}
- publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
+ publicKey, err := sshNewPublicKey(&privateKey.PublicKey)
if err != nil {
return err
}
pubKeyPath := keyPath + ".pub"
- pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
+ pubKeyFile, err := osOpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return err
}
- defer pubKeyFile.Close()
+ defer func(pubKeyFile *os.File) {
+ errGroup = append(errGroup, pubKeyFile.Close())
+ }(pubKeyFile)
- _, err = pubKeyFile.Write(ssh.MarshalAuthorizedKey(publicKey))
+ _, err = pubKeyWrite(pubKeyFile, ssh.MarshalAuthorizedKey(publicKey))
if err != nil {
return err
}
log.Printf("SSH key pair generated successfully at %s and %s", keyPath, pubKeyPath)
- return nil
+ return errors.Join(errGroup...)
}
diff --git a/internal/key/key_test.go b/internal/key/key_test.go
new file mode 100644
index 0000000..d28c33b
--- /dev/null
+++ b/internal/key/key_test.go
@@ -0,0 +1,235 @@
+package key
+
+import (
+ "crypto/rsa"
+ "encoding/pem"
+ "errors"
+ "io"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "golang.org/x/crypto/ssh"
+)
+
+func TestGenerateSSHKeyIfNotExist(t *testing.T) {
+ tempDir := t.TempDir()
+
+ tests := []struct {
+ name string
+ setup func(t *testing.T, tempDir string) string
+ mockSetup func() func()
+ wantErr bool
+ errStr string
+ verify func(t *testing.T, keyPath string)
+ }{
+ {
+ name: "GenerateNewKey",
+ setup: func(t *testing.T, tempDir string) string {
+ return filepath.Join(tempDir, "id_rsa")
+ },
+ verify: func(t *testing.T, keyPath string) {
+ pubKeyPath := keyPath + ".pub"
+ if _, err := os.Stat(keyPath); os.IsNotExist(err) {
+ t.Errorf("Private key file not created")
+ }
+ if _, err := os.Stat(pubKeyPath); os.IsNotExist(err) {
+ t.Errorf("Public key file not created")
+ }
+ privateKeyBytes, err := os.ReadFile(keyPath)
+ if err != nil {
+ t.Fatalf("Failed to read private key: %v", err)
+ }
+ if _, err = ssh.ParseRawPrivateKey(privateKeyBytes); err != nil {
+ t.Errorf("Failed to parse private key: %v", err)
+ }
+ publicKeyBytes, err := os.ReadFile(pubKeyPath)
+ if err != nil {
+ t.Fatalf("Failed to read public key: %v", err)
+ }
+ if _, _, _, _, err = ssh.ParseAuthorizedKey(publicKeyBytes); err != nil {
+ t.Errorf("Failed to parse public key: %v", err)
+ }
+ },
+ },
+ {
+ name: "DoNotOverwriteExistingKey",
+ setup: func(t *testing.T, tempDir string) string {
+ keyPath := filepath.Join(tempDir, "existing_id_rsa")
+ dummyPrivate := "dummy private"
+ dummyPublic := "dummy public"
+ if err := os.WriteFile(keyPath, []byte(dummyPrivate), 0600); err != nil {
+ t.Fatalf("Failed to create dummy private key: %v", err)
+ }
+ if err := os.WriteFile(keyPath+".pub", []byte(dummyPublic), 0644); err != nil {
+ t.Fatalf("Failed to create dummy public key: %v", err)
+ }
+ return keyPath
+ },
+ verify: func(t *testing.T, keyPath string) {
+ gotPrivate, _ := os.ReadFile(keyPath)
+ if string(gotPrivate) != "dummy private" {
+ t.Errorf("Private key was overwritten")
+ }
+ gotPublic, _ := os.ReadFile(keyPath + ".pub")
+ if string(gotPublic) != "dummy public" {
+ t.Errorf("Public key was overwritten")
+ }
+ },
+ },
+ {
+ name: "CreateNestedDirectories",
+ setup: func(t *testing.T, tempDir string) string {
+ return filepath.Join(tempDir, "nested", "dir", "id_rsa")
+ },
+ verify: func(t *testing.T, keyPath string) {
+ if _, err := os.Stat(keyPath); os.IsNotExist(err) {
+ t.Errorf("Private key file not created in nested directory")
+ }
+ },
+ },
+ {
+ name: "FailureMkdirAll",
+ setup: func(t *testing.T, tempDir string) string {
+ dirPath := filepath.Join(tempDir, "file_as_dir")
+ if err := os.WriteFile(dirPath, []byte("not a dir"), 0644); err != nil {
+ t.Fatalf("Failed to create file: %v", err)
+ }
+ return filepath.Join(dirPath, "id_rsa")
+ },
+ wantErr: true,
+ },
+ {
+ name: "PrivateExistsPublicMissing",
+ setup: func(t *testing.T, tempDir string) string {
+ keyPath := filepath.Join(tempDir, "partial_id_rsa")
+ if err := os.WriteFile(keyPath, []byte("private"), 0600); err != nil {
+ t.Fatalf("Failed to create private key: %v", err)
+ }
+ return keyPath
+ },
+ verify: func(t *testing.T, keyPath string) {
+ if _, err := os.Stat(keyPath + ".pub"); !os.IsNotExist(err) {
+ t.Errorf("Public key should NOT have been created if private key existed")
+ }
+ },
+ },
+ {
+ name: "FailureRSAGenerateKey",
+ setup: func(t *testing.T, tempDir string) string {
+ return filepath.Join(tempDir, "fail_rsa")
+ },
+ mockSetup: func() func() {
+ old := rsaGenerateKey
+ rsaGenerateKey = func(random io.Reader, bits int) (*rsa.PrivateKey, error) {
+ return nil, errors.New("rsa error")
+ }
+ return func() { rsaGenerateKey = old }
+ },
+ wantErr: true,
+ errStr: "rsa error",
+ },
+ {
+ name: "FailureOpenFilePrivate",
+ setup: func(t *testing.T, tempDir string) string {
+ return filepath.Join(tempDir, "fail_open_private")
+ },
+ mockSetup: func() func() {
+ old := osOpenFile
+ osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
+ return nil, errors.New("open error")
+ }
+ return func() { osOpenFile = old }
+ },
+ wantErr: true,
+ errStr: "open error",
+ },
+ {
+ name: "FailurePemEncode",
+ setup: func(t *testing.T, tempDir string) string {
+ return filepath.Join(tempDir, "fail_pem")
+ },
+ mockSetup: func() func() {
+ old := pemEncode
+ pemEncode = func(out io.Writer, b *pem.Block) error {
+ return errors.New("pem error")
+ }
+ return func() { pemEncode = old }
+ },
+ wantErr: true,
+ errStr: "pem error",
+ },
+ {
+ name: "FailureSSHNewPublicKey",
+ setup: func(t *testing.T, tempDir string) string {
+ return filepath.Join(tempDir, "fail_ssh")
+ },
+ mockSetup: func() func() {
+ old := sshNewPublicKey
+ sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
+ return nil, errors.New("ssh error")
+ }
+ return func() { sshNewPublicKey = old }
+ },
+ wantErr: true,
+ errStr: "ssh error",
+ },
+ {
+ name: "FailureOpenFilePublic",
+ setup: func(t *testing.T, tempDir string) string {
+ return filepath.Join(tempDir, "fail_open_public")
+ },
+ mockSetup: func() func() {
+ old := osOpenFile
+ osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
+ if filepath.Ext(name) == ".pub" {
+ return nil, errors.New("open pub error")
+ }
+ return os.OpenFile(name, flag, perm)
+ }
+ return func() { osOpenFile = old }
+ },
+ wantErr: true,
+ errStr: "open pub error",
+ },
+ {
+ name: "FailurePubKeyWrite",
+ setup: func(t *testing.T, tempDir string) string {
+ return filepath.Join(tempDir, "fail_write")
+ },
+ mockSetup: func() func() {
+ old := pubKeyWrite
+ pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
+ return 0, errors.New("write error")
+ }
+ return func() { pubKeyWrite = old }
+ },
+ wantErr: true,
+ errStr: "write error",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ keyPath := tt.setup(t, tempDir)
+ if tt.mockSetup != nil {
+ cleanup := tt.mockSetup()
+ defer cleanup()
+ }
+
+ err := GenerateSSHKeyIfNotExist(keyPath)
+
+ if (err != nil) != tt.wantErr {
+ t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr && tt.errStr != "" && err != nil && err.Error() != tt.errStr {
+ t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErrStr %v", err, tt.errStr)
+ }
+
+ if tt.verify != nil {
+ tt.verify(t, keyPath)
+ }
+ })
+ }
+}
diff --git a/internal/middleware/forwardedfor_test.go b/internal/middleware/forwardedfor_test.go
new file mode 100644
index 0000000..49f9980
--- /dev/null
+++ b/internal/middleware/forwardedfor_test.go
@@ -0,0 +1,126 @@
+package middleware
+
+import (
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+type mockRequestHeader struct {
+ mock.Mock
+}
+
+func (m *mockRequestHeader) Value(key string) string {
+ return m.Called(key).String(0)
+}
+
+func (m *mockRequestHeader) Set(key string, value string) {
+ m.Called(key, value)
+}
+
+func (m *mockRequestHeader) Remove(key string) {
+ m.Called(key)
+}
+
+func (m *mockRequestHeader) Finalize() []byte {
+ return m.Called().Get(0).([]byte)
+}
+
+func (m *mockRequestHeader) Method() string {
+ return m.Called().String(0)
+}
+
+func (m *mockRequestHeader) Path() string {
+ return m.Called().String(0)
+}
+
+func (m *mockRequestHeader) Version() string {
+ return m.Called().String(0)
+}
+
+func TestForwardedFor_HandleRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ addr net.Addr
+ expectedHost string
+ expectError bool
+ }{
+ {
+ name: "valid IPv4 address",
+ addr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 8080},
+ expectedHost: "192.168.1.100",
+ expectError: false,
+ },
+ {
+ name: "valid IPv6 address",
+ addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 8080},
+ expectedHost: "2001:db8::ff00:42:8329",
+ expectError: false,
+ },
+ {
+ name: "invalid address format",
+ addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
+ expectedHost: "",
+ expectError: true,
+ },
+ {
+ name: "valid IPv4 address with port",
+ addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
+ expectedHost: "127.0.0.1",
+ expectError: false,
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ ff := NewForwardedFor(tc.addr)
+ reqHeader := new(mockRequestHeader)
+
+ if !tc.expectError {
+ reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return()
+ }
+
+ err := ff.HandleRequest(reqHeader)
+
+ if tc.expectError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ reqHeader.AssertExpectations(t)
+ }
+ })
+ }
+}
+
+func TestNewForwardedFor(t *testing.T) {
+ tests := []struct {
+ name string
+ addr net.Addr
+ expectAddr net.Addr
+ }{
+ {
+ name: "IPv4 address",
+ addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
+ expectAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
+ },
+ {
+ name: "IPv6 address",
+ addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
+ expectAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
+ },
+ {
+ name: "Unix address",
+ addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
+ expectAddr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
+ },
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ ff := NewForwardedFor(tc.addr)
+ assert.Equal(t, tc.expectAddr.String(), ff.addr.String())
+ })
+ }
+}
diff --git a/internal/middleware/tunnelfingerprint_test.go b/internal/middleware/tunnelfingerprint_test.go
new file mode 100644
index 0000000..0054d1e
--- /dev/null
+++ b/internal/middleware/tunnelfingerprint_test.go
@@ -0,0 +1,70 @@
+package middleware
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+type mockResponseHeader struct {
+ mock.Mock
+}
+
+func (m *mockResponseHeader) Value(key string) string {
+ return m.Called(key).String(0)
+}
+
+func (m *mockResponseHeader) Set(key string, value string) {
+ m.Called(key, value)
+}
+
+func (m *mockResponseHeader) Remove(key string) {
+ m.Called(key)
+}
+
+func (m *mockResponseHeader) Finalize() []byte {
+ return m.Called().Get(0).([]byte)
+}
+
+func TestTunnelFingerprintHandleResponse(t *testing.T) {
+ tests := []struct {
+ name string
+ expected map[string]string
+ body []byte
+ wantErr error
+ }{
+ {
+ name: "Sets Server Header",
+ expected: map[string]string{"Server": "Tunnel Please"},
+ body: []byte("Sample body"),
+ wantErr: nil,
+ },
+ {
+ name: "Overwrites Server Header",
+ expected: map[string]string{"Server": "Tunnel Please"},
+ body: nil,
+ wantErr: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockHeader := new(mockResponseHeader)
+ for k, v := range tt.expected {
+ mockHeader.On("Set", k, v).Return()
+ }
+
+ tunnelFingerprint := NewTunnelFingerprint()
+
+ err := tunnelFingerprint.HandleResponse(mockHeader, tt.body)
+ assert.ErrorIs(t, err, tt.wantErr)
+ mockHeader.AssertExpectations(t)
+ })
+ }
+}
+
+func TestNewTunnelFingerprint(t *testing.T) {
+ instance := NewTunnelFingerprint()
+ assert.NotNil(t, instance)
+}
diff --git a/internal/port/port_test.go b/internal/port/port_test.go
new file mode 100644
index 0000000..fcc64d3
--- /dev/null
+++ b/internal/port/port_test.go
@@ -0,0 +1,114 @@
+package port
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestAddRange(t *testing.T) {
+ tests := []struct {
+ name string
+ startPort uint16
+ endPort uint16
+ wantErr bool
+ }{
+ {"normal range", 1000, 1002, false},
+ {"invalid range", 2000, 1999, true},
+ {"single port range", 3000, 3000, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ pm := New()
+ err := pm.AddRange(tt.startPort, tt.endPort)
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestUnassigned(t *testing.T) {
+ pm := New()
+ _ = pm.AddRange(1000, 1002)
+
+ tests := []struct {
+ name string
+ status map[uint16]bool
+ want uint16
+ wantOk bool
+ }{
+ {"all unassigned", map[uint16]bool{1000: false, 1001: false, 1002: false}, 1000, true},
+ {"some assigned", map[uint16]bool{1000: true, 1001: false, 1002: true}, 1001, true},
+ {"all assigned", map[uint16]bool{1000: true, 1001: true, 1002: true}, 0, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ for k, v := range tt.status {
+ _ = pm.SetStatus(k, v)
+ }
+ got, gotOk := pm.Unassigned()
+ assert.Equal(t, tt.want, got)
+ assert.Equal(t, tt.wantOk, gotOk)
+ })
+ }
+}
+
+func TestSetStatus(t *testing.T) {
+ pm := New()
+ _ = pm.AddRange(1000, 1002)
+
+ tests := []struct {
+ name string
+ port uint16
+ assigned bool
+ }{
+ {"assign port 1000", 1000, true},
+ {"unassign port 1001", 1001, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := pm.SetStatus(tt.port, tt.assigned)
+ assert.NoError(t, err)
+
+ status, ok := pm.(*port).ports[tt.port]
+ assert.True(t, ok)
+ assert.Equal(t, tt.assigned, status)
+ })
+ }
+}
+
+func TestClaim(t *testing.T) {
+ pm := New()
+ _ = pm.AddRange(1000, 1002)
+
+ tests := []struct {
+ name string
+ port uint16
+ status bool
+ want bool
+ }{
+ {"claim unassigned port", 1000, false, true},
+ {"claim already assigned port", 1001, true, false},
+ {"claim non-existent port", 5000, false, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if _, exists := pm.(*port).ports[tt.port]; exists {
+ _ = pm.SetStatus(tt.port, tt.status)
+ }
+
+ got := pm.Claim(tt.port)
+ assert.Equal(t, tt.want, got)
+
+ finalState := pm.(*port).ports[tt.port]
+ assert.True(t, finalState)
+ })
+ }
+}
diff --git a/internal/random/random.go b/internal/random/random.go
index 929cc7b..cb793d4 100644
--- a/internal/random/random.go
+++ b/internal/random/random.go
@@ -1,12 +1,35 @@
package random
-import "crypto/rand"
+import (
+ "crypto/rand"
+ "fmt"
+ "io"
+)
-func GenerateRandomString(length int) (string, error) {
+var (
+ ErrInvalidLength = fmt.Errorf("invalid length")
+)
+
+type Random interface {
+ String(length int) (string, error)
+}
+
+type random struct {
+ reader io.Reader
+}
+
+func New() Random {
+ return &random{reader: rand.Reader}
+}
+
+func (ran *random) String(length int) (string, error) {
+ if length < 0 {
+ return "", ErrInvalidLength
+ }
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length)
- if _, err := rand.Read(b); err != nil {
+ if _, err := ran.reader.Read(b); err != nil {
return "", err
}
diff --git a/internal/random/random_test.go b/internal/random/random_test.go
new file mode 100644
index 0000000..e0cd512
--- /dev/null
+++ b/internal/random/random_test.go
@@ -0,0 +1,70 @@
+package random
+
+import (
+ "io"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestRandom_String(t *testing.T) {
+ tests := []struct {
+ name string
+ length int
+ wantErr bool
+ }{
+ {"ValidLengthZero", 0, false},
+ {"ValidPositiveLength", 10, false},
+ {"NegativeLength", -1, true},
+ {"VeryLargeLength", 1_000_000, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ randomizer := New()
+
+ result, err := randomizer.String(tt.length)
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ assert.Len(t, result, tt.length)
+ }
+ })
+ }
+}
+
+func TestRandomWithFailingReader_String(t *testing.T) {
+ errBrainrot := assert.AnError
+
+ tests := []struct {
+ name string
+ reader io.Reader
+ expectErr error
+ }{
+ {
+ name: "failing reader",
+ reader: func() io.Reader {
+ return &failingReader{err: errBrainrot}
+ }(),
+ expectErr: errBrainrot,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ randomizer := &random{reader: tt.reader}
+ result, err := randomizer.String(20)
+ assert.ErrorIs(t, err, tt.expectErr)
+ assert.Empty(t, result)
+ })
+ }
+}
+
+type failingReader struct {
+ err error
+}
+
+func (f *failingReader) Read(p []byte) (int, error) {
+ return 0, f.err
+}
diff --git a/internal/registry/registry.go b/internal/registry/registry.go
index 89cac48..e12ea0b 100644
--- a/internal/registry/registry.go
+++ b/internal/registry/registry.go
@@ -94,12 +94,13 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
return ErrInvalidSlug
}
- r.mu.Lock()
- defer r.mu.Unlock()
-
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return ErrSlugInUse
}
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
client, ok := r.byUser[user][oldKey]
if !ok {
return ErrSessionNotFound
@@ -111,9 +112,6 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
client.Slug().Set(newKey.Id)
r.slugIndex[newKey] = user
- if r.byUser[user] == nil {
- r.byUser[user] = make(map[Key]Session)
- }
r.byUser[user][newKey] = client
return nil
}
diff --git a/internal/registry/registry_test.go b/internal/registry/registry_test.go
new file mode 100644
index 0000000..484b4e9
--- /dev/null
+++ b/internal/registry/registry_test.go
@@ -0,0 +1,695 @@
+package registry
+
+import (
+ "sync"
+ "testing"
+ "time"
+ "tunnel_pls/internal/port"
+ "tunnel_pls/session/forwarder"
+ "tunnel_pls/session/interaction"
+ "tunnel_pls/session/lifecycle"
+ "tunnel_pls/session/slug"
+ "tunnel_pls/types"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+
+ "golang.org/x/crypto/ssh"
+)
+
+type mockSession struct {
+ mock.Mock
+}
+
+func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(lifecycle.Lifecycle)
+}
+func (m *mockSession) Interaction() interaction.Interaction {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(interaction.Interaction)
+}
+func (m *mockSession) Forwarder() forwarder.Forwarder {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(forwarder.Forwarder)
+}
+func (m *mockSession) Slug() slug.Slug {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(slug.Slug)
+}
+func (m *mockSession) Detail() *types.Detail {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(*types.Detail)
+}
+
+type mockLifecycle struct {
+ mock.Mock
+}
+
+func (ml *mockLifecycle) Channel() ssh.Channel {
+ args := ml.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(ssh.Channel)
+}
+
+func (ml *mockLifecycle) Connection() ssh.Conn {
+ args := ml.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(ssh.Conn)
+}
+
+func (ml *mockLifecycle) PortRegistry() port.Port {
+ args := ml.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(port.Port)
+}
+
+func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { ml.Called(channel) }
+func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { ml.Called(status) }
+func (ml *mockLifecycle) IsActive() bool { return ml.Called().Bool(0) }
+func (ml *mockLifecycle) StartedAt() time.Time { return ml.Called().Get(0).(time.Time) }
+func (ml *mockLifecycle) Close() error { return ml.Called().Error(0) }
+func (ml *mockLifecycle) User() string { return ml.Called().String(0) }
+
+type mockSlug struct {
+ mock.Mock
+}
+
+func (ms *mockSlug) Set(slug string) { ms.Called(slug) }
+func (ms *mockSlug) String() string { return ms.Called().String(0) }
+
+func createMockSession(user ...string) *mockSession {
+ u := "user1"
+ if len(user) > 0 {
+ u = user[0]
+ }
+ m := new(mockSession)
+ ml := new(mockLifecycle)
+ ml.On("User").Return(u).Maybe()
+ m.On("Lifecycle").Return(ml).Maybe()
+ ms := new(mockSlug)
+ ms.On("Set", mock.Anything).Maybe()
+ m.On("Slug").Return(ms).Maybe()
+ m.On("Interaction").Return(nil).Maybe()
+ m.On("Forwarder").Return(nil).Maybe()
+ m.On("Detail").Return(nil).Maybe()
+ return m
+}
+
+func TestNewRegistry(t *testing.T) {
+ r := NewRegistry()
+ require.NotNil(t, r)
+}
+
+func TestRegistry_Get(t *testing.T) {
+ tests := []struct {
+ name string
+ setupFunc func(r *registry)
+ key types.SessionKey
+ wantErr error
+ wantResult bool
+ }{
+ {
+ name: "session found",
+ setupFunc: func(r *registry) {
+ user := "user1"
+ key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
+ session := createMockSession(user)
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser[user] = map[types.SessionKey]Session{
+ key: session,
+ }
+ r.slugIndex[key] = user
+ },
+ key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
+ wantErr: nil,
+ wantResult: true,
+ },
+ {
+ name: "session not found in slugIndex",
+ setupFunc: func(r *registry) {},
+ key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
+ wantErr: ErrSessionNotFound,
+ },
+ {
+ name: "session not found in byUser",
+ setupFunc: func(r *registry) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
+ },
+ key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
+ wantErr: ErrSessionNotFound,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := ®istry{
+ byUser: make(map[string]map[types.SessionKey]Session),
+ slugIndex: make(map[types.SessionKey]string),
+ mu: sync.RWMutex{},
+ }
+ tt.setupFunc(r)
+
+ session, err := r.Get(tt.key)
+
+ assert.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.wantResult, session != nil)
+ })
+ }
+}
+
+func TestRegistry_GetWithUser(t *testing.T) {
+ tests := []struct {
+ name string
+ setupFunc func(r *registry)
+ user string
+ key types.SessionKey
+ wantErr error
+ wantResult bool
+ }{
+ {
+ name: "session found",
+ setupFunc: func(r *registry) {
+ user := "user1"
+ key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
+ session := createMockSession()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser[user] = map[types.SessionKey]Session{
+ key: session,
+ }
+ r.slugIndex[key] = user
+ },
+ user: "user1",
+ key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
+ wantErr: nil,
+ wantResult: true,
+ },
+ {
+ name: "session not found in slugIndex",
+ setupFunc: func(r *registry) {},
+ user: "user1",
+ key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
+ wantErr: ErrSessionNotFound,
+ },
+ {
+ name: "session not found in byUser",
+ setupFunc: func(r *registry) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
+ },
+ user: "user1",
+ key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
+ wantErr: ErrSessionNotFound,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := ®istry{
+ byUser: make(map[string]map[types.SessionKey]Session),
+ slugIndex: make(map[types.SessionKey]string),
+ mu: sync.RWMutex{},
+ }
+ tt.setupFunc(r)
+
+ session, err := r.GetWithUser(tt.user, tt.key)
+
+ assert.ErrorIs(t, err, tt.wantErr)
+ assert.Equal(t, tt.wantResult, session != nil)
+ })
+ }
+}
+
+func TestRegistry_Update(t *testing.T) {
+ tests := []struct {
+ name string
+ user string
+ setupFunc func(r *registry) (oldKey, newKey types.SessionKey)
+ wantErr error
+ }{
+ {
+ name: "change slug success",
+ user: "user1",
+ setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
+ oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
+ newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
+ session := createMockSession("user1")
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser["user1"] = map[types.SessionKey]Session{
+ oldKey: session,
+ }
+ r.slugIndex[oldKey] = "user1"
+
+ return oldKey, newKey
+ },
+ wantErr: nil,
+ },
+ {
+ name: "change slug to already used slug",
+ user: "user1",
+ setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
+ oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
+ newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
+ session := createMockSession()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser["user1"] = map[types.SessionKey]Session{
+ oldKey: session,
+ newKey: session,
+ }
+ r.slugIndex[oldKey] = "user1"
+ r.slugIndex[newKey] = "user1"
+
+ return oldKey, newKey
+ },
+ wantErr: ErrSlugInUse,
+ },
+ {
+ name: "change slug to forbidden slug",
+ user: "user1",
+ setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
+ oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
+ newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP}
+ session := createMockSession()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser["user1"] = map[types.SessionKey]Session{
+ oldKey: session,
+ }
+ r.slugIndex[oldKey] = "user1"
+
+ return oldKey, newKey
+ },
+ wantErr: ErrForbiddenSlug,
+ },
+ {
+ name: "change slug to invalid slug",
+ user: "user1",
+ setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
+ oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
+ newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP}
+ session := createMockSession()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser["user1"] = map[types.SessionKey]Session{
+ oldKey: session,
+ }
+ r.slugIndex[oldKey] = "user1"
+
+ return oldKey, newKey
+ },
+ wantErr: ErrInvalidSlug,
+ },
+ {
+ name: "change slug but session not found",
+ user: "user2",
+ setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
+ oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
+ newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
+ session := createMockSession()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser["user1"] = map[types.SessionKey]Session{
+ types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
+ }
+ r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
+
+ return oldKey, newKey
+ },
+ wantErr: ErrSessionNotFound,
+ },
+ {
+ name: "change slug but session is not in the map",
+ user: "user2",
+ setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
+ oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
+ newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
+ session := createMockSession()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser["user1"] = map[types.SessionKey]Session{
+ types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
+ }
+ r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
+
+ return oldKey, newKey
+ },
+ wantErr: ErrSessionNotFound,
+ },
+ {
+ name: "change slug with same slug",
+ user: "user1",
+ setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
+ oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
+ newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
+ session := createMockSession()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser["user1"] = map[types.SessionKey]Session{
+ oldKey: session,
+ }
+ r.slugIndex[oldKey] = "user1"
+
+ return oldKey, newKey
+ },
+ wantErr: ErrSlugUnchanged,
+ },
+ {
+ name: "tcp tunnel cannot change slug",
+ user: "user1",
+ setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
+ oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
+ newKey := oldKey
+ session := createMockSession()
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.byUser["user1"] = map[types.SessionKey]Session{
+ oldKey: session,
+ }
+ r.slugIndex[oldKey] = "user1"
+
+ return oldKey, newKey
+ },
+ wantErr: ErrSlugChangeNotAllowed,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ r := ®istry{
+ byUser: make(map[string]map[types.SessionKey]Session),
+ slugIndex: make(map[types.SessionKey]string),
+ mu: sync.RWMutex{},
+ }
+
+ oldKey, newKey := tt.setupFunc(r)
+
+ err := r.Update(tt.user, oldKey, newKey)
+ assert.ErrorIs(t, err, tt.wantErr)
+
+ if err == nil {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ _, ok := r.byUser[tt.user][newKey]
+ assert.True(t, ok, "newKey not found in registry")
+ _, ok = r.byUser[tt.user][oldKey]
+ assert.False(t, ok, "oldKey still exists in registry")
+ }
+ })
+ }
+}
+
+func TestRegistry_Register(t *testing.T) {
+ tests := []struct {
+ name string
+ user string
+ setupFunc func(r *registry) Key
+ wantOK bool
+ }{
+ {
+ name: "register new key successfully",
+ user: "user1",
+ setupFunc: func(r *registry) Key {
+ key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
+ return key
+ },
+ wantOK: true,
+ },
+ {
+ name: "register already existing key fails",
+ user: "user1",
+ setupFunc: func(r *registry) Key {
+ key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
+ session := createMockSession()
+
+ r.mu.Lock()
+ r.byUser["user1"] = map[Key]Session{key: session}
+ r.slugIndex[key] = "user1"
+ r.mu.Unlock()
+
+ return key
+ },
+ wantOK: false,
+ },
+ {
+ name: "register multiple keys for same user",
+ user: "user1",
+ setupFunc: func(r *registry) Key {
+ firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP}
+ session := createMockSession()
+ r.mu.Lock()
+ r.byUser["user1"] = map[Key]Session{firstKey: session}
+ r.slugIndex[firstKey] = "user1"
+ r.mu.Unlock()
+
+ return types.SessionKey{Id: "second", Type: types.TunnelTypeHTTP}
+ },
+ wantOK: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+
+ r := ®istry{
+ byUser: make(map[string]map[Key]Session),
+ slugIndex: make(map[Key]string),
+ mu: sync.RWMutex{},
+ }
+
+ key := tt.setupFunc(r)
+ session := createMockSession()
+
+ ok := r.Register(key, session)
+ assert.Equal(t, tt.wantOK, ok)
+
+ if ok {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ assert.Equal(t, session, r.byUser[tt.user][key], "session not stored in byUser")
+ assert.Equal(t, tt.user, r.slugIndex[key], "slugIndex not updated")
+ }
+ })
+ }
+}
+
+func TestRegistry_GetAllSessionFromUser(t *testing.T) {
+ tests := []struct {
+ name string
+ setupFunc func(r *registry) string
+ expectN int
+ }{
+ {
+ name: "user has no sessions",
+ setupFunc: func(r *registry) string {
+ return "user1"
+ },
+ expectN: 0,
+ },
+ {
+ name: "user has multiple sessions",
+ setupFunc: func(r *registry) string {
+ user := "user1"
+ key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
+ key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
+ r.mu.Lock()
+ r.byUser[user] = map[Key]Session{
+ key1: createMockSession(),
+ key2: createMockSession(),
+ }
+ r.mu.Unlock()
+ return user
+ },
+ expectN: 2,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := ®istry{
+ byUser: make(map[string]map[Key]Session),
+ slugIndex: make(map[Key]string),
+ mu: sync.RWMutex{},
+ }
+ user := tt.setupFunc(r)
+ sessions := r.GetAllSessionFromUser(user)
+ assert.Len(t, sessions, tt.expectN)
+ })
+ }
+}
+
+func TestRegistry_Remove(t *testing.T) {
+ tests := []struct {
+ name string
+ setupFunc func(r *registry) (string, types.SessionKey)
+ key types.SessionKey
+ verify func(*testing.T, *registry, string, types.SessionKey)
+ }{
+ {
+ name: "remove existing key",
+ setupFunc: func(r *registry) (string, types.SessionKey) {
+ user := "user1"
+ key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
+ session := createMockSession()
+ r.mu.Lock()
+ r.byUser[user] = map[Key]Session{key: session}
+ r.slugIndex[key] = user
+ r.mu.Unlock()
+ return user, key
+ },
+ verify: func(t *testing.T, r *registry, user string, key types.SessionKey) {
+ _, ok := r.byUser[user][key]
+ assert.False(t, ok, "expected key to be removed from byUser")
+ _, ok = r.slugIndex[key]
+ assert.False(t, ok, "expected key to be removed from slugIndex")
+ _, ok = r.byUser[user]
+ assert.False(t, ok, "expected user to be removed from byUser map")
+ },
+ },
+ {
+ name: "remove non-existing key",
+ setupFunc: func(r *registry) (string, types.SessionKey) {
+ return "", types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP}
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r := ®istry{
+ byUser: make(map[string]map[Key]Session),
+ slugIndex: make(map[Key]string),
+ mu: sync.RWMutex{},
+ }
+ user, key := tt.setupFunc(r)
+ if user == "" {
+ key = tt.key
+ }
+ r.Remove(key)
+ if tt.verify != nil {
+ tt.verify(t, r, user, key)
+ }
+ })
+ }
+}
+
+func TestIsValidSlug(t *testing.T) {
+ tests := []struct {
+ slug string
+ want bool
+ }{
+ {"abc", true},
+ {"abc-123", true},
+ {"a", false},
+ {"verybigdihsixsevenlabubu", false},
+ {"-iamsigma", false},
+ {"ligma-", false},
+ {"invalid$", false},
+ {"valid-slug1", true},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.slug, func(t *testing.T) {
+ got := isValidSlug(tt.slug)
+ if got != tt.want {
+ t.Errorf("isValidSlug(%q) = %v; want %v", tt.slug, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIsValidSlugChar(t *testing.T) {
+ tests := []struct {
+ char byte
+ want bool
+ }{
+ {'a', true},
+ {'z', true},
+ {'0', true},
+ {'9', true},
+ {'-', true},
+ {'A', false},
+ {'$', false},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(string(tt.char), func(t *testing.T) {
+ got := isValidSlugChar(tt.char)
+ if got != tt.want {
+ t.Errorf("isValidSlugChar(%q) = %v; want %v", tt.char, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIsForbiddenSlug(t *testing.T) {
+ forbiddenSlugs = map[string]struct{}{
+ "admin": {},
+ "root": {},
+ }
+
+ tests := []struct {
+ slug string
+ want bool
+ }{
+ {"admin", true},
+ {"root", true},
+ {"user", false},
+ {"guest", false},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.slug, func(t *testing.T) {
+ got := isForbiddenSlug(tt.slug)
+ if got != tt.want {
+ t.Errorf("isForbiddenSlug(%q) = %v; want %v", tt.slug, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/internal/transport/http.go b/internal/transport/http.go
index dd091c3..5c4648d 100644
--- a/internal/transport/http.go
+++ b/internal/transport/http.go
@@ -4,27 +4,28 @@ import (
"errors"
"log"
"net"
+ "tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
)
type httpServer struct {
handler *httpHandler
- port string
+ config config.Config
}
-func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
+func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport {
return &httpServer{
- handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
- port: port,
+ handler: newHTTPHandler(config, sessionRegistry),
+ config: config,
}
}
func (ht *httpServer) Listen() (net.Listener, error) {
- return net.Listen("tcp", ":"+ht.port)
+ return net.Listen("tcp", ":"+ht.config.HTTPPort())
}
func (ht *httpServer) Serve(listener net.Listener) error {
- log.Printf("HTTP server is starting on port %s", ht.port)
+ log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort())
for {
conn, err := listener.Accept()
if err != nil {
@@ -35,6 +36,6 @@ func (ht *httpServer) Serve(listener net.Listener) error {
continue
}
- go ht.handler.handler(conn, false)
+ go ht.handler.Handler(conn, false)
}
}
diff --git a/internal/transport/http_test.go b/internal/transport/http_test.go
new file mode 100644
index 0000000..cd3cf68
--- /dev/null
+++ b/internal/transport/http_test.go
@@ -0,0 +1,135 @@
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+func TestNewHTTPServer(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("HTTPPort").Return(port)
+
+ srv := NewHTTPServer(mockConfig, msr)
+ assert.NotNil(t, srv)
+
+ httpSrv, ok := srv.(*httpServer)
+ assert.True(t, ok)
+ assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
+ assert.NotNil(t, srv)
+}
+
+func TestHTTPServer_Listen(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("HTTPPort").Return(port)
+ srv := NewHTTPServer(mockConfig, msr)
+
+ listener, err := srv.Listen()
+ assert.NoError(t, err)
+ assert.NotNil(t, listener)
+ err = listener.Close()
+ assert.NoError(t, err)
+}
+
+func TestHTTPServer_Serve(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("HTTPPort").Return(port)
+ srv := NewHTTPServer(mockConfig, msr)
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ assert.NoError(t, err)
+
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ err = listener.Close()
+ assert.NoError(t, err)
+ }()
+
+ err = srv.Serve(listener)
+ assert.True(t, errors.Is(err, net.ErrClosed))
+}
+
+func TestHTTPServer_Serve_AcceptError(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("HTTPPort").Return(port)
+ srv := NewHTTPServer(mockConfig, msr)
+
+ ml := new(mockListener)
+ ml.On("Accept").Return(nil, errors.New("accept error")).Once()
+ ml.On("Accept").Return(nil, net.ErrClosed).Once()
+
+ err := srv.Serve(ml)
+ assert.True(t, errors.Is(err, net.ErrClosed))
+ ml.AssertExpectations(t)
+}
+
+func TestHTTPServer_Serve_Success(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("HTTPPort").Return(port)
+ mockConfig.On("HeaderSize").Return(4096)
+ mockConfig.On("TLSRedirect").Return(false)
+ srv := NewHTTPServer(mockConfig, msr)
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ assert.NoError(t, err)
+ listenerport := listener.Addr().(*net.TCPAddr).Port
+
+ go func() {
+ _ = srv.Serve(listener)
+ }()
+
+ conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
+ assert.NoError(t, err)
+
+ _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
+
+ time.Sleep(100 * time.Millisecond)
+ err = conn.Close()
+ assert.NoError(t, err)
+
+ err = listener.Close()
+ assert.NoError(t, err)
+
+}
+
+type mockListener struct {
+ mock.Mock
+}
+
+func (m *mockListener) Accept() (net.Conn, error) {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(net.Conn), args.Error(1)
+}
+
+func (m *mockListener) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func (m *mockListener) Addr() net.Addr {
+ args := m.Called()
+ return args.Get(0).(net.Addr)
+}
diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go
index 8bab4a0..67aa6fb 100644
--- a/internal/transport/httphandler.go
+++ b/internal/transport/httphandler.go
@@ -1,7 +1,8 @@
package transport
import (
- "bufio"
+ "bytes"
+ "context"
"errors"
"fmt"
"io"
@@ -10,6 +11,7 @@ import (
"net/http"
"strings"
"time"
+ "tunnel_pls/internal/config"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware"
@@ -20,16 +22,14 @@ import (
)
type httpHandler struct {
- domain string
+ config config.Config
sessionRegistry registry.Registry
- redirectTLS bool
}
-func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
+func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler {
return &httpHandler{
- domain: domain,
+ config: config,
sessionRegistry: sessionRegistry,
- redirectTLS: redirectTLS,
}
}
@@ -52,13 +52,28 @@ func (hh *httpHandler) badRequest(conn net.Conn) error {
return nil
}
-func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
+func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
defer hh.closeConnection(conn)
- dstReader := bufio.NewReader(conn)
- reqhf, err := header.NewRequest(dstReader)
+ _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
+ buf := make([]byte, hh.config.HeaderSize())
+ n, err := conn.Read(buf)
+ if err != nil {
+ _ = hh.badRequest(conn)
+ return
+ }
+
+ if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 {
+ _ = hh.badRequest(conn)
+ return
+ }
+
+ _ = conn.SetReadDeadline(time.Time{})
+
+ reqhf, err := header.NewRequest(buf[:n])
if err != nil {
log.Printf("Error creating request header: %v", err)
+ _ = hh.badRequest(conn)
return
}
@@ -69,7 +84,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
}
if hh.shouldRedirectToTLS(isTLS) {
- _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
+ _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain()))
return
}
@@ -77,13 +92,16 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
return
}
- sshSession, err := hh.getSession(slug)
+ sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
+ Id: slug,
+ Type: types.TunnelTypeHTTP,
+ })
if err != nil {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return
}
- hw := stream.New(conn, dstReader, conn.RemoteAddr())
+ hw := stream.New(conn, conn, conn.RemoteAddr())
defer func(hw stream.HTTP) {
err = hw.Close()
if err != nil {
@@ -102,14 +120,14 @@ func (hh *httpHandler) closeConnection(conn net.Conn) {
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".")
- if len(host) < 1 {
+ if len(host) <= 1 {
return "", errors.New("invalid host")
}
return host[0], nil
}
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
- return !isTLS && hh.redirectTLS
+ return !isTLS && hh.config.TLSRedirect()
}
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
@@ -128,29 +146,22 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
+ return true
}
return true
}
-func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
- sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
- Id: slug,
- Type: types.TunnelTypeHTTP,
- })
- if err != nil {
- return nil, err
- }
- return sshSession, nil
-}
-
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
- channel, err := hh.openForwardedChannel(hw, sshSession)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+ defer cancel()
+ channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, hw.RemoteAddr())
if err != nil {
- log.Printf("Failed to establish channel: %v", err)
- sshSession.Forwarder().WriteBadGatewayResponse(hw)
+ log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
+ go ssh.DiscardRequests(reqs)
+
defer func() {
err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
@@ -167,47 +178,6 @@ func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.Requ
sshSession.Forwarder().HandleConnection(hw, channel)
}
-func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
- payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
-
- type channelResult struct {
- channel ssh.Channel
- reqs <-chan *ssh.Request
- err error
- }
-
- resultChan := make(chan channelResult, 1)
-
- go func() {
- channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
- select {
- case resultChan <- channelResult{channel, reqs, err}:
- default:
- hh.cleanupUnusedChannel(channel, reqs)
- }
- }()
-
- select {
- case result := <-resultChan:
- if result.err != nil {
- return nil, result.err
- }
- go ssh.DiscardRequests(result.reqs)
- return result.channel, nil
- case <-time.After(5 * time.Second):
- return nil, errors.New("timeout opening forwarded-tcpip channel")
- }
-}
-
-func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
- if channel != nil {
- if err := channel.Close(); err != nil {
- log.Printf("Failed to close unused channel: %v", err)
- }
- go ssh.DiscardRequests(reqs)
- }
-}
-
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
fingerprintMiddleware := middleware.NewTunnelFingerprint()
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
diff --git a/internal/transport/httphandler_test.go b/internal/transport/httphandler_test.go
new file mode 100644
index 0000000..6801b22
--- /dev/null
+++ b/internal/transport/httphandler_test.go
@@ -0,0 +1,717 @@
+package transport
+
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+ "tunnel_pls/internal/registry"
+ "tunnel_pls/session/forwarder"
+ "tunnel_pls/session/interaction"
+ "tunnel_pls/session/lifecycle"
+ "tunnel_pls/session/slug"
+ "tunnel_pls/types"
+
+ "golang.org/x/crypto/ssh"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+type MockSessionRegistry struct {
+ mock.Mock
+}
+
+func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
+ args := m.Called(key)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(registry.Session), args.Error(1)
+}
+
+func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
+ args := m.Called(user, key)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(registry.Session), args.Error(1)
+}
+
+func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
+ args := m.Called(user, oldKey, newKey)
+ return args.Error(0)
+}
+
+func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
+ args := m.Called(key, session)
+ return args.Bool(0)
+}
+
+func (m *MockSessionRegistry) Remove(key registry.Key) {
+ m.Called(key)
+}
+
+func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
+ args := m.Called(user)
+ return args.Get(0).([]registry.Session)
+}
+
+func (m *MockSessionRegistry) Slug() slug.Slug {
+ args := m.Called()
+ return args.Get(0).(slug.Slug)
+}
+
+type MockSession struct {
+ mock.Mock
+}
+
+func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
+ args := m.Called()
+ return args.Get(0).(lifecycle.Lifecycle)
+}
+
+func (m *MockSession) Interaction() interaction.Interaction {
+ args := m.Called()
+ return args.Get(0).(interaction.Interaction)
+}
+
+func (m *MockSession) Forwarder() forwarder.Forwarder {
+ args := m.Called()
+ return args.Get(0).(forwarder.Forwarder)
+}
+
+func (m *MockSession) Slug() slug.Slug {
+ args := m.Called()
+ return args.Get(0).(slug.Slug)
+}
+
+func (m *MockSession) Detail() *types.Detail {
+ args := m.Called()
+ return args.Get(0).(*types.Detail)
+}
+
+type MockSSHChannel struct {
+ ssh.Channel
+ mock.Mock
+}
+
+func (m *MockSSHChannel) Write(data []byte) (int, error) {
+ args := m.Called(data)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockSSHChannel) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+type MockForwarder struct {
+ mock.Mock
+}
+
+func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
+ m.Called(dst, src)
+}
+
+func (m *MockForwarder) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func (m *MockForwarder) TunnelType() types.TunnelType {
+ args := m.Called()
+ return args.Get(0).(types.TunnelType)
+}
+
+func (m *MockForwarder) ForwardedPort() uint16 {
+ args := m.Called()
+ return uint16(args.Int(0))
+}
+
+func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
+ m.Called(tunnelType)
+}
+
+func (m *MockForwarder) SetForwardedPort(port uint16) {
+ m.Called(port)
+}
+
+func (m *MockForwarder) SetListener(listener net.Listener) {
+ m.Called(listener)
+}
+
+func (m *MockForwarder) Listener() net.Listener {
+ args := m.Called()
+ return args.Get(0).(net.Listener)
+}
+
+func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
+ args := m.Called(ctx, origin)
+ if args.Get(0) == nil {
+ return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
+ }
+ return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
+}
+
+type MockConn struct {
+ mock.Mock
+ ReadBuffer *bytes.Buffer
+}
+
+func (m *MockConn) LocalAddr() net.Addr {
+ args := m.Called()
+ return args.Get(0).(net.Addr)
+}
+
+func (m *MockConn) SetDeadline(t time.Time) error {
+ args := m.Called(t)
+ return args.Error(0)
+}
+
+func (m *MockConn) SetReadDeadline(t time.Time) error {
+ args := m.Called(t)
+ return args.Error(0)
+}
+
+func (m *MockConn) SetWriteDeadline(t time.Time) error {
+ args := m.Called(t)
+ return args.Error(0)
+}
+
+func (m *MockConn) Read(b []byte) (n int, err error) {
+ if m.ReadBuffer != nil {
+ return m.ReadBuffer.Read(b)
+ }
+ args := m.Called(b)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockConn) Write(b []byte) (n int, err error) {
+ args := m.Called(b)
+ if args.Int(0) == -1 {
+ return len(b), args.Error(1)
+ }
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockConn) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func (m *MockConn) RemoteAddr() net.Addr {
+ args := m.Called()
+ return args.Get(0).(net.Addr)
+}
+
+type wrappedConn struct {
+ net.Conn
+ remoteAddr net.Addr
+}
+
+func (c *wrappedConn) RemoteAddr() net.Addr {
+ return c.remoteAddr
+}
+
+func TestNewHTTPHandler(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ mockConfig.On("Domain").Return("domain")
+ mockConfig.On("TLSRedirect").Return(false)
+ hh := newHTTPHandler(mockConfig, msr)
+ assert.NotNil(t, hh)
+ assert.Equal(t, msr, hh.sessionRegistry)
+}
+
+func TestHandler(t *testing.T) {
+ tests := []struct {
+ name string
+ isTLS bool
+ redirectTLS bool
+ request []byte
+ expected []byte
+ setupMocks func(*MockSessionRegistry)
+ setupConn func() (net.Conn, net.Conn)
+ expectError bool
+ }{
+ {
+ name: "bad request - invalid host",
+ isTLS: false,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"),
+ expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ },
+ },
+ {
+ name: "bad request - missing host",
+ isTLS: false,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\n\r\n"),
+ expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ },
+ },
+ {
+ name: "isTLS true and redirectTLS true - no redirect",
+ isTLS: true,
+ redirectTLS: true,
+ request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
+ expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ },
+ },
+ {
+ name: "redirect to TLS",
+ isTLS: false,
+ redirectTLS: true,
+ request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"),
+ expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ },
+ },
+ {
+ name: "handle ping request",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
+ expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ },
+ },
+ {
+ name: "session not found",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
+ expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ msr.On("Get", types.SessionKey{
+ Id: "test",
+ Type: types.TunnelTypeHTTP,
+ }).Return((registry.Session)(nil), fmt.Errorf("session not found"))
+ },
+ },
+ {
+ name: "bad request - invalid http",
+ isTLS: false,
+ redirectTLS: false,
+ request: []byte("INVALID\r\n\r\n"),
+ expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ },
+ },
+ {
+ name: "bad request - header too large",
+ isTLS: false,
+ redirectTLS: false,
+ request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))),
+ expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ },
+ },
+ {
+ name: "bad request - no request",
+ isTLS: false,
+ redirectTLS: false,
+ request: []byte(""),
+ expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ },
+ },
+ {
+ name: "forwarding - open channel fails",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
+ expected: []byte(""),
+ setupMocks: func(msr *MockSessionRegistry) {
+ mockSession := new(MockSession)
+ mockForwarder := new(MockForwarder)
+
+ msr.On("Get", types.SessionKey{
+ Id: "test",
+ Type: types.TunnelTypeHTTP,
+ }).Return(mockSession, nil)
+
+ mockSession.On("Forwarder").Return(mockForwarder)
+
+ mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
+ mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
+ },
+ },
+ {
+ name: "forwarding - send initial request fails",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
+ expected: []byte(""),
+ setupMocks: func(msr *MockSessionRegistry) {
+ mockSession := new(MockSession)
+ mockForwarder := new(MockForwarder)
+ mockSSHChannel := new(MockSSHChannel)
+
+ msr.On("Get", types.SessionKey{
+ Id: "test",
+ Type: types.TunnelTypeHTTP,
+ }).Return(mockSession, nil)
+
+ mockSession.On("Forwarder").Return(mockForwarder)
+
+ mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
+
+ reqCh := make(chan *ssh.Request)
+ mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
+
+ mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
+ mockSSHChannel.On("Close").Return(nil)
+
+ go func() {
+ for range reqCh {
+ }
+ }()
+ },
+ },
+ {
+ name: "forwarding - success",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
+ expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ mockSession := new(MockSession)
+ mockForwarder := new(MockForwarder)
+ mockSSHChannel := new(MockSSHChannel)
+
+ msr.On("Get", types.SessionKey{
+ Id: "test",
+ Type: types.TunnelTypeHTTP,
+ }).Return(mockSession, nil)
+
+ mockSession.On("Forwarder").Return(mockForwarder)
+
+ mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
+
+ reqCh := make(chan *ssh.Request)
+ mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
+
+ mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
+ mockSSHChannel.On("Close").Return(nil)
+
+ mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
+ w := args.Get(0).(io.ReadWriter)
+ _, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
+ })
+
+ go func() {
+ for range reqCh {
+ }
+ }()
+ },
+ },
+ {
+ name: "redirect - write failure",
+ isTLS: false,
+ redirectTLS: true,
+ request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
+ expected: []byte(""),
+ setupConn: func() (net.Conn, net.Conn) {
+ mc := new(MockConn)
+ mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
+ mc.On("SetReadDeadline", mock.Anything).Return(nil)
+ mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error"))
+ mc.On("Close").Return(nil)
+ return mc, nil
+ },
+ },
+ {
+ name: "bad request - write failure",
+ isTLS: false,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\n\r\n"),
+ expected: []byte(""),
+ setupConn: func() (net.Conn, net.Conn) {
+ mc := new(MockConn)
+ mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
+ mc.On("SetReadDeadline", mock.Anything).Return(nil)
+ mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
+ mc.On("Close").Return(nil)
+ return mc, nil
+ },
+ },
+ {
+ name: "read error - connection failure",
+ isTLS: false,
+ redirectTLS: false,
+ request: []byte(""),
+ expected: []byte(""),
+ setupConn: func() (net.Conn, net.Conn) {
+ mc := new(MockConn)
+ mc.On("SetReadDeadline", mock.Anything).Return(nil)
+ mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
+ mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer"))
+ mc.On("Close").Return(nil)
+ return mc, nil
+ },
+ },
+ {
+ name: "handle ping request - write failure",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
+ expected: []byte(""),
+ setupConn: func() (net.Conn, net.Conn) {
+ mc := new(MockConn)
+ mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
+ mc.On("SetReadDeadline", mock.Anything).Return(nil)
+ mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
+ mc.On("Close").Return(nil)
+ return mc, nil
+ },
+ },
+ {
+ name: "close connection - error",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
+ expected: []byte(""),
+ setupConn: func() (net.Conn, net.Conn) {
+ mc := new(MockConn)
+ mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
+ mc.On("SetReadDeadline", mock.Anything).Return(nil)
+ mc.On("Write", mock.Anything).Return(182, nil)
+ mc.On("Close").Return(fmt.Errorf("close error"))
+ return mc, nil
+ },
+ },
+ {
+ name: "forwarding - stream close error",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
+ expected: []byte(""),
+ setupMocks: func(msr *MockSessionRegistry) {
+ mockSession := new(MockSession)
+ mockForwarder := new(MockForwarder)
+ mockSSHChannel := new(MockSSHChannel)
+
+ msr.On("Get", mock.Anything).Return(mockSession, nil)
+ mockSession.On("Forwarder").Return(mockForwarder)
+
+ mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
+
+ reqCh := make(chan *ssh.Request)
+ mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
+
+ mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
+ mockSSHChannel.On("Close").Return(nil)
+
+ mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return()
+ },
+ setupConn: func() (net.Conn, net.Conn) {
+ mc := new(MockConn)
+ mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
+ mc.On("SetReadDeadline", mock.Anything).Return(nil)
+ mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
+ addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
+ mc.On("RemoteAddr").Return(addr)
+ return mc, nil
+ },
+ },
+ {
+ name: "forwarding - middleware failure",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
+ expected: []byte(""),
+ setupMocks: func(msr *MockSessionRegistry) {
+ mockSession := new(MockSession)
+ mockForwarder := new(MockForwarder)
+ mockSSHChannel := new(MockSSHChannel)
+
+ msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
+ return k.Id == "test"
+ })).Return(mockSession, nil)
+ mockSession.On("Forwarder").Return(mockForwarder)
+ mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
+
+ reqCh := make(chan *ssh.Request)
+ mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
+ mockSSHChannel.On("Close").Return(nil)
+ },
+ setupConn: func() (net.Conn, net.Conn) {
+ mc := new(MockConn)
+ mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
+ mc.On("SetReadDeadline", mock.Anything).Return(nil)
+ mc.On("Close").Return(nil).Times(2)
+ mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
+ return mc, nil
+ },
+ },
+ {
+ name: "forwarding - channel close error",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
+ expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
+ setupMocks: func(msr *MockSessionRegistry) {
+ mockSession := new(MockSession)
+ mockForwarder := new(MockForwarder)
+ mockSSHChannel := new(MockSSHChannel)
+
+ msr.On("Get", mock.Anything).Return(mockSession, nil)
+ mockSession.On("Forwarder").Return(mockForwarder)
+
+ mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
+ reqCh := make(chan *ssh.Request)
+ mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
+
+ mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
+ mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
+
+ mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
+ w := args.Get(0).(io.ReadWriter)
+ _, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
+ })
+ },
+ },
+ {
+ name: "forwarding - open channel timeout",
+ isTLS: true,
+ redirectTLS: false,
+ request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
+ expected: []byte(""),
+ setupMocks: func(msr *MockSessionRegistry) {
+ mockSession := new(MockSession)
+ mockForwarder := new(MockForwarder)
+
+ msr.On("Get", mock.Anything).Return(mockSession, nil)
+ mockSession.On("Forwarder").Return(mockForwarder)
+
+ mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
+
+ mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
+ ctx := args.Get(0).(context.Context)
+ <-ctx.Done()
+ }).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockSessionRegistry := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return("example.com")
+ mockConfig.On("HTTPPort").Return(port)
+ mockConfig.On("HeaderSize").Return(4096)
+ mockConfig.On("TLSRedirect").Return(true)
+ hh := &httpHandler{
+ sessionRegistry: mockSessionRegistry,
+ config: mockConfig,
+ }
+
+ if tt.setupMocks != nil {
+ tt.setupMocks(mockSessionRegistry)
+ }
+
+ var serverConn, clientConn net.Conn
+ if tt.setupConn != nil {
+ serverConn, clientConn = tt.setupConn()
+ } else {
+ serverConn, clientConn = net.Pipe()
+ }
+
+ if clientConn != nil {
+ defer func(clientConn net.Conn) {
+ err := clientConn.Close()
+ assert.NoError(t, err)
+ }(clientConn)
+ }
+
+ remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
+ var wrappedServerConn net.Conn
+ if _, ok := serverConn.(*MockConn); ok {
+ wrappedServerConn = serverConn
+ } else {
+ wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr}
+ }
+
+ responseChan := make(chan []byte, 1)
+ doneChan := make(chan struct{})
+
+ if clientConn != nil {
+ go func() {
+ defer close(doneChan)
+ var res []byte
+ for {
+ buf := make([]byte, 4096)
+ n, err := clientConn.Read(buf)
+ if err != nil {
+ if err != io.EOF {
+ t.Logf("Error reading response: %v", err)
+ }
+ break
+ }
+ res = append(res, buf[:n]...)
+ if len(tt.expected) > 0 && len(res) >= len(tt.expected) {
+ break
+ }
+ }
+ responseChan <- res
+ }()
+
+ go func() {
+ _, err := clientConn.Write(tt.request)
+ if err != nil {
+ t.Logf("Error writing request: %v", err)
+ }
+ }()
+ } else {
+ close(responseChan)
+ close(doneChan)
+ }
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ hh.Handler(wrappedServerConn, tt.isTLS)
+ }()
+
+ select {
+ case response := <-responseChan:
+ if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" {
+ resStr := string(response)
+ assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n"))
+ assert.Contains(t, resStr, "Content-Length: 5\r\n")
+ assert.Contains(t, resStr, "Server: Tunnel Please\r\n")
+ assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello"))
+ } else {
+ assert.Equal(t, string(tt.expected), string(response))
+ }
+ case <-time.After(10 * time.Second):
+ if clientConn != nil {
+ t.Fatal("Test timeout - no response received")
+ }
+ }
+
+ wg.Wait()
+ if clientConn != nil {
+ <-doneChan
+ }
+
+ mockSessionRegistry.AssertExpectations(t)
+ if mc, ok := serverConn.(*MockConn); ok {
+ mc.AssertExpectations(t)
+ }
+ })
+ }
+}
diff --git a/internal/transport/https.go b/internal/transport/https.go
index 88ffe27..f1076bf 100644
--- a/internal/transport/https.go
+++ b/internal/transport/https.go
@@ -5,31 +5,30 @@ import (
"errors"
"log"
"net"
+ "tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
)
type https struct {
+ config config.Config
tlsConfig *tls.Config
httpHandler *httpHandler
- domain string
- port string
}
-func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
+func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport {
return &https{
+ config: config,
tlsConfig: tlsConfig,
- httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
- domain: domain,
- port: port,
+ httpHandler: newHTTPHandler(config, sessionRegistry),
}
}
func (ht *https) Listen() (net.Listener, error) {
- return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
+ return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error {
- log.Printf("HTTPS server is starting on port %s", ht.port)
+ log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort())
for {
conn, err := listener.Accept()
if err != nil {
@@ -40,6 +39,6 @@ func (ht *https) Serve(listener net.Listener) error {
continue
}
- go ht.httpHandler.handler(conn, true)
+ go ht.httpHandler.Handler(conn, true)
}
}
diff --git a/internal/transport/https_test.go b/internal/transport/https_test.go
new file mode 100644
index 0000000..6081d97
--- /dev/null
+++ b/internal/transport/https_test.go
@@ -0,0 +1,120 @@
+package transport
+
+import (
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestNewHTTPSServer(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ tlsConfig := &tls.Config{}
+ mockConfig.On("Domain").Return(mockConfig)
+ mockConfig.On("HTTPSPort").Return(port)
+ srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
+ assert.NotNil(t, srv)
+
+ httpsSrv, ok := srv.(*https)
+ assert.True(t, ok)
+ assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
+ assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
+}
+
+func TestHTTPSServer_Listen(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return(mockConfig)
+ mockConfig.On("HTTPSPort").Return(port)
+ tlsConfig := &tls.Config{
+ GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ return nil, nil
+ },
+ }
+ srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
+
+ listener, err := srv.Listen()
+ if err != nil {
+ t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err)
+ return
+ }
+ assert.NotNil(t, listener)
+ err = listener.Close()
+ assert.NoError(t, err)
+}
+
+func TestHTTPSServer_Serve(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return(mockConfig)
+ mockConfig.On("HTTPSPort").Return(port)
+ srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ assert.NoError(t, err)
+
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ err = listener.Close()
+ assert.NoError(t, err)
+ }()
+
+ err = srv.Serve(listener)
+ assert.True(t, errors.Is(err, net.ErrClosed))
+}
+
+func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
+ msr := new(MockSessionRegistry)
+
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return(mockConfig)
+ mockConfig.On("HTTPSPort").Return(port)
+ srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
+
+ ml := new(mockListener)
+ ml.On("Accept").Return(nil, errors.New("accept error")).Once()
+ ml.On("Accept").Return(nil, net.ErrClosed).Once()
+
+ err := srv.Serve(ml)
+ assert.True(t, errors.Is(err, net.ErrClosed))
+ ml.AssertExpectations(t)
+}
+
+func TestHTTPSServer_Serve_Success(t *testing.T) {
+ msr := new(MockSessionRegistry)
+ mockConfig := &MockConfig{}
+ port := "0"
+ mockConfig.On("Domain").Return(mockConfig)
+ mockConfig.On("HTTPSPort").Return(port)
+ mockConfig.On("HeaderSize").Return(4096)
+
+ srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ assert.NoError(t, err)
+ listenerport := listener.Addr().(*net.TCPAddr).Port
+
+ go func() {
+ _ = srv.Serve(listener)
+ }()
+
+ conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
+ assert.NoError(t, err)
+
+ _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
+
+ time.Sleep(100 * time.Millisecond)
+ err = conn.Close()
+ assert.NoError(t, err)
+ err = listener.Close()
+ assert.NoError(t, err)
+}
diff --git a/internal/transport/tcp.go b/internal/transport/tcp.go
index 99670d2..34ea3c2 100644
--- a/internal/transport/tcp.go
+++ b/internal/transport/tcp.go
@@ -1,27 +1,28 @@
package transport
import (
+ "context"
"errors"
"fmt"
"io"
"log"
"net"
+ "time"
"golang.org/x/crypto/ssh"
)
type tcp struct {
port uint16
- forwarder forwarder
+ forwarder Forwarder
}
-type forwarder interface {
- CreateForwardedTCPIPPayload(origin net.Addr) []byte
- OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
+type Forwarder interface {
+ OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
HandleConnection(dst io.ReadWriter, src ssh.Channel)
}
-func NewTCPServer(port uint16, forwarder forwarder) Transport {
+func NewTCPServer(port uint16, forwarder Forwarder) Transport {
return &tcp{
port: port,
forwarder: forwarder,
@@ -53,11 +54,11 @@ func (tt *tcp) handleTcp(conn net.Conn) {
log.Printf("Failed to close connection: %v", err)
}
}()
- payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
- channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+ defer cancel()
+ channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, conn.RemoteAddr())
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
-
return
}
diff --git a/internal/transport/tcp_test.go b/internal/transport/tcp_test.go
new file mode 100644
index 0000000..c4c4963
--- /dev/null
+++ b/internal/transport/tcp_test.go
@@ -0,0 +1,146 @@
+package transport
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "golang.org/x/crypto/ssh"
+)
+
+func TestNewTCPServer(t *testing.T) {
+ mf := new(MockForwarder)
+ port := uint16(9000)
+
+ srv := NewTCPServer(port, mf)
+ assert.NotNil(t, srv)
+
+ tcpSrv, ok := srv.(*tcp)
+ assert.True(t, ok)
+ assert.Equal(t, port, tcpSrv.port)
+ assert.Equal(t, mf, tcpSrv.forwarder)
+}
+
+func TestTCPServer_Listen(t *testing.T) {
+ mf := new(MockForwarder)
+ srv := NewTCPServer(0, mf)
+
+ listener, err := srv.Listen()
+ assert.NoError(t, err)
+ assert.NotNil(t, listener)
+ err = listener.Close()
+ assert.NoError(t, err)
+}
+
+func TestTCPServer_Serve(t *testing.T) {
+ mf := new(MockForwarder)
+ srv := NewTCPServer(0, mf)
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ assert.NoError(t, err)
+
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ err = listener.Close()
+ assert.NoError(t, err)
+ }()
+
+ err = srv.Serve(listener)
+ assert.Nil(t, err)
+}
+
+func TestTCPServer_Serve_AcceptError(t *testing.T) {
+ mf := new(MockForwarder)
+ srv := NewTCPServer(0, mf)
+
+ ml := new(mockListener)
+ ml.On("Accept").Return(nil, errors.New("accept error")).Once()
+ ml.On("Accept").Return(nil, net.ErrClosed).Once()
+
+ err := srv.Serve(ml)
+ assert.Nil(t, err)
+ ml.AssertExpectations(t)
+}
+
+func TestTCPServer_Serve_Success(t *testing.T) {
+ mf := new(MockForwarder)
+ srv := NewTCPServer(0, mf)
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ assert.NoError(t, err)
+ port := listener.Addr().(*net.TCPAddr).Port
+
+ reqs := make(chan *ssh.Request)
+ mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
+ mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
+
+ go func() {
+ _ = srv.Serve(listener)
+ }()
+
+ conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
+ assert.NoError(t, err)
+
+ time.Sleep(100 * time.Millisecond)
+ err = conn.Close()
+ assert.NoError(t, err)
+ err = listener.Close()
+ assert.NoError(t, err)
+ mf.AssertExpectations(t)
+}
+
+func TestTCPServer_handleTcp_Success(t *testing.T) {
+ mf := new(MockForwarder)
+ srv := NewTCPServer(0, mf).(*tcp)
+
+ serverConn, clientConn := net.Pipe()
+ defer func(clientConn net.Conn) {
+ err := clientConn.Close()
+ assert.NoError(t, err)
+ }(clientConn)
+
+ reqs := make(chan *ssh.Request)
+ mockChannel := new(MockSSHChannel)
+ mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
+
+ mf.On("HandleConnection", serverConn, mockChannel).Return()
+
+ srv.handleTcp(serverConn)
+
+ mf.AssertExpectations(t)
+}
+
+func TestTCPServer_handleTcp_CloseError(t *testing.T) {
+ mf := new(MockForwarder)
+ srv := NewTCPServer(0, mf).(*tcp)
+
+ mc := new(MockConn)
+ mc.On("Close").Return(errors.New("close error"))
+ mc.On("RemoteAddr").Return(&net.TCPAddr{})
+
+ mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
+
+ srv.handleTcp(mc)
+ mc.AssertExpectations(t)
+}
+
+func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
+ mf := new(MockForwarder)
+ srv := NewTCPServer(0, mf).(*tcp)
+
+ serverConn, clientConn := net.Pipe()
+ defer func(clientConn net.Conn) {
+ err := clientConn.Close()
+ assert.NoError(t, err)
+ }(clientConn)
+
+ mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
+
+ srv.handleTcp(serverConn)
+
+ mf.AssertExpectations(t)
+}
diff --git a/internal/transport/tls.go b/internal/transport/tls.go
index 877afb4..4d62e60 100644
--- a/internal/transport/tls.go
+++ b/internal/transport/tls.go
@@ -8,6 +8,7 @@ import (
"fmt"
"log"
"os"
+ "path/filepath"
"sync"
"time"
"tunnel_pls/internal/config"
@@ -16,13 +17,22 @@ import (
"github.com/libdns/cloudflare"
)
-type TLSManager interface {
- userCertsExistAndValid() bool
- loadUserCerts() error
- startCertWatcher()
- initCertMagic() error
- getTLSConfig() *tls.Config
- getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
+func NewTLSConfig(config config.Config) (*tls.Config, error) {
+ var initErr error
+
+ tlsManagerOnce.Do(func() {
+ tm := createTLSManager(config)
+ initErr = tm.initialize()
+ if initErr == nil {
+ globalTLSManager = tm
+ }
+ })
+
+ if initErr != nil {
+ return nil, initErr
+ }
+
+ return globalTLSManager.getTLSConfig(), nil
}
type tlsManager struct {
@@ -40,52 +50,60 @@ type tlsManager struct {
useCertMagic bool
}
-var globalTLSManager TLSManager
+var globalTLSManager *tlsManager
var tlsManagerOnce sync.Once
-func NewTLSConfig(config config.Config) (*tls.Config, error) {
- var initErr error
+func createTLSManager(cfg config.Config) *tlsManager {
+ storagePath := cfg.TLSStoragePath()
+ cleanBase := filepath.Clean(storagePath)
- tlsManagerOnce.Do(func() {
- certPath := "certs/tls/cert.pem"
- keyPath := "certs/tls/privkey.pem"
- storagePath := "certs/tls/certmagic"
+ return &tlsManager{
+ config: cfg,
+ certPath: filepath.Join(cleanBase, "cert.pem"),
+ keyPath: filepath.Join(cleanBase, "privkey.pem"),
+ storagePath: filepath.Join(cleanBase, "certmagic"),
+ }
+}
- tm := &tlsManager{
- config: config,
- certPath: certPath,
- keyPath: keyPath,
- storagePath: storagePath,
- }
+func (tm *tlsManager) initialize() error {
+ if tm.userCertsExistAndValid() {
+ return tm.initializeWithUserCerts()
+ }
+ return tm.initializeWithCertMagic()
+}
- if tm.userCertsExistAndValid() {
- log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath)
- if err := tm.loadUserCerts(); err != nil {
- initErr = fmt.Errorf("failed to load user certificates: %w", err)
- return
- }
- tm.useCertMagic = false
- tm.startCertWatcher()
- } else {
- log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain())
- if err := tm.initCertMagic(); err != nil {
- initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
- return
- }
- tm.useCertMagic = true
- }
+func (tm *tlsManager) initializeWithUserCerts() error {
+ log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath)
- globalTLSManager = tm
- })
-
- if initErr != nil {
- return nil, initErr
+ if err := tm.loadUserCerts(); err != nil {
+ return fmt.Errorf("failed to load user certificates: %w", err)
}
- return globalTLSManager.getTLSConfig(), nil
+ tm.useCertMagic = false
+ tm.startCertWatcher()
+ return nil
+}
+
+func (tm *tlsManager) initializeWithCertMagic() error {
+ log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic",
+ tm.config.Domain(), tm.config.Domain())
+
+ if err := tm.initCertMagic(); err != nil {
+ return fmt.Errorf("failed to initialize CertMagic: %w", err)
+ }
+
+ tm.useCertMagic = true
+ return nil
}
func (tm *tlsManager) userCertsExistAndValid() bool {
+ if !tm.certFilesExist() {
+ return false
+ }
+ return validateCertDomains(tm.certPath, tm.config.Domain())
+}
+
+func (tm *tlsManager) certFilesExist() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath)
return false
@@ -94,66 +112,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
log.Printf("Key file not found: %s", tm.keyPath)
return false
}
-
- return ValidateCertDomains(tm.certPath, tm.config.Domain())
-}
-
-func ValidateCertDomains(certPath, domain string) bool {
- certPEM, err := os.ReadFile(certPath)
- if err != nil {
- log.Printf("Failed to read certificate: %v", err)
- return false
- }
-
- block, _ := pem.Decode(certPEM)
- if block == nil {
- log.Printf("Failed to decode PEM block from certificate")
- return false
- }
-
- cert, err := x509.ParseCertificate(block.Bytes)
- if err != nil {
- log.Printf("Failed to parse certificate: %v", err)
- return false
- }
-
- if time.Now().After(cert.NotAfter) {
- log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
- return false
- }
-
- if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) {
- log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
- return false
- }
-
- var certDomains []string
- if cert.Subject.CommonName != "" {
- certDomains = append(certDomains, cert.Subject.CommonName)
- }
- certDomains = append(certDomains, cert.DNSNames...)
-
- hasBase := false
- hasWildcard := false
- wildcardDomain := "*." + domain
-
- for _, d := range certDomains {
- if d == domain {
- hasBase = true
- }
- if d == wildcardDomain {
- hasWildcard = true
- }
- }
-
- if !hasBase {
- log.Printf("Certificate does not cover base domain: %s", domain)
- }
- if !hasWildcard {
- log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain)
- }
-
- return hasBase && hasWildcard
+ return true
}
func (tm *tlsManager) loadUserCerts() error {
@@ -172,62 +131,34 @@ func (tm *tlsManager) loadUserCerts() error {
func (tm *tlsManager) startCertWatcher() {
go func() {
- var lastCertMod, lastKeyMod time.Time
-
- if info, err := os.Stat(tm.certPath); err == nil {
- lastCertMod = info.ModTime()
- }
- if info, err := os.Stat(tm.keyPath); err == nil {
- lastKeyMod = info.ModTime()
- }
-
- ticker := time.NewTicker(30 * time.Second)
- defer ticker.Stop()
-
- for range ticker.C {
- certInfo, certErr := os.Stat(tm.certPath)
- keyInfo, keyErr := os.Stat(tm.keyPath)
-
- if certErr != nil || keyErr != nil {
- continue
- }
-
- if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
- log.Printf("Certificate files changed, reloading...")
-
- if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
- log.Printf("New certificates don't cover required domains")
-
- if err := tm.initCertMagic(); err != nil {
- log.Printf("Failed to initialize CertMagic: %v", err)
- continue
- }
- tm.useCertMagic = true
- return
- }
-
- if err := tm.loadUserCerts(); err != nil {
- log.Printf("Failed to reload certificates: %v", err)
- continue
- }
-
- lastCertMod = certInfo.ModTime()
- lastKeyMod = keyInfo.ModTime()
- log.Printf("Certificates reloaded successfully")
- }
- }
+ watcher := newCertWatcher(tm)
+ watcher.watch()
}()
}
func (tm *tlsManager) initCertMagic() error {
- if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
- return fmt.Errorf("failed to create cert storage directory: %w", err)
+ if err := tm.createStorageDirectory(); err != nil {
+ return err
}
if tm.config.CFAPIToken() == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
}
+ magic := tm.createCertMagicConfig()
+ tm.magic = magic
+
+ return tm.obtainCertificates(magic)
+}
+
+func (tm *tlsManager) createStorageDirectory() error {
+ if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
+ return fmt.Errorf("failed to create cert storage directory: %w", err)
+ }
+ return nil
+}
+
+func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
cfProvider := &cloudflare.Provider{
APIToken: tm.config.CFAPIToken(),
}
@@ -244,6 +175,13 @@ func (tm *tlsManager) initCertMagic() error {
Storage: storage,
})
+ acmeIssuer := tm.createACMEIssuer(magic, cfProvider)
+ magic.Issuers = []certmagic.Issuer{acmeIssuer}
+
+ return magic
+}
+
+func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer {
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: tm.config.ACMEEmail(),
Agreed: true,
@@ -262,9 +200,10 @@ func (tm *tlsManager) initCertMagic() error {
log.Printf("Using Let's Encrypt production server")
}
- magic.Issuers = []certmagic.Issuer{acmeIssuer}
- tm.magic = magic
+ return acmeIssuer
+}
+func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error {
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
log.Printf("Requesting certificates for: %v", domains)
@@ -307,3 +246,190 @@ func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certifica
return tm.userCert, nil
}
+
+func validateCertDomains(certPath, domain string) bool {
+ cert, err := loadAndParseCertificate(certPath)
+ if err != nil {
+ return false
+ }
+
+ if !isCertificateValid(cert) {
+ return false
+ }
+
+ return certCoversRequiredDomains(cert, domain)
+}
+
+func loadAndParseCertificate(certPath string) (*x509.Certificate, error) {
+ certPEM, err := os.ReadFile(certPath)
+ if err != nil {
+ log.Printf("Failed to read certificate: %v", err)
+ return nil, err
+ }
+
+ block, _ := pem.Decode(certPEM)
+ if block == nil {
+ log.Printf("Failed to decode PEM block from certificate")
+ return nil, fmt.Errorf("failed to decode PEM block")
+ }
+
+ cert, err := x509.ParseCertificate(block.Bytes)
+ if err != nil {
+ log.Printf("Failed to parse certificate: %v", err)
+ return nil, err
+ }
+
+ return cert, nil
+}
+
+func isCertificateValid(cert *x509.Certificate) bool {
+ now := time.Now()
+
+ if now.After(cert.NotAfter) {
+ log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
+ return false
+ }
+
+ thirtyDaysFromNow := now.Add(30 * 24 * time.Hour)
+ if thirtyDaysFromNow.After(cert.NotAfter) {
+ log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
+ return false
+ }
+
+ return true
+}
+
+func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool {
+ certDomains := extractCertDomains(cert)
+ hasBase, hasWildcard := checkDomainCoverage(certDomains, domain)
+
+ logDomainCoverage(hasBase, hasWildcard, domain)
+ return hasBase && hasWildcard
+}
+
+func extractCertDomains(cert *x509.Certificate) []string {
+ var domains []string
+ if cert.Subject.CommonName != "" {
+ domains = append(domains, cert.Subject.CommonName)
+ }
+ domains = append(domains, cert.DNSNames...)
+ return domains
+}
+
+func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) {
+ wildcardDomain := "*." + domain
+
+ for _, d := range certDomains {
+ if d == domain {
+ hasBase = true
+ }
+ if d == wildcardDomain {
+ hasWildcard = true
+ }
+ }
+
+ return hasBase, hasWildcard
+}
+
+func logDomainCoverage(hasBase, hasWildcard bool, domain string) {
+ if !hasBase {
+ log.Printf("Certificate does not cover base domain: %s", domain)
+ }
+ if !hasWildcard {
+ log.Printf("Certificate does not cover wildcard domain: *.%s", domain)
+ }
+}
+
+type certWatcher struct {
+ tm *tlsManager
+ lastCertMod time.Time
+ lastKeyMod time.Time
+}
+
+func newCertWatcher(tm *tlsManager) *certWatcher {
+ watcher := &certWatcher{tm: tm}
+ watcher.initializeModTimes()
+ return watcher
+}
+
+func (cw *certWatcher) initializeModTimes() {
+ if info, err := os.Stat(cw.tm.certPath); err == nil {
+ cw.lastCertMod = info.ModTime()
+ }
+ if info, err := os.Stat(cw.tm.keyPath); err == nil {
+ cw.lastKeyMod = info.ModTime()
+ }
+}
+
+func (cw *certWatcher) watch() {
+ ticker := time.NewTicker(30 * time.Second)
+ defer ticker.Stop()
+
+ for range ticker.C {
+ if cw.checkAndReloadCerts() {
+ return
+ }
+ }
+}
+
+func (cw *certWatcher) checkAndReloadCerts() bool {
+ certInfo, keyInfo, err := cw.getFileInfo()
+ if err != nil {
+ return false
+ }
+
+ if !cw.filesModified(certInfo, keyInfo) {
+ return false
+ }
+
+ return cw.handleCertificateChange(certInfo, keyInfo)
+}
+
+func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) {
+ certInfo, certErr := os.Stat(cw.tm.certPath)
+ keyInfo, keyErr := os.Stat(cw.tm.keyPath)
+
+ if certErr != nil || keyErr != nil {
+ return nil, nil, fmt.Errorf("file stat error")
+ }
+
+ return certInfo, keyInfo, nil
+}
+
+func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool {
+ return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod)
+}
+
+func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool {
+ log.Printf("Certificate files changed, reloading...")
+
+ if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) {
+ return cw.switchToCertMagic()
+ }
+
+ if err := cw.tm.loadUserCerts(); err != nil {
+ log.Printf("Failed to reload certificates: %v", err)
+ return false
+ }
+
+ cw.updateModTimes(certInfo, keyInfo)
+ log.Printf("Certificates reloaded successfully")
+ return false
+}
+
+func (cw *certWatcher) switchToCertMagic() bool {
+ log.Printf("New certificates don't cover required domains")
+
+ if err := cw.tm.initCertMagic(); err != nil {
+ log.Printf("Failed to initialize CertMagic: %v", err)
+ return false
+ }
+
+ cw.tm.useCertMagic = true
+ return true
+}
+
+func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) {
+ cw.lastCertMod = certInfo.ModTime()
+ cw.lastKeyMod = keyInfo.ModTime()
+}
diff --git a/internal/transport/tls_test.go b/internal/transport/tls_test.go
new file mode 100644
index 0000000..0c5510c
--- /dev/null
+++ b/internal/transport/tls_test.go
@@ -0,0 +1,1246 @@
+package transport
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "math/big"
+ "os"
+ "path/filepath"
+ "sync"
+ "testing"
+ "time"
+ "tunnel_pls/internal/config"
+ "tunnel_pls/types"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+)
+
+type MockConfig struct {
+ mock.Mock
+}
+
+func (m *MockConfig) Domain() string { return m.Called().String(0) }
+func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
+func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
+func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
+func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
+func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
+func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
+func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
+func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
+func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
+func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
+func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
+func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
+func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
+func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
+
+func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, soon bool) (string, string) {
+ t.Helper()
+
+ priv, err := rsa.GenerateKey(rand.Reader, 2048)
+ assert.NoError(t, err)
+
+ notAfter := time.Now().Add(365 * 24 * time.Hour)
+ if expired {
+ notAfter = time.Now().Add(-24 * time.Hour)
+ } else if soon {
+ notAfter = time.Now().Add(15 * 24 * time.Hour)
+ }
+
+ template := x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{
+ CommonName: domain,
+ },
+ NotBefore: time.Now().Add(-24 * time.Hour),
+ NotAfter: notAfter,
+ DNSNames: []string{domain},
+ }
+
+ if wildcard {
+ template.DNSNames = append(template.DNSNames, "*."+domain)
+ }
+
+ derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
+ assert.NoError(t, err)
+
+ certOut, err := os.CreateTemp("", "cert*.pem")
+ assert.NoError(t, err)
+ err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+ assert.NoError(t, err)
+ err = certOut.Close()
+ assert.NoError(t, err)
+
+ keyOut, err := os.CreateTemp("", "key*.pem")
+ assert.NoError(t, err)
+ err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
+ assert.NoError(t, err)
+ err = keyOut.Close()
+ assert.NoError(t, err)
+
+ return certOut.Name(), keyOut.Name()
+}
+
+func setupTestDir(t *testing.T) string {
+ t.Helper()
+
+ tmpDir, err := os.MkdirTemp("", "tls-test-*")
+ assert.NoError(t, err)
+
+ t.Cleanup(func() {
+ err = os.RemoveAll(tmpDir)
+ assert.NoError(t, err)
+ })
+
+ return tmpDir
+}
+
+func TestValidateCertDomains(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) (certPath string, cleanup func())
+ domain string
+ expected bool
+ }{
+ {
+ name: "file not found",
+ setup: func(t *testing.T) (string, func()) {
+ return "nonexistent.pem", func() {}
+ },
+ domain: "example.com",
+ expected: false,
+ },
+ {
+ name: "invalid PEM",
+ setup: func(t *testing.T) (string, func()) {
+ tmpFile, err := os.CreateTemp("", "invalid*.pem")
+ assert.NoError(t, err)
+ _, err = tmpFile.WriteString("not a pem")
+ assert.NoError(t, err)
+ err = tmpFile.Close()
+ assert.NoError(t, err)
+ return tmpFile.Name(), func() {
+ _ = os.Remove(tmpFile.Name())
+ }
+ },
+ domain: "example.com",
+ expected: false,
+ },
+ {
+ name: "valid cert with wildcard",
+ setup: func(t *testing.T) (string, func()) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ return certPath, func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ }
+ },
+ domain: "example.com",
+ expected: true,
+ },
+ {
+ name: "expired cert",
+ setup: func(t *testing.T) (string, func()) {
+ certPath, keyPath := createTestCert(t, "example.com", true, true, false)
+ return certPath, func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ }
+ },
+ domain: "example.com",
+ expected: false,
+ },
+ {
+ name: "cert expiring soon",
+ setup: func(t *testing.T) (string, func()) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, true)
+ return certPath, func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ }
+ },
+ domain: "example.com",
+ expected: false,
+ },
+ {
+ name: "missing wildcard",
+ setup: func(t *testing.T) (string, func()) {
+ certPath, keyPath := createTestCert(t, "example.com", false, false, false)
+ return certPath, func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ }
+ },
+ domain: "example.com",
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ certPath, cleanup := tt.setup(t)
+ defer cleanup()
+
+ result := validateCertDomains(certPath, tt.domain)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestLoadAndParseCertificate(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) (certPath string, cleanup func())
+ wantError bool
+ validate func(t *testing.T, cert *x509.Certificate)
+ }{
+ {
+ name: "success",
+ setup: func(t *testing.T) (string, func()) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ return certPath, func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ }
+ },
+ wantError: false,
+ validate: func(t *testing.T, cert *x509.Certificate) {
+ assert.Equal(t, "example.com", cert.Subject.CommonName)
+ },
+ },
+ {
+ name: "file not found",
+ setup: func(t *testing.T) (string, func()) {
+ return "nonexistent.pem", func() {}
+ },
+ wantError: true,
+ validate: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ certPath, cleanup := tt.setup(t)
+ defer cleanup()
+
+ cert, err := loadAndParseCertificate(certPath)
+
+ if tt.wantError {
+ assert.Error(t, err)
+ assert.Nil(t, cert)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, cert)
+ if tt.validate != nil {
+ tt.validate(t, cert)
+ }
+ }
+ })
+ }
+}
+
+func TestIsCertificateValid(t *testing.T) {
+ tests := []struct {
+ name string
+ expired bool
+ soon bool
+ expected bool
+ }{
+ {
+ name: "valid certificate",
+ expired: false,
+ soon: false,
+ expected: true,
+ },
+ {
+ name: "expired certificate",
+ expired: true,
+ soon: false,
+ expected: false,
+ },
+ {
+ name: "expiring soon",
+ expired: false,
+ soon: true,
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ certPath, keyPath := createTestCert(t, "example.com", true, tt.expired, tt.soon)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(certPath)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(keyPath)
+
+ cert, err := loadAndParseCertificate(certPath)
+ assert.NoError(t, err)
+
+ result := isCertificateValid(cert)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestExtractCertDomains(t *testing.T) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(certPath)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(keyPath)
+
+ cert, err := loadAndParseCertificate(certPath)
+ assert.NoError(t, err)
+
+ domains := extractCertDomains(cert)
+ assert.Contains(t, domains, "example.com")
+ assert.Contains(t, domains, "*.example.com")
+}
+
+func TestCheckDomainCoverage(t *testing.T) {
+ tests := []struct {
+ name string
+ certDomains []string
+ domain string
+ wantBase bool
+ wantWildcard bool
+ }{
+ {
+ name: "both covered",
+ certDomains: []string{"example.com", "*.example.com"},
+ domain: "example.com",
+ wantBase: true,
+ wantWildcard: true,
+ },
+ {
+ name: "only base",
+ certDomains: []string{"example.com"},
+ domain: "example.com",
+ wantBase: true,
+ wantWildcard: false,
+ },
+ {
+ name: "only wildcard",
+ certDomains: []string{"*.example.com"},
+ domain: "example.com",
+ wantBase: false,
+ wantWildcard: true,
+ },
+ {
+ name: "neither",
+ certDomains: []string{"other.com"},
+ domain: "example.com",
+ wantBase: false,
+ wantWildcard: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ hasBase, hasWildcard := checkDomainCoverage(tt.certDomains, tt.domain)
+ assert.Equal(t, tt.wantBase, hasBase)
+ assert.Equal(t, tt.wantWildcard, hasWildcard)
+ })
+ }
+}
+
+func TestTLSManager_getTLSConfig(t *testing.T) {
+ tm := &tlsManager{
+ useCertMagic: false,
+ }
+ cfg := tm.getTLSConfig()
+ assert.NotNil(t, cfg)
+ assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
+ assert.Equal(t, uint16(tls.VersionTLS13), cfg.MaxVersion)
+ assert.NotNil(t, cfg.GetCertificate)
+}
+
+func TestTLSManager_getCertificate(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) *tlsManager
+ wantError bool
+ errorContains string
+ }{
+ {
+ name: "no certificate available",
+ setup: func(t *testing.T) *tlsManager {
+ return &tlsManager{
+ useCertMagic: false,
+ userCert: nil,
+ }
+ },
+ wantError: true,
+ errorContains: "no certificate available",
+ },
+ {
+ name: "with user certificate",
+ setup: func(t *testing.T) *tlsManager {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ cert, err := tls.LoadX509KeyPair(certPath, keyPath)
+ assert.NoError(t, err)
+
+ return &tlsManager{
+ useCertMagic: false,
+ userCert: &cert,
+ }
+ },
+ wantError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tm := tt.setup(t)
+ hello := &tls.ClientHelloInfo{
+ ServerName: "example.com",
+ }
+
+ cert, err := tm.getCertificate(hello)
+
+ if tt.wantError {
+ assert.Error(t, err)
+ assert.Nil(t, cert)
+ if tt.errorContains != "" {
+ assert.Contains(t, err.Error(), tt.errorContains)
+ }
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, cert)
+ }
+ })
+ }
+}
+
+func TestTLSManager_userCertsExistAndValid(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) *tlsManager
+ expected bool
+ }{
+ {
+ name: "no files",
+ setup: func(t *testing.T) *tlsManager {
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+
+ return &tlsManager{
+ config: mockCfg,
+ certPath: "nonexistent.pem",
+ keyPath: "nonexistent.key",
+ }
+ },
+ expected: false,
+ },
+ {
+ name: "missing key file",
+ setup: func(t *testing.T) *tlsManager {
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() { _ = os.Remove(certPath) })
+ err := os.Remove(keyPath)
+ assert.NoError(t, err)
+
+ return &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+ },
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tm := tt.setup(t)
+ result := tm.userCertsExistAndValid()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTLSManager_certFilesExist(t *testing.T) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(certPath)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(keyPath)
+
+ tm := &tlsManager{
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+
+ result := tm.certFilesExist()
+ assert.True(t, result)
+}
+
+func TestTLSManager_loadUserCerts(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) *tlsManager
+ wantError bool
+ }{
+ {
+ name: "success",
+ setup: func(t *testing.T) *tlsManager {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ return &tlsManager{
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+ },
+ wantError: false,
+ },
+ {
+ name: "invalid path",
+ setup: func(t *testing.T) *tlsManager {
+ return &tlsManager{
+ certPath: "nonexistent.pem",
+ keyPath: "nonexistent.key",
+ }
+ },
+ wantError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tm := tt.setup(t)
+ err := tm.loadUserCerts()
+
+ if tt.wantError {
+ assert.Error(t, err)
+ assert.Nil(t, tm.userCert)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, tm.userCert)
+ }
+ })
+ }
+}
+
+func TestCreateTLSManager(t *testing.T) {
+ tmpDir := setupTestDir(t)
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("TLSStoragePath").Return(tmpDir)
+
+ tm := createTLSManager(mockCfg)
+
+ assert.NotNil(t, tm)
+ assert.Equal(t, mockCfg, tm.config)
+ assert.Equal(t, filepath.Join(tmpDir, "cert.pem"), tm.certPath)
+ assert.Equal(t, filepath.Join(tmpDir, "privkey.pem"), tm.keyPath)
+ assert.Equal(t, filepath.Join(tmpDir, "certmagic"), tm.storagePath)
+}
+
+func TestNewCertWatcher(t *testing.T) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(certPath)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(keyPath)
+
+ mockCfg := &MockConfig{}
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+
+ watcher := newCertWatcher(tm)
+
+ assert.NotNil(t, watcher)
+ assert.Equal(t, tm, watcher.tm)
+ assert.False(t, watcher.lastCertMod.IsZero())
+ assert.False(t, watcher.lastKeyMod.IsZero())
+}
+
+func TestCertWatcher_filesModified(t *testing.T) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(certPath)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(keyPath)
+
+ mockCfg := &MockConfig{}
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+
+ watcher := newCertWatcher(tm)
+
+ certInfo, err := os.Stat(certPath)
+ assert.NoError(t, err)
+ keyInfo, err := os.Stat(keyPath)
+ assert.NoError(t, err)
+
+ result := watcher.filesModified(certInfo, keyInfo)
+ assert.False(t, result)
+
+ watcher.lastCertMod = time.Now().Add(-1 * time.Hour)
+
+ result = watcher.filesModified(certInfo, keyInfo)
+ assert.True(t, result)
+}
+
+func TestCertWatcher_updateModTimes(t *testing.T) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(certPath)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(keyPath)
+
+ mockCfg := &MockConfig{}
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+
+ watcher := newCertWatcher(tm)
+
+ certInfo, err := os.Stat(certPath)
+ assert.NoError(t, err)
+ keyInfo, err := os.Stat(keyPath)
+ assert.NoError(t, err)
+
+ time.Sleep(10 * time.Millisecond)
+ watcher.updateModTimes(certInfo, keyInfo)
+
+ assert.Equal(t, certInfo.ModTime(), watcher.lastCertMod)
+ assert.Equal(t, keyInfo.ModTime(), watcher.lastKeyMod)
+}
+
+func TestCertWatcher_getFileInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) *tlsManager
+ wantError bool
+ validate func(t *testing.T, certInfo, keyInfo os.FileInfo)
+ }{
+ {
+ name: "success",
+ setup: func(t *testing.T) *tlsManager {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ return &tlsManager{
+ config: &MockConfig{},
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+ },
+ wantError: false,
+ validate: func(t *testing.T, certInfo, keyInfo os.FileInfo) {
+ assert.NotNil(t, certInfo)
+ assert.NotNil(t, keyInfo)
+ },
+ },
+ {
+ name: "missing cert file",
+ setup: func(t *testing.T) *tlsManager {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ err := os.Remove(certPath)
+ assert.NoError(t, err)
+ t.Cleanup(func() { _ = os.Remove(keyPath) })
+
+ return &tlsManager{
+ config: &MockConfig{},
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+ },
+ wantError: true,
+ },
+ {
+ name: "missing key file",
+ setup: func(t *testing.T) *tlsManager {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ err := os.Remove(keyPath)
+ assert.NoError(t, err)
+ t.Cleanup(func() { _ = os.Remove(certPath) })
+
+ return &tlsManager{
+ config: &MockConfig{},
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+ },
+ wantError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tm := tt.setup(t)
+ watcher := newCertWatcher(tm)
+
+ certInfo, keyInfo, err := watcher.getFileInfo()
+
+ if tt.wantError {
+ assert.Error(t, err)
+ assert.Nil(t, certInfo)
+ assert.Nil(t, keyInfo)
+ } else {
+ assert.NoError(t, err)
+ if tt.validate != nil {
+ tt.validate(t, certInfo, keyInfo)
+ }
+ }
+ })
+ }
+}
+
+func TestCertWatcher_checkAndReloadCerts(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) (*tlsManager, *certWatcher)
+ expected bool
+ }{
+ {
+ name: "file error",
+ setup: func(t *testing.T) (*tlsManager, *certWatcher) {
+ tm := &tlsManager{
+ config: &MockConfig{},
+ certPath: "nonexistent.pem",
+ keyPath: "nonexistent.key",
+ }
+ return tm, newCertWatcher(tm)
+ },
+ expected: false,
+ },
+ {
+ name: "no modification",
+ setup: func(t *testing.T) (*tlsManager, *certWatcher) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ tm := &tlsManager{
+ config: &MockConfig{},
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+ return tm, newCertWatcher(tm)
+ },
+ expected: false,
+ },
+ {
+ name: "with modification",
+ setup: func(t *testing.T) (*tlsManager, *certWatcher) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+
+ err := tm.loadUserCerts()
+ assert.NoError(t, err)
+
+ watcher := newCertWatcher(tm)
+ watcher.lastCertMod = time.Now().Add(-1 * time.Hour)
+
+ return tm, watcher
+ },
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, watcher := tt.setup(t)
+ result := watcher.checkAndReloadCerts()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestCertWatcher_handleCertificateChange(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo)
+ expected bool
+ }{
+ {
+ name: "successful reload",
+ setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+
+ watcher := newCertWatcher(tm)
+
+ certInfo, _ := os.Stat(certPath)
+ keyInfo, _ := os.Stat(keyPath)
+
+ return tm, watcher, certInfo, keyInfo
+ },
+ expected: false,
+ },
+ {
+ name: "invalid cert triggers certmagic",
+ setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
+ certPath, keyPath := createTestCert(t, "example.com", false, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ tmpDir := setupTestDir(t)
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+ mockCfg.On("CFAPIToken").Return("")
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ storagePath: tmpDir,
+ }
+
+ watcher := newCertWatcher(tm)
+
+ certInfo, _ := os.Stat(certPath)
+ keyInfo, _ := os.Stat(keyPath)
+
+ return tm, watcher, certInfo, keyInfo
+ },
+ expected: false,
+ },
+ {
+ name: "load error",
+ setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: "nonexistent.key",
+ }
+
+ watcher := newCertWatcher(tm)
+
+ certInfo, _ := os.Stat(certPath)
+ keyInfo, _ := os.Stat(keyPath)
+
+ return tm, watcher, certInfo, keyInfo
+ },
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, watcher, certInfo, keyInfo := tt.setup(t)
+ result := watcher.handleCertificateChange(certInfo, keyInfo)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestCertWatcher_switchToCertMagic(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) *tlsManager
+ expected bool
+ }{
+ {
+ name: "with staging token",
+ setup: func(t *testing.T) *tlsManager {
+ tmpDir := setupTestDir(t)
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+ mockCfg.On("CFAPIToken").Return("test-token")
+ mockCfg.On("ACMEEmail").Return("test@example.com")
+ mockCfg.On("ACMEStaging").Return(true)
+
+ return &tlsManager{
+ config: mockCfg,
+ storagePath: tmpDir,
+ }
+ },
+ expected: false,
+ },
+ {
+ name: "missing token",
+ setup: func(t *testing.T) *tlsManager {
+ tmpDir := setupTestDir(t)
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+ mockCfg.On("CFAPIToken").Return("")
+
+ return &tlsManager{
+ config: mockCfg,
+ storagePath: tmpDir,
+ }
+ },
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tm := tt.setup(t)
+ watcher := newCertWatcher(tm)
+ result := watcher.switchToCertMagic()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestCertWatcher_watch(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) (*tlsManager, *certWatcher)
+ expected bool
+ }{
+ {
+ name: "exits on certmagic switch attempt",
+ setup: func(t *testing.T) (*tlsManager, *certWatcher) {
+ certPath, keyPath := createTestCert(t, "example.com", false, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ tmpDir := setupTestDir(t)
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+ mockCfg.On("CFAPIToken").Return("")
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ storagePath: tmpDir,
+ }
+
+ watcher := newCertWatcher(tm)
+ watcher.lastCertMod = time.Now().Add(-1 * time.Hour)
+ watcher.lastKeyMod = time.Now().Add(-1 * time.Hour)
+
+ return tm, watcher
+ },
+ expected: false,
+ },
+ {
+ name: "continues on no modification",
+ setup: func(t *testing.T) (*tlsManager, *certWatcher) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+
+ return tm, newCertWatcher(tm)
+ },
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, watcher := tt.setup(t)
+ result := watcher.checkAndReloadCerts()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestCertWatcher_watch_Integration(t *testing.T) {
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(certPath)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(keyPath)
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("Domain").Return("example.com")
+
+ tm := &tlsManager{
+ config: mockCfg,
+ certPath: certPath,
+ keyPath: keyPath,
+ }
+
+ err := tm.loadUserCerts()
+ assert.NoError(t, err)
+ initialCert := tm.userCert
+
+ watcher := newCertWatcher(tm)
+
+ go watcher.watch()
+
+ time.Sleep(50 * time.Millisecond)
+
+ newCertPath, newKeyPath := createTestCert(t, "example.com", true, false, false)
+ defer func(name string) {
+ err = os.Remove(name)
+ assert.NoError(t, err)
+ }(newCertPath)
+ defer func(name string) {
+ err = os.Remove(name)
+ assert.NoError(t, err)
+ }(newKeyPath)
+
+ newCertData, err := os.ReadFile(newCertPath)
+ assert.NoError(t, err)
+ newKeyData, err := os.ReadFile(newKeyPath)
+ assert.NoError(t, err)
+
+ err = os.WriteFile(certPath, newCertData, 0644)
+ assert.NoError(t, err)
+ err = os.WriteFile(keyPath, newKeyData, 0644)
+ assert.NoError(t, err)
+
+ time.Sleep(100 * time.Millisecond)
+
+ assert.NotNil(t, tm.userCert)
+ assert.Equal(t, initialCert, tm.userCert)
+}
+
+func TestNewTLSConfig(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func(t *testing.T) config.Config
+ wantError bool
+ errorMsg string
+ validate func(t *testing.T, cfg *tls.Config)
+ }{
+ {
+ name: "with valid user certs",
+ setup: func(t *testing.T) config.Config {
+ globalTLSManager = nil
+ tlsManagerOnce = sync.Once{}
+
+ tmpDir := setupTestDir(t)
+
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ t.Cleanup(func() {
+ _ = os.Remove(certPath)
+ _ = os.Remove(keyPath)
+ })
+
+ certData, err := os.ReadFile(certPath)
+ assert.NoError(t, err)
+ keyData, err := os.ReadFile(keyPath)
+ assert.NoError(t, err)
+
+ err = os.WriteFile(filepath.Join(tmpDir, "cert.pem"), certData, 0644)
+ assert.NoError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir, "privkey.pem"), keyData, 0644)
+ assert.NoError(t, err)
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("TLSStoragePath").Return(tmpDir)
+ mockCfg.On("Domain").Return("example.com")
+
+ return mockCfg
+ },
+ wantError: false,
+ validate: func(t *testing.T, cfg *tls.Config) {
+ assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
+ assert.NotNil(t, cfg.GetCertificate)
+ },
+ },
+ {
+ name: "missing certs requires certmagic",
+ setup: func(t *testing.T) config.Config {
+ globalTLSManager = nil
+ tlsManagerOnce = sync.Once{}
+
+ tmpDir := setupTestDir(t)
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("TLSStoragePath").Return(tmpDir)
+ mockCfg.On("Domain").Return("example.com")
+ mockCfg.On("CFAPIToken").Return("")
+
+ return mockCfg
+ },
+ wantError: true,
+ errorMsg: "CF_API_TOKEN",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := tt.setup(t)
+ tlsConfig, err := NewTLSConfig(cfg)
+
+ if tt.wantError {
+ assert.Error(t, err)
+ if tt.errorMsg != "" {
+ assert.Contains(t, err.Error(), tt.errorMsg)
+ }
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, tlsConfig)
+ if tt.validate != nil {
+ tt.validate(t, tlsConfig)
+ }
+ }
+ })
+ }
+}
+
+func TestNewTLSConfig_Singleton(t *testing.T) {
+ globalTLSManager = nil
+ tlsManagerOnce = sync.Once{}
+
+ tmpDir := setupTestDir(t)
+
+ certPath, keyPath := createTestCert(t, "example.com", true, false, false)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(certPath)
+ defer func(name string) {
+ err := os.Remove(name)
+ assert.NoError(t, err)
+ }(keyPath)
+
+ certData, err := os.ReadFile(certPath)
+ assert.NoError(t, err)
+ keyData, err := os.ReadFile(keyPath)
+ assert.NoError(t, err)
+
+ err = os.WriteFile(filepath.Join(tmpDir, "cert.pem"), certData, 0644)
+ assert.NoError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir, "privkey.pem"), keyData, 0644)
+ assert.NoError(t, err)
+
+ mockCfg := &MockConfig{}
+ mockCfg.On("TLSStoragePath").Return(tmpDir)
+ mockCfg.On("Domain").Return("example.com")
+
+ tlsConfig1, err1 := NewTLSConfig(mockCfg)
+ tlsConfig2, err2 := NewTLSConfig(mockCfg)
+
+ assert.NoError(t, err1)
+ assert.NoError(t, err2)
+ assert.NotNil(t, tlsConfig1)
+ assert.NotNil(t, tlsConfig2)
+
+ assert.Equal(t, tlsConfig1.MinVersion, tlsConfig2.MinVersion)
+ assert.Equal(t, tlsConfig1.MaxVersion, tlsConfig2.MaxVersion)
+ assert.Equal(t, tlsConfig1.SessionTicketsDisabled, tlsConfig2.SessionTicketsDisabled)
+ assert.Equal(t, tlsConfig1.ClientAuth, tlsConfig2.ClientAuth)
+
+ hello := &tls.ClientHelloInfo{ServerName: "example.com"}
+ cert1, err1 := tlsConfig1.GetCertificate(hello)
+ cert2, err2 := tlsConfig2.GetCertificate(hello)
+
+ assert.NoError(t, err1)
+ assert.NoError(t, err2)
+ assert.NotNil(t, cert1)
+ assert.NotNil(t, cert2)
+
+ assert.Equal(t, cert1, cert2)
+}
diff --git a/internal/transport/transport.go b/internal/transport/transport.go
index ca27061..31219fd 100644
--- a/internal/transport/transport.go
+++ b/internal/transport/transport.go
@@ -8,3 +8,7 @@ type Transport interface {
Listen() (net.Listener, error)
Serve(listener net.Listener) error
}
+
+type HTTP interface {
+ Handler(conn net.Conn, isTLS bool)
+}
diff --git a/internal/version/version_test.go b/internal/version/version_test.go
new file mode 100644
index 0000000..f4873f5
--- /dev/null
+++ b/internal/version/version_test.go
@@ -0,0 +1,84 @@
+package version
+
+import (
+ "fmt"
+ "testing"
+)
+
+func TestVersionFunctions(t *testing.T) {
+ origVersion := Version
+ origBuildDate := BuildDate
+ origCommit := Commit
+ defer func() {
+ Version = origVersion
+ BuildDate = origBuildDate
+ Commit = origCommit
+ }()
+
+ tests := []struct {
+ name string
+ version string
+ buildDate string
+ commit string
+ wantFull string
+ wantShort string
+ }{
+ {
+ name: "Default dev version",
+ version: "dev",
+ buildDate: "unknown",
+ commit: "unknown",
+ wantFull: "tunnel_pls dev (commit: unknown, built: unknown)",
+ wantShort: "dev",
+ },
+ {
+ name: "Release version",
+ version: "v1.0.0",
+ buildDate: "2026-01-23",
+ commit: "abcdef123",
+ wantFull: "tunnel_pls v1.0.0 (commit: abcdef123, built: 2026-01-23)",
+ wantShort: "v1.0.0",
+ },
+ {
+ name: "Empty values",
+ version: "",
+ buildDate: "",
+ commit: "",
+ wantFull: "tunnel_pls (commit: , built: )",
+ wantShort: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ Version = tt.version
+ BuildDate = tt.buildDate
+ Commit = tt.commit
+
+ gotFull := GetVersion()
+ if gotFull != tt.wantFull {
+ t.Errorf("GetVersion() = %q, want %q", gotFull, tt.wantFull)
+ }
+
+ gotShort := GetShortVersion()
+ if gotShort != tt.wantShort {
+ t.Errorf("GetShortVersion() = %q, want %q", gotShort, tt.wantShort)
+ }
+ })
+ }
+}
+
+func TestGetVersion_Format(t *testing.T) {
+ v := "1.2.3"
+ c := "brainrot"
+ d := "now"
+
+ Version = v
+ Commit = c
+ BuildDate = d
+
+ expected := fmt.Sprintf("tunnel_pls %s (commit: %s, built: %s)", v, c, d)
+ if GetVersion() != expected {
+ t.Errorf("GetVersion() formatting mismatch")
+ }
+}
diff --git a/main.go b/main.go
index f897b46..a908903 100644
--- a/main.go
+++ b/main.go
@@ -1,27 +1,13 @@
package main
import (
- "context"
"fmt"
"log"
- "net"
- "net/http"
- _ "net/http/pprof"
"os"
- "os/signal"
- "syscall"
- "time"
+ "tunnel_pls/internal/bootstrap"
"tunnel_pls/internal/config"
- "tunnel_pls/internal/grpc/client"
- "tunnel_pls/internal/key"
"tunnel_pls/internal/port"
- "tunnel_pls/internal/registry"
- "tunnel_pls/internal/transport"
"tunnel_pls/internal/version"
- "tunnel_pls/server"
- "tunnel_pls/types"
-
- "golang.org/x/crypto/ssh"
)
func main() {
@@ -32,148 +18,19 @@ func main() {
log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile)
-
log.Printf("Starting %s", version.GetVersion())
conf, err := config.MustLoad()
if err != nil {
- log.Fatalf("Failed to load configuration: %s", err)
- return
+ log.Fatalf("Config load error: %v", err)
}
- sshConfig := &ssh.ServerConfig{
- NoClientAuth: true,
- ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
- }
-
- sshKeyPath := "certs/ssh/id_rsa"
- if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
- log.Fatalf("Failed to generate SSH key: %s", err)
- }
-
- privateBytes, err := os.ReadFile(sshKeyPath)
+ boot, err := bootstrap.New(conf, port.New())
if err != nil {
- log.Fatalf("Failed to load private key: %s", err)
+ log.Fatalf("Startup error: %v", err)
}
- private, err := ssh.ParsePrivateKey(privateBytes)
- if err != nil {
- log.Fatalf("Failed to parse private key: %s", err)
- }
-
- sshConfig.AddHostKey(private)
- sessionRegistry := registry.NewRegistry()
-
- ctx, cancel := context.WithCancel(context.Background())
- defer cancel()
-
- errChan := make(chan error, 2)
- shutdownChan := make(chan os.Signal, 1)
- signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
-
- var grpcClient client.Client
-
- if conf.Mode() == types.ServerModeNODE {
- grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort())
-
- grpcClient, err = client.New(conf, grpcAddr, sessionRegistry)
- if err != nil {
- log.Fatalf("failed to create grpc client: %v", err)
- }
-
- healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
- if err = grpcClient.CheckServerHealth(healthCtx); err != nil {
- healthCancel()
- log.Fatalf("gRPC health check failed: %v", err)
- }
- healthCancel()
-
- go func() {
- if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
- errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
- }
- }()
- }
-
- go func() {
- var httpListener net.Listener
- httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect())
- httpListener, err = httpserver.Listen()
- if err != nil {
- errChan <- fmt.Errorf("failed to start http server: %w", err)
- return
- }
- err = httpserver.Serve(httpListener)
- if err != nil {
- errChan <- fmt.Errorf("error when serving http server: %w", err)
- return
- }
- }()
-
- if conf.TLSEnabled() {
- go func() {
- var httpsListener net.Listener
- tlsConfig, _ := transport.NewTLSConfig(conf)
- httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig)
- httpsListener, err = httpsServer.Listen()
- if err != nil {
- errChan <- fmt.Errorf("failed to start http server: %w", err)
- return
- }
- err = httpsServer.Serve(httpsListener)
- if err != nil {
- errChan <- fmt.Errorf("error when serving http server: %w", err)
- return
- }
- }()
- }
-
- portManager := port.New()
- err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd())
- if err != nil {
- log.Fatalf("Failed to initialize port manager: %s", err)
- return
- }
-
- var app server.Server
- go func() {
- app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort())
- if err != nil {
- errChan <- fmt.Errorf("failed to start server: %s", err)
- return
- }
- app.Start()
-
- }()
-
- if conf.PprofEnabled() {
- go func() {
- pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort())
- log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
- if err = http.ListenAndServe(pprofAddr, nil); err != nil {
- log.Printf("pprof server error: %v", err)
- }
- }()
- }
-
- select {
- case err = <-errChan:
- log.Printf("error happen : %s", err)
- case sig := <-shutdownChan:
- log.Printf("received signal %s, shutting down", sig)
- }
-
- cancel()
-
- if app != nil {
- if err = app.Close(); err != nil {
- log.Printf("failed to close server : %s", err)
- }
- }
-
- if grpcClient != nil {
- if err = grpcClient.Close(); err != nil {
- log.Printf("failed to close grpc conn : %s", err)
- }
+ if err = boot.Run(); err != nil {
+ log.Fatalf("Application error: %v", err)
}
}
diff --git a/server/server.go b/server/server.go
index f47c579..d3df5fd 100644
--- a/server/server.go
+++ b/server/server.go
@@ -4,12 +4,14 @@ import (
"context"
"errors"
"fmt"
+ "io"
"log"
"net"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port"
+ "tunnel_pls/internal/random"
"tunnel_pls/internal/registry"
"tunnel_pls/session"
@@ -21,6 +23,7 @@ type Server interface {
Close() error
}
type server struct {
+ randomizer random.Random
config config.Config
sshPort string
sshListener net.Listener
@@ -30,13 +33,14 @@ type server struct {
portRegistry port.Port
}
-func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
+func New(randomizer random.Random, config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
if err != nil {
return nil, err
}
return &server{
+ randomizer: randomizer,
config: config,
sshPort: sshPort,
sshListener: listener,
@@ -82,7 +86,7 @@ func (s *server) handleConnection(conn net.Conn) {
defer func(sshConn *ssh.ServerConn) {
err = sshConn.Close()
- if err != nil && !errors.Is(err, net.ErrClosed) {
+ if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
log.Printf("failed to close SSH server: %v", err)
}
}(sshConn)
@@ -95,11 +99,19 @@ func (s *server) handleConnection(conn net.Conn) {
cancel()
}
log.Println("SSH connection established:", sshConn.User())
- sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
+ sshSession := session.New(&session.Config{
+ Randomizer: s.randomizer,
+ Config: s.config,
+ Conn: sshConn,
+ InitialReq: forwardingReqs,
+ SshChan: chans,
+ SessionRegistry: s.sessionRegistry,
+ PortRegistry: s.portRegistry,
+ User: user,
+ })
err = sshSession.Start()
if err != nil {
- log.Printf("SSH session ended with error: %v", err)
+ log.Printf("SSH session ended with error: %s", err.Error())
return
}
- return
}
diff --git a/server/server_test.go b/server/server_test.go
new file mode 100644
index 0000000..a4d5c74
--- /dev/null
+++ b/server/server_test.go
@@ -0,0 +1,880 @@
+package server
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/rsa"
+ "errors"
+ "fmt"
+ "net"
+ "testing"
+ "time"
+ "tunnel_pls/internal/registry"
+ "tunnel_pls/session/slug"
+ "tunnel_pls/types"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "golang.org/x/crypto/ssh"
+ "google.golang.org/grpc"
+)
+
+type MockRandom struct {
+ mock.Mock
+}
+
+func (m *MockRandom) String(length int) (string, error) {
+ args := m.Called(length)
+ return args.String(0), args.Error(1)
+}
+
+type MockConfig struct {
+ mock.Mock
+}
+
+func (m *MockConfig) Domain() string { return m.Called().String(0) }
+func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
+func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
+func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
+func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
+func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
+func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
+func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
+func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
+func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
+func (m *MockConfig) Mode() types.ServerMode {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return 0
+ }
+ switch v := args.Get(0).(type) {
+ case types.ServerMode:
+ return v
+ case int:
+ return types.ServerMode(v)
+ default:
+ return types.ServerMode(args.Int(0))
+ }
+}
+func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
+func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
+func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
+func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
+
+type MockSessionRegistry struct {
+ mock.Mock
+}
+
+func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
+ args := m.Called(key)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(registry.Session), args.Error(1)
+}
+
+func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
+ args := m.Called(user, key)
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(registry.Session), args.Error(1)
+}
+
+func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
+ args := m.Called(user, oldKey, newKey)
+ return args.Error(0)
+}
+
+func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
+ args := m.Called(key, session)
+ return args.Bool(0)
+}
+
+func (m *MockSessionRegistry) Remove(key registry.Key) {
+ m.Called(key)
+}
+
+func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
+ args := m.Called(user)
+ return args.Get(0).([]registry.Session)
+}
+
+func (m *MockSessionRegistry) Slug() slug.Slug {
+ args := m.Called()
+ return args.Get(0).(slug.Slug)
+}
+
+type MockGRPCClient struct {
+ mock.Mock
+}
+
+func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
+ args := m.Called()
+ return args.Get(0).(*grpc.ClientConn)
+}
+
+func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
+ args := m.Called(ctx, token)
+ return args.Bool(0), args.String(1), args.Error(2)
+}
+
+func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
+ args := m.Called(ctx)
+ return args.Error(0)
+}
+
+func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
+ args := m.Called(ctx, domain, token)
+ return args.Error(0)
+}
+
+func (m *MockGRPCClient) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+type MockPort struct {
+ mock.Mock
+}
+
+func (m *MockPort) AddRange(startPort, endPort uint16) error {
+ return m.Called(startPort, endPort).Error(0)
+}
+
+func (m *MockPort) Unassigned() (uint16, bool) {
+ args := m.Called()
+ return uint16(args.Int(0)), args.Bool(1)
+}
+
+func (m *MockPort) SetStatus(port uint16, assigned bool) error {
+ return m.Called(port, assigned).Error(0)
+}
+
+func (m *MockPort) Claim(port uint16) bool {
+ return m.Called(port).Bool(0)
+}
+
+type MockListener struct {
+ mock.Mock
+}
+
+func (m *MockListener) Accept() (net.Conn, error) {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil, args.Error(1)
+ }
+ return args.Get(0).(net.Conn), args.Error(1)
+}
+
+func (m *MockListener) Close() error {
+ return m.Called().Error(0)
+}
+
+func (m *MockListener) Addr() net.Addr {
+ return m.Called().Get(0).(net.Addr)
+}
+
+func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
+ key, _ := rsa.GenerateKey(rand.Reader, 2048)
+ signer, _ := ssh.NewSignerFromKey(key)
+ config := &ssh.ServerConfig{
+ NoClientAuth: true,
+ }
+ config.AddHostKey(signer)
+ return config, signer
+}
+
+func TestNew(t *testing.T) {
+ mr := new(MockRandom)
+ mc := new(MockConfig)
+ mreg := new(MockSessionRegistry)
+ mg := new(MockGRPCClient)
+ mp := new(MockPort)
+ sc, _ := getTestSSHConfig()
+
+ tests := []struct {
+ name string
+ port string
+ wantErr bool
+ }{
+ {
+ name: "success",
+ port: "0",
+ wantErr: false,
+ },
+ {
+ name: "invalid port",
+ port: "invalid",
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ s, err := New(mr, mc, sc, mreg, mg, mp, tt.port)
+ if tt.wantErr {
+ assert.Error(t, err)
+ assert.Nil(t, s)
+ } else {
+ assert.NoError(t, err)
+ assert.NotNil(t, s)
+ _ = s.Close()
+ }
+ })
+ }
+
+ t.Run("port already in use", func(t *testing.T) {
+ l, err := net.Listen("tcp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ port := l.Addr().(*net.TCPAddr).Port
+ defer func(l net.Listener) {
+ err = l.Close()
+ assert.NoError(t, err)
+ }(l)
+
+ s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
+ assert.Error(t, err)
+ assert.Nil(t, s)
+ })
+}
+
+func TestClose(t *testing.T) {
+ mr := new(MockRandom)
+ mc := new(MockConfig)
+ mreg := new(MockSessionRegistry)
+ mg := new(MockGRPCClient)
+ mp := new(MockPort)
+ sc, _ := getTestSSHConfig()
+
+ t.Run("successful close", func(t *testing.T) {
+ s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
+ err := s.Close()
+ assert.NoError(t, err)
+ })
+
+ t.Run("close already closed listener", func(t *testing.T) {
+ s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
+ _ = s.Close()
+ err := s.Close()
+ assert.Error(t, err)
+ })
+
+ t.Run("close with nil listener", func(t *testing.T) {
+ s := &server{
+ sshListener: nil,
+ }
+ defer func() {
+ if r := recover(); r != nil {
+ assert.NotNil(t, r)
+ }
+ }()
+ _ = s.Close()
+ t.Fatal("expected panic for nil listener")
+ })
+}
+
+func TestStart(t *testing.T) {
+ mr := new(MockRandom)
+ mc := new(MockConfig)
+ mreg := new(MockSessionRegistry)
+ mg := new(MockGRPCClient)
+ mp := new(MockPort)
+ sc, _ := getTestSSHConfig()
+
+ t.Run("normal stop", func(t *testing.T) {
+ s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ _ = s.Close()
+ }()
+ s.Start()
+ })
+
+ t.Run("accept error - temporary error continues loop", func(t *testing.T) {
+ ml := new(MockListener)
+ s := &server{
+ sshListener: ml,
+ sshPort: "0",
+ }
+
+ ml.On("Accept").Return(nil, errors.New("temporary error")).Once()
+ ml.On("Accept").Return(nil, net.ErrClosed).Once()
+
+ s.Start()
+ ml.AssertExpectations(t)
+ })
+
+ t.Run("accept error - immediate close", func(t *testing.T) {
+ ml := new(MockListener)
+ s := &server{
+ sshListener: ml,
+ sshPort: "0",
+ }
+
+ ml.On("Accept").Return(nil, net.ErrClosed).Once()
+
+ s.Start()
+ ml.AssertExpectations(t)
+ })
+
+ t.Run("accept success - connection fails SSH handshake", func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockGrpcClient := &MockGRPCClient{}
+ mockPort := &MockPort{}
+
+ sshConfig, _ := getTestSSHConfig()
+
+ serverConn, clientConn := net.Pipe()
+
+ mockListener := &MockListener{}
+ mockListener.On("Accept").Return(serverConn, nil).Once()
+ mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
+
+ s := &server{
+ randomizer: mockRandom,
+ config: mockConfig,
+ sshPort: "0",
+ sshListener: mockListener,
+ sshConfig: sshConfig,
+ grpcClient: mockGrpcClient,
+ sessionRegistry: mockSessionRegistry,
+ portRegistry: mockPort,
+ }
+
+ go s.Start()
+
+ time.Sleep(50 * time.Millisecond)
+ err := clientConn.Close()
+ assert.NoError(t, err)
+
+ time.Sleep(100 * time.Millisecond)
+
+ mockListener.AssertExpectations(t)
+ })
+
+ t.Run("accept success - valid SSH connection without auth", func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockPort := &MockPort{}
+
+ sshConfig, _ := getTestSSHConfig()
+
+ serverConn, clientConn := net.Pipe()
+
+ mockListener := &MockListener{}
+ mockListener.On("Accept").Return(serverConn, nil).Once()
+ mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
+
+ s := &server{
+ randomizer: mockRandom,
+ config: mockConfig,
+ sshPort: "0",
+ sshListener: mockListener,
+ sshConfig: sshConfig,
+ grpcClient: nil,
+ sessionRegistry: mockSessionRegistry,
+ portRegistry: mockPort,
+ }
+
+ go s.Start()
+
+ time.Sleep(50 * time.Millisecond)
+ err := clientConn.Close()
+ assert.NoError(t, err)
+
+ time.Sleep(100 * time.Millisecond)
+
+ mockListener.AssertExpectations(t)
+ })
+}
+
+func TestHandleConnection(t *testing.T) {
+ t.Run("SSH handshake fails - connection closed", func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockGrpcClient := &MockGRPCClient{}
+ mockPort := &MockPort{}
+
+ sshConfig, _ := getTestSSHConfig()
+
+ serverConn, clientConn := net.Pipe()
+
+ s := &server{
+ randomizer: mockRandom,
+ config: mockConfig,
+ sshPort: "0",
+ sshConfig: sshConfig,
+ grpcClient: mockGrpcClient,
+ sessionRegistry: mockSessionRegistry,
+ portRegistry: mockPort,
+ }
+
+ err := clientConn.Close()
+ assert.NoError(t, err)
+
+ s.handleConnection(serverConn)
+ })
+
+ // SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS
+ // GONNA IMPLEMENT THIS UNIT TEST LATER
+
+ //t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) {
+ // mockRandom := &MockRandom{}
+ // mockConfig := &MockConfig{}
+ // mockSessionRegistry := &MockSessionRegistry{}
+ // mockGrpcClient := &MockGRPCClient{}
+ // mockPort := &MockPort{}
+ //
+ // sshConfig, _ := getTestSSHConfig()
+ //
+ // serverConn, clientConn := net.Pipe()
+ //
+ // s := &server{
+ // randomizer: mockRandom,
+ // config: mockConfig,
+ // sshPort: "0",
+ // sshConfig: sshConfig,
+ // grpcClient: mockGrpcClient,
+ // sessionRegistry: mockSessionRegistry,
+ // portRegistry: mockPort,
+ // }
+ //
+ // done := make(chan bool, 1)
+ //
+ // go func() {
+ // s.handleConnection(serverConn)
+ // done <- true
+ // }()
+ //
+ // go func() {
+ // clientConn.Write([]byte("invalid ssh protocol\n"))
+ // clientConn.Close()
+ // }()
+ //
+ // select {
+ // case <-done:
+ // case <-time.After(1 * time.Second):
+ // t.Fatal("handleConnection did not complete in time")
+ // }
+ //})
+
+ t.Run("SSH connection established without gRPC client", func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockPort := &MockPort{}
+
+ serverConfig, _ := getTestSSHConfig()
+
+ mockConfig.On("Domain").Return("test.com")
+ mockConfig.On("Mode").Return(types.ServerModeNODE)
+ mockConfig.On("SSHPort").Return("2200")
+ mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
+ mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
+ mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(listener net.Listener) {
+ err = listener.Close()
+ assert.NoError(t, err)
+ }(listener)
+
+ serverAddr := listener.Addr().String()
+
+ s := &server{
+ randomizer: mockRandom,
+ config: mockConfig,
+ sshPort: "0",
+ sshConfig: serverConfig,
+ grpcClient: nil,
+ sessionRegistry: mockSessionRegistry,
+ portRegistry: mockPort,
+ }
+
+ done := make(chan bool, 1)
+
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ return
+ }
+ s.handleConnection(conn)
+ done <- true
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+
+ clientConfig := &ssh.ClientConfig{
+ User: "testuser",
+ Auth: []ssh.AuthMethod{ssh.Password("password")},
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ Timeout: 2 * time.Second,
+ }
+
+ go func() {
+ client, err := ssh.Dial("tcp", serverAddr, clientConfig)
+ if err != nil {
+ t.Logf("Client dial failed: %v", err)
+ return
+ }
+ defer func(client *ssh.Client) {
+ err = client.Close()
+ assert.NoError(t, err)
+ }(client)
+
+ type forwardPayload struct {
+ BindAddr string
+ BindPort uint32
+ }
+
+ payload := ssh.Marshal(forwardPayload{
+ BindAddr: "localhost",
+ BindPort: 80,
+ })
+
+ _, _, err = client.SendRequest("tcpip-forward", true, payload)
+ if err != nil {
+ t.Logf("Forward request failed: %v", err)
+ }
+
+ time.Sleep(500 * time.Millisecond)
+ }()
+
+ select {
+ case <-done:
+ t.Log("handleConnection completed")
+ case <-time.After(5 * time.Second):
+ t.Fatal("handleConnection did not complete in time")
+ }
+ })
+
+ t.Run("SSH connection established with gRPC authorization", func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockGrpcClient := &MockGRPCClient{}
+ mockPort := &MockPort{}
+
+ serverConfig, _ := getTestSSHConfig()
+
+ mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
+ mockConfig.On("Domain").Return("test.com")
+ mockConfig.On("Mode").Return(types.ServerModeNODE)
+ mockConfig.On("SSHPort").Return("2200")
+ mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
+ mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
+ mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(listener net.Listener) {
+ err = listener.Close()
+ assert.NoError(t, err)
+ }(listener)
+
+ serverAddr := listener.Addr().String()
+
+ s := &server{
+ randomizer: mockRandom,
+ config: mockConfig,
+ sshPort: "0",
+ sshConfig: serverConfig,
+ grpcClient: mockGrpcClient,
+ sessionRegistry: mockSessionRegistry,
+ portRegistry: mockPort,
+ }
+
+ done := make(chan bool, 1)
+
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ return
+ }
+ s.handleConnection(conn)
+ done <- true
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+
+ clientConfig := &ssh.ClientConfig{
+ User: "testuser",
+ Auth: []ssh.AuthMethod{ssh.Password("password")},
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ Timeout: 2 * time.Second,
+ }
+
+ go func() {
+ client, err := ssh.Dial("tcp", serverAddr, clientConfig)
+ if err != nil {
+ t.Logf("Client dial failed: %v", err)
+ return
+ }
+ defer func(client *ssh.Client) {
+ err = client.Close()
+ assert.NoError(t, err)
+ }(client)
+
+ type forwardPayload struct {
+ BindAddr string
+ BindPort uint32
+ }
+
+ payload := ssh.Marshal(forwardPayload{
+ BindAddr: "localhost",
+ BindPort: 80,
+ })
+
+ _, _, err = client.SendRequest("tcpip-forward", true, payload)
+ if err != nil {
+ t.Logf("Forward request failed: %v", err)
+ }
+
+ time.Sleep(500 * time.Millisecond)
+ }()
+
+ select {
+ case <-done:
+ mockGrpcClient.AssertExpectations(t)
+ case <-time.After(5 * time.Second):
+ t.Fatal("handleConnection did not complete in time")
+ }
+ })
+
+ t.Run("SSH connection with gRPC authorization error", func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockGrpcClient := &MockGRPCClient{}
+ mockPort := &MockPort{}
+
+ serverConfig, _ := getTestSSHConfig()
+
+ mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
+ mockConfig.On("Domain").Return("test.com")
+ mockConfig.On("Mode").Return(types.ServerModeNODE)
+ mockConfig.On("SSHPort").Return("2200")
+ mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
+ mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
+ mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
+
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(listener net.Listener) {
+ err = listener.Close()
+ assert.NoError(t, err)
+ }(listener)
+
+ serverAddr := listener.Addr().String()
+
+ s := &server{
+ randomizer: mockRandom,
+ config: mockConfig,
+ sshPort: "0",
+ sshConfig: serverConfig,
+ grpcClient: mockGrpcClient,
+ sessionRegistry: mockSessionRegistry,
+ portRegistry: mockPort,
+ }
+
+ done := make(chan bool, 1)
+
+ go func() {
+ conn, err := listener.Accept()
+ if err != nil {
+ return
+ }
+ s.handleConnection(conn)
+ done <- true
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+
+ clientConfig := &ssh.ClientConfig{
+ User: "testuser",
+ Auth: []ssh.AuthMethod{ssh.Password("password")},
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ Timeout: 2 * time.Second,
+ }
+
+ go func() {
+ client, err := ssh.Dial("tcp", serverAddr, clientConfig)
+ if err != nil {
+ t.Logf("Client dial failed: %v", err)
+ return
+ }
+ defer func(client *ssh.Client) {
+ _ = client.Close()
+ }(client)
+
+ type forwardPayload struct {
+ BindAddr string
+ BindPort uint32
+ }
+
+ payload := ssh.Marshal(forwardPayload{
+ BindAddr: "localhost",
+ BindPort: 8080,
+ })
+
+ _, _, err = client.SendRequest("tcpip-forward", true, payload)
+ if err != nil {
+ t.Logf("Forward request failed: %v", err)
+ }
+
+ time.Sleep(500 * time.Millisecond)
+ }()
+
+ select {
+ case <-done:
+ mockGrpcClient.AssertExpectations(t)
+ case <-time.After(5 * time.Second):
+ t.Fatal("handleConnection did not complete in time")
+ }
+ })
+
+ t.Run("connection cleanup on close", func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockPort := &MockPort{}
+
+ serverConfig, _ := getTestSSHConfig()
+
+ serverConn, clientConn := net.Pipe()
+
+ s := &server{
+ randomizer: mockRandom,
+ config: mockConfig,
+ sshPort: "0",
+ sshConfig: serverConfig,
+ grpcClient: nil,
+ sessionRegistry: mockSessionRegistry,
+ portRegistry: mockPort,
+ }
+
+ done := make(chan bool, 1)
+
+ go func() {
+ s.handleConnection(serverConn)
+ done <- true
+ }()
+
+ err := clientConn.Close()
+ assert.NoError(t, err)
+
+ select {
+ case <-done:
+ case <-time.After(1 * time.Second):
+ t.Fatal("handleConnection did not complete in time")
+ }
+ })
+}
+
+func TestIntegration(t *testing.T) {
+ t.Run("full server lifecycle", func(t *testing.T) {
+ mr := new(MockRandom)
+ mc := new(MockConfig)
+ mreg := new(MockSessionRegistry)
+ mg := new(MockGRPCClient)
+ mp := new(MockPort)
+ sc, _ := getTestSSHConfig()
+
+ s, err := New(mr, mc, sc, mreg, mg, mp, "0")
+ assert.NoError(t, err)
+ assert.NotNil(t, s)
+
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ err := s.Close()
+ assert.NoError(t, err)
+ }()
+
+ s.Start()
+ })
+
+ t.Run("multiple connections", func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockPort := &MockPort{}
+
+ sshConfig, _ := getTestSSHConfig()
+
+ conn1Server, conn1Client := net.Pipe()
+ conn2Server, conn2Client := net.Pipe()
+
+ mockListener := &MockListener{}
+ mockListener.On("Accept").Return(conn1Server, nil).Once()
+ mockListener.On("Accept").Return(conn2Server, nil).Once()
+ mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
+
+ s := &server{
+ randomizer: mockRandom,
+ config: mockConfig,
+ sshPort: "0",
+ sshListener: mockListener,
+ sshConfig: sshConfig,
+ grpcClient: nil,
+ sessionRegistry: mockSessionRegistry,
+ portRegistry: mockPort,
+ }
+
+ go s.Start()
+
+ time.Sleep(50 * time.Millisecond)
+ _ = conn1Client.Close()
+ time.Sleep(50 * time.Millisecond)
+ _ = conn2Client.Close()
+ time.Sleep(100 * time.Millisecond)
+
+ mockListener.AssertExpectations(t)
+ })
+}
+
+func TestErrorHandling(t *testing.T) {
+ t.Run("write error during SSH handshake", func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockPort := &MockPort{}
+
+ sshConfig, _ := getTestSSHConfig()
+
+ serverConn, clientConn := net.Pipe()
+ err := clientConn.Close()
+ assert.NoError(t, err)
+
+ s := &server{
+ randomizer: mockRandom,
+ config: mockConfig,
+ sshPort: "0",
+ sshConfig: sshConfig,
+ grpcClient: nil,
+ sessionRegistry: mockSessionRegistry,
+ portRegistry: mockPort,
+ }
+
+ s.handleConnection(serverConn)
+ })
+}
diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go
index c602565..03528c9 100644
--- a/session/forwarder/forwarder.go
+++ b/session/forwarder/forwarder.go
@@ -1,8 +1,7 @@
package forwarder
import (
- "bytes"
- "encoding/binary"
+ "context"
"errors"
"fmt"
"io"
@@ -10,7 +9,6 @@ import (
"net"
"strconv"
"sync"
- "time"
"tunnel_pls/internal/config"
"tunnel_pls/session/slug"
"tunnel_pls/types"
@@ -26,9 +24,7 @@ type Forwarder interface {
TunnelType() types.TunnelType
ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel)
- CreateForwardedTCPIPPayload(origin net.Addr) []byte
- OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
- WriteBadGatewayResponse(dst io.Writer)
+ OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
Close() error
}
type forwarder struct {
@@ -50,19 +46,21 @@ func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
bufferPool: sync.Pool{
New: func() interface{} {
bufSize := config.BufferSize()
- return make([]byte, bufSize)
+ buf := make([]byte, bufSize)
+ return &buf
},
},
}
}
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
- buf := f.bufferPool.Get().([]byte)
+ buf := f.bufferPool.Get().(*[]byte)
defer f.bufferPool.Put(buf)
- return io.CopyBuffer(dst, src, buf)
+ return io.CopyBuffer(dst, src, *buf)
}
-func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
+func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
+ payload := createForwardedTCPIPPayload(origin, f.forwardedPort)
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
@@ -74,13 +72,9 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
- default:
+ case <-ctx.Done():
if channel != nil {
- err = channel.Close()
- if err != nil {
- log.Printf("Failed to close unused channel: %v", err)
- return
- }
+ _ = channel.Close()
go ssh.DiscardRequests(reqs)
}
}
@@ -89,8 +83,8 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
select {
case result := <-resultChan:
return result.channel, result.reqs, result.err
- case <-time.After(5 * time.Second):
- return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
+ case <-ctx.Done():
+ return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err())
}
}
@@ -119,10 +113,7 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string)
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
defer func() {
- _, err := io.Copy(io.Discard, src)
- if err != nil {
- log.Printf("Failed to discard connection: %v", err)
- }
+ _, _ = io.Copy(io.Discard, src)
}()
var wg sync.WaitGroup
@@ -173,14 +164,6 @@ func (f *forwarder) Listener() net.Listener {
return f.listener
}
-func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
- _, err := dst.Write(types.BadGatewayResponse)
- if err != nil {
- log.Printf("failed to write Bad Gateway response: %v", err)
- return
- }
-}
-
func (f *forwarder) Close() error {
if f.Listener() != nil {
return f.listener.Close()
@@ -188,43 +171,21 @@ func (f *forwarder) Close() error {
return nil
}
-func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
- var buf bytes.Buffer
-
- host, originPort := parseAddr(origin.String())
-
- writeSSHString(&buf, "localhost")
- err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
- if err != nil {
- log.Printf("Failed to write string to buffer: %v", err)
- return nil
- }
-
- writeSSHString(&buf, host)
- err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
- if err != nil {
- log.Printf("Failed to write string to buffer: %v", err)
- return nil
- }
-
- return buf.Bytes()
-}
-
-func parseAddr(addr string) (string, uint16) {
- host, portStr, err := net.SplitHostPort(addr)
- if err != nil {
- log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
- return "0.0.0.0", uint16(0)
- }
+func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte {
+ host, portStr, _ := net.SplitHostPort(origin.String())
port, _ := strconv.Atoi(portStr)
- return host, uint16(port)
-}
-func writeSSHString(buffer *bytes.Buffer, str string) {
- err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
- if err != nil {
- log.Printf("Failed to write string to buffer: %v", err)
- return
+ forwardPayload := struct {
+ DestAddr string
+ DestPort uint32
+ OriginAddr string
+ OriginPort uint32
+ }{
+ DestAddr: "localhost",
+ DestPort: uint32(destPort),
+ OriginAddr: host,
+ OriginPort: uint32(port),
}
- buffer.WriteString(str)
+
+ return ssh.Marshal(forwardPayload)
}
diff --git a/session/forwarder/forwarder_test.go b/session/forwarder/forwarder_test.go
new file mode 100644
index 0000000..092d783
--- /dev/null
+++ b/session/forwarder/forwarder_test.go
@@ -0,0 +1,1806 @@
+package forwarder
+
+import (
+ "bytes"
+ "context"
+ "errors"
+ "io"
+ "net"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+ "tunnel_pls/session/slug"
+ "tunnel_pls/types"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+)
+
+type mockConfig struct {
+ mock.Mock
+}
+
+func (m *mockConfig) Domain() string { return m.Called().String(0) }
+func (m *mockConfig) SSHPort() string { return m.Called().String(0) }
+func (m *mockConfig) HTTPPort() string { return m.Called().String(0) }
+func (m *mockConfig) HTTPSPort() string { return m.Called().String(0) }
+func (m *mockConfig) KeyLoc() string { return m.Called().String(0) }
+func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
+func (m *mockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
+func (m *mockConfig) TLSStoragePath() string { return m.Called().String(0) }
+func (m *mockConfig) ACMEEmail() string { return m.Called().String(0) }
+func (m *mockConfig) CFAPIToken() string { return m.Called().String(0) }
+func (m *mockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
+func (m *mockConfig) AllowedPortsStart() uint16 { return m.Called().Get(0).(uint16) }
+func (m *mockConfig) AllowedPortsEnd() uint16 { return m.Called().Get(0).(uint16) }
+func (m *mockConfig) BufferSize() int { return m.Called().Int(0) }
+func (m *mockConfig) HeaderSize() int { return m.Called().Int(0) }
+func (m *mockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
+func (m *mockConfig) PprofPort() string { return m.Called().String(0) }
+func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
+func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) }
+func (m *mockConfig) GRPCPort() string { return m.Called().String(0) }
+func (m *mockConfig) NodeToken() string { return m.Called().String(0) }
+
+type mockConn struct {
+ mock.Mock
+}
+
+func (c *mockConn) Close() error { return c.Called().Error(0) }
+func (c *mockConn) User() string { return c.Called().String(0) }
+func (c *mockConn) SessionID() []byte { return c.Called().Get(0).([]byte) }
+func (c *mockConn) ClientVersion() []byte { return c.Called().Get(0).([]byte) }
+func (c *mockConn) ServerVersion() []byte { return c.Called().Get(0).([]byte) }
+func (c *mockConn) RemoteAddr() net.Addr { return c.Called().Get(0).(net.Addr) }
+func (c *mockConn) LocalAddr() net.Addr { return c.Called().Get(0).(net.Addr) }
+func (c *mockConn) SendRequest(s string, b bool, d []byte) (bool, []byte, error) {
+ args := c.Called(s, b, d)
+ return args.Bool(0), args.Get(1).([]byte), args.Error(2)
+}
+func (c *mockConn) Wait() error { return c.Called().Error(0) }
+
+func (c *mockConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
+ args := c.Called(name, data)
+ return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
+}
+
+type testChannel struct {
+ mock.Mock
+ readBuf *syncBuffer
+ writeBuf *syncBuffer
+ closedWrite atomic.Bool
+}
+
+func (c *testChannel) Read(b []byte) (int, error) {
+ return c.readBuf.Read(b)
+}
+
+func (c *testChannel) Write(b []byte) (int, error) {
+ return c.writeBuf.Write(b)
+}
+
+func (c *testChannel) Close() error {
+ return c.Called().Error(0)
+}
+
+func (c *testChannel) CloseWrite() error {
+ c.closedWrite.Store(true)
+ return c.writeBuf.Close()
+}
+
+func (c *testChannel) Stderr() io.ReadWriter {
+ return c.Called().Get(0).(io.ReadWriter)
+}
+
+func (c *testChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
+ args := c.Called(name, wantReply, payload)
+ return args.Bool(0), args.Error(1)
+}
+
+func (c *testChannel) AckRequest(ok bool, payload []byte) error {
+ return c.Called(ok, payload).Error(0)
+}
+
+type syncBuffer struct {
+ mu sync.Mutex
+ buf []byte
+ closed bool
+ cond *sync.Cond
+}
+
+func newSyncBuffer() *syncBuffer {
+ sb := &syncBuffer{}
+ sb.cond = sync.NewCond(&sb.mu)
+ return sb
+}
+
+func (sb *syncBuffer) Write(p []byte) (int, error) {
+ sb.mu.Lock()
+ defer sb.mu.Unlock()
+ if sb.closed {
+ return 0, io.ErrClosedPipe
+ }
+ sb.buf = append(sb.buf, p...)
+ sb.cond.Broadcast()
+ return len(p), nil
+}
+
+func (sb *syncBuffer) Read(p []byte) (int, error) {
+ sb.mu.Lock()
+ defer sb.mu.Unlock()
+
+ for len(sb.buf) == 0 {
+ if sb.closed {
+ return 0, io.EOF
+ }
+ sb.cond.Wait()
+ }
+
+ n := copy(p, sb.buf)
+ sb.buf = sb.buf[n:]
+ return n, nil
+}
+
+func (sb *syncBuffer) Close() error {
+ sb.mu.Lock()
+ defer sb.mu.Unlock()
+ sb.closed = true
+ sb.cond.Broadcast()
+ return nil
+}
+
+func newChannelPair() (*testChannel, *testChannelPeer) {
+ peerToChBuf := newSyncBuffer()
+ chToPeerBuf := newSyncBuffer()
+
+ channel := &testChannel{
+ readBuf: peerToChBuf,
+ writeBuf: chToPeerBuf,
+ }
+
+ peer := &testChannelPeer{
+ readBuf: chToPeerBuf,
+ writeBuf: peerToChBuf,
+ }
+
+ channel.On("Close").Return(nil).Maybe()
+
+ return channel, peer
+}
+
+type testChannelPeer struct {
+ readBuf *syncBuffer
+ writeBuf *syncBuffer
+}
+
+func (p *testChannelPeer) Read(b []byte) (int, error) {
+ return p.readBuf.Read(b)
+}
+
+func (p *testChannelPeer) Write(b []byte) (int, error) {
+ return p.writeBuf.Write(b)
+}
+
+func (p *testChannelPeer) CloseWrite() error {
+ return p.writeBuf.Close()
+}
+
+func newPipePair() (*pipeConn, *pipeConn) {
+ r1, w1 := io.Pipe()
+ r2, w2 := io.Pipe()
+
+ conn1 := &pipeConn{
+ reader: r1,
+ writer: w2,
+ }
+
+ conn2 := &pipeConn{
+ reader: r2,
+ writer: w1,
+ }
+
+ return conn1, conn2
+}
+
+type pipeConn struct {
+ reader *io.PipeReader
+ writer *io.PipeWriter
+}
+
+func (p *pipeConn) Read(b []byte) (int, error) {
+ return p.reader.Read(b)
+}
+
+func (p *pipeConn) Write(b []byte) (int, error) {
+ return p.writer.Write(b)
+}
+
+func (p *pipeConn) Close() error {
+ err := p.reader.Close()
+ if err != nil {
+ return err
+ }
+ err = p.writer.Close()
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (p *pipeConn) CloseWrite() error {
+ return p.writer.Close()
+}
+
+func (p *pipeConn) LocalAddr() net.Addr {
+ return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
+}
+
+func (p *pipeConn) RemoteAddr() net.Addr {
+ return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
+}
+
+func (p *pipeConn) SetDeadline(t time.Time) error { return nil }
+func (p *pipeConn) SetReadDeadline(t time.Time) error { return nil }
+func (p *pipeConn) SetWriteDeadline(t time.Time) error { return nil }
+
+func TestNew(t *testing.T) {
+ tests := []struct {
+ name string
+ bufferSize int
+ wantBufLen int
+ }{
+ {
+ name: "default buffer size",
+ bufferSize: 16,
+ wantBufLen: 16,
+ },
+ {
+ name: "custom buffer size",
+ bufferSize: 32,
+ wantBufLen: 32,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
+ s := slug.New()
+ conn := &mockConn{}
+
+ forwarder := New(cfg, s, conn).(*forwarder)
+
+ buf := forwarder.bufferPool.Get().(*[]byte)
+ require.Len(t, *buf, tt.wantBufLen)
+ forwarder.bufferPool.Put(buf)
+
+ assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType())
+ assert.Equal(t, uint16(0), forwarder.ForwardedPort())
+ assert.Equal(t, conn, forwarder.conn)
+ assert.Equal(t, s, forwarder.slug)
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestHandleConnection(t *testing.T) {
+ tests := []struct {
+ name string
+ bufferSize int
+ messageToDst []byte
+ messageToSrc []byte
+ }{
+ {
+ name: "small messages",
+ bufferSize: 4,
+ messageToDst: []byte("hi"),
+ messageToSrc: []byte("yo"),
+ },
+ {
+ name: "medium messages",
+ bufferSize: 8,
+ messageToDst: []byte("hello"),
+ messageToSrc: []byte("world"),
+ },
+ {
+ name: "larger messages",
+ bufferSize: 16,
+ messageToDst: []byte("I love femboy"),
+ messageToSrc: []byte("mee too"),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ channel, channelPeer := newChannelPair()
+ dstEndpoint, dstPeer := newPipePair()
+
+ done := make(chan struct{})
+ go func() {
+ forwarder.HandleConnection(dstEndpoint, channel)
+ close(done)
+ }()
+
+ readDst := make(chan struct {
+ data []byte
+ err error
+ }, 1)
+ go func() {
+ buf := make([]byte, len(tt.messageToDst))
+ n, err := io.ReadFull(dstPeer, buf)
+ readDst <- struct {
+ data []byte
+ err error
+ }{data: buf[:n], err: err}
+ }()
+
+ _, err := channelPeer.Write(tt.messageToDst)
+ require.NoError(t, err)
+
+ dstResult := <-readDst
+ require.NoError(t, dstResult.err)
+ assert.Equal(t, tt.messageToDst, dstResult.data)
+
+ readSrc := make(chan struct {
+ data []byte
+ err error
+ }, 1)
+ go func() {
+ buf := make([]byte, len(tt.messageToSrc))
+ n, err := io.ReadFull(channelPeer, buf)
+ readSrc <- struct {
+ data []byte
+ err error
+ }{data: buf[:n], err: err}
+ }()
+
+ _, err = dstPeer.Write(tt.messageToSrc)
+ require.NoError(t, err)
+
+ srcResult := <-readSrc
+ require.NoError(t, srcResult.err)
+ assert.Equal(t, tt.messageToSrc, srcResult.data)
+
+ require.NoError(t, channelPeer.CloseWrite())
+ require.NoError(t, dstPeer.CloseWrite())
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("HandleConnection did not complete")
+ }
+ assert.True(t, channel.closedWrite.Load())
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestHandleConnection_Error(t *testing.T) {
+ tests := []struct {
+ name string
+ bufferSize int
+ messageToDst []byte
+ messageToSrc []byte
+ }{
+ {
+ name: "small messages",
+ bufferSize: 4,
+ messageToDst: []byte("hi"),
+ messageToSrc: []byte("yo"),
+ },
+ {
+ name: "medium messages",
+ bufferSize: 8,
+ messageToDst: []byte("hello"),
+ messageToSrc: []byte("world"),
+ },
+ {
+ name: "larger messages",
+ bufferSize: 16,
+ messageToDst: []byte("I love femboy"),
+ messageToSrc: []byte("mee too"),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ channel, _ := newChannelPair()
+ dstEndpoint, _ := newPipePair()
+
+ go func() {
+ forwarder.HandleConnection(dstEndpoint, channel)
+ }()
+
+ err := dstEndpoint.Close()
+ assert.NoError(t, err)
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestOpenForwardedChannel(t *testing.T) {
+ tests := []struct {
+ name string
+ forwardedPort uint16
+ originIP string
+ originPort int
+ wantDestAddr string
+ wantDestPort uint32
+ wantOrigAddr string
+ wantOrigPort uint32
+ }{
+ {
+ name: "localhost origin",
+ forwardedPort: 2222,
+ originIP: "127.0.0.1",
+ originPort: 9000,
+ wantDestAddr: "localhost",
+ wantDestPort: 2222,
+ wantOrigAddr: "127.0.0.1",
+ wantOrigPort: 9000,
+ },
+ {
+ name: "remote origin",
+ forwardedPort: 8080,
+ originIP: "192.168.1.100",
+ originPort: 5000,
+ wantDestAddr: "localhost",
+ wantDestPort: 8080,
+ wantOrigAddr: "192.168.1.100",
+ wantOrigPort: 5000,
+ },
+ {
+ name: "different port",
+ forwardedPort: 3000,
+ originIP: "10.0.0.1",
+ originPort: 7777,
+ wantDestAddr: "localhost",
+ wantDestPort: 3000,
+ wantOrigAddr: "10.0.0.1",
+ wantOrigPort: 7777,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(8).Maybe()
+ channel := &testChannel{
+ readBuf: newSyncBuffer(),
+ writeBuf: newSyncBuffer(),
+ }
+ requests := make(chan *ssh.Request)
+
+ var capturedData []byte
+ conn := &mockConn{}
+ conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) {
+ data := args.Get(1).([]byte)
+ capturedData = make([]byte, len(data))
+ copy(capturedData, data)
+ }).Return(channel, (<-chan *ssh.Request)(requests), nil)
+
+ forwarder := New(cfg, slug.New(), conn).(*forwarder)
+ forwarder.SetForwardedPort(tt.forwardedPort)
+
+ origin := &net.TCPAddr{IP: net.ParseIP(tt.originIP), Port: tt.originPort}
+ ch, reqs, err := forwarder.OpenForwardedChannel(context.Background(), origin)
+ require.NoError(t, err)
+ assert.Same(t, channel, ch)
+ assert.NotNil(t, reqs)
+
+ var payload struct {
+ DestAddr string
+ DestPort uint32
+ OriginAddr string
+ OriginPort uint32
+ }
+ err = ssh.Unmarshal(capturedData, &payload)
+ assert.NoError(t, err)
+ assert.Equal(t, tt.wantDestAddr, payload.DestAddr)
+ assert.Equal(t, tt.wantDestPort, payload.DestPort)
+ assert.Equal(t, tt.wantOrigAddr, payload.OriginAddr)
+ assert.Equal(t, tt.wantOrigPort, payload.OriginPort)
+
+ conn.AssertExpectations(t)
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestOpenForwardedChannelContextCancellation(t *testing.T) {
+ tests := []struct {
+ name string
+ cancelBefore bool
+ cancelDuring bool
+ wantErr bool
+ wantErrType error
+ }{
+ {
+ name: "cancel during open",
+ cancelBefore: false,
+ cancelDuring: true,
+ wantErr: true,
+ wantErrType: context.Canceled,
+ },
+ {
+ name: "cancel before open",
+ cancelBefore: true,
+ cancelDuring: false,
+ wantErr: true,
+ wantErrType: context.Canceled,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(8).Maybe()
+ channel := &testChannel{
+ readBuf: newSyncBuffer(),
+ writeBuf: newSyncBuffer(),
+ }
+ channel.On("Close").Return(nil)
+ requests := make(chan *ssh.Request)
+
+ openChannelCalled := make(chan struct{})
+ openChannelBlock := make(chan struct{})
+
+ conn := &mockConn{}
+ conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) {
+ close(openChannelCalled)
+ <-openChannelBlock
+ }).Return(channel, (<-chan *ssh.Request)(requests), nil).Maybe()
+
+ forwarder := New(cfg, slug.New(), conn).(*forwarder)
+ forwarder.SetForwardedPort(8080)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ if tt.cancelBefore {
+ cancel()
+ }
+
+ var (
+ openedChannel ssh.Channel
+ openedReqs <-chan *ssh.Request
+ openErr error
+ )
+
+ done := make(chan struct{})
+ go func() {
+ origin := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 7000}
+ openedChannel, openedReqs, openErr = forwarder.OpenForwardedChannel(ctx, origin)
+ close(done)
+ }()
+
+ if tt.cancelDuring {
+ <-openChannelCalled
+ cancel()
+ }
+ close(openChannelBlock)
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("OpenForwardedChannel did not return after cancellation")
+ }
+
+ if tt.wantErr {
+ require.Error(t, openErr)
+ assert.True(t, errors.Is(openErr, tt.wantErrType))
+ assert.Nil(t, openedChannel)
+ assert.Nil(t, openedReqs)
+ } else {
+ require.NoError(t, openErr)
+ assert.NotNil(t, openedChannel)
+ assert.NotNil(t, openedReqs)
+ }
+
+ conn.AssertExpectations(t)
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestCreateForwardedTCPIPPayload(t *testing.T) {
+ tests := []struct {
+ name string
+ originIP string
+ originPort int
+ forwardedPort uint16
+ wantDestAddr string
+ wantDestPort uint32
+ wantOriginAddr string
+ wantOriginPort uint32
+ }{
+ {
+ name: "standard case",
+ originIP: "192.0.2.10",
+ originPort: 5050,
+ forwardedPort: 8080,
+ wantDestAddr: "localhost",
+ wantDestPort: 8080,
+ wantOriginAddr: "192.0.2.10",
+ wantOriginPort: 5050,
+ },
+ {
+ name: "localhost origin",
+ originIP: "127.0.0.1",
+ originPort: 3000,
+ forwardedPort: 9000,
+ wantDestAddr: "localhost",
+ wantDestPort: 9000,
+ wantOriginAddr: "127.0.0.1",
+ wantOriginPort: 3000,
+ },
+ {
+ name: "high port numbers",
+ originIP: "10.0.0.1",
+ originPort: 65535,
+ forwardedPort: 65534,
+ wantDestAddr: "localhost",
+ wantDestPort: 65534,
+ wantOriginAddr: "10.0.0.1",
+ wantOriginPort: 65535,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ origin := &net.TCPAddr{IP: net.ParseIP(tt.originIP), Port: tt.originPort}
+ payload := createForwardedTCPIPPayload(origin, tt.forwardedPort)
+
+ var decoded struct {
+ DestAddr string
+ DestPort uint32
+ OriginAddr string
+ OriginPort uint32
+ }
+
+ err := ssh.Unmarshal(payload, &decoded)
+ assert.NoError(t, err)
+ assert.Equal(t, tt.wantDestAddr, decoded.DestAddr)
+ assert.Equal(t, tt.wantDestPort, decoded.DestPort)
+ assert.Equal(t, tt.wantOriginAddr, decoded.OriginAddr)
+ assert.Equal(t, tt.wantOriginPort, decoded.OriginPort)
+ })
+ }
+}
+
+type mockReader struct {
+ mock.Mock
+}
+
+func (m *mockReader) Read(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+type mockWriter struct {
+ mock.Mock
+}
+
+func (m *mockWriter) Write(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *mockWriter) CloseWrite() error {
+ return m.Called().Error(0)
+}
+
+type mockWriteCloser struct {
+ mock.Mock
+}
+
+func (m *mockWriteCloser) Write(p []byte) (int, error) {
+ args := m.Called(p)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *mockWriteCloser) Close() error {
+ return m.Called().Error(0)
+}
+
+func TestCopyAndClose(t *testing.T) {
+ tests := []struct {
+ name string
+ setupSrc func() io.Reader
+ setupDst func() io.Writer
+ direction string
+ wantErr bool
+ wantErrMsg string
+ checkErrTypes []error
+ }{
+ {
+ name: "successful copy with EOF",
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ r.On("Read", mock.Anything).Return(5, nil).Once()
+ r.On("Read", mock.Anything).Return(0, io.EOF).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriter{}
+ w.On("Write", mock.Anything).Return(5, nil).Once()
+ w.On("CloseWrite").Return(nil).Once()
+ return w
+ },
+ direction: "src->dst",
+ wantErr: false,
+ },
+ {
+ name: "copy error - not EOF or ErrClosed",
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ customErr := errors.New("custom read error")
+ r.On("Read", mock.Anything).Return(0, customErr).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriter{}
+ w.On("CloseWrite").Return(nil).Once()
+ return w
+ },
+ direction: "src->dst",
+ wantErr: true,
+ wantErrMsg: "copy error (src->dst)",
+ },
+ {
+ name: "copy error - ErrClosed should be ignored",
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ r.On("Read", mock.Anything).Return(0, net.ErrClosed).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriter{}
+ w.On("CloseWrite").Return(nil).Once()
+ return w
+ },
+ direction: "src->dst",
+ wantErr: false,
+ },
+ {
+ name: "close writer error - not EOF",
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ r.On("Read", mock.Anything).Return(0, io.EOF).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriter{}
+ closeErr := errors.New("close error")
+ w.On("CloseWrite").Return(closeErr).Once()
+ return w
+ },
+ direction: "src->dst",
+ wantErr: true,
+ wantErrMsg: "close stream error (src->dst)",
+ },
+ {
+ name: "close writer error - EOF should be ignored",
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ r.On("Read", mock.Anything).Return(0, io.EOF).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriter{}
+ w.On("CloseWrite").Return(io.EOF).Once()
+ return w
+ },
+ direction: "src->dst",
+ wantErr: false,
+ },
+ {
+ name: "both copy and close errors",
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ copyErr := errors.New("copy error")
+ r.On("Read", mock.Anything).Return(0, copyErr).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriter{}
+ closeErr := errors.New("close error")
+ w.On("CloseWrite").Return(closeErr).Once()
+ return w
+ },
+ direction: "src->dst",
+ wantErr: true,
+ wantErrMsg: "copy error (src->dst)",
+ },
+ {
+ name: "successful copy with WriteCloser",
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ r.On("Read", mock.Anything).Return(5, nil).Once()
+ r.On("Read", mock.Anything).Return(0, io.EOF).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriteCloser{}
+ w.On("Write", mock.Anything).Return(5, nil).Once()
+ w.On("Close").Return(nil).Once()
+ return w
+ },
+ direction: "dst->src",
+ wantErr: false,
+ },
+ {
+ name: "WriteCloser close error",
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ r.On("Read", mock.Anything).Return(0, io.EOF).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriteCloser{}
+ closeErr := errors.New("writeCloser close error")
+ w.On("Close").Return(closeErr).Once()
+ return w
+ },
+ direction: "dst->src",
+ wantErr: true,
+ wantErrMsg: "close stream error (dst->src)",
+ },
+ {
+ name: "copy with multiple reads before EOF",
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ r.On("Read", mock.Anything).Return(10, nil).Once()
+ r.On("Read", mock.Anything).Return(15, nil).Once()
+ r.On("Read", mock.Anything).Return(5, nil).Once()
+ r.On("Read", mock.Anything).Return(0, io.EOF).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriter{}
+ w.On("Write", mock.Anything).Return(10, nil).Once()
+ w.On("Write", mock.Anything).Return(15, nil).Once()
+ w.On("Write", mock.Anything).Return(5, nil).Once()
+ w.On("CloseWrite").Return(nil).Once()
+ return w
+ },
+ direction: "src->dst",
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(32).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ src := tt.setupSrc()
+ dst := tt.setupDst()
+
+ err := forwarder.copyAndClose(dst, src, tt.direction)
+
+ if tt.wantErr {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.wantErrMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ if mr, ok := src.(*mockReader); ok {
+ mr.AssertExpectations(t)
+ }
+ if mw, ok := dst.(*mockWriter); ok {
+ mw.AssertExpectations(t)
+ }
+ if mwc, ok := dst.(*mockWriteCloser); ok {
+ mwc.AssertExpectations(t)
+ }
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestCopyAndCloseJoinedErrors(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(32).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ src := &mockReader{}
+ copyErr := errors.New("copy failed")
+ src.On("Read", mock.Anything).Return(0, copyErr).Once()
+
+ dst := &mockWriter{}
+ closeErr := errors.New("close failed")
+ dst.On("CloseWrite").Return(closeErr).Once()
+
+ err := forwarder.copyAndClose(dst, src, "test")
+
+ require.Error(t, err)
+
+ assert.Contains(t, err.Error(), "copy error (test)")
+ assert.Contains(t, err.Error(), "close stream error (test)")
+ assert.Contains(t, err.Error(), "copy failed")
+ assert.Contains(t, err.Error(), "close failed")
+
+ src.AssertExpectations(t)
+ dst.AssertExpectations(t)
+ cfg.AssertExpectations(t)
+}
+
+func TestCopyWithBuffer(t *testing.T) {
+ tests := []struct {
+ name string
+ bufferSize int
+ setupSrc func() io.Reader
+ setupDst func() io.Writer
+ wantBytesCount int64
+ wantErr bool
+ wantErrType error
+ }{
+ {
+ name: "successful copy small data",
+ bufferSize: 16,
+ setupSrc: func() io.Reader {
+ return io.NopCloser(bytes.NewReader([]byte("hello world")))
+ },
+ setupDst: func() io.Writer {
+ return &bytes.Buffer{}
+ },
+ wantBytesCount: 11,
+ wantErr: false,
+ },
+ {
+ name: "successful copy large data",
+ bufferSize: 8,
+ setupSrc: func() io.Reader {
+ data := make([]byte, 1024)
+ for i := range data {
+ data[i] = byte(i % 256)
+ }
+ return io.NopCloser(bytes.NewReader(data))
+ },
+ setupDst: func() io.Writer {
+ return &bytes.Buffer{}
+ },
+ wantBytesCount: 1024,
+ wantErr: false,
+ },
+ {
+ name: "empty data",
+ bufferSize: 16,
+ setupSrc: func() io.Reader {
+ return io.NopCloser(bytes.NewReader([]byte{}))
+ },
+ setupDst: func() io.Writer {
+ return &bytes.Buffer{}
+ },
+ wantBytesCount: 0,
+ wantErr: false,
+ },
+ {
+ name: "read error",
+ bufferSize: 16,
+ setupSrc: func() io.Reader {
+ r := &mockReader{}
+ r.On("Read", mock.Anything).Return(0, errors.New("read error")).Once()
+ return r
+ },
+ setupDst: func() io.Writer {
+ return &bytes.Buffer{}
+ },
+ wantBytesCount: 0,
+ wantErr: true,
+ },
+ {
+ name: "write error",
+ bufferSize: 16,
+ setupSrc: func() io.Reader {
+ return io.NopCloser(bytes.NewReader([]byte("test data")))
+ },
+ setupDst: func() io.Writer {
+ w := &mockWriter{}
+ w.On("Write", mock.Anything).Return(0, errors.New("write error")).Once()
+ return w
+ },
+ wantBytesCount: 0,
+ wantErr: true,
+ },
+ {
+ name: "partial write continues",
+ bufferSize: 16,
+ setupSrc: func() io.Reader {
+ return io.NopCloser(bytes.NewReader([]byte("testing")))
+ },
+ setupDst: func() io.Writer {
+ buf := &bytes.Buffer{}
+ return buf
+ },
+ wantBytesCount: 7,
+ wantErr: false,
+ },
+ {
+ name: "multiple buffer fills",
+ bufferSize: 4,
+ setupSrc: func() io.Reader {
+ return io.NopCloser(bytes.NewReader([]byte("this is a longer message")))
+ },
+ setupDst: func() io.Writer {
+ return &bytes.Buffer{}
+ },
+ wantBytesCount: 24,
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ src := tt.setupSrc()
+ dst := tt.setupDst()
+
+ n, err := forwarder.copyWithBuffer(dst, src)
+
+ if tt.wantErr {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+ assert.Equal(t, tt.wantBytesCount, n)
+ }
+
+ if buf, ok := dst.(*bytes.Buffer); ok && !tt.wantErr {
+ assert.Equal(t, tt.wantBytesCount, int64(buf.Len()))
+ }
+
+ if mr, ok := src.(*mockReader); ok {
+ mr.AssertExpectations(t)
+ }
+ if mw, ok := dst.(*mockWriter); ok {
+ mw.AssertExpectations(t)
+ }
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestCopyWithBufferReusesBuffer(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(16).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ buf1 := forwarder.bufferPool.Get().(*[]byte)
+ initialPtr := buf1
+
+ forwarder.bufferPool.Put(buf1)
+
+ src := io.NopCloser(bytes.NewReader([]byte("test")))
+ dst := &bytes.Buffer{}
+ _, err := forwarder.copyWithBuffer(dst, src)
+ require.NoError(t, err)
+
+ buf2 := forwarder.bufferPool.Get().(*[]byte)
+ secondPtr := buf2
+
+ forwarder.bufferPool.Put(buf2)
+
+ assert.Equal(t, initialPtr, secondPtr, "Buffers should be the same pointer")
+
+ assert.Len(t, *buf2, 16)
+ assert.Len(t, *buf1, 16)
+
+ _ = initialPtr
+ _ = secondPtr
+
+ cfg.AssertExpectations(t)
+}
+
+func TestSetType(t *testing.T) {
+ tests := []struct {
+ name string
+ tunnelType types.TunnelType
+ }{
+ {
+ name: "set to HTTP",
+ tunnelType: types.TunnelTypeHTTP,
+ },
+ {
+ name: "set to TCP",
+ tunnelType: types.TunnelTypeTCP,
+ },
+ {
+ name: "set to UNKNOWN",
+ tunnelType: types.TunnelTypeUNKNOWN,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(16).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType())
+
+ forwarder.SetType(tt.tunnelType)
+
+ assert.Equal(t, tt.tunnelType, forwarder.TunnelType())
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestTunnelType(t *testing.T) {
+ tests := []struct {
+ name string
+ tunnelType types.TunnelType
+ }{
+ {
+ name: "get HTTP type",
+ tunnelType: types.TunnelTypeHTTP,
+ },
+ {
+ name: "get TCP type",
+ tunnelType: types.TunnelTypeTCP,
+ },
+ {
+ name: "get UNKNOWN type",
+ tunnelType: types.TunnelTypeUNKNOWN,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(16).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ forwarder.SetType(tt.tunnelType)
+ result := forwarder.TunnelType()
+
+ assert.Equal(t, tt.tunnelType, result)
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestSetForwardedPort(t *testing.T) {
+ tests := []struct {
+ name string
+ port uint16
+ }{
+ {
+ name: "set standard port",
+ port: 8080,
+ },
+ {
+ name: "set low port",
+ port: 80,
+ },
+ {
+ name: "set high port",
+ port: 65535,
+ },
+ {
+ name: "set zero port",
+ port: 0,
+ },
+ {
+ name: "set custom port",
+ port: 3000,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(16).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ assert.Equal(t, uint16(0), forwarder.ForwardedPort())
+
+ forwarder.SetForwardedPort(tt.port)
+
+ assert.Equal(t, tt.port, forwarder.ForwardedPort())
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestForwardedPort(t *testing.T) {
+ tests := []struct {
+ name string
+ port uint16
+ }{
+ {
+ name: "get default port",
+ port: 0,
+ },
+ {
+ name: "get standard port",
+ port: 8080,
+ },
+ {
+ name: "get high port",
+ port: 65535,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(16).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ if tt.port != 0 {
+ forwarder.SetForwardedPort(tt.port)
+ }
+
+ result := forwarder.ForwardedPort()
+ assert.Equal(t, tt.port, result)
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestSetListener(t *testing.T) {
+ tests := []struct {
+ name string
+ setupListener func() net.Listener
+ }{
+ {
+ name: "set TCP listener",
+ setupListener: func() net.Listener {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ return listener
+ },
+ },
+ {
+ name: "set nil listener",
+ setupListener: func() net.Listener {
+ return nil
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(16).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ listener := tt.setupListener()
+ if listener != nil {
+ defer func(listener net.Listener) {
+ err := listener.Close()
+ assert.NoError(t, err)
+ }(listener)
+ }
+
+ assert.Nil(t, forwarder.Listener())
+
+ forwarder.SetListener(listener)
+
+ assert.Equal(t, listener, forwarder.Listener())
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestListener(t *testing.T) {
+ tests := []struct {
+ name string
+ setupListener func() net.Listener
+ }{
+ {
+ name: "get nil listener",
+ setupListener: func() net.Listener {
+ return nil
+ },
+ },
+ {
+ name: "get TCP listener",
+ setupListener: func() net.Listener {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ return listener
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(16).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ listener := tt.setupListener()
+ if listener != nil {
+ defer func(listener net.Listener) {
+ err := listener.Close()
+ assert.NoError(t, err)
+ }(listener)
+ forwarder.SetListener(listener)
+ }
+
+ result := forwarder.Listener()
+ assert.Equal(t, listener, result)
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestClose(t *testing.T) {
+ tests := []struct {
+ name string
+ setupListener func() net.Listener
+ wantErr bool
+ }{
+ {
+ name: "close with nil listener",
+ setupListener: func() net.Listener {
+ return nil
+ },
+ wantErr: false,
+ },
+ {
+ name: "close with active listener",
+ setupListener: func() net.Listener {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ return listener
+ },
+ wantErr: false,
+ },
+ {
+ name: "close already closed listener",
+ setupListener: func() net.Listener {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ err = listener.Close()
+ assert.NoError(t, err)
+ return listener
+ },
+ wantErr: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(16).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ listener := tt.setupListener()
+ if listener != nil {
+ forwarder.SetListener(listener)
+ }
+
+ err := forwarder.Close()
+
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestCloseWriter(t *testing.T) {
+ tests := []struct {
+ name string
+ setup func() io.Writer
+ wantErr bool
+ }{
+ {
+ name: "close writer with CloseWrite method",
+ setup: func() io.Writer {
+ w := &mockWriter{}
+ w.On("CloseWrite").Return(nil).Once()
+ return w
+ },
+ wantErr: false,
+ },
+ {
+ name: "close writer with CloseWrite error",
+ setup: func() io.Writer {
+ w := &mockWriter{}
+ w.On("CloseWrite").Return(errors.New("close write error")).Once()
+ return w
+ },
+ wantErr: true,
+ },
+ {
+ name: "close WriteCloser",
+ setup: func() io.Writer {
+ w := &mockWriteCloser{}
+ w.On("Close").Return(nil).Once()
+ return w
+ },
+ wantErr: false,
+ },
+ {
+ name: "close WriteCloser with error",
+ setup: func() io.Writer {
+ w := &mockWriteCloser{}
+ w.On("Close").Return(errors.New("close error")).Once()
+ return w
+ },
+ wantErr: true,
+ },
+ {
+ name: "close plain writer (no close method)",
+ setup: func() io.Writer {
+ return &bytes.Buffer{}
+ },
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ writer := tt.setup()
+
+ err := closeWriter(writer)
+
+ if tt.wantErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ if mw, ok := writer.(*mockWriter); ok {
+ mw.AssertExpectations(t)
+ }
+ if mwc, ok := writer.(*mockWriteCloser); ok {
+ mwc.AssertExpectations(t)
+ }
+ })
+ }
+}
+
+func TestHandleConnectionWithErrors(t *testing.T) {
+ tests := []struct {
+ name string
+ bufferSize int
+ setupChannel func() (*testChannel, *testChannelPeer)
+ setupDst func() (net.Conn, *pipeConn)
+ simulateErr func(channel *testChannelPeer, dst *pipeConn)
+ }{
+ {
+ name: "handle read error from channel",
+ bufferSize: 16,
+ setupChannel: func() (*testChannel, *testChannelPeer) {
+ return newChannelPair()
+ },
+ setupDst: func() (net.Conn, *pipeConn) {
+ return newPipePair()
+ },
+ simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
+ err := channel.CloseWrite()
+ assert.NoError(t, err)
+ err = dst.CloseWrite()
+ assert.NoError(t, err)
+ },
+ },
+ {
+ name: "handle write error to destination",
+ bufferSize: 16,
+ setupChannel: func() (*testChannel, *testChannelPeer) {
+ return newChannelPair()
+ },
+ setupDst: func() (net.Conn, *pipeConn) {
+ return newPipePair()
+ },
+ simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
+ err := dst.Close()
+ assert.NoError(t, err)
+ time.Sleep(10 * time.Millisecond)
+ write, err := channel.Write([]byte("test"))
+ assert.NotZero(t, write)
+ assert.NoError(t, err)
+ err = channel.CloseWrite()
+ assert.NoError(t, err)
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ channel, channelPeer := tt.setupChannel()
+ dstEndpoint, dstPeer := tt.setupDst()
+
+ done := make(chan struct{})
+ go func() {
+ forwarder.HandleConnection(dstEndpoint, channel)
+ close(done)
+ }()
+
+ tt.simulateErr(channelPeer, dstPeer)
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("HandleConnection did not complete")
+ }
+
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestHandleConnectionDiscardOnExit(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(16).Maybe()
+ forwarder := New(cfg, slug.New(), nil).(*forwarder)
+
+ channel, channelPeer := newChannelPair()
+ dstEndpoint, dstPeer := newPipePair()
+
+ done := make(chan struct{})
+ go func() {
+ forwarder.HandleConnection(dstEndpoint, channel)
+ close(done)
+ }()
+
+ _, err := channelPeer.Write([]byte("test data"))
+ require.NoError(t, err)
+ require.NoError(t, channelPeer.CloseWrite())
+ require.NoError(t, dstPeer.Close())
+
+ select {
+ case <-done:
+ case <-time.After(10 * time.Second):
+ t.Fatal("HandleConnection did not complete")
+ }
+
+ cfg.AssertExpectations(t)
+}
+
+func TestOpenForwardedChannelSuccess(t *testing.T) {
+ tests := []struct {
+ name string
+ forwardedPort uint16
+ originAddr string
+ originPort int
+ }{
+ {
+ name: "open channel standard port",
+ forwardedPort: 8080,
+ originAddr: "127.0.0.1",
+ originPort: 9000,
+ },
+ {
+ name: "open channel high port",
+ forwardedPort: 65534,
+ originAddr: "192.168.1.100",
+ originPort: 5000,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(8).Maybe()
+ channel := &testChannel{
+ readBuf: newSyncBuffer(),
+ writeBuf: newSyncBuffer(),
+ }
+ requests := make(chan *ssh.Request)
+
+ conn := &mockConn{}
+ conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).
+ Return(channel, (<-chan *ssh.Request)(requests), nil)
+
+ forwarder := New(cfg, slug.New(), conn).(*forwarder)
+ forwarder.SetForwardedPort(tt.forwardedPort)
+
+ origin := &net.TCPAddr{IP: net.ParseIP(tt.originAddr), Port: tt.originPort}
+ ch, reqs, err := forwarder.OpenForwardedChannel(context.Background(), origin)
+
+ require.NoError(t, err)
+ assert.NotNil(t, ch)
+ assert.NotNil(t, reqs)
+
+ conn.AssertExpectations(t)
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestOpenForwardedChannelError(t *testing.T) {
+ tests := []struct {
+ name string
+ setupConn func() *mockConn
+ wantErr bool
+ wantErrMsg string
+ }{
+ {
+ name: "open channel returns error",
+ setupConn: func() *mockConn {
+ conn := &mockConn{}
+ conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).
+ Return((*testChannel)(nil), (<-chan *ssh.Request)(nil), errors.New("channel open failed"))
+ return conn
+ },
+ wantErr: true,
+ wantErrMsg: "channel open failed",
+ },
+ {
+ name: "open channel with nil channel",
+ setupConn: func() *mockConn {
+ conn := &mockConn{}
+ conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).
+ Return((*testChannel)(nil), (<-chan *ssh.Request)(nil), nil)
+ return conn
+ },
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(8).Maybe()
+
+ conn := tt.setupConn()
+ forwarder := New(cfg, slug.New(), conn).(*forwarder)
+ forwarder.SetForwardedPort(8080)
+
+ origin := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 9000}
+ _, _, err := forwarder.OpenForwardedChannel(context.Background(), origin)
+
+ if tt.wantErr {
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), tt.wantErrMsg)
+ } else {
+ assert.NoError(t, err)
+ }
+
+ conn.AssertExpectations(t)
+ cfg.AssertExpectations(t)
+ })
+ }
+}
+
+func TestOpenForwardedChannelContextCancelledDuringOpen(t *testing.T) {
+ cfg := &mockConfig{}
+ cfg.On("BufferSize").Return(8).Maybe()
+
+ channel := &testChannel{
+ readBuf: newSyncBuffer(),
+ writeBuf: newSyncBuffer(),
+ }
+ channel.On("Close").Return(nil).Maybe()
+
+ requests := make(chan *ssh.Request)
+
+ openChannelStarted := make(chan struct{})
+ openChannelBlock := make(chan struct{})
+
+ conn := &mockConn{}
+ conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) {
+ close(openChannelStarted)
+ <-openChannelBlock
+ }).Return(channel, (<-chan *ssh.Request)(requests), nil)
+
+ forwarder := New(cfg, slug.New(), conn).(*forwarder)
+ forwarder.SetForwardedPort(8080)
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ resultChan := make(chan error, 1)
+ go func() {
+ origin := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 7000}
+ _, _, err := forwarder.OpenForwardedChannel(ctx, origin)
+ resultChan <- err
+ }()
+
+ <-openChannelStarted
+
+ cancel()
+
+ close(openChannelBlock)
+
+ select {
+ case err := <-resultChan:
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "context cancelled")
+ case <-time.After(2 * time.Second):
+ t.Fatal("OpenForwardedChannel did not return")
+ }
+
+ time.Sleep(50 * time.Millisecond)
+
+ conn.AssertExpectations(t)
+ cfg.AssertExpectations(t)
+ channel.AssertExpectations(t)
+}
+
+func TestCreateForwardedTCPIPPayloadEdgeCases(t *testing.T) {
+ tests := []struct {
+ name string
+ originAddr string
+ destPort uint16
+ wantDestAddr string
+ wantDestPort uint32
+ }{
+ {
+ name: "IPv4 localhost",
+ originAddr: "127.0.0.1:5000",
+ destPort: 8080,
+ wantDestAddr: "localhost",
+ wantDestPort: 8080,
+ },
+ {
+ name: "IPv6 address",
+ originAddr: "[::1]:3000",
+ destPort: 9000,
+ wantDestAddr: "localhost",
+ wantDestPort: 9000,
+ },
+ {
+ name: "private network",
+ originAddr: "192.168.1.1:12345",
+ destPort: 443,
+ wantDestAddr: "localhost",
+ wantDestPort: 443,
+ },
+ {
+ name: "port 1",
+ originAddr: "10.0.0.1:1",
+ destPort: 1,
+ wantDestAddr: "localhost",
+ wantDestPort: 1,
+ },
+ {
+ name: "max port",
+ originAddr: "172.16.0.1:65535",
+ destPort: 65535,
+ wantDestAddr: "localhost",
+ wantDestPort: 65535,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ addr, err := net.ResolveTCPAddr("tcp", tt.originAddr)
+ require.NoError(t, err)
+
+ payload := createForwardedTCPIPPayload(addr, tt.destPort)
+
+ var decoded struct {
+ DestAddr string
+ DestPort uint32
+ OriginAddr string
+ OriginPort uint32
+ }
+
+ err = ssh.Unmarshal(payload, &decoded)
+ require.NoError(t, err)
+
+ assert.Equal(t, tt.wantDestAddr, decoded.DestAddr)
+ assert.Equal(t, tt.wantDestPort, decoded.DestPort)
+ })
+ }
+}
diff --git a/session/interaction/commands.go b/session/interaction/commands.go
index e884aeb..5e24368 100644
--- a/session/interaction/commands.go
+++ b/session/interaction/commands.go
@@ -10,34 +10,37 @@ import (
"github.com/charmbracelet/lipgloss"
)
+func (m *model) handleCommandSelection(item commandItem) (tea.Model, tea.Cmd) {
+ switch item.name {
+ case "slug":
+ m.showingCommands = false
+ m.editingSlug = true
+ m.slugInput.SetValue(m.interaction.slug.String())
+ m.slugInput.Focus()
+ return m, tea.Batch(tea.ClearScreen, textinput.Blink)
+ case "tunnel-type":
+ m.showingCommands = false
+ m.showingComingSoon = true
+ return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
+ default:
+ m.showingCommands = false
+ return m, nil
+ }
+}
+
func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch {
- case key.Matches(msg, m.keymap.quit):
+ case key.Matches(msg, m.keymap.quit), msg.String() == "esc":
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case msg.String() == "enter":
selectedItem := m.commandList.SelectedItem()
if selectedItem != nil {
item := selectedItem.(commandItem)
- if item.name == "slug" {
- m.showingCommands = false
- m.editingSlug = true
- m.slugInput.SetValue(m.interaction.slug.String())
- m.slugInput.Focus()
- return m, tea.Batch(tea.ClearScreen, textinput.Blink)
- } else if item.name == "tunnel-type" {
- m.showingCommands = false
- m.showingComingSoon = true
- return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
- }
- m.showingCommands = false
- return m, nil
+ return m.handleCommandSelection(item)
}
- case msg.String() == "esc":
- m.showingCommands = false
- return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
m.commandList, cmd = m.commandList.Update(msg)
return m, cmd
diff --git a/session/interaction/dashboard.go b/session/interaction/dashboard.go
index a24ab7c..cf10ddb 100644
--- a/session/interaction/dashboard.go
+++ b/session/interaction/dashboard.go
@@ -23,164 +23,194 @@ func (m *model) dashboardUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
}
func (m *model) dashboardView() string {
- titleStyle := lipgloss.NewStyle().
- Bold(true).
- Foreground(lipgloss.Color("#7D56F4")).
- PaddingTop(1)
-
- subtitleStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#888888")).
- Italic(true)
-
- urlStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#7D56F4")).
- Underline(true).
- Italic(true)
-
- urlBoxStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#04B575")).
- Bold(true).
- Italic(true)
-
- keyHintStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#7D56F4")).
- Bold(true)
+ isCompact := shouldUseCompactLayout(m.width, BreakpointLarge)
var b strings.Builder
+ b.WriteString(m.renderHeader(isCompact))
+ b.WriteString(m.renderUserInfo(isCompact))
+ b.WriteString(m.renderQuickActions(isCompact))
+ b.WriteString(m.renderFooter(isCompact))
- isCompact := shouldUseCompactLayout(m.width, 85)
+ return b.String()
+}
- var asciiArtMargin int
- if isCompact {
- asciiArtMargin = 0
- } else {
- asciiArtMargin = 1
- }
+func (m *model) renderHeader(isCompact bool) string {
+ var b strings.Builder
+ asciiArtMargin := getMarginValue(isCompact, 0, 1)
asciiArtStyle := lipgloss.NewStyle().
Bold(true).
- Foreground(lipgloss.Color("#7D56F4")).
+ Foreground(lipgloss.Color(ColorPrimary)).
MarginBottom(asciiArtMargin)
- var asciiArt string
- if shouldUseCompactLayout(m.width, 50) {
- asciiArt = "TUNNEL PLS"
- } else if isCompact {
- asciiArt = `
+ b.WriteString(asciiArtStyle.Render(m.getASCIIArt()))
+ b.WriteString("\n")
+
+ if !shouldUseCompactLayout(m.width, BreakpointSmall) {
+ b.WriteString(m.renderSubtitle())
+ } else {
+ b.WriteString("\n")
+ }
+
+ return b.String()
+}
+
+func (m *model) getASCIIArt() string {
+ if shouldUseCompactLayout(m.width, BreakpointTiny) {
+ return "TUNNEL PLS"
+ }
+
+ if shouldUseCompactLayout(m.width, BreakpointLarge) {
+ return `
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
- } else {
- asciiArt = `
+ }
+
+ return `
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
- }
+}
- b.WriteString(asciiArtStyle.Render(asciiArt))
- b.WriteString("\n")
+func (m *model) renderSubtitle() string {
+ subtitleStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorGray)).
+ Italic(true)
- if !shouldUseCompactLayout(m.width, 60) {
- b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
- b.WriteString(urlStyle.Render("https://fossy.my.id"))
- b.WriteString("\n\n")
- } else {
- b.WriteString("\n")
- }
+ urlStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorPrimary)).
+ Underline(true).
+ Italic(true)
+ return subtitleStyle.Render("Secure tunnel service by Bagas • ") +
+ urlStyle.Render("https://fossy.my.id") + "\n\n"
+}
+
+func (m *model) renderUserInfo(isCompact bool) string {
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
- var boxPadding int
- var boxMargin int
- if isCompact {
- boxPadding = 1
- boxMargin = 1
- } else {
- boxPadding = 2
- boxMargin = 2
- }
+ boxPadding := getMarginValue(isCompact, 1, 2)
+ boxMargin := getMarginValue(isCompact, 1, 2)
responsiveInfoBox := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
- BorderForeground(lipgloss.Color("#7D56F4")).
+ BorderForeground(lipgloss.Color(ColorPrimary)).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(boxMaxWidth)
- authenticatedUser := m.interaction.user
+ infoContent := m.getUserInfoContent(isCompact)
+ return responsiveInfoBox.Render(infoContent) + "\n"
+}
+func (m *model) getUserInfoContent(isCompact bool) string {
userInfoStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#FAFAFA")).
+ Foreground(lipgloss.Color(ColorWhite)).
Bold(true)
sectionHeaderStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#888888")).
+ Foreground(lipgloss.Color(ColorGray)).
Bold(true)
addressStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#FAFAFA"))
+ Foreground(lipgloss.Color(ColorWhite))
- var infoContent string
- if shouldUseCompactLayout(m.width, 70) {
- infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s",
+ urlBoxStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorSecondary)).
+ Bold(true).
+ Italic(true)
+
+ authenticatedUser := m.interaction.user
+ tunnelURL := urlBoxStyle.Render(m.getTunnelURL())
+
+ if isCompact {
+ return fmt.Sprintf("👤 %s\n\n%s\n%s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
- addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL()))))
- } else {
- infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
- userInfoStyle.Render(authenticatedUser),
- sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
- addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL())))
+ addressStyle.Render(fmt.Sprintf(" %s", tunnelURL)))
}
- b.WriteString(responsiveInfoBox.Render(infoContent))
+ return fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
+ userInfoStyle.Render(authenticatedUser),
+ sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
+ addressStyle.Render(tunnelURL))
+}
+
+func (m *model) renderQuickActions(isCompact bool) string {
+ var b strings.Builder
+
+ titleStyle := lipgloss.NewStyle().
+ Bold(true).
+ Foreground(lipgloss.Color(ColorPrimary)).
+ PaddingTop(1)
+
+ b.WriteString(titleStyle.Render(m.getQuickActionsTitle()))
b.WriteString("\n")
- var quickActionsTitle string
- if shouldUseCompactLayout(m.width, 50) {
- quickActionsTitle = "Actions"
- } else if isCompact {
- quickActionsTitle = "Quick Actions"
- } else {
- quickActionsTitle = "✨ Quick Actions"
- }
- b.WriteString(titleStyle.Render(quickActionsTitle))
- b.WriteString("\n")
-
- var featureMargin int
- if isCompact {
- featureMargin = 1
- } else {
- featureMargin = 2
- }
-
- compactFeatureStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#FAFAFA")).
+ featureMargin := getMarginValue(isCompact, 1, 2)
+ featureStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorWhite)).
MarginLeft(featureMargin)
- var commandsText string
- var quitText string
- if shouldUseCompactLayout(m.width, 60) {
- commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
- quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
- } else {
- commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
- quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
- }
+ keyHintStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorPrimary)).
+ Bold(true)
- b.WriteString(compactFeatureStyle.Render(commandsText))
+ commands := m.getActionCommands(keyHintStyle)
+ b.WriteString(featureStyle.Render(commands.commandsText))
b.WriteString("\n")
- b.WriteString(compactFeatureStyle.Render(quitText))
-
- if !shouldUseCompactLayout(m.width, 70) {
- b.WriteString("\n\n")
- footerStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#666666")).
- Italic(true)
- b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
- }
+ b.WriteString(featureStyle.Render(commands.quitText))
return b.String()
}
+
+func (m *model) getQuickActionsTitle() string {
+ if shouldUseCompactLayout(m.width, BreakpointTiny) {
+ return "Actions"
+ }
+ if shouldUseCompactLayout(m.width, BreakpointLarge) {
+ return "Quick Actions"
+ }
+ return "✨ Quick Actions"
+}
+
+type actionCommands struct {
+ commandsText string
+ quitText string
+}
+
+func (m *model) getActionCommands(keyHintStyle lipgloss.Style) actionCommands {
+ if shouldUseCompactLayout(m.width, BreakpointSmall) {
+ return actionCommands{
+ commandsText: fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]")),
+ quitText: fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]")),
+ }
+ }
+
+ return actionCommands{
+ commandsText: fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]")),
+ quitText: fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]")),
+ }
+}
+
+func (m *model) renderFooter(isCompact bool) string {
+ if isCompact {
+ return ""
+ }
+
+ footerStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorDarkGray)).
+ Italic(true)
+
+ return "\n\n" + footerStyle.Render("Press 'C' to customize your tunnel settings")
+}
+
+func getMarginValue(isCompact bool, compactValue, normalValue int) int {
+ if isCompact {
+ return compactValue
+ }
+ return normalValue
+}
diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go
index 5f68102..814cd67 100644
--- a/session/interaction/interaction.go
+++ b/session/interaction/interaction.go
@@ -3,7 +3,9 @@ package interaction
import (
"context"
"log"
+ "sync"
"tunnel_pls/internal/config"
+ "tunnel_pls/internal/random"
"tunnel_pls/session/slug"
"tunnel_pls/types"
@@ -39,6 +41,7 @@ type Forwarder interface {
type CloseFunc func() error
type interaction struct {
+ randomizer random.Random
config config.Config
channel ssh.Channel
slug slug.Slug
@@ -50,6 +53,7 @@ type interaction struct {
ctx context.Context
cancel context.CancelFunc
mode types.InteractiveMode
+ programMu sync.Mutex
}
func (i *interaction) SetMode(m types.InteractiveMode) {
@@ -76,9 +80,10 @@ func (i *interaction) SetWH(w, h int) {
}
}
-func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
+func New(randomizer random.Random, config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
ctx, cancel := context.WithCancel(context.Background())
return &interaction{
+ randomizer: randomizer,
config: config,
channel: nil,
slug: slug,
@@ -100,6 +105,10 @@ func (i *interaction) Stop() {
if i.cancel != nil {
i.cancel()
}
+
+ i.programMu.Lock()
+ defer i.programMu.Unlock()
+
if i.program != nil {
i.program.Kill()
i.program = nil
@@ -210,6 +219,7 @@ func (i *interaction) Start() {
ti.Width = 50
m := &model{
+ randomizer: i.randomizer,
domain: i.config.Domain(),
protocol: protocol,
tunnelType: tunnelType,
@@ -234,6 +244,7 @@ func (i *interaction) Start() {
help: help.New(),
}
+ i.programMu.Lock()
i.program = tea.NewProgram(
m,
tea.WithInput(i.channel),
@@ -244,16 +255,21 @@ func (i *interaction) Start() {
tea.WithoutSignalHandler(),
tea.WithFPS(30),
)
+ i.programMu.Unlock()
_, err := i.program.Run()
if err != nil {
log.Printf("Cannot close tea: %s \n", err)
}
- i.program.Kill()
- i.program = nil
+
+ i.programMu.Lock()
+ if i.program != nil {
+ i.program.Kill()
+ i.program = nil
+ }
+ i.programMu.Unlock()
+
if i.closeFunc != nil {
- if err := i.closeFunc(); err != nil {
- log.Printf("Cannot close session: %s \n", err)
- }
+ _ = i.closeFunc()
}
}
diff --git a/session/interaction/interaction_test.go b/session/interaction/interaction_test.go
new file mode 100644
index 0000000..4992093
--- /dev/null
+++ b/session/interaction/interaction_test.go
@@ -0,0 +1,2310 @@
+package interaction
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "testing"
+ "time"
+ "tunnel_pls/types"
+
+ "github.com/charmbracelet/bubbles/key"
+ "github.com/charmbracelet/bubbles/list"
+ "github.com/charmbracelet/bubbles/textinput"
+ tea "github.com/charmbracelet/bubbletea"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "golang.org/x/crypto/ssh"
+)
+
+type MockRandom struct {
+ mock.Mock
+}
+
+func (m *MockRandom) String(length int) (string, error) {
+ args := m.Called(length)
+ return args.String(0), args.Error(1)
+}
+
+type MockConfig struct {
+ mock.Mock
+}
+
+func (m *MockConfig) Domain() string { return m.Called().String(0) }
+func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
+func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
+func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
+func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
+func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
+func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
+func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
+func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
+func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
+func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
+func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
+func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
+func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
+func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
+func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
+func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
+func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
+
+type MockSlug struct {
+ mock.Mock
+}
+
+func (ms *MockSlug) Set(slug string) { ms.Called(slug) }
+func (ms *MockSlug) String() string { return ms.Called().String(0) }
+
+type MockForwarder struct {
+ mock.Mock
+}
+
+func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
+ args := m.Called(origin)
+ return args.Get(0).([]byte)
+}
+
+func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
+ m.Called(dst, src)
+}
+
+func (m *MockForwarder) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func (m *MockForwarder) TunnelType() types.TunnelType {
+ args := m.Called()
+ return args.Get(0).(types.TunnelType)
+}
+
+func (m *MockForwarder) ForwardedPort() uint16 {
+ args := m.Called()
+ return args.Get(0).(uint16)
+}
+
+func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
+ m.Called(tunnelType)
+}
+
+func (m *MockForwarder) SetForwardedPort(port uint16) {
+ m.Called(port)
+}
+
+func (m *MockForwarder) SetListener(listener net.Listener) {
+ m.Called(listener)
+}
+
+func (m *MockForwarder) Listener() net.Listener {
+ args := m.Called()
+ return args.Get(0).(net.Listener)
+}
+
+func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
+ args := m.Called(ctx, origin)
+ if args.Get(0) == nil {
+ return nil, nil, args.Error(2)
+ }
+ return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
+}
+
+type MockSessionRegistry struct {
+ mock.Mock
+}
+
+func (m *MockSessionRegistry) Update(user string, oldKey, newKey types.SessionKey) error {
+ args := m.Called(user, oldKey, newKey)
+ return args.Error(0)
+}
+
+type MockChannel struct {
+ mock.Mock
+ data []byte
+}
+
+func (m *MockChannel) Read(b []byte) (n int, err error) {
+ args := m.Called(b)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockChannel) Write(b []byte) (n int, err error) {
+ m.data = append(m.data, b...)
+ args := m.Called(b)
+ return args.Int(0), args.Error(1)
+}
+
+func (m *MockChannel) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func (m *MockChannel) CloseWrite() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func (m *MockChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
+ args := m.Called(name, wantReply, payload)
+ return args.Bool(0), args.Error(1)
+}
+
+func (m *MockChannel) Stderr() io.ReadWriter {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return nil
+ }
+ return args.Get(0).(io.ReadWriter)
+}
+
+type MockCloser struct {
+ mock.Mock
+}
+
+func (m *MockCloser) Close() error { return m.Called().Error(0) }
+
+func TestNew(t *testing.T) {
+ tests := []struct {
+ name string
+ user string
+ }{
+ {
+ name: "creates interaction with default mode",
+ user: "testuser",
+ },
+ {
+ name: "creates interaction for different user",
+ user: "anotheruser",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, tt.user, mockCloser.Close)
+
+ assert.NotNil(t, mockInteraction)
+ })
+ }
+}
+
+func TestInteraction_SetMode(t *testing.T) {
+ tests := []struct {
+ name string
+ mode types.InteractiveMode
+ }{
+ {
+ name: "set headless mode",
+ mode: types.InteractiveModeHEADLESS,
+ },
+ {
+ name: "set interactive mode",
+ mode: types.InteractiveModeINTERACTIVE,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+ mockInteraction.SetMode(tt.mode)
+
+ assert.Equal(t, tt.mode, mockInteraction.Mode())
+ })
+ }
+}
+
+func TestInteraction_Mode(t *testing.T) {
+ tests := []struct {
+ name string
+ setMode types.InteractiveMode
+ expected types.InteractiveMode
+ }{
+ {
+ name: "mode returns set value",
+ setMode: types.InteractiveModeINTERACTIVE,
+ expected: types.InteractiveModeINTERACTIVE,
+ },
+ {
+ name: "mode returns headless value",
+ setMode: types.InteractiveModeHEADLESS,
+ expected: types.InteractiveModeHEADLESS,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+
+ mockInteraction.SetMode(tt.setMode)
+ assert.Equal(t, tt.expected, mockInteraction.Mode())
+ })
+ }
+}
+
+func TestInteraction_Send(t *testing.T) {
+ tests := []struct {
+ name string
+ message string
+ setupChannel bool
+ channelReturn int
+ channelError error
+ wantError bool
+ }{
+ {
+ name: "send message successfully",
+ message: "test message",
+ setupChannel: true,
+ channelReturn: 12,
+ channelError: nil,
+ wantError: false,
+ },
+ {
+ name: "send message with channel error",
+ message: "test message",
+ setupChannel: true,
+ channelReturn: 0,
+ channelError: errors.New("channel write error"),
+ wantError: true,
+ },
+ {
+ name: "send message without channel",
+ message: "test message",
+ setupChannel: false,
+ wantError: false,
+ },
+ {
+ name: "send empty message",
+ message: "",
+ setupChannel: true,
+ channelReturn: 0,
+ channelError: nil,
+ wantError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+
+ if tt.setupChannel {
+ mockChannel := &MockChannel{}
+ mockChannel.On("Write", []byte(tt.message)).Return(tt.channelReturn, tt.channelError)
+ mockInteraction.SetChannel(mockChannel)
+ }
+
+ err := mockInteraction.Send(tt.message)
+
+ if tt.wantError {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ })
+ }
+}
+
+func TestInteraction_SetWH(t *testing.T) {
+ tests := []struct {
+ name string
+ width int
+ height int
+ }{
+ {
+ name: "set large window size",
+ width: 100,
+ height: 50,
+ },
+ {
+ name: "set medium window size",
+ width: 80,
+ height: 24,
+ },
+ {
+ name: "set small window size",
+ width: 20,
+ height: 10,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+
+ mockInteraction.SetWH(tt.width, tt.height)
+ })
+ }
+}
+
+func TestInteraction_SetChannel(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+
+ mockChannel := &MockChannel{}
+ mockInteraction.SetChannel(mockChannel)
+
+ mockChannel.On("Write", []byte("test")).Return(4, nil)
+ err := mockInteraction.Send("test")
+ assert.NoError(t, err)
+}
+
+func TestInteraction_Redraw(t *testing.T) {
+ tests := []struct {
+ name string
+ description string
+ }{
+ {
+ name: "redraw interaction",
+ description: "should not panic when calling redraw",
+ },
+ {
+ name: "redraw multiple times",
+ description: "should handle multiple redraws",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+
+ mockInteraction.Redraw()
+ })
+ }
+}
+
+func TestInteraction_Start(t *testing.T) {
+ tests := []struct {
+ name string
+ mode types.InteractiveMode
+ tlsEnabled bool
+ tunnelType types.TunnelType
+ port uint16
+ }{
+ {
+ name: "start in headless mode - should return immediately",
+ mode: types.InteractiveModeHEADLESS,
+ tlsEnabled: false,
+ tunnelType: types.TunnelTypeHTTP,
+ port: 8080,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+ mockInteraction.SetMode(tt.mode)
+
+ mockConfig.On("Domain").Return("tunnl.live")
+ mockConfig.On("TLSEnabled").Return(tt.tlsEnabled)
+ mockForwarder.On("TunnelType").Return(tt.tunnelType)
+ mockForwarder.On("ForwardedPort").Return(tt.port)
+
+ mockInteraction.Start()
+ })
+ }
+}
+
+func TestModel_Update(t *testing.T) {
+ tests := []struct {
+ name string
+ msg tea.Msg
+ showingComingSoon bool
+ editingSlug bool
+ showingCommands bool
+ width int
+ height int
+ expectedWidth int
+ expectedHeight int
+ expectedQuit bool
+ }{
+ {
+ name: "tick message clears coming soon",
+ msg: tickMsg{},
+ showingComingSoon: true,
+ editingSlug: false,
+ showingCommands: false,
+ expectedQuit: false,
+ },
+ {
+ name: "window size message - large screen",
+ msg: tea.WindowSizeMsg{Width: 100, Height: 50},
+ expectedWidth: 100,
+ expectedHeight: 50,
+ expectedQuit: false,
+ },
+ {
+ name: "window size message - small screen",
+ msg: tea.WindowSizeMsg{Width: 60, Height: 20},
+ expectedWidth: 60,
+ expectedHeight: 20,
+ expectedQuit: false,
+ },
+ {
+ name: "quit message",
+ msg: tea.QuitMsg{},
+ expectedQuit: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+
+ m := &model{
+ randomizer: mockRandom,
+ domain: "tunnl.live",
+ protocol: "http",
+ tunnelType: types.TunnelTypeHTTP,
+ port: 8080,
+ commandList: list.New([]list.Item{}, list.NewDefaultDelegate(), 80, 20),
+ interaction: mockInteraction.(*interaction),
+ showingComingSoon: tt.showingComingSoon,
+ editingSlug: tt.editingSlug,
+ showingCommands: tt.showingCommands,
+ width: tt.width,
+ height: tt.height,
+ keymap: keymap{
+ quit: key.NewBinding(
+ key.WithKeys("q", "ctrl+c"),
+ key.WithHelp("q", "quit"),
+ ),
+ command: key.NewBinding(
+ key.WithKeys("c"),
+ key.WithHelp("c", "commands"),
+ ),
+ random: key.NewBinding(
+ key.WithKeys("ctrl+r"),
+ key.WithHelp("ctrl+r", "random"),
+ ),
+ },
+ }
+
+ result, _ := m.Update(tt.msg)
+ updatedModel := result.(*model)
+
+ if tt.expectedQuit {
+ assert.True(t, updatedModel.quitting)
+ }
+
+ if windowMsg, ok := tt.msg.(tea.WindowSizeMsg); ok {
+ assert.Equal(t, windowMsg.Width, updatedModel.width)
+ assert.Equal(t, windowMsg.Height, updatedModel.height)
+ }
+
+ if _, ok := tt.msg.(tickMsg); ok && tt.showingComingSoon {
+ assert.False(t, updatedModel.showingComingSoon)
+ }
+ })
+ }
+}
+
+func TestModel_View(t *testing.T) {
+ tests := []struct {
+ name string
+ quitting bool
+ showingComingSoon bool
+ editingSlug bool
+ showingCommands bool
+ expectedEmpty bool
+ }{
+ {
+ name: "quitting returns empty string",
+ quitting: true,
+ expectedEmpty: true,
+ },
+ {
+ name: "showing coming soon view",
+ showingComingSoon: true,
+ expectedEmpty: false,
+ },
+ {
+ name: "editing slug view",
+ editingSlug: true,
+ expectedEmpty: false,
+ },
+ {
+ name: "showing commands view",
+ showingCommands: true,
+ expectedEmpty: false,
+ },
+ {
+ name: "dashboard view (default)",
+ expectedEmpty: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+
+ mockSlug.On("String").Return("test-slug")
+
+ m := &model{
+ randomizer: mockRandom,
+ domain: "tunnl.live",
+ protocol: "http",
+ tunnelType: types.TunnelTypeHTTP,
+ port: 8080,
+ commandList: list.New([]list.Item{}, list.NewDefaultDelegate(), 80, 20),
+ interaction: mockInteraction.(*interaction),
+ quitting: tt.quitting,
+ showingComingSoon: tt.showingComingSoon,
+ editingSlug: tt.editingSlug,
+ showingCommands: tt.showingCommands,
+ keymap: keymap{
+ quit: key.NewBinding(
+ key.WithKeys("q", "ctrl+c"),
+ key.WithHelp("q", "quit"),
+ ),
+ command: key.NewBinding(
+ key.WithKeys("c"),
+ key.WithHelp("c", "commands"),
+ ),
+ random: key.NewBinding(
+ key.WithKeys("ctrl+r"),
+ key.WithHelp("ctrl+r", "random"),
+ ),
+ },
+ }
+
+ view := m.View()
+
+ if tt.expectedEmpty {
+ assert.Empty(t, view)
+ } else {
+ assert.NotEmpty(t, view)
+ }
+ })
+ }
+}
+
+func TestInteraction_Integration(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeCallCount := 0
+ closeFunc := func() error {
+ closeCallCount++
+ return nil
+ }
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ assert.NotNil(t, mockInteraction)
+
+ mockInteraction.SetMode(types.InteractiveModeINTERACTIVE)
+ assert.Equal(t, types.InteractiveModeINTERACTIVE, mockInteraction.Mode())
+
+ mockChannel := &MockChannel{}
+ mockInteraction.SetChannel(mockChannel)
+
+ mockChannel.On("Write", []byte("hello")).Return(5, nil)
+ err := mockInteraction.Send("hello")
+ assert.NoError(t, err)
+ mockChannel.AssertExpectations(t)
+
+ mockInteraction.SetWH(80, 24)
+
+ mockInteraction.Redraw()
+}
+
+func TestModel_Update_KeyMessages(t *testing.T) {
+ tests := []struct {
+ name string
+ key tea.KeyMsg
+ showingComingSoon bool
+ editingSlug bool
+ showingCommands bool
+ description string
+ }{
+ {
+ name: "key press while showing coming soon",
+ key: tea.KeyMsg{Type: tea.KeyEnter},
+ showingComingSoon: true,
+ description: "should call comingSoonUpdate",
+ },
+ {
+ name: "key press while editing slug",
+ key: tea.KeyMsg{Type: tea.KeyEnter},
+ editingSlug: true,
+ description: "should call slugUpdate",
+ },
+ {
+ name: "key press while showing commands",
+ key: tea.KeyMsg{Type: tea.KeyEnter},
+ showingCommands: true,
+ description: "should call commandsUpdate",
+ },
+ {
+ name: "key press in dashboard view",
+ key: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'c'}},
+ description: "should call dashboardUpdate",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close)
+
+ mockSlug.On("String").Return("test-slug").Maybe()
+ mockSessionRegistry.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(nil)
+
+ m := &model{
+ randomizer: mockRandom,
+ domain: "tunnl.live",
+ protocol: "http",
+ tunnelType: types.TunnelTypeHTTP,
+ port: 8080,
+ commandList: list.New([]list.Item{}, list.NewDefaultDelegate(), 80, 20),
+ interaction: mockInteraction.(*interaction),
+ showingComingSoon: tt.showingComingSoon,
+ editingSlug: tt.editingSlug,
+ showingCommands: tt.showingCommands,
+ keymap: keymap{
+ quit: key.NewBinding(
+ key.WithKeys("q", "ctrl+c"),
+ key.WithHelp("q", "quit"),
+ ),
+ command: key.NewBinding(
+ key.WithKeys("c"),
+ key.WithHelp("c", "commands"),
+ ),
+ random: key.NewBinding(
+ key.WithKeys("ctrl+r"),
+ key.WithHelp("ctrl+r", "random"),
+ ),
+ },
+ }
+
+ result, _ := m.Update(tt.key)
+ assert.NotNil(t, result)
+ })
+ }
+}
+
+func TestModel_SlugUpdate(t *testing.T) {
+ tests := []struct {
+ name string
+ tunnelType types.TunnelType
+ keyMsg tea.KeyMsg
+ inputValue string
+ setupMocks func(*MockSessionRegistry, *MockSlug, *MockRandom)
+ expectedEdit bool
+ expectedError string
+ shouldSetValue bool
+ }{
+ {
+ name: "escape key cancels editing",
+ tunnelType: types.TunnelTypeHTTP,
+ keyMsg: tea.KeyMsg{Type: tea.KeyEsc},
+ setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {},
+ expectedEdit: false,
+ },
+ {
+ name: "ctrl+c cancels editing",
+ tunnelType: types.TunnelTypeHTTP,
+ keyMsg: tea.KeyMsg{Type: tea.KeyCtrlC},
+ setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {},
+ expectedEdit: false,
+ },
+ {
+ name: "enter key saves valid slug",
+ tunnelType: types.TunnelTypeHTTP,
+ keyMsg: tea.KeyMsg{Type: tea.KeyEnter},
+ inputValue: "my-custom-slug",
+ setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {
+ ms.On("String").Return("old-slug")
+ msr.On("Update", "testuser",
+ types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP},
+ types.SessionKey{Id: "my-custom-slug", Type: types.TunnelTypeHTTP},
+ ).Return(nil)
+ },
+ expectedEdit: false,
+ expectedError: "",
+ },
+ {
+ name: "enter key with error shows error message",
+ tunnelType: types.TunnelTypeHTTP,
+ keyMsg: tea.KeyMsg{Type: tea.KeyEnter},
+ inputValue: "invalid",
+ setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {
+ ms.On("String").Return("old-slug")
+ msr.On("Update", "testuser",
+ types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP},
+ types.SessionKey{Id: "invalid", Type: types.TunnelTypeHTTP},
+ ).Return(assert.AnError)
+ },
+ expectedEdit: true,
+ expectedError: assert.AnError.Error(),
+ },
+ {
+ name: "ctrl+r generates random slug",
+ tunnelType: types.TunnelTypeHTTP,
+ keyMsg: tea.KeyMsg{Type: tea.KeyCtrlR},
+ setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {
+ mr.On("String", 20).Return("random-generated-slug", nil)
+ },
+ expectedEdit: true,
+ shouldSetValue: true,
+ },
+ {
+ name: "ctrl+r with error does nothing",
+ tunnelType: types.TunnelTypeHTTP,
+ keyMsg: tea.KeyMsg{Type: tea.KeyCtrlR},
+ setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {
+ mr.On("String", 20).Return("", assert.AnError)
+ },
+ expectedEdit: true,
+ },
+ {
+ name: "regular key updates input",
+ tunnelType: types.TunnelTypeHTTP,
+ keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'a'}},
+ setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {},
+ expectedEdit: true,
+ },
+ {
+ name: "tcp tunnel exits editing immediately",
+ tunnelType: types.TunnelTypeTCP,
+ keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'a'}},
+ setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {},
+ expectedEdit: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ ti := textinput.New()
+ ti.SetValue(tt.inputValue)
+
+ m := &model{
+ randomizer: mockRandom,
+ domain: "tunnl.live",
+ protocol: "http",
+ tunnelType: tt.tunnelType,
+ port: 8080,
+ commandList: list.New([]list.Item{}, list.NewDefaultDelegate(), 80, 20),
+ slugInput: ti,
+ editingSlug: true,
+ interaction: mockInteraction.(*interaction),
+ keymap: keymap{
+ quit: key.NewBinding(
+ key.WithKeys("q", "ctrl+c"),
+ key.WithHelp("q", "quit"),
+ ),
+ command: key.NewBinding(
+ key.WithKeys("c"),
+ key.WithHelp("c", "commands"),
+ ),
+ random: key.NewBinding(
+ key.WithKeys("ctrl+r"),
+ key.WithHelp("ctrl+r", "random"),
+ ),
+ },
+ }
+
+ tt.setupMocks(mockSessionRegistry, mockSlug, mockRandom)
+
+ result, _ := m.slugUpdate(tt.keyMsg)
+ resultModel := result.(*model)
+
+ assert.Equal(t, tt.expectedEdit, resultModel.editingSlug)
+ if tt.expectedError != "" {
+ assert.Equal(t, tt.expectedError, resultModel.slugError)
+ } else if !tt.expectedEdit {
+ assert.Equal(t, "", resultModel.slugError)
+ }
+
+ mockSessionRegistry.AssertExpectations(t)
+ mockSlug.AssertExpectations(t)
+ mockRandom.AssertExpectations(t)
+ })
+ }
+}
+
+func TestModel_SlugView(t *testing.T) {
+ tests := []struct {
+ name string
+ width int
+ tunnelType types.TunnelType
+ slugError string
+ contains string
+ }{
+ {
+ name: "http tunnel - large screen",
+ width: 100,
+ tunnelType: types.TunnelTypeHTTP,
+ contains: "Subdomain",
+ },
+ {
+ name: "http tunnel - small screen",
+ width: 50,
+ tunnelType: types.TunnelTypeHTTP,
+ contains: "Subdomain",
+ },
+ {
+ name: "http tunnel - tiny screen",
+ width: 30,
+ tunnelType: types.TunnelTypeHTTP,
+ contains: "Subdomain",
+ },
+ {
+ name: "http tunnel with error",
+ width: 100,
+ tunnelType: types.TunnelTypeHTTP,
+ slugError: "Slug already exists",
+ contains: "Slug already exists",
+ },
+ {
+ name: "tcp tunnel - large screen",
+ width: 100,
+ tunnelType: types.TunnelTypeTCP,
+ contains: "TCP",
+ },
+ {
+ name: "tcp tunnel - small screen",
+ width: 50,
+ tunnelType: types.TunnelTypeTCP,
+ contains: "TCP",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ ti := textinput.New()
+ ti.SetValue("test-slug")
+
+ m := &model{
+ randomizer: mockRandom,
+ domain: "tunnl.live",
+ protocol: "http",
+ tunnelType: tt.tunnelType,
+ port: 8080,
+ slugInput: ti,
+ slugError: tt.slugError,
+ interaction: mockInteraction.(*interaction),
+ width: tt.width,
+ }
+
+ view := m.slugView()
+ assert.NotEmpty(t, view)
+ assert.Contains(t, view, tt.contains)
+ })
+ }
+}
+
+func TestModel_ComingSoonUpdate(t *testing.T) {
+ tests := []struct {
+ name string
+ keyMsg tea.KeyMsg
+ }{
+ {
+ name: "any key dismisses coming soon",
+ keyMsg: tea.KeyMsg{Type: tea.KeyEnter},
+ },
+ {
+ name: "escape key dismisses",
+ keyMsg: tea.KeyMsg{Type: tea.KeyEsc},
+ },
+ {
+ name: "space key dismisses",
+ keyMsg: tea.KeyMsg{Type: tea.KeySpace},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ m := &model{
+ interaction: mockInteraction.(*interaction),
+ showingComingSoon: true,
+ }
+
+ result, _ := m.comingSoonUpdate(tt.keyMsg)
+ resultModel := result.(*model)
+
+ assert.False(t, resultModel.showingComingSoon)
+ })
+ }
+}
+
+func TestModel_ComingSoonView(t *testing.T) {
+ tests := []struct {
+ name string
+ width int
+ }{
+ {
+ name: "large screen",
+ width: 100,
+ },
+ {
+ name: "medium screen",
+ width: 60,
+ },
+ {
+ name: "small screen",
+ width: 50,
+ },
+ {
+ name: "tiny screen",
+ width: 30,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ m := &model{
+ interaction: mockInteraction.(*interaction),
+ width: tt.width,
+ }
+
+ view := m.comingSoonView()
+ assert.NotEmpty(t, view)
+ assert.Contains(t, view, "Coming")
+ })
+ }
+}
+
+func TestModel_CommandsUpdate(t *testing.T) {
+ tests := []struct {
+ name string
+ keyMsg tea.KeyMsg
+ selectedItem list.Item
+ expectCommands bool
+ expectEditSlug bool
+ expectComingSoon bool
+ }{
+ {
+ name: "escape key closes commands",
+ keyMsg: tea.KeyMsg{Type: tea.KeyEsc},
+ expectCommands: false,
+ },
+ {
+ name: "q key closes commands",
+ keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'q'}},
+ expectCommands: false,
+ },
+ {
+ name: "enter on slug command starts editing",
+ keyMsg: tea.KeyMsg{Type: tea.KeyEnter},
+ selectedItem: commandItem{name: "slug", desc: "Set custom subdomain"},
+ expectCommands: false,
+ expectEditSlug: true,
+ },
+ {
+ name: "enter on tunnel-type shows coming soon",
+ keyMsg: tea.KeyMsg{Type: tea.KeyEnter},
+ selectedItem: commandItem{name: "tunnel-type", desc: "Change tunnel type"},
+ expectCommands: false,
+ expectComingSoon: true,
+ },
+ {
+ name: "arrow key navigates list",
+ keyMsg: tea.KeyMsg{Type: tea.KeyDown},
+ expectCommands: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ mockSlug.On("String").Return("current-slug").Maybe()
+
+ items := []list.Item{
+ commandItem{name: "slug", desc: "Set custom subdomain"},
+ commandItem{name: "tunnel-type", desc: "Change tunnel type"},
+ }
+
+ delegate := list.NewDefaultDelegate()
+ commandList := list.New(items, delegate, 80, 20)
+ if tt.selectedItem != nil {
+ for i, item := range items {
+ if item.(commandItem).name == tt.selectedItem.(commandItem).name {
+ commandList.Select(i)
+ break
+ }
+ }
+ }
+
+ ti := textinput.New()
+
+ m := &model{
+ randomizer: mockRandom,
+ interaction: mockInteraction.(*interaction),
+ showingCommands: true,
+ commandList: commandList,
+ slugInput: ti,
+ keymap: keymap{
+ quit: key.NewBinding(
+ key.WithKeys("q", "ctrl+c"),
+ key.WithHelp("q", "quit"),
+ ),
+ command: key.NewBinding(
+ key.WithKeys("c"),
+ key.WithHelp("c", "commands"),
+ ),
+ random: key.NewBinding(
+ key.WithKeys("ctrl+r"),
+ key.WithHelp("ctrl+r", "random"),
+ ),
+ },
+ }
+
+ result, _ := m.commandsUpdate(tt.keyMsg)
+ resultModel := result.(*model)
+
+ assert.Equal(t, tt.expectCommands, resultModel.showingCommands)
+ if tt.expectEditSlug {
+ assert.True(t, resultModel.editingSlug)
+ }
+ if tt.expectComingSoon {
+ assert.True(t, resultModel.showingComingSoon)
+ }
+ })
+ }
+}
+
+func TestModel_CommandsView(t *testing.T) {
+ tests := []struct {
+ name string
+ width int
+ }{
+ {
+ name: "large screen",
+ width: 100,
+ },
+ {
+ name: "medium screen",
+ width: 60,
+ },
+ {
+ name: "small screen",
+ width: 50,
+ },
+ {
+ name: "tiny screen",
+ width: 30,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ items := []list.Item{
+ commandItem{name: "slug", desc: "Set custom subdomain"},
+ commandItem{name: "tunnel-type", desc: "Change tunnel type"},
+ }
+
+ delegate := list.NewDefaultDelegate()
+ commandList := list.New(items, delegate, 80, 20)
+
+ m := &model{
+ interaction: mockInteraction.(*interaction),
+ commandList: commandList,
+ width: tt.width,
+ }
+
+ view := m.commandsView()
+ assert.NotEmpty(t, view)
+ assert.Contains(t, view, "Commands")
+ })
+ }
+}
+
+func TestModel_DashboardUpdate(t *testing.T) {
+ tests := []struct {
+ name string
+ keyMsg tea.KeyMsg
+ expectQuit bool
+ expectCommands bool
+ }{
+ {
+ name: "q key quits",
+ keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'q'}},
+ expectQuit: true,
+ },
+ {
+ name: "ctrl+c quits",
+ keyMsg: tea.KeyMsg{Type: tea.KeyCtrlC},
+ expectQuit: true,
+ },
+ {
+ name: "c key opens commands",
+ keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'c'}},
+ expectCommands: true,
+ },
+ {
+ name: "other keys do nothing",
+ keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ m := &model{
+ interaction: mockInteraction.(*interaction),
+ keymap: keymap{
+ quit: key.NewBinding(
+ key.WithKeys("q", "ctrl+c"),
+ key.WithHelp("q", "quit"),
+ ),
+ command: key.NewBinding(
+ key.WithKeys("c"),
+ key.WithHelp("c", "commands"),
+ ),
+ random: key.NewBinding(
+ key.WithKeys("ctrl+r"),
+ key.WithHelp("ctrl+r", "random"),
+ ),
+ },
+ }
+
+ result, _ := m.dashboardUpdate(tt.keyMsg)
+ resultModel := result.(*model)
+
+ if tt.expectQuit {
+ assert.True(t, resultModel.quitting)
+ }
+ if tt.expectCommands {
+ assert.True(t, resultModel.showingCommands)
+ }
+ })
+ }
+}
+
+func TestModel_DashboardView(t *testing.T) {
+ tests := []struct {
+ name string
+ width int
+ tunnelType types.TunnelType
+ protocol string
+ port uint16
+ contains string
+ }{
+ {
+ name: "http tunnel - large screen",
+ width: 100,
+ tunnelType: types.TunnelTypeHTTP,
+ protocol: "http",
+ contains: "http",
+ },
+ {
+ name: "https tunnel - large screen",
+ width: 100,
+ tunnelType: types.TunnelTypeHTTP,
+ protocol: "https",
+ contains: "https",
+ },
+ {
+ name: "tcp tunnel - large screen",
+ width: 100,
+ tunnelType: types.TunnelTypeTCP,
+ port: 8080,
+ contains: "tcp",
+ },
+ {
+ name: "http tunnel - medium screen",
+ width: 70,
+ tunnelType: types.TunnelTypeHTTP,
+ protocol: "http",
+ contains: "http",
+ },
+ {
+ name: "http tunnel - small screen",
+ width: 50,
+ tunnelType: types.TunnelTypeHTTP,
+ protocol: "http",
+ contains: "http",
+ },
+ {
+ name: "http tunnel - tiny screen",
+ width: 30,
+ tunnelType: types.TunnelTypeHTTP,
+ protocol: "http",
+ contains: "http",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ mockSlug.On("String").Return("test-slug")
+
+ m := &model{
+ randomizer: mockRandom,
+ domain: "tunnl.live",
+ protocol: tt.protocol,
+ tunnelType: tt.tunnelType,
+ port: tt.port,
+ interaction: mockInteraction.(*interaction),
+ width: tt.width,
+ }
+
+ view := m.dashboardView()
+ assert.NotEmpty(t, view)
+ assert.Contains(t, view, tt.contains)
+ })
+ }
+}
+
+func TestGetResponsiveWidth(t *testing.T) {
+ tests := []struct {
+ name string
+ screenWidth int
+ padding int
+ minWidth int
+ maxWidth int
+ expected int
+ }{
+ {
+ name: "screen wider than max",
+ screenWidth: 100,
+ padding: 10,
+ minWidth: 20,
+ maxWidth: 60,
+ expected: 60,
+ },
+ {
+ name: "screen narrower than min",
+ screenWidth: 30,
+ padding: 10,
+ minWidth: 40,
+ maxWidth: 80,
+ expected: 40,
+ },
+ {
+ name: "screen within range",
+ screenWidth: 70,
+ padding: 10,
+ minWidth: 20,
+ maxWidth: 80,
+ expected: 60,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := getResponsiveWidth(tt.screenWidth, tt.padding, tt.minWidth, tt.maxWidth)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestShouldUseCompactLayout(t *testing.T) {
+ tests := []struct {
+ name string
+ width int
+ threshold int
+ expected bool
+ }{
+ {
+ name: "width below threshold",
+ width: 50,
+ threshold: 60,
+ expected: true,
+ },
+ {
+ name: "width at threshold",
+ width: 60,
+ threshold: 60,
+ expected: false,
+ },
+ {
+ name: "width above threshold",
+ width: 70,
+ threshold: 60,
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := shouldUseCompactLayout(tt.width, tt.threshold)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTruncateString(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ maxLength int
+ expected string
+ }{
+ {
+ name: "string shorter than max",
+ input: "short",
+ maxLength: 10,
+ expected: "short",
+ },
+ {
+ name: "string equal to max",
+ input: "exactly10c",
+ maxLength: 10,
+ expected: "exactly10c",
+ },
+ {
+ name: "string longer than max",
+ input: "this is a very long string",
+ maxLength: 10,
+ expected: "this is...",
+ },
+ {
+ name: "very short max length",
+ input: "hello",
+ maxLength: 3,
+ expected: "hel",
+ },
+ {
+ name: "max length less than 4",
+ input: "hello",
+ maxLength: 2,
+ expected: "he",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := truncateString(tt.input, tt.maxLength)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestBuildURL(t *testing.T) {
+ tests := []struct {
+ name string
+ protocol string
+ subdomain string
+ domain string
+ expected string
+ }{
+ {
+ name: "http url",
+ protocol: "http",
+ subdomain: "test",
+ domain: "tunnl.live",
+ expected: "http://test.tunnl.live",
+ },
+ {
+ name: "https url",
+ protocol: "https",
+ subdomain: "api",
+ domain: "myapp.io",
+ expected: "https://api.myapp.io",
+ },
+ {
+ name: "custom subdomain",
+ protocol: "http",
+ subdomain: "my-custom-slug",
+ domain: "tunnl.live",
+ expected: "http://my-custom-slug.tunnl.live",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := buildURL(tt.protocol, tt.subdomain, tt.domain)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestTickCmd(t *testing.T) {
+ tests := []struct {
+ name string
+ duration time.Duration
+ }{
+ {
+ name: "5 second tick",
+ duration: 5 * time.Second,
+ },
+ {
+ name: "1 second tick",
+ duration: 1 * time.Second,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cmd := tickCmd(tt.duration)
+ assert.NotNil(t, cmd)
+ })
+ }
+}
+
+func TestGetPaddingValue(t *testing.T) {
+ tests := []struct {
+ name string
+ isVeryCompact bool
+ isCompact bool
+ expected int
+ }{
+ {
+ name: "very compact layout",
+ isVeryCompact: true,
+ isCompact: false,
+ expected: 1,
+ },
+ {
+ name: "compact layout",
+ isVeryCompact: false,
+ isCompact: true,
+ expected: 1,
+ },
+ {
+ name: "normal layout",
+ isVeryCompact: false,
+ isCompact: false,
+ expected: 2,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := getPaddingValue(tt.isVeryCompact, tt.isCompact)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestGetMarginValue(t *testing.T) {
+ tests := []struct {
+ name string
+ isCompact bool
+ compactValue int
+ normalValue int
+ expected int
+ }{
+ {
+ name: "compact layout",
+ isCompact: true,
+ compactValue: 1,
+ normalValue: 2,
+ expected: 1,
+ },
+ {
+ name: "normal layout",
+ isCompact: false,
+ compactValue: 1,
+ normalValue: 2,
+ expected: 2,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := getMarginValue(tt.isCompact, tt.compactValue, tt.normalValue)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestCommandItem(t *testing.T) {
+ tests := []struct {
+ name string
+ item commandItem
+ wantName string
+ wantDesc string
+ }{
+ {
+ name: "slug command",
+ item: commandItem{name: "slug", desc: "Set custom subdomain"},
+ wantName: "slug",
+ wantDesc: "Set custom subdomain",
+ },
+ {
+ name: "tunnel-type command",
+ item: commandItem{name: "tunnel-type", desc: "Change tunnel type"},
+ wantName: "tunnel-type",
+ wantDesc: "Change tunnel type",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ assert.Equal(t, tt.wantName, tt.item.FilterValue())
+ assert.Equal(t, tt.wantName, tt.item.Title())
+ assert.Equal(t, tt.wantDesc, tt.item.Description())
+ })
+ }
+}
+
+func TestModel_GetTunnelURL(t *testing.T) {
+ tests := []struct {
+ name string
+ tunnelType types.TunnelType
+ protocol string
+ slug string
+ domain string
+ port uint16
+ expected string
+ }{
+ {
+ name: "http tunnel",
+ tunnelType: types.TunnelTypeHTTP,
+ protocol: "http",
+ slug: "my-app",
+ domain: "tunnl.live",
+ expected: "http://my-app.tunnl.live",
+ },
+ {
+ name: "https tunnel",
+ tunnelType: types.TunnelTypeHTTP,
+ protocol: "https",
+ slug: "secure-app",
+ domain: "tunnl.live",
+ expected: "https://secure-app.tunnl.live",
+ },
+ {
+ name: "tcp tunnel",
+ tunnelType: types.TunnelTypeTCP,
+ domain: "tunnl.live",
+ port: 8080,
+ expected: "tcp://tunnl.live:8080",
+ },
+ {
+ name: "tcp tunnel with different port",
+ tunnelType: types.TunnelTypeTCP,
+ domain: "tunnl.live",
+ port: 3306,
+ expected: "tcp://tunnl.live:3306",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ mockSlug.On("String").Return(tt.slug).Maybe()
+
+ m := &model{
+ domain: tt.domain,
+ protocol: tt.protocol,
+ tunnelType: tt.tunnelType,
+ port: tt.port,
+ interaction: mockInteraction.(*interaction),
+ }
+
+ result := m.getTunnelURL()
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestModel_Init(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockCloser := &MockCloser{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close)
+
+ m := &model{
+ interaction: mockInteraction.(*interaction),
+ }
+
+ cmd := m.Init()
+ assert.NotNil(t, cmd)
+}
+
+func TestInteraction_Start_Interactive(t *testing.T) {
+ tests := []struct {
+ name string
+ tlsEnabled bool
+ tunnelType types.TunnelType
+ port uint16
+ domain string
+ }{
+ {
+ name: "interactive mode with http",
+ tlsEnabled: false,
+ tunnelType: types.TunnelTypeHTTP,
+ port: 8080,
+ domain: "tunnl.live",
+ },
+ {
+ name: "interactive mode with https",
+ tlsEnabled: true,
+ tunnelType: types.TunnelTypeHTTP,
+ port: 8443,
+ domain: "secure.tunnl.live",
+ },
+ {
+ name: "interactive mode with tcp",
+ tlsEnabled: false,
+ tunnelType: types.TunnelTypeTCP,
+ port: 3306,
+ domain: "db.tunnl.live",
+ },
+ {
+ name: "interactive mode with tcp and tls enabled",
+ tlsEnabled: true,
+ tunnelType: types.TunnelTypeTCP,
+ port: 5432,
+ domain: "postgres.tunnl.live",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeCallCount := 0
+ closeFunc := func() error {
+ closeCallCount++
+ return nil
+ }
+
+ mockConfig.On("Domain").Return(tt.domain)
+ mockConfig.On("TLSEnabled").Return(tt.tlsEnabled)
+ mockForwarder.On("TunnelType").Return(tt.tunnelType)
+ mockForwarder.On("ForwardedPort").Return(tt.port)
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ mockInteraction.SetMode(types.InteractiveModeINTERACTIVE)
+
+ mockChannel := &MockChannel{}
+ mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe()
+ mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe()
+ mockInteraction.SetChannel(mockChannel)
+
+ done := make(chan bool, 1)
+ go func() {
+ mockInteraction.Start()
+ done <- true
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+
+ i := mockInteraction.(*interaction)
+ i.Stop()
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("Start() did not complete in time")
+ }
+
+ assert.Equal(t, 1, closeCallCount, "close function should be called once")
+
+ mockConfig.AssertExpectations(t)
+ mockForwarder.AssertExpectations(t)
+ })
+ }
+}
+
+func TestInteraction_Start_ProtocolSelection(t *testing.T) {
+ tests := []struct {
+ name string
+ tlsEnabled bool
+ expectedProto string
+ }{
+ {
+ name: "http when TLS disabled",
+ tlsEnabled: false,
+ expectedProto: "http",
+ },
+ {
+ name: "https when TLS enabled",
+ tlsEnabled: true,
+ expectedProto: "https",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeFunc := func() error { return nil }
+
+ mockConfig.On("Domain").Return("tunnl.live")
+ mockConfig.On("TLSEnabled").Return(tt.tlsEnabled)
+ mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP)
+ mockForwarder.On("ForwardedPort").Return(uint16(8080))
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ mockInteraction.SetMode(types.InteractiveModeINTERACTIVE)
+
+ mockChannel := &MockChannel{}
+ mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe()
+ mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe()
+ mockInteraction.SetChannel(mockChannel)
+
+ go func() {
+ mockInteraction.Start()
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+
+ i := mockInteraction.(*interaction)
+ if i.program != nil {
+ assert.NotNil(t, i.program, "program should be initialized")
+ }
+
+ i.Stop()
+
+ mockConfig.AssertExpectations(t)
+ mockForwarder.AssertExpectations(t)
+ })
+ }
+}
+
+func TestInteraction_Stop(t *testing.T) {
+ tests := []struct {
+ name string
+ setupProgram bool
+ description string
+ }{
+ {
+ name: "stop with active program",
+ setupProgram: true,
+ description: "should kill program and set to nil",
+ },
+ {
+ name: "stop without program",
+ setupProgram: false,
+ description: "should not panic when program is nil",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeFunc := func() error { return nil }
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ i := mockInteraction.(*interaction)
+
+ if tt.setupProgram {
+ mockConfig.On("Domain").Return("tunnl.live")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP)
+ mockForwarder.On("ForwardedPort").Return(uint16(8080))
+
+ mockInteraction.SetMode(types.InteractiveModeINTERACTIVE)
+ mockChannel := &MockChannel{}
+ mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe()
+ mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe()
+ mockInteraction.SetChannel(mockChannel)
+
+ go func() {
+ mockInteraction.Start()
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+ }
+
+ assert.NotPanics(t, func() {
+ i.Stop()
+ })
+
+ assert.Nil(t, i.program)
+
+ select {
+ case <-i.ctx.Done():
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("context should be cancelled after Stop()")
+ }
+ })
+ }
+}
+
+func TestInteraction_Start_CommandListSetup(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeFunc := func() error { return nil }
+
+ mockConfig.On("Domain").Return("tunnl.live")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP)
+ mockForwarder.On("ForwardedPort").Return(uint16(8080))
+
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ mockInteraction.SetMode(types.InteractiveModeINTERACTIVE)
+
+ mockChannel := &MockChannel{}
+ mockChannel.On("Read", mock.Anything).Return(0, nil)
+ mockChannel.On("Write", mock.Anything).Return(0, nil)
+ mockInteraction.SetChannel(mockChannel)
+
+ go func() {
+ mockInteraction.Start()
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+
+ i := mockInteraction.(*interaction)
+
+ assert.NotNil(t, i.program, "program should be initialized")
+
+ i.Stop()
+}
+
+func TestInteraction_Start_TextInputSetup(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeFunc := func() error { return nil }
+
+ mockConfig.On("Domain").Return("tunnl.live")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP)
+ mockForwarder.On("ForwardedPort").Return(uint16(8080))
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ mockInteraction.SetMode(types.InteractiveModeINTERACTIVE)
+
+ mockChannel := &MockChannel{}
+ mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe()
+ mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe()
+ mockInteraction.SetChannel(mockChannel)
+
+ go func() {
+ mockInteraction.Start()
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+
+ i := mockInteraction.(*interaction)
+ i.Stop()
+
+ mockConfig.AssertExpectations(t)
+ mockForwarder.AssertExpectations(t)
+}
+
+func TestInteraction_Start_CleanupOnExit(t *testing.T) {
+ tests := []struct {
+ name string
+ closeFunc CloseFunc
+ expectCloseCalled bool
+ }{
+ {
+ name: "cleanup calls close function",
+ closeFunc: func() error {
+ return nil
+ },
+ expectCloseCalled: true,
+ },
+ {
+ name: "cleanup with nil close function",
+ closeFunc: nil,
+ expectCloseCalled: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+
+ closeCallCount := 0
+ var closeFunc CloseFunc
+ if tt.closeFunc != nil {
+ closeFunc = func() error {
+ closeCallCount++
+ return tt.closeFunc()
+ }
+ }
+
+ mockConfig.On("Domain").Return("tunnl.live")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP)
+ mockForwarder.On("ForwardedPort").Return(uint16(8080))
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ mockInteraction.SetMode(types.InteractiveModeINTERACTIVE)
+
+ mockChannel := &MockChannel{}
+ mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe()
+ mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe()
+ mockInteraction.SetChannel(mockChannel)
+
+ done := make(chan bool, 1)
+ go func() {
+ mockInteraction.Start()
+ done <- true
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+
+ i := mockInteraction.(*interaction)
+ i.Stop()
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("Start() did not complete")
+ }
+
+ if tt.expectCloseCalled {
+ assert.Equal(t, 1, closeCallCount, "close function should be called")
+ } else {
+ assert.Equal(t, 0, closeCallCount, "close function should not be called when nil")
+ }
+ })
+ }
+}
+
+func TestInteraction_Start_WithDifferentChannels(t *testing.T) {
+ tests := []struct {
+ name string
+ setupChannel bool
+ }{
+ {
+ name: "start with channel set",
+ setupChannel: true,
+ },
+ {
+ name: "start with nil channel",
+ setupChannel: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeFunc := func() error { return nil }
+
+ mockConfig.On("Domain").Return("tunnl.live")
+ mockConfig.On("TLSEnabled").Return(false)
+ mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP)
+ mockForwarder.On("ForwardedPort").Return(uint16(8080))
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ mockInteraction.SetMode(types.InteractiveModeINTERACTIVE)
+
+ if tt.setupChannel {
+ mockChannel := &MockChannel{}
+ mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe()
+ mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe()
+ mockInteraction.SetChannel(mockChannel)
+ }
+
+ go func() {
+ mockInteraction.Start()
+ }()
+
+ time.Sleep(50 * time.Millisecond)
+
+ i := mockInteraction.(*interaction)
+ i.Stop()
+
+ mockConfig.AssertExpectations(t)
+ mockForwarder.AssertExpectations(t)
+ })
+ }
+}
+
+func TestInteraction_Stop_ContextCancellation(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeFunc := func() error { return nil }
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ i := mockInteraction.(*interaction)
+
+ select {
+ case <-i.ctx.Done():
+ t.Fatal("context should not be cancelled initially")
+ default:
+ }
+
+ i.Stop()
+
+ select {
+ case <-i.ctx.Done():
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("context should be cancelled after Stop()")
+ }
+
+ assert.NotPanics(t, func() {
+ i.Stop()
+ })
+}
+
+func TestInteraction_Stop_MultipleCallsSafe(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeFunc := func() error { return nil }
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ i := mockInteraction.(*interaction)
+
+ assert.NotPanics(t, func() {
+ i.Stop()
+ i.Stop()
+ i.Stop()
+ })
+
+ assert.Nil(t, i.program)
+}
+
+func TestInteraction_Start_HeadlessMode_NoOp(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeFunc := func() error { return nil }
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ mockInteraction.SetMode(types.InteractiveModeHEADLESS)
+
+ done := make(chan bool, 1)
+ go func() {
+ mockInteraction.Start()
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("headless mode should return immediately")
+ }
+
+ i := mockInteraction.(*interaction)
+ assert.Nil(t, i.program, "program should not be created in headless mode")
+
+ mockConfig.AssertNotCalled(t, "Domain")
+ mockConfig.AssertNotCalled(t, "TLSEnabled")
+ mockForwarder.AssertNotCalled(t, "TunnelType")
+ mockForwarder.AssertNotCalled(t, "ForwardedPort")
+}
+
+func TestInteraction_New_ContextInitialization(t *testing.T) {
+ mockRandom := &MockRandom{}
+ mockConfig := &MockConfig{}
+ mockSlug := &MockSlug{}
+ mockForwarder := &MockForwarder{}
+ mockSessionRegistry := &MockSessionRegistry{}
+ closeFunc := func() error { return nil }
+ mockSlug.On("String").Return("test-slug")
+
+ mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc)
+ i := mockInteraction.(*interaction)
+
+ assert.NotNil(t, i.ctx, "context should be initialized")
+ assert.NotNil(t, i.cancel, "cancel function should be initialized")
+
+ select {
+ case <-i.ctx.Done():
+ t.Fatal("context should not be cancelled initially")
+ default:
+ }
+}
diff --git a/session/interaction/model.go b/session/interaction/model.go
index 189b0a1..c0d1672 100644
--- a/session/interaction/model.go
+++ b/session/interaction/model.go
@@ -3,6 +3,7 @@ package interaction
import (
"fmt"
"time"
+ "tunnel_pls/internal/random"
"tunnel_pls/types"
"github.com/charmbracelet/bubbles/help"
@@ -22,6 +23,7 @@ func (i commandItem) Title() string { return i.name }
func (i commandItem) Description() string { return i.desc }
type model struct {
+ randomizer random.Random
domain string
protocol string
tunnelType types.TunnelType
@@ -40,6 +42,25 @@ type model struct {
height int
}
+const (
+ ColorPrimary = "#7D56F4"
+ ColorSecondary = "#04B575"
+ ColorGray = "#888888"
+ ColorDarkGray = "#666666"
+ ColorWhite = "#FAFAFA"
+ ColorError = "#FF0000"
+ ColorErrorBg = "#3D0000"
+ ColorWarning = "#FFA500"
+ ColorWarningBg = "#3D2000"
+)
+
+const (
+ BreakpointTiny = 50
+ BreakpointSmall = 60
+ BreakpointMedium = 70
+ BreakpointLarge = 85
+)
+
func (m *model) getTunnelURL() string {
if m.tunnelType == types.TunnelTypeHTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
diff --git a/session/interaction/slug.go b/session/interaction/slug.go
index 2b871d4..ff57cb1 100644
--- a/session/interaction/slug.go
+++ b/session/interaction/slug.go
@@ -3,7 +3,6 @@ package interaction
import (
"fmt"
"strings"
- "tunnel_pls/internal/random"
"tunnel_pls/types"
"github.com/charmbracelet/bubbles/key"
@@ -22,7 +21,7 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
}
switch msg.String() {
- case "esc":
+ case "esc", "ctrl+c":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
@@ -41,19 +40,13 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
- case "ctrl+c":
- m.editingSlug = false
- m.slugError = ""
- return m, tea.Batch(tea.ClearScreen, textinput.Blink)
default:
if key.Matches(msg, m.keymap.random) {
- newSubdomain, err := random.GenerateRandomString(20)
+ newSubdomain, err := m.randomizer.String(20)
if err != nil {
return m, cmd
}
m.slugInput.SetValue(newSubdomain)
- m.slugError = ""
- m.slugInput, cmd = m.slugInput.Update(msg)
}
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
@@ -62,163 +55,211 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
}
func (m *model) slugView() string {
- isCompact := shouldUseCompactLayout(m.width, 70)
- isVeryCompact := shouldUseCompactLayout(m.width, 50)
+ isCompact := shouldUseCompactLayout(m.width, BreakpointMedium)
+ isVeryCompact := shouldUseCompactLayout(m.width, BreakpointTiny)
- var boxPadding int
- var boxMargin int
- if isVeryCompact {
- boxPadding = 1
- boxMargin = 1
- } else if isCompact {
- boxPadding = 1
- boxMargin = 1
- } else {
- boxPadding = 2
- boxMargin = 2
+ var b strings.Builder
+ b.WriteString(m.renderSlugTitle(isVeryCompact))
+
+ if m.tunnelType != types.TunnelTypeHTTP {
+ b.WriteString(m.renderTCPWarning(isVeryCompact, isCompact))
+ return b.String()
}
+ b.WriteString(m.renderSlugRules(isVeryCompact, isCompact))
+ b.WriteString(m.renderSlugInstruction(isVeryCompact))
+ b.WriteString(m.renderSlugInput(isVeryCompact, isCompact))
+ b.WriteString(m.renderSlugPreview(isVeryCompact))
+ b.WriteString(m.renderSlugHelp(isVeryCompact))
+
+ return b.String()
+}
+
+func (m *model) renderSlugTitle(isVeryCompact bool) string {
titleStyle := lipgloss.NewStyle().
Bold(true).
- Foreground(lipgloss.Color("#7D56F4")).
+ Foreground(lipgloss.Color(ColorPrimary)).
PaddingTop(1).
PaddingBottom(1)
- instructionStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#FAFAFA")).
- MarginTop(1)
+ title := "🔧 Edit Subdomain"
+ if isVeryCompact {
+ title = "Edit Subdomain"
+ }
- inputBoxStyle := lipgloss.NewStyle().
+ return titleStyle.Render(title) + "\n\n"
+}
+
+func (m *model) renderTCPWarning(isVeryCompact, isCompact bool) string {
+ boxPadding := getPaddingValue(isVeryCompact, isCompact)
+ boxMargin := getMarginValue(isCompact, 1, 2)
+ warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
+
+ warningBoxStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorWarning)).
+ Background(lipgloss.Color(ColorWarningBg)).
+ Bold(true).
Border(lipgloss.RoundedBorder()).
- BorderForeground(lipgloss.Color("#7D56F4")).
+ BorderForeground(lipgloss.Color(ColorWarning)).
Padding(1, boxPadding).
MarginTop(boxMargin).
- MarginBottom(boxMargin)
+ MarginBottom(boxMargin).
+ Width(warningBoxWidth)
helpStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#666666")).
+ Foreground(lipgloss.Color(ColorDarkGray)).
Italic(true).
MarginTop(1)
- errorBoxStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#FF0000")).
- Background(lipgloss.Color("#3D0000")).
- Bold(true).
- Border(lipgloss.RoundedBorder()).
- BorderForeground(lipgloss.Color("#FF0000")).
- Padding(0, boxPadding).
- MarginTop(1).
- MarginBottom(1)
+ warningText := m.getTCPWarningText(isVeryCompact)
+ helpText := m.getTCPHelpText(isVeryCompact)
+ var b strings.Builder
+ b.WriteString(warningBoxStyle.Render(warningText))
+ b.WriteString("\n\n")
+ b.WriteString(helpStyle.Render(helpText))
+
+ return b.String()
+}
+
+func (m *model) getTCPWarningText(isVeryCompact bool) string {
+ if isVeryCompact {
+ return "⚠️ TCP tunnels don't support custom subdomains."
+ }
+ return "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
+}
+
+func (m *model) getTCPHelpText(isVeryCompact bool) string {
+ if isVeryCompact {
+ return "Press any key to go back"
+ }
+ return "Press Enter or Esc to go back"
+}
+
+func (m *model) renderSlugRules(isVeryCompact, isCompact bool) string {
+ boxPadding := getPaddingValue(isVeryCompact, isCompact)
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
+
rulesBoxStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#FAFAFA")).
+ Foreground(lipgloss.Color(ColorWhite)).
Border(lipgloss.RoundedBorder()).
- BorderForeground(lipgloss.Color("#7D56F4")).
+ BorderForeground(lipgloss.Color(ColorPrimary)).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1).
Width(rulesBoxWidth)
- var b strings.Builder
- var title string
+ rulesContent := m.getRulesContent(isVeryCompact, isCompact)
+ return rulesBoxStyle.Render(rulesContent) + "\n"
+}
+
+func (m *model) getRulesContent(isVeryCompact, isCompact bool) string {
if isVeryCompact {
- title = "Edit Subdomain"
- } else {
- title = "🔧 Edit Subdomain"
- }
- b.WriteString(titleStyle.Render(title))
- b.WriteString("\n\n")
-
- if m.tunnelType != types.TunnelTypeHTTP {
- warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
- warningBoxStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#FFA500")).
- Background(lipgloss.Color("#3D2000")).
- Bold(true).
- Border(lipgloss.RoundedBorder()).
- BorderForeground(lipgloss.Color("#FFA500")).
- Padding(1, boxPadding).
- MarginTop(boxMargin).
- MarginBottom(boxMargin).
- Width(warningBoxWidth)
-
- var warningText string
- if isVeryCompact {
- warningText = "⚠️ TCP tunnels don't support custom subdomains."
- } else {
- warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
- }
- b.WriteString(warningBoxStyle.Render(warningText))
- b.WriteString("\n\n")
-
- var helpText string
- if isVeryCompact {
- helpText = "Press any key to go back"
- } else {
- helpText = "Press Enter or Esc to go back"
- }
- b.WriteString(helpStyle.Render(helpText))
- return b.String()
+ return "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
}
- var rulesContent string
- if isVeryCompact {
- rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
- } else if isCompact {
- rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
- } else {
- rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
+ if isCompact {
+ return "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
}
- b.WriteString(rulesBoxStyle.Render(rulesContent))
- b.WriteString("\n")
- var instruction string
+ return "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
+}
+
+func (m *model) renderSlugInstruction(isVeryCompact bool) string {
+ instructionStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorWhite)).
+ MarginTop(1)
+
+ instruction := "Enter your custom subdomain:"
if isVeryCompact {
instruction = "Custom subdomain:"
- } else {
- instruction = "Enter your custom subdomain:"
}
- b.WriteString(instructionStyle.Render(instruction))
- b.WriteString("\n")
+
+ return instructionStyle.Render(instruction) + "\n"
+}
+
+func (m *model) renderSlugInput(isVeryCompact, isCompact bool) string {
+ boxPadding := getPaddingValue(isVeryCompact, isCompact)
+ boxMargin := getMarginValue(isCompact, 1, 2)
if m.slugError != "" {
- errorInputBoxStyle := lipgloss.NewStyle().
- Border(lipgloss.RoundedBorder()).
- BorderForeground(lipgloss.Color("#FF0000")).
- Padding(1, boxPadding).
- MarginTop(boxMargin).
- MarginBottom(1)
- b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
- b.WriteString("\n")
- b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
- b.WriteString("\n")
- } else {
- b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
- b.WriteString("\n")
+ return m.renderErrorInput(boxPadding, boxMargin)
}
+ return m.renderNormalInput(boxPadding, boxMargin)
+}
+
+func (m *model) renderErrorInput(boxPadding, boxMargin int) string {
+ errorInputBoxStyle := lipgloss.NewStyle().
+ Border(lipgloss.RoundedBorder()).
+ BorderForeground(lipgloss.Color(ColorError)).
+ Padding(1, boxPadding).
+ MarginTop(boxMargin).
+ MarginBottom(1)
+
+ errorBoxStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorError)).
+ Background(lipgloss.Color(ColorErrorBg)).
+ Bold(true).
+ Border(lipgloss.RoundedBorder()).
+ BorderForeground(lipgloss.Color(ColorError)).
+ Padding(0, boxPadding).
+ MarginTop(1).
+ MarginBottom(1)
+
+ var b strings.Builder
+ b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
+ b.WriteString("\n")
+ b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
+ b.WriteString("\n")
+
+ return b.String()
+}
+
+func (m *model) renderNormalInput(boxPadding, boxMargin int) string {
+ inputBoxStyle := lipgloss.NewStyle().
+ Border(lipgloss.RoundedBorder()).
+ BorderForeground(lipgloss.Color(ColorPrimary)).
+ Padding(1, boxPadding).
+ MarginTop(boxMargin).
+ MarginBottom(boxMargin)
+
+ return inputBoxStyle.Render(m.slugInput.View()) + "\n"
+}
+
+func (m *model) renderSlugPreview(isVeryCompact bool) string {
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
- if len(previewURL) > previewWidth-10 {
+ if isVeryCompact {
previewURL = truncateString(previewURL, previewWidth-10)
}
previewStyle := lipgloss.NewStyle().
- Foreground(lipgloss.Color("#04B575")).
+ Foreground(lipgloss.Color(ColorSecondary)).
Italic(true).
Width(previewWidth)
- b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
- b.WriteString("\n")
- var helpText string
+ return previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)) + "\n"
+}
+
+func (m *model) renderSlugHelp(isVeryCompact bool) string {
+ helpStyle := lipgloss.NewStyle().
+ Foreground(lipgloss.Color(ColorDarkGray)).
+ Italic(true).
+ MarginTop(1)
+
+ helpText := "Press Enter to save • CTRL+R for random • Esc to cancel"
if isVeryCompact {
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
- } else {
- helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
}
- b.WriteString(helpStyle.Render(helpText))
- return b.String()
+ return helpStyle.Render(helpText)
+}
+
+func getPaddingValue(isVeryCompact, isCompact bool) int {
+ if isVeryCompact || isCompact {
+ return 1
+ }
+ return 2
}
diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go
index f9f9d6e..2d9dd3e 100644
--- a/session/lifecycle/lifecycle.go
+++ b/session/lifecycle/lifecycle.go
@@ -2,6 +2,9 @@ package lifecycle
import (
"errors"
+ "io"
+ "net"
+ "sync"
"time"
portUtil "tunnel_pls/internal/port"
@@ -22,7 +25,9 @@ type SessionRegistry interface {
}
type lifecycle struct {
+ mu sync.Mutex
status types.SessionStatus
+ closeErr error
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
@@ -49,6 +54,7 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti
type Lifecycle interface {
Connection() ssh.Conn
+ Channel() ssh.Channel
PortRegistry() portUtil.Port
User() string
SetChannel(channel ssh.Channel)
@@ -69,33 +75,48 @@ func (l *lifecycle) User() string {
func (l *lifecycle) SetChannel(channel ssh.Channel) {
l.channel = channel
}
+
+func (l *lifecycle) Channel() ssh.Channel {
+ return l.channel
+}
+
func (l *lifecycle) Connection() ssh.Conn {
return l.conn
}
+
func (l *lifecycle) SetStatus(status types.SessionStatus) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
l.status = status
- if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
- l.startedAt = time.Now()
- }
}
-func closeIfNotNil(c interface{ Close() error }) error {
- if c != nil {
- return c.Close()
- }
- return nil
+func (l *lifecycle) IsActive() bool {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ return l.status == types.SessionStatusRUNNING
}
func (l *lifecycle) Close() error {
+ l.mu.Lock()
+ defer l.mu.Unlock()
+ if l.status == types.SessionStatusCLOSED {
+ return l.closeErr
+ }
+ l.status = types.SessionStatusCLOSED
+
var errs []error
tunnelType := l.forwarder.TunnelType()
- if err := closeIfNotNil(l.channel); err != nil {
- errs = append(errs, err)
+ if l.channel != nil {
+ if err := l.channel.Close(); err != nil && !isClosedError(err) {
+ errs = append(errs, err)
+ }
}
- if err := closeIfNotNil(l.conn); err != nil {
- errs = append(errs, err)
+ if l.conn != nil {
+ if err := l.conn.Close(); err != nil && !isClosedError(err) {
+ errs = append(errs, err)
+ }
}
clientSlug := l.slug.String()
@@ -106,19 +127,19 @@ func (l *lifecycle) Close() error {
l.sessionRegistry.Remove(key)
if tunnelType == types.TunnelTypeTCP {
- if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil {
- errs = append(errs, err)
- }
- if err := l.forwarder.Close(); err != nil {
- errs = append(errs, err)
- }
+ errs = append(errs, l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false))
+ errs = append(errs, l.forwarder.Close())
}
- return errors.Join(errs...)
+ l.closeErr = errors.Join(errs...)
+ return l.closeErr
}
-func (l *lifecycle) IsActive() bool {
- return l.status == types.SessionStatusRUNNING
+func isClosedError(err error) bool {
+ if err == nil {
+ return false
+ }
+ return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || err.Error() == "EOF"
}
func (l *lifecycle) StartedAt() time.Time {
diff --git a/session/lifecycle/lifecycle_test.go b/session/lifecycle/lifecycle_test.go
new file mode 100644
index 0000000..608b3a8
--- /dev/null
+++ b/session/lifecycle/lifecycle_test.go
@@ -0,0 +1,303 @@
+package lifecycle
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "testing"
+ "tunnel_pls/types"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "golang.org/x/crypto/ssh"
+)
+
+type MockSessionRegistry struct {
+ mock.Mock
+}
+
+func (m *MockSessionRegistry) Remove(key types.SessionKey) {
+ m.Called(key)
+}
+
+type MockForwarder struct {
+ mock.Mock
+}
+
+func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
+ args := m.Called(origin)
+ return args.Get(0).([]byte)
+}
+
+func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
+ m.Called(dst, src)
+}
+
+func (m *MockForwarder) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+func (m *MockForwarder) TunnelType() types.TunnelType {
+ args := m.Called()
+ return args.Get(0).(types.TunnelType)
+}
+
+func (m *MockForwarder) ForwardedPort() uint16 {
+ args := m.Called()
+ return args.Get(0).(uint16)
+}
+
+func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
+ m.Called(tunnelType)
+}
+
+func (m *MockForwarder) SetForwardedPort(port uint16) {
+ m.Called(port)
+}
+
+func (m *MockForwarder) SetListener(listener net.Listener) {
+ m.Called(listener)
+}
+
+func (m *MockForwarder) Listener() net.Listener {
+ args := m.Called()
+ return args.Get(0).(net.Listener)
+}
+
+func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
+ args := m.Called(ctx, origin)
+ if args.Get(0) == nil {
+ return nil, nil, args.Error(2)
+ }
+ return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
+}
+
+type MockPort struct {
+ mock.Mock
+}
+
+func (m *MockPort) AddRange(startPort, endPort uint16) error {
+ return m.Called(startPort, endPort).Error(0)
+}
+func (m *MockPort) Unassigned() (uint16, bool) {
+ args := m.Called()
+ var port uint16
+ if args.Get(0) != nil {
+ switch v := args.Get(0).(type) {
+ case int:
+ port = uint16(v)
+ case uint16:
+ port = v
+ case uint32:
+ port = uint16(v)
+ case int32:
+ port = uint16(v)
+ case float64:
+ port = uint16(v)
+ default:
+ port = uint16(args.Int(0))
+ }
+ }
+ return port, args.Bool(1)
+}
+func (m *MockPort) SetStatus(port uint16, assigned bool) error {
+ return m.Called(port, assigned).Error(0)
+}
+func (m *MockPort) Claim(port uint16) bool {
+ return m.Called(port).Bool(0)
+}
+
+type MockSlug struct {
+ mock.Mock
+}
+
+func (ms *MockSlug) Set(slug string) {
+ ms.Called(slug)
+}
+func (ms *MockSlug) String() string {
+ return ms.Called().String(0)
+}
+
+type MockSSHConn struct {
+ ssh.Conn
+ mock.Mock
+}
+
+func (m *MockSSHConn) Close() error {
+ args := m.Called()
+ return args.Error(0)
+}
+
+type MockSSHChannel struct {
+ ssh.Channel
+ mock.Mock
+}
+
+func (m *MockSSHChannel) Close() error {
+ return m.Called().Error(0)
+}
+
+func TestNew(t *testing.T) {
+ mockSSHConn := new(MockSSHConn)
+ mockForwarder := &MockForwarder{}
+ mockSlug := &MockSlug{}
+ mockPort := &MockPort{}
+ mockSessionRegistry := &MockSessionRegistry{}
+
+ mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
+
+ assert.NotNil(t, mockLifecycle.Connection())
+ assert.NotNil(t, mockLifecycle.User())
+ assert.NotNil(t, mockLifecycle.PortRegistry())
+ assert.NotNil(t, mockLifecycle.StartedAt())
+}
+
+func TestLifecycle_User(t *testing.T) {
+ mockSSHConn := new(MockSSHConn)
+ mockForwarder := &MockForwarder{}
+ mockSlug := &MockSlug{}
+ mockPort := &MockPort{}
+ mockSessionRegistry := &MockSessionRegistry{}
+
+ user := "mas-fuad"
+ mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user)
+ assert.Equal(t, user, mockLifecycle.User())
+}
+
+func TestLifecycle_SetChannel(t *testing.T) {
+ mockSSHConn := new(MockSSHConn)
+ mockForwarder := &MockForwarder{}
+ mockSlug := &MockSlug{}
+ mockPort := &MockPort{}
+ mockSessionRegistry := &MockSessionRegistry{}
+
+ mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
+
+ mockSSHChannel := &MockSSHChannel{}
+
+ mockLifecycle.SetChannel(mockSSHChannel)
+
+ assert.Equal(t, mockSSHChannel, mockLifecycle.Channel())
+}
+
+func TestLifecycle_SetStatus(t *testing.T) {
+ mockSSHConn := new(MockSSHConn)
+ mockForwarder := &MockForwarder{}
+ mockSlug := &MockSlug{}
+ mockPort := &MockPort{}
+ mockSessionRegistry := &MockSessionRegistry{}
+
+ mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
+
+ mockLifecycle.SetStatus(types.SessionStatusRUNNING)
+ assert.True(t, mockLifecycle.IsActive())
+}
+
+func TestLifecycle_IsActive(t *testing.T) {
+ mockSSHConn := new(MockSSHConn)
+ mockForwarder := &MockForwarder{}
+ mockSlug := &MockSlug{}
+ mockPort := &MockPort{}
+ mockSessionRegistry := &MockSessionRegistry{}
+
+ mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
+
+ mockLifecycle.SetStatus(types.SessionStatusRUNNING)
+ assert.True(t, mockLifecycle.IsActive())
+}
+
+func TestLifecycle_Close(t *testing.T) {
+ tests := []struct {
+ name string
+ tunnelType types.TunnelType
+ connCloseErr error
+ channelCloseErr error
+ expectErr bool
+ alreadyClosed bool
+ }{
+ {
+ name: "Close HTTP forwarding success",
+ tunnelType: types.TunnelTypeHTTP,
+ expectErr: false,
+ },
+ {
+ name: "Close TCP forwarding success",
+ tunnelType: types.TunnelTypeTCP,
+ expectErr: false,
+ },
+ {
+ name: "Close with conn close error",
+ tunnelType: types.TunnelTypeHTTP,
+ connCloseErr: errors.New("conn close error"),
+ expectErr: true,
+ },
+ {
+ name: "Close with channel close error",
+ tunnelType: types.TunnelTypeHTTP,
+ channelCloseErr: errors.New("channel close error"),
+ expectErr: true,
+ },
+ {
+ name: "Close when already closed",
+ tunnelType: types.TunnelTypeHTTP,
+ alreadyClosed: true,
+ expectErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ mockSSHConn := &MockSSHConn{}
+ mockSSHConn.On("Close").Return(tt.connCloseErr)
+
+ mockForwarder := &MockForwarder{}
+ mockForwarder.On("TunnelType").Return(tt.tunnelType)
+ if tt.tunnelType == types.TunnelTypeTCP {
+ mockForwarder.On("ForwardedPort").Return(uint16(8080))
+ mockForwarder.On("Close").Return(nil)
+ }
+
+ mockSlug := &MockSlug{}
+ mockSlug.On("String").Return("test-slug")
+
+ mockPort := &MockPort{}
+ if tt.tunnelType == types.TunnelTypeTCP {
+ mockPort.On("SetStatus", uint16(8080), false).Return(nil)
+ }
+
+ mockSessionRegistry := &MockSessionRegistry{}
+ mockSessionRegistry.On("Remove", mock.Anything).Return()
+
+ mockSSHChannel := &MockSSHChannel{}
+ mockSSHChannel.On("Close").Return(tt.channelCloseErr)
+
+ mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
+
+ mockLifecycle.SetStatus(types.SessionStatusRUNNING)
+ mockLifecycle.SetChannel(mockSSHChannel)
+
+ if tt.alreadyClosed {
+ err := mockLifecycle.Close()
+ assert.NoError(t, err)
+ }
+
+ err := mockLifecycle.Close()
+
+ if tt.expectErr {
+ assert.Error(t, err)
+ } else {
+ assert.NoError(t, err)
+ }
+ assert.False(t, mockLifecycle.IsActive())
+
+ mockSSHConn.AssertExpectations(t)
+ mockForwarder.AssertExpectations(t)
+ mockSlug.AssertExpectations(t)
+ mockPort.AssertExpectations(t)
+ mockSessionRegistry.AssertExpectations(t)
+ mockSSHChannel.AssertExpectations(t)
+ })
+ }
+}
diff --git a/session/session.go b/session/session.go
index b1895ab..cc27c4c 100644
--- a/session/session.go
+++ b/session/session.go
@@ -1,7 +1,6 @@
package session
import (
- "bytes"
"encoding/binary"
"errors"
"fmt"
@@ -37,6 +36,7 @@ type Session interface {
}
type session struct {
+ randomizer random.Random
config config.Config
initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel
@@ -47,23 +47,35 @@ type session struct {
registry registry.Registry
}
+type Config struct {
+ Randomizer random.Random
+ Config config.Config
+ Conn *ssh.ServerConn
+ InitialReq <-chan *ssh.Request
+ SshChan <-chan ssh.NewChannel
+ SessionRegistry registry.Registry
+ PortRegistry portUtil.Port
+ User string
+}
+
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
-func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
+func New(conf *Config) Session {
slugManager := slug.New()
- forwarderManager := forwarder.New(config, slugManager, conn)
- lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
- interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
+ forwarderManager := forwarder.New(conf.Config, slugManager, conf.Conn)
+ lifecycleManager := lifecycle.New(conf.Conn, forwarderManager, slugManager, conf.PortRegistry, conf.SessionRegistry, conf.User)
+ interactionManager := interaction.New(conf.Randomizer, conf.Config, slugManager, forwarderManager, conf.SessionRegistry, conf.User, lifecycleManager.Close)
return &session{
- config: config,
- initialReq: initialReq,
- sshChan: sshChan,
+ randomizer: conf.Randomizer,
+ config: conf.Config,
+ initialReq: conf.InitialReq,
+ sshChan: conf.SshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slug: slugManager,
- registry: sessionRegistry,
+ registry: conf.SessionRegistry,
}
}
@@ -85,12 +97,12 @@ func (s *session) Slug() slug.Slug {
func (s *session) Detail() *types.Detail {
tunnelTypeMap := map[types.TunnelType]string{
- types.TunnelTypeHTTP: "TunnelTypeHTTP",
- types.TunnelTypeTCP: "TunnelTypeTCP",
+ types.TunnelTypeHTTP: "HTTP",
+ types.TunnelTypeTCP: "TCP",
}
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
if !ok {
- tunnelType = "TunnelTypeUNKNOWN"
+ tunnelType = "UNKNOWN"
}
return &types.Detail{
@@ -113,7 +125,7 @@ func (s *session) Start() error {
}
if s.shouldRejectUnauthorized() {
- return s.denyForwardingRequest(tcpipReq, nil, nil, fmt.Sprintf("headless forwarding only allowed on node mode"))
+ return s.denyForwardingRequest(tcpipReq, nil, nil, "headless forwarding only allowed on node mode")
}
if err := s.HandleTCPIPForward(tcpipReq); err != nil {
@@ -160,13 +172,11 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
}
func (s *session) handleMissingForwardRequest() error {
- err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
+ err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
if err != nil {
return err
}
- if err = s.lifecycle.Close(); err != nil {
- log.Printf("failed to close session: %v", err)
- }
+
return fmt.Errorf("no forwarding Request")
}
@@ -182,7 +192,6 @@ func (s *session) waitForSessionEnd() error {
}
if err := s.lifecycle.Close(); err != nil {
- log.Printf("failed to close session: %v", err)
return err
}
return nil
@@ -227,8 +236,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
for req := range GlobalRequest {
switch req.Type {
case "shell", "pty-req":
- err := req.Reply(true, nil)
- if err != nil {
+ if err := req.Reply(true, nil); err != nil {
return err
}
case "window-change":
@@ -237,8 +245,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
}
default:
log.Println("Unknown request type:", req.Type)
- err := req.Reply(false, nil)
- if err != nil {
+ if err := req.Reply(false, nil); err != nil {
return err
}
}
@@ -246,24 +253,24 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
return nil
}
-func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) {
- address, err = readSSHString(payloadReader)
- if err != nil {
- return "", 0, err
+func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) {
+ var forwardPayload struct {
+ BindAddr string
+ BindPort uint32
}
- var rawPortToBind uint32
- if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil {
- return "", 0, err
+ if err = ssh.Unmarshal(payload, &forwardPayload); err != nil {
+ return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err)
}
- if rawPortToBind > 65535 {
+ if forwardPayload.BindPort > 65535 {
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
}
- port = uint16(rawPortToBind)
+ port = uint16(forwardPayload.BindPort)
+
if isBlockedPort(port) {
- return "", 0, fmt.Errorf("port is block")
+ return "", 0, fmt.Errorf("port is blocked")
}
if port == 0 {
@@ -271,10 +278,10 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string,
if !ok {
return "", 0, fmt.Errorf("no available port")
}
- return address, unassigned, err
+ return forwardPayload.BindAddr, unassigned, nil
}
- return address, port, err
+ return forwardPayload.BindAddr, port, nil
}
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
@@ -282,37 +289,25 @@ func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey,
if key != nil {
s.registry.Remove(*key)
}
+
if listener != nil {
- if err := listener.Close(); err != nil {
- errs = append(errs, fmt.Errorf("close listener: %w", err))
- }
- }
- if err := req.Reply(false, nil); err != nil {
- errs = append(errs, fmt.Errorf("reply request: %w", err))
- }
- if err := s.lifecycle.Close(); err != nil {
- errs = append(errs, fmt.Errorf("close session: %w", err))
+ errs = append(errs, listener.Close())
}
+
+ errs = append(errs, req.Reply(false, nil))
+ errs = append(errs, s.lifecycle.Close())
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
return errors.Join(errs...)
}
-func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) {
- buf := new(bytes.Buffer)
- err = binary.Write(buf, binary.BigEndian, uint32(port))
- if err != nil {
- return err
- }
-
- err = req.Reply(true, buf.Bytes())
- if err != nil {
- return err
- }
- return nil
-}
-
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
- err := s.approveForwardingRequest(req, portToBind)
+ replyPayload := struct {
+ BoundPort uint32
+ }{
+ BoundPort: uint32(portToBind),
+ }
+ err := req.Reply(true, ssh.Marshal(replyPayload))
+
if err != nil {
return err
}
@@ -330,9 +325,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
}
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
- reader := bytes.NewReader(req.Payload)
-
- address, port, err := s.parseForwardPayload(reader)
+ address, port, err := s.parseForwardPayload(req.Payload)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
}
@@ -346,7 +339,7 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) error {
}
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
- randomString, err := random.GenerateRandomString(20)
+ randomString, err := s.randomizer.String(20)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
}
@@ -364,13 +357,13 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
- return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
+ return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
}
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
listener, err := tcpServer.Listen()
if err != nil {
- return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
+ return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
@@ -393,18 +386,6 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
return nil
}
-func readSSHString(reader io.Reader) (string, error) {
- var length uint32
- if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
- return "", err
- }
- strBytes := make([]byte, length)
- if _, err := reader.Read(strBytes); err != nil {
- return "", err
- }
- return string(strBytes), nil
-}
-
func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 {
return false
diff --git a/session/session_test.go b/session/session_test.go
new file mode 100644
index 0000000..8a89125
--- /dev/null
+++ b/session/session_test.go
@@ -0,0 +1,1360 @@
+package session
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/binary"
+ "encoding/pem"
+ "fmt"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+ "tunnel_pls/internal/config"
+ "tunnel_pls/internal/registry"
+ "tunnel_pls/session/lifecycle"
+ "tunnel_pls/types"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/crypto/ssh"
+)
+
+type mockRandom struct {
+ mock.Mock
+}
+
+func (m *mockRandom) String(length int) (string, error) {
+ args := m.Called(length)
+ return args.String(0), args.Error(1)
+}
+
+type mockConfig struct {
+ mock.Mock
+ config.Config
+}
+
+func (m *mockConfig) Domain() string { return m.Called().String(0) }
+func (m *mockConfig) SSHPort() string { return m.Called().String(0) }
+func (m *mockConfig) Mode() types.ServerMode {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return 0
+ }
+ switch v := args.Get(0).(type) {
+ case types.ServerMode:
+ return v
+ case int:
+ return types.ServerMode(v)
+ default:
+ return types.ServerMode(args.Int(0))
+ }
+}
+func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
+
+type mockRegistry struct {
+ mock.Mock
+ registry.Registry
+ removedKey types.SessionKey
+}
+
+func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool {
+ return m.Called(key, session).Bool(0)
+}
+
+func (m *mockRegistry) Remove(key types.SessionKey) {
+ m.removedKey = key
+}
+
+type mockPort struct {
+ mock.Mock
+}
+
+func (m *mockPort) AddRange(startPort, endPort uint16) error {
+ return m.Called(startPort, endPort).Error(0)
+}
+func (m *mockPort) Unassigned() (uint16, bool) {
+ args := m.Called()
+ var port uint16
+ if args.Get(0) != nil {
+ switch v := args.Get(0).(type) {
+ case int:
+ port = uint16(v)
+ case uint16:
+ port = v
+ case uint32:
+ port = uint16(v)
+ case int32:
+ port = uint16(v)
+ case float64:
+ port = uint16(v)
+ default:
+ port = uint16(args.Int(0))
+ }
+ }
+ return port, args.Bool(1)
+}
+func (m *mockPort) SetStatus(port uint16, assigned bool) error {
+ return m.Called(port, assigned).Error(0)
+}
+func (m *mockPort) Claim(port uint16) bool {
+ return m.Called(port).Bool(0)
+}
+
+type mockSSHConn struct {
+ ssh.Conn
+ mock.Mock
+}
+
+func (m *mockSSHConn) Wait() error {
+ return m.Called().Error(0)
+}
+
+func (m *mockSSHConn) Close() error {
+ return m.Called().Error(0)
+}
+
+func (m *mockSSHConn) User() string {
+ return m.Called().String(0)
+}
+
+func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, sChans <-chan ssh.NewChannel, cConn ssh.Conn, cleanup func()) {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ require.NoError(t, err)
+
+ privDER := x509.MarshalPKCS1PrivateKey(key)
+ privBlock := pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Bytes: privDER,
+ }
+ pk, err := ssh.ParsePrivateKey(pem.EncodeToMemory(&privBlock))
+ require.NoError(t, err)
+
+ sCfg := &ssh.ServerConfig{
+ NoClientAuth: true,
+ }
+ sCfg.AddHostKey(pk)
+
+ cCfg := &ssh.ClientConfig{
+ User: "test",
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ Timeout: 5 * time.Second,
+ }
+
+ var sConnObj *ssh.ServerConn
+ var sChansChan <-chan ssh.NewChannel
+ var sReqsChan <-chan *ssh.Request
+
+ errChan := make(chan error, 1)
+ go func() {
+ conn, err := l.Accept()
+ if err != nil {
+ errChan <- err
+ return
+ }
+ sConnObj, sChansChan, sReqsChan, err = ssh.NewServerConn(conn, sCfg)
+ errChan <- err
+ }()
+
+ conn, err := net.Dial("tcp", l.Addr().String())
+ require.NoError(t, err)
+ cConnObj, cChans, cReqs, err := ssh.NewClientConn(conn, "pipe", cCfg)
+ require.NoError(t, err)
+
+ go ssh.DiscardRequests(cReqs)
+ go func() {
+ for newChan := range cChans {
+ if newChan.ChannelType() == "session" {
+ continue
+ }
+ err = newChan.Reject(ssh.Prohibited, "")
+ assert.NoError(t, err)
+ }
+ }()
+
+ select {
+ case err := <-errChan:
+ require.NoError(t, err)
+ case <-time.After(5 * time.Second):
+ t.Fatal("SSH handshake timed out")
+ }
+
+ return sConnObj, sReqsChan, sChansChan, cConnObj, func() {
+ _ = cConnObj.Close()
+ _ = sConnObj.Close()
+ _ = l.Close()
+ }
+}
+
+func TestNew(t *testing.T) {
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: &mockConfig{},
+ Conn: &ssh.ServerConn{},
+ InitialReq: make(chan *ssh.Request),
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: &mockRegistry{},
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }
+
+ s := New(conf)
+ assert.NotNil(t, s)
+ assert.NotNil(t, s.Lifecycle())
+ assert.NotNil(t, s.Interaction())
+ assert.NotNil(t, s.Forwarder())
+ assert.NotNil(t, s.Slug())
+}
+
+func TestDetail(t *testing.T) {
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: &mockConfig{},
+ Conn: &ssh.ServerConn{},
+ InitialReq: make(chan *ssh.Request),
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: &mockRegistry{},
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }
+
+ s := New(conf).(*session)
+ s.forwarder.SetType(types.TunnelTypeHTTP)
+ s.slug.Set("test-slug")
+ s.lifecycle.SetStatus(types.SessionStatusRUNNING)
+
+ detail := s.Detail()
+ assert.Equal(t, "HTTP", detail.ForwardingType)
+ assert.Equal(t, "test-slug", detail.Slug)
+ assert.Equal(t, "testuser", detail.UserID)
+ assert.True(t, detail.Active)
+
+ s.forwarder.SetType(types.TunnelTypeTCP)
+ detail = s.Detail()
+ assert.Equal(t, "TCP", detail.ForwardingType)
+
+ s.forwarder.SetType(types.TunnelTypeUNKNOWN)
+ detail = s.Detail()
+ assert.Equal(t, "UNKNOWN", detail.ForwardingType)
+}
+
+func TestIsBlockedPort(t *testing.T) {
+ tests := []struct {
+ port uint16
+ expected bool
+ }{
+ {80, false},
+ {443, false},
+ {22, true},
+ {1023, true},
+ {1024, false},
+ {1080, true},
+ {3306, true},
+ {8080, true},
+ {0, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(fmt.Sprintf("Port %d", tt.port), func(t *testing.T) {
+ assert.Equal(t, tt.expected, isBlockedPort(tt.port))
+ })
+ }
+}
+
+func TestHandleGlobalRequest(t *testing.T) {
+ _, sReqs, _, cConn, cleanup := setupSSH(t)
+ defer cleanup()
+
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: &mockConfig{},
+ Conn: &ssh.ServerConn{},
+ InitialReq: make(chan *ssh.Request),
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: &mockRegistry{},
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+
+ done := make(chan struct{})
+ go func() {
+ _ = s.HandleGlobalRequest(sReqs)
+ close(done)
+ }()
+
+ tests := []struct {
+ name string
+ reqType string
+ payload []byte
+ wantReply bool
+ expected bool
+ }{
+ {"shell", "shell", nil, true, true},
+ {"pty-req", "pty-req", nil, true, true},
+ {"window-change valid", "window-change", make([]byte, 16), true, true},
+ {"window-change invalid", "window-change", make([]byte, 4), true, false},
+ {"unknown", "unknown", nil, true, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ok, _, err := cConn.SendRequest(tt.reqType, tt.wantReply, tt.payload)
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, ok)
+ })
+ }
+
+ err := cConn.Close()
+ assert.NoError(t, err)
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("HandleGlobalRequest timed out after cConn.Close()")
+ }
+}
+
+func TestHandleTCPIPForward_Table(t *testing.T) {
+ setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
+ sConn, sReqs, _, cConn, cleanup := setupSSH(t)
+ mRegistry := &mockRegistry{}
+ mPort := &mockPort{}
+ mRandom := &mockRandom{}
+ conf := &Config{
+ Randomizer: mRandom,
+ Config: &mockConfig{},
+ Conn: sConn,
+ InitialReq: make(chan *ssh.Request),
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: mRegistry,
+ PortRegistry: mPort,
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+ return s, mRegistry, mPort, mRandom, sConn, sReqs, cConn, cleanup
+ }
+
+ t.Run("HTTP Forward Success", func(t *testing.T) {
+ s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mRandom.On("String", 20).Return("test-slug-1234567890", nil)
+ mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 80)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ var req *ssh.Request
+ select {
+ case req = <-sReqs:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for tcpip-forward request")
+ }
+
+ err := s.HandleTCPIPForward(req)
+ assert.NoError(t, err)
+ assert.Equal(t, "test-slug-1234567890", s.slug.String())
+ })
+
+ t.Run("TCP Forward Success", func(t *testing.T) {
+ s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mPort.On("Claim", mock.Anything).Return(true)
+ mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 0)
+
+ mPort.On("Unassigned").Return(uint16(12345), true)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ var req *ssh.Request
+ select {
+ case req = <-sReqs:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for tcpip-forward request")
+ }
+
+ err := s.HandleTCPIPForward(req)
+ assert.NoError(t, err)
+ assert.Equal(t, uint16(12345), s.forwarder.ForwardedPort())
+ })
+
+ t.Run("Invalid Payload", func(t *testing.T) {
+ s, _, _, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ payload := []byte{0, 0, 0}
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ var req *ssh.Request
+ select {
+ case req = <-sReqs:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for tcpip-forward request")
+ }
+
+ err := s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ })
+
+ t.Run("Blocked Port", func(t *testing.T) {
+ s, _, _, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 22)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ var req *ssh.Request
+ select {
+ case req = <-sReqs:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for tcpip-forward request")
+ }
+
+ err := s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ })
+}
+
+func TestStart_Table(t *testing.T) {
+ setup := func(t *testing.T) (*session, *Config, ssh.Conn, func()) {
+ sConn, sReqs, sChans, cConn, cleanup := setupSSH(t)
+ mRegistry := &mockRegistry{}
+ mPort := &mockPort{}
+ mRandom := &mockRandom{}
+ mConfig := &mockConfig{}
+ mConfig.On("Mode").Return(types.ServerModeSTANDALONE)
+ mConfig.On("Domain").Return("example.com")
+ mConfig.On("SSHPort").Return("2222")
+
+ conf := &Config{
+ Randomizer: mRandom,
+ Config: mConfig,
+ Conn: sConn,
+ InitialReq: sReqs,
+ SshChan: sChans,
+ SessionRegistry: mRegistry,
+ PortRegistry: mPort,
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+ return s, conf, cConn, cleanup
+ }
+
+ t.Run("Full Success TCP", func(t *testing.T) {
+ s, conf, cConn, cleanup := setup(t)
+ defer cleanup()
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 0)
+
+ conf.PortRegistry.(*mockPort).On("Claim", mock.Anything).Return(true)
+ conf.PortRegistry.(*mockPort).On("Unassigned").Return(uint16(0), true)
+ conf.PortRegistry.(*mockPort).On("SetStatus", mock.AnythingOfType("uint16"), mock.Anything).Return(nil)
+ conf.SessionRegistry.(*mockRegistry).On("Register", mock.Anything, mock.Anything).Return(true)
+ conf.Config.(*mockConfig).On("TLSEnabled").Return(false)
+ go func() {
+ time.Sleep(200 * time.Millisecond)
+ ch, reqs, err := cConn.OpenChannel("session", nil)
+ if err == nil {
+ go ssh.DiscardRequests(reqs)
+ time.Sleep(200 * time.Millisecond)
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ time.Sleep(200 * time.Millisecond)
+ write, err := ch.Write([]byte("q"))
+ assert.NoError(t, err)
+ assert.NotZero(t, write)
+ time.Sleep(100 * time.Millisecond)
+ _ = ch.Close()
+ }
+ _ = cConn.Close()
+ }()
+
+ err := s.Start()
+ assert.NoError(t, err)
+ })
+
+ t.Run("Headless mode success", func(t *testing.T) {
+ s, conf, cConn, cleanup := setup(t)
+ defer cleanup()
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 80)
+
+ conf.Randomizer.(*mockRandom).On("String", 20).Return("headless-slug", nil)
+ conf.SessionRegistry.(*mockRegistry).On("Register", mock.Anything, mock.Anything).Return(true)
+
+ go func() {
+ time.Sleep(600 * time.Millisecond)
+ _, _, err := cConn.SendRequest("tcpip-forward", true, payload)
+ assert.NoError(t, err)
+
+ time.Sleep(100 * time.Millisecond)
+ err = cConn.Close()
+ assert.NoError(t, err)
+
+ }()
+
+ err := s.Start()
+ assert.NoError(t, err)
+ })
+
+ t.Run("Missing Forward Request", func(t *testing.T) {
+ s, _, cConn, cleanup := setup(t)
+ defer cleanup()
+
+ go func() {
+ time.Sleep(1200 * time.Millisecond)
+ _ = cConn.Close()
+ }()
+
+ err := s.Start()
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no forwarding Request")
+ })
+
+ t.Run("Unauthorized Headless", func(t *testing.T) {
+ _, conf, cConn, cleanup := setup(t)
+ defer cleanup()
+
+ conf.User = "UNAUTHORIZED"
+ s := New(conf).(*session)
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 80)
+
+ go func() {
+ time.Sleep(600 * time.Millisecond)
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ err := s.Start()
+ assert.Error(t, err)
+ })
+}
+
+func TestForwardingFailures(t *testing.T) {
+ setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
+ sConn, sReqs, _, cConn, cleanup := setupSSH(t)
+ mRegistry := &mockRegistry{}
+ mPort := &mockPort{}
+ mRandom := &mockRandom{}
+ conf := &Config{
+ Randomizer: mRandom,
+ Config: &mockConfig{},
+ Conn: sConn,
+ InitialReq: make(chan *ssh.Request),
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: mRegistry,
+ PortRegistry: mPort,
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+ return s, mRegistry, mPort, mRandom, sConn, sReqs, cConn, cleanup
+ }
+
+ t.Run("HTTP Registration Failed", func(t *testing.T) {
+ s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mRandom.On("String", 20).Return("test-slug", nil)
+ mRegistry.On("Register", mock.Anything, mock.Anything).Return(false)
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 80)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ var req *ssh.Request
+ select {
+ case req = <-sReqs:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for tcpip-forward request")
+ }
+
+ err := s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ })
+
+ t.Run("TCP Port Claim Failed", func(t *testing.T) {
+ s, _, mPort, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mPort.On("Claim", mock.Anything).Return(false)
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 1234)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ var req *ssh.Request
+ select {
+ case req = <-sReqs:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for tcpip-forward request")
+ }
+
+ err := s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ })
+
+ t.Run("HTTP Randomizer Error", func(t *testing.T) {
+ s, _, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mRandom.On("String", 20).Return("", fmt.Errorf("random error"))
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 80)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ req := <-sReqs
+ err := s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "random error")
+ })
+
+ t.Run("Port Registry No Port", func(t *testing.T) {
+ s, _, mPort, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mPort.On("Unassigned").Return(uint16(0), false)
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 0)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ req := <-sReqs
+ err := s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "no available port")
+ })
+
+ t.Run("Port too large", func(t *testing.T) {
+ s, _, _, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 70000)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ req := <-sReqs
+ err := s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "port is larger than allowed")
+ })
+
+ t.Run("TCP Registration Failed", func(t *testing.T) {
+ s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mPort.On("Claim", mock.Anything).Return(true)
+ mRegistry.On("Register", mock.Anything, mock.Anything).Return(false)
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 1234)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ req := <-sReqs
+ err := s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "Failed to register TunnelTypeTCP client")
+ })
+
+ t.Run("Finalize Forwarding Failure", func(t *testing.T) {
+ s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mRandom.On("String", 20).Return("test-slug", nil)
+ mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], 80)
+
+ go func() {
+ _, _, err := cConn.SendRequest("tcpip-forward", true, payload)
+ assert.Error(t, err, io.EOF)
+ }()
+
+ req := <-sReqs
+ err := cConn.Close()
+ assert.NoError(t, err)
+
+ time.Sleep(50 * time.Millisecond)
+
+ err = s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ })
+
+ t.Run("TCP Listen Failure", func(t *testing.T) {
+ s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mPort.On("Claim", mock.Anything).Return(true)
+ mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
+
+ l, err := net.Listen("tcp", "0.0.0.0:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(l net.Listener) {
+ err = l.Close()
+ assert.NoError(t, err)
+ }(l)
+ _, portStr, _ := net.SplitHostPort(l.Addr().String())
+ port, _ := strconv.Atoi(portStr)
+
+ payload := make([]byte, 4+9+4)
+ binary.BigEndian.PutUint32(payload[0:4], 9)
+ copy(payload[4:13], "localhost")
+ binary.BigEndian.PutUint32(payload[13:17], uint32(port))
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
+ }()
+
+ req := <-sReqs
+ err = s.HandleTCPIPForward(req)
+ assert.Error(t, err)
+ assert.Contains(t, err.Error(), "is already in use or restricted")
+ })
+}
+
+func TestSetupInteractiveMode_Error(t *testing.T) {
+ sConn, _, sChans, _, cleanup := setupSSH(t)
+ defer cleanup()
+
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: &mockConfig{},
+ Conn: sConn,
+ InitialReq: make(chan *ssh.Request),
+ SshChan: sChans,
+ SessionRegistry: &mockRegistry{},
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+
+ mockChan := &mockNewChanFail{}
+ err := s.setupInteractiveMode(mockChan)
+ if err == nil {
+ t.Error("expected error, got nil")
+ }
+}
+
+type mockNewChanFail struct {
+ ssh.NewChannel
+}
+
+func (m *mockNewChanFail) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
+ return nil, nil, fmt.Errorf("accept failed")
+}
+
+func TestWaitForTCPIPForward_EdgeCases(t *testing.T) {
+ t.Run("Wrong Request Type", func(t *testing.T) {
+ _, sReqs, _, cConn, cleanup := setupSSH(t)
+ defer cleanup()
+
+ s := &session{initialReq: sReqs}
+
+ go func() {
+ _, _, _ = cConn.SendRequest("not-tcpip-forward", true, nil)
+ }()
+
+ req := s.waitForTCPIPForward()
+ if req != nil {
+ t.Error("expected nil request")
+ }
+ })
+
+ t.Run("Channel Closed", func(t *testing.T) {
+ initialReq := make(chan *ssh.Request)
+ s := &session{initialReq: initialReq}
+ close(initialReq)
+
+ req := s.waitForTCPIPForward()
+ if req != nil {
+ t.Error("expected nil request")
+ }
+ })
+}
+
+func TestSetupSessionMode_ChannelClosed(t *testing.T) {
+ sshChan := make(chan ssh.NewChannel)
+ s := &session{sshChan: sshChan}
+ close(sshChan)
+
+ err := s.setupSessionMode()
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+}
+
+func TestStart_SetupSessionModeError(t *testing.T) {
+ sshChan := make(chan ssh.NewChannel, 1)
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: &mockConfig{},
+ Conn: &ssh.ServerConn{},
+ InitialReq: make(chan *ssh.Request),
+ SshChan: sshChan,
+ SessionRegistry: &mockRegistry{},
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+
+ mockChan := &mockNewChanFail{}
+ sshChan <- mockChan
+
+ err := s.Start()
+ if err == nil {
+ t.Error("expected error, got nil")
+ }
+}
+
+func TestWaitForSessionEnd_Error(t *testing.T) {
+ mConn := &mockSSHConn{}
+ mConn.On("Wait").Return(fmt.Errorf("wait error"))
+ mConn.On("Close").Return(nil)
+
+ mForwarder := &mockLifecycleForwarder{}
+ mForwarder.On("TunnelType").Return(types.TunnelTypeTCP)
+ mForwarder.On("ForwardedPort").Return(uint16(80))
+ mForwarder.On("Close").Return(fmt.Errorf("close error"))
+
+ mSlug := &mockLifecycleSlug{}
+ mSlug.On("String").Return("slug")
+
+ mPort := &mockPort{}
+ mPort.On("SetStatus", mock.Anything, mock.Anything).Return(nil)
+
+ mRegistry := &mockRegistry{}
+ mRegistry.On("Remove", mock.Anything).Return()
+
+ l := lifecycle.New(mConn, mForwarder, mSlug, mPort, mRegistry, "testuser")
+ s := &session{
+ lifecycle: l,
+ }
+
+ err := s.waitForSessionEnd()
+ assert.Error(t, err)
+}
+
+type mockLifecycleForwarder struct {
+ mock.Mock
+ lifecycle.Forwarder
+}
+
+func (m *mockLifecycleForwarder) TunnelType() types.TunnelType {
+ return m.Called().Get(0).(types.TunnelType)
+}
+func (m *mockLifecycleForwarder) ForwardedPort() uint16 {
+ args := m.Called()
+ if args.Get(0) == nil {
+ return 0
+ }
+ switch v := args.Get(0).(type) {
+ case uint16:
+ return v
+ case uint32:
+ return uint16(v)
+ case uint64:
+ return uint16(v)
+ case uint8:
+ return uint16(v)
+ case uint:
+ return uint16(v)
+ case int:
+ return uint16(v)
+ case int8:
+ return uint16(v)
+ case int16:
+ return uint16(v)
+ case int32:
+ return uint16(v)
+ case int64:
+ return uint16(v)
+ case float32:
+ return uint16(v)
+ case float64:
+ return uint16(v)
+ default:
+ return uint16(args.Int(0))
+ }
+}
+func (m *mockLifecycleForwarder) Close() error {
+ return m.Called().Error(0)
+}
+
+type mockLifecycleSlug struct {
+ mock.Mock
+}
+
+func (m *mockLifecycleSlug) String() string { return m.Called().String(0) }
+func (m *mockLifecycleSlug) Set(slug string) {
+ m.Called(slug)
+}
+
+func TestHandleMissingForwardRequest(t *testing.T) {
+ mConn := &mockSSHConn{}
+ mConfig := &mockConfig{}
+ mConfig.On("Domain").Return("example.com")
+ mConfig.On("SSHPort").Return("2222")
+ mConn.On("Close").Return(nil)
+
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: mConfig,
+ Conn: &ssh.ServerConn{Conn: mConn},
+ InitialReq: make(chan *ssh.Request),
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: &mockRegistry{},
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }
+
+ s := New(conf).(*session)
+
+ err := s.handleMissingForwardRequest()
+ if err == nil {
+ t.Error("expected error, got nil")
+ }
+}
+
+func TestParseForwardPayload_Errors(t *testing.T) {
+ s := &session{}
+
+ t.Run("Short Address", func(t *testing.T) {
+ _, _, err := s.parseForwardPayload([]byte{0, 0, 0, 4})
+ if err == nil {
+ t.Error("expected error, got nil")
+ }
+ })
+
+ t.Run("Short Port", func(t *testing.T) {
+ payload := append([]byte{0, 0, 0, 4}, []byte("addr")...)
+ _, _, err := s.parseForwardPayload(payload)
+ if err == nil {
+ t.Error("expected error, got nil")
+ }
+ })
+
+ t.Run("Blocked Port", func(t *testing.T) {
+ payload := append([]byte{0, 0, 0, 4}, []byte("addr")...)
+ portBuf := make([]byte, 4)
+ binary.BigEndian.PutUint32(portBuf, 22)
+ payload = append(payload, portBuf...)
+ _, _, err := s.parseForwardPayload(payload)
+ if err == nil {
+ t.Error("expected error, got nil")
+ } else if !strings.Contains(err.Error(), "port is block") {
+ t.Errorf("expected error to contain %q, got %q", "port is block", err.Error())
+ }
+ })
+}
+
+func TestDenyForwardingRequest_TunnelNotSetupYet(t *testing.T) {
+ sConn, sReqs, _, cConn, cleanup := setupSSH(t)
+ defer cleanup()
+
+ mRegistry := &mockRegistry{}
+ mPort := &mockPort{}
+ mRandom := &mockRandom{}
+ conf := &Config{
+ Randomizer: mRandom,
+ Config: &mockConfig{},
+ Conn: sConn,
+ InitialReq: sReqs,
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: mRegistry,
+ PortRegistry: mPort,
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+
+ go func() {
+ _, _, _ = cConn.SendRequest("tcpip-forward", true, nil)
+ }()
+
+ var req *ssh.Request
+ select {
+ case req = <-sReqs:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+
+ key := &types.SessionKey{Id: "", Type: types.TunnelTypeUNKNOWN}
+ err := s.denyForwardingRequest(req, key, &mockCloser{}, "test error")
+ if err == nil {
+ t.Error("expected error, got nil")
+ } else if !strings.Contains(err.Error(), "test error") {
+ t.Errorf("expected error to contain %q, got %q", "test error", err.Error())
+ }
+ assert.Equal(t, *key, mRegistry.removedKey)
+}
+
+func TestDenyForwardingRequest_Full(t *testing.T) {
+ setup := func(t *testing.T) (*session, *mockRegistry, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
+ sConn, sReqs, _, cConn, cleanup := setupSSH(t)
+ mRegistry := &mockRegistry{}
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: &mockConfig{},
+ Conn: sConn,
+ InitialReq: sReqs,
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: mRegistry,
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+ return s, mRegistry, sConn, sReqs, cConn, cleanup
+ }
+
+ getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request {
+ go func() {
+ _, _, _ = client.SendRequest("tcpip-forward", true, nil)
+ }()
+ select {
+ case req, ok := <-serverReqs:
+ if !ok {
+ t.Fatal("channel closed")
+ }
+ return req
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout getting request")
+ return nil
+ }
+ }
+
+ t.Run("All Success", func(t *testing.T) {
+ s, mRegistry, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ req := getReq(t, cConn, sReqs)
+ key := &types.SessionKey{Id: "test", Type: types.TunnelTypeHTTP}
+
+ s.slug.Set("test")
+ s.forwarder.SetType(types.TunnelTypeHTTP)
+
+ mCloser := &mockCloser{}
+ err := s.denyForwardingRequest(req, key, mCloser, "error")
+ if err == nil {
+ t.Error("expected error, got nil")
+ } else if !strings.Contains(err.Error(), "error") {
+ t.Errorf("expected error to contain %q, got %q", "error", err.Error())
+ }
+ assert.Equal(t, *key, mRegistry.removedKey)
+ })
+
+ t.Run("Listener Close error", func(t *testing.T) {
+ s, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ req := getReq(t, cConn, sReqs)
+ mCloser := &mockCloser{err: fmt.Errorf("close error")}
+ err := s.denyForwardingRequest(req, nil, mCloser, "error")
+ assert.Error(t, err, net.ErrClosed)
+ })
+
+ t.Run("Reply error", func(t *testing.T) {
+ s, _, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ req := getReq(t, cConn, sReqs)
+ err := cConn.Close()
+ assert.NoError(t, err)
+
+ time.Sleep(100 * time.Millisecond)
+
+ err = s.denyForwardingRequest(req, nil, nil, assert.AnError.Error())
+ assert.Error(t, err, assert.AnError)
+ })
+}
+
+func TestHandleTCPForward_Failures(t *testing.T) {
+ setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
+ sConn, sReqs, _, cConn, cleanup := setupSSH(t)
+ mRegistry := &mockRegistry{}
+ mPort := &mockPort{}
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: &mockConfig{},
+ Conn: sConn,
+ InitialReq: sReqs,
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: mRegistry,
+ PortRegistry: mPort,
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+ return s, mRegistry, mPort, sConn, sReqs, cConn, cleanup
+ }
+
+ getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request {
+ go func() {
+ _, _, _ = client.SendRequest("tcpip-forward", true, nil)
+ }()
+ select {
+ case req, ok := <-serverReqs:
+ if !ok {
+ t.Fatal("channel closed")
+ }
+ return req
+ case <-time.After(2 * time.Second):
+ t.Fatal("timeout getting request")
+ return nil
+ }
+ }
+
+ t.Run("Port Claim fail", func(t *testing.T) {
+ s, _, mPort, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mPort.On("Claim", mock.Anything).Return(false)
+ err := s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", 1234)
+ if err == nil {
+ t.Error("expected error, got nil")
+ } else if !strings.Contains(err.Error(), "already in use") {
+ t.Errorf("expected error to contain %q, got %q", "already in use", err.Error())
+ }
+ })
+
+ t.Run("Listen fail", func(t *testing.T) {
+ s, _, mPort, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mPort.On("Claim", mock.Anything).Return(true)
+ l, err := net.Listen("tcp", "0.0.0.0:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(l net.Listener) {
+ err = l.Close()
+ assert.NoError(t, err)
+ }(l)
+ port := uint16(l.Addr().(*net.TCPAddr).Port)
+
+ err = s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", port)
+ if err == nil {
+ t.Error("expected error, got nil")
+ } else if !strings.Contains(err.Error(), "already in use") {
+ t.Errorf("expected error to contain %q, got %q", "already in use", err.Error())
+ }
+ })
+
+ t.Run("Registry Register fail", func(t *testing.T) {
+ s, mRegistry, mPort, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mPort.On("Claim", mock.Anything).Return(true)
+ mRegistry.On("Register", mock.Anything, mock.Anything).Return(false)
+ err := s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", 0)
+ if err == nil {
+ t.Error("expected error, got nil")
+ } else if !strings.Contains(err.Error(), "Failed to register") {
+ t.Errorf("expected error to contain %q, got %q", "Failed to register", err.Error())
+ }
+ })
+
+ t.Run("Finalize fail (Reply fail)", func(t *testing.T) {
+ s, mRegistry, mPort, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mPort.On("Claim", mock.Anything).Return(true)
+ mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
+ req := getReq(t, cConn, sReqs)
+ err := cConn.Close()
+ assert.NoError(t, err)
+ time.Sleep(100 * time.Millisecond)
+
+ err = s.HandleTCPForward(req, "localhost", 0)
+ if err == nil {
+ t.Error("expected error, got nil")
+ } else if !strings.Contains(err.Error(), "Failed to finalize forwarding") {
+ t.Errorf("expected error to contain %q, got %q", "Failed to finalize forwarding", err.Error())
+ }
+ })
+}
+
+func TestHandleHTTPForward_Failures(t *testing.T) {
+ setup := func(t *testing.T) (*session, *mockRegistry, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
+ sConn, sReqs, _, cConn, cleanup := setupSSH(t)
+ mRegistry := &mockRegistry{}
+ mRandom := &mockRandom{}
+ s := New(&Config{
+ Randomizer: mRandom,
+ Config: &mockConfig{},
+ Conn: sConn,
+ InitialReq: sReqs,
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: mRegistry,
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }).(*session)
+ return s, mRegistry, mRandom, sConn, sReqs, cConn, cleanup
+ }
+
+ getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request {
+ go func() { _, _, _ = client.SendRequest("tcpip-forward", true, nil) }()
+ return <-serverReqs
+ }
+
+ t.Run("Random fail", func(t *testing.T) {
+ s, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mRandom.On("String", 20).Return("", fmt.Errorf("random error"))
+ err := s.HandleHTTPForward(getReq(t, cConn, sReqs), 80)
+ if err == nil {
+ t.Error("expected error, got nil")
+ } else if !strings.Contains(err.Error(), "Failed to create slug") {
+ t.Errorf("expected error to contain %q, got %q", "Failed to create slug", err.Error())
+ }
+ })
+
+ t.Run("Register fail", func(t *testing.T) {
+ s, mRegistry, mRandom, _, sReqs, cConn, cleanup := setup(t)
+ defer cleanup()
+ mRandom.On("String", 20).Return("slug", nil)
+ mRegistry.On("Register", mock.Anything, mock.Anything).Return(false)
+ err := s.HandleHTTPForward(getReq(t, cConn, sReqs), 80)
+ if err == nil {
+ t.Error("expected error, got nil")
+ } else if !strings.Contains(err.Error(), "Failed to register") {
+ t.Errorf("expected error to contain %q, got %q", "Failed to register", err.Error())
+ }
+ })
+}
+
+func TestHandleGlobalRequest_Failures(t *testing.T) {
+ _, sReqs, _, cConn, cleanup := setupSSH(t)
+ defer cleanup()
+
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: &mockConfig{},
+ Conn: &ssh.ServerConn{},
+ InitialReq: make(chan *ssh.Request),
+ SshChan: make(chan ssh.NewChannel),
+ SessionRegistry: &mockRegistry{},
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+
+ done := make(chan struct{})
+ go func() {
+ _ = s.HandleGlobalRequest(sReqs)
+ close(done)
+ }()
+
+ tests := []struct {
+ name string
+ reqType string
+ payload []byte
+ wantReply bool
+ expected bool
+ }{
+ {"shell", "shell", nil, true, true},
+ {"pty-req", "pty-req", nil, true, true},
+ {"window-change valid", "window-change", make([]byte, 16), true, true},
+ {"window-change invalid", "window-change", make([]byte, 4), true, false},
+ {"unknown", "unknown", nil, true, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ok, _, err := cConn.SendRequest(tt.reqType, tt.wantReply, tt.payload)
+ assert.NoError(t, err)
+ assert.Equal(t, tt.expected, ok)
+ })
+ }
+
+ err := cConn.Close()
+ assert.NoError(t, err)
+
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ t.Fatal("HandleGlobalRequest timed out after cConn.Close()")
+ }
+}
+
+func TestSetupInteractiveMode_GlobalRequestError(t *testing.T) {
+ sConn, _, sChans, _, cleanup := setupSSH(t)
+ defer cleanup()
+
+ conf := &Config{
+ Randomizer: &mockRandom{},
+ Config: &mockConfig{},
+ Conn: sConn,
+ InitialReq: make(chan *ssh.Request),
+ SshChan: sChans,
+ SessionRegistry: &mockRegistry{},
+ PortRegistry: &mockPort{},
+ User: "testuser",
+ }
+ s := New(conf).(*session)
+
+ mockChan := &mockNewChanFail{}
+ err := s.setupInteractiveMode(mockChan)
+ if err == nil {
+ t.Error("expected error, got nil")
+ }
+}
+
+type mockCloser struct {
+ err error
+}
+
+func (m *mockCloser) Close() error { return m.err }
diff --git a/session/slug/slug_test.go b/session/slug/slug_test.go
new file mode 100644
index 0000000..c7af138
--- /dev/null
+++ b/session/slug/slug_test.go
@@ -0,0 +1,99 @@
+package slug
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/suite"
+)
+
+type SlugTestSuite struct {
+ suite.Suite
+ slug Slug
+}
+
+func (suite *SlugTestSuite) SetupTest() {
+ suite.slug = New()
+}
+
+func TestNew(t *testing.T) {
+ s := New()
+
+ assert.NotNil(t, s, "New() should return a non-nil Slug")
+ assert.Implements(t, (*Slug)(nil), s, "New() should return a type that implements Slug interface")
+ assert.Equal(t, "", s.String(), "New() should initialize with empty string")
+}
+
+func (suite *SlugTestSuite) TestString() {
+ assert.Equal(suite.T(), "", suite.slug.String(), "String() should return empty string initially")
+
+ suite.slug.Set("test-slug")
+ assert.Equal(suite.T(), "test-slug", suite.slug.String(), "String() should return the set value")
+}
+
+func (suite *SlugTestSuite) TestSet() {
+ testCases := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {
+ name: "simple slug",
+ input: "hello-world",
+ expected: "hello-world",
+ },
+ {
+ name: "empty string",
+ input: "",
+ expected: "",
+ },
+ {
+ name: "slug with numbers",
+ input: "test-123",
+ expected: "test-123",
+ },
+ {
+ name: "slug with special characters",
+ input: "hello_world-123",
+ expected: "hello_world-123",
+ },
+ {
+ name: "overwrite existing slug",
+ input: "new-slug",
+ expected: "new-slug",
+ },
+ }
+
+ for _, tc := range testCases {
+ suite.Run(tc.name, func() {
+ suite.slug.Set(tc.input)
+ assert.Equal(suite.T(), tc.expected, suite.slug.String())
+ })
+ }
+}
+
+func (suite *SlugTestSuite) TestMultipleSet() {
+ suite.slug.Set("first-slug")
+ assert.Equal(suite.T(), "first-slug", suite.slug.String())
+
+ suite.slug.Set("second-slug")
+ assert.Equal(suite.T(), "second-slug", suite.slug.String())
+
+ suite.slug.Set("")
+ assert.Equal(suite.T(), "", suite.slug.String())
+}
+
+func TestSlugIsolation(t *testing.T) {
+ slug1 := New()
+ slug2 := New()
+
+ slug1.Set("slug-one")
+ slug2.Set("slug-two")
+
+ assert.Equal(t, "slug-one", slug1.String(), "First slug should maintain its value")
+ assert.Equal(t, "slug-two", slug2.String(), "Second slug should maintain its value")
+}
+
+func TestSlugTestSuite(t *testing.T) {
+ suite.Run(t, new(SlugTestSuite))
+}
diff --git a/sonar-project.properties b/sonar-project.properties
deleted file mode 100644
index 277a293..0000000
--- a/sonar-project.properties
+++ /dev/null
@@ -1 +0,0 @@
-sonar.projectKey=tunnel-please
\ No newline at end of file
diff --git a/types/types.go b/types/types.go
index 34ccfb4..77d6ac4 100644
--- a/types/types.go
+++ b/types/types.go
@@ -7,6 +7,7 @@ type SessionStatus int
const (
SessionStatusINITIALIZING SessionStatus = iota
SessionStatusRUNNING
+ SessionStatusCLOSED
)
type InteractiveMode int