diff --git a/.gitea/workflows/sonarqube.yml b/.gitea/workflows/sonarqube.yml index 9c672ac..57625a2 100644 --- a/.gitea/workflows/sonarqube.yml +++ b/.gitea/workflows/sonarqube.yml @@ -13,8 +13,58 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: '1.25.5' + cache: false + + - name: Install dependencies + run: go mod tidy + + - name: Run go vet + run: go vet ./... 2>&1 | tee vet-results.txt + + - name: Run tests with coverage + run: | + go test ./... -v -coverprofile=coverage + + - name: Run GolangCI-Lint Analysis + uses: golangci/golangci-lint-action@v9 + with: + skip-cache: true + version: v2.6 + args: > + --issues-exit-code=0 + --output.text.path=stdout + --output.checkstyle.path=golangci-lint-report.xml + + - name: Set SonarQube project key + run: | + BRANCH_NAME=${GITHUB_REF#refs/heads/} + if [ "$BRANCH_NAME" = "main" ]; then + SONAR_PROJECT_KEY="tunnel-please" + else + BRANCH_KEY=${BRANCH_NAME//\//-} + SONAR_PROJECT_KEY="tunnel-please-$BRANCH_KEY" + fi + echo "SONAR_PROJECT_KEY=tunnel-please-$BRANCH_KEY" >> $GITHUB_ENV + echo "Using SonarQube Project Key: $SONAR_PROJECT_KEY" + - name: SonarQube Scan uses: SonarSource/sonarqube-scan-action@v7.0.0 env: SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }} SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }} + with: + args: > + -Dsonar.projectKey=${{ env.SONAR_PROJECT_KEY }} + -Dsonar.go.coverage.reportPaths=coverage + -Dsonar.test.inclusions=**/*_test.go + -Dsonar.test.exclusions=**/vendor/** + -Dsonar.exclusions=**/*_test.go,**/vendor/**,**/golangci-lint-report.xml + -Dsonar.go.govet.reportPaths=vet-results.txt + -Dsonar.go.golangci-lint.reportPaths=golangci-lint-report.xml + -Dsonar.sources=./ + -Dsonar.tests=./ \ No newline at end of file diff --git a/.gitignore b/.gitignore index dc40a4f..fd6e5af 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,6 @@ id_rsa* .env tmp certs -app \ No newline at end of file +app +coverage +test-results.json \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index c1c452d..6b8b765 100644 --- a/Dockerfile +++ b/Dockerfile @@ -22,7 +22,10 @@ RUN --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=0 GOOS=linux \ go build -trimpath \ - -ldflags="-w -s -X tunnel_pls/version.Version=${VERSION} -X tunnel_pls/version.BuildDate=${BUILD_DATE} -X tunnel_pls/version.Commit=${COMMIT}" \ + -ldflags="-w -s \ + -X tunnel_pls/internal/version.Version=${VERSION} \ + -X tunnel_pls/internal/version.BuildDate=${BUILD_DATE} \ + -X tunnel_pls/internal/version.Commit=${COMMIT}" \ -o /app/tunnel_pls \ . diff --git a/README.md b/README.md index 474f430..628efbe 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,22 @@ +
+ + gopher + # Tunnel Please -A lightweight SSH-based tunnel server written in Go that enables secure TCP and HTTP forwarding with an interactive terminal interface for managing connections and custom subdomains. +A lightweight SSH-based tunnel server + +

+ +[![Coverage](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=coverage&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please) +[![Lines of Code](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=ncloc&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please) +[![Quality Gate Status](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=alert_status&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please) +[![Security Issues](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_security_issues&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please) +[![Maintainability Rating](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_maintainability_rating&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please) +[![Reliability Rating](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_reliability_rating&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please) +[![Security Rating](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_security_rating&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please) + +
## Features @@ -17,108 +33,32 @@ A lightweight SSH-based tunnel server written in Go that enables secure TCP and The following environment variables can be configured in the `.env` file: -| Variable | Description | Default | Required | -|----------|-------------|---------|----------| -| `DOMAIN` | Domain name for subdomain routing | `localhost` | No | -| `PORT` | SSH server port | `2200` | No | -| `HTTP_PORT` | HTTP server port | `8080` | No | -| `HTTPS_PORT` | HTTPS server port | `8443` | No | -| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No | -| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No | -| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@` | No | -| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | - | Yes (if auto-cert) | -| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No | -| `CORS_LIST` | Comma-separated list of allowed CORS origins | - | No | -| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No | -| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No | -| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No | -| `PPROF_PORT` | Port for pprof server | `6060` | No | -| `MODE` | Runtime mode: `standalone` (default, no gRPC/auth) or `node` (enable gRPC + auth) | `standalone` | No | -| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No | -| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No | -| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | - (required in `node`) | Yes (node mode) | +| Variable | Description | Default | Required | +|---------------------|-----------------------------------------------------------------------------|-------------------------|---------------------| +| `DOMAIN` | Domain name for subdomain routing | `localhost` | No | +| `PORT` | SSH server port | `2200` | No | +| `HTTP_PORT` | HTTP server port | `8080` | No | +| `HTTPS_PORT` | HTTPS server port | `8443` | No | +| `KEY_LOC` | Path to the private key file | `certs/privkey.pem` | No | +| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No | +| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No | +| `TLS_STORAGE_PATH` | Path to store TLS certificates | `certs/tls/` | No | +| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@` | No | +| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | `-` | Yes (if auto-cert) | +| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No | +| `CORS_LIST` | Comma-separated list of allowed CORS origins | `-` | No | +| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No | +| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No | +| `MAX_HEADER_SIZE` | Maximum size of HTTP headers in bytes (4096-131072) | `4096` | No | +| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No | +| `PPROF_PORT` | Port for pprof server | `6060` | No | +| `MODE` | Runtime mode: `standalone` or `node` | `standalone` | No | +| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No | +| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No | +| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | `-` | Yes (node mode) | **Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality. -### Automatic TLS Certificate Management - -The server supports automatic TLS certificate generation and renewal using [CertMagic](https://github.com/caddyserver/certmagic) with Cloudflare DNS-01 challenge. This is required for wildcard certificate support (`*.yourdomain.com`). - -**Certificate Storage:** -- TLS certificates are stored in `certs/tls/` (relative to application directory) -- User-provided certificates: `certs/tls/cert.pem` and `certs/tls/privkey.pem` -- CertMagic automatic certificates: `certs/tls/certmagic/` -- SSH keys are stored separately in `certs/ssh/` - -**How it works:** -1. If user-provided certificates exist at `certs/tls/cert.pem` and `certs/tls/privkey.pem` and cover both `DOMAIN` and `*.DOMAIN`, they will be used -2. If certificates are missing, expired, expiring within 30 days, or don't cover the required domains, CertMagic will automatically obtain new certificates from Let's Encrypt -3. Certificates are automatically renewed before expiration -4. User-provided certificates support hot-reload (changes detected every 30 seconds) - -**Cloudflare API Token Setup:** - -To use automatic certificate generation, you need a Cloudflare API token with the following permissions: - -1. Go to [Cloudflare Dashboard](https://dash.cloudflare.com/profile/api-tokens) -2. Click "Create Token" -3. Use "Create Custom Token" with these permissions: - - **Zone → Zone → Read** (for all zones or specific zone) - - **Zone → DNS → Edit** (for all zones or specific zone) -4. Copy the token and set it as `CF_API_TOKEN` environment variable - -**Example configuration for automatic certificates:** -```env -DOMAIN=example.com -TLS_ENABLED=true -CF_API_TOKEN=your_cloudflare_api_token_here -ACME_EMAIL=admin@example.com -# ACME_STAGING=true # Uncomment for testing to avoid rate limits -``` - -### SSH Key Auto-Generation - -The application will automatically generate a new 4096-bit RSA key pair at `certs/ssh/id_rsa` if it doesn't exist. This makes it easier to get started without manually creating SSH keys. SSH keys are stored separately from TLS certificates. - -### Memory Optimization - -The application uses a buffer pool with controlled buffer sizes to prevent excessive memory usage under high concurrent loads. The `BUFFER_SIZE` environment variable controls the size of buffers used for io.Copy operations: - -- **Default:** 32768 bytes (32 KB) - Good balance for most scenarios -- **Minimum:** 4096 bytes (4 KB) - Lower memory usage, more CPU overhead -- **Maximum:** 1048576 bytes (1 MB) - Higher throughput, more memory usage - -**Recommended settings based on load:** -- **Low traffic (<100 concurrent):** `BUFFER_SIZE=32768` (default) -- **High traffic (>100 concurrent):** `BUFFER_SIZE=16384` or `BUFFER_SIZE=8192` -- **Very high traffic (>1000 concurrent):** `BUFFER_SIZE=8192` or `BUFFER_SIZE=4096` - -The buffer pool reuses buffers across connections, preventing memory fragmentation and reducing garbage collection pressure. - -### Profiling with pprof - -To enable profiling for performance analysis: - -1. Set `PPROF_ENABLED=true` in your `.env` file -2. Optionally set `PPROF_PORT` to your desired port (default: 6060) -3. Access profiling data at `http://localhost:6060/debug/pprof/` - -Common pprof endpoints: -- `/debug/pprof/` - Index page with available profiles -- `/debug/pprof/heap` - Memory allocation profile -- `/debug/pprof/goroutine` - Stack traces of all current goroutines -- `/debug/pprof/profile` - CPU profile (30-second sample by default) -- `/debug/pprof/trace` - Execution trace - -Example usage with `go tool pprof`: -```bash -# Analyze CPU profile -go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30 - -# Analyze memory heap -go tool pprof http://localhost:6060/debug/pprof/heap -``` - ## Docker Deployment Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`. @@ -197,22 +137,6 @@ docker-compose -f docker-compose.tcp.yml up -d docker-compose -f docker-compose.root.yml down ``` -### Volume Management - -All configurations use a named volume `certs` for persistent storage: -- SSH keys: `/app/certs/ssh/` -- TLS certificates: `/app/certs/tls/` - -To backup certificates: -```bash -docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar czf /backup/certs-backup.tar.gz -C /data . -``` - -To restore certificates: -```bash -docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar xzf /backup/certs-backup.tar.gz -C /data -``` - ### Recommendation **Use `docker-compose.root.yml`** for production deployments if you need: diff --git a/docs/images/gopher.png b/docs/images/gopher.png new file mode 100644 index 0000000..bdc8fec Binary files /dev/null and b/docs/images/gopher.png differ diff --git a/go.mod b/go.mod index 958657e..214f1c5 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/libdns/cloudflare v0.2.2 github.com/muesli/termenv v0.16.0 + github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.47.0 google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 @@ -27,6 +28,7 @@ require ( github.com/clipperhouse/displaywidth v0.6.2 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/libdns/libdns v1.1.1 // indirect @@ -38,8 +40,10 @@ require ( github.com/miekg/dns v1.1.69 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/zeebo/blake3 v0.2.4 // indirect go.uber.org/multierr v1.11.0 // indirect @@ -52,4 +56,5 @@ require ( golang.org/x/text v0.33.0 // indirect golang.org/x/tools v0.40.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 11912af..4356e9d 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,7 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= @@ -80,6 +81,12 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= @@ -138,5 +145,8 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go new file mode 100644 index 0000000..6c3a1f9 --- /dev/null +++ b/internal/bootstrap/bootstrap.go @@ -0,0 +1,196 @@ +package bootstrap + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + "tunnel_pls/internal/config" + "tunnel_pls/internal/grpc/client" + "tunnel_pls/internal/key" + "tunnel_pls/internal/port" + "tunnel_pls/internal/random" + "tunnel_pls/internal/registry" + "tunnel_pls/internal/transport" + "tunnel_pls/internal/version" + "tunnel_pls/server" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" +) + +type Bootstrap struct { + Randomizer random.Random + Config config.Config + SessionRegistry registry.Registry + Port port.Port + GrpcClient client.Client + ErrChan chan error + SignalChan chan os.Signal +} + +func New(config config.Config, port port.Port) (*Bootstrap, error) { + randomizer := random.New() + sessionRegistry := registry.NewRegistry() + + if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil { + return nil, err + } + + grpcClient, err := client.New(config, sessionRegistry) + if err != nil { + return nil, err + } + + errChan := make(chan error, 5) + signalChan := make(chan os.Signal, 1) + + return &Bootstrap{ + Randomizer: randomizer, + Config: config, + SessionRegistry: sessionRegistry, + Port: port, + GrpcClient: grpcClient, + ErrChan: errChan, + SignalChan: signalChan, + }, nil +} + +func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) { + sshCfg := &ssh.ServerConfig{ + NoClientAuth: true, + ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), + } + + if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { + return nil, fmt.Errorf("generate ssh key: %w", err) + } + privateBytes, err := os.ReadFile(sshKeyPath) + if err != nil { + return nil, fmt.Errorf("read private key: %w", err) + } + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + return nil, fmt.Errorf("parse private key: %w", err) + } + sshCfg.AddHostKey(private) + return sshCfg, nil +} + +func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error { + healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second) + defer healthCancel() + if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil { + return fmt.Errorf("gRPC health check failed: %w", err) + } + + go func() { + if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { + errChan <- fmt.Errorf("failed to subscribe to events: %w", err) + } + }() + + return nil +} + +func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) { + httpserver := transport.NewHTTPServer(conf, registry) + ln, err := httpserver.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to start http server: %w", err) + return + } + if err = httpserver.Serve(ln); err != nil { + errChan <- fmt.Errorf("error when serving http server: %w", err) + } +} + +func startHTTPSServer(conf config.Config, registry registry.Registry, errChan chan<- error) { + tlsCfg, err := transport.NewTLSConfig(conf) + if err != nil { + errChan <- fmt.Errorf("failed to create TLS config: %w", err) + return + } + httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg) + ln, err := httpsServer.Listen() + if err != nil { + errChan <- fmt.Errorf("failed to create TLS config: %w", err) + return + } + if err = httpsServer.Serve(ln); err != nil { + errChan <- fmt.Errorf("error when serving https server: %w", err) + } +} + +func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) { + sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort()) + if err != nil { + errChan <- err + return + } + + sshServer.Start() + + errChan <- sshServer.Close() +} + +func startPprof(pprofPort string, errChan chan<- error) { + pprofAddr := fmt.Sprintf("localhost:%s", pprofPort) + log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) + if err := http.ListenAndServe(pprofAddr, nil); err != nil { + errChan <- fmt.Errorf("pprof server error: %v", err) + } +} +func (b *Bootstrap) Run() error { + sshConfig, err := newSSHConfig(b.Config.KeyLoc()) + if err != nil { + return fmt.Errorf("failed to create SSH config: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM) + + if b.Config.Mode() == types.ServerModeNODE { + err = b.startGRPCClient(ctx, b.Config, b.ErrChan) + if err != nil { + return fmt.Errorf("failed to start gRPC client: %w", err) + } + defer func(grpcClient client.Client) { + err = grpcClient.Close() + if err != nil { + log.Printf("failed to close gRPC client") + } + }(b.GrpcClient) + } + + go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan) + + if b.Config.TLSEnabled() { + go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan) + } + + go func() { + startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan) + }() + + if b.Config.PprofEnabled() { + go startPprof(b.Config.PprofPort(), b.ErrChan) + } + + log.Println("All services started successfully") + + select { + case err = <-b.ErrChan: + return fmt.Errorf("service error: %w", err) + case sig := <-b.SignalChan: + log.Printf("Received signal %s, initiating graceful shutdown", sig) + cancel() + return nil + } +} diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go new file mode 100644 index 0000000..2453cde --- /dev/null +++ b/internal/bootstrap/bootstrap_test.go @@ -0,0 +1,558 @@ +package bootstrap + +import ( + "context" + "fmt" + "net" + "net/http" + _ "net/http/pprof" + "os" + "path/filepath" + "strconv" + "testing" + "time" + "tunnel_pls/internal/config" + "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "google.golang.org/grpc" +) + +type MockSessionRegistry struct { + mock.Mock +} + +func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) { + args := m.Called(key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { + args := m.Called(user, key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error { + args := m.Called(user, oldKey, newKey) + return args.Error(0) +} + +func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool { + args := m.Called(key, session) + return args.Bool(0) +} + +func (m *MockSessionRegistry) Remove(key registry.Key) { + m.Called(key) +} + +func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session { + args := m.Called(user) + return args.Get(0).([]registry.Session) +} + +func (m *MockSessionRegistry) Slug() slug.Slug { + args := m.Called() + return args.Get(0).(slug.Slug) +} + +type MockRandom struct { + mock.Mock +} + +func (m *MockRandom) String(length int) (string, error) { + args := m.Called(length) + return args.String(0), args.Error(1) +} + +type MockConfig struct { + mock.Mock +} + +func (m *MockConfig) Domain() string { return m.Called().String(0) } +func (m *MockConfig) SSHPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) } +func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) PprofPort() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { + args := m.Called() + if args.Get(0) == nil { + return 0 + } + switch v := args.Get(0).(type) { + case types.ServerMode: + return v + case int: + return types.ServerMode(v) + default: + return types.ServerMode(args.Int(0)) + } +} +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } + +type MockPort struct { + mock.Mock +} + +func (m *MockPort) AddRange(startPort, endPort uint16) error { + return m.Called(startPort, endPort).Error(0) +} +func (m *MockPort) Unassigned() (uint16, bool) { + args := m.Called() + var mPort uint16 + if args.Get(0) != nil { + switch v := args.Get(0).(type) { + case int: + mPort = uint16(v) + case uint16: + mPort = v + case uint32: + mPort = uint16(v) + case int32: + mPort = uint16(v) + case float64: + mPort = uint16(v) + default: + mPort = uint16(args.Int(0)) + } + } + return mPort, args.Bool(1) +} +func (m *MockPort) SetStatus(port uint16, assigned bool) error { + return m.Called(port, assigned).Error(0) +} +func (m *MockPort) Claim(port uint16) bool { + return m.Called(port).Bool(0) +} + +type MockGRPCClient struct { + mock.Mock +} + +func (m *MockGRPCClient) ClientConn() *grpc.ClientConn { + args := m.Called() + return args.Get(0).(*grpc.ClientConn) +} + +func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { + args := m.Called(ctx, token) + return args.Bool(0), args.String(1), args.Error(2) +} + +func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error { + args := m.Called(ctx, domain, token) + return args.Error(0) +} + +func (m *MockGRPCClient) Close() error { + args := m.Called() + return args.Error(0) +} + +func TestNew(t *testing.T) { + tests := []struct { + name string + setupConfig func() config.Config + setupPort func() port.Port + wantErr bool + errContains string + }{ + { + name: "Success New with default value", + wantErr: false, + }, + { + name: "Error when AddRange fails", + setupPort: func() port.Port { + mockPort := &MockPort{} + mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range")) + return mockPort + }, + wantErr: true, + errContains: "invalid port range", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var mockPort port.Port + if tt.setupPort != nil { + mockPort = tt.setupPort() + } else { + mockPort = port.New() + } + + var mockConfig config.Config + if tt.setupConfig != nil { + mockConfig = tt.setupConfig() + } else { + var err error + mockConfig, err = config.MustLoad() + assert.NoError(t, err) + } + + bootstrap, err := New(mockConfig, mockPort) + + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + assert.Nil(t, bootstrap) + } else { + assert.NoError(t, err) + assert.NotNil(t, bootstrap) + assert.NotNil(t, bootstrap.Randomizer) + assert.NotNil(t, bootstrap.SessionRegistry) + assert.NotNil(t, bootstrap.Config) + assert.NotNil(t, bootstrap.Port) + assert.NotNil(t, bootstrap.ErrChan) + assert.NotNil(t, bootstrap.SignalChan) + } + }) + } +} + +func randomAvailablePort() (string, error) { + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + defer func(listener net.Listener) { + _ = listener.Close() + }(listener) + + mPort := listener.Addr().(*net.TCPAddr).Port + return strconv.Itoa(mPort), nil +} + +func TestRun(t *testing.T) { + mockRandom := &MockRandom{} + mockErrChan := make(chan error, 1) + mockSignalChan := make(chan os.Signal, 1) + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + tmpDir := t.TempDir() + keyLoc := filepath.Join(tmpDir, "key.key") + + tests := []struct { + name string + setupConfig func() *MockConfig + setupGrpcClient func() *MockGRPCClient + needCerts bool + expectError bool + }{ + { + name: "successful run and termination", + setupConfig: func() *MockConfig { + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig + }, + expectError: false, + }, + { + name: "error from SSH server invalid port", + setupConfig: func() *MockConfig { + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("invalid") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig + }, + expectError: true, + }, + { + name: "error from HTTP server invalid port", + setupConfig: func() *MockConfig { + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("invalid") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig + }, + expectError: true, + }, + { + name: "error from HTTPS server invalid port", + setupConfig: func() *MockConfig { + tempDir := os.TempDir() + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("invalid") + mockConfig.On("TLSEnabled").Return(true) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("TLSStoragePath").Return(tempDir) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig + }, + expectError: true, + }, + { + name: "grpc health check failed", + setupConfig: func() *MockConfig { + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("invalid") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig + }, + setupGrpcClient: func() *MockGRPCClient { + mockGRPCClient := &MockGRPCClient{} + mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed")) + return mockGRPCClient + }, + expectError: true, + }, + { + name: "successful run with pprof enabled", + setupConfig: func() *MockConfig { + mockConfig := &MockConfig{} + pprofPort, _ := randomAvailablePort() + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(true) + mockConfig.On("PprofPort").Return(pprofPort) + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig + }, + expectError: false, + }, { + name: "successful run in NODE mode with signal", + setupConfig: func() *MockConfig { + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig + }, + setupGrpcClient: func() *MockGRPCClient { + mockGRPCClient := &MockGRPCClient{} + mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil) + mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockGRPCClient.On("Close").Return(nil) + return mockGRPCClient + }, + expectError: false, + }, { + name: "successful run in NODE mode with signal buf error when closing", + setupConfig: func() *MockConfig { + mockConfig := &MockConfig{} + mockConfig.On("KeyLoc").Return(keyLoc) + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("Domain").Return("example.com") + mockConfig.On("SSHPort").Return("0") + mockConfig.On("HTTPPort").Return("0") + mockConfig.On("HTTPSPort").Return("0") + mockConfig.On("TLSEnabled").Return(false) + mockConfig.On("TLSRedirect").Return(false) + mockConfig.On("ACMEEmail").Return("test@example.com") + mockConfig.On("CFAPIToken").Return("fake-token") + mockConfig.On("ACMEStaging").Return(true) + mockConfig.On("AllowedPortsStart").Return(uint16(1024)) + mockConfig.On("AllowedPortsEnd").Return(uint16(65535)) + mockConfig.On("BufferSize").Return(4096) + mockConfig.On("PprofEnabled").Return(false) + mockConfig.On("PprofPort").Return("0") + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("0") + mockConfig.On("NodeToken").Return("fake-node-token") + return mockConfig + }, + setupGrpcClient: func() *MockGRPCClient { + mockGRPCClient := &MockGRPCClient{} + mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil) + mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy")) + return mockGRPCClient + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockConfig := tt.setupConfig() + mockGRPCClient := &MockGRPCClient{} + bootstrap := &Bootstrap{ + Randomizer: mockRandom, + Config: mockConfig, + SessionRegistry: mockSessionRegistry, + Port: mockPort, + ErrChan: mockErrChan, + SignalChan: mockSignalChan, + GrpcClient: mockGRPCClient, + } + + if tt.setupGrpcClient != nil { + bootstrap.GrpcClient = tt.setupGrpcClient() + } + + done := make(chan error, 1) + go func() { + done <- bootstrap.Run() + }() + + if tt.expectError { + err := <-done + assert.Error(t, err) + } else if tt.name == "successful run with pprof enabled" { + time.Sleep(200 * time.Millisecond) + fmt.Println(mockConfig.PprofPort()) + resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort())) + assert.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + err = resp.Body.Close() + assert.NoError(t, err) + mockSignalChan <- os.Interrupt + err = <-done + assert.NoError(t, err) + } else { + time.Sleep(time.Second) + mockSignalChan <- os.Interrupt + err := <-done + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 62e1aca..3e6c9e1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -9,8 +9,11 @@ type Config interface { HTTPPort() string HTTPSPort() string + KeyLoc() string + TLSEnabled() bool TLSRedirect() bool + TLSStoragePath() string ACMEEmail() string CFAPIToken() string @@ -20,6 +23,7 @@ type Config interface { AllowedPortsEnd() uint16 BufferSize() int + HeaderSize() int PprofEnabled() bool PprofPort() string @@ -47,14 +51,17 @@ func (c *config) Domain() string { return c.domain } func (c *config) SSHPort() string { return c.sshPort } func (c *config) HTTPPort() string { return c.httpPort } func (c *config) HTTPSPort() string { return c.httpsPort } +func (c *config) KeyLoc() string { return c.keyLoc } func (c *config) TLSEnabled() bool { return c.tlsEnabled } func (c *config) TLSRedirect() bool { return c.tlsRedirect } +func (c *config) TLSStoragePath() string { return c.tlsStoragePath } func (c *config) ACMEEmail() string { return c.acmeEmail } func (c *config) CFAPIToken() string { return c.cfAPIToken } func (c *config) ACMEStaging() bool { return c.acmeStaging } func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart } func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd } func (c *config) BufferSize() int { return c.bufferSize } +func (c *config) HeaderSize() int { return c.headerSize } func (c *config) PprofEnabled() bool { return c.pprofEnabled } func (c *config) PprofPort() string { return c.pprofPort } func (c *config) Mode() types.ServerMode { return c.mode } diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..85f93f3 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,405 @@ +package config + +import ( + "os" + "testing" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" +) + +func TestGetenv(t *testing.T) { + tests := []struct { + name string + key string + val string + def string + expected string + }{ + { + name: "returns existing env", + key: "TEST_ENV_EXIST", + val: "value", + def: "default", + expected: "value", + }, + { + name: "returns default when env missing", + key: "TEST_ENV_MISSING", + val: "", + def: "default", + expected: "default", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.val != "" { + t.Setenv(tt.key, tt.val) + } else { + err := os.Unsetenv(tt.key) + assert.NoError(t, err) + } + assert.Equal(t, tt.expected, getenv(tt.key, tt.def)) + }) + } +} + +func TestGetenvBool(t *testing.T) { + tests := []struct { + name string + key string + val string + def bool + expected bool + }{ + { + name: "returns true when env is true", + key: "TEST_BOOL_TRUE", + val: "true", + def: false, + expected: true, + }, + { + name: "returns false when env is false", + key: "TEST_BOOL_FALSE", + val: "false", + def: true, + expected: false, + }, + { + name: "returns default when env missing", + key: "TEST_BOOL_MISSING", + val: "", + def: true, + expected: true, + }, + { + name: "returns false when env is not true", + key: "TEST_BOOL_INVALID", + val: "yes", + def: true, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.val != "" { + t.Setenv(tt.key, tt.val) + } else { + err := os.Unsetenv(tt.key) + assert.NoError(t, err) + } + assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def)) + }) + } +} + +func TestParseMode(t *testing.T) { + tests := []struct { + name string + mode string + expect types.ServerMode + expectErr bool + }{ + {"standalone", "standalone", types.ServerModeSTANDALONE, false}, + {"node", "node", types.ServerModeNODE, false}, + {"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false}, + {"invalid", "invalid", 0, true}, + {"empty (default)", "", types.ServerModeSTANDALONE, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.mode != "" { + t.Setenv("MODE", tt.mode) + } else { + err := os.Unsetenv("MODE") + assert.NoError(t, err) + } + mode, err := parseMode() + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expect, mode) + } + }) + } +} + +func TestParseAllowedPorts(t *testing.T) { + tests := []struct { + name string + val string + start uint16 + end uint16 + expectErr bool + }{ + {"valid range", "1000-2000", 1000, 2000, false}, + {"empty", "", 0, 0, false}, + {"invalid format - no dash", "1000", 0, 0, true}, + {"invalid format - too many dashes", "1000-2000-3000", 0, 0, true}, + {"invalid start port", "abc-2000", 0, 0, true}, + {"invalid end port", "1000-abc", 0, 0, true}, + {"out of range start", "70000-80000", 0, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.val != "" { + t.Setenv("ALLOWED_PORTS", tt.val) + } else { + err := os.Unsetenv("ALLOWED_PORTS") + assert.NoError(t, err) + } + start, end, err := parseAllowedPorts() + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.start, start) + assert.Equal(t, tt.end, end) + } + }) + } +} + +func TestParseBufferSize(t *testing.T) { + tests := []struct { + name string + val string + expect int + }{ + {"valid size", "8192", 8192}, + {"default size", "", 32768}, + {"too small", "1024", 4096}, + {"too large", "2000000", 4096}, + {"invalid format", "abc", 4096}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.val != "" { + t.Setenv("BUFFER_SIZE", tt.val) + } else { + err := os.Unsetenv("BUFFER_SIZE") + assert.NoError(t, err) + } + size := parseBufferSize() + assert.Equal(t, tt.expect, size) + }) + } +} + +func TestParseHeaderSize(t *testing.T) { + tests := []struct { + name string + val string + expect int + }{ + {"valid size", "8192", 8192}, + {"default size", "", 4096}, + {"too small", "1024", 4096}, + {"too large", "2000000", 4096}, + {"invalid format", "abc", 4096}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.val != "" { + t.Setenv("MAX_HEADER_SIZE", tt.val) + } else { + err := os.Unsetenv("MAX_HEADER_SIZE") + assert.NoError(t, err) + } + size := parseHeaderSize() + assert.Equal(t, tt.expect, size) + }) + } +} + +func TestParse(t *testing.T) { + tests := []struct { + name string + envs map[string]string + expectErr bool + }{ + { + name: "minimal valid config", + envs: map[string]string{ + "DOMAIN": "example.com", + }, + expectErr: false, + }, + { + name: "TLS enabled without token", + envs: map[string]string{ + "TLS_ENABLED": "true", + }, + expectErr: true, + }, + { + name: "TLS enabled with token", + envs: map[string]string{ + "TLS_ENABLED": "true", + "CF_API_TOKEN": "secret", + }, + expectErr: false, + }, + { + name: "Node mode without token", + envs: map[string]string{ + "MODE": "node", + }, + expectErr: true, + }, + { + name: "Node mode with token", + envs: map[string]string{ + "MODE": "node", + "NODE_TOKEN": "token", + }, + expectErr: false, + }, + { + name: "invalid mode", + envs: map[string]string{ + "MODE": "invalid", + }, + expectErr: true, + }, + { + name: "invalid allowed ports", + envs: map[string]string{ + "ALLOWED_PORTS": "1000", + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Clearenv() + for k, v := range tt.envs { + t.Setenv(k, v) + } + cfg, err := parse() + if tt.expectErr { + assert.Error(t, err) + assert.Nil(t, cfg) + } else { + assert.NoError(t, err) + assert.NotNil(t, cfg) + } + }) + } +} + +func TestGetters(t *testing.T) { + envs := map[string]string{ + "DOMAIN": "example.com", + "PORT": "2222", + "HTTP_PORT": "80", + "HTTPS_PORT": "443", + "KEY_LOC": "certs/ssh/id_rsa", + "TLS_ENABLED": "true", + "TLS_REDIRECT": "true", + "TLS_STORAGE_PATH": "certs/tls/", + "ACME_EMAIL": "test@example.com", + "CF_API_TOKEN": "token", + "ACME_STAGING": "true", + "ALLOWED_PORTS": "1000-2000", + "BUFFER_SIZE": "16384", + "MAX_HEADER_SIZE": "4096", + "PPROF_ENABLED": "true", + "PPROF_PORT": "7070", + "MODE": "standalone", + "GRPC_ADDRESS": "127.0.0.1", + "GRPC_PORT": "9090", + "NODE_TOKEN": "ntoken", + } + + os.Clearenv() + for k, v := range envs { + t.Setenv(k, v) + } + + cfg, err := parse() + assert.NoError(t, err) + + assert.Equal(t, "example.com", cfg.Domain()) + assert.Equal(t, "2222", cfg.SSHPort()) + assert.Equal(t, "80", cfg.HTTPPort()) + assert.Equal(t, "443", cfg.HTTPSPort()) + assert.Equal(t, "certs/ssh/id_rsa", cfg.KeyLoc()) + assert.Equal(t, true, cfg.TLSEnabled()) + assert.Equal(t, true, cfg.TLSRedirect()) + assert.Equal(t, "certs/tls/", cfg.TLSStoragePath()) + assert.Equal(t, "test@example.com", cfg.ACMEEmail()) + assert.Equal(t, "token", cfg.CFAPIToken()) + assert.Equal(t, true, cfg.ACMEStaging()) + assert.Equal(t, uint16(1000), cfg.AllowedPortsStart()) + assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd()) + assert.Equal(t, 16384, cfg.BufferSize()) + assert.Equal(t, 4096, cfg.HeaderSize()) + assert.Equal(t, true, cfg.PprofEnabled()) + assert.Equal(t, "7070", cfg.PprofPort()) + assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode()) + assert.Equal(t, "127.0.0.1", cfg.GRPCAddress()) + assert.Equal(t, "9090", cfg.GRPCPort()) + assert.Equal(t, "ntoken", cfg.NodeToken()) +} + +func TestMustLoad(t *testing.T) { + t.Run("success", func(t *testing.T) { + os.Clearenv() + t.Setenv("DOMAIN", "example.com") + cfg, err := MustLoad() + assert.NoError(t, err) + assert.NotNil(t, cfg) + }) + + t.Run("loadEnvFile error", func(t *testing.T) { + err := os.Mkdir(".env", 0755) + assert.NoError(t, err) + defer func() { + err = os.Remove(".env") + assert.NoError(t, err) + }() + + cfg, err := MustLoad() + assert.Error(t, err) + assert.Nil(t, cfg) + }) + + t.Run("parse error", func(t *testing.T) { + os.Clearenv() + t.Setenv("MODE", "invalid") + cfg, err := MustLoad() + assert.Error(t, err) + assert.Nil(t, cfg) + }) +} + +func TestLoadEnvFile(t *testing.T) { + t.Run("file exists", func(t *testing.T) { + err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644) + assert.NoError(t, err) + defer func() { + err = os.Remove(".env") + assert.NoError(t, err) + }() + + err = loadEnvFile() + assert.NoError(t, err) + assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE")) + }) + + t.Run("file missing", func(t *testing.T) { + _ = os.Remove(".env") + err := loadEnvFile() + assert.NoError(t, err) + }) +} diff --git a/internal/config/loader.go b/internal/config/loader.go index cde9fd0..5cbfe1f 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -18,18 +18,21 @@ type config struct { httpPort string httpsPort string - tlsEnabled bool - tlsRedirect bool + keyLoc string - acmeEmail string - cfAPIToken string - acmeStaging bool + tlsEnabled bool + tlsRedirect bool + tlsStoragePath string + acmeEmail string + cfAPIToken string + acmeStaging bool allowedPortsStart uint16 allowedPortsEnd uint16 bufferSize int - + headerSize int + pprofEnabled bool pprofPort string @@ -51,8 +54,11 @@ func parse() (*config, error) { httpPort := getenv("HTTP_PORT", "8080") httpsPort := getenv("HTTPS_PORT", "8443") + keyLoc := getenv("KEY_LOC", "certs/privkey.pem") + tlsEnabled := getenvBool("TLS_ENABLED", false) tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false) + tlsStoragePath := getenv("TLS_STORAGE_PATH", "certs/tls/") acmeEmail := getenv("ACME_EMAIL", "admin@"+domain) acmeStaging := getenvBool("ACME_STAGING", false) @@ -68,6 +74,7 @@ func parse() (*config, error) { } bufferSize := parseBufferSize() + headerSize := parseHeaderSize() pprofEnabled := getenvBool("PPROF_ENABLED", false) pprofPort := getenv("PPROF_PORT", "6060") @@ -85,14 +92,17 @@ func parse() (*config, error) { sshPort: sshPort, httpPort: httpPort, httpsPort: httpsPort, + keyLoc: keyLoc, tlsEnabled: tlsEnabled, tlsRedirect: tlsRedirect, + tlsStoragePath: tlsStoragePath, acmeEmail: acmeEmail, cfAPIToken: cfToken, acmeStaging: acmeStaging, allowedPortsStart: start, allowedPortsEnd: end, bufferSize: bufferSize, + headerSize: headerSize, pprofEnabled: pprofEnabled, pprofPort: pprofPort, mode: mode, @@ -154,6 +164,16 @@ func parseBufferSize() int { return size } +func parseHeaderSize() int { + raw := getenv("MAX_HEADER_SIZE", "4096") + size, err := strconv.Atoi(raw) + if err != nil || size < 4096 || size > 131072 { + log.Println("Invalid BUFFER_SIZE, falling back to 4096") + return 4096 + } + return size +} + func getenv(key, def string) string { if v := os.Getenv(key); v != "" { return v diff --git a/internal/grpc/client/client.go b/internal/grpc/client/client.go index f2e0a1e..3adcc57 100644 --- a/internal/grpc/client/client.go +++ b/internal/grpc/client/client.go @@ -38,7 +38,15 @@ type client struct { closing bool } -func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) { +var ( + grpcNewClient = grpc.NewClient + healthNewHealthClient = grpc_health_v1.NewHealthClient + initialBackoff = time.Second +) + +func New(config config.Config, sessionRegistry registry.Registry) (Client, error) { + address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort()) + var opts []grpc.DialOption opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) @@ -58,7 +66,7 @@ func New(config config.Config, address string, sessionRegistry registry.Registry ), ) - conn, err := grpc.NewClient(address, opts...) + conn, err := grpcNewClient(address, opts...) if err != nil { return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err) } @@ -77,85 +85,100 @@ func New(config config.Config, address string, sessionRegistry registry.Registry } func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error { - const ( - baseBackoff = time.Second - maxBackoff = 30 * time.Second - ) - - backoff := baseBackoff - wait := func() error { - if backoff <= 0 { - return nil - } - select { - case <-time.After(backoff): - return nil - case <-ctx.Done(): - return ctx.Err() - } - } - growBackoff := func() { - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } - } + backoff := initialBackoff for { - subscribe, err := c.eventService.Subscribe(ctx) - if err != nil { - if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil { - return err - } - if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated { - return err - } - if err = wait(); err != nil { - return err - } - growBackoff() - log.Printf("Reconnect to controller within %v sec", backoff.Seconds()) - continue + if err := c.subscribeAndProcess(ctx, identity, authToken, &backoff); err != nil { + return err } + } +} - err = subscribe.Send(&proto.Node{ - Type: proto.EventType_AUTHENTICATION, - Payload: &proto.Node_AuthEvent{ - AuthEvent: &proto.Authentication{ - Identity: identity, - AuthToken: authToken, - }, +func (c *client) subscribeAndProcess(ctx context.Context, identity, authToken string, backoff *time.Duration) error { + subscribe, err := c.eventService.Subscribe(ctx) + if err != nil { + return c.handleSubscribeError(ctx, err, backoff) + } + + err = subscribe.Send(&proto.Node{ + Type: proto.EventType_AUTHENTICATION, + Payload: &proto.Node_AuthEvent{ + AuthEvent: &proto.Authentication{ + Identity: identity, + AuthToken: authToken, }, - }) + }, + }) - if err != nil { - log.Println("Authentication failed to send to gRPC server:", err) - if c.isConnectionError(err) { - if err = wait(); err != nil { - return err - } - growBackoff() - continue - } - return err - } - log.Println("Authentication Successfully sent to gRPC server") - backoff = baseBackoff + if err != nil { + return c.handleAuthError(ctx, err, backoff) + } - if err = c.processEventStream(subscribe); err != nil { - if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil { - return err - } - if c.isConnectionError(err) { - log.Printf("Reconnect to controller within %v sec", backoff.Seconds()) - if err = wait(); err != nil { - return err - } - growBackoff() - continue - } - return err - } + log.Println("Authentication Successfully sent to gRPC server") + *backoff = time.Second + + return c.handleStreamError(ctx, c.processEventStream(subscribe), backoff) +} + +func (c *client) handleSubscribeError(ctx context.Context, err error, backoff *time.Duration) error { + if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil { + return err + } + if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated { + return err + } + if err = c.wait(ctx, *backoff); err != nil { + return err + } + c.growBackoff(backoff) + log.Printf("Reconnect to controller within %v sec", backoff.Seconds()) + return nil +} + +func (c *client) handleAuthError(ctx context.Context, err error, backoff *time.Duration) error { + log.Println("Authentication failed to send to gRPC server:", err) + if !c.isConnectionError(err) { + return err + } + if err := c.wait(ctx, *backoff); err != nil { + return err + } + c.growBackoff(backoff) + return nil +} + +func (c *client) handleStreamError(ctx context.Context, err error, backoff *time.Duration) error { + if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil { + return err + } + if !c.isConnectionError(err) { + return err + } + log.Printf("Reconnect to controller within %v sec", backoff.Seconds()) + if err := c.wait(ctx, *backoff); err != nil { + return err + } + c.growBackoff(backoff) + return nil +} + +func (c *client) wait(ctx context.Context, duration time.Duration) error { + if duration <= 0 { + return nil + } + select { + case <-time.After(duration): + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (c *client) growBackoff(backoff *time.Duration) { + const maxBackoff = 30 * time.Second + *backoff *= 2 + if *backoff > maxBackoff { + *backoff = maxBackoff } } @@ -191,35 +214,20 @@ func (c *client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, pr func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { slugEvent := evt.GetSlugEvent() user := slugEvent.GetUser() - oldSlug := slugEvent.GetOld() - newSlug := slugEvent.GetNew() + oldKey := types.SessionKey{Id: slugEvent.GetOld(), Type: types.TunnelTypeHTTP} + newKey := types.SessionKey{Id: slugEvent.GetNew(), Type: types.TunnelTypeHTTP} - userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}) + userSession, err := c.sessionRegistry.Get(oldKey) if err != nil { - return c.sendNode(subscribe, &proto.Node{ - Type: proto.EventType_SLUG_CHANGE_RESPONSE, - Payload: &proto.Node_SlugEventResponse{ - SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()}, - }, - }, "slug change failure response") + return c.sendSlugChangeResponse(subscribe, false, err.Error()) } - if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil { - return c.sendNode(subscribe, &proto.Node{ - Type: proto.EventType_SLUG_CHANGE_RESPONSE, - Payload: &proto.Node_SlugEventResponse{ - SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()}, - }, - }, "slug change failure response") + if err = c.sessionRegistry.Update(user, oldKey, newKey); err != nil { + return c.sendSlugChangeResponse(subscribe, false, err.Error()) } userSession.Interaction().Redraw() - return c.sendNode(subscribe, &proto.Node{ - Type: proto.EventType_SLUG_CHANGE_RESPONSE, - Payload: &proto.Node_SlugEventResponse{ - SlugEventResponse: &proto.SlugChangeEventResponse{Success: true, Message: ""}, - }, - }, "slug change success response") + return c.sendSlugChangeResponse(subscribe, true, "") } func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { @@ -238,12 +246,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node }) } - return c.sendNode(subscribe, &proto.Node{ - Type: proto.EventType_GET_SESSIONS, - Payload: &proto.Node_GetSessionsEvent{ - GetSessionsEvent: &proto.GetSessionsResponse{Details: details}, - }, - }, "send get sessions response") + return c.sendGetSessionsResponse(subscribe, details) } func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { @@ -253,39 +256,46 @@ func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType()) if err != nil { - return c.sendNode(subscribe, &proto.Node{ - Type: proto.EventType_TERMINATE_SESSION, - Payload: &proto.Node_TerminateSessionEventResponse{ - TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()}, - }, - }, "terminate session invalid tunnel type") + return c.sendTerminateSessionResponse(subscribe, false, err.Error()) } userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType}) if err != nil { - return c.sendNode(subscribe, &proto.Node{ - Type: proto.EventType_TERMINATE_SESSION, - Payload: &proto.Node_TerminateSessionEventResponse{ - TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()}, - }, - }, "terminate session fetch failed") + return c.sendTerminateSessionResponse(subscribe, false, err.Error()) } if err = userSession.Lifecycle().Close(); err != nil { - return c.sendNode(subscribe, &proto.Node{ - Type: proto.EventType_TERMINATE_SESSION, - Payload: &proto.Node_TerminateSessionEventResponse{ - TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()}, - }, - }, "terminate session close failed") + return c.sendTerminateSessionResponse(subscribe, false, err.Error()) } + return c.sendTerminateSessionResponse(subscribe, true, "") +} + +func (c *client) sendSlugChangeResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error { + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_SLUG_CHANGE_RESPONSE, + Payload: &proto.Node_SlugEventResponse{ + SlugEventResponse: &proto.SlugChangeEventResponse{Success: success, Message: message}, + }, + }, "slug change response") +} + +func (c *client) sendGetSessionsResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], details []*proto.Detail) error { + return c.sendNode(subscribe, &proto.Node{ + Type: proto.EventType_GET_SESSIONS, + Payload: &proto.Node_GetSessionsEvent{ + GetSessionsEvent: &proto.GetSessionsResponse{Details: details}, + }, + }, "send get sessions response") +} + +func (c *client) sendTerminateSessionResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error { return c.sendNode(subscribe, &proto.Node{ Type: proto.EventType_TERMINATE_SESSION, Payload: &proto.Node_TerminateSessionEventResponse{ - TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: true, Message: ""}, + TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: success, Message: message}, }, - }, "terminate session success response") + }, "terminate session response") } func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error { @@ -326,7 +336,7 @@ func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bo } func (c *client) CheckServerHealth(ctx context.Context) error { - healthClient := grpc_health_v1.NewHealthClient(c.ClientConn()) + healthClient := healthNewHealthClient(c.ClientConn()) resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{ Service: "", }) diff --git a/internal/grpc/client/client_test.go b/internal/grpc/client/client_test.go new file mode 100644 index 0000000..e69065d --- /dev/null +++ b/internal/grpc/client/client_test.go @@ -0,0 +1,1078 @@ +package client + +import ( + "context" + "errors" + "fmt" + "io" + "testing" + "time" + + "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/status" +) + +func TestClient_ClientConn(t *testing.T) { + conn := &grpc.ClientConn{} + c := &client{conn: conn} + if c.ClientConn() != conn { + t.Errorf("ClientConn() did not return expected connection") + } +} + +func TestClient_Close(t *testing.T) { + c := &client{} + if err := c.Close(); err != nil { + t.Errorf("Close() on nil connection returned error: %v", err) + } +} + +func TestAuthorizeConn(t *testing.T) { + mockUserSvc := &mockUserServiceClient{} + c := &client{authorizeConnectionService: mockUserSvc} + + tests := []struct { + name string + token string + mockResp *proto.CheckResponse + mockErr error + wantAuth bool + wantUser string + wantErr bool + }{ + { + name: "Success", + token: "valid", + mockResp: &proto.CheckResponse{Response: proto.AuthorizationResponse_MESSAGE_TYPE_AUTHORIZED, User: "mas-fuad"}, + wantAuth: true, + wantUser: "mas-fuad", + wantErr: false, + }, + { + name: "Unauthorized", + token: "invalid", + mockResp: &proto.CheckResponse{Response: proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED}, + wantAuth: false, + wantUser: "UNAUTHORIZED", + wantErr: false, + }, + { + name: "Error", + token: "error", + mockErr: errors.New("grpc error"), + wantAuth: false, + wantUser: "UNAUTHORIZED", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockUserSvc.On("Check", mock.Anything, &proto.CheckRequest{AuthToken: tt.token}, mock.Anything).Return(tt.mockResp, tt.mockErr).Once() + + auth, user, err := c.AuthorizeConn(context.Background(), tt.token) + if (err != nil) != tt.wantErr { + t.Errorf("AuthorizeConn() error = %v, wantErr %v", err, tt.wantErr) + } + assert.Equal(t, tt.wantAuth, auth) + assert.Equal(t, tt.wantUser, user) + mockUserSvc.AssertExpectations(t) + }) + } +} + +func TestHandleSubscribeError(t *testing.T) { + c := &client{} + ctx := context.Background() + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + + tests := []struct { + name string + ctx context.Context + err error + backoff time.Duration + wantErr bool + wantB time.Duration + }{ + { + name: "ContextCanceled", + ctx: canceledCtx, + err: context.Canceled, + backoff: time.Second, + wantErr: true, + }, + { + name: "GrpcCanceled", + ctx: ctx, + err: status.Error(codes.Canceled, "canceled"), + backoff: time.Second, + wantErr: true, + }, + { + name: "CtxErrSet", + ctx: canceledCtx, + err: errors.New("other error"), + backoff: time.Second, + wantErr: true, + }, + { + name: "Unauthenticated", + ctx: ctx, + err: status.Error(codes.Unauthenticated, "unauth"), + backoff: time.Second, + wantErr: true, + }, + { + name: "ConnectionError", + ctx: ctx, + err: status.Error(codes.Unavailable, "unavailable"), + backoff: time.Second, + wantErr: false, + wantB: 2 * time.Second, + }, + { + name: "NonConnectionError", + ctx: ctx, + err: status.Error(codes.Internal, "internal"), + backoff: time.Second, + wantErr: true, + }, + { + name: "WaitCanceled", + ctx: canceledCtx, + err: status.Error(codes.Unavailable, "unavailable"), + backoff: time.Second, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + backoff := tt.backoff + err := c.handleSubscribeError(tt.ctx, tt.err, &backoff) + if (err != nil) != tt.wantErr { + t.Errorf("handleSubscribeError() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && backoff != tt.wantB { + t.Errorf("handleSubscribeError() backoff = %v, want %v", backoff, tt.wantB) + } + }) + } + + t.Run("WaitCanceledReal", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + backoff := 50 * time.Millisecond + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + err := c.handleSubscribeError(ctx, status.Error(codes.Unavailable, "unavailable"), &backoff) + if err == nil { + t.Errorf("expected error from wait") + } + }) +} + +func TestHandleStreamError(t *testing.T) { + c := &client{} + ctx := context.Background() + canceledCtx, cancel := context.WithCancel(ctx) + cancel() + + tests := []struct { + name string + ctx context.Context + err error + backoff time.Duration + wantErr bool + wantB time.Duration + }{ + { + name: "ContextCanceled", + ctx: canceledCtx, + err: context.Canceled, + backoff: time.Second, + wantErr: true, + }, + { + name: "GrpcCanceled", + ctx: ctx, + err: status.Error(codes.Canceled, "canceled"), + backoff: time.Second, + wantErr: true, + }, + { + name: "CtxErrSet", + ctx: canceledCtx, + err: errors.New("other error"), + backoff: time.Second, + wantErr: true, + }, + { + name: "ConnectionError", + ctx: ctx, + err: status.Error(codes.Unavailable, "unavailable"), + backoff: time.Second, + wantErr: false, + wantB: 2 * time.Second, + }, + { + name: "NonConnectionError", + ctx: ctx, + err: status.Error(codes.Internal, "internal"), + backoff: time.Second, + wantErr: true, + }, + { + name: "WaitCanceled", + ctx: canceledCtx, + err: status.Error(codes.Unavailable, "unavailable"), + backoff: time.Second, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + backoff := tt.backoff + err := c.handleStreamError(tt.ctx, tt.err, &backoff) + if (err != nil) != tt.wantErr { + t.Errorf("handleStreamError() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && backoff != tt.wantB { + t.Errorf("handleStreamError() backoff = %v, want %v", backoff, tt.wantB) + } + }) + } + + t.Run("WaitCanceledReal", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + backoff := 50 * time.Millisecond + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + err := c.handleStreamError(ctx, status.Error(codes.Unavailable, "unavailable"), &backoff) + if err == nil { + t.Errorf("expected error from wait") + } + }) +} + +func TestHandleAuthError(t *testing.T) { + c := &client{} + ctx := context.Background() + + tests := []struct { + name string + err error + backoff time.Duration + wantErr bool + wantB time.Duration + }{ + { + name: "ConnectionError", + err: status.Error(codes.Unavailable, "unavailable"), + backoff: time.Second, + wantErr: false, + wantB: 2 * time.Second, + }, + { + name: "NonConnectionError", + err: status.Error(codes.Internal, "internal"), + backoff: time.Second, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + backoff := tt.backoff + err := c.handleAuthError(ctx, tt.err, &backoff) + if (err != nil) != tt.wantErr { + t.Errorf("handleAuthError() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && backoff != tt.wantB { + t.Errorf("handleAuthError() backoff = %v, want %v", backoff, tt.wantB) + } + }) + } +} + +func TestHandleAuthError_WaitFail(t *testing.T) { + c := &client{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + backoff := time.Second + err := c.handleAuthError(ctx, status.Error(codes.Unavailable, "unavailable"), &backoff) + if err == nil { + t.Errorf("expected error when wait fails") + } +} + +func TestProcessEventStream(t *testing.T) { + c := &client{} + + t.Run("UnknownEventType", func(t *testing.T) { + mockStream := &mockSubscribeClient{} + mockStream.On("Recv").Return(&proto.Events{Type: proto.EventType(999)}, nil).Once() + mockStream.On("Recv").Return(nil, io.EOF).Once() + + err := c.processEventStream(mockStream) + assert.ErrorIs(t, err, io.EOF) + }) + + t.Run("DispatchSuccess", func(t *testing.T) { + events := []proto.EventType{ + proto.EventType_SLUG_CHANGE, + proto.EventType_GET_SESSIONS, + proto.EventType_TERMINATE_SESSION, + } + + for _, et := range events { + t.Run(et.String(), func(t *testing.T) { + mockStream := &mockSubscribeClient{} + payload := &proto.Events{Type: et} + switch et { + case proto.EventType_SLUG_CHANGE: + payload.Payload = &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}} + case proto.EventType_GET_SESSIONS: + payload.Payload = &proto.Events_GetSessionsEvent{GetSessionsEvent: &proto.GetSessionsEvent{}} + case proto.EventType_TERMINATE_SESSION: + payload.Payload = &proto.Events_TerminateSessionEvent{TerminateSessionEvent: &proto.TerminateSessionEvent{}} + } + + mockStream.On("Recv").Return(payload, nil).Once() + mockStream.On("Recv").Return(nil, io.EOF).Once() + + mockReg := &mockRegistry{} + c.sessionRegistry = mockReg + mCfg := &MockConfig{} + c.config = mCfg + mCfg.On("Domain").Return("test.com").Maybe() + + switch et { + case proto.EventType_SLUG_CHANGE: + mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once() + case proto.EventType_GET_SESSIONS: + mockReg.On("GetAllSessionFromUser", mock.Anything).Return(nil).Once() + case proto.EventType_TERMINATE_SESSION: + mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("fail")).Once() + } + mockStream.On("Send", mock.Anything).Return(nil).Once() + + err := c.processEventStream(mockStream) + assert.ErrorIs(t, err, io.EOF) + }) + } + }) + + t.Run("HandlerError", func(t *testing.T) { + mockStream := &mockSubscribeClient{} + mockStream.On("Recv").Return(&proto.Events{ + Type: proto.EventType_SLUG_CHANGE, + Payload: &proto.Events_SlugEvent{SlugEvent: &proto.SlugChangeEvent{}}, + }, nil).Once() + + mockReg := &mockRegistry{} + mockReg.On("Get", mock.Anything).Return(nil, errors.New("fail")).Once() + c.sessionRegistry = mockReg + + expectedErr := status.Error(codes.Unavailable, "send fail") + mockStream.On("Send", mock.Anything).Return(expectedErr).Once() + + err := c.processEventStream(mockStream) + assert.Equal(t, expectedErr, err) + }) +} + +func TestSendNode(t *testing.T) { + c := &client{} + + t.Run("Success", func(t *testing.T) { + mockStream := &mockSubscribeClient{} + mockStream.On("Send", mock.Anything).Return(nil).Once() + err := c.sendNode(mockStream, &proto.Node{}, "context") + assert.NoError(t, err) + mockStream.AssertExpectations(t) + }) + + t.Run("ConnectionError", func(t *testing.T) { + mockStream := &mockSubscribeClient{} + expectedErr := status.Error(codes.Unavailable, "fail") + mockStream.On("Send", mock.Anything).Return(expectedErr).Once() + err := c.sendNode(mockStream, &proto.Node{}, "context") + assert.ErrorIs(t, err, expectedErr) + mockStream.AssertExpectations(t) + }) + + t.Run("OtherError", func(t *testing.T) { + mockStream := &mockSubscribeClient{} + mockStream.On("Send", mock.Anything).Return(status.Error(codes.Internal, "fail")).Once() + err := c.sendNode(mockStream, &proto.Node{}, "context") + assert.NoError(t, err) + mockStream.AssertExpectations(t) + }) +} + +func TestHandleSlugChange(t *testing.T) { + mockReg := &mockRegistry{} + mockStream := &mockSubscribeClient{} + c := &client{sessionRegistry: mockReg} + + evt := &proto.Events{ + Payload: &proto.Events_SlugEvent{ + SlugEvent: &proto.SlugChangeEvent{ + User: "mas-fuad", + Old: "old-slug", + New: "new-slug", + }, + }, + } + + t.Run("Success", func(t *testing.T) { + mockSess := &mockSession{} + mockInter := &mockInteraction{} + mockSess.On("Interaction").Return(mockInter).Once() + mockInter.On("Redraw").Return().Once() + + mockReg.On("Get", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once() + mockReg.On("Update", "mas-fuad", types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}, types.SessionKey{Id: "new-slug", Type: types.TunnelTypeHTTP}).Return(nil).Once() + + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + return n.Type == proto.EventType_SLUG_CHANGE_RESPONSE && n.GetSlugEventResponse().Success + })).Return(nil).Once() + + err := c.handleSlugChange(mockStream, evt) + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + mockInter.AssertExpectations(t) + }) + + t.Run("SessionNotFound", func(t *testing.T) { + mockReg.On("Get", mock.Anything).Return(nil, errors.New("not found")).Once() + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "not found" + })).Return(nil).Once() + + err := c.handleSlugChange(mockStream, evt) + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + }) + + t.Run("UpdateError", func(t *testing.T) { + mockSess := &mockSession{} + mockReg.On("Get", mock.Anything).Return(mockSess, nil).Once() + mockReg.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("update fail")).Once() + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + return !n.GetSlugEventResponse().Success && n.GetSlugEventResponse().Message == "update fail" + })).Return(nil).Once() + + err := c.handleSlugChange(mockStream, evt) + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + }) +} + +func TestHandleGetSessions(t *testing.T) { + mockReg := &mockRegistry{} + mockStream := &mockSubscribeClient{} + mockCfg := &MockConfig{} + c := &client{sessionRegistry: mockReg, config: mockCfg} + + evt := &proto.Events{ + Payload: &proto.Events_GetSessionsEvent{ + GetSessionsEvent: &proto.GetSessionsEvent{ + Identity: "mas-fuad", + }, + }, + } + + t.Run("Success", func(t *testing.T) { + now := time.Now() + mockSess := &mockSession{} + mockSess.On("Detail").Return(&types.Detail{ + ForwardingType: "http", + Slug: "myslug", + UserID: "mas-fuad", + Active: true, + StartedAt: now, + }).Once() + + mockReg.On("GetAllSessionFromUser", "mas-fuad").Return([]registry.Session{mockSess}).Once() + mockCfg.On("Domain").Return("test.com").Once() + + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + if n.Type != proto.EventType_GET_SESSIONS { + return false + } + details := n.GetGetSessionsEvent().Details + return len(details) == 1 && details[0].Slug == "myslug" + })).Return(nil).Once() + + err := c.handleGetSessions(mockStream, evt) + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + mockCfg.AssertExpectations(t) + }) +} + +func TestHandleTerminateSession(t *testing.T) { + mockReg := &mockRegistry{} + mockStream := &mockSubscribeClient{} + c := &client{sessionRegistry: mockReg} + + evt := &proto.Events{ + Payload: &proto.Events_TerminateSessionEvent{ + TerminateSessionEvent: &proto.TerminateSessionEvent{ + User: "mas-fuad", + Slug: "myslug", + TunnelType: proto.TunnelType_HTTP, + }, + }, + } + + t.Run("Success", func(t *testing.T) { + mockSess := &mockSession{} + mockLife := &mockLifecycle{} + mockSess.On("Lifecycle").Return(mockLife).Once() + mockLife.On("Close").Return(nil).Once() + + mockReg.On("GetWithUser", "mas-fuad", types.SessionKey{Id: "myslug", Type: types.TunnelTypeHTTP}).Return(mockSess, nil).Once() + + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + return n.GetTerminateSessionEventResponse().Success + })).Return(nil).Once() + + err := c.handleTerminateSession(mockStream, evt) + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + mockLife.AssertExpectations(t) + }) + + t.Run("TunnelTypeUnknown", func(t *testing.T) { + badEvt := &proto.Events{ + Payload: &proto.Events_TerminateSessionEvent{ + TerminateSessionEvent: &proto.TerminateSessionEvent{ + TunnelType: proto.TunnelType(999), + }, + }, + } + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + resp := n.GetTerminateSessionEventResponse() + return !resp.Success && resp.Message != "" + })).Return(nil).Once() + + err := c.handleTerminateSession(mockStream, badEvt) + assert.NoError(t, err) + mockStream.AssertExpectations(t) + }) + + t.Run("SessionNotFound", func(t *testing.T) { + mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(nil, errors.New("not found")).Once() + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + resp := n.GetTerminateSessionEventResponse() + return !resp.Success && resp.Message == "not found" + })).Return(nil).Once() + + err := c.handleTerminateSession(mockStream, evt) + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + }) + + t.Run("CloseError", func(t *testing.T) { + mockSess := &mockSession{} + mockLife := &mockLifecycle{} + mockSess.On("Lifecycle").Return(mockLife).Once() + mockLife.On("Close").Return(errors.New("close fail")).Once() + mockReg.On("GetWithUser", mock.Anything, mock.Anything).Return(mockSess, nil).Once() + + mockStream.On("Send", mock.MatchedBy(func(n *proto.Node) bool { + resp := n.GetTerminateSessionEventResponse() + return !resp.Success && resp.Message == "close fail" + })).Return(nil).Once() + + err := c.handleTerminateSession(mockStream, evt) + assert.NoError(t, err) + mockReg.AssertExpectations(t) + mockStream.AssertExpectations(t) + mockLife.AssertExpectations(t) + }) +} + +func TestSubscribeAndProcess(t *testing.T) { + mockEventSvc := &mockEventServiceClient{} + c := &client{eventService: mockEventSvc} + ctx := context.Background() + backoff := time.Second + + t.Run("SubscribeError", func(t *testing.T) { + expectedErr := status.Error(codes.Unauthenticated, "unauth") + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once() + err := c.subscribeAndProcess(ctx, "id", "token", &backoff) + assert.ErrorIs(t, err, expectedErr) + }) + + t.Run("AuthSendError", func(t *testing.T) { + mockStream := &mockSubscribeClient{} + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once() + expectedErr := status.Error(codes.Internal, "send fail") + mockStream.On("Send", mock.Anything).Return(expectedErr).Once() + err := c.subscribeAndProcess(ctx, "id", "token", &backoff) + assert.ErrorIs(t, err, expectedErr) + }) + + t.Run("StreamError", func(t *testing.T) { + mockStream := &mockSubscribeClient{} + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(mockStream, nil).Once() + mockStream.On("Send", mock.Anything).Return(nil).Once() + expectedErr := status.Error(codes.Internal, "stream fail") + mockStream.On("Recv").Return(nil, expectedErr).Once() + err := c.subscribeAndProcess(ctx, "id", "token", &backoff) + assert.ErrorIs(t, err, expectedErr) + }) +} + +func TestSubscribeEvents(t *testing.T) { + mockEventSvc := &mockEventServiceClient{} + c := &client{eventService: mockEventSvc} + + t.Run("ReturnsOnError", func(t *testing.T) { + expectedErr := errors.New("fatal error") + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, expectedErr).Once() + err := c.SubscribeEvents(context.Background(), "id", "token") + assert.ErrorIs(t, err, expectedErr) + }) + + t.Run("RetryLoop", func(t *testing.T) { + oldB := initialBackoff + initialBackoff = 5 * time.Millisecond + defer func() { initialBackoff = oldB }() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + mockEventSvc.On("Subscribe", mock.Anything, mock.Anything).Return(nil, status.Error(codes.Unavailable, "unavailable")) + + err := c.SubscribeEvents(ctx, "id", "token") + assert.True(t, errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled)) + mockEventSvc.AssertExpectations(t) + }) +} + +func TestCheckServerHealth(t *testing.T) { + mockHealth := &mockHealthClient{} + old := healthNewHealthClient + healthNewHealthClient = func(cc grpc.ClientConnInterface) grpc_health_v1.HealthClient { + return mockHealth + } + defer func() { healthNewHealthClient = old }() + + c := &client{} + + t.Run("Success", func(t *testing.T) { + mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_SERVING}, nil).Once() + err := c.CheckServerHealth(context.Background()) + assert.NoError(t, err) + mockHealth.AssertExpectations(t) + }) + + t.Run("Error", func(t *testing.T) { + mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("health fail")).Once() + err := c.CheckServerHealth(context.Background()) + assert.ErrorContains(t, err, "health check failed: health fail") + mockHealth.AssertExpectations(t) + }) + + t.Run("NotServing", func(t *testing.T) { + mockHealth.On("Check", mock.Anything, mock.Anything, mock.Anything).Return(&grpc_health_v1.HealthCheckResponse{Status: grpc_health_v1.HealthCheckResponse_NOT_SERVING}, nil).Once() + err := c.CheckServerHealth(context.Background()) + assert.ErrorContains(t, err, "server not serving: NOT_SERVING") + mockHealth.AssertExpectations(t) + }) +} + +func TestNew_Error(t *testing.T) { + old := grpcNewClient + grpcNewClient = func(target string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + return nil, errors.New("dial fail") + } + defer func() { grpcNewClient = old }() + mockConfig := &MockConfig{} + + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("1234") + cli, err := New(mockConfig, &mockRegistry{}) + if err == nil || err.Error() != "failed to connect to gRPC server at localhost:1234: dial fail" { + t.Errorf("expected dial fail error, got %v", err) + } + if cli != nil { + t.Errorf("expected nil client") + } +} + +func TestNew(t *testing.T) { + mockConfig := &MockConfig{} + mockReg := &mockRegistry{} + mockConfig.On("GRPCAddress").Return("localhost") + mockConfig.On("GRPCPort").Return("1234") + cli, err := New(mockConfig, mockReg) + if err != nil { + t.Errorf("New() error = %v", err) + } + if cli == nil { + t.Fatal("New() returned nil client") + } + defer func(cli Client) { + _ = cli.Close() + }(cli) +} + +type MockConfig struct { + mock.Mock +} + +func (m *MockConfig) Domain() string { return m.Called().String(0) } +func (m *MockConfig) SSHPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) } +func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) PprofPort() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } + +type mockRegistry struct { + mock.Mock +} + +func (m *mockRegistry) Get(key registry.Key) (registry.Session, error) { + args := m.Called(key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} +func (m *mockRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { + args := m.Called(user, key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} +func (m *mockRegistry) Update(user string, oldKey, newKey registry.Key) error { + return m.Called(user, oldKey, newKey).Error(0) +} +func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session { + args := m.Called(user) + if args.Get(0) == nil { + return nil + } + return args.Get(0).([]registry.Session) +} +func (m *mockRegistry) Register(key registry.Key, session registry.Session) bool { + return m.Called(key, session).Bool(0) +} +func (m *mockRegistry) Remove(key registry.Key) { + m.Called(key) +} + +type mockSession struct { + mock.Mock +} + +func (m *mockSession) Lifecycle() lifecycle.Lifecycle { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(lifecycle.Lifecycle) +} +func (m *mockSession) Interaction() interaction.Interaction { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(interaction.Interaction) +} +func (m *mockSession) Detail() *types.Detail { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(*types.Detail) +} +func (m *mockSession) Slug() slug.Slug { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(slug.Slug) +} +func (m *mockSession) Forwarder() forwarder.Forwarder { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(forwarder.Forwarder) +} + +type mockInteraction struct { + mock.Mock +} + +func (m *mockInteraction) Start() { m.Called() } +func (m *mockInteraction) Stop() { m.Called() } +func (m *mockInteraction) Redraw() { m.Called() } +func (m *mockInteraction) SetWH(w, h int) { m.Called(w, h) } +func (m *mockInteraction) SetChannel(channel ssh.Channel) { m.Called(channel) } +func (m *mockInteraction) SetMode(mode types.InteractiveMode) { m.Called(mode) } +func (m *mockInteraction) Mode() types.InteractiveMode { + return m.Called().Get(0).(types.InteractiveMode) +} +func (m *mockInteraction) Send(message string) error { return m.Called(message).Error(0) } + +type mockLifecycle struct { + mock.Mock +} + +func (m *mockLifecycle) Close() error { return m.Called().Error(0) } +func (m *mockLifecycle) Channel() ssh.Channel { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(ssh.Channel) +} +func (m *mockLifecycle) Connection() ssh.Conn { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(ssh.Conn) +} +func (m *mockLifecycle) User() string { return m.Called().String(0) } +func (m *mockLifecycle) SetChannel(channel ssh.Channel) { m.Called(channel) } +func (m *mockLifecycle) SetStatus(status types.SessionStatus) { m.Called(status) } +func (m *mockLifecycle) IsActive() bool { return m.Called().Bool(0) } +func (m *mockLifecycle) StartedAt() time.Time { return m.Called().Get(0).(time.Time) } +func (m *mockLifecycle) PortRegistry() port.Port { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(port.Port) +} + +type mockEventServiceClient struct { + mock.Mock +} + +func (m *mockEventServiceClient) Subscribe(ctx context.Context, opts ...grpc.CallOption) (proto.EventService_SubscribeClient, error) { + args := m.Called(ctx, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(proto.EventService_SubscribeClient), args.Error(1) +} + +type mockSubscribeClient struct { + mock.Mock + grpc.ClientStream +} + +func (m *mockSubscribeClient) Send(n *proto.Node) error { return m.Called(n).Error(0) } +func (m *mockSubscribeClient) Recv() (*proto.Events, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*proto.Events), args.Error(1) +} +func (m *mockSubscribeClient) Context() context.Context { return m.Called().Get(0).(context.Context) } + +type mockUserServiceClient struct { + mock.Mock +} + +func (m *mockUserServiceClient) Check(ctx context.Context, in *proto.CheckRequest, opts ...grpc.CallOption) (*proto.CheckResponse, error) { + args := m.Called(ctx, in, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*proto.CheckResponse), args.Error(1) +} + +type mockHealthClient struct { + mock.Mock +} + +func (m *mockHealthClient) Check(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthCheckResponse, error) { + args := m.Called(ctx, in, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*grpc_health_v1.HealthCheckResponse), args.Error(1) +} + +func (m *mockHealthClient) Watch(ctx context.Context, in *grpc_health_v1.HealthCheckRequest, opts ...grpc.CallOption) (grpc_health_v1.Health_WatchClient, error) { + args := m.Called(ctx, in, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(grpc_health_v1.Health_WatchClient), args.Error(1) +} + +func (m *mockHealthClient) List(ctx context.Context, in *grpc_health_v1.HealthListRequest, opts ...grpc.CallOption) (*grpc_health_v1.HealthListResponse, error) { + args := m.Called(ctx, in, opts) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(*grpc_health_v1.HealthListResponse), args.Error(1) +} + +func TestProtoToTunnelType(t *testing.T) { + c := &client{} + tests := []struct { + name string + input proto.TunnelType + want types.TunnelType + wantErr bool + }{ + {"HTTP", proto.TunnelType_HTTP, types.TunnelTypeHTTP, false}, + {"TCP", proto.TunnelType_TCP, types.TunnelTypeTCP, false}, + {"Unknown", proto.TunnelType(999), types.TunnelTypeUNKNOWN, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := c.protoToTunnelType(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("protoToTunnelType() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("protoToTunnelType() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsConnectionError(t *testing.T) { + c := &client{} + tests := []struct { + name string + closing bool + err error + want bool + }{ + {"NilError", false, nil, false}, + {"Closing", true, io.EOF, false}, + {"EOF", false, io.EOF, true}, + {"Unavailable", false, status.Error(codes.Unavailable, "unavailable"), true}, + {"Canceled", false, status.Error(codes.Canceled, "canceled"), true}, + {"DeadlineExceeded", false, status.Error(codes.DeadlineExceeded, "deadline"), true}, + {"Internal", false, status.Error(codes.Internal, "internal"), false}, + {"WrappedEOF", false, errors.New("wrapped: " + io.EOF.Error()), false}, + } + + tests[7].err = fmt.Errorf("wrapped: %w", io.EOF) + tests[7].want = true + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c.closing = tt.closing + if got := c.isConnectionError(tt.err); got != tt.want { + t.Errorf("isConnectionError() = %v, want %v for error %v", got, tt.want, tt.err) + } + }) + } +} + +func TestGrowBackoff(t *testing.T) { + c := &client{} + tests := []struct { + name string + initial time.Duration + want time.Duration + }{ + {"NormalGrow", time.Second, 2 * time.Second}, + {"MaxLimit", 20 * time.Second, 30 * time.Second}, + {"AlreadyAtMax", 30 * time.Second, 30 * time.Second}, + {"OverMax", 40 * time.Second, 30 * time.Second}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + backoff := tt.initial + c.growBackoff(&backoff) + if backoff != tt.want { + t.Errorf("growBackoff() = %v, want %v", backoff, tt.want) + } + }) + } +} + +func TestWait(t *testing.T) { + c := &client{} + + t.Run("ZeroDuration", func(t *testing.T) { + err := c.wait(context.Background(), 0) + if err != nil { + t.Errorf("wait() zero duration error = %v", err) + } + }) + + t.Run("ContextCanceled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := c.wait(ctx, time.Minute) + if !errors.Is(err, context.Canceled) { + t.Errorf("wait() context canceled error = %v", err) + } + }) + + t.Run("Timeout", func(t *testing.T) { + start := time.Now() + err := c.wait(context.Background(), 10*time.Millisecond) + if err != nil { + t.Errorf("wait() timeout error = %v", err) + } + if time.Since(start) < 10*time.Millisecond { + t.Errorf("wait() returned too early") + } + }) +} diff --git a/internal/http/header/header_test.go b/internal/http/header/header_test.go new file mode 100644 index 0000000..b3f9228 --- /dev/null +++ b/internal/http/header/header_test.go @@ -0,0 +1,227 @@ +package header + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewRequest(t *testing.T) { + tests := []struct { + name string + data []byte + expectErr bool + errContains string + expectMethod string + expectPath string + expectVersion string + expectHeaders map[string]string + }{ + { + name: "success", + data: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\nX-Custom: value\r\n\r\n"), + expectErr: false, + expectMethod: "GET", + expectPath: "/path", + expectVersion: "HTTP/1.1", + expectHeaders: map[string]string{ + "Host": "example.com", + "X-Custom": "value", + }, + }, + { + name: "no CRLF in start line", + data: []byte("GET /path HTTP/1.1"), + expectErr: true, + errContains: "no CRLF found in start line", + }, + { + name: "invalid start line - missing method", + data: []byte("INVALID\r\n\r\n"), + expectErr: true, + errContains: "invalid start line: missing method", + }, + { + name: "invalid start line - missing version", + data: []byte("GET /path\r\n\r\n"), + expectErr: true, + errContains: "invalid start line: missing version", + }, + { + name: "invalid start line - multiple spaces", + data: []byte("GET /path HTTP/1.1\r\n\r\n"), + expectErr: false, + expectMethod: "GET", + expectPath: "", + expectVersion: "/path HTTP/1.1", + expectHeaders: map[string]string{}, + }, + { + name: "start line with trailing space", + data: []byte("GET / HTTP/1.1 \r\n\r\n"), + expectErr: false, + expectMethod: "GET", + expectPath: "/", + expectVersion: "HTTP/1.1 ", + expectHeaders: map[string]string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := NewRequest(tt.data) + if tt.expectErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + assert.Nil(t, req) + } else { + assert.NoError(t, err) + assert.NotNil(t, req) + assert.Equal(t, tt.expectMethod, req.Method()) + assert.Equal(t, tt.expectPath, req.Path()) + assert.Equal(t, tt.expectVersion, req.Version()) + for k, v := range tt.expectHeaders { + assert.Equal(t, v, req.Value(k)) + } + } + }) + } +} + +func TestRequestHeaderMethods(t *testing.T) { + data := []byte("GET / HTTP/1.1\r\nHost: original\r\n\r\n") + req, _ := NewRequest(data) + + req.Set("Host", "updated") + req.Set("X-New", "new-value") + assert.Equal(t, "updated", req.Value("Host")) + assert.Equal(t, "new-value", req.Value("X-New")) + + assert.Equal(t, "", req.Value("Non-Existent")) + + req.Remove("X-New") + assert.Equal(t, "", req.Value("X-New")) + + final := req.Finalize() + assert.Contains(t, string(final), "GET / HTTP/1.1\r\n") + assert.Contains(t, string(final), "Host: updated\r\n") + assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n"))) +} + +func TestNewResponse(t *testing.T) { + tests := []struct { + name string + data []byte + expectErr bool + errContains string + expectHeaders map[string]string + }{ + { + name: "success", + data: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"), + expectErr: false, + expectHeaders: map[string]string{ + "Content-Length": "0", + }, + }, + { + name: "invalid response - no CRLF", + data: []byte("HTTP/1.1 200 OK"), + expectErr: true, + errContains: "no CRLF found in start line", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := NewResponse(tt.data) + if tt.expectErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + assert.Nil(t, resp) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + for k, v := range tt.expectHeaders { + assert.Equal(t, v, resp.Value(k)) + } + } + }) + } +} + +func TestResponseHeaderMethods(t *testing.T) { + data := []byte("HTTP/1.1 200 OK\r\nServer: old\r\n\r\n") + resp, _ := NewResponse(data) + + resp.Set("Server", "new") + resp.Set("X-Res", "val") + assert.Equal(t, "new", resp.Value("Server")) + assert.Equal(t, "val", resp.Value("X-Res")) + + resp.Remove("X-Res") + assert.Equal(t, "", resp.Value("X-Res")) + + final := resp.Finalize() + assert.Contains(t, string(final), "HTTP/1.1 200 OK\r\n") + assert.Contains(t, string(final), "Server: new\r\n") + assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n"))) +} + +func TestSetRemainingHeaders(t *testing.T) { + tests := []struct { + name string + data []byte + initialHeaders map[string]string + expectHeaders map[string]string + }{ + { + name: "various header formats", + data: []byte("K1: V1\r\nK2:V2\r\n K3 : V3 \r\nNoColon\r\n\r\n"), + expectHeaders: map[string]string{ + "K1": "V1", + "K2": "V2", + "K3": "V3", + }, + }, + { + name: "no trailing CRLF", + data: []byte("K1: V1"), + expectHeaders: map[string]string{ + "K1": "V1", + }, + }, + { + name: "empty lines", + data: []byte("\r\nK1: V1"), + expectHeaders: map[string]string{}, + }, + { + name: "headers with only colon", + data: []byte(": value\r\nkey:\r\n"), + expectHeaders: map[string]string{ + "": "value", + "key": "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &requestHeader{headers: make(map[string]string)} + if tt.initialHeaders != nil { + req.headers = tt.initialHeaders + } + setRemainingHeaders(tt.data, req) + assert.Equal(t, len(tt.expectHeaders), len(req.headers)) + for k, v := range tt.expectHeaders { + assert.Equal(t, v, req.headers[k]) + } + }) + } +} diff --git a/internal/http/header/parser.go b/internal/http/header/parser.go index 861c49e..9a58d59 100644 --- a/internal/http/header/parser.go +++ b/internal/http/header/parser.go @@ -1,7 +1,6 @@ package header import ( - "bufio" "bytes" "fmt" ) @@ -36,31 +35,6 @@ func setRemainingHeaders(remaining []byte, header interface { } } -func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) { - header := &requestHeader{ - headers: make(map[string]string, 16), - } - - lineEnd := bytes.Index(headerData, []byte("\r\n")) - if lineEnd == -1 { - return nil, fmt.Errorf("invalid request: no CRLF found in start line") - } - - startLine := headerData[:lineEnd] - header.startLine = startLine - var err error - header.method, header.path, header.version, err = parseStartLine(startLine) - if err != nil { - return nil, err - } - - remaining := headerData[lineEnd+2:] - - setRemainingHeaders(remaining, header) - - return header, nil -} - func parseStartLine(startLine []byte) (method, path, version string, err error) { firstSpace := bytes.IndexByte(startLine, ' ') if firstSpace == -1 { @@ -80,51 +54,6 @@ func parseStartLine(startLine []byte) (method, path, version string, err error) return method, path, version, nil } -func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) { - header := &requestHeader{ - headers: make(map[string]string, 16), - } - - startLineBytes, err := br.ReadSlice('\n') - if err != nil { - return nil, err - } - - startLineBytes = bytes.TrimRight(startLineBytes, "\r\n") - header.startLine = make([]byte, len(startLineBytes)) - copy(header.startLine, startLineBytes) - - header.method, header.path, header.version, err = parseStartLine(header.startLine) - if err != nil { - return nil, err - } - - for { - lineBytes, err := br.ReadSlice('\n') - if err != nil { - return nil, err - } - - lineBytes = bytes.TrimRight(lineBytes, "\r\n") - - if len(lineBytes) == 0 { - break - } - - colonIdx := bytes.IndexByte(lineBytes, ':') - if colonIdx == -1 { - continue - } - - key := bytes.TrimSpace(lineBytes[:colonIdx]) - value := bytes.TrimSpace(lineBytes[colonIdx+1:]) - - header.headers[string(key)] = string(value) - } - - return header, nil -} - func finalize(startLine []byte, headers map[string]string) []byte { size := len(startLine) + 2 for key, val := range headers { diff --git a/internal/http/header/request.go b/internal/http/header/request.go index 1fbe57a..e35f169 100644 --- a/internal/http/header/request.go +++ b/internal/http/header/request.go @@ -1,19 +1,33 @@ package header import ( - "bufio" + "bytes" "fmt" ) -func NewRequest(r interface{}) (RequestHeader, error) { - switch v := r.(type) { - case []byte: - return parseHeadersFromBytes(v) - case *bufio.Reader: - return parseHeadersFromReader(v) - default: - return nil, fmt.Errorf("unsupported type: %T", r) +func NewRequest(headerData []byte) (RequestHeader, error) { + header := &requestHeader{ + headers: make(map[string]string, 16), } + + lineEnd := bytes.Index(headerData, []byte("\r\n")) + if lineEnd == -1 { + return nil, fmt.Errorf("invalid request: no CRLF found in start line") + } + + startLine := headerData[:lineEnd] + header.startLine = startLine + var err error + header.method, header.path, header.version, err = parseStartLine(startLine) + if err != nil { + return nil, err + } + + remaining := headerData[lineEnd+2:] + + setRemainingHeaders(remaining, header) + + return header, nil } func (req *requestHeader) Value(key string) string { diff --git a/internal/http/stream/stream.go b/internal/http/stream/stream.go index 97d2752..d339474 100644 --- a/internal/http/stream/stream.go +++ b/internal/http/stream/stream.go @@ -30,7 +30,6 @@ type http struct { remoteAddr net.Addr writer io.Writer reader io.Reader - headerBuf []byte buf []byte respHeader header.ResponseHeader reqHeader header.RequestHeader @@ -72,7 +71,10 @@ func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware { } func (hs *http) Close() error { - return hs.writer.(io.Closer).Close() + if closer, ok := hs.writer.(io.Closer); ok { + return closer.Close() + } + return nil } func (hs *http) CloseWrite() error { diff --git a/internal/http/stream/stream_test.go b/internal/http/stream/stream_test.go new file mode 100644 index 0000000..60c9d65 --- /dev/null +++ b/internal/http/stream/stream_test.go @@ -0,0 +1,765 @@ +package stream + +import ( + "bytes" + "io" + "strings" + "testing" + + "tunnel_pls/internal/http/header" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockAddr struct { + mock.Mock +} + +func (m *MockAddr) String() string { + args := m.Called() + return args.String(0) +} + +func (m *MockAddr) Network() string { + args := m.Called() + return args.String(0) +} + +type MockRequestMiddleware struct { + mock.Mock +} + +func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error { + args := m.Called(h) + return args.Error(0) +} + +type MockResponseMiddleware struct { + mock.Mock +} + +func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error { + args := m.Called(h, body) + return args.Error(0) +} + +type MockReadWriter struct { + mock.Mock + bytes.Buffer +} + +func (m *MockReadWriter) Read(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *MockReadWriter) Write(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *MockReadWriter) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockReadWriter) CloseWrite() error { + args := m.Called() + return args.Error(0) +} + +type MockReadWriterOnlyCloser struct { + mock.Mock + bytes.Buffer +} + +func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *MockReadWriterOnlyCloser) Close() error { + args := m.Called() + return args.Error(0) +} + +type MockWriterOnly struct { + mock.Mock +} + +func (m *MockWriterOnly) Write(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *MockWriterOnly) Read(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +type MockReader struct { + mock.Mock +} + +func (m *MockReader) Read(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +type MockWriter struct { + mock.Mock +} + +func (m *MockWriter) Write(p []byte) (int, error) { + ret := m.Called(p) + + var n int + var err error + + switch v := ret.Get(0).(type) { + case func([]byte) int: + n = v(p) + case int: + n = v + default: + n = len(p) + } + + switch v := ret.Get(1).(type) { + case func([]byte) error: + err = v(p) + case error: + err = v + default: + err = nil + } + + return n, err +} + +func (m *MockWriter) Close() error { + args := m.Called() + return args.Error(0) +} + +func TestHTTPMethods(t *testing.T) { + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + rw := new(MockReadWriter) + hs := New(rw, rw, addr) + + assert.Equal(t, addr, hs.RemoteAddr()) + + reqMW := new(MockRequestMiddleware) + hs.UseRequestMiddleware(reqMW) + assert.Equal(t, 1, len(hs.RequestMiddlewares())) + assert.Equal(t, reqMW, hs.RequestMiddlewares()[0]) + + respMW := new(MockResponseMiddleware) + hs.UseResponseMiddleware(respMW) + assert.Equal(t, 1, len(hs.ResponseMiddlewares())) + assert.Equal(t, respMW, hs.ResponseMiddlewares()[0]) + + reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n")) + hs.SetRequestHeader(reqH) +} + +func TestApplyMiddlewares(t *testing.T) { + tests := []struct { + name string + setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware) + apply func(HTTP, header.RequestHeader, header.ResponseHeader) error + verify func(*testing.T, header.RequestHeader, header.ResponseHeader) + expectErr bool + }{ + { + name: "apply request middleware success", + setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) { + reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) { + h := args.Get(0).(header.RequestHeader) + h.Set("X-Middleware", "true") + }).Return(nil) + hs.UseRequestMiddleware(reqMW) + }, + apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { + return hs.ApplyRequestMiddlewares(reqH) + }, + verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) { + assert.Equal(t, "true", reqH.Value("X-Middleware")) + }, + }, + { + name: "apply response middleware success", + setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) { + respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + h := args.Get(0).(header.ResponseHeader) + h.Set("X-Resp-Middleware", "true") + }).Return(nil) + hs.UseResponseMiddleware(respMW) + }, + apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { + return hs.ApplyResponseMiddlewares(respH, []byte("body")) + }, + verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) { + assert.Equal(t, "true", respH.Value("X-Resp-Middleware")) + }, + }, + { + name: "apply request middleware error", + setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) { + reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError) + hs.UseRequestMiddleware(reqMW) + }, + apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { + return hs.ApplyRequestMiddlewares(reqH) + }, + expectErr: true, + }, + { + name: "apply response middleware error", + setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) { + respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError) + hs.UseResponseMiddleware(respMW) + }, + apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { + return hs.ApplyResponseMiddlewares(respH, []byte("body")) + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n")) + respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n")) + + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + rw := new(MockReadWriter) + hs := New(rw, rw, addr) + + reqMW := new(MockRequestMiddleware) + respMW := new(MockResponseMiddleware) + tt.setup(hs, reqMW, respMW) + + err := tt.apply(hs, reqH, respH) + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.verify != nil { + tt.verify(t, reqH, respH) + } + } + + reqMW.AssertExpectations(t) + respMW.AssertExpectations(t) + }) + } +} + +func TestCloseMethods(t *testing.T) { + tests := []struct { + name string + setup func() (io.Writer, io.Reader) + op func(HTTP) error + verify func(*testing.T, io.Writer) + }{ + { + name: "Close success", + setup: func() (io.Writer, io.Reader) { + rw := new(MockReadWriter) + rw.On("Close").Return(nil) + return rw, rw + }, + op: func(hs HTTP) error { return hs.Close() }, + verify: func(t *testing.T, w io.Writer) { + w.(*MockReadWriter).AssertCalled(t, "Close") + }, + }, + { + name: "CloseWrite with CloseWrite implementation", + setup: func() (io.Writer, io.Reader) { + rw := new(MockReadWriter) + rw.On("CloseWrite").Return(nil) + return rw, rw + }, + op: func(hs HTTP) error { return hs.CloseWrite() }, + verify: func(t *testing.T, w io.Writer) { + w.(*MockReadWriter).AssertCalled(t, "CloseWrite") + }, + }, + { + name: "CloseWrite fallback to Close", + setup: func() (io.Writer, io.Reader) { + rw := new(MockReadWriterOnlyCloser) + rw.On("Close").Return(nil) + return rw, rw + }, + op: func(hs HTTP) error { return hs.CloseWrite() }, + verify: func(t *testing.T, w io.Writer) { + w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close") + }, + }, + { + name: "Close with No Closer", + setup: func() (io.Writer, io.Reader) { + w := new(MockWriterOnly) + r := new(MockReader) + return w, r + }, + op: func(hs HTTP) error { return hs.Close() }, + }, + { + name: "CloseWrite with No CloseWrite and No Closer", + setup: func() (io.Writer, io.Reader) { + w := new(MockWriterOnly) + r := new(MockReader) + return w, r + }, + op: func(hs HTTP) error { return hs.CloseWrite() }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + w, r := tt.setup() + hs := New(w, r, addr) + + assert.NotPanics(t, func() { + err := tt.op(hs) + assert.NoError(t, err) + }) + + if tt.verify != nil { + tt.verify(t, w) + } + }) + } +} + +func TestSplitHeaderAndBody(t *testing.T) { + tests := []struct { + name string + data []byte + delimiterIdx int + expectHeader []byte + expectBody []byte + }{ + { + name: "standard", + data: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nBodyContent"), + delimiterIdx: 31, + expectHeader: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"), + expectBody: []byte("BodyContent"), + }, + { + name: "empty body", + data: []byte("HTTP/1.1 200 OK\r\n\r\n"), + delimiterIdx: 15, + expectHeader: []byte("HTTP/1.1 200 OK\r\n\r\n"), + expectBody: []byte(""), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, b := splitHeaderAndBody(tt.data, tt.delimiterIdx) + assert.Equal(t, tt.expectHeader, h) + assert.Equal(t, tt.expectBody, b) + }) + } +} + +func TestIsHTTPHeader(t *testing.T) { + tests := []struct { + name string + buf []byte + expect bool + }{ + { + name: "valid request", + buf: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"), + expect: true, + }, + { + name: "valid response", + buf: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"), + expect: true, + }, + { + name: "invalid start line", + buf: []byte("NOT_HTTP /path\r\nHost: example.com\r\n\r\n"), + expect: false, + }, + { + name: "invalid header line (no colon)", + buf: []byte("GET / HTTP/1.1\r\nInvalidHeaderLine\r\n\r\n"), + expect: false, + }, + { + name: "invalid header line (colon at 0)", + buf: []byte("GET / HTTP/1.1\r\n: value\r\n\r\n"), + expect: false, + }, + { + name: "empty header section", + buf: []byte("GET / HTTP/1.1\r\n\r\n"), + expect: true, + }, + { + name: "multiple headers", + buf: []byte("GET / HTTP/1.1\r\nH1: V1\r\nH2: V2\r\n\r\n"), + expect: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isHTTPHeader(tt.buf) + assert.Equal(t, tt.expect, result) + }) + } +} + +func TestRead(t *testing.T) { + tests := []struct { + name string + input []byte + readLen int + expectContent string + expectRead int + expectErr bool + middlewareErr error + isHTTP bool + }{ + { + name: "valid http request", + input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\nBody"), + readLen: 100, + expectContent: "Body", + expectRead: 54, + isHTTP: true, + }, + { + name: "non-http data", + input: []byte("Some random data\r\n\r\nMore data"), + readLen: 100, + expectContent: "Some random data\r\n\r\nMore data", + expectRead: 29, + isHTTP: false, + }, + { + name: "no delimiter", + input: []byte("Partial data without delimiter"), + readLen: 100, + expectContent: "Partial data without delimiter", + expectRead: 30, + isHTTP: false, + }, + { + name: "middleware error", + input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\n"), + readLen: 100, + middlewareErr: assert.AnError, + expectErr: true, + isHTTP: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + reader := new(MockReader) + writer := new(MockWriterOnly) + + if tt.expectErr || tt.name == "valid http request" { + reader.On("Read", mock.Anything).Run(func(args mock.Arguments) { + p := args.Get(0).([]byte) + copy(p, tt.input) + }).Return(len(tt.input), io.EOF).Once() + } else { + reader.On("Read", mock.Anything).Run(func(args mock.Arguments) { + p := args.Get(0).([]byte) + copy(p, tt.input) + }).Return(len(tt.input), nil).Once() + } + + hs := New(writer, reader, addr) + + reqMW := new(MockRequestMiddleware) + if tt.isHTTP { + if tt.middlewareErr != nil { + reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr) + } else { + reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) { + h := args.Get(0).(header.RequestHeader) + h.Set("X-Middleware", "true") + }).Return(nil) + } + } + hs.UseRequestMiddleware(reqMW) + + p := make([]byte, tt.readLen) + n, err := hs.Read(p) + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectRead, n) + if tt.name == "valid http request" { + content := string(p[:n]) + assert.Contains(t, content, "GET / HTTP/1.1\r\n") + assert.Contains(t, content, "Host: test\r\n") + assert.Contains(t, content, "X-Middleware: true\r\n") + assert.True(t, bytes.HasSuffix(p[:n], []byte("\r\n\r\nBody"))) + } else { + assert.Equal(t, tt.expectContent, string(p[:n])) + } + } + + if tt.isHTTP { + reqMW.AssertExpectations(t) + } + reader.AssertExpectations(t) + }) + } +} + +func TestWrite(t *testing.T) { + tests := []struct { + name string + writes [][]byte + expectWritten string + expectErr bool + middlewareErr error + isHTTP bool + }{ + { + name: "valid http response in one write", + writes: [][]byte{ + []byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"), + }, + expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody", + isHTTP: true, + }, + { + name: "valid http response in multiple writes", + writes: [][]byte{ + []byte("HTTP/1.1 200 OK\r\n"), + []byte("Content-Length: 4\r\n\r\n"), + []byte("Body"), + }, + expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody", + isHTTP: true, + }, + { + name: "non-http data", + writes: [][]byte{ + []byte("Random data with delimiter\r\n\r\nFlush"), + }, + expectWritten: "Random data with delimiter\r\n\r\nFlush", + isHTTP: false, + }, + { + name: "bypass buffering", + writes: [][]byte{ + []byte("HTTP/1.1 200 OK\r\n\r\n"), + []byte("HTTP/1.1 200 OK\r\n\r\n"), + }, + expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" + + "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n", + isHTTP: true, + }, + { + name: "middleware error", + writes: [][]byte{ + []byte("HTTP/1.1 200 OK\r\n\r\n"), + }, + middlewareErr: assert.AnError, + expectErr: true, + isHTTP: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + var writtenData bytes.Buffer + writer := new(MockWriter) + + writer.On("Write", mock.Anything).Run(func(args mock.Arguments) { + p := args.Get(0).([]byte) + writtenData.Write(p) + }).Return(func(p []byte) int { + return len(p) + }, nil) + + reader := new(MockReader) + hs := New(writer, reader, addr) + + respMW := new(MockResponseMiddleware) + if tt.isHTTP { + if tt.middlewareErr != nil { + respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr) + } else { + respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + h := args.Get(0).(header.ResponseHeader) + h.Set("X-Resp-Middleware", "true") + }).Return(nil) + } + } + hs.UseResponseMiddleware(respMW) + + var totalN int + var err error + for _, w := range tt.writes { + var n int + n, err = hs.Write(w) + if err != nil { + break + } + totalN += n + } + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + written := writtenData.String() + if strings.HasPrefix(tt.expectWritten, "HTTP/") { + assert.Contains(t, written, "HTTP/1.1 200 OK\r\n") + assert.Contains(t, written, "X-Resp-Middleware: true\r\n") + if strings.Contains(tt.expectWritten, "Content-Length: 4") { + assert.Contains(t, written, "Content-Length: 4\r\n") + } + assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n")) + } else { + assert.Equal(t, tt.expectWritten, written) + } + } + + if tt.isHTTP { + respMW.AssertExpectations(t) + } + if tt.middlewareErr == nil { + writer.AssertExpectations(t) + } + }) + } +} + +func TestWriteErrors(t *testing.T) { + tests := []struct { + name string + setup func() (io.Writer, io.Reader) + data []byte + }{ + { + name: "write error in writeHeaderAndBody", + setup: func() (io.Writer, io.Reader) { + writer := new(MockWriter) + writer.On("Write", mock.Anything).Return(0, assert.AnError) + reader := new(MockReader) + return writer, reader + }, + data: []byte("HTTP/1.1 200 OK\r\n\r\n"), + }, + { + name: "write error in writeHeaderAndBody second write", + setup: func() (io.Writer, io.Reader) { + writer := new(MockWriter) + writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once() + writer.On("Write", mock.Anything).Return(0, assert.AnError).Once() + reader := new(MockReader) + return writer, reader + }, + data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"), + }, + { + name: "write error in writeRawBuffer", + setup: func() (io.Writer, io.Reader) { + writer := new(MockWriter) + writer.On("Write", mock.Anything).Return(0, assert.AnError) + reader := new(MockReader) + return writer, reader + }, + data: []byte("Not HTTP\r\n\r\nFlush"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + w, r := tt.setup() + hs := New(w, r, addr) + + _, err := hs.Write(tt.data) + assert.Error(t, err) + + w.(*MockWriter).AssertExpectations(t) + }) + } +} + +func TestReadEOF(t *testing.T) { + tests := []struct { + name string + setup func() io.Reader + expectN int + expectErr error + expectContent string + }{ + { + name: "read eof", + setup: func() io.Reader { + reader := new(MockReader) + reader.On("Read", mock.Anything).Run(func(args mock.Arguments) { + p := args.Get(0).([]byte) + copy(p, "data") + }).Return(4, io.EOF) + return reader + }, + expectN: 4, + expectErr: io.EOF, + expectContent: "data", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + reader := tt.setup() + hs := New(nil, reader, addr) + + p := make([]byte, 100) + n, err := hs.Read(p) + + assert.Equal(t, tt.expectN, n) + assert.Equal(t, tt.expectErr, err) + assert.Equal(t, tt.expectContent, string(p[:n])) + + reader.(*MockReader).AssertExpectations(t) + }) + } +} diff --git a/internal/key/key.go b/internal/key/key.go index 659abe3..b1c387a 100644 --- a/internal/key/key.go +++ b/internal/key/key.go @@ -5,6 +5,8 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" + "io" "log" "os" "path/filepath" @@ -12,7 +14,20 @@ import ( "golang.org/x/crypto/ssh" ) +var ( + rsaGenerateKey = rsa.GenerateKey + pemEncode = pem.Encode + sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) { + return ssh.NewPublicKey(key) + } + pubKeyWrite = func(w io.Writer, data []byte) (int, error) { + return w.Write(data) + } + osOpenFile = os.OpenFile +) + func GenerateSSHKeyIfNotExist(keyPath string) error { + var errGroup = make([]error, 0) if _, err := os.Stat(keyPath); err == nil { log.Printf("SSH key already exists at %s", keyPath) return nil @@ -20,7 +35,7 @@ func GenerateSSHKeyIfNotExist(keyPath string) error { log.Printf("SSH key not found at %s, generating new key pair...", keyPath) - privateKey, err := rsa.GenerateKey(rand.Reader, 4096) + privateKey, err := rsaGenerateKey(rand.Reader, 4096) if err != nil { return err } @@ -35,33 +50,37 @@ func GenerateSSHKeyIfNotExist(keyPath string) error { return err } - privateKeyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + privateKeyFile, err := osOpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return err } - defer privateKeyFile.Close() + defer func(privateKeyFile *os.File) { + errGroup = append(errGroup, privateKeyFile.Close()) + }(privateKeyFile) - if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil { + if err := pemEncode(privateKeyFile, privateKeyPEM); err != nil { return err } - publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + publicKey, err := sshNewPublicKey(&privateKey.PublicKey) if err != nil { return err } pubKeyPath := keyPath + ".pub" - pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + pubKeyFile, err := osOpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { return err } - defer pubKeyFile.Close() + defer func(pubKeyFile *os.File) { + errGroup = append(errGroup, pubKeyFile.Close()) + }(pubKeyFile) - _, err = pubKeyFile.Write(ssh.MarshalAuthorizedKey(publicKey)) + _, err = pubKeyWrite(pubKeyFile, ssh.MarshalAuthorizedKey(publicKey)) if err != nil { return err } log.Printf("SSH key pair generated successfully at %s and %s", keyPath, pubKeyPath) - return nil + return errors.Join(errGroup...) } diff --git a/internal/key/key_test.go b/internal/key/key_test.go new file mode 100644 index 0000000..d28c33b --- /dev/null +++ b/internal/key/key_test.go @@ -0,0 +1,235 @@ +package key + +import ( + "crypto/rsa" + "encoding/pem" + "errors" + "io" + "os" + "path/filepath" + "testing" + + "golang.org/x/crypto/ssh" +) + +func TestGenerateSSHKeyIfNotExist(t *testing.T) { + tempDir := t.TempDir() + + tests := []struct { + name string + setup func(t *testing.T, tempDir string) string + mockSetup func() func() + wantErr bool + errStr string + verify func(t *testing.T, keyPath string) + }{ + { + name: "GenerateNewKey", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "id_rsa") + }, + verify: func(t *testing.T, keyPath string) { + pubKeyPath := keyPath + ".pub" + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + t.Errorf("Private key file not created") + } + if _, err := os.Stat(pubKeyPath); os.IsNotExist(err) { + t.Errorf("Public key file not created") + } + privateKeyBytes, err := os.ReadFile(keyPath) + if err != nil { + t.Fatalf("Failed to read private key: %v", err) + } + if _, err = ssh.ParseRawPrivateKey(privateKeyBytes); err != nil { + t.Errorf("Failed to parse private key: %v", err) + } + publicKeyBytes, err := os.ReadFile(pubKeyPath) + if err != nil { + t.Fatalf("Failed to read public key: %v", err) + } + if _, _, _, _, err = ssh.ParseAuthorizedKey(publicKeyBytes); err != nil { + t.Errorf("Failed to parse public key: %v", err) + } + }, + }, + { + name: "DoNotOverwriteExistingKey", + setup: func(t *testing.T, tempDir string) string { + keyPath := filepath.Join(tempDir, "existing_id_rsa") + dummyPrivate := "dummy private" + dummyPublic := "dummy public" + if err := os.WriteFile(keyPath, []byte(dummyPrivate), 0600); err != nil { + t.Fatalf("Failed to create dummy private key: %v", err) + } + if err := os.WriteFile(keyPath+".pub", []byte(dummyPublic), 0644); err != nil { + t.Fatalf("Failed to create dummy public key: %v", err) + } + return keyPath + }, + verify: func(t *testing.T, keyPath string) { + gotPrivate, _ := os.ReadFile(keyPath) + if string(gotPrivate) != "dummy private" { + t.Errorf("Private key was overwritten") + } + gotPublic, _ := os.ReadFile(keyPath + ".pub") + if string(gotPublic) != "dummy public" { + t.Errorf("Public key was overwritten") + } + }, + }, + { + name: "CreateNestedDirectories", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "nested", "dir", "id_rsa") + }, + verify: func(t *testing.T, keyPath string) { + if _, err := os.Stat(keyPath); os.IsNotExist(err) { + t.Errorf("Private key file not created in nested directory") + } + }, + }, + { + name: "FailureMkdirAll", + setup: func(t *testing.T, tempDir string) string { + dirPath := filepath.Join(tempDir, "file_as_dir") + if err := os.WriteFile(dirPath, []byte("not a dir"), 0644); err != nil { + t.Fatalf("Failed to create file: %v", err) + } + return filepath.Join(dirPath, "id_rsa") + }, + wantErr: true, + }, + { + name: "PrivateExistsPublicMissing", + setup: func(t *testing.T, tempDir string) string { + keyPath := filepath.Join(tempDir, "partial_id_rsa") + if err := os.WriteFile(keyPath, []byte("private"), 0600); err != nil { + t.Fatalf("Failed to create private key: %v", err) + } + return keyPath + }, + verify: func(t *testing.T, keyPath string) { + if _, err := os.Stat(keyPath + ".pub"); !os.IsNotExist(err) { + t.Errorf("Public key should NOT have been created if private key existed") + } + }, + }, + { + name: "FailureRSAGenerateKey", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_rsa") + }, + mockSetup: func() func() { + old := rsaGenerateKey + rsaGenerateKey = func(random io.Reader, bits int) (*rsa.PrivateKey, error) { + return nil, errors.New("rsa error") + } + return func() { rsaGenerateKey = old } + }, + wantErr: true, + errStr: "rsa error", + }, + { + name: "FailureOpenFilePrivate", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_open_private") + }, + mockSetup: func() func() { + old := osOpenFile + osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + return nil, errors.New("open error") + } + return func() { osOpenFile = old } + }, + wantErr: true, + errStr: "open error", + }, + { + name: "FailurePemEncode", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_pem") + }, + mockSetup: func() func() { + old := pemEncode + pemEncode = func(out io.Writer, b *pem.Block) error { + return errors.New("pem error") + } + return func() { pemEncode = old } + }, + wantErr: true, + errStr: "pem error", + }, + { + name: "FailureSSHNewPublicKey", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_ssh") + }, + mockSetup: func() func() { + old := sshNewPublicKey + sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) { + return nil, errors.New("ssh error") + } + return func() { sshNewPublicKey = old } + }, + wantErr: true, + errStr: "ssh error", + }, + { + name: "FailureOpenFilePublic", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_open_public") + }, + mockSetup: func() func() { + old := osOpenFile + osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) { + if filepath.Ext(name) == ".pub" { + return nil, errors.New("open pub error") + } + return os.OpenFile(name, flag, perm) + } + return func() { osOpenFile = old } + }, + wantErr: true, + errStr: "open pub error", + }, + { + name: "FailurePubKeyWrite", + setup: func(t *testing.T, tempDir string) string { + return filepath.Join(tempDir, "fail_write") + }, + mockSetup: func() func() { + old := pubKeyWrite + pubKeyWrite = func(w io.Writer, data []byte) (int, error) { + return 0, errors.New("write error") + } + return func() { pubKeyWrite = old } + }, + wantErr: true, + errStr: "write error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keyPath := tt.setup(t, tempDir) + if tt.mockSetup != nil { + cleanup := tt.mockSetup() + defer cleanup() + } + + err := GenerateSSHKeyIfNotExist(keyPath) + + if (err != nil) != tt.wantErr { + t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && tt.errStr != "" && err != nil && err.Error() != tt.errStr { + t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErrStr %v", err, tt.errStr) + } + + if tt.verify != nil { + tt.verify(t, keyPath) + } + }) + } +} diff --git a/internal/middleware/forwardedfor_test.go b/internal/middleware/forwardedfor_test.go new file mode 100644 index 0000000..49f9980 --- /dev/null +++ b/internal/middleware/forwardedfor_test.go @@ -0,0 +1,126 @@ +package middleware + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type mockRequestHeader struct { + mock.Mock +} + +func (m *mockRequestHeader) Value(key string) string { + return m.Called(key).String(0) +} + +func (m *mockRequestHeader) Set(key string, value string) { + m.Called(key, value) +} + +func (m *mockRequestHeader) Remove(key string) { + m.Called(key) +} + +func (m *mockRequestHeader) Finalize() []byte { + return m.Called().Get(0).([]byte) +} + +func (m *mockRequestHeader) Method() string { + return m.Called().String(0) +} + +func (m *mockRequestHeader) Path() string { + return m.Called().String(0) +} + +func (m *mockRequestHeader) Version() string { + return m.Called().String(0) +} + +func TestForwardedFor_HandleRequest(t *testing.T) { + tests := []struct { + name string + addr net.Addr + expectedHost string + expectError bool + }{ + { + name: "valid IPv4 address", + addr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 8080}, + expectedHost: "192.168.1.100", + expectError: false, + }, + { + name: "valid IPv6 address", + addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 8080}, + expectedHost: "2001:db8::ff00:42:8329", + expectError: false, + }, + { + name: "invalid address format", + addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"}, + expectedHost: "", + expectError: true, + }, + { + name: "valid IPv4 address with port", + addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234}, + expectedHost: "127.0.0.1", + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ff := NewForwardedFor(tc.addr) + reqHeader := new(mockRequestHeader) + + if !tc.expectError { + reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return() + } + + err := ff.HandleRequest(reqHeader) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + reqHeader.AssertExpectations(t) + } + }) + } +} + +func TestNewForwardedFor(t *testing.T) { + tests := []struct { + name string + addr net.Addr + expectAddr net.Addr + }{ + { + name: "IPv4 address", + addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + expectAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + }, + { + name: "IPv6 address", + addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0}, + expectAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0}, + }, + { + name: "Unix address", + addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"}, + expectAddr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ff := NewForwardedFor(tc.addr) + assert.Equal(t, tc.expectAddr.String(), ff.addr.String()) + }) + } +} diff --git a/internal/middleware/tunnelfingerprint_test.go b/internal/middleware/tunnelfingerprint_test.go new file mode 100644 index 0000000..0054d1e --- /dev/null +++ b/internal/middleware/tunnelfingerprint_test.go @@ -0,0 +1,70 @@ +package middleware + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type mockResponseHeader struct { + mock.Mock +} + +func (m *mockResponseHeader) Value(key string) string { + return m.Called(key).String(0) +} + +func (m *mockResponseHeader) Set(key string, value string) { + m.Called(key, value) +} + +func (m *mockResponseHeader) Remove(key string) { + m.Called(key) +} + +func (m *mockResponseHeader) Finalize() []byte { + return m.Called().Get(0).([]byte) +} + +func TestTunnelFingerprintHandleResponse(t *testing.T) { + tests := []struct { + name string + expected map[string]string + body []byte + wantErr error + }{ + { + name: "Sets Server Header", + expected: map[string]string{"Server": "Tunnel Please"}, + body: []byte("Sample body"), + wantErr: nil, + }, + { + name: "Overwrites Server Header", + expected: map[string]string{"Server": "Tunnel Please"}, + body: nil, + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockHeader := new(mockResponseHeader) + for k, v := range tt.expected { + mockHeader.On("Set", k, v).Return() + } + + tunnelFingerprint := NewTunnelFingerprint() + + err := tunnelFingerprint.HandleResponse(mockHeader, tt.body) + assert.ErrorIs(t, err, tt.wantErr) + mockHeader.AssertExpectations(t) + }) + } +} + +func TestNewTunnelFingerprint(t *testing.T) { + instance := NewTunnelFingerprint() + assert.NotNil(t, instance) +} diff --git a/internal/port/port_test.go b/internal/port/port_test.go new file mode 100644 index 0000000..fcc64d3 --- /dev/null +++ b/internal/port/port_test.go @@ -0,0 +1,114 @@ +package port + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddRange(t *testing.T) { + tests := []struct { + name string + startPort uint16 + endPort uint16 + wantErr bool + }{ + {"normal range", 1000, 1002, false}, + {"invalid range", 2000, 1999, true}, + {"single port range", 3000, 3000, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm := New() + err := pm.AddRange(tt.startPort, tt.endPort) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestUnassigned(t *testing.T) { + pm := New() + _ = pm.AddRange(1000, 1002) + + tests := []struct { + name string + status map[uint16]bool + want uint16 + wantOk bool + }{ + {"all unassigned", map[uint16]bool{1000: false, 1001: false, 1002: false}, 1000, true}, + {"some assigned", map[uint16]bool{1000: true, 1001: false, 1002: true}, 1001, true}, + {"all assigned", map[uint16]bool{1000: true, 1001: true, 1002: true}, 0, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.status { + _ = pm.SetStatus(k, v) + } + got, gotOk := pm.Unassigned() + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.wantOk, gotOk) + }) + } +} + +func TestSetStatus(t *testing.T) { + pm := New() + _ = pm.AddRange(1000, 1002) + + tests := []struct { + name string + port uint16 + assigned bool + }{ + {"assign port 1000", 1000, true}, + {"unassign port 1001", 1001, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := pm.SetStatus(tt.port, tt.assigned) + assert.NoError(t, err) + + status, ok := pm.(*port).ports[tt.port] + assert.True(t, ok) + assert.Equal(t, tt.assigned, status) + }) + } +} + +func TestClaim(t *testing.T) { + pm := New() + _ = pm.AddRange(1000, 1002) + + tests := []struct { + name string + port uint16 + status bool + want bool + }{ + {"claim unassigned port", 1000, false, true}, + {"claim already assigned port", 1001, true, false}, + {"claim non-existent port", 5000, false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if _, exists := pm.(*port).ports[tt.port]; exists { + _ = pm.SetStatus(tt.port, tt.status) + } + + got := pm.Claim(tt.port) + assert.Equal(t, tt.want, got) + + finalState := pm.(*port).ports[tt.port] + assert.True(t, finalState) + }) + } +} diff --git a/internal/random/random.go b/internal/random/random.go index 929cc7b..cb793d4 100644 --- a/internal/random/random.go +++ b/internal/random/random.go @@ -1,12 +1,35 @@ package random -import "crypto/rand" +import ( + "crypto/rand" + "fmt" + "io" +) -func GenerateRandomString(length int) (string, error) { +var ( + ErrInvalidLength = fmt.Errorf("invalid length") +) + +type Random interface { + String(length int) (string, error) +} + +type random struct { + reader io.Reader +} + +func New() Random { + return &random{reader: rand.Reader} +} + +func (ran *random) String(length int) (string, error) { + if length < 0 { + return "", ErrInvalidLength + } const charset = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, length) - if _, err := rand.Read(b); err != nil { + if _, err := ran.reader.Read(b); err != nil { return "", err } diff --git a/internal/random/random_test.go b/internal/random/random_test.go new file mode 100644 index 0000000..e0cd512 --- /dev/null +++ b/internal/random/random_test.go @@ -0,0 +1,70 @@ +package random + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRandom_String(t *testing.T) { + tests := []struct { + name string + length int + wantErr bool + }{ + {"ValidLengthZero", 0, false}, + {"ValidPositiveLength", 10, false}, + {"NegativeLength", -1, true}, + {"VeryLargeLength", 1_000_000, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randomizer := New() + + result, err := randomizer.String(tt.length) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Len(t, result, tt.length) + } + }) + } +} + +func TestRandomWithFailingReader_String(t *testing.T) { + errBrainrot := assert.AnError + + tests := []struct { + name string + reader io.Reader + expectErr error + }{ + { + name: "failing reader", + reader: func() io.Reader { + return &failingReader{err: errBrainrot} + }(), + expectErr: errBrainrot, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randomizer := &random{reader: tt.reader} + result, err := randomizer.String(20) + assert.ErrorIs(t, err, tt.expectErr) + assert.Empty(t, result) + }) + } +} + +type failingReader struct { + err error +} + +func (f *failingReader) Read(p []byte) (int, error) { + return 0, f.err +} diff --git a/internal/registry/registry.go b/internal/registry/registry.go index 89cac48..e12ea0b 100644 --- a/internal/registry/registry.go +++ b/internal/registry/registry.go @@ -94,12 +94,13 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { return ErrInvalidSlug } - r.mu.Lock() - defer r.mu.Unlock() - if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { return ErrSlugInUse } + + r.mu.Lock() + defer r.mu.Unlock() + client, ok := r.byUser[user][oldKey] if !ok { return ErrSessionNotFound @@ -111,9 +112,6 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { client.Slug().Set(newKey.Id) r.slugIndex[newKey] = user - if r.byUser[user] == nil { - r.byUser[user] = make(map[Key]Session) - } r.byUser[user][newKey] = client return nil } diff --git a/internal/registry/registry_test.go b/internal/registry/registry_test.go new file mode 100644 index 0000000..484b4e9 --- /dev/null +++ b/internal/registry/registry_test.go @@ -0,0 +1,695 @@ +package registry + +import ( + "sync" + "testing" + "time" + "tunnel_pls/internal/port" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "golang.org/x/crypto/ssh" +) + +type mockSession struct { + mock.Mock +} + +func (m *mockSession) Lifecycle() lifecycle.Lifecycle { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(lifecycle.Lifecycle) +} +func (m *mockSession) Interaction() interaction.Interaction { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(interaction.Interaction) +} +func (m *mockSession) Forwarder() forwarder.Forwarder { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(forwarder.Forwarder) +} +func (m *mockSession) Slug() slug.Slug { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(slug.Slug) +} +func (m *mockSession) Detail() *types.Detail { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(*types.Detail) +} + +type mockLifecycle struct { + mock.Mock +} + +func (ml *mockLifecycle) Channel() ssh.Channel { + args := ml.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(ssh.Channel) +} + +func (ml *mockLifecycle) Connection() ssh.Conn { + args := ml.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(ssh.Conn) +} + +func (ml *mockLifecycle) PortRegistry() port.Port { + args := ml.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(port.Port) +} + +func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { ml.Called(channel) } +func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { ml.Called(status) } +func (ml *mockLifecycle) IsActive() bool { return ml.Called().Bool(0) } +func (ml *mockLifecycle) StartedAt() time.Time { return ml.Called().Get(0).(time.Time) } +func (ml *mockLifecycle) Close() error { return ml.Called().Error(0) } +func (ml *mockLifecycle) User() string { return ml.Called().String(0) } + +type mockSlug struct { + mock.Mock +} + +func (ms *mockSlug) Set(slug string) { ms.Called(slug) } +func (ms *mockSlug) String() string { return ms.Called().String(0) } + +func createMockSession(user ...string) *mockSession { + u := "user1" + if len(user) > 0 { + u = user[0] + } + m := new(mockSession) + ml := new(mockLifecycle) + ml.On("User").Return(u).Maybe() + m.On("Lifecycle").Return(ml).Maybe() + ms := new(mockSlug) + ms.On("Set", mock.Anything).Maybe() + m.On("Slug").Return(ms).Maybe() + m.On("Interaction").Return(nil).Maybe() + m.On("Forwarder").Return(nil).Maybe() + m.On("Detail").Return(nil).Maybe() + return m +} + +func TestNewRegistry(t *testing.T) { + r := NewRegistry() + require.NotNil(t, r) +} + +func TestRegistry_Get(t *testing.T) { + tests := []struct { + name string + setupFunc func(r *registry) + key types.SessionKey + wantErr error + wantResult bool + }{ + { + name: "session found", + setupFunc: func(r *registry) { + user := "user1" + key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} + session := createMockSession(user) + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser[user] = map[types.SessionKey]Session{ + key: session, + } + r.slugIndex[key] = user + }, + key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}, + wantErr: nil, + wantResult: true, + }, + { + name: "session not found in slugIndex", + setupFunc: func(r *registry) {}, + key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}, + wantErr: ErrSessionNotFound, + }, + { + name: "session not found in byUser", + setupFunc: func(r *registry) { + r.mu.Lock() + defer r.mu.Unlock() + r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user" + }, + key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}, + wantErr: ErrSessionNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := ®istry{ + byUser: make(map[string]map[types.SessionKey]Session), + slugIndex: make(map[types.SessionKey]string), + mu: sync.RWMutex{}, + } + tt.setupFunc(r) + + session, err := r.Get(tt.key) + + assert.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantResult, session != nil) + }) + } +} + +func TestRegistry_GetWithUser(t *testing.T) { + tests := []struct { + name string + setupFunc func(r *registry) + user string + key types.SessionKey + wantErr error + wantResult bool + }{ + { + name: "session found", + setupFunc: func(r *registry) { + user := "user1" + key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} + session := createMockSession() + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser[user] = map[types.SessionKey]Session{ + key: session, + } + r.slugIndex[key] = user + }, + user: "user1", + key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}, + wantErr: nil, + wantResult: true, + }, + { + name: "session not found in slugIndex", + setupFunc: func(r *registry) {}, + user: "user1", + key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}, + wantErr: ErrSessionNotFound, + }, + { + name: "session not found in byUser", + setupFunc: func(r *registry) { + r.mu.Lock() + defer r.mu.Unlock() + r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user" + }, + user: "user1", + key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}, + wantErr: ErrSessionNotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := ®istry{ + byUser: make(map[string]map[types.SessionKey]Session), + slugIndex: make(map[types.SessionKey]string), + mu: sync.RWMutex{}, + } + tt.setupFunc(r) + + session, err := r.GetWithUser(tt.user, tt.key) + + assert.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantResult, session != nil) + }) + } +} + +func TestRegistry_Update(t *testing.T) { + tests := []struct { + name string + user string + setupFunc func(r *registry) (oldKey, newKey types.SessionKey) + wantErr error + }{ + { + name: "change slug success", + user: "user1", + setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { + oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} + newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP} + session := createMockSession("user1") + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser["user1"] = map[types.SessionKey]Session{ + oldKey: session, + } + r.slugIndex[oldKey] = "user1" + + return oldKey, newKey + }, + wantErr: nil, + }, + { + name: "change slug to already used slug", + user: "user1", + setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { + oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} + newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP} + session := createMockSession() + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser["user1"] = map[types.SessionKey]Session{ + oldKey: session, + newKey: session, + } + r.slugIndex[oldKey] = "user1" + r.slugIndex[newKey] = "user1" + + return oldKey, newKey + }, + wantErr: ErrSlugInUse, + }, + { + name: "change slug to forbidden slug", + user: "user1", + setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { + oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} + newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP} + session := createMockSession() + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser["user1"] = map[types.SessionKey]Session{ + oldKey: session, + } + r.slugIndex[oldKey] = "user1" + + return oldKey, newKey + }, + wantErr: ErrForbiddenSlug, + }, + { + name: "change slug to invalid slug", + user: "user1", + setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { + oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} + newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP} + session := createMockSession() + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser["user1"] = map[types.SessionKey]Session{ + oldKey: session, + } + r.slugIndex[oldKey] = "user1" + + return oldKey, newKey + }, + wantErr: ErrInvalidSlug, + }, + { + name: "change slug but session not found", + user: "user2", + setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { + oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP} + newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP} + session := createMockSession() + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser["user1"] = map[types.SessionKey]Session{ + types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session, + } + r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1" + + return oldKey, newKey + }, + wantErr: ErrSessionNotFound, + }, + { + name: "change slug but session is not in the map", + user: "user2", + setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { + oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP} + newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP} + session := createMockSession() + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser["user1"] = map[types.SessionKey]Session{ + types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session, + } + r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1" + + return oldKey, newKey + }, + wantErr: ErrSessionNotFound, + }, + { + name: "change slug with same slug", + user: "user1", + setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { + oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} + newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP} + session := createMockSession() + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser["user1"] = map[types.SessionKey]Session{ + oldKey: session, + } + r.slugIndex[oldKey] = "user1" + + return oldKey, newKey + }, + wantErr: ErrSlugUnchanged, + }, + { + name: "tcp tunnel cannot change slug", + user: "user1", + setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) { + oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP} + newKey := oldKey + session := createMockSession() + + r.mu.Lock() + defer r.mu.Unlock() + r.byUser["user1"] = map[types.SessionKey]Session{ + oldKey: session, + } + r.slugIndex[oldKey] = "user1" + + return oldKey, newKey + }, + wantErr: ErrSlugChangeNotAllowed, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := ®istry{ + byUser: make(map[string]map[types.SessionKey]Session), + slugIndex: make(map[types.SessionKey]string), + mu: sync.RWMutex{}, + } + + oldKey, newKey := tt.setupFunc(r) + + err := r.Update(tt.user, oldKey, newKey) + assert.ErrorIs(t, err, tt.wantErr) + + if err == nil { + r.mu.RLock() + defer r.mu.RUnlock() + _, ok := r.byUser[tt.user][newKey] + assert.True(t, ok, "newKey not found in registry") + _, ok = r.byUser[tt.user][oldKey] + assert.False(t, ok, "oldKey still exists in registry") + } + }) + } +} + +func TestRegistry_Register(t *testing.T) { + tests := []struct { + name string + user string + setupFunc func(r *registry) Key + wantOK bool + }{ + { + name: "register new key successfully", + user: "user1", + setupFunc: func(r *registry) Key { + key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} + return key + }, + wantOK: true, + }, + { + name: "register already existing key fails", + user: "user1", + setupFunc: func(r *registry) Key { + key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP} + session := createMockSession() + + r.mu.Lock() + r.byUser["user1"] = map[Key]Session{key: session} + r.slugIndex[key] = "user1" + r.mu.Unlock() + + return key + }, + wantOK: false, + }, + { + name: "register multiple keys for same user", + user: "user1", + setupFunc: func(r *registry) Key { + firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP} + session := createMockSession() + r.mu.Lock() + r.byUser["user1"] = map[Key]Session{firstKey: session} + r.slugIndex[firstKey] = "user1" + r.mu.Unlock() + + return types.SessionKey{Id: "second", Type: types.TunnelTypeHTTP} + }, + wantOK: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := ®istry{ + byUser: make(map[string]map[Key]Session), + slugIndex: make(map[Key]string), + mu: sync.RWMutex{}, + } + + key := tt.setupFunc(r) + session := createMockSession() + + ok := r.Register(key, session) + assert.Equal(t, tt.wantOK, ok) + + if ok { + r.mu.RLock() + defer r.mu.RUnlock() + assert.Equal(t, session, r.byUser[tt.user][key], "session not stored in byUser") + assert.Equal(t, tt.user, r.slugIndex[key], "slugIndex not updated") + } + }) + } +} + +func TestRegistry_GetAllSessionFromUser(t *testing.T) { + tests := []struct { + name string + setupFunc func(r *registry) string + expectN int + }{ + { + name: "user has no sessions", + setupFunc: func(r *registry) string { + return "user1" + }, + expectN: 0, + }, + { + name: "user has multiple sessions", + setupFunc: func(r *registry) string { + user := "user1" + key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP} + key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP} + r.mu.Lock() + r.byUser[user] = map[Key]Session{ + key1: createMockSession(), + key2: createMockSession(), + } + r.mu.Unlock() + return user + }, + expectN: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := ®istry{ + byUser: make(map[string]map[Key]Session), + slugIndex: make(map[Key]string), + mu: sync.RWMutex{}, + } + user := tt.setupFunc(r) + sessions := r.GetAllSessionFromUser(user) + assert.Len(t, sessions, tt.expectN) + }) + } +} + +func TestRegistry_Remove(t *testing.T) { + tests := []struct { + name string + setupFunc func(r *registry) (string, types.SessionKey) + key types.SessionKey + verify func(*testing.T, *registry, string, types.SessionKey) + }{ + { + name: "remove existing key", + setupFunc: func(r *registry) (string, types.SessionKey) { + user := "user1" + key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP} + session := createMockSession() + r.mu.Lock() + r.byUser[user] = map[Key]Session{key: session} + r.slugIndex[key] = user + r.mu.Unlock() + return user, key + }, + verify: func(t *testing.T, r *registry, user string, key types.SessionKey) { + _, ok := r.byUser[user][key] + assert.False(t, ok, "expected key to be removed from byUser") + _, ok = r.slugIndex[key] + assert.False(t, ok, "expected key to be removed from slugIndex") + _, ok = r.byUser[user] + assert.False(t, ok, "expected user to be removed from byUser map") + }, + }, + { + name: "remove non-existing key", + setupFunc: func(r *registry) (string, types.SessionKey) { + return "", types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP} + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := ®istry{ + byUser: make(map[string]map[Key]Session), + slugIndex: make(map[Key]string), + mu: sync.RWMutex{}, + } + user, key := tt.setupFunc(r) + if user == "" { + key = tt.key + } + r.Remove(key) + if tt.verify != nil { + tt.verify(t, r, user, key) + } + }) + } +} + +func TestIsValidSlug(t *testing.T) { + tests := []struct { + slug string + want bool + }{ + {"abc", true}, + {"abc-123", true}, + {"a", false}, + {"verybigdihsixsevenlabubu", false}, + {"-iamsigma", false}, + {"ligma-", false}, + {"invalid$", false}, + {"valid-slug1", true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.slug, func(t *testing.T) { + got := isValidSlug(tt.slug) + if got != tt.want { + t.Errorf("isValidSlug(%q) = %v; want %v", tt.slug, got, tt.want) + } + }) + } +} + +func TestIsValidSlugChar(t *testing.T) { + tests := []struct { + char byte + want bool + }{ + {'a', true}, + {'z', true}, + {'0', true}, + {'9', true}, + {'-', true}, + {'A', false}, + {'$', false}, + } + + for _, tt := range tests { + tt := tt + t.Run(string(tt.char), func(t *testing.T) { + got := isValidSlugChar(tt.char) + if got != tt.want { + t.Errorf("isValidSlugChar(%q) = %v; want %v", tt.char, got, tt.want) + } + }) + } +} + +func TestIsForbiddenSlug(t *testing.T) { + forbiddenSlugs = map[string]struct{}{ + "admin": {}, + "root": {}, + } + + tests := []struct { + slug string + want bool + }{ + {"admin", true}, + {"root", true}, + {"user", false}, + {"guest", false}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.slug, func(t *testing.T) { + got := isForbiddenSlug(tt.slug) + if got != tt.want { + t.Errorf("isForbiddenSlug(%q) = %v; want %v", tt.slug, got, tt.want) + } + }) + } +} diff --git a/internal/transport/http.go b/internal/transport/http.go index dd091c3..5c4648d 100644 --- a/internal/transport/http.go +++ b/internal/transport/http.go @@ -4,27 +4,28 @@ import ( "errors" "log" "net" + "tunnel_pls/internal/config" "tunnel_pls/internal/registry" ) type httpServer struct { handler *httpHandler - port string + config config.Config } -func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { +func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport { return &httpServer{ - handler: newHTTPHandler(domain, sessionRegistry, redirectTLS), - port: port, + handler: newHTTPHandler(config, sessionRegistry), + config: config, } } func (ht *httpServer) Listen() (net.Listener, error) { - return net.Listen("tcp", ":"+ht.port) + return net.Listen("tcp", ":"+ht.config.HTTPPort()) } func (ht *httpServer) Serve(listener net.Listener) error { - log.Printf("HTTP server is starting on port %s", ht.port) + log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort()) for { conn, err := listener.Accept() if err != nil { @@ -35,6 +36,6 @@ func (ht *httpServer) Serve(listener net.Listener) error { continue } - go ht.handler.handler(conn, false) + go ht.handler.Handler(conn, false) } } diff --git a/internal/transport/http_test.go b/internal/transport/http_test.go new file mode 100644 index 0000000..cd3cf68 --- /dev/null +++ b/internal/transport/http_test.go @@ -0,0 +1,135 @@ +package transport + +import ( + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNewHTTPServer(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + + srv := NewHTTPServer(mockConfig, msr) + assert.NotNil(t, srv) + + httpSrv, ok := srv.(*httpServer) + assert.True(t, ok) + assert.Equal(t, msr, httpSrv.handler.sessionRegistry) + assert.NotNil(t, srv) +} + +func TestHTTPServer_Listen(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + srv := NewHTTPServer(mockConfig, msr) + + listener, err := srv.Listen() + assert.NoError(t, err) + assert.NotNil(t, listener) + err = listener.Close() + assert.NoError(t, err) +} + +func TestHTTPServer_Serve(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + srv := NewHTTPServer(mockConfig, msr) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + + go func() { + time.Sleep(100 * time.Millisecond) + err = listener.Close() + assert.NoError(t, err) + }() + + err = srv.Serve(listener) + assert.True(t, errors.Is(err, net.ErrClosed)) +} + +func TestHTTPServer_Serve_AcceptError(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + srv := NewHTTPServer(mockConfig, msr) + + ml := new(mockListener) + ml.On("Accept").Return(nil, errors.New("accept error")).Once() + ml.On("Accept").Return(nil, net.ErrClosed).Once() + + err := srv.Serve(ml) + assert.True(t, errors.Is(err, net.ErrClosed)) + ml.AssertExpectations(t) +} + +func TestHTTPServer_Serve_Success(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + mockConfig.On("HeaderSize").Return(4096) + mockConfig.On("TLSRedirect").Return(false) + srv := NewHTTPServer(mockConfig, msr) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + listenerport := listener.Addr().(*net.TCPAddr).Port + + go func() { + _ = srv.Serve(listener) + }() + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport)) + assert.NoError(t, err) + + _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) + + time.Sleep(100 * time.Millisecond) + err = conn.Close() + assert.NoError(t, err) + + err = listener.Close() + assert.NoError(t, err) + +} + +type mockListener struct { + mock.Mock +} + +func (m *mockListener) Accept() (net.Conn, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(net.Conn), args.Error(1) +} + +func (m *mockListener) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *mockListener) Addr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go index 8bab4a0..67aa6fb 100644 --- a/internal/transport/httphandler.go +++ b/internal/transport/httphandler.go @@ -1,7 +1,8 @@ package transport import ( - "bufio" + "bytes" + "context" "errors" "fmt" "io" @@ -10,6 +11,7 @@ import ( "net/http" "strings" "time" + "tunnel_pls/internal/config" "tunnel_pls/internal/http/header" "tunnel_pls/internal/http/stream" "tunnel_pls/internal/middleware" @@ -20,16 +22,14 @@ import ( ) type httpHandler struct { - domain string + config config.Config sessionRegistry registry.Registry - redirectTLS bool } -func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { +func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler { return &httpHandler{ - domain: domain, + config: config, sessionRegistry: sessionRegistry, - redirectTLS: redirectTLS, } } @@ -52,13 +52,28 @@ func (hh *httpHandler) badRequest(conn net.Conn) error { return nil } -func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { +func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) { defer hh.closeConnection(conn) - dstReader := bufio.NewReader(conn) - reqhf, err := header.NewRequest(dstReader) + _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + buf := make([]byte, hh.config.HeaderSize()) + n, err := conn.Read(buf) + if err != nil { + _ = hh.badRequest(conn) + return + } + + if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 { + _ = hh.badRequest(conn) + return + } + + _ = conn.SetReadDeadline(time.Time{}) + + reqhf, err := header.NewRequest(buf[:n]) if err != nil { log.Printf("Error creating request header: %v", err) + _ = hh.badRequest(conn) return } @@ -69,7 +84,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { } if hh.shouldRedirectToTLS(isTLS) { - _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain)) + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain())) return } @@ -77,13 +92,16 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { return } - sshSession, err := hh.getSession(slug) + sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ + Id: slug, + Type: types.TunnelTypeHTTP, + }) if err != nil { _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) return } - hw := stream.New(conn, dstReader, conn.RemoteAddr()) + hw := stream.New(conn, conn, conn.RemoteAddr()) defer func(hw stream.HTTP) { err = hw.Close() if err != nil { @@ -102,14 +120,14 @@ func (hh *httpHandler) closeConnection(conn net.Conn) { func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) { host := strings.Split(reqhf.Value("Host"), ".") - if len(host) < 1 { + if len(host) <= 1 { return "", errors.New("invalid host") } return host[0], nil } func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool { - return !isTLS && hh.redirectTLS + return !isTLS && hh.config.TLSRedirect() } func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { @@ -128,29 +146,22 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { )) if err != nil { log.Println("Failed to write 200 OK:", err) + return true } return true } -func (hh *httpHandler) getSession(slug string) (registry.Session, error) { - sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ - Id: slug, - Type: types.TunnelTypeHTTP, - }) - if err != nil { - return nil, err - } - return sshSession, nil -} - func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { - channel, err := hh.openForwardedChannel(hw, sshSession) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, hw.RemoteAddr()) if err != nil { - log.Printf("Failed to establish channel: %v", err) - sshSession.Forwarder().WriteBadGatewayResponse(hw) + log.Printf("Failed to open forwarded-tcpip channel: %v", err) return } + go ssh.DiscardRequests(reqs) + defer func() { err = channel.Close() if err != nil && !errors.Is(err, io.EOF) { @@ -167,47 +178,6 @@ func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.Requ sshSession.Forwarder().HandleConnection(hw, channel) } -func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) { - payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr()) - - type channelResult struct { - channel ssh.Channel - reqs <-chan *ssh.Request - err error - } - - resultChan := make(chan channelResult, 1) - - go func() { - channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) - select { - case resultChan <- channelResult{channel, reqs, err}: - default: - hh.cleanupUnusedChannel(channel, reqs) - } - }() - - select { - case result := <-resultChan: - if result.err != nil { - return nil, result.err - } - go ssh.DiscardRequests(result.reqs) - return result.channel, nil - case <-time.After(5 * time.Second): - return nil, errors.New("timeout opening forwarded-tcpip channel") - } -} - -func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) { - if channel != nil { - if err := channel.Close(); err != nil { - log.Printf("Failed to close unused channel: %v", err) - } - go ssh.DiscardRequests(reqs) - } -} - func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) { fingerprintMiddleware := middleware.NewTunnelFingerprint() forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr()) diff --git a/internal/transport/httphandler_test.go b/internal/transport/httphandler_test.go new file mode 100644 index 0000000..6801b22 --- /dev/null +++ b/internal/transport/httphandler_test.go @@ -0,0 +1,717 @@ +package transport + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "strings" + "sync" + "testing" + "time" + "tunnel_pls/internal/registry" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockSessionRegistry struct { + mock.Mock +} + +func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) { + args := m.Called(key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { + args := m.Called(user, key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error { + args := m.Called(user, oldKey, newKey) + return args.Error(0) +} + +func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool { + args := m.Called(key, session) + return args.Bool(0) +} + +func (m *MockSessionRegistry) Remove(key registry.Key) { + m.Called(key) +} + +func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session { + args := m.Called(user) + return args.Get(0).([]registry.Session) +} + +func (m *MockSessionRegistry) Slug() slug.Slug { + args := m.Called() + return args.Get(0).(slug.Slug) +} + +type MockSession struct { + mock.Mock +} + +func (m *MockSession) Lifecycle() lifecycle.Lifecycle { + args := m.Called() + return args.Get(0).(lifecycle.Lifecycle) +} + +func (m *MockSession) Interaction() interaction.Interaction { + args := m.Called() + return args.Get(0).(interaction.Interaction) +} + +func (m *MockSession) Forwarder() forwarder.Forwarder { + args := m.Called() + return args.Get(0).(forwarder.Forwarder) +} + +func (m *MockSession) Slug() slug.Slug { + args := m.Called() + return args.Get(0).(slug.Slug) +} + +func (m *MockSession) Detail() *types.Detail { + args := m.Called() + return args.Get(0).(*types.Detail) +} + +type MockSSHChannel struct { + ssh.Channel + mock.Mock +} + +func (m *MockSSHChannel) Write(data []byte) (int, error) { + args := m.Called(data) + return args.Int(0), args.Error(1) +} + +func (m *MockSSHChannel) Close() error { + args := m.Called() + return args.Error(0) +} + +type MockForwarder struct { + mock.Mock +} + +func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { + m.Called(dst, src) +} + +func (m *MockForwarder) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockForwarder) TunnelType() types.TunnelType { + args := m.Called() + return args.Get(0).(types.TunnelType) +} + +func (m *MockForwarder) ForwardedPort() uint16 { + args := m.Called() + return uint16(args.Int(0)) +} + +func (m *MockForwarder) SetType(tunnelType types.TunnelType) { + m.Called(tunnelType) +} + +func (m *MockForwarder) SetForwardedPort(port uint16) { + m.Called(port) +} + +func (m *MockForwarder) SetListener(listener net.Listener) { + m.Called(listener) +} + +func (m *MockForwarder) Listener() net.Listener { + args := m.Called() + return args.Get(0).(net.Listener) +} + +func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(ctx, origin) + if args.Get(0) == nil { + return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) + } + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +type MockConn struct { + mock.Mock + ReadBuffer *bytes.Buffer +} + +func (m *MockConn) LocalAddr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} + +func (m *MockConn) SetDeadline(t time.Time) error { + args := m.Called(t) + return args.Error(0) +} + +func (m *MockConn) SetReadDeadline(t time.Time) error { + args := m.Called(t) + return args.Error(0) +} + +func (m *MockConn) SetWriteDeadline(t time.Time) error { + args := m.Called(t) + return args.Error(0) +} + +func (m *MockConn) Read(b []byte) (n int, err error) { + if m.ReadBuffer != nil { + return m.ReadBuffer.Read(b) + } + args := m.Called(b) + return args.Int(0), args.Error(1) +} + +func (m *MockConn) Write(b []byte) (n int, err error) { + args := m.Called(b) + if args.Int(0) == -1 { + return len(b), args.Error(1) + } + return args.Int(0), args.Error(1) +} + +func (m *MockConn) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockConn) RemoteAddr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} + +type wrappedConn struct { + net.Conn + remoteAddr net.Addr +} + +func (c *wrappedConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func TestNewHTTPHandler(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + mockConfig.On("Domain").Return("domain") + mockConfig.On("TLSRedirect").Return(false) + hh := newHTTPHandler(mockConfig, msr) + assert.NotNil(t, hh) + assert.Equal(t, msr, hh.sessionRegistry) +} + +func TestHandler(t *testing.T) { + tests := []struct { + name string + isTLS bool + redirectTLS bool + request []byte + expected []byte + setupMocks func(*MockSessionRegistry) + setupConn func() (net.Conn, net.Conn) + expectError bool + }{ + { + name: "bad request - invalid host", + isTLS: false, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "bad request - missing host", + isTLS: false, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\n\r\n"), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "isTLS true and redirectTLS true - no redirect", + isTLS: true, + redirectTLS: true, + request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "redirect to TLS", + isTLS: false, + redirectTLS: true, + request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"), + expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "handle ping request", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "session not found", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + msr.On("Get", types.SessionKey{ + Id: "test", + Type: types.TunnelTypeHTTP, + }).Return((registry.Session)(nil), fmt.Errorf("session not found")) + }, + }, + { + name: "bad request - invalid http", + isTLS: false, + redirectTLS: false, + request: []byte("INVALID\r\n\r\n"), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "bad request - header too large", + isTLS: false, + redirectTLS: false, + request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "bad request - no request", + isTLS: false, + redirectTLS: false, + request: []byte(""), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "forwarding - open channel fails", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte(""), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + + msr.On("Get", types.SessionKey{ + Id: "test", + Type: types.TunnelTypeHTTP, + }).Return(mockSession, nil) + + mockSession.On("Forwarder").Return(mockForwarder) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed")) + }, + }, + { + name: "forwarding - send initial request fails", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte(""), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", types.SessionKey{ + Id: "test", + Type: types.TunnelTypeHTTP, + }).Return(mockSession, nil) + + mockSession.On("Forwarder").Return(mockForwarder) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + + reqCh := make(chan *ssh.Request) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + + mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mockSSHChannel.On("Close").Return(nil) + + go func() { + for range reqCh { + } + }() + }, + }, + { + name: "forwarding - success", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", types.SessionKey{ + Id: "test", + Type: types.TunnelTypeHTTP, + }).Return(mockSession, nil) + + mockSession.On("Forwarder").Return(mockForwarder) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + + reqCh := make(chan *ssh.Request) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + + mockSSHChannel.On("Write", mock.Anything).Return(0, nil) + mockSSHChannel.On("Close").Return(nil) + + mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) { + w := args.Get(0).(io.ReadWriter) + _, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello")) + }) + + go func() { + for range reqCh { + } + }() + }, + }, + { + name: "redirect - write failure", + isTLS: false, + redirectTLS: true, + request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) + mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error")) + mc.On("Close").Return(nil) + return mc, nil + }, + }, + { + name: "bad request - write failure", + isTLS: false, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\n\r\n"), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) + mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mc.On("Close").Return(nil) + return mc, nil + }, + }, + { + name: "read error - connection failure", + isTLS: false, + redirectTLS: false, + request: []byte(""), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.On("SetReadDeadline", mock.Anything).Return(nil) + mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer")) + mc.On("Close").Return(nil) + return mc, nil + }, + }, + { + name: "handle ping request - write failure", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) + mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mc.On("Close").Return(nil) + return mc, nil + }, + }, + { + name: "close connection - error", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) + mc.On("Write", mock.Anything).Return(182, nil) + mc.On("Close").Return(fmt.Errorf("close error")) + return mc, nil + }, + }, + { + name: "forwarding - stream close error", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte(""), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", mock.Anything).Return(mockSession, nil) + mockSession.On("Forwarder").Return(mockForwarder) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + + reqCh := make(chan *ssh.Request) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + + mockSSHChannel.On("Write", mock.Anything).Return(0, nil) + mockSSHChannel.On("Close").Return(nil) + + mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return() + }, + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) + mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2) + addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + mc.On("RemoteAddr").Return(addr) + return mc, nil + }, + }, + { + name: "forwarding - middleware failure", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte(""), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool { + return k.Id == "test" + })).Return(mockSession, nil) + mockSession.On("Forwarder").Return(mockForwarder) + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + + reqCh := make(chan *ssh.Request) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockSSHChannel.On("Close").Return(nil) + }, + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) + mc.On("Close").Return(nil).Times(2) + mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")}) + return mc, nil + }, + }, + { + name: "forwarding - channel close error", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", mock.Anything).Return(mockSession, nil) + mockSession.On("Forwarder").Return(mockForwarder) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + reqCh := make(chan *ssh.Request) + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + + mockSSHChannel.On("Write", mock.Anything).Return(0, nil) + mockSSHChannel.On("Close").Return(fmt.Errorf("close error")) + + mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) { + w := args.Get(0).(io.ReadWriter) + _, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello")) + }) + }, + }, + { + name: "forwarding - open channel timeout", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte(""), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + + msr.On("Get", mock.Anything).Return(mockSession, nil) + mockSession.On("Forwarder").Return(mockForwarder) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + + mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + ctx := args.Get(0).(context.Context) + <-ctx.Done() + }).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSessionRegistry := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + mockConfig.On("HeaderSize").Return(4096) + mockConfig.On("TLSRedirect").Return(true) + hh := &httpHandler{ + sessionRegistry: mockSessionRegistry, + config: mockConfig, + } + + if tt.setupMocks != nil { + tt.setupMocks(mockSessionRegistry) + } + + var serverConn, clientConn net.Conn + if tt.setupConn != nil { + serverConn, clientConn = tt.setupConn() + } else { + serverConn, clientConn = net.Pipe() + } + + if clientConn != nil { + defer func(clientConn net.Conn) { + err := clientConn.Close() + assert.NoError(t, err) + }(clientConn) + } + + remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + var wrappedServerConn net.Conn + if _, ok := serverConn.(*MockConn); ok { + wrappedServerConn = serverConn + } else { + wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr} + } + + responseChan := make(chan []byte, 1) + doneChan := make(chan struct{}) + + if clientConn != nil { + go func() { + defer close(doneChan) + var res []byte + for { + buf := make([]byte, 4096) + n, err := clientConn.Read(buf) + if err != nil { + if err != io.EOF { + t.Logf("Error reading response: %v", err) + } + break + } + res = append(res, buf[:n]...) + if len(tt.expected) > 0 && len(res) >= len(tt.expected) { + break + } + } + responseChan <- res + }() + + go func() { + _, err := clientConn.Write(tt.request) + if err != nil { + t.Logf("Error writing request: %v", err) + } + }() + } else { + close(responseChan) + close(doneChan) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + hh.Handler(wrappedServerConn, tt.isTLS) + }() + + select { + case response := <-responseChan: + if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" { + resStr := string(response) + assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n")) + assert.Contains(t, resStr, "Content-Length: 5\r\n") + assert.Contains(t, resStr, "Server: Tunnel Please\r\n") + assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello")) + } else { + assert.Equal(t, string(tt.expected), string(response)) + } + case <-time.After(10 * time.Second): + if clientConn != nil { + t.Fatal("Test timeout - no response received") + } + } + + wg.Wait() + if clientConn != nil { + <-doneChan + } + + mockSessionRegistry.AssertExpectations(t) + if mc, ok := serverConn.(*MockConn); ok { + mc.AssertExpectations(t) + } + }) + } +} diff --git a/internal/transport/https.go b/internal/transport/https.go index 88ffe27..f1076bf 100644 --- a/internal/transport/https.go +++ b/internal/transport/https.go @@ -5,31 +5,30 @@ import ( "errors" "log" "net" + "tunnel_pls/internal/config" "tunnel_pls/internal/registry" ) type https struct { + config config.Config tlsConfig *tls.Config httpHandler *httpHandler - domain string - port string } -func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport { +func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport { return &https{ + config: config, tlsConfig: tlsConfig, - httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS), - domain: domain, - port: port, + httpHandler: newHTTPHandler(config, sessionRegistry), } } func (ht *https) Listen() (net.Listener, error) { - return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig) + return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig) } func (ht *https) Serve(listener net.Listener) error { - log.Printf("HTTPS server is starting on port %s", ht.port) + log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort()) for { conn, err := listener.Accept() if err != nil { @@ -40,6 +39,6 @@ func (ht *https) Serve(listener net.Listener) error { continue } - go ht.httpHandler.handler(conn, true) + go ht.httpHandler.Handler(conn, true) } } diff --git a/internal/transport/https_test.go b/internal/transport/https_test.go new file mode 100644 index 0000000..6081d97 --- /dev/null +++ b/internal/transport/https_test.go @@ -0,0 +1,120 @@ +package transport + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewHTTPSServer(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + tlsConfig := &tls.Config{} + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) + srv := NewHTTPSServer(mockConfig, msr, tlsConfig) + assert.NotNil(t, srv) + + httpsSrv, ok := srv.(*https) + assert.True(t, ok) + assert.Equal(t, tlsConfig, httpsSrv.tlsConfig) + assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry) +} + +func TestHTTPSServer_Listen(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) + tlsConfig := &tls.Config{ + GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return nil, nil + }, + } + srv := NewHTTPSServer(mockConfig, msr, tlsConfig) + + listener, err := srv.Listen() + if err != nil { + t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err) + return + } + assert.NotNil(t, listener) + err = listener.Close() + assert.NoError(t, err) +} + +func TestHTTPSServer_Serve(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) + srv := NewHTTPSServer(mockConfig, msr, &tls.Config{}) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + + go func() { + time.Sleep(100 * time.Millisecond) + err = listener.Close() + assert.NoError(t, err) + }() + + err = srv.Serve(listener) + assert.True(t, errors.Is(err, net.ErrClosed)) +} + +func TestHTTPSServer_Serve_AcceptError(t *testing.T) { + msr := new(MockSessionRegistry) + + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) + srv := NewHTTPSServer(mockConfig, msr, &tls.Config{}) + + ml := new(mockListener) + ml.On("Accept").Return(nil, errors.New("accept error")).Once() + ml.On("Accept").Return(nil, net.ErrClosed).Once() + + err := srv.Serve(ml) + assert.True(t, errors.Is(err, net.ErrClosed)) + ml.AssertExpectations(t) +} + +func TestHTTPSServer_Serve_Success(t *testing.T) { + msr := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) + mockConfig.On("HeaderSize").Return(4096) + + srv := NewHTTPSServer(mockConfig, msr, &tls.Config{}) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + listenerport := listener.Addr().(*net.TCPAddr).Port + + go func() { + _ = srv.Serve(listener) + }() + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport)) + assert.NoError(t, err) + + _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) + + time.Sleep(100 * time.Millisecond) + err = conn.Close() + assert.NoError(t, err) + err = listener.Close() + assert.NoError(t, err) +} diff --git a/internal/transport/tcp.go b/internal/transport/tcp.go index 99670d2..34ea3c2 100644 --- a/internal/transport/tcp.go +++ b/internal/transport/tcp.go @@ -1,27 +1,28 @@ package transport import ( + "context" "errors" "fmt" "io" "log" "net" + "time" "golang.org/x/crypto/ssh" ) type tcp struct { port uint16 - forwarder forwarder + forwarder Forwarder } -type forwarder interface { - CreateForwardedTCPIPPayload(origin net.Addr) []byte - OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) +type Forwarder interface { + OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) HandleConnection(dst io.ReadWriter, src ssh.Channel) } -func NewTCPServer(port uint16, forwarder forwarder) Transport { +func NewTCPServer(port uint16, forwarder Forwarder) Transport { return &tcp{ port: port, forwarder: forwarder, @@ -53,11 +54,11 @@ func (tt *tcp) handleTcp(conn net.Conn) { log.Printf("Failed to close connection: %v", err) } }() - payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr()) - channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, conn.RemoteAddr()) if err != nil { log.Printf("Failed to open forwarded-tcpip channel: %v", err) - return } diff --git a/internal/transport/tcp_test.go b/internal/transport/tcp_test.go new file mode 100644 index 0000000..c4c4963 --- /dev/null +++ b/internal/transport/tcp_test.go @@ -0,0 +1,146 @@ +package transport + +import ( + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/ssh" +) + +func TestNewTCPServer(t *testing.T) { + mf := new(MockForwarder) + port := uint16(9000) + + srv := NewTCPServer(port, mf) + assert.NotNil(t, srv) + + tcpSrv, ok := srv.(*tcp) + assert.True(t, ok) + assert.Equal(t, port, tcpSrv.port) + assert.Equal(t, mf, tcpSrv.forwarder) +} + +func TestTCPServer_Listen(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf) + + listener, err := srv.Listen() + assert.NoError(t, err) + assert.NotNil(t, listener) + err = listener.Close() + assert.NoError(t, err) +} + +func TestTCPServer_Serve(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + + go func() { + time.Sleep(100 * time.Millisecond) + err = listener.Close() + assert.NoError(t, err) + }() + + err = srv.Serve(listener) + assert.Nil(t, err) +} + +func TestTCPServer_Serve_AcceptError(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf) + + ml := new(mockListener) + ml.On("Accept").Return(nil, errors.New("accept error")).Once() + ml.On("Accept").Return(nil, net.ErrClosed).Once() + + err := srv.Serve(ml) + assert.Nil(t, err) + ml.AssertExpectations(t) +} + +func TestTCPServer_Serve_Success(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + + reqs := make(chan *ssh.Request) + mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil) + mf.On("HandleConnection", mock.Anything, mock.Anything).Return() + + go func() { + _ = srv.Serve(listener) + }() + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + err = conn.Close() + assert.NoError(t, err) + err = listener.Close() + assert.NoError(t, err) + mf.AssertExpectations(t) +} + +func TestTCPServer_handleTcp_Success(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf).(*tcp) + + serverConn, clientConn := net.Pipe() + defer func(clientConn net.Conn) { + err := clientConn.Close() + assert.NoError(t, err) + }(clientConn) + + reqs := make(chan *ssh.Request) + mockChannel := new(MockSSHChannel) + mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil) + + mf.On("HandleConnection", serverConn, mockChannel).Return() + + srv.handleTcp(serverConn) + + mf.AssertExpectations(t) +} + +func TestTCPServer_handleTcp_CloseError(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf).(*tcp) + + mc := new(MockConn) + mc.On("Close").Return(errors.New("close error")) + mc.On("RemoteAddr").Return(&net.TCPAddr{}) + + mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) + + srv.handleTcp(mc) + mc.AssertExpectations(t) +} + +func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf).(*tcp) + + serverConn, clientConn := net.Pipe() + defer func(clientConn net.Conn) { + err := clientConn.Close() + assert.NoError(t, err) + }(clientConn) + + mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) + + srv.handleTcp(serverConn) + + mf.AssertExpectations(t) +} diff --git a/internal/transport/tls.go b/internal/transport/tls.go index 877afb4..4d62e60 100644 --- a/internal/transport/tls.go +++ b/internal/transport/tls.go @@ -8,6 +8,7 @@ import ( "fmt" "log" "os" + "path/filepath" "sync" "time" "tunnel_pls/internal/config" @@ -16,13 +17,22 @@ import ( "github.com/libdns/cloudflare" ) -type TLSManager interface { - userCertsExistAndValid() bool - loadUserCerts() error - startCertWatcher() - initCertMagic() error - getTLSConfig() *tls.Config - getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) +func NewTLSConfig(config config.Config) (*tls.Config, error) { + var initErr error + + tlsManagerOnce.Do(func() { + tm := createTLSManager(config) + initErr = tm.initialize() + if initErr == nil { + globalTLSManager = tm + } + }) + + if initErr != nil { + return nil, initErr + } + + return globalTLSManager.getTLSConfig(), nil } type tlsManager struct { @@ -40,52 +50,60 @@ type tlsManager struct { useCertMagic bool } -var globalTLSManager TLSManager +var globalTLSManager *tlsManager var tlsManagerOnce sync.Once -func NewTLSConfig(config config.Config) (*tls.Config, error) { - var initErr error +func createTLSManager(cfg config.Config) *tlsManager { + storagePath := cfg.TLSStoragePath() + cleanBase := filepath.Clean(storagePath) - tlsManagerOnce.Do(func() { - certPath := "certs/tls/cert.pem" - keyPath := "certs/tls/privkey.pem" - storagePath := "certs/tls/certmagic" + return &tlsManager{ + config: cfg, + certPath: filepath.Join(cleanBase, "cert.pem"), + keyPath: filepath.Join(cleanBase, "privkey.pem"), + storagePath: filepath.Join(cleanBase, "certmagic"), + } +} - tm := &tlsManager{ - config: config, - certPath: certPath, - keyPath: keyPath, - storagePath: storagePath, - } +func (tm *tlsManager) initialize() error { + if tm.userCertsExistAndValid() { + return tm.initializeWithUserCerts() + } + return tm.initializeWithCertMagic() +} - if tm.userCertsExistAndValid() { - log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath) - if err := tm.loadUserCerts(); err != nil { - initErr = fmt.Errorf("failed to load user certificates: %w", err) - return - } - tm.useCertMagic = false - tm.startCertWatcher() - } else { - log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain()) - if err := tm.initCertMagic(); err != nil { - initErr = fmt.Errorf("failed to initialize CertMagic: %w", err) - return - } - tm.useCertMagic = true - } +func (tm *tlsManager) initializeWithUserCerts() error { + log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath) - globalTLSManager = tm - }) - - if initErr != nil { - return nil, initErr + if err := tm.loadUserCerts(); err != nil { + return fmt.Errorf("failed to load user certificates: %w", err) } - return globalTLSManager.getTLSConfig(), nil + tm.useCertMagic = false + tm.startCertWatcher() + return nil +} + +func (tm *tlsManager) initializeWithCertMagic() error { + log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", + tm.config.Domain(), tm.config.Domain()) + + if err := tm.initCertMagic(); err != nil { + return fmt.Errorf("failed to initialize CertMagic: %w", err) + } + + tm.useCertMagic = true + return nil } func (tm *tlsManager) userCertsExistAndValid() bool { + if !tm.certFilesExist() { + return false + } + return validateCertDomains(tm.certPath, tm.config.Domain()) +} + +func (tm *tlsManager) certFilesExist() bool { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { log.Printf("Certificate file not found: %s", tm.certPath) return false @@ -94,66 +112,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool { log.Printf("Key file not found: %s", tm.keyPath) return false } - - return ValidateCertDomains(tm.certPath, tm.config.Domain()) -} - -func ValidateCertDomains(certPath, domain string) bool { - certPEM, err := os.ReadFile(certPath) - if err != nil { - log.Printf("Failed to read certificate: %v", err) - return false - } - - block, _ := pem.Decode(certPEM) - if block == nil { - log.Printf("Failed to decode PEM block from certificate") - return false - } - - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - log.Printf("Failed to parse certificate: %v", err) - return false - } - - if time.Now().After(cert.NotAfter) { - log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter) - return false - } - - if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) { - log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter) - return false - } - - var certDomains []string - if cert.Subject.CommonName != "" { - certDomains = append(certDomains, cert.Subject.CommonName) - } - certDomains = append(certDomains, cert.DNSNames...) - - hasBase := false - hasWildcard := false - wildcardDomain := "*." + domain - - for _, d := range certDomains { - if d == domain { - hasBase = true - } - if d == wildcardDomain { - hasWildcard = true - } - } - - if !hasBase { - log.Printf("Certificate does not cover base domain: %s", domain) - } - if !hasWildcard { - log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain) - } - - return hasBase && hasWildcard + return true } func (tm *tlsManager) loadUserCerts() error { @@ -172,62 +131,34 @@ func (tm *tlsManager) loadUserCerts() error { func (tm *tlsManager) startCertWatcher() { go func() { - var lastCertMod, lastKeyMod time.Time - - if info, err := os.Stat(tm.certPath); err == nil { - lastCertMod = info.ModTime() - } - if info, err := os.Stat(tm.keyPath); err == nil { - lastKeyMod = info.ModTime() - } - - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for range ticker.C { - certInfo, certErr := os.Stat(tm.certPath) - keyInfo, keyErr := os.Stat(tm.keyPath) - - if certErr != nil || keyErr != nil { - continue - } - - if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) { - log.Printf("Certificate files changed, reloading...") - - if !ValidateCertDomains(tm.certPath, tm.config.Domain()) { - log.Printf("New certificates don't cover required domains") - - if err := tm.initCertMagic(); err != nil { - log.Printf("Failed to initialize CertMagic: %v", err) - continue - } - tm.useCertMagic = true - return - } - - if err := tm.loadUserCerts(); err != nil { - log.Printf("Failed to reload certificates: %v", err) - continue - } - - lastCertMod = certInfo.ModTime() - lastKeyMod = keyInfo.ModTime() - log.Printf("Certificates reloaded successfully") - } - } + watcher := newCertWatcher(tm) + watcher.watch() }() } func (tm *tlsManager) initCertMagic() error { - if err := os.MkdirAll(tm.storagePath, 0700); err != nil { - return fmt.Errorf("failed to create cert storage directory: %w", err) + if err := tm.createStorageDirectory(); err != nil { + return err } if tm.config.CFAPIToken() == "" { return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") } + magic := tm.createCertMagicConfig() + tm.magic = magic + + return tm.obtainCertificates(magic) +} + +func (tm *tlsManager) createStorageDirectory() error { + if err := os.MkdirAll(tm.storagePath, 0700); err != nil { + return fmt.Errorf("failed to create cert storage directory: %w", err) + } + return nil +} + +func (tm *tlsManager) createCertMagicConfig() *certmagic.Config { cfProvider := &cloudflare.Provider{ APIToken: tm.config.CFAPIToken(), } @@ -244,6 +175,13 @@ func (tm *tlsManager) initCertMagic() error { Storage: storage, }) + acmeIssuer := tm.createACMEIssuer(magic, cfProvider) + magic.Issuers = []certmagic.Issuer{acmeIssuer} + + return magic +} + +func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer { acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ Email: tm.config.ACMEEmail(), Agreed: true, @@ -262,9 +200,10 @@ func (tm *tlsManager) initCertMagic() error { log.Printf("Using Let's Encrypt production server") } - magic.Issuers = []certmagic.Issuer{acmeIssuer} - tm.magic = magic + return acmeIssuer +} +func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error { domains := []string{tm.config.Domain(), "*." + tm.config.Domain()} log.Printf("Requesting certificates for: %v", domains) @@ -307,3 +246,190 @@ func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certifica return tm.userCert, nil } + +func validateCertDomains(certPath, domain string) bool { + cert, err := loadAndParseCertificate(certPath) + if err != nil { + return false + } + + if !isCertificateValid(cert) { + return false + } + + return certCoversRequiredDomains(cert, domain) +} + +func loadAndParseCertificate(certPath string) (*x509.Certificate, error) { + certPEM, err := os.ReadFile(certPath) + if err != nil { + log.Printf("Failed to read certificate: %v", err) + return nil, err + } + + block, _ := pem.Decode(certPEM) + if block == nil { + log.Printf("Failed to decode PEM block from certificate") + return nil, fmt.Errorf("failed to decode PEM block") + } + + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + log.Printf("Failed to parse certificate: %v", err) + return nil, err + } + + return cert, nil +} + +func isCertificateValid(cert *x509.Certificate) bool { + now := time.Now() + + if now.After(cert.NotAfter) { + log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter) + return false + } + + thirtyDaysFromNow := now.Add(30 * 24 * time.Hour) + if thirtyDaysFromNow.After(cert.NotAfter) { + log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter) + return false + } + + return true +} + +func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool { + certDomains := extractCertDomains(cert) + hasBase, hasWildcard := checkDomainCoverage(certDomains, domain) + + logDomainCoverage(hasBase, hasWildcard, domain) + return hasBase && hasWildcard +} + +func extractCertDomains(cert *x509.Certificate) []string { + var domains []string + if cert.Subject.CommonName != "" { + domains = append(domains, cert.Subject.CommonName) + } + domains = append(domains, cert.DNSNames...) + return domains +} + +func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) { + wildcardDomain := "*." + domain + + for _, d := range certDomains { + if d == domain { + hasBase = true + } + if d == wildcardDomain { + hasWildcard = true + } + } + + return hasBase, hasWildcard +} + +func logDomainCoverage(hasBase, hasWildcard bool, domain string) { + if !hasBase { + log.Printf("Certificate does not cover base domain: %s", domain) + } + if !hasWildcard { + log.Printf("Certificate does not cover wildcard domain: *.%s", domain) + } +} + +type certWatcher struct { + tm *tlsManager + lastCertMod time.Time + lastKeyMod time.Time +} + +func newCertWatcher(tm *tlsManager) *certWatcher { + watcher := &certWatcher{tm: tm} + watcher.initializeModTimes() + return watcher +} + +func (cw *certWatcher) initializeModTimes() { + if info, err := os.Stat(cw.tm.certPath); err == nil { + cw.lastCertMod = info.ModTime() + } + if info, err := os.Stat(cw.tm.keyPath); err == nil { + cw.lastKeyMod = info.ModTime() + } +} + +func (cw *certWatcher) watch() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for range ticker.C { + if cw.checkAndReloadCerts() { + return + } + } +} + +func (cw *certWatcher) checkAndReloadCerts() bool { + certInfo, keyInfo, err := cw.getFileInfo() + if err != nil { + return false + } + + if !cw.filesModified(certInfo, keyInfo) { + return false + } + + return cw.handleCertificateChange(certInfo, keyInfo) +} + +func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) { + certInfo, certErr := os.Stat(cw.tm.certPath) + keyInfo, keyErr := os.Stat(cw.tm.keyPath) + + if certErr != nil || keyErr != nil { + return nil, nil, fmt.Errorf("file stat error") + } + + return certInfo, keyInfo, nil +} + +func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool { + return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod) +} + +func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool { + log.Printf("Certificate files changed, reloading...") + + if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) { + return cw.switchToCertMagic() + } + + if err := cw.tm.loadUserCerts(); err != nil { + log.Printf("Failed to reload certificates: %v", err) + return false + } + + cw.updateModTimes(certInfo, keyInfo) + log.Printf("Certificates reloaded successfully") + return false +} + +func (cw *certWatcher) switchToCertMagic() bool { + log.Printf("New certificates don't cover required domains") + + if err := cw.tm.initCertMagic(); err != nil { + log.Printf("Failed to initialize CertMagic: %v", err) + return false + } + + cw.tm.useCertMagic = true + return true +} + +func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) { + cw.lastCertMod = certInfo.ModTime() + cw.lastKeyMod = keyInfo.ModTime() +} diff --git a/internal/transport/tls_test.go b/internal/transport/tls_test.go new file mode 100644 index 0000000..0c5510c --- /dev/null +++ b/internal/transport/tls_test.go @@ -0,0 +1,1246 @@ +package transport + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "sync" + "testing" + "time" + "tunnel_pls/internal/config" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockConfig struct { + mock.Mock +} + +func (m *MockConfig) Domain() string { return m.Called().String(0) } +func (m *MockConfig) SSHPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) } +func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) PprofPort() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } +func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } + +func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, soon bool) (string, string) { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + assert.NoError(t, err) + + notAfter := time.Now().Add(365 * 24 * time.Hour) + if expired { + notAfter = time.Now().Add(-24 * time.Hour) + } else if soon { + notAfter = time.Now().Add(15 * 24 * time.Hour) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: domain, + }, + NotBefore: time.Now().Add(-24 * time.Hour), + NotAfter: notAfter, + DNSNames: []string{domain}, + } + + if wildcard { + template.DNSNames = append(template.DNSNames, "*."+domain) + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + assert.NoError(t, err) + + certOut, err := os.CreateTemp("", "cert*.pem") + assert.NoError(t, err) + err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + assert.NoError(t, err) + err = certOut.Close() + assert.NoError(t, err) + + keyOut, err := os.CreateTemp("", "key*.pem") + assert.NoError(t, err) + err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + assert.NoError(t, err) + err = keyOut.Close() + assert.NoError(t, err) + + return certOut.Name(), keyOut.Name() +} + +func setupTestDir(t *testing.T) string { + t.Helper() + + tmpDir, err := os.MkdirTemp("", "tls-test-*") + assert.NoError(t, err) + + t.Cleanup(func() { + err = os.RemoveAll(tmpDir) + assert.NoError(t, err) + }) + + return tmpDir +} + +func TestValidateCertDomains(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (certPath string, cleanup func()) + domain string + expected bool + }{ + { + name: "file not found", + setup: func(t *testing.T) (string, func()) { + return "nonexistent.pem", func() {} + }, + domain: "example.com", + expected: false, + }, + { + name: "invalid PEM", + setup: func(t *testing.T) (string, func()) { + tmpFile, err := os.CreateTemp("", "invalid*.pem") + assert.NoError(t, err) + _, err = tmpFile.WriteString("not a pem") + assert.NoError(t, err) + err = tmpFile.Close() + assert.NoError(t, err) + return tmpFile.Name(), func() { + _ = os.Remove(tmpFile.Name()) + } + }, + domain: "example.com", + expected: false, + }, + { + name: "valid cert with wildcard", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + return certPath, func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + } + }, + domain: "example.com", + expected: true, + }, + { + name: "expired cert", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", true, true, false) + return certPath, func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + } + }, + domain: "example.com", + expected: false, + }, + { + name: "cert expiring soon", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", true, false, true) + return certPath, func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + } + }, + domain: "example.com", + expected: false, + }, + { + name: "missing wildcard", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", false, false, false) + return certPath, func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + } + }, + domain: "example.com", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + certPath, cleanup := tt.setup(t) + defer cleanup() + + result := validateCertDomains(certPath, tt.domain) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestLoadAndParseCertificate(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (certPath string, cleanup func()) + wantError bool + validate func(t *testing.T, cert *x509.Certificate) + }{ + { + name: "success", + setup: func(t *testing.T) (string, func()) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + return certPath, func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + } + }, + wantError: false, + validate: func(t *testing.T, cert *x509.Certificate) { + assert.Equal(t, "example.com", cert.Subject.CommonName) + }, + }, + { + name: "file not found", + setup: func(t *testing.T) (string, func()) { + return "nonexistent.pem", func() {} + }, + wantError: true, + validate: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + certPath, cleanup := tt.setup(t) + defer cleanup() + + cert, err := loadAndParseCertificate(certPath) + + if tt.wantError { + assert.Error(t, err) + assert.Nil(t, cert) + } else { + assert.NoError(t, err) + assert.NotNil(t, cert) + if tt.validate != nil { + tt.validate(t, cert) + } + } + }) + } +} + +func TestIsCertificateValid(t *testing.T) { + tests := []struct { + name string + expired bool + soon bool + expected bool + }{ + { + name: "valid certificate", + expired: false, + soon: false, + expected: true, + }, + { + name: "expired certificate", + expired: true, + soon: false, + expected: false, + }, + { + name: "expiring soon", + expired: false, + soon: true, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, tt.expired, tt.soon) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) + + cert, err := loadAndParseCertificate(certPath) + assert.NoError(t, err) + + result := isCertificateValid(cert) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestExtractCertDomains(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) + + cert, err := loadAndParseCertificate(certPath) + assert.NoError(t, err) + + domains := extractCertDomains(cert) + assert.Contains(t, domains, "example.com") + assert.Contains(t, domains, "*.example.com") +} + +func TestCheckDomainCoverage(t *testing.T) { + tests := []struct { + name string + certDomains []string + domain string + wantBase bool + wantWildcard bool + }{ + { + name: "both covered", + certDomains: []string{"example.com", "*.example.com"}, + domain: "example.com", + wantBase: true, + wantWildcard: true, + }, + { + name: "only base", + certDomains: []string{"example.com"}, + domain: "example.com", + wantBase: true, + wantWildcard: false, + }, + { + name: "only wildcard", + certDomains: []string{"*.example.com"}, + domain: "example.com", + wantBase: false, + wantWildcard: true, + }, + { + name: "neither", + certDomains: []string{"other.com"}, + domain: "example.com", + wantBase: false, + wantWildcard: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hasBase, hasWildcard := checkDomainCoverage(tt.certDomains, tt.domain) + assert.Equal(t, tt.wantBase, hasBase) + assert.Equal(t, tt.wantWildcard, hasWildcard) + }) + } +} + +func TestTLSManager_getTLSConfig(t *testing.T) { + tm := &tlsManager{ + useCertMagic: false, + } + cfg := tm.getTLSConfig() + assert.NotNil(t, cfg) + assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion) + assert.Equal(t, uint16(tls.VersionTLS13), cfg.MaxVersion) + assert.NotNil(t, cfg.GetCertificate) +} + +func TestTLSManager_getCertificate(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + wantError bool + errorContains string + }{ + { + name: "no certificate available", + setup: func(t *testing.T) *tlsManager { + return &tlsManager{ + useCertMagic: false, + userCert: nil, + } + }, + wantError: true, + errorContains: "no certificate available", + }, + { + name: "with user certificate", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + assert.NoError(t, err) + + return &tlsManager{ + useCertMagic: false, + userCert: &cert, + } + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + hello := &tls.ClientHelloInfo{ + ServerName: "example.com", + } + + cert, err := tm.getCertificate(hello) + + if tt.wantError { + assert.Error(t, err) + assert.Nil(t, cert) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, cert) + } + }) + } +} + +func TestTLSManager_userCertsExistAndValid(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + expected bool + }{ + { + name: "no files", + setup: func(t *testing.T) *tlsManager { + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + return &tlsManager{ + config: mockCfg, + certPath: "nonexistent.pem", + keyPath: "nonexistent.key", + } + }, + expected: false, + }, + { + name: "missing key file", + setup: func(t *testing.T) *tlsManager { + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { _ = os.Remove(certPath) }) + err := os.Remove(keyPath) + assert.NoError(t, err) + + return &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + result := tm.userCertsExistAndValid() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTLSManager_certFilesExist(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) + + tm := &tlsManager{ + certPath: certPath, + keyPath: keyPath, + } + + result := tm.certFilesExist() + assert.True(t, result) +} + +func TestTLSManager_loadUserCerts(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + wantError bool + }{ + { + name: "success", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + return &tlsManager{ + certPath: certPath, + keyPath: keyPath, + } + }, + wantError: false, + }, + { + name: "invalid path", + setup: func(t *testing.T) *tlsManager { + return &tlsManager{ + certPath: "nonexistent.pem", + keyPath: "nonexistent.key", + } + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + err := tm.loadUserCerts() + + if tt.wantError { + assert.Error(t, err) + assert.Nil(t, tm.userCert) + } else { + assert.NoError(t, err) + assert.NotNil(t, tm.userCert) + } + }) + } +} + +func TestCreateTLSManager(t *testing.T) { + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("TLSStoragePath").Return(tmpDir) + + tm := createTLSManager(mockCfg) + + assert.NotNil(t, tm) + assert.Equal(t, mockCfg, tm.config) + assert.Equal(t, filepath.Join(tmpDir, "cert.pem"), tm.certPath) + assert.Equal(t, filepath.Join(tmpDir, "privkey.pem"), tm.keyPath) + assert.Equal(t, filepath.Join(tmpDir, "certmagic"), tm.storagePath) +} + +func TestNewCertWatcher(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) + + mockCfg := &MockConfig{} + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + watcher := newCertWatcher(tm) + + assert.NotNil(t, watcher) + assert.Equal(t, tm, watcher.tm) + assert.False(t, watcher.lastCertMod.IsZero()) + assert.False(t, watcher.lastKeyMod.IsZero()) +} + +func TestCertWatcher_filesModified(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) + + mockCfg := &MockConfig{} + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + watcher := newCertWatcher(tm) + + certInfo, err := os.Stat(certPath) + assert.NoError(t, err) + keyInfo, err := os.Stat(keyPath) + assert.NoError(t, err) + + result := watcher.filesModified(certInfo, keyInfo) + assert.False(t, result) + + watcher.lastCertMod = time.Now().Add(-1 * time.Hour) + + result = watcher.filesModified(certInfo, keyInfo) + assert.True(t, result) +} + +func TestCertWatcher_updateModTimes(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) + + mockCfg := &MockConfig{} + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + watcher := newCertWatcher(tm) + + certInfo, err := os.Stat(certPath) + assert.NoError(t, err) + keyInfo, err := os.Stat(keyPath) + assert.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + watcher.updateModTimes(certInfo, keyInfo) + + assert.Equal(t, certInfo.ModTime(), watcher.lastCertMod) + assert.Equal(t, keyInfo.ModTime(), watcher.lastKeyMod) +} + +func TestCertWatcher_getFileInfo(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + wantError bool + validate func(t *testing.T, certInfo, keyInfo os.FileInfo) + }{ + { + name: "success", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + return &tlsManager{ + config: &MockConfig{}, + certPath: certPath, + keyPath: keyPath, + } + }, + wantError: false, + validate: func(t *testing.T, certInfo, keyInfo os.FileInfo) { + assert.NotNil(t, certInfo) + assert.NotNil(t, keyInfo) + }, + }, + { + name: "missing cert file", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + err := os.Remove(certPath) + assert.NoError(t, err) + t.Cleanup(func() { _ = os.Remove(keyPath) }) + + return &tlsManager{ + config: &MockConfig{}, + certPath: certPath, + keyPath: keyPath, + } + }, + wantError: true, + }, + { + name: "missing key file", + setup: func(t *testing.T) *tlsManager { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + err := os.Remove(keyPath) + assert.NoError(t, err) + t.Cleanup(func() { _ = os.Remove(certPath) }) + + return &tlsManager{ + config: &MockConfig{}, + certPath: certPath, + keyPath: keyPath, + } + }, + wantError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + watcher := newCertWatcher(tm) + + certInfo, keyInfo, err := watcher.getFileInfo() + + if tt.wantError { + assert.Error(t, err) + assert.Nil(t, certInfo) + assert.Nil(t, keyInfo) + } else { + assert.NoError(t, err) + if tt.validate != nil { + tt.validate(t, certInfo, keyInfo) + } + } + }) + } +} + +func TestCertWatcher_checkAndReloadCerts(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (*tlsManager, *certWatcher) + expected bool + }{ + { + name: "file error", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + tm := &tlsManager{ + config: &MockConfig{}, + certPath: "nonexistent.pem", + keyPath: "nonexistent.key", + } + return tm, newCertWatcher(tm) + }, + expected: false, + }, + { + name: "no modification", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + tm := &tlsManager{ + config: &MockConfig{}, + certPath: certPath, + keyPath: keyPath, + } + return tm, newCertWatcher(tm) + }, + expected: false, + }, + { + name: "with modification", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + err := tm.loadUserCerts() + assert.NoError(t, err) + + watcher := newCertWatcher(tm) + watcher.lastCertMod = time.Now().Add(-1 * time.Hour) + + return tm, watcher + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, watcher := tt.setup(t) + result := watcher.checkAndReloadCerts() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCertWatcher_handleCertificateChange(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) + expected bool + }{ + { + name: "successful reload", + setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + watcher := newCertWatcher(tm) + + certInfo, _ := os.Stat(certPath) + keyInfo, _ := os.Stat(keyPath) + + return tm, watcher, certInfo, keyInfo + }, + expected: false, + }, + { + name: "invalid cert triggers certmagic", + setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) { + certPath, keyPath := createTestCert(t, "example.com", false, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + storagePath: tmpDir, + } + + watcher := newCertWatcher(tm) + + certInfo, _ := os.Stat(certPath) + keyInfo, _ := os.Stat(keyPath) + + return tm, watcher, certInfo, keyInfo + }, + expected: false, + }, + { + name: "load error", + setup: func(t *testing.T) (*tlsManager, *certWatcher, os.FileInfo, os.FileInfo) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: "nonexistent.key", + } + + watcher := newCertWatcher(tm) + + certInfo, _ := os.Stat(certPath) + keyInfo, _ := os.Stat(keyPath) + + return tm, watcher, certInfo, keyInfo + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, watcher, certInfo, keyInfo := tt.setup(t) + result := watcher.handleCertificateChange(certInfo, keyInfo) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCertWatcher_switchToCertMagic(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) *tlsManager + expected bool + }{ + { + name: "with staging token", + setup: func(t *testing.T) *tlsManager { + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("test-token") + mockCfg.On("ACMEEmail").Return("test@example.com") + mockCfg.On("ACMEStaging").Return(true) + + return &tlsManager{ + config: mockCfg, + storagePath: tmpDir, + } + }, + expected: false, + }, + { + name: "missing token", + setup: func(t *testing.T) *tlsManager { + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("") + + return &tlsManager{ + config: mockCfg, + storagePath: tmpDir, + } + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tm := tt.setup(t) + watcher := newCertWatcher(tm) + result := watcher.switchToCertMagic() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCertWatcher_watch(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) (*tlsManager, *certWatcher) + expected bool + }{ + { + name: "exits on certmagic switch attempt", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + certPath, keyPath := createTestCert(t, "example.com", false, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + storagePath: tmpDir, + } + + watcher := newCertWatcher(tm) + watcher.lastCertMod = time.Now().Add(-1 * time.Hour) + watcher.lastKeyMod = time.Now().Add(-1 * time.Hour) + + return tm, watcher + }, + expected: false, + }, + { + name: "continues on no modification", + setup: func(t *testing.T) (*tlsManager, *certWatcher) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + return tm, newCertWatcher(tm) + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, watcher := tt.setup(t) + result := watcher.checkAndReloadCerts() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCertWatcher_watch_Integration(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) + + mockCfg := &MockConfig{} + mockCfg.On("Domain").Return("example.com") + + tm := &tlsManager{ + config: mockCfg, + certPath: certPath, + keyPath: keyPath, + } + + err := tm.loadUserCerts() + assert.NoError(t, err) + initialCert := tm.userCert + + watcher := newCertWatcher(tm) + + go watcher.watch() + + time.Sleep(50 * time.Millisecond) + + newCertPath, newKeyPath := createTestCert(t, "example.com", true, false, false) + defer func(name string) { + err = os.Remove(name) + assert.NoError(t, err) + }(newCertPath) + defer func(name string) { + err = os.Remove(name) + assert.NoError(t, err) + }(newKeyPath) + + newCertData, err := os.ReadFile(newCertPath) + assert.NoError(t, err) + newKeyData, err := os.ReadFile(newKeyPath) + assert.NoError(t, err) + + err = os.WriteFile(certPath, newCertData, 0644) + assert.NoError(t, err) + err = os.WriteFile(keyPath, newKeyData, 0644) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + assert.NotNil(t, tm.userCert) + assert.Equal(t, initialCert, tm.userCert) +} + +func TestNewTLSConfig(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) config.Config + wantError bool + errorMsg string + validate func(t *testing.T, cfg *tls.Config) + }{ + { + name: "with valid user certs", + setup: func(t *testing.T) config.Config { + globalTLSManager = nil + tlsManagerOnce = sync.Once{} + + tmpDir := setupTestDir(t) + + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + t.Cleanup(func() { + _ = os.Remove(certPath) + _ = os.Remove(keyPath) + }) + + certData, err := os.ReadFile(certPath) + assert.NoError(t, err) + keyData, err := os.ReadFile(keyPath) + assert.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpDir, "cert.pem"), certData, 0644) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, "privkey.pem"), keyData, 0644) + assert.NoError(t, err) + + mockCfg := &MockConfig{} + mockCfg.On("TLSStoragePath").Return(tmpDir) + mockCfg.On("Domain").Return("example.com") + + return mockCfg + }, + wantError: false, + validate: func(t *testing.T, cfg *tls.Config) { + assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion) + assert.NotNil(t, cfg.GetCertificate) + }, + }, + { + name: "missing certs requires certmagic", + setup: func(t *testing.T) config.Config { + globalTLSManager = nil + tlsManagerOnce = sync.Once{} + + tmpDir := setupTestDir(t) + + mockCfg := &MockConfig{} + mockCfg.On("TLSStoragePath").Return(tmpDir) + mockCfg.On("Domain").Return("example.com") + mockCfg.On("CFAPIToken").Return("") + + return mockCfg + }, + wantError: true, + errorMsg: "CF_API_TOKEN", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.setup(t) + tlsConfig, err := NewTLSConfig(cfg) + + if tt.wantError { + assert.Error(t, err) + if tt.errorMsg != "" { + assert.Contains(t, err.Error(), tt.errorMsg) + } + } else { + assert.NoError(t, err) + assert.NotNil(t, tlsConfig) + if tt.validate != nil { + tt.validate(t, tlsConfig) + } + } + }) + } +} + +func TestNewTLSConfig_Singleton(t *testing.T) { + globalTLSManager = nil + tlsManagerOnce = sync.Once{} + + tmpDir := setupTestDir(t) + + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(certPath) + defer func(name string) { + err := os.Remove(name) + assert.NoError(t, err) + }(keyPath) + + certData, err := os.ReadFile(certPath) + assert.NoError(t, err) + keyData, err := os.ReadFile(keyPath) + assert.NoError(t, err) + + err = os.WriteFile(filepath.Join(tmpDir, "cert.pem"), certData, 0644) + assert.NoError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, "privkey.pem"), keyData, 0644) + assert.NoError(t, err) + + mockCfg := &MockConfig{} + mockCfg.On("TLSStoragePath").Return(tmpDir) + mockCfg.On("Domain").Return("example.com") + + tlsConfig1, err1 := NewTLSConfig(mockCfg) + tlsConfig2, err2 := NewTLSConfig(mockCfg) + + assert.NoError(t, err1) + assert.NoError(t, err2) + assert.NotNil(t, tlsConfig1) + assert.NotNil(t, tlsConfig2) + + assert.Equal(t, tlsConfig1.MinVersion, tlsConfig2.MinVersion) + assert.Equal(t, tlsConfig1.MaxVersion, tlsConfig2.MaxVersion) + assert.Equal(t, tlsConfig1.SessionTicketsDisabled, tlsConfig2.SessionTicketsDisabled) + assert.Equal(t, tlsConfig1.ClientAuth, tlsConfig2.ClientAuth) + + hello := &tls.ClientHelloInfo{ServerName: "example.com"} + cert1, err1 := tlsConfig1.GetCertificate(hello) + cert2, err2 := tlsConfig2.GetCertificate(hello) + + assert.NoError(t, err1) + assert.NoError(t, err2) + assert.NotNil(t, cert1) + assert.NotNil(t, cert2) + + assert.Equal(t, cert1, cert2) +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index ca27061..31219fd 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -8,3 +8,7 @@ type Transport interface { Listen() (net.Listener, error) Serve(listener net.Listener) error } + +type HTTP interface { + Handler(conn net.Conn, isTLS bool) +} diff --git a/internal/version/version_test.go b/internal/version/version_test.go new file mode 100644 index 0000000..f4873f5 --- /dev/null +++ b/internal/version/version_test.go @@ -0,0 +1,84 @@ +package version + +import ( + "fmt" + "testing" +) + +func TestVersionFunctions(t *testing.T) { + origVersion := Version + origBuildDate := BuildDate + origCommit := Commit + defer func() { + Version = origVersion + BuildDate = origBuildDate + Commit = origCommit + }() + + tests := []struct { + name string + version string + buildDate string + commit string + wantFull string + wantShort string + }{ + { + name: "Default dev version", + version: "dev", + buildDate: "unknown", + commit: "unknown", + wantFull: "tunnel_pls dev (commit: unknown, built: unknown)", + wantShort: "dev", + }, + { + name: "Release version", + version: "v1.0.0", + buildDate: "2026-01-23", + commit: "abcdef123", + wantFull: "tunnel_pls v1.0.0 (commit: abcdef123, built: 2026-01-23)", + wantShort: "v1.0.0", + }, + { + name: "Empty values", + version: "", + buildDate: "", + commit: "", + wantFull: "tunnel_pls (commit: , built: )", + wantShort: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + Version = tt.version + BuildDate = tt.buildDate + Commit = tt.commit + + gotFull := GetVersion() + if gotFull != tt.wantFull { + t.Errorf("GetVersion() = %q, want %q", gotFull, tt.wantFull) + } + + gotShort := GetShortVersion() + if gotShort != tt.wantShort { + t.Errorf("GetShortVersion() = %q, want %q", gotShort, tt.wantShort) + } + }) + } +} + +func TestGetVersion_Format(t *testing.T) { + v := "1.2.3" + c := "brainrot" + d := "now" + + Version = v + Commit = c + BuildDate = d + + expected := fmt.Sprintf("tunnel_pls %s (commit: %s, built: %s)", v, c, d) + if GetVersion() != expected { + t.Errorf("GetVersion() formatting mismatch") + } +} diff --git a/main.go b/main.go index f897b46..a908903 100644 --- a/main.go +++ b/main.go @@ -1,27 +1,13 @@ package main import ( - "context" "fmt" "log" - "net" - "net/http" - _ "net/http/pprof" "os" - "os/signal" - "syscall" - "time" + "tunnel_pls/internal/bootstrap" "tunnel_pls/internal/config" - "tunnel_pls/internal/grpc/client" - "tunnel_pls/internal/key" "tunnel_pls/internal/port" - "tunnel_pls/internal/registry" - "tunnel_pls/internal/transport" "tunnel_pls/internal/version" - "tunnel_pls/server" - "tunnel_pls/types" - - "golang.org/x/crypto/ssh" ) func main() { @@ -32,148 +18,19 @@ func main() { log.SetOutput(os.Stdout) log.SetFlags(log.LstdFlags | log.Lshortfile) - log.Printf("Starting %s", version.GetVersion()) conf, err := config.MustLoad() if err != nil { - log.Fatalf("Failed to load configuration: %s", err) - return + log.Fatalf("Config load error: %v", err) } - sshConfig := &ssh.ServerConfig{ - NoClientAuth: true, - ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()), - } - - sshKeyPath := "certs/ssh/id_rsa" - if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { - log.Fatalf("Failed to generate SSH key: %s", err) - } - - privateBytes, err := os.ReadFile(sshKeyPath) + boot, err := bootstrap.New(conf, port.New()) if err != nil { - log.Fatalf("Failed to load private key: %s", err) + log.Fatalf("Startup error: %v", err) } - private, err := ssh.ParsePrivateKey(privateBytes) - if err != nil { - log.Fatalf("Failed to parse private key: %s", err) - } - - sshConfig.AddHostKey(private) - sessionRegistry := registry.NewRegistry() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - errChan := make(chan error, 2) - shutdownChan := make(chan os.Signal, 1) - signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) - - var grpcClient client.Client - - if conf.Mode() == types.ServerModeNODE { - grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort()) - - grpcClient, err = client.New(conf, grpcAddr, sessionRegistry) - if err != nil { - log.Fatalf("failed to create grpc client: %v", err) - } - - healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second) - if err = grpcClient.CheckServerHealth(healthCtx); err != nil { - healthCancel() - log.Fatalf("gRPC health check failed: %v", err) - } - healthCancel() - - go func() { - if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil { - errChan <- fmt.Errorf("failed to subscribe to events: %w", err) - } - }() - } - - go func() { - var httpListener net.Listener - httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect()) - httpListener, err = httpserver.Listen() - if err != nil { - errChan <- fmt.Errorf("failed to start http server: %w", err) - return - } - err = httpserver.Serve(httpListener) - if err != nil { - errChan <- fmt.Errorf("error when serving http server: %w", err) - return - } - }() - - if conf.TLSEnabled() { - go func() { - var httpsListener net.Listener - tlsConfig, _ := transport.NewTLSConfig(conf) - httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig) - httpsListener, err = httpsServer.Listen() - if err != nil { - errChan <- fmt.Errorf("failed to start http server: %w", err) - return - } - err = httpsServer.Serve(httpsListener) - if err != nil { - errChan <- fmt.Errorf("error when serving http server: %w", err) - return - } - }() - } - - portManager := port.New() - err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd()) - if err != nil { - log.Fatalf("Failed to initialize port manager: %s", err) - return - } - - var app server.Server - go func() { - app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort()) - if err != nil { - errChan <- fmt.Errorf("failed to start server: %s", err) - return - } - app.Start() - - }() - - if conf.PprofEnabled() { - go func() { - pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort()) - log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr) - if err = http.ListenAndServe(pprofAddr, nil); err != nil { - log.Printf("pprof server error: %v", err) - } - }() - } - - select { - case err = <-errChan: - log.Printf("error happen : %s", err) - case sig := <-shutdownChan: - log.Printf("received signal %s, shutting down", sig) - } - - cancel() - - if app != nil { - if err = app.Close(); err != nil { - log.Printf("failed to close server : %s", err) - } - } - - if grpcClient != nil { - if err = grpcClient.Close(); err != nil { - log.Printf("failed to close grpc conn : %s", err) - } + if err = boot.Run(); err != nil { + log.Fatalf("Application error: %v", err) } } diff --git a/server/server.go b/server/server.go index f47c579..d3df5fd 100644 --- a/server/server.go +++ b/server/server.go @@ -4,12 +4,14 @@ import ( "context" "errors" "fmt" + "io" "log" "net" "time" "tunnel_pls/internal/config" "tunnel_pls/internal/grpc/client" "tunnel_pls/internal/port" + "tunnel_pls/internal/random" "tunnel_pls/internal/registry" "tunnel_pls/session" @@ -21,6 +23,7 @@ type Server interface { Close() error } type server struct { + randomizer random.Random config config.Config sshPort string sshListener net.Listener @@ -30,13 +33,14 @@ type server struct { portRegistry port.Port } -func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { +func New(randomizer random.Random, config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort)) if err != nil { return nil, err } return &server{ + randomizer: randomizer, config: config, sshPort: sshPort, sshListener: listener, @@ -82,7 +86,7 @@ func (s *server) handleConnection(conn net.Conn) { defer func(sshConn *ssh.ServerConn) { err = sshConn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { + if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { log.Printf("failed to close SSH server: %v", err) } }(sshConn) @@ -95,11 +99,19 @@ func (s *server) handleConnection(conn net.Conn) { cancel() } log.Println("SSH connection established:", sshConn.User()) - sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) + sshSession := session.New(&session.Config{ + Randomizer: s.randomizer, + Config: s.config, + Conn: sshConn, + InitialReq: forwardingReqs, + SshChan: chans, + SessionRegistry: s.sessionRegistry, + PortRegistry: s.portRegistry, + User: user, + }) err = sshSession.Start() if err != nil { - log.Printf("SSH session ended with error: %v", err) + log.Printf("SSH session ended with error: %s", err.Error()) return } - return } diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..a4d5c74 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,880 @@ +package server + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "errors" + "fmt" + "net" + "testing" + "time" + "tunnel_pls/internal/registry" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/ssh" + "google.golang.org/grpc" +) + +type MockRandom struct { + mock.Mock +} + +func (m *MockRandom) String(length int) (string, error) { + args := m.Called(length) + return args.String(0), args.Error(1) +} + +type MockConfig struct { + mock.Mock +} + +func (m *MockConfig) Domain() string { return m.Called().String(0) } +func (m *MockConfig) SSHPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) } +func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) PprofPort() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { + args := m.Called() + if args.Get(0) == nil { + return 0 + } + switch v := args.Get(0).(type) { + case types.ServerMode: + return v + case int: + return types.ServerMode(v) + default: + return types.ServerMode(args.Int(0)) + } +} +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } + +type MockSessionRegistry struct { + mock.Mock +} + +func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) { + args := m.Called(key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { + args := m.Called(user, key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error { + args := m.Called(user, oldKey, newKey) + return args.Error(0) +} + +func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool { + args := m.Called(key, session) + return args.Bool(0) +} + +func (m *MockSessionRegistry) Remove(key registry.Key) { + m.Called(key) +} + +func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session { + args := m.Called(user) + return args.Get(0).([]registry.Session) +} + +func (m *MockSessionRegistry) Slug() slug.Slug { + args := m.Called() + return args.Get(0).(slug.Slug) +} + +type MockGRPCClient struct { + mock.Mock +} + +func (m *MockGRPCClient) ClientConn() *grpc.ClientConn { + args := m.Called() + return args.Get(0).(*grpc.ClientConn) +} + +func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { + args := m.Called(ctx, token) + return args.Bool(0), args.String(1), args.Error(2) +} + +func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error { + args := m.Called(ctx, domain, token) + return args.Error(0) +} + +func (m *MockGRPCClient) Close() error { + args := m.Called() + return args.Error(0) +} + +type MockPort struct { + mock.Mock +} + +func (m *MockPort) AddRange(startPort, endPort uint16) error { + return m.Called(startPort, endPort).Error(0) +} + +func (m *MockPort) Unassigned() (uint16, bool) { + args := m.Called() + return uint16(args.Int(0)), args.Bool(1) +} + +func (m *MockPort) SetStatus(port uint16, assigned bool) error { + return m.Called(port, assigned).Error(0) +} + +func (m *MockPort) Claim(port uint16) bool { + return m.Called(port).Bool(0) +} + +type MockListener struct { + mock.Mock +} + +func (m *MockListener) Accept() (net.Conn, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(net.Conn), args.Error(1) +} + +func (m *MockListener) Close() error { + return m.Called().Error(0) +} + +func (m *MockListener) Addr() net.Addr { + return m.Called().Get(0).(net.Addr) +} + +func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) { + key, _ := rsa.GenerateKey(rand.Reader, 2048) + signer, _ := ssh.NewSignerFromKey(key) + config := &ssh.ServerConfig{ + NoClientAuth: true, + } + config.AddHostKey(signer) + return config, signer +} + +func TestNew(t *testing.T) { + mr := new(MockRandom) + mc := new(MockConfig) + mreg := new(MockSessionRegistry) + mg := new(MockGRPCClient) + mp := new(MockPort) + sc, _ := getTestSSHConfig() + + tests := []struct { + name string + port string + wantErr bool + }{ + { + name: "success", + port: "0", + wantErr: false, + }, + { + name: "invalid port", + port: "invalid", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, err := New(mr, mc, sc, mreg, mg, mp, tt.port) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, s) + } else { + assert.NoError(t, err) + assert.NotNil(t, s) + _ = s.Close() + } + }) + } + + t.Run("port already in use", func(t *testing.T) { + l, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + port := l.Addr().(*net.TCPAddr).Port + defer func(l net.Listener) { + err = l.Close() + assert.NoError(t, err) + }(l) + + s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port)) + assert.Error(t, err) + assert.Nil(t, s) + }) +} + +func TestClose(t *testing.T) { + mr := new(MockRandom) + mc := new(MockConfig) + mreg := new(MockSessionRegistry) + mg := new(MockGRPCClient) + mp := new(MockPort) + sc, _ := getTestSSHConfig() + + t.Run("successful close", func(t *testing.T) { + s, _ := New(mr, mc, sc, mreg, mg, mp, "0") + err := s.Close() + assert.NoError(t, err) + }) + + t.Run("close already closed listener", func(t *testing.T) { + s, _ := New(mr, mc, sc, mreg, mg, mp, "0") + _ = s.Close() + err := s.Close() + assert.Error(t, err) + }) + + t.Run("close with nil listener", func(t *testing.T) { + s := &server{ + sshListener: nil, + } + defer func() { + if r := recover(); r != nil { + assert.NotNil(t, r) + } + }() + _ = s.Close() + t.Fatal("expected panic for nil listener") + }) +} + +func TestStart(t *testing.T) { + mr := new(MockRandom) + mc := new(MockConfig) + mreg := new(MockSessionRegistry) + mg := new(MockGRPCClient) + mp := new(MockPort) + sc, _ := getTestSSHConfig() + + t.Run("normal stop", func(t *testing.T) { + s, _ := New(mr, mc, sc, mreg, mg, mp, "0") + go func() { + time.Sleep(100 * time.Millisecond) + _ = s.Close() + }() + s.Start() + }) + + t.Run("accept error - temporary error continues loop", func(t *testing.T) { + ml := new(MockListener) + s := &server{ + sshListener: ml, + sshPort: "0", + } + + ml.On("Accept").Return(nil, errors.New("temporary error")).Once() + ml.On("Accept").Return(nil, net.ErrClosed).Once() + + s.Start() + ml.AssertExpectations(t) + }) + + t.Run("accept error - immediate close", func(t *testing.T) { + ml := new(MockListener) + s := &server{ + sshListener: ml, + sshPort: "0", + } + + ml.On("Accept").Return(nil, net.ErrClosed).Once() + + s.Start() + ml.AssertExpectations(t) + }) + + t.Run("accept success - connection fails SSH handshake", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockGrpcClient := &MockGRPCClient{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + + mockListener := &MockListener{} + mockListener.On("Accept").Return(serverConn, nil).Once() + mockListener.On("Accept").Return(nil, net.ErrClosed).Once() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshListener: mockListener, + sshConfig: sshConfig, + grpcClient: mockGrpcClient, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + go s.Start() + + time.Sleep(50 * time.Millisecond) + err := clientConn.Close() + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + mockListener.AssertExpectations(t) + }) + + t.Run("accept success - valid SSH connection without auth", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + + mockListener := &MockListener{} + mockListener.On("Accept").Return(serverConn, nil).Once() + mockListener.On("Accept").Return(nil, net.ErrClosed).Once() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshListener: mockListener, + sshConfig: sshConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + go s.Start() + + time.Sleep(50 * time.Millisecond) + err := clientConn.Close() + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + mockListener.AssertExpectations(t) + }) +} + +func TestHandleConnection(t *testing.T) { + t.Run("SSH handshake fails - connection closed", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockGrpcClient := &MockGRPCClient{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: sshConfig, + grpcClient: mockGrpcClient, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + err := clientConn.Close() + assert.NoError(t, err) + + s.handleConnection(serverConn) + }) + + // SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS + // GONNA IMPLEMENT THIS UNIT TEST LATER + + //t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) { + // mockRandom := &MockRandom{} + // mockConfig := &MockConfig{} + // mockSessionRegistry := &MockSessionRegistry{} + // mockGrpcClient := &MockGRPCClient{} + // mockPort := &MockPort{} + // + // sshConfig, _ := getTestSSHConfig() + // + // serverConn, clientConn := net.Pipe() + // + // s := &server{ + // randomizer: mockRandom, + // config: mockConfig, + // sshPort: "0", + // sshConfig: sshConfig, + // grpcClient: mockGrpcClient, + // sessionRegistry: mockSessionRegistry, + // portRegistry: mockPort, + // } + // + // done := make(chan bool, 1) + // + // go func() { + // s.handleConnection(serverConn) + // done <- true + // }() + // + // go func() { + // clientConn.Write([]byte("invalid ssh protocol\n")) + // clientConn.Close() + // }() + // + // select { + // case <-done: + // case <-time.After(1 * time.Second): + // t.Fatal("handleConnection did not complete in time") + // } + //}) + + t.Run("SSH connection established without gRPC client", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + serverConfig, _ := getTestSSHConfig() + + mockConfig.On("Domain").Return("test.com") + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("SSHPort").Return("2200") + mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil) + mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + mockSessionRegistry.On("Remove", mock.Anything).Return(nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func(listener net.Listener) { + err = listener.Close() + assert.NoError(t, err) + }(listener) + + serverAddr := listener.Addr().String() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: serverConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + done := make(chan bool, 1) + + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + s.handleConnection(conn) + done <- true + }() + + time.Sleep(50 * time.Millisecond) + + clientConfig := &ssh.ClientConfig{ + User: "testuser", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + go func() { + client, err := ssh.Dial("tcp", serverAddr, clientConfig) + if err != nil { + t.Logf("Client dial failed: %v", err) + return + } + defer func(client *ssh.Client) { + err = client.Close() + assert.NoError(t, err) + }(client) + + type forwardPayload struct { + BindAddr string + BindPort uint32 + } + + payload := ssh.Marshal(forwardPayload{ + BindAddr: "localhost", + BindPort: 80, + }) + + _, _, err = client.SendRequest("tcpip-forward", true, payload) + if err != nil { + t.Logf("Forward request failed: %v", err) + } + + time.Sleep(500 * time.Millisecond) + }() + + select { + case <-done: + t.Log("handleConnection completed") + case <-time.After(5 * time.Second): + t.Fatal("handleConnection did not complete in time") + } + }) + + t.Run("SSH connection established with gRPC authorization", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockGrpcClient := &MockGRPCClient{} + mockPort := &MockPort{} + + serverConfig, _ := getTestSSHConfig() + + mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil) + mockConfig.On("Domain").Return("test.com") + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("SSHPort").Return("2200") + mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil) + mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + mockSessionRegistry.On("Remove", mock.Anything).Return(nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func(listener net.Listener) { + err = listener.Close() + assert.NoError(t, err) + }(listener) + + serverAddr := listener.Addr().String() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: serverConfig, + grpcClient: mockGrpcClient, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + done := make(chan bool, 1) + + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + s.handleConnection(conn) + done <- true + }() + + time.Sleep(50 * time.Millisecond) + + clientConfig := &ssh.ClientConfig{ + User: "testuser", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + go func() { + client, err := ssh.Dial("tcp", serverAddr, clientConfig) + if err != nil { + t.Logf("Client dial failed: %v", err) + return + } + defer func(client *ssh.Client) { + err = client.Close() + assert.NoError(t, err) + }(client) + + type forwardPayload struct { + BindAddr string + BindPort uint32 + } + + payload := ssh.Marshal(forwardPayload{ + BindAddr: "localhost", + BindPort: 80, + }) + + _, _, err = client.SendRequest("tcpip-forward", true, payload) + if err != nil { + t.Logf("Forward request failed: %v", err) + } + + time.Sleep(500 * time.Millisecond) + }() + + select { + case <-done: + mockGrpcClient.AssertExpectations(t) + case <-time.After(5 * time.Second): + t.Fatal("handleConnection did not complete in time") + } + }) + + t.Run("SSH connection with gRPC authorization error", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockGrpcClient := &MockGRPCClient{} + mockPort := &MockPort{} + + serverConfig, _ := getTestSSHConfig() + + mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil) + mockConfig.On("Domain").Return("test.com") + mockConfig.On("Mode").Return(types.ServerModeNODE) + mockConfig.On("SSHPort").Return("2200") + mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil) + mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + mockSessionRegistry.On("Remove", mock.Anything).Return(nil) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer func(listener net.Listener) { + err = listener.Close() + assert.NoError(t, err) + }(listener) + + serverAddr := listener.Addr().String() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: serverConfig, + grpcClient: mockGrpcClient, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + done := make(chan bool, 1) + + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + s.handleConnection(conn) + done <- true + }() + + time.Sleep(50 * time.Millisecond) + + clientConfig := &ssh.ClientConfig{ + User: "testuser", + Auth: []ssh.AuthMethod{ssh.Password("password")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + + go func() { + client, err := ssh.Dial("tcp", serverAddr, clientConfig) + if err != nil { + t.Logf("Client dial failed: %v", err) + return + } + defer func(client *ssh.Client) { + _ = client.Close() + }(client) + + type forwardPayload struct { + BindAddr string + BindPort uint32 + } + + payload := ssh.Marshal(forwardPayload{ + BindAddr: "localhost", + BindPort: 8080, + }) + + _, _, err = client.SendRequest("tcpip-forward", true, payload) + if err != nil { + t.Logf("Forward request failed: %v", err) + } + + time.Sleep(500 * time.Millisecond) + }() + + select { + case <-done: + mockGrpcClient.AssertExpectations(t) + case <-time.After(5 * time.Second): + t.Fatal("handleConnection did not complete in time") + } + }) + + t.Run("connection cleanup on close", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + serverConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: serverConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + done := make(chan bool, 1) + + go func() { + s.handleConnection(serverConn) + done <- true + }() + + err := clientConn.Close() + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(1 * time.Second): + t.Fatal("handleConnection did not complete in time") + } + }) +} + +func TestIntegration(t *testing.T) { + t.Run("full server lifecycle", func(t *testing.T) { + mr := new(MockRandom) + mc := new(MockConfig) + mreg := new(MockSessionRegistry) + mg := new(MockGRPCClient) + mp := new(MockPort) + sc, _ := getTestSSHConfig() + + s, err := New(mr, mc, sc, mreg, mg, mp, "0") + assert.NoError(t, err) + assert.NotNil(t, s) + + go func() { + time.Sleep(100 * time.Millisecond) + err := s.Close() + assert.NoError(t, err) + }() + + s.Start() + }) + + t.Run("multiple connections", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + conn1Server, conn1Client := net.Pipe() + conn2Server, conn2Client := net.Pipe() + + mockListener := &MockListener{} + mockListener.On("Accept").Return(conn1Server, nil).Once() + mockListener.On("Accept").Return(conn2Server, nil).Once() + mockListener.On("Accept").Return(nil, net.ErrClosed).Once() + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshListener: mockListener, + sshConfig: sshConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + go s.Start() + + time.Sleep(50 * time.Millisecond) + _ = conn1Client.Close() + time.Sleep(50 * time.Millisecond) + _ = conn2Client.Close() + time.Sleep(100 * time.Millisecond) + + mockListener.AssertExpectations(t) + }) +} + +func TestErrorHandling(t *testing.T) { + t.Run("write error during SSH handshake", func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSessionRegistry := &MockSessionRegistry{} + mockPort := &MockPort{} + + sshConfig, _ := getTestSSHConfig() + + serverConn, clientConn := net.Pipe() + err := clientConn.Close() + assert.NoError(t, err) + + s := &server{ + randomizer: mockRandom, + config: mockConfig, + sshPort: "0", + sshConfig: sshConfig, + grpcClient: nil, + sessionRegistry: mockSessionRegistry, + portRegistry: mockPort, + } + + s.handleConnection(serverConn) + }) +} diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index c602565..03528c9 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -1,8 +1,7 @@ package forwarder import ( - "bytes" - "encoding/binary" + "context" "errors" "fmt" "io" @@ -10,7 +9,6 @@ import ( "net" "strconv" "sync" - "time" "tunnel_pls/internal/config" "tunnel_pls/session/slug" "tunnel_pls/types" @@ -26,9 +24,7 @@ type Forwarder interface { TunnelType() types.TunnelType ForwardedPort() uint16 HandleConnection(dst io.ReadWriter, src ssh.Channel) - CreateForwardedTCPIPPayload(origin net.Addr) []byte - OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) - WriteBadGatewayResponse(dst io.Writer) + OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) Close() error } type forwarder struct { @@ -50,19 +46,21 @@ func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder { bufferPool: sync.Pool{ New: func() interface{} { bufSize := config.BufferSize() - return make([]byte, bufSize) + buf := make([]byte, bufSize) + return &buf }, }, } } func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { - buf := f.bufferPool.Get().([]byte) + buf := f.bufferPool.Get().(*[]byte) defer f.bufferPool.Put(buf) - return io.CopyBuffer(dst, src, buf) + return io.CopyBuffer(dst, src, *buf) } -func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { +func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) { + payload := createForwardedTCPIPPayload(origin, f.forwardedPort) type channelResult struct { channel ssh.Channel reqs <-chan *ssh.Request @@ -74,13 +72,9 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) select { case resultChan <- channelResult{channel, reqs, err}: - default: + case <-ctx.Done(): if channel != nil { - err = channel.Close() - if err != nil { - log.Printf("Failed to close unused channel: %v", err) - return - } + _ = channel.Close() go ssh.DiscardRequests(reqs) } } @@ -89,8 +83,8 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s select { case result := <-resultChan: return result.channel, result.reqs, result.err - case <-time.After(5 * time.Second): - return nil, nil, errors.New("timeout opening forwarded-tcpip channel") + case <-ctx.Done(): + return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err()) } } @@ -119,10 +113,7 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { defer func() { - _, err := io.Copy(io.Discard, src) - if err != nil { - log.Printf("Failed to discard connection: %v", err) - } + _, _ = io.Copy(io.Discard, src) }() var wg sync.WaitGroup @@ -173,14 +164,6 @@ func (f *forwarder) Listener() net.Listener { return f.listener } -func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) { - _, err := dst.Write(types.BadGatewayResponse) - if err != nil { - log.Printf("failed to write Bad Gateway response: %v", err) - return - } -} - func (f *forwarder) Close() error { if f.Listener() != nil { return f.listener.Close() @@ -188,43 +171,21 @@ func (f *forwarder) Close() error { return nil } -func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { - var buf bytes.Buffer - - host, originPort := parseAddr(origin.String()) - - writeSSHString(&buf, "localhost") - err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort())) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - - writeSSHString(&buf, host) - err = binary.Write(&buf, binary.BigEndian, uint32(originPort)) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return nil - } - - return buf.Bytes() -} - -func parseAddr(addr string) (string, uint16) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr) - return "0.0.0.0", uint16(0) - } +func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte { + host, portStr, _ := net.SplitHostPort(origin.String()) port, _ := strconv.Atoi(portStr) - return host, uint16(port) -} -func writeSSHString(buffer *bytes.Buffer, str string) { - err := binary.Write(buffer, binary.BigEndian, uint32(len(str))) - if err != nil { - log.Printf("Failed to write string to buffer: %v", err) - return + forwardPayload := struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 + }{ + DestAddr: "localhost", + DestPort: uint32(destPort), + OriginAddr: host, + OriginPort: uint32(port), } - buffer.WriteString(str) + + return ssh.Marshal(forwardPayload) } diff --git a/session/forwarder/forwarder_test.go b/session/forwarder/forwarder_test.go new file mode 100644 index 0000000..092d783 --- /dev/null +++ b/session/forwarder/forwarder_test.go @@ -0,0 +1,1806 @@ +package forwarder + +import ( + "bytes" + "context" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "testing" + "time" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +type mockConfig struct { + mock.Mock +} + +func (m *mockConfig) Domain() string { return m.Called().String(0) } +func (m *mockConfig) SSHPort() string { return m.Called().String(0) } +func (m *mockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *mockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *mockConfig) KeyLoc() string { return m.Called().String(0) } +func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *mockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *mockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *mockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *mockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *mockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *mockConfig) AllowedPortsStart() uint16 { return m.Called().Get(0).(uint16) } +func (m *mockConfig) AllowedPortsEnd() uint16 { return m.Called().Get(0).(uint16) } +func (m *mockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *mockConfig) HeaderSize() int { return m.Called().Int(0) } +func (m *mockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *mockConfig) PprofPort() string { return m.Called().String(0) } +func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } +func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *mockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *mockConfig) NodeToken() string { return m.Called().String(0) } + +type mockConn struct { + mock.Mock +} + +func (c *mockConn) Close() error { return c.Called().Error(0) } +func (c *mockConn) User() string { return c.Called().String(0) } +func (c *mockConn) SessionID() []byte { return c.Called().Get(0).([]byte) } +func (c *mockConn) ClientVersion() []byte { return c.Called().Get(0).([]byte) } +func (c *mockConn) ServerVersion() []byte { return c.Called().Get(0).([]byte) } +func (c *mockConn) RemoteAddr() net.Addr { return c.Called().Get(0).(net.Addr) } +func (c *mockConn) LocalAddr() net.Addr { return c.Called().Get(0).(net.Addr) } +func (c *mockConn) SendRequest(s string, b bool, d []byte) (bool, []byte, error) { + args := c.Called(s, b, d) + return args.Bool(0), args.Get(1).([]byte), args.Error(2) +} +func (c *mockConn) Wait() error { return c.Called().Error(0) } + +func (c *mockConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { + args := c.Called(name, data) + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +type testChannel struct { + mock.Mock + readBuf *syncBuffer + writeBuf *syncBuffer + closedWrite atomic.Bool +} + +func (c *testChannel) Read(b []byte) (int, error) { + return c.readBuf.Read(b) +} + +func (c *testChannel) Write(b []byte) (int, error) { + return c.writeBuf.Write(b) +} + +func (c *testChannel) Close() error { + return c.Called().Error(0) +} + +func (c *testChannel) CloseWrite() error { + c.closedWrite.Store(true) + return c.writeBuf.Close() +} + +func (c *testChannel) Stderr() io.ReadWriter { + return c.Called().Get(0).(io.ReadWriter) +} + +func (c *testChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + args := c.Called(name, wantReply, payload) + return args.Bool(0), args.Error(1) +} + +func (c *testChannel) AckRequest(ok bool, payload []byte) error { + return c.Called(ok, payload).Error(0) +} + +type syncBuffer struct { + mu sync.Mutex + buf []byte + closed bool + cond *sync.Cond +} + +func newSyncBuffer() *syncBuffer { + sb := &syncBuffer{} + sb.cond = sync.NewCond(&sb.mu) + return sb +} + +func (sb *syncBuffer) Write(p []byte) (int, error) { + sb.mu.Lock() + defer sb.mu.Unlock() + if sb.closed { + return 0, io.ErrClosedPipe + } + sb.buf = append(sb.buf, p...) + sb.cond.Broadcast() + return len(p), nil +} + +func (sb *syncBuffer) Read(p []byte) (int, error) { + sb.mu.Lock() + defer sb.mu.Unlock() + + for len(sb.buf) == 0 { + if sb.closed { + return 0, io.EOF + } + sb.cond.Wait() + } + + n := copy(p, sb.buf) + sb.buf = sb.buf[n:] + return n, nil +} + +func (sb *syncBuffer) Close() error { + sb.mu.Lock() + defer sb.mu.Unlock() + sb.closed = true + sb.cond.Broadcast() + return nil +} + +func newChannelPair() (*testChannel, *testChannelPeer) { + peerToChBuf := newSyncBuffer() + chToPeerBuf := newSyncBuffer() + + channel := &testChannel{ + readBuf: peerToChBuf, + writeBuf: chToPeerBuf, + } + + peer := &testChannelPeer{ + readBuf: chToPeerBuf, + writeBuf: peerToChBuf, + } + + channel.On("Close").Return(nil).Maybe() + + return channel, peer +} + +type testChannelPeer struct { + readBuf *syncBuffer + writeBuf *syncBuffer +} + +func (p *testChannelPeer) Read(b []byte) (int, error) { + return p.readBuf.Read(b) +} + +func (p *testChannelPeer) Write(b []byte) (int, error) { + return p.writeBuf.Write(b) +} + +func (p *testChannelPeer) CloseWrite() error { + return p.writeBuf.Close() +} + +func newPipePair() (*pipeConn, *pipeConn) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + conn1 := &pipeConn{ + reader: r1, + writer: w2, + } + + conn2 := &pipeConn{ + reader: r2, + writer: w1, + } + + return conn1, conn2 +} + +type pipeConn struct { + reader *io.PipeReader + writer *io.PipeWriter +} + +func (p *pipeConn) Read(b []byte) (int, error) { + return p.reader.Read(b) +} + +func (p *pipeConn) Write(b []byte) (int, error) { + return p.writer.Write(b) +} + +func (p *pipeConn) Close() error { + err := p.reader.Close() + if err != nil { + return err + } + err = p.writer.Close() + if err != nil { + return err + } + return nil +} + +func (p *pipeConn) CloseWrite() error { + return p.writer.Close() +} + +func (p *pipeConn) LocalAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (p *pipeConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0} +} + +func (p *pipeConn) SetDeadline(t time.Time) error { return nil } +func (p *pipeConn) SetReadDeadline(t time.Time) error { return nil } +func (p *pipeConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestNew(t *testing.T) { + tests := []struct { + name string + bufferSize int + wantBufLen int + }{ + { + name: "default buffer size", + bufferSize: 16, + wantBufLen: 16, + }, + { + name: "custom buffer size", + bufferSize: 32, + wantBufLen: 32, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(tt.bufferSize).Maybe() + s := slug.New() + conn := &mockConn{} + + forwarder := New(cfg, s, conn).(*forwarder) + + buf := forwarder.bufferPool.Get().(*[]byte) + require.Len(t, *buf, tt.wantBufLen) + forwarder.bufferPool.Put(buf) + + assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType()) + assert.Equal(t, uint16(0), forwarder.ForwardedPort()) + assert.Equal(t, conn, forwarder.conn) + assert.Equal(t, s, forwarder.slug) + cfg.AssertExpectations(t) + }) + } +} + +func TestHandleConnection(t *testing.T) { + tests := []struct { + name string + bufferSize int + messageToDst []byte + messageToSrc []byte + }{ + { + name: "small messages", + bufferSize: 4, + messageToDst: []byte("hi"), + messageToSrc: []byte("yo"), + }, + { + name: "medium messages", + bufferSize: 8, + messageToDst: []byte("hello"), + messageToSrc: []byte("world"), + }, + { + name: "larger messages", + bufferSize: 16, + messageToDst: []byte("I love femboy"), + messageToSrc: []byte("mee too"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(tt.bufferSize).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + channel, channelPeer := newChannelPair() + dstEndpoint, dstPeer := newPipePair() + + done := make(chan struct{}) + go func() { + forwarder.HandleConnection(dstEndpoint, channel) + close(done) + }() + + readDst := make(chan struct { + data []byte + err error + }, 1) + go func() { + buf := make([]byte, len(tt.messageToDst)) + n, err := io.ReadFull(dstPeer, buf) + readDst <- struct { + data []byte + err error + }{data: buf[:n], err: err} + }() + + _, err := channelPeer.Write(tt.messageToDst) + require.NoError(t, err) + + dstResult := <-readDst + require.NoError(t, dstResult.err) + assert.Equal(t, tt.messageToDst, dstResult.data) + + readSrc := make(chan struct { + data []byte + err error + }, 1) + go func() { + buf := make([]byte, len(tt.messageToSrc)) + n, err := io.ReadFull(channelPeer, buf) + readSrc <- struct { + data []byte + err error + }{data: buf[:n], err: err} + }() + + _, err = dstPeer.Write(tt.messageToSrc) + require.NoError(t, err) + + srcResult := <-readSrc + require.NoError(t, srcResult.err) + assert.Equal(t, tt.messageToSrc, srcResult.data) + + require.NoError(t, channelPeer.CloseWrite()) + require.NoError(t, dstPeer.CloseWrite()) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("HandleConnection did not complete") + } + assert.True(t, channel.closedWrite.Load()) + cfg.AssertExpectations(t) + }) + } +} + +func TestHandleConnection_Error(t *testing.T) { + tests := []struct { + name string + bufferSize int + messageToDst []byte + messageToSrc []byte + }{ + { + name: "small messages", + bufferSize: 4, + messageToDst: []byte("hi"), + messageToSrc: []byte("yo"), + }, + { + name: "medium messages", + bufferSize: 8, + messageToDst: []byte("hello"), + messageToSrc: []byte("world"), + }, + { + name: "larger messages", + bufferSize: 16, + messageToDst: []byte("I love femboy"), + messageToSrc: []byte("mee too"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(tt.bufferSize).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + channel, _ := newChannelPair() + dstEndpoint, _ := newPipePair() + + go func() { + forwarder.HandleConnection(dstEndpoint, channel) + }() + + err := dstEndpoint.Close() + assert.NoError(t, err) + cfg.AssertExpectations(t) + }) + } +} + +func TestOpenForwardedChannel(t *testing.T) { + tests := []struct { + name string + forwardedPort uint16 + originIP string + originPort int + wantDestAddr string + wantDestPort uint32 + wantOrigAddr string + wantOrigPort uint32 + }{ + { + name: "localhost origin", + forwardedPort: 2222, + originIP: "127.0.0.1", + originPort: 9000, + wantDestAddr: "localhost", + wantDestPort: 2222, + wantOrigAddr: "127.0.0.1", + wantOrigPort: 9000, + }, + { + name: "remote origin", + forwardedPort: 8080, + originIP: "192.168.1.100", + originPort: 5000, + wantDestAddr: "localhost", + wantDestPort: 8080, + wantOrigAddr: "192.168.1.100", + wantOrigPort: 5000, + }, + { + name: "different port", + forwardedPort: 3000, + originIP: "10.0.0.1", + originPort: 7777, + wantDestAddr: "localhost", + wantDestPort: 3000, + wantOrigAddr: "10.0.0.1", + wantOrigPort: 7777, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(8).Maybe() + channel := &testChannel{ + readBuf: newSyncBuffer(), + writeBuf: newSyncBuffer(), + } + requests := make(chan *ssh.Request) + + var capturedData []byte + conn := &mockConn{} + conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) { + data := args.Get(1).([]byte) + capturedData = make([]byte, len(data)) + copy(capturedData, data) + }).Return(channel, (<-chan *ssh.Request)(requests), nil) + + forwarder := New(cfg, slug.New(), conn).(*forwarder) + forwarder.SetForwardedPort(tt.forwardedPort) + + origin := &net.TCPAddr{IP: net.ParseIP(tt.originIP), Port: tt.originPort} + ch, reqs, err := forwarder.OpenForwardedChannel(context.Background(), origin) + require.NoError(t, err) + assert.Same(t, channel, ch) + assert.NotNil(t, reqs) + + var payload struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 + } + err = ssh.Unmarshal(capturedData, &payload) + assert.NoError(t, err) + assert.Equal(t, tt.wantDestAddr, payload.DestAddr) + assert.Equal(t, tt.wantDestPort, payload.DestPort) + assert.Equal(t, tt.wantOrigAddr, payload.OriginAddr) + assert.Equal(t, tt.wantOrigPort, payload.OriginPort) + + conn.AssertExpectations(t) + cfg.AssertExpectations(t) + }) + } +} + +func TestOpenForwardedChannelContextCancellation(t *testing.T) { + tests := []struct { + name string + cancelBefore bool + cancelDuring bool + wantErr bool + wantErrType error + }{ + { + name: "cancel during open", + cancelBefore: false, + cancelDuring: true, + wantErr: true, + wantErrType: context.Canceled, + }, + { + name: "cancel before open", + cancelBefore: true, + cancelDuring: false, + wantErr: true, + wantErrType: context.Canceled, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(8).Maybe() + channel := &testChannel{ + readBuf: newSyncBuffer(), + writeBuf: newSyncBuffer(), + } + channel.On("Close").Return(nil) + requests := make(chan *ssh.Request) + + openChannelCalled := make(chan struct{}) + openChannelBlock := make(chan struct{}) + + conn := &mockConn{} + conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) { + close(openChannelCalled) + <-openChannelBlock + }).Return(channel, (<-chan *ssh.Request)(requests), nil).Maybe() + + forwarder := New(cfg, slug.New(), conn).(*forwarder) + forwarder.SetForwardedPort(8080) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + if tt.cancelBefore { + cancel() + } + + var ( + openedChannel ssh.Channel + openedReqs <-chan *ssh.Request + openErr error + ) + + done := make(chan struct{}) + go func() { + origin := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 7000} + openedChannel, openedReqs, openErr = forwarder.OpenForwardedChannel(ctx, origin) + close(done) + }() + + if tt.cancelDuring { + <-openChannelCalled + cancel() + } + close(openChannelBlock) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("OpenForwardedChannel did not return after cancellation") + } + + if tt.wantErr { + require.Error(t, openErr) + assert.True(t, errors.Is(openErr, tt.wantErrType)) + assert.Nil(t, openedChannel) + assert.Nil(t, openedReqs) + } else { + require.NoError(t, openErr) + assert.NotNil(t, openedChannel) + assert.NotNil(t, openedReqs) + } + + conn.AssertExpectations(t) + cfg.AssertExpectations(t) + }) + } +} + +func TestCreateForwardedTCPIPPayload(t *testing.T) { + tests := []struct { + name string + originIP string + originPort int + forwardedPort uint16 + wantDestAddr string + wantDestPort uint32 + wantOriginAddr string + wantOriginPort uint32 + }{ + { + name: "standard case", + originIP: "192.0.2.10", + originPort: 5050, + forwardedPort: 8080, + wantDestAddr: "localhost", + wantDestPort: 8080, + wantOriginAddr: "192.0.2.10", + wantOriginPort: 5050, + }, + { + name: "localhost origin", + originIP: "127.0.0.1", + originPort: 3000, + forwardedPort: 9000, + wantDestAddr: "localhost", + wantDestPort: 9000, + wantOriginAddr: "127.0.0.1", + wantOriginPort: 3000, + }, + { + name: "high port numbers", + originIP: "10.0.0.1", + originPort: 65535, + forwardedPort: 65534, + wantDestAddr: "localhost", + wantDestPort: 65534, + wantOriginAddr: "10.0.0.1", + wantOriginPort: 65535, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origin := &net.TCPAddr{IP: net.ParseIP(tt.originIP), Port: tt.originPort} + payload := createForwardedTCPIPPayload(origin, tt.forwardedPort) + + var decoded struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 + } + + err := ssh.Unmarshal(payload, &decoded) + assert.NoError(t, err) + assert.Equal(t, tt.wantDestAddr, decoded.DestAddr) + assert.Equal(t, tt.wantDestPort, decoded.DestPort) + assert.Equal(t, tt.wantOriginAddr, decoded.OriginAddr) + assert.Equal(t, tt.wantOriginPort, decoded.OriginPort) + }) + } +} + +type mockReader struct { + mock.Mock +} + +func (m *mockReader) Read(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +type mockWriter struct { + mock.Mock +} + +func (m *mockWriter) Write(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *mockWriter) CloseWrite() error { + return m.Called().Error(0) +} + +type mockWriteCloser struct { + mock.Mock +} + +func (m *mockWriteCloser) Write(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *mockWriteCloser) Close() error { + return m.Called().Error(0) +} + +func TestCopyAndClose(t *testing.T) { + tests := []struct { + name string + setupSrc func() io.Reader + setupDst func() io.Writer + direction string + wantErr bool + wantErrMsg string + checkErrTypes []error + }{ + { + name: "successful copy with EOF", + setupSrc: func() io.Reader { + r := &mockReader{} + r.On("Read", mock.Anything).Return(5, nil).Once() + r.On("Read", mock.Anything).Return(0, io.EOF).Once() + return r + }, + setupDst: func() io.Writer { + w := &mockWriter{} + w.On("Write", mock.Anything).Return(5, nil).Once() + w.On("CloseWrite").Return(nil).Once() + return w + }, + direction: "src->dst", + wantErr: false, + }, + { + name: "copy error - not EOF or ErrClosed", + setupSrc: func() io.Reader { + r := &mockReader{} + customErr := errors.New("custom read error") + r.On("Read", mock.Anything).Return(0, customErr).Once() + return r + }, + setupDst: func() io.Writer { + w := &mockWriter{} + w.On("CloseWrite").Return(nil).Once() + return w + }, + direction: "src->dst", + wantErr: true, + wantErrMsg: "copy error (src->dst)", + }, + { + name: "copy error - ErrClosed should be ignored", + setupSrc: func() io.Reader { + r := &mockReader{} + r.On("Read", mock.Anything).Return(0, net.ErrClosed).Once() + return r + }, + setupDst: func() io.Writer { + w := &mockWriter{} + w.On("CloseWrite").Return(nil).Once() + return w + }, + direction: "src->dst", + wantErr: false, + }, + { + name: "close writer error - not EOF", + setupSrc: func() io.Reader { + r := &mockReader{} + r.On("Read", mock.Anything).Return(0, io.EOF).Once() + return r + }, + setupDst: func() io.Writer { + w := &mockWriter{} + closeErr := errors.New("close error") + w.On("CloseWrite").Return(closeErr).Once() + return w + }, + direction: "src->dst", + wantErr: true, + wantErrMsg: "close stream error (src->dst)", + }, + { + name: "close writer error - EOF should be ignored", + setupSrc: func() io.Reader { + r := &mockReader{} + r.On("Read", mock.Anything).Return(0, io.EOF).Once() + return r + }, + setupDst: func() io.Writer { + w := &mockWriter{} + w.On("CloseWrite").Return(io.EOF).Once() + return w + }, + direction: "src->dst", + wantErr: false, + }, + { + name: "both copy and close errors", + setupSrc: func() io.Reader { + r := &mockReader{} + copyErr := errors.New("copy error") + r.On("Read", mock.Anything).Return(0, copyErr).Once() + return r + }, + setupDst: func() io.Writer { + w := &mockWriter{} + closeErr := errors.New("close error") + w.On("CloseWrite").Return(closeErr).Once() + return w + }, + direction: "src->dst", + wantErr: true, + wantErrMsg: "copy error (src->dst)", + }, + { + name: "successful copy with WriteCloser", + setupSrc: func() io.Reader { + r := &mockReader{} + r.On("Read", mock.Anything).Return(5, nil).Once() + r.On("Read", mock.Anything).Return(0, io.EOF).Once() + return r + }, + setupDst: func() io.Writer { + w := &mockWriteCloser{} + w.On("Write", mock.Anything).Return(5, nil).Once() + w.On("Close").Return(nil).Once() + return w + }, + direction: "dst->src", + wantErr: false, + }, + { + name: "WriteCloser close error", + setupSrc: func() io.Reader { + r := &mockReader{} + r.On("Read", mock.Anything).Return(0, io.EOF).Once() + return r + }, + setupDst: func() io.Writer { + w := &mockWriteCloser{} + closeErr := errors.New("writeCloser close error") + w.On("Close").Return(closeErr).Once() + return w + }, + direction: "dst->src", + wantErr: true, + wantErrMsg: "close stream error (dst->src)", + }, + { + name: "copy with multiple reads before EOF", + setupSrc: func() io.Reader { + r := &mockReader{} + r.On("Read", mock.Anything).Return(10, nil).Once() + r.On("Read", mock.Anything).Return(15, nil).Once() + r.On("Read", mock.Anything).Return(5, nil).Once() + r.On("Read", mock.Anything).Return(0, io.EOF).Once() + return r + }, + setupDst: func() io.Writer { + w := &mockWriter{} + w.On("Write", mock.Anything).Return(10, nil).Once() + w.On("Write", mock.Anything).Return(15, nil).Once() + w.On("Write", mock.Anything).Return(5, nil).Once() + w.On("CloseWrite").Return(nil).Once() + return w + }, + direction: "src->dst", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(32).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + src := tt.setupSrc() + dst := tt.setupDst() + + err := forwarder.copyAndClose(dst, src, tt.direction) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErrMsg) + } else { + assert.NoError(t, err) + } + + if mr, ok := src.(*mockReader); ok { + mr.AssertExpectations(t) + } + if mw, ok := dst.(*mockWriter); ok { + mw.AssertExpectations(t) + } + if mwc, ok := dst.(*mockWriteCloser); ok { + mwc.AssertExpectations(t) + } + cfg.AssertExpectations(t) + }) + } +} + +func TestCopyAndCloseJoinedErrors(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(32).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + src := &mockReader{} + copyErr := errors.New("copy failed") + src.On("Read", mock.Anything).Return(0, copyErr).Once() + + dst := &mockWriter{} + closeErr := errors.New("close failed") + dst.On("CloseWrite").Return(closeErr).Once() + + err := forwarder.copyAndClose(dst, src, "test") + + require.Error(t, err) + + assert.Contains(t, err.Error(), "copy error (test)") + assert.Contains(t, err.Error(), "close stream error (test)") + assert.Contains(t, err.Error(), "copy failed") + assert.Contains(t, err.Error(), "close failed") + + src.AssertExpectations(t) + dst.AssertExpectations(t) + cfg.AssertExpectations(t) +} + +func TestCopyWithBuffer(t *testing.T) { + tests := []struct { + name string + bufferSize int + setupSrc func() io.Reader + setupDst func() io.Writer + wantBytesCount int64 + wantErr bool + wantErrType error + }{ + { + name: "successful copy small data", + bufferSize: 16, + setupSrc: func() io.Reader { + return io.NopCloser(bytes.NewReader([]byte("hello world"))) + }, + setupDst: func() io.Writer { + return &bytes.Buffer{} + }, + wantBytesCount: 11, + wantErr: false, + }, + { + name: "successful copy large data", + bufferSize: 8, + setupSrc: func() io.Reader { + data := make([]byte, 1024) + for i := range data { + data[i] = byte(i % 256) + } + return io.NopCloser(bytes.NewReader(data)) + }, + setupDst: func() io.Writer { + return &bytes.Buffer{} + }, + wantBytesCount: 1024, + wantErr: false, + }, + { + name: "empty data", + bufferSize: 16, + setupSrc: func() io.Reader { + return io.NopCloser(bytes.NewReader([]byte{})) + }, + setupDst: func() io.Writer { + return &bytes.Buffer{} + }, + wantBytesCount: 0, + wantErr: false, + }, + { + name: "read error", + bufferSize: 16, + setupSrc: func() io.Reader { + r := &mockReader{} + r.On("Read", mock.Anything).Return(0, errors.New("read error")).Once() + return r + }, + setupDst: func() io.Writer { + return &bytes.Buffer{} + }, + wantBytesCount: 0, + wantErr: true, + }, + { + name: "write error", + bufferSize: 16, + setupSrc: func() io.Reader { + return io.NopCloser(bytes.NewReader([]byte("test data"))) + }, + setupDst: func() io.Writer { + w := &mockWriter{} + w.On("Write", mock.Anything).Return(0, errors.New("write error")).Once() + return w + }, + wantBytesCount: 0, + wantErr: true, + }, + { + name: "partial write continues", + bufferSize: 16, + setupSrc: func() io.Reader { + return io.NopCloser(bytes.NewReader([]byte("testing"))) + }, + setupDst: func() io.Writer { + buf := &bytes.Buffer{} + return buf + }, + wantBytesCount: 7, + wantErr: false, + }, + { + name: "multiple buffer fills", + bufferSize: 4, + setupSrc: func() io.Reader { + return io.NopCloser(bytes.NewReader([]byte("this is a longer message"))) + }, + setupDst: func() io.Writer { + return &bytes.Buffer{} + }, + wantBytesCount: 24, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(tt.bufferSize).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + src := tt.setupSrc() + dst := tt.setupDst() + + n, err := forwarder.copyWithBuffer(dst, src) + + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantBytesCount, n) + } + + if buf, ok := dst.(*bytes.Buffer); ok && !tt.wantErr { + assert.Equal(t, tt.wantBytesCount, int64(buf.Len())) + } + + if mr, ok := src.(*mockReader); ok { + mr.AssertExpectations(t) + } + if mw, ok := dst.(*mockWriter); ok { + mw.AssertExpectations(t) + } + cfg.AssertExpectations(t) + }) + } +} + +func TestCopyWithBufferReusesBuffer(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(16).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + buf1 := forwarder.bufferPool.Get().(*[]byte) + initialPtr := buf1 + + forwarder.bufferPool.Put(buf1) + + src := io.NopCloser(bytes.NewReader([]byte("test"))) + dst := &bytes.Buffer{} + _, err := forwarder.copyWithBuffer(dst, src) + require.NoError(t, err) + + buf2 := forwarder.bufferPool.Get().(*[]byte) + secondPtr := buf2 + + forwarder.bufferPool.Put(buf2) + + assert.Equal(t, initialPtr, secondPtr, "Buffers should be the same pointer") + + assert.Len(t, *buf2, 16) + assert.Len(t, *buf1, 16) + + _ = initialPtr + _ = secondPtr + + cfg.AssertExpectations(t) +} + +func TestSetType(t *testing.T) { + tests := []struct { + name string + tunnelType types.TunnelType + }{ + { + name: "set to HTTP", + tunnelType: types.TunnelTypeHTTP, + }, + { + name: "set to TCP", + tunnelType: types.TunnelTypeTCP, + }, + { + name: "set to UNKNOWN", + tunnelType: types.TunnelTypeUNKNOWN, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(16).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType()) + + forwarder.SetType(tt.tunnelType) + + assert.Equal(t, tt.tunnelType, forwarder.TunnelType()) + cfg.AssertExpectations(t) + }) + } +} + +func TestTunnelType(t *testing.T) { + tests := []struct { + name string + tunnelType types.TunnelType + }{ + { + name: "get HTTP type", + tunnelType: types.TunnelTypeHTTP, + }, + { + name: "get TCP type", + tunnelType: types.TunnelTypeTCP, + }, + { + name: "get UNKNOWN type", + tunnelType: types.TunnelTypeUNKNOWN, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(16).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + forwarder.SetType(tt.tunnelType) + result := forwarder.TunnelType() + + assert.Equal(t, tt.tunnelType, result) + cfg.AssertExpectations(t) + }) + } +} + +func TestSetForwardedPort(t *testing.T) { + tests := []struct { + name string + port uint16 + }{ + { + name: "set standard port", + port: 8080, + }, + { + name: "set low port", + port: 80, + }, + { + name: "set high port", + port: 65535, + }, + { + name: "set zero port", + port: 0, + }, + { + name: "set custom port", + port: 3000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(16).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + assert.Equal(t, uint16(0), forwarder.ForwardedPort()) + + forwarder.SetForwardedPort(tt.port) + + assert.Equal(t, tt.port, forwarder.ForwardedPort()) + cfg.AssertExpectations(t) + }) + } +} + +func TestForwardedPort(t *testing.T) { + tests := []struct { + name string + port uint16 + }{ + { + name: "get default port", + port: 0, + }, + { + name: "get standard port", + port: 8080, + }, + { + name: "get high port", + port: 65535, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(16).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + if tt.port != 0 { + forwarder.SetForwardedPort(tt.port) + } + + result := forwarder.ForwardedPort() + assert.Equal(t, tt.port, result) + cfg.AssertExpectations(t) + }) + } +} + +func TestSetListener(t *testing.T) { + tests := []struct { + name string + setupListener func() net.Listener + }{ + { + name: "set TCP listener", + setupListener: func() net.Listener { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + return listener + }, + }, + { + name: "set nil listener", + setupListener: func() net.Listener { + return nil + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(16).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + listener := tt.setupListener() + if listener != nil { + defer func(listener net.Listener) { + err := listener.Close() + assert.NoError(t, err) + }(listener) + } + + assert.Nil(t, forwarder.Listener()) + + forwarder.SetListener(listener) + + assert.Equal(t, listener, forwarder.Listener()) + cfg.AssertExpectations(t) + }) + } +} + +func TestListener(t *testing.T) { + tests := []struct { + name string + setupListener func() net.Listener + }{ + { + name: "get nil listener", + setupListener: func() net.Listener { + return nil + }, + }, + { + name: "get TCP listener", + setupListener: func() net.Listener { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + return listener + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(16).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + listener := tt.setupListener() + if listener != nil { + defer func(listener net.Listener) { + err := listener.Close() + assert.NoError(t, err) + }(listener) + forwarder.SetListener(listener) + } + + result := forwarder.Listener() + assert.Equal(t, listener, result) + cfg.AssertExpectations(t) + }) + } +} + +func TestClose(t *testing.T) { + tests := []struct { + name string + setupListener func() net.Listener + wantErr bool + }{ + { + name: "close with nil listener", + setupListener: func() net.Listener { + return nil + }, + wantErr: false, + }, + { + name: "close with active listener", + setupListener: func() net.Listener { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + return listener + }, + wantErr: false, + }, + { + name: "close already closed listener", + setupListener: func() net.Listener { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + err = listener.Close() + assert.NoError(t, err) + return listener + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(16).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + listener := tt.setupListener() + if listener != nil { + forwarder.SetListener(listener) + } + + err := forwarder.Close() + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + cfg.AssertExpectations(t) + }) + } +} + +func TestCloseWriter(t *testing.T) { + tests := []struct { + name string + setup func() io.Writer + wantErr bool + }{ + { + name: "close writer with CloseWrite method", + setup: func() io.Writer { + w := &mockWriter{} + w.On("CloseWrite").Return(nil).Once() + return w + }, + wantErr: false, + }, + { + name: "close writer with CloseWrite error", + setup: func() io.Writer { + w := &mockWriter{} + w.On("CloseWrite").Return(errors.New("close write error")).Once() + return w + }, + wantErr: true, + }, + { + name: "close WriteCloser", + setup: func() io.Writer { + w := &mockWriteCloser{} + w.On("Close").Return(nil).Once() + return w + }, + wantErr: false, + }, + { + name: "close WriteCloser with error", + setup: func() io.Writer { + w := &mockWriteCloser{} + w.On("Close").Return(errors.New("close error")).Once() + return w + }, + wantErr: true, + }, + { + name: "close plain writer (no close method)", + setup: func() io.Writer { + return &bytes.Buffer{} + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + writer := tt.setup() + + err := closeWriter(writer) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + + if mw, ok := writer.(*mockWriter); ok { + mw.AssertExpectations(t) + } + if mwc, ok := writer.(*mockWriteCloser); ok { + mwc.AssertExpectations(t) + } + }) + } +} + +func TestHandleConnectionWithErrors(t *testing.T) { + tests := []struct { + name string + bufferSize int + setupChannel func() (*testChannel, *testChannelPeer) + setupDst func() (net.Conn, *pipeConn) + simulateErr func(channel *testChannelPeer, dst *pipeConn) + }{ + { + name: "handle read error from channel", + bufferSize: 16, + setupChannel: func() (*testChannel, *testChannelPeer) { + return newChannelPair() + }, + setupDst: func() (net.Conn, *pipeConn) { + return newPipePair() + }, + simulateErr: func(channel *testChannelPeer, dst *pipeConn) { + err := channel.CloseWrite() + assert.NoError(t, err) + err = dst.CloseWrite() + assert.NoError(t, err) + }, + }, + { + name: "handle write error to destination", + bufferSize: 16, + setupChannel: func() (*testChannel, *testChannelPeer) { + return newChannelPair() + }, + setupDst: func() (net.Conn, *pipeConn) { + return newPipePair() + }, + simulateErr: func(channel *testChannelPeer, dst *pipeConn) { + err := dst.Close() + assert.NoError(t, err) + time.Sleep(10 * time.Millisecond) + write, err := channel.Write([]byte("test")) + assert.NotZero(t, write) + assert.NoError(t, err) + err = channel.CloseWrite() + assert.NoError(t, err) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(tt.bufferSize).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + channel, channelPeer := tt.setupChannel() + dstEndpoint, dstPeer := tt.setupDst() + + done := make(chan struct{}) + go func() { + forwarder.HandleConnection(dstEndpoint, channel) + close(done) + }() + + tt.simulateErr(channelPeer, dstPeer) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("HandleConnection did not complete") + } + + cfg.AssertExpectations(t) + }) + } +} + +func TestHandleConnectionDiscardOnExit(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(16).Maybe() + forwarder := New(cfg, slug.New(), nil).(*forwarder) + + channel, channelPeer := newChannelPair() + dstEndpoint, dstPeer := newPipePair() + + done := make(chan struct{}) + go func() { + forwarder.HandleConnection(dstEndpoint, channel) + close(done) + }() + + _, err := channelPeer.Write([]byte("test data")) + require.NoError(t, err) + require.NoError(t, channelPeer.CloseWrite()) + require.NoError(t, dstPeer.Close()) + + select { + case <-done: + case <-time.After(10 * time.Second): + t.Fatal("HandleConnection did not complete") + } + + cfg.AssertExpectations(t) +} + +func TestOpenForwardedChannelSuccess(t *testing.T) { + tests := []struct { + name string + forwardedPort uint16 + originAddr string + originPort int + }{ + { + name: "open channel standard port", + forwardedPort: 8080, + originAddr: "127.0.0.1", + originPort: 9000, + }, + { + name: "open channel high port", + forwardedPort: 65534, + originAddr: "192.168.1.100", + originPort: 5000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(8).Maybe() + channel := &testChannel{ + readBuf: newSyncBuffer(), + writeBuf: newSyncBuffer(), + } + requests := make(chan *ssh.Request) + + conn := &mockConn{} + conn.On("OpenChannel", "forwarded-tcpip", mock.Anything). + Return(channel, (<-chan *ssh.Request)(requests), nil) + + forwarder := New(cfg, slug.New(), conn).(*forwarder) + forwarder.SetForwardedPort(tt.forwardedPort) + + origin := &net.TCPAddr{IP: net.ParseIP(tt.originAddr), Port: tt.originPort} + ch, reqs, err := forwarder.OpenForwardedChannel(context.Background(), origin) + + require.NoError(t, err) + assert.NotNil(t, ch) + assert.NotNil(t, reqs) + + conn.AssertExpectations(t) + cfg.AssertExpectations(t) + }) + } +} + +func TestOpenForwardedChannelError(t *testing.T) { + tests := []struct { + name string + setupConn func() *mockConn + wantErr bool + wantErrMsg string + }{ + { + name: "open channel returns error", + setupConn: func() *mockConn { + conn := &mockConn{} + conn.On("OpenChannel", "forwarded-tcpip", mock.Anything). + Return((*testChannel)(nil), (<-chan *ssh.Request)(nil), errors.New("channel open failed")) + return conn + }, + wantErr: true, + wantErrMsg: "channel open failed", + }, + { + name: "open channel with nil channel", + setupConn: func() *mockConn { + conn := &mockConn{} + conn.On("OpenChannel", "forwarded-tcpip", mock.Anything). + Return((*testChannel)(nil), (<-chan *ssh.Request)(nil), nil) + return conn + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(8).Maybe() + + conn := tt.setupConn() + forwarder := New(cfg, slug.New(), conn).(*forwarder) + forwarder.SetForwardedPort(8080) + + origin := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 9000} + _, _, err := forwarder.OpenForwardedChannel(context.Background(), origin) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErrMsg) + } else { + assert.NoError(t, err) + } + + conn.AssertExpectations(t) + cfg.AssertExpectations(t) + }) + } +} + +func TestOpenForwardedChannelContextCancelledDuringOpen(t *testing.T) { + cfg := &mockConfig{} + cfg.On("BufferSize").Return(8).Maybe() + + channel := &testChannel{ + readBuf: newSyncBuffer(), + writeBuf: newSyncBuffer(), + } + channel.On("Close").Return(nil).Maybe() + + requests := make(chan *ssh.Request) + + openChannelStarted := make(chan struct{}) + openChannelBlock := make(chan struct{}) + + conn := &mockConn{} + conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) { + close(openChannelStarted) + <-openChannelBlock + }).Return(channel, (<-chan *ssh.Request)(requests), nil) + + forwarder := New(cfg, slug.New(), conn).(*forwarder) + forwarder.SetForwardedPort(8080) + + ctx, cancel := context.WithCancel(context.Background()) + + resultChan := make(chan error, 1) + go func() { + origin := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 7000} + _, _, err := forwarder.OpenForwardedChannel(ctx, origin) + resultChan <- err + }() + + <-openChannelStarted + + cancel() + + close(openChannelBlock) + + select { + case err := <-resultChan: + require.Error(t, err) + assert.Contains(t, err.Error(), "context cancelled") + case <-time.After(2 * time.Second): + t.Fatal("OpenForwardedChannel did not return") + } + + time.Sleep(50 * time.Millisecond) + + conn.AssertExpectations(t) + cfg.AssertExpectations(t) + channel.AssertExpectations(t) +} + +func TestCreateForwardedTCPIPPayloadEdgeCases(t *testing.T) { + tests := []struct { + name string + originAddr string + destPort uint16 + wantDestAddr string + wantDestPort uint32 + }{ + { + name: "IPv4 localhost", + originAddr: "127.0.0.1:5000", + destPort: 8080, + wantDestAddr: "localhost", + wantDestPort: 8080, + }, + { + name: "IPv6 address", + originAddr: "[::1]:3000", + destPort: 9000, + wantDestAddr: "localhost", + wantDestPort: 9000, + }, + { + name: "private network", + originAddr: "192.168.1.1:12345", + destPort: 443, + wantDestAddr: "localhost", + wantDestPort: 443, + }, + { + name: "port 1", + originAddr: "10.0.0.1:1", + destPort: 1, + wantDestAddr: "localhost", + wantDestPort: 1, + }, + { + name: "max port", + originAddr: "172.16.0.1:65535", + destPort: 65535, + wantDestAddr: "localhost", + wantDestPort: 65535, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, err := net.ResolveTCPAddr("tcp", tt.originAddr) + require.NoError(t, err) + + payload := createForwardedTCPIPPayload(addr, tt.destPort) + + var decoded struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 + } + + err = ssh.Unmarshal(payload, &decoded) + require.NoError(t, err) + + assert.Equal(t, tt.wantDestAddr, decoded.DestAddr) + assert.Equal(t, tt.wantDestPort, decoded.DestPort) + }) + } +} diff --git a/session/interaction/commands.go b/session/interaction/commands.go index e884aeb..5e24368 100644 --- a/session/interaction/commands.go +++ b/session/interaction/commands.go @@ -10,34 +10,37 @@ import ( "github.com/charmbracelet/lipgloss" ) +func (m *model) handleCommandSelection(item commandItem) (tea.Model, tea.Cmd) { + switch item.name { + case "slug": + m.showingCommands = false + m.editingSlug = true + m.slugInput.SetValue(m.interaction.slug.String()) + m.slugInput.Focus() + return m, tea.Batch(tea.ClearScreen, textinput.Blink) + case "tunnel-type": + m.showingCommands = false + m.showingComingSoon = true + return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink) + default: + m.showingCommands = false + return m, nil + } +} + func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { var cmd tea.Cmd switch { - case key.Matches(msg, m.keymap.quit): + case key.Matches(msg, m.keymap.quit), msg.String() == "esc": m.showingCommands = false return m, tea.Batch(tea.ClearScreen, textinput.Blink) case msg.String() == "enter": selectedItem := m.commandList.SelectedItem() if selectedItem != nil { item := selectedItem.(commandItem) - if item.name == "slug" { - m.showingCommands = false - m.editingSlug = true - m.slugInput.SetValue(m.interaction.slug.String()) - m.slugInput.Focus() - return m, tea.Batch(tea.ClearScreen, textinput.Blink) - } else if item.name == "tunnel-type" { - m.showingCommands = false - m.showingComingSoon = true - return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink) - } - m.showingCommands = false - return m, nil + return m.handleCommandSelection(item) } - case msg.String() == "esc": - m.showingCommands = false - return m, tea.Batch(tea.ClearScreen, textinput.Blink) } m.commandList, cmd = m.commandList.Update(msg) return m, cmd diff --git a/session/interaction/dashboard.go b/session/interaction/dashboard.go index a24ab7c..cf10ddb 100644 --- a/session/interaction/dashboard.go +++ b/session/interaction/dashboard.go @@ -23,164 +23,194 @@ func (m *model) dashboardUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } func (m *model) dashboardView() string { - titleStyle := lipgloss.NewStyle(). - Bold(true). - Foreground(lipgloss.Color("#7D56F4")). - PaddingTop(1) - - subtitleStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#888888")). - Italic(true) - - urlStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#7D56F4")). - Underline(true). - Italic(true) - - urlBoxStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#04B575")). - Bold(true). - Italic(true) - - keyHintStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#7D56F4")). - Bold(true) + isCompact := shouldUseCompactLayout(m.width, BreakpointLarge) var b strings.Builder + b.WriteString(m.renderHeader(isCompact)) + b.WriteString(m.renderUserInfo(isCompact)) + b.WriteString(m.renderQuickActions(isCompact)) + b.WriteString(m.renderFooter(isCompact)) - isCompact := shouldUseCompactLayout(m.width, 85) + return b.String() +} - var asciiArtMargin int - if isCompact { - asciiArtMargin = 0 - } else { - asciiArtMargin = 1 - } +func (m *model) renderHeader(isCompact bool) string { + var b strings.Builder + asciiArtMargin := getMarginValue(isCompact, 0, 1) asciiArtStyle := lipgloss.NewStyle(). Bold(true). - Foreground(lipgloss.Color("#7D56F4")). + Foreground(lipgloss.Color(ColorPrimary)). MarginBottom(asciiArtMargin) - var asciiArt string - if shouldUseCompactLayout(m.width, 50) { - asciiArt = "TUNNEL PLS" - } else if isCompact { - asciiArt = ` + b.WriteString(asciiArtStyle.Render(m.getASCIIArt())) + b.WriteString("\n") + + if !shouldUseCompactLayout(m.width, BreakpointSmall) { + b.WriteString(m.renderSubtitle()) + } else { + b.WriteString("\n") + } + + return b.String() +} + +func (m *model) getASCIIArt() string { + if shouldUseCompactLayout(m.width, BreakpointTiny) { + return "TUNNEL PLS" + } + + if shouldUseCompactLayout(m.width, BreakpointLarge) { + return ` ▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀ █ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀` - } else { - asciiArt = ` + } + + return ` ████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗ ╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝ ██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗ ██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║ ██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║ ╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝` - } +} - b.WriteString(asciiArtStyle.Render(asciiArt)) - b.WriteString("\n") +func (m *model) renderSubtitle() string { + subtitleStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorGray)). + Italic(true) - if !shouldUseCompactLayout(m.width, 60) { - b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • ")) - b.WriteString(urlStyle.Render("https://fossy.my.id")) - b.WriteString("\n\n") - } else { - b.WriteString("\n") - } + urlStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorPrimary)). + Underline(true). + Italic(true) + return subtitleStyle.Render("Secure tunnel service by Bagas • ") + + urlStyle.Render("https://fossy.my.id") + "\n\n" +} + +func (m *model) renderUserInfo(isCompact bool) string { boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80) - var boxPadding int - var boxMargin int - if isCompact { - boxPadding = 1 - boxMargin = 1 - } else { - boxPadding = 2 - boxMargin = 2 - } + boxPadding := getMarginValue(isCompact, 1, 2) + boxMargin := getMarginValue(isCompact, 1, 2) responsiveInfoBox := lipgloss.NewStyle(). Border(lipgloss.RoundedBorder()). - BorderForeground(lipgloss.Color("#7D56F4")). + BorderForeground(lipgloss.Color(ColorPrimary)). Padding(1, boxPadding). MarginTop(boxMargin). MarginBottom(boxMargin). Width(boxMaxWidth) - authenticatedUser := m.interaction.user + infoContent := m.getUserInfoContent(isCompact) + return responsiveInfoBox.Render(infoContent) + "\n" +} +func (m *model) getUserInfoContent(isCompact bool) string { userInfoStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#FAFAFA")). + Foreground(lipgloss.Color(ColorWhite)). Bold(true) sectionHeaderStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#888888")). + Foreground(lipgloss.Color(ColorGray)). Bold(true) addressStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#FAFAFA")) + Foreground(lipgloss.Color(ColorWhite)) - var infoContent string - if shouldUseCompactLayout(m.width, 70) { - infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s", + urlBoxStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorSecondary)). + Bold(true). + Italic(true) + + authenticatedUser := m.interaction.user + tunnelURL := urlBoxStyle.Render(m.getTunnelURL()) + + if isCompact { + return fmt.Sprintf("👤 %s\n\n%s\n%s", userInfoStyle.Render(authenticatedUser), sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"), - addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL())))) - } else { - infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s", - userInfoStyle.Render(authenticatedUser), - sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"), - addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL()))) + addressStyle.Render(fmt.Sprintf(" %s", tunnelURL))) } - b.WriteString(responsiveInfoBox.Render(infoContent)) + return fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s", + userInfoStyle.Render(authenticatedUser), + sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"), + addressStyle.Render(tunnelURL)) +} + +func (m *model) renderQuickActions(isCompact bool) string { + var b strings.Builder + + titleStyle := lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color(ColorPrimary)). + PaddingTop(1) + + b.WriteString(titleStyle.Render(m.getQuickActionsTitle())) b.WriteString("\n") - var quickActionsTitle string - if shouldUseCompactLayout(m.width, 50) { - quickActionsTitle = "Actions" - } else if isCompact { - quickActionsTitle = "Quick Actions" - } else { - quickActionsTitle = "✨ Quick Actions" - } - b.WriteString(titleStyle.Render(quickActionsTitle)) - b.WriteString("\n") - - var featureMargin int - if isCompact { - featureMargin = 1 - } else { - featureMargin = 2 - } - - compactFeatureStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#FAFAFA")). + featureMargin := getMarginValue(isCompact, 1, 2) + featureStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorWhite)). MarginLeft(featureMargin) - var commandsText string - var quitText string - if shouldUseCompactLayout(m.width, 60) { - commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]")) - quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]")) - } else { - commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]")) - quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]")) - } + keyHintStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorPrimary)). + Bold(true) - b.WriteString(compactFeatureStyle.Render(commandsText)) + commands := m.getActionCommands(keyHintStyle) + b.WriteString(featureStyle.Render(commands.commandsText)) b.WriteString("\n") - b.WriteString(compactFeatureStyle.Render(quitText)) - - if !shouldUseCompactLayout(m.width, 70) { - b.WriteString("\n\n") - footerStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#666666")). - Italic(true) - b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings")) - } + b.WriteString(featureStyle.Render(commands.quitText)) return b.String() } + +func (m *model) getQuickActionsTitle() string { + if shouldUseCompactLayout(m.width, BreakpointTiny) { + return "Actions" + } + if shouldUseCompactLayout(m.width, BreakpointLarge) { + return "Quick Actions" + } + return "✨ Quick Actions" +} + +type actionCommands struct { + commandsText string + quitText string +} + +func (m *model) getActionCommands(keyHintStyle lipgloss.Style) actionCommands { + if shouldUseCompactLayout(m.width, BreakpointSmall) { + return actionCommands{ + commandsText: fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]")), + quitText: fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]")), + } + } + + return actionCommands{ + commandsText: fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]")), + quitText: fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]")), + } +} + +func (m *model) renderFooter(isCompact bool) string { + if isCompact { + return "" + } + + footerStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorDarkGray)). + Italic(true) + + return "\n\n" + footerStyle.Render("Press 'C' to customize your tunnel settings") +} + +func getMarginValue(isCompact bool, compactValue, normalValue int) int { + if isCompact { + return compactValue + } + return normalValue +} diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index 5f68102..814cd67 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -3,7 +3,9 @@ package interaction import ( "context" "log" + "sync" "tunnel_pls/internal/config" + "tunnel_pls/internal/random" "tunnel_pls/session/slug" "tunnel_pls/types" @@ -39,6 +41,7 @@ type Forwarder interface { type CloseFunc func() error type interaction struct { + randomizer random.Random config config.Config channel ssh.Channel slug slug.Slug @@ -50,6 +53,7 @@ type interaction struct { ctx context.Context cancel context.CancelFunc mode types.InteractiveMode + programMu sync.Mutex } func (i *interaction) SetMode(m types.InteractiveMode) { @@ -76,9 +80,10 @@ func (i *interaction) SetWH(w, h int) { } } -func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { +func New(randomizer random.Random, config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { ctx, cancel := context.WithCancel(context.Background()) return &interaction{ + randomizer: randomizer, config: config, channel: nil, slug: slug, @@ -100,6 +105,10 @@ func (i *interaction) Stop() { if i.cancel != nil { i.cancel() } + + i.programMu.Lock() + defer i.programMu.Unlock() + if i.program != nil { i.program.Kill() i.program = nil @@ -210,6 +219,7 @@ func (i *interaction) Start() { ti.Width = 50 m := &model{ + randomizer: i.randomizer, domain: i.config.Domain(), protocol: protocol, tunnelType: tunnelType, @@ -234,6 +244,7 @@ func (i *interaction) Start() { help: help.New(), } + i.programMu.Lock() i.program = tea.NewProgram( m, tea.WithInput(i.channel), @@ -244,16 +255,21 @@ func (i *interaction) Start() { tea.WithoutSignalHandler(), tea.WithFPS(30), ) + i.programMu.Unlock() _, err := i.program.Run() if err != nil { log.Printf("Cannot close tea: %s \n", err) } - i.program.Kill() - i.program = nil + + i.programMu.Lock() + if i.program != nil { + i.program.Kill() + i.program = nil + } + i.programMu.Unlock() + if i.closeFunc != nil { - if err := i.closeFunc(); err != nil { - log.Printf("Cannot close session: %s \n", err) - } + _ = i.closeFunc() } } diff --git a/session/interaction/interaction_test.go b/session/interaction/interaction_test.go new file mode 100644 index 0000000..4992093 --- /dev/null +++ b/session/interaction/interaction_test.go @@ -0,0 +1,2310 @@ +package interaction + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + "tunnel_pls/types" + + "github.com/charmbracelet/bubbles/key" + "github.com/charmbracelet/bubbles/list" + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/ssh" +) + +type MockRandom struct { + mock.Mock +} + +func (m *MockRandom) String(length int) (string, error) { + args := m.Called(length) + return args.String(0), args.Error(1) +} + +type MockConfig struct { + mock.Mock +} + +func (m *MockConfig) Domain() string { return m.Called().String(0) } +func (m *MockConfig) SSHPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) } +func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) PprofPort() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } +func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } +func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } + +type MockSlug struct { + mock.Mock +} + +func (ms *MockSlug) Set(slug string) { ms.Called(slug) } +func (ms *MockSlug) String() string { return ms.Called().String(0) } + +type MockForwarder struct { + mock.Mock +} + +func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { + args := m.Called(origin) + return args.Get(0).([]byte) +} + +func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { + m.Called(dst, src) +} + +func (m *MockForwarder) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockForwarder) TunnelType() types.TunnelType { + args := m.Called() + return args.Get(0).(types.TunnelType) +} + +func (m *MockForwarder) ForwardedPort() uint16 { + args := m.Called() + return args.Get(0).(uint16) +} + +func (m *MockForwarder) SetType(tunnelType types.TunnelType) { + m.Called(tunnelType) +} + +func (m *MockForwarder) SetForwardedPort(port uint16) { + m.Called(port) +} + +func (m *MockForwarder) SetListener(listener net.Listener) { + m.Called(listener) +} + +func (m *MockForwarder) Listener() net.Listener { + args := m.Called() + return args.Get(0).(net.Listener) +} + +func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(ctx, origin) + if args.Get(0) == nil { + return nil, nil, args.Error(2) + } + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +type MockSessionRegistry struct { + mock.Mock +} + +func (m *MockSessionRegistry) Update(user string, oldKey, newKey types.SessionKey) error { + args := m.Called(user, oldKey, newKey) + return args.Error(0) +} + +type MockChannel struct { + mock.Mock + data []byte +} + +func (m *MockChannel) Read(b []byte) (n int, err error) { + args := m.Called(b) + return args.Int(0), args.Error(1) +} + +func (m *MockChannel) Write(b []byte) (n int, err error) { + m.data = append(m.data, b...) + args := m.Called(b) + return args.Int(0), args.Error(1) +} + +func (m *MockChannel) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockChannel) CloseWrite() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + args := m.Called(name, wantReply, payload) + return args.Bool(0), args.Error(1) +} + +func (m *MockChannel) Stderr() io.ReadWriter { + args := m.Called() + if args.Get(0) == nil { + return nil + } + return args.Get(0).(io.ReadWriter) +} + +type MockCloser struct { + mock.Mock +} + +func (m *MockCloser) Close() error { return m.Called().Error(0) } + +func TestNew(t *testing.T) { + tests := []struct { + name string + user string + }{ + { + name: "creates interaction with default mode", + user: "testuser", + }, + { + name: "creates interaction for different user", + user: "anotheruser", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, tt.user, mockCloser.Close) + + assert.NotNil(t, mockInteraction) + }) + } +} + +func TestInteraction_SetMode(t *testing.T) { + tests := []struct { + name string + mode types.InteractiveMode + }{ + { + name: "set headless mode", + mode: types.InteractiveModeHEADLESS, + }, + { + name: "set interactive mode", + mode: types.InteractiveModeINTERACTIVE, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + mockInteraction.SetMode(tt.mode) + + assert.Equal(t, tt.mode, mockInteraction.Mode()) + }) + } +} + +func TestInteraction_Mode(t *testing.T) { + tests := []struct { + name string + setMode types.InteractiveMode + expected types.InteractiveMode + }{ + { + name: "mode returns set value", + setMode: types.InteractiveModeINTERACTIVE, + expected: types.InteractiveModeINTERACTIVE, + }, + { + name: "mode returns headless value", + setMode: types.InteractiveModeHEADLESS, + expected: types.InteractiveModeHEADLESS, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + + mockInteraction.SetMode(tt.setMode) + assert.Equal(t, tt.expected, mockInteraction.Mode()) + }) + } +} + +func TestInteraction_Send(t *testing.T) { + tests := []struct { + name string + message string + setupChannel bool + channelReturn int + channelError error + wantError bool + }{ + { + name: "send message successfully", + message: "test message", + setupChannel: true, + channelReturn: 12, + channelError: nil, + wantError: false, + }, + { + name: "send message with channel error", + message: "test message", + setupChannel: true, + channelReturn: 0, + channelError: errors.New("channel write error"), + wantError: true, + }, + { + name: "send message without channel", + message: "test message", + setupChannel: false, + wantError: false, + }, + { + name: "send empty message", + message: "", + setupChannel: true, + channelReturn: 0, + channelError: nil, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + + if tt.setupChannel { + mockChannel := &MockChannel{} + mockChannel.On("Write", []byte(tt.message)).Return(tt.channelReturn, tt.channelError) + mockInteraction.SetChannel(mockChannel) + } + + err := mockInteraction.Send(tt.message) + + if tt.wantError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestInteraction_SetWH(t *testing.T) { + tests := []struct { + name string + width int + height int + }{ + { + name: "set large window size", + width: 100, + height: 50, + }, + { + name: "set medium window size", + width: 80, + height: 24, + }, + { + name: "set small window size", + width: 20, + height: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + + mockInteraction.SetWH(tt.width, tt.height) + }) + } +} + +func TestInteraction_SetChannel(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + + mockChannel := &MockChannel{} + mockInteraction.SetChannel(mockChannel) + + mockChannel.On("Write", []byte("test")).Return(4, nil) + err := mockInteraction.Send("test") + assert.NoError(t, err) +} + +func TestInteraction_Redraw(t *testing.T) { + tests := []struct { + name string + description string + }{ + { + name: "redraw interaction", + description: "should not panic when calling redraw", + }, + { + name: "redraw multiple times", + description: "should handle multiple redraws", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + + mockInteraction.Redraw() + }) + } +} + +func TestInteraction_Start(t *testing.T) { + tests := []struct { + name string + mode types.InteractiveMode + tlsEnabled bool + tunnelType types.TunnelType + port uint16 + }{ + { + name: "start in headless mode - should return immediately", + mode: types.InteractiveModeHEADLESS, + tlsEnabled: false, + tunnelType: types.TunnelTypeHTTP, + port: 8080, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + mockInteraction.SetMode(tt.mode) + + mockConfig.On("Domain").Return("tunnl.live") + mockConfig.On("TLSEnabled").Return(tt.tlsEnabled) + mockForwarder.On("TunnelType").Return(tt.tunnelType) + mockForwarder.On("ForwardedPort").Return(tt.port) + + mockInteraction.Start() + }) + } +} + +func TestModel_Update(t *testing.T) { + tests := []struct { + name string + msg tea.Msg + showingComingSoon bool + editingSlug bool + showingCommands bool + width int + height int + expectedWidth int + expectedHeight int + expectedQuit bool + }{ + { + name: "tick message clears coming soon", + msg: tickMsg{}, + showingComingSoon: true, + editingSlug: false, + showingCommands: false, + expectedQuit: false, + }, + { + name: "window size message - large screen", + msg: tea.WindowSizeMsg{Width: 100, Height: 50}, + expectedWidth: 100, + expectedHeight: 50, + expectedQuit: false, + }, + { + name: "window size message - small screen", + msg: tea.WindowSizeMsg{Width: 60, Height: 20}, + expectedWidth: 60, + expectedHeight: 20, + expectedQuit: false, + }, + { + name: "quit message", + msg: tea.QuitMsg{}, + expectedQuit: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + + m := &model{ + randomizer: mockRandom, + domain: "tunnl.live", + protocol: "http", + tunnelType: types.TunnelTypeHTTP, + port: 8080, + commandList: list.New([]list.Item{}, list.NewDefaultDelegate(), 80, 20), + interaction: mockInteraction.(*interaction), + showingComingSoon: tt.showingComingSoon, + editingSlug: tt.editingSlug, + showingCommands: tt.showingCommands, + width: tt.width, + height: tt.height, + keymap: keymap{ + quit: key.NewBinding( + key.WithKeys("q", "ctrl+c"), + key.WithHelp("q", "quit"), + ), + command: key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "commands"), + ), + random: key.NewBinding( + key.WithKeys("ctrl+r"), + key.WithHelp("ctrl+r", "random"), + ), + }, + } + + result, _ := m.Update(tt.msg) + updatedModel := result.(*model) + + if tt.expectedQuit { + assert.True(t, updatedModel.quitting) + } + + if windowMsg, ok := tt.msg.(tea.WindowSizeMsg); ok { + assert.Equal(t, windowMsg.Width, updatedModel.width) + assert.Equal(t, windowMsg.Height, updatedModel.height) + } + + if _, ok := tt.msg.(tickMsg); ok && tt.showingComingSoon { + assert.False(t, updatedModel.showingComingSoon) + } + }) + } +} + +func TestModel_View(t *testing.T) { + tests := []struct { + name string + quitting bool + showingComingSoon bool + editingSlug bool + showingCommands bool + expectedEmpty bool + }{ + { + name: "quitting returns empty string", + quitting: true, + expectedEmpty: true, + }, + { + name: "showing coming soon view", + showingComingSoon: true, + expectedEmpty: false, + }, + { + name: "editing slug view", + editingSlug: true, + expectedEmpty: false, + }, + { + name: "showing commands view", + showingCommands: true, + expectedEmpty: false, + }, + { + name: "dashboard view (default)", + expectedEmpty: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + + mockSlug.On("String").Return("test-slug") + + m := &model{ + randomizer: mockRandom, + domain: "tunnl.live", + protocol: "http", + tunnelType: types.TunnelTypeHTTP, + port: 8080, + commandList: list.New([]list.Item{}, list.NewDefaultDelegate(), 80, 20), + interaction: mockInteraction.(*interaction), + quitting: tt.quitting, + showingComingSoon: tt.showingComingSoon, + editingSlug: tt.editingSlug, + showingCommands: tt.showingCommands, + keymap: keymap{ + quit: key.NewBinding( + key.WithKeys("q", "ctrl+c"), + key.WithHelp("q", "quit"), + ), + command: key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "commands"), + ), + random: key.NewBinding( + key.WithKeys("ctrl+r"), + key.WithHelp("ctrl+r", "random"), + ), + }, + } + + view := m.View() + + if tt.expectedEmpty { + assert.Empty(t, view) + } else { + assert.NotEmpty(t, view) + } + }) + } +} + +func TestInteraction_Integration(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeCallCount := 0 + closeFunc := func() error { + closeCallCount++ + return nil + } + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + assert.NotNil(t, mockInteraction) + + mockInteraction.SetMode(types.InteractiveModeINTERACTIVE) + assert.Equal(t, types.InteractiveModeINTERACTIVE, mockInteraction.Mode()) + + mockChannel := &MockChannel{} + mockInteraction.SetChannel(mockChannel) + + mockChannel.On("Write", []byte("hello")).Return(5, nil) + err := mockInteraction.Send("hello") + assert.NoError(t, err) + mockChannel.AssertExpectations(t) + + mockInteraction.SetWH(80, 24) + + mockInteraction.Redraw() +} + +func TestModel_Update_KeyMessages(t *testing.T) { + tests := []struct { + name string + key tea.KeyMsg + showingComingSoon bool + editingSlug bool + showingCommands bool + description string + }{ + { + name: "key press while showing coming soon", + key: tea.KeyMsg{Type: tea.KeyEnter}, + showingComingSoon: true, + description: "should call comingSoonUpdate", + }, + { + name: "key press while editing slug", + key: tea.KeyMsg{Type: tea.KeyEnter}, + editingSlug: true, + description: "should call slugUpdate", + }, + { + name: "key press while showing commands", + key: tea.KeyMsg{Type: tea.KeyEnter}, + showingCommands: true, + description: "should call commandsUpdate", + }, + { + name: "key press in dashboard view", + key: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'c'}}, + description: "should call dashboardUpdate", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "user", mockCloser.Close) + + mockSlug.On("String").Return("test-slug").Maybe() + mockSessionRegistry.On("Update", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + m := &model{ + randomizer: mockRandom, + domain: "tunnl.live", + protocol: "http", + tunnelType: types.TunnelTypeHTTP, + port: 8080, + commandList: list.New([]list.Item{}, list.NewDefaultDelegate(), 80, 20), + interaction: mockInteraction.(*interaction), + showingComingSoon: tt.showingComingSoon, + editingSlug: tt.editingSlug, + showingCommands: tt.showingCommands, + keymap: keymap{ + quit: key.NewBinding( + key.WithKeys("q", "ctrl+c"), + key.WithHelp("q", "quit"), + ), + command: key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "commands"), + ), + random: key.NewBinding( + key.WithKeys("ctrl+r"), + key.WithHelp("ctrl+r", "random"), + ), + }, + } + + result, _ := m.Update(tt.key) + assert.NotNil(t, result) + }) + } +} + +func TestModel_SlugUpdate(t *testing.T) { + tests := []struct { + name string + tunnelType types.TunnelType + keyMsg tea.KeyMsg + inputValue string + setupMocks func(*MockSessionRegistry, *MockSlug, *MockRandom) + expectedEdit bool + expectedError string + shouldSetValue bool + }{ + { + name: "escape key cancels editing", + tunnelType: types.TunnelTypeHTTP, + keyMsg: tea.KeyMsg{Type: tea.KeyEsc}, + setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {}, + expectedEdit: false, + }, + { + name: "ctrl+c cancels editing", + tunnelType: types.TunnelTypeHTTP, + keyMsg: tea.KeyMsg{Type: tea.KeyCtrlC}, + setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {}, + expectedEdit: false, + }, + { + name: "enter key saves valid slug", + tunnelType: types.TunnelTypeHTTP, + keyMsg: tea.KeyMsg{Type: tea.KeyEnter}, + inputValue: "my-custom-slug", + setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) { + ms.On("String").Return("old-slug") + msr.On("Update", "testuser", + types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}, + types.SessionKey{Id: "my-custom-slug", Type: types.TunnelTypeHTTP}, + ).Return(nil) + }, + expectedEdit: false, + expectedError: "", + }, + { + name: "enter key with error shows error message", + tunnelType: types.TunnelTypeHTTP, + keyMsg: tea.KeyMsg{Type: tea.KeyEnter}, + inputValue: "invalid", + setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) { + ms.On("String").Return("old-slug") + msr.On("Update", "testuser", + types.SessionKey{Id: "old-slug", Type: types.TunnelTypeHTTP}, + types.SessionKey{Id: "invalid", Type: types.TunnelTypeHTTP}, + ).Return(assert.AnError) + }, + expectedEdit: true, + expectedError: assert.AnError.Error(), + }, + { + name: "ctrl+r generates random slug", + tunnelType: types.TunnelTypeHTTP, + keyMsg: tea.KeyMsg{Type: tea.KeyCtrlR}, + setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) { + mr.On("String", 20).Return("random-generated-slug", nil) + }, + expectedEdit: true, + shouldSetValue: true, + }, + { + name: "ctrl+r with error does nothing", + tunnelType: types.TunnelTypeHTTP, + keyMsg: tea.KeyMsg{Type: tea.KeyCtrlR}, + setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) { + mr.On("String", 20).Return("", assert.AnError) + }, + expectedEdit: true, + }, + { + name: "regular key updates input", + tunnelType: types.TunnelTypeHTTP, + keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'a'}}, + setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {}, + expectedEdit: true, + }, + { + name: "tcp tunnel exits editing immediately", + tunnelType: types.TunnelTypeTCP, + keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'a'}}, + setupMocks: func(msr *MockSessionRegistry, ms *MockSlug, mr *MockRandom) {}, + expectedEdit: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + ti := textinput.New() + ti.SetValue(tt.inputValue) + + m := &model{ + randomizer: mockRandom, + domain: "tunnl.live", + protocol: "http", + tunnelType: tt.tunnelType, + port: 8080, + commandList: list.New([]list.Item{}, list.NewDefaultDelegate(), 80, 20), + slugInput: ti, + editingSlug: true, + interaction: mockInteraction.(*interaction), + keymap: keymap{ + quit: key.NewBinding( + key.WithKeys("q", "ctrl+c"), + key.WithHelp("q", "quit"), + ), + command: key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "commands"), + ), + random: key.NewBinding( + key.WithKeys("ctrl+r"), + key.WithHelp("ctrl+r", "random"), + ), + }, + } + + tt.setupMocks(mockSessionRegistry, mockSlug, mockRandom) + + result, _ := m.slugUpdate(tt.keyMsg) + resultModel := result.(*model) + + assert.Equal(t, tt.expectedEdit, resultModel.editingSlug) + if tt.expectedError != "" { + assert.Equal(t, tt.expectedError, resultModel.slugError) + } else if !tt.expectedEdit { + assert.Equal(t, "", resultModel.slugError) + } + + mockSessionRegistry.AssertExpectations(t) + mockSlug.AssertExpectations(t) + mockRandom.AssertExpectations(t) + }) + } +} + +func TestModel_SlugView(t *testing.T) { + tests := []struct { + name string + width int + tunnelType types.TunnelType + slugError string + contains string + }{ + { + name: "http tunnel - large screen", + width: 100, + tunnelType: types.TunnelTypeHTTP, + contains: "Subdomain", + }, + { + name: "http tunnel - small screen", + width: 50, + tunnelType: types.TunnelTypeHTTP, + contains: "Subdomain", + }, + { + name: "http tunnel - tiny screen", + width: 30, + tunnelType: types.TunnelTypeHTTP, + contains: "Subdomain", + }, + { + name: "http tunnel with error", + width: 100, + tunnelType: types.TunnelTypeHTTP, + slugError: "Slug already exists", + contains: "Slug already exists", + }, + { + name: "tcp tunnel - large screen", + width: 100, + tunnelType: types.TunnelTypeTCP, + contains: "TCP", + }, + { + name: "tcp tunnel - small screen", + width: 50, + tunnelType: types.TunnelTypeTCP, + contains: "TCP", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + ti := textinput.New() + ti.SetValue("test-slug") + + m := &model{ + randomizer: mockRandom, + domain: "tunnl.live", + protocol: "http", + tunnelType: tt.tunnelType, + port: 8080, + slugInput: ti, + slugError: tt.slugError, + interaction: mockInteraction.(*interaction), + width: tt.width, + } + + view := m.slugView() + assert.NotEmpty(t, view) + assert.Contains(t, view, tt.contains) + }) + } +} + +func TestModel_ComingSoonUpdate(t *testing.T) { + tests := []struct { + name string + keyMsg tea.KeyMsg + }{ + { + name: "any key dismisses coming soon", + keyMsg: tea.KeyMsg{Type: tea.KeyEnter}, + }, + { + name: "escape key dismisses", + keyMsg: tea.KeyMsg{Type: tea.KeyEsc}, + }, + { + name: "space key dismisses", + keyMsg: tea.KeyMsg{Type: tea.KeySpace}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + m := &model{ + interaction: mockInteraction.(*interaction), + showingComingSoon: true, + } + + result, _ := m.comingSoonUpdate(tt.keyMsg) + resultModel := result.(*model) + + assert.False(t, resultModel.showingComingSoon) + }) + } +} + +func TestModel_ComingSoonView(t *testing.T) { + tests := []struct { + name string + width int + }{ + { + name: "large screen", + width: 100, + }, + { + name: "medium screen", + width: 60, + }, + { + name: "small screen", + width: 50, + }, + { + name: "tiny screen", + width: 30, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + m := &model{ + interaction: mockInteraction.(*interaction), + width: tt.width, + } + + view := m.comingSoonView() + assert.NotEmpty(t, view) + assert.Contains(t, view, "Coming") + }) + } +} + +func TestModel_CommandsUpdate(t *testing.T) { + tests := []struct { + name string + keyMsg tea.KeyMsg + selectedItem list.Item + expectCommands bool + expectEditSlug bool + expectComingSoon bool + }{ + { + name: "escape key closes commands", + keyMsg: tea.KeyMsg{Type: tea.KeyEsc}, + expectCommands: false, + }, + { + name: "q key closes commands", + keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'q'}}, + expectCommands: false, + }, + { + name: "enter on slug command starts editing", + keyMsg: tea.KeyMsg{Type: tea.KeyEnter}, + selectedItem: commandItem{name: "slug", desc: "Set custom subdomain"}, + expectCommands: false, + expectEditSlug: true, + }, + { + name: "enter on tunnel-type shows coming soon", + keyMsg: tea.KeyMsg{Type: tea.KeyEnter}, + selectedItem: commandItem{name: "tunnel-type", desc: "Change tunnel type"}, + expectCommands: false, + expectComingSoon: true, + }, + { + name: "arrow key navigates list", + keyMsg: tea.KeyMsg{Type: tea.KeyDown}, + expectCommands: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + mockSlug.On("String").Return("current-slug").Maybe() + + items := []list.Item{ + commandItem{name: "slug", desc: "Set custom subdomain"}, + commandItem{name: "tunnel-type", desc: "Change tunnel type"}, + } + + delegate := list.NewDefaultDelegate() + commandList := list.New(items, delegate, 80, 20) + if tt.selectedItem != nil { + for i, item := range items { + if item.(commandItem).name == tt.selectedItem.(commandItem).name { + commandList.Select(i) + break + } + } + } + + ti := textinput.New() + + m := &model{ + randomizer: mockRandom, + interaction: mockInteraction.(*interaction), + showingCommands: true, + commandList: commandList, + slugInput: ti, + keymap: keymap{ + quit: key.NewBinding( + key.WithKeys("q", "ctrl+c"), + key.WithHelp("q", "quit"), + ), + command: key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "commands"), + ), + random: key.NewBinding( + key.WithKeys("ctrl+r"), + key.WithHelp("ctrl+r", "random"), + ), + }, + } + + result, _ := m.commandsUpdate(tt.keyMsg) + resultModel := result.(*model) + + assert.Equal(t, tt.expectCommands, resultModel.showingCommands) + if tt.expectEditSlug { + assert.True(t, resultModel.editingSlug) + } + if tt.expectComingSoon { + assert.True(t, resultModel.showingComingSoon) + } + }) + } +} + +func TestModel_CommandsView(t *testing.T) { + tests := []struct { + name string + width int + }{ + { + name: "large screen", + width: 100, + }, + { + name: "medium screen", + width: 60, + }, + { + name: "small screen", + width: 50, + }, + { + name: "tiny screen", + width: 30, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + items := []list.Item{ + commandItem{name: "slug", desc: "Set custom subdomain"}, + commandItem{name: "tunnel-type", desc: "Change tunnel type"}, + } + + delegate := list.NewDefaultDelegate() + commandList := list.New(items, delegate, 80, 20) + + m := &model{ + interaction: mockInteraction.(*interaction), + commandList: commandList, + width: tt.width, + } + + view := m.commandsView() + assert.NotEmpty(t, view) + assert.Contains(t, view, "Commands") + }) + } +} + +func TestModel_DashboardUpdate(t *testing.T) { + tests := []struct { + name string + keyMsg tea.KeyMsg + expectQuit bool + expectCommands bool + }{ + { + name: "q key quits", + keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'q'}}, + expectQuit: true, + }, + { + name: "ctrl+c quits", + keyMsg: tea.KeyMsg{Type: tea.KeyCtrlC}, + expectQuit: true, + }, + { + name: "c key opens commands", + keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'c'}}, + expectCommands: true, + }, + { + name: "other keys do nothing", + keyMsg: tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + m := &model{ + interaction: mockInteraction.(*interaction), + keymap: keymap{ + quit: key.NewBinding( + key.WithKeys("q", "ctrl+c"), + key.WithHelp("q", "quit"), + ), + command: key.NewBinding( + key.WithKeys("c"), + key.WithHelp("c", "commands"), + ), + random: key.NewBinding( + key.WithKeys("ctrl+r"), + key.WithHelp("ctrl+r", "random"), + ), + }, + } + + result, _ := m.dashboardUpdate(tt.keyMsg) + resultModel := result.(*model) + + if tt.expectQuit { + assert.True(t, resultModel.quitting) + } + if tt.expectCommands { + assert.True(t, resultModel.showingCommands) + } + }) + } +} + +func TestModel_DashboardView(t *testing.T) { + tests := []struct { + name string + width int + tunnelType types.TunnelType + protocol string + port uint16 + contains string + }{ + { + name: "http tunnel - large screen", + width: 100, + tunnelType: types.TunnelTypeHTTP, + protocol: "http", + contains: "http", + }, + { + name: "https tunnel - large screen", + width: 100, + tunnelType: types.TunnelTypeHTTP, + protocol: "https", + contains: "https", + }, + { + name: "tcp tunnel - large screen", + width: 100, + tunnelType: types.TunnelTypeTCP, + port: 8080, + contains: "tcp", + }, + { + name: "http tunnel - medium screen", + width: 70, + tunnelType: types.TunnelTypeHTTP, + protocol: "http", + contains: "http", + }, + { + name: "http tunnel - small screen", + width: 50, + tunnelType: types.TunnelTypeHTTP, + protocol: "http", + contains: "http", + }, + { + name: "http tunnel - tiny screen", + width: 30, + tunnelType: types.TunnelTypeHTTP, + protocol: "http", + contains: "http", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + mockSlug.On("String").Return("test-slug") + + m := &model{ + randomizer: mockRandom, + domain: "tunnl.live", + protocol: tt.protocol, + tunnelType: tt.tunnelType, + port: tt.port, + interaction: mockInteraction.(*interaction), + width: tt.width, + } + + view := m.dashboardView() + assert.NotEmpty(t, view) + assert.Contains(t, view, tt.contains) + }) + } +} + +func TestGetResponsiveWidth(t *testing.T) { + tests := []struct { + name string + screenWidth int + padding int + minWidth int + maxWidth int + expected int + }{ + { + name: "screen wider than max", + screenWidth: 100, + padding: 10, + minWidth: 20, + maxWidth: 60, + expected: 60, + }, + { + name: "screen narrower than min", + screenWidth: 30, + padding: 10, + minWidth: 40, + maxWidth: 80, + expected: 40, + }, + { + name: "screen within range", + screenWidth: 70, + padding: 10, + minWidth: 20, + maxWidth: 80, + expected: 60, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getResponsiveWidth(tt.screenWidth, tt.padding, tt.minWidth, tt.maxWidth) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestShouldUseCompactLayout(t *testing.T) { + tests := []struct { + name string + width int + threshold int + expected bool + }{ + { + name: "width below threshold", + width: 50, + threshold: 60, + expected: true, + }, + { + name: "width at threshold", + width: 60, + threshold: 60, + expected: false, + }, + { + name: "width above threshold", + width: 70, + threshold: 60, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := shouldUseCompactLayout(tt.width, tt.threshold) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTruncateString(t *testing.T) { + tests := []struct { + name string + input string + maxLength int + expected string + }{ + { + name: "string shorter than max", + input: "short", + maxLength: 10, + expected: "short", + }, + { + name: "string equal to max", + input: "exactly10c", + maxLength: 10, + expected: "exactly10c", + }, + { + name: "string longer than max", + input: "this is a very long string", + maxLength: 10, + expected: "this is...", + }, + { + name: "very short max length", + input: "hello", + maxLength: 3, + expected: "hel", + }, + { + name: "max length less than 4", + input: "hello", + maxLength: 2, + expected: "he", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncateString(tt.input, tt.maxLength) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestBuildURL(t *testing.T) { + tests := []struct { + name string + protocol string + subdomain string + domain string + expected string + }{ + { + name: "http url", + protocol: "http", + subdomain: "test", + domain: "tunnl.live", + expected: "http://test.tunnl.live", + }, + { + name: "https url", + protocol: "https", + subdomain: "api", + domain: "myapp.io", + expected: "https://api.myapp.io", + }, + { + name: "custom subdomain", + protocol: "http", + subdomain: "my-custom-slug", + domain: "tunnl.live", + expected: "http://my-custom-slug.tunnl.live", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildURL(tt.protocol, tt.subdomain, tt.domain) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTickCmd(t *testing.T) { + tests := []struct { + name string + duration time.Duration + }{ + { + name: "5 second tick", + duration: 5 * time.Second, + }, + { + name: "1 second tick", + duration: 1 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := tickCmd(tt.duration) + assert.NotNil(t, cmd) + }) + } +} + +func TestGetPaddingValue(t *testing.T) { + tests := []struct { + name string + isVeryCompact bool + isCompact bool + expected int + }{ + { + name: "very compact layout", + isVeryCompact: true, + isCompact: false, + expected: 1, + }, + { + name: "compact layout", + isVeryCompact: false, + isCompact: true, + expected: 1, + }, + { + name: "normal layout", + isVeryCompact: false, + isCompact: false, + expected: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getPaddingValue(tt.isVeryCompact, tt.isCompact) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestGetMarginValue(t *testing.T) { + tests := []struct { + name string + isCompact bool + compactValue int + normalValue int + expected int + }{ + { + name: "compact layout", + isCompact: true, + compactValue: 1, + normalValue: 2, + expected: 1, + }, + { + name: "normal layout", + isCompact: false, + compactValue: 1, + normalValue: 2, + expected: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getMarginValue(tt.isCompact, tt.compactValue, tt.normalValue) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestCommandItem(t *testing.T) { + tests := []struct { + name string + item commandItem + wantName string + wantDesc string + }{ + { + name: "slug command", + item: commandItem{name: "slug", desc: "Set custom subdomain"}, + wantName: "slug", + wantDesc: "Set custom subdomain", + }, + { + name: "tunnel-type command", + item: commandItem{name: "tunnel-type", desc: "Change tunnel type"}, + wantName: "tunnel-type", + wantDesc: "Change tunnel type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.wantName, tt.item.FilterValue()) + assert.Equal(t, tt.wantName, tt.item.Title()) + assert.Equal(t, tt.wantDesc, tt.item.Description()) + }) + } +} + +func TestModel_GetTunnelURL(t *testing.T) { + tests := []struct { + name string + tunnelType types.TunnelType + protocol string + slug string + domain string + port uint16 + expected string + }{ + { + name: "http tunnel", + tunnelType: types.TunnelTypeHTTP, + protocol: "http", + slug: "my-app", + domain: "tunnl.live", + expected: "http://my-app.tunnl.live", + }, + { + name: "https tunnel", + tunnelType: types.TunnelTypeHTTP, + protocol: "https", + slug: "secure-app", + domain: "tunnl.live", + expected: "https://secure-app.tunnl.live", + }, + { + name: "tcp tunnel", + tunnelType: types.TunnelTypeTCP, + domain: "tunnl.live", + port: 8080, + expected: "tcp://tunnl.live:8080", + }, + { + name: "tcp tunnel with different port", + tunnelType: types.TunnelTypeTCP, + domain: "tunnl.live", + port: 3306, + expected: "tcp://tunnl.live:3306", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + mockSlug.On("String").Return(tt.slug).Maybe() + + m := &model{ + domain: tt.domain, + protocol: tt.protocol, + tunnelType: tt.tunnelType, + port: tt.port, + interaction: mockInteraction.(*interaction), + } + + result := m.getTunnelURL() + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestModel_Init(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + mockCloser := &MockCloser{} + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", mockCloser.Close) + + m := &model{ + interaction: mockInteraction.(*interaction), + } + + cmd := m.Init() + assert.NotNil(t, cmd) +} + +func TestInteraction_Start_Interactive(t *testing.T) { + tests := []struct { + name string + tlsEnabled bool + tunnelType types.TunnelType + port uint16 + domain string + }{ + { + name: "interactive mode with http", + tlsEnabled: false, + tunnelType: types.TunnelTypeHTTP, + port: 8080, + domain: "tunnl.live", + }, + { + name: "interactive mode with https", + tlsEnabled: true, + tunnelType: types.TunnelTypeHTTP, + port: 8443, + domain: "secure.tunnl.live", + }, + { + name: "interactive mode with tcp", + tlsEnabled: false, + tunnelType: types.TunnelTypeTCP, + port: 3306, + domain: "db.tunnl.live", + }, + { + name: "interactive mode with tcp and tls enabled", + tlsEnabled: true, + tunnelType: types.TunnelTypeTCP, + port: 5432, + domain: "postgres.tunnl.live", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeCallCount := 0 + closeFunc := func() error { + closeCallCount++ + return nil + } + + mockConfig.On("Domain").Return(tt.domain) + mockConfig.On("TLSEnabled").Return(tt.tlsEnabled) + mockForwarder.On("TunnelType").Return(tt.tunnelType) + mockForwarder.On("ForwardedPort").Return(tt.port) + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + mockInteraction.SetMode(types.InteractiveModeINTERACTIVE) + + mockChannel := &MockChannel{} + mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe() + mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe() + mockInteraction.SetChannel(mockChannel) + + done := make(chan bool, 1) + go func() { + mockInteraction.Start() + done <- true + }() + + time.Sleep(50 * time.Millisecond) + + i := mockInteraction.(*interaction) + i.Stop() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Start() did not complete in time") + } + + assert.Equal(t, 1, closeCallCount, "close function should be called once") + + mockConfig.AssertExpectations(t) + mockForwarder.AssertExpectations(t) + }) + } +} + +func TestInteraction_Start_ProtocolSelection(t *testing.T) { + tests := []struct { + name string + tlsEnabled bool + expectedProto string + }{ + { + name: "http when TLS disabled", + tlsEnabled: false, + expectedProto: "http", + }, + { + name: "https when TLS enabled", + tlsEnabled: true, + expectedProto: "https", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeFunc := func() error { return nil } + + mockConfig.On("Domain").Return("tunnl.live") + mockConfig.On("TLSEnabled").Return(tt.tlsEnabled) + mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP) + mockForwarder.On("ForwardedPort").Return(uint16(8080)) + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + mockInteraction.SetMode(types.InteractiveModeINTERACTIVE) + + mockChannel := &MockChannel{} + mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe() + mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe() + mockInteraction.SetChannel(mockChannel) + + go func() { + mockInteraction.Start() + }() + + time.Sleep(50 * time.Millisecond) + + i := mockInteraction.(*interaction) + if i.program != nil { + assert.NotNil(t, i.program, "program should be initialized") + } + + i.Stop() + + mockConfig.AssertExpectations(t) + mockForwarder.AssertExpectations(t) + }) + } +} + +func TestInteraction_Stop(t *testing.T) { + tests := []struct { + name string + setupProgram bool + description string + }{ + { + name: "stop with active program", + setupProgram: true, + description: "should kill program and set to nil", + }, + { + name: "stop without program", + setupProgram: false, + description: "should not panic when program is nil", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeFunc := func() error { return nil } + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + i := mockInteraction.(*interaction) + + if tt.setupProgram { + mockConfig.On("Domain").Return("tunnl.live") + mockConfig.On("TLSEnabled").Return(false) + mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP) + mockForwarder.On("ForwardedPort").Return(uint16(8080)) + + mockInteraction.SetMode(types.InteractiveModeINTERACTIVE) + mockChannel := &MockChannel{} + mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe() + mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe() + mockInteraction.SetChannel(mockChannel) + + go func() { + mockInteraction.Start() + }() + + time.Sleep(50 * time.Millisecond) + } + + assert.NotPanics(t, func() { + i.Stop() + }) + + assert.Nil(t, i.program) + + select { + case <-i.ctx.Done(): + case <-time.After(100 * time.Millisecond): + t.Fatal("context should be cancelled after Stop()") + } + }) + } +} + +func TestInteraction_Start_CommandListSetup(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeFunc := func() error { return nil } + + mockConfig.On("Domain").Return("tunnl.live") + mockConfig.On("TLSEnabled").Return(false) + mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP) + mockForwarder.On("ForwardedPort").Return(uint16(8080)) + + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + mockInteraction.SetMode(types.InteractiveModeINTERACTIVE) + + mockChannel := &MockChannel{} + mockChannel.On("Read", mock.Anything).Return(0, nil) + mockChannel.On("Write", mock.Anything).Return(0, nil) + mockInteraction.SetChannel(mockChannel) + + go func() { + mockInteraction.Start() + }() + + time.Sleep(50 * time.Millisecond) + + i := mockInteraction.(*interaction) + + assert.NotNil(t, i.program, "program should be initialized") + + i.Stop() +} + +func TestInteraction_Start_TextInputSetup(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeFunc := func() error { return nil } + + mockConfig.On("Domain").Return("tunnl.live") + mockConfig.On("TLSEnabled").Return(false) + mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP) + mockForwarder.On("ForwardedPort").Return(uint16(8080)) + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + mockInteraction.SetMode(types.InteractiveModeINTERACTIVE) + + mockChannel := &MockChannel{} + mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe() + mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe() + mockInteraction.SetChannel(mockChannel) + + go func() { + mockInteraction.Start() + }() + + time.Sleep(50 * time.Millisecond) + + i := mockInteraction.(*interaction) + i.Stop() + + mockConfig.AssertExpectations(t) + mockForwarder.AssertExpectations(t) +} + +func TestInteraction_Start_CleanupOnExit(t *testing.T) { + tests := []struct { + name string + closeFunc CloseFunc + expectCloseCalled bool + }{ + { + name: "cleanup calls close function", + closeFunc: func() error { + return nil + }, + expectCloseCalled: true, + }, + { + name: "cleanup with nil close function", + closeFunc: nil, + expectCloseCalled: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + + closeCallCount := 0 + var closeFunc CloseFunc + if tt.closeFunc != nil { + closeFunc = func() error { + closeCallCount++ + return tt.closeFunc() + } + } + + mockConfig.On("Domain").Return("tunnl.live") + mockConfig.On("TLSEnabled").Return(false) + mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP) + mockForwarder.On("ForwardedPort").Return(uint16(8080)) + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + mockInteraction.SetMode(types.InteractiveModeINTERACTIVE) + + mockChannel := &MockChannel{} + mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe() + mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe() + mockInteraction.SetChannel(mockChannel) + + done := make(chan bool, 1) + go func() { + mockInteraction.Start() + done <- true + }() + + time.Sleep(50 * time.Millisecond) + + i := mockInteraction.(*interaction) + i.Stop() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("Start() did not complete") + } + + if tt.expectCloseCalled { + assert.Equal(t, 1, closeCallCount, "close function should be called") + } else { + assert.Equal(t, 0, closeCallCount, "close function should not be called when nil") + } + }) + } +} + +func TestInteraction_Start_WithDifferentChannels(t *testing.T) { + tests := []struct { + name string + setupChannel bool + }{ + { + name: "start with channel set", + setupChannel: true, + }, + { + name: "start with nil channel", + setupChannel: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeFunc := func() error { return nil } + + mockConfig.On("Domain").Return("tunnl.live") + mockConfig.On("TLSEnabled").Return(false) + mockForwarder.On("TunnelType").Return(types.TunnelTypeHTTP) + mockForwarder.On("ForwardedPort").Return(uint16(8080)) + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + mockInteraction.SetMode(types.InteractiveModeINTERACTIVE) + + if tt.setupChannel { + mockChannel := &MockChannel{} + mockChannel.On("Read", mock.Anything).Return(0, assert.AnError).Maybe() + mockChannel.On("Write", mock.Anything).Return(0, nil).Maybe() + mockInteraction.SetChannel(mockChannel) + } + + go func() { + mockInteraction.Start() + }() + + time.Sleep(50 * time.Millisecond) + + i := mockInteraction.(*interaction) + i.Stop() + + mockConfig.AssertExpectations(t) + mockForwarder.AssertExpectations(t) + }) + } +} + +func TestInteraction_Stop_ContextCancellation(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeFunc := func() error { return nil } + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + i := mockInteraction.(*interaction) + + select { + case <-i.ctx.Done(): + t.Fatal("context should not be cancelled initially") + default: + } + + i.Stop() + + select { + case <-i.ctx.Done(): + case <-time.After(100 * time.Millisecond): + t.Fatal("context should be cancelled after Stop()") + } + + assert.NotPanics(t, func() { + i.Stop() + }) +} + +func TestInteraction_Stop_MultipleCallsSafe(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeFunc := func() error { return nil } + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + i := mockInteraction.(*interaction) + + assert.NotPanics(t, func() { + i.Stop() + i.Stop() + i.Stop() + }) + + assert.Nil(t, i.program) +} + +func TestInteraction_Start_HeadlessMode_NoOp(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeFunc := func() error { return nil } + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + mockInteraction.SetMode(types.InteractiveModeHEADLESS) + + done := make(chan bool, 1) + go func() { + mockInteraction.Start() + done <- true + }() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + t.Fatal("headless mode should return immediately") + } + + i := mockInteraction.(*interaction) + assert.Nil(t, i.program, "program should not be created in headless mode") + + mockConfig.AssertNotCalled(t, "Domain") + mockConfig.AssertNotCalled(t, "TLSEnabled") + mockForwarder.AssertNotCalled(t, "TunnelType") + mockForwarder.AssertNotCalled(t, "ForwardedPort") +} + +func TestInteraction_New_ContextInitialization(t *testing.T) { + mockRandom := &MockRandom{} + mockConfig := &MockConfig{} + mockSlug := &MockSlug{} + mockForwarder := &MockForwarder{} + mockSessionRegistry := &MockSessionRegistry{} + closeFunc := func() error { return nil } + mockSlug.On("String").Return("test-slug") + + mockInteraction := New(mockRandom, mockConfig, mockSlug, mockForwarder, mockSessionRegistry, "testuser", closeFunc) + i := mockInteraction.(*interaction) + + assert.NotNil(t, i.ctx, "context should be initialized") + assert.NotNil(t, i.cancel, "cancel function should be initialized") + + select { + case <-i.ctx.Done(): + t.Fatal("context should not be cancelled initially") + default: + } +} diff --git a/session/interaction/model.go b/session/interaction/model.go index 189b0a1..c0d1672 100644 --- a/session/interaction/model.go +++ b/session/interaction/model.go @@ -3,6 +3,7 @@ package interaction import ( "fmt" "time" + "tunnel_pls/internal/random" "tunnel_pls/types" "github.com/charmbracelet/bubbles/help" @@ -22,6 +23,7 @@ func (i commandItem) Title() string { return i.name } func (i commandItem) Description() string { return i.desc } type model struct { + randomizer random.Random domain string protocol string tunnelType types.TunnelType @@ -40,6 +42,25 @@ type model struct { height int } +const ( + ColorPrimary = "#7D56F4" + ColorSecondary = "#04B575" + ColorGray = "#888888" + ColorDarkGray = "#666666" + ColorWhite = "#FAFAFA" + ColorError = "#FF0000" + ColorErrorBg = "#3D0000" + ColorWarning = "#FFA500" + ColorWarningBg = "#3D2000" +) + +const ( + BreakpointTiny = 50 + BreakpointSmall = 60 + BreakpointMedium = 70 + BreakpointLarge = 85 +) + func (m *model) getTunnelURL() string { if m.tunnelType == types.TunnelTypeHTTP { return buildURL(m.protocol, m.interaction.slug.String(), m.domain) diff --git a/session/interaction/slug.go b/session/interaction/slug.go index 2b871d4..ff57cb1 100644 --- a/session/interaction/slug.go +++ b/session/interaction/slug.go @@ -3,7 +3,6 @@ package interaction import ( "fmt" "strings" - "tunnel_pls/internal/random" "tunnel_pls/types" "github.com/charmbracelet/bubbles/key" @@ -22,7 +21,7 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } switch msg.String() { - case "esc": + case "esc", "ctrl+c": m.editingSlug = false m.slugError = "" return m, tea.Batch(tea.ClearScreen, textinput.Blink) @@ -41,19 +40,13 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { m.editingSlug = false m.slugError = "" return m, tea.Batch(tea.ClearScreen, textinput.Blink) - case "ctrl+c": - m.editingSlug = false - m.slugError = "" - return m, tea.Batch(tea.ClearScreen, textinput.Blink) default: if key.Matches(msg, m.keymap.random) { - newSubdomain, err := random.GenerateRandomString(20) + newSubdomain, err := m.randomizer.String(20) if err != nil { return m, cmd } m.slugInput.SetValue(newSubdomain) - m.slugError = "" - m.slugInput, cmd = m.slugInput.Update(msg) } m.slugError = "" m.slugInput, cmd = m.slugInput.Update(msg) @@ -62,163 +55,211 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } func (m *model) slugView() string { - isCompact := shouldUseCompactLayout(m.width, 70) - isVeryCompact := shouldUseCompactLayout(m.width, 50) + isCompact := shouldUseCompactLayout(m.width, BreakpointMedium) + isVeryCompact := shouldUseCompactLayout(m.width, BreakpointTiny) - var boxPadding int - var boxMargin int - if isVeryCompact { - boxPadding = 1 - boxMargin = 1 - } else if isCompact { - boxPadding = 1 - boxMargin = 1 - } else { - boxPadding = 2 - boxMargin = 2 + var b strings.Builder + b.WriteString(m.renderSlugTitle(isVeryCompact)) + + if m.tunnelType != types.TunnelTypeHTTP { + b.WriteString(m.renderTCPWarning(isVeryCompact, isCompact)) + return b.String() } + b.WriteString(m.renderSlugRules(isVeryCompact, isCompact)) + b.WriteString(m.renderSlugInstruction(isVeryCompact)) + b.WriteString(m.renderSlugInput(isVeryCompact, isCompact)) + b.WriteString(m.renderSlugPreview(isVeryCompact)) + b.WriteString(m.renderSlugHelp(isVeryCompact)) + + return b.String() +} + +func (m *model) renderSlugTitle(isVeryCompact bool) string { titleStyle := lipgloss.NewStyle(). Bold(true). - Foreground(lipgloss.Color("#7D56F4")). + Foreground(lipgloss.Color(ColorPrimary)). PaddingTop(1). PaddingBottom(1) - instructionStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#FAFAFA")). - MarginTop(1) + title := "🔧 Edit Subdomain" + if isVeryCompact { + title = "Edit Subdomain" + } - inputBoxStyle := lipgloss.NewStyle(). + return titleStyle.Render(title) + "\n\n" +} + +func (m *model) renderTCPWarning(isVeryCompact, isCompact bool) string { + boxPadding := getPaddingValue(isVeryCompact, isCompact) + boxMargin := getMarginValue(isCompact, 1, 2) + warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60) + + warningBoxStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorWarning)). + Background(lipgloss.Color(ColorWarningBg)). + Bold(true). Border(lipgloss.RoundedBorder()). - BorderForeground(lipgloss.Color("#7D56F4")). + BorderForeground(lipgloss.Color(ColorWarning)). Padding(1, boxPadding). MarginTop(boxMargin). - MarginBottom(boxMargin) + MarginBottom(boxMargin). + Width(warningBoxWidth) helpStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#666666")). + Foreground(lipgloss.Color(ColorDarkGray)). Italic(true). MarginTop(1) - errorBoxStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#FF0000")). - Background(lipgloss.Color("#3D0000")). - Bold(true). - Border(lipgloss.RoundedBorder()). - BorderForeground(lipgloss.Color("#FF0000")). - Padding(0, boxPadding). - MarginTop(1). - MarginBottom(1) + warningText := m.getTCPWarningText(isVeryCompact) + helpText := m.getTCPHelpText(isVeryCompact) + var b strings.Builder + b.WriteString(warningBoxStyle.Render(warningText)) + b.WriteString("\n\n") + b.WriteString(helpStyle.Render(helpText)) + + return b.String() +} + +func (m *model) getTCPWarningText(isVeryCompact bool) string { + if isVeryCompact { + return "⚠️ TCP tunnels don't support custom subdomains." + } + return "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization." +} + +func (m *model) getTCPHelpText(isVeryCompact bool) string { + if isVeryCompact { + return "Press any key to go back" + } + return "Press Enter or Esc to go back" +} + +func (m *model) renderSlugRules(isVeryCompact, isCompact bool) string { + boxPadding := getPaddingValue(isVeryCompact, isCompact) rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60) + rulesBoxStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#FAFAFA")). + Foreground(lipgloss.Color(ColorWhite)). Border(lipgloss.RoundedBorder()). - BorderForeground(lipgloss.Color("#7D56F4")). + BorderForeground(lipgloss.Color(ColorPrimary)). Padding(0, boxPadding). MarginTop(1). MarginBottom(1). Width(rulesBoxWidth) - var b strings.Builder - var title string + rulesContent := m.getRulesContent(isVeryCompact, isCompact) + return rulesBoxStyle.Render(rulesContent) + "\n" +} + +func (m *model) getRulesContent(isVeryCompact, isCompact bool) string { if isVeryCompact { - title = "Edit Subdomain" - } else { - title = "🔧 Edit Subdomain" - } - b.WriteString(titleStyle.Render(title)) - b.WriteString("\n\n") - - if m.tunnelType != types.TunnelTypeHTTP { - warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60) - warningBoxStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#FFA500")). - Background(lipgloss.Color("#3D2000")). - Bold(true). - Border(lipgloss.RoundedBorder()). - BorderForeground(lipgloss.Color("#FFA500")). - Padding(1, boxPadding). - MarginTop(boxMargin). - MarginBottom(boxMargin). - Width(warningBoxWidth) - - var warningText string - if isVeryCompact { - warningText = "⚠️ TCP tunnels don't support custom subdomains." - } else { - warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization." - } - b.WriteString(warningBoxStyle.Render(warningText)) - b.WriteString("\n\n") - - var helpText string - if isVeryCompact { - helpText = "Press any key to go back" - } else { - helpText = "Press Enter or Esc to go back" - } - b.WriteString(helpStyle.Render(helpText)) - return b.String() + return "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -" } - var rulesContent string - if isVeryCompact { - rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -" - } else if isCompact { - rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -" - } else { - rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -" + if isCompact { + return "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -" } - b.WriteString(rulesBoxStyle.Render(rulesContent)) - b.WriteString("\n") - var instruction string + return "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -" +} + +func (m *model) renderSlugInstruction(isVeryCompact bool) string { + instructionStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorWhite)). + MarginTop(1) + + instruction := "Enter your custom subdomain:" if isVeryCompact { instruction = "Custom subdomain:" - } else { - instruction = "Enter your custom subdomain:" } - b.WriteString(instructionStyle.Render(instruction)) - b.WriteString("\n") + + return instructionStyle.Render(instruction) + "\n" +} + +func (m *model) renderSlugInput(isVeryCompact, isCompact bool) string { + boxPadding := getPaddingValue(isVeryCompact, isCompact) + boxMargin := getMarginValue(isCompact, 1, 2) if m.slugError != "" { - errorInputBoxStyle := lipgloss.NewStyle(). - Border(lipgloss.RoundedBorder()). - BorderForeground(lipgloss.Color("#FF0000")). - Padding(1, boxPadding). - MarginTop(boxMargin). - MarginBottom(1) - b.WriteString(errorInputBoxStyle.Render(m.slugInput.View())) - b.WriteString("\n") - b.WriteString(errorBoxStyle.Render("❌ " + m.slugError)) - b.WriteString("\n") - } else { - b.WriteString(inputBoxStyle.Render(m.slugInput.View())) - b.WriteString("\n") + return m.renderErrorInput(boxPadding, boxMargin) } + return m.renderNormalInput(boxPadding, boxMargin) +} + +func (m *model) renderErrorInput(boxPadding, boxMargin int) string { + errorInputBoxStyle := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(ColorError)). + Padding(1, boxPadding). + MarginTop(boxMargin). + MarginBottom(1) + + errorBoxStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorError)). + Background(lipgloss.Color(ColorErrorBg)). + Bold(true). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(ColorError)). + Padding(0, boxPadding). + MarginTop(1). + MarginBottom(1) + + var b strings.Builder + b.WriteString(errorInputBoxStyle.Render(m.slugInput.View())) + b.WriteString("\n") + b.WriteString(errorBoxStyle.Render("❌ " + m.slugError)) + b.WriteString("\n") + + return b.String() +} + +func (m *model) renderNormalInput(boxPadding, boxMargin int) string { + inputBoxStyle := lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(lipgloss.Color(ColorPrimary)). + Padding(1, boxPadding). + MarginTop(boxMargin). + MarginBottom(boxMargin) + + return inputBoxStyle.Render(m.slugInput.View()) + "\n" +} + +func (m *model) renderSlugPreview(isVeryCompact bool) string { previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain) previewWidth := getResponsiveWidth(m.width, 10, 30, 80) - if len(previewURL) > previewWidth-10 { + if isVeryCompact { previewURL = truncateString(previewURL, previewWidth-10) } previewStyle := lipgloss.NewStyle(). - Foreground(lipgloss.Color("#04B575")). + Foreground(lipgloss.Color(ColorSecondary)). Italic(true). Width(previewWidth) - b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL))) - b.WriteString("\n") - var helpText string + return previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)) + "\n" +} + +func (m *model) renderSlugHelp(isVeryCompact bool) string { + helpStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color(ColorDarkGray)). + Italic(true). + MarginTop(1) + + helpText := "Press Enter to save • CTRL+R for random • Esc to cancel" if isVeryCompact { helpText = "Enter: save • CTRL+R: random • Esc: cancel" - } else { - helpText = "Press Enter to save • CTRL+R for random • Esc to cancel" } - b.WriteString(helpStyle.Render(helpText)) - return b.String() + return helpStyle.Render(helpText) +} + +func getPaddingValue(isVeryCompact, isCompact bool) int { + if isVeryCompact || isCompact { + return 1 + } + return 2 } diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index f9f9d6e..2d9dd3e 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -2,6 +2,9 @@ package lifecycle import ( "errors" + "io" + "net" + "sync" "time" portUtil "tunnel_pls/internal/port" @@ -22,7 +25,9 @@ type SessionRegistry interface { } type lifecycle struct { + mu sync.Mutex status types.SessionStatus + closeErr error conn ssh.Conn channel ssh.Channel forwarder Forwarder @@ -49,6 +54,7 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti type Lifecycle interface { Connection() ssh.Conn + Channel() ssh.Channel PortRegistry() portUtil.Port User() string SetChannel(channel ssh.Channel) @@ -69,33 +75,48 @@ func (l *lifecycle) User() string { func (l *lifecycle) SetChannel(channel ssh.Channel) { l.channel = channel } + +func (l *lifecycle) Channel() ssh.Channel { + return l.channel +} + func (l *lifecycle) Connection() ssh.Conn { return l.conn } + func (l *lifecycle) SetStatus(status types.SessionStatus) { + l.mu.Lock() + defer l.mu.Unlock() l.status = status - if status == types.SessionStatusRUNNING && l.startedAt.IsZero() { - l.startedAt = time.Now() - } } -func closeIfNotNil(c interface{ Close() error }) error { - if c != nil { - return c.Close() - } - return nil +func (l *lifecycle) IsActive() bool { + l.mu.Lock() + defer l.mu.Unlock() + return l.status == types.SessionStatusRUNNING } func (l *lifecycle) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + if l.status == types.SessionStatusCLOSED { + return l.closeErr + } + l.status = types.SessionStatusCLOSED + var errs []error tunnelType := l.forwarder.TunnelType() - if err := closeIfNotNil(l.channel); err != nil { - errs = append(errs, err) + if l.channel != nil { + if err := l.channel.Close(); err != nil && !isClosedError(err) { + errs = append(errs, err) + } } - if err := closeIfNotNil(l.conn); err != nil { - errs = append(errs, err) + if l.conn != nil { + if err := l.conn.Close(); err != nil && !isClosedError(err) { + errs = append(errs, err) + } } clientSlug := l.slug.String() @@ -106,19 +127,19 @@ func (l *lifecycle) Close() error { l.sessionRegistry.Remove(key) if tunnelType == types.TunnelTypeTCP { - if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil { - errs = append(errs, err) - } - if err := l.forwarder.Close(); err != nil { - errs = append(errs, err) - } + errs = append(errs, l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false)) + errs = append(errs, l.forwarder.Close()) } - return errors.Join(errs...) + l.closeErr = errors.Join(errs...) + return l.closeErr } -func (l *lifecycle) IsActive() bool { - return l.status == types.SessionStatusRUNNING +func isClosedError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || err.Error() == "EOF" } func (l *lifecycle) StartedAt() time.Time { diff --git a/session/lifecycle/lifecycle_test.go b/session/lifecycle/lifecycle_test.go new file mode 100644 index 0000000..608b3a8 --- /dev/null +++ b/session/lifecycle/lifecycle_test.go @@ -0,0 +1,303 @@ +package lifecycle + +import ( + "context" + "errors" + "io" + "net" + "testing" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/ssh" +) + +type MockSessionRegistry struct { + mock.Mock +} + +func (m *MockSessionRegistry) Remove(key types.SessionKey) { + m.Called(key) +} + +type MockForwarder struct { + mock.Mock +} + +func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { + args := m.Called(origin) + return args.Get(0).([]byte) +} + +func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { + m.Called(dst, src) +} + +func (m *MockForwarder) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockForwarder) TunnelType() types.TunnelType { + args := m.Called() + return args.Get(0).(types.TunnelType) +} + +func (m *MockForwarder) ForwardedPort() uint16 { + args := m.Called() + return args.Get(0).(uint16) +} + +func (m *MockForwarder) SetType(tunnelType types.TunnelType) { + m.Called(tunnelType) +} + +func (m *MockForwarder) SetForwardedPort(port uint16) { + m.Called(port) +} + +func (m *MockForwarder) SetListener(listener net.Listener) { + m.Called(listener) +} + +func (m *MockForwarder) Listener() net.Listener { + args := m.Called() + return args.Get(0).(net.Listener) +} + +func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(ctx, origin) + if args.Get(0) == nil { + return nil, nil, args.Error(2) + } + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +type MockPort struct { + mock.Mock +} + +func (m *MockPort) AddRange(startPort, endPort uint16) error { + return m.Called(startPort, endPort).Error(0) +} +func (m *MockPort) Unassigned() (uint16, bool) { + args := m.Called() + var port uint16 + if args.Get(0) != nil { + switch v := args.Get(0).(type) { + case int: + port = uint16(v) + case uint16: + port = v + case uint32: + port = uint16(v) + case int32: + port = uint16(v) + case float64: + port = uint16(v) + default: + port = uint16(args.Int(0)) + } + } + return port, args.Bool(1) +} +func (m *MockPort) SetStatus(port uint16, assigned bool) error { + return m.Called(port, assigned).Error(0) +} +func (m *MockPort) Claim(port uint16) bool { + return m.Called(port).Bool(0) +} + +type MockSlug struct { + mock.Mock +} + +func (ms *MockSlug) Set(slug string) { + ms.Called(slug) +} +func (ms *MockSlug) String() string { + return ms.Called().String(0) +} + +type MockSSHConn struct { + ssh.Conn + mock.Mock +} + +func (m *MockSSHConn) Close() error { + args := m.Called() + return args.Error(0) +} + +type MockSSHChannel struct { + ssh.Channel + mock.Mock +} + +func (m *MockSSHChannel) Close() error { + return m.Called().Error(0) +} + +func TestNew(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + assert.NotNil(t, mockLifecycle.Connection()) + assert.NotNil(t, mockLifecycle.User()) + assert.NotNil(t, mockLifecycle.PortRegistry()) + assert.NotNil(t, mockLifecycle.StartedAt()) +} + +func TestLifecycle_User(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + user := "mas-fuad" + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user) + assert.Equal(t, user, mockLifecycle.User()) +} + +func TestLifecycle_SetChannel(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + mockSSHChannel := &MockSSHChannel{} + + mockLifecycle.SetChannel(mockSSHChannel) + + assert.Equal(t, mockSSHChannel, mockLifecycle.Channel()) +} + +func TestLifecycle_SetStatus(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + mockLifecycle.SetStatus(types.SessionStatusRUNNING) + assert.True(t, mockLifecycle.IsActive()) +} + +func TestLifecycle_IsActive(t *testing.T) { + mockSSHConn := new(MockSSHConn) + mockForwarder := &MockForwarder{} + mockSlug := &MockSlug{} + mockPort := &MockPort{} + mockSessionRegistry := &MockSessionRegistry{} + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + mockLifecycle.SetStatus(types.SessionStatusRUNNING) + assert.True(t, mockLifecycle.IsActive()) +} + +func TestLifecycle_Close(t *testing.T) { + tests := []struct { + name string + tunnelType types.TunnelType + connCloseErr error + channelCloseErr error + expectErr bool + alreadyClosed bool + }{ + { + name: "Close HTTP forwarding success", + tunnelType: types.TunnelTypeHTTP, + expectErr: false, + }, + { + name: "Close TCP forwarding success", + tunnelType: types.TunnelTypeTCP, + expectErr: false, + }, + { + name: "Close with conn close error", + tunnelType: types.TunnelTypeHTTP, + connCloseErr: errors.New("conn close error"), + expectErr: true, + }, + { + name: "Close with channel close error", + tunnelType: types.TunnelTypeHTTP, + channelCloseErr: errors.New("channel close error"), + expectErr: true, + }, + { + name: "Close when already closed", + tunnelType: types.TunnelTypeHTTP, + alreadyClosed: true, + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSSHConn := &MockSSHConn{} + mockSSHConn.On("Close").Return(tt.connCloseErr) + + mockForwarder := &MockForwarder{} + mockForwarder.On("TunnelType").Return(tt.tunnelType) + if tt.tunnelType == types.TunnelTypeTCP { + mockForwarder.On("ForwardedPort").Return(uint16(8080)) + mockForwarder.On("Close").Return(nil) + } + + mockSlug := &MockSlug{} + mockSlug.On("String").Return("test-slug") + + mockPort := &MockPort{} + if tt.tunnelType == types.TunnelTypeTCP { + mockPort.On("SetStatus", uint16(8080), false).Return(nil) + } + + mockSessionRegistry := &MockSessionRegistry{} + mockSessionRegistry.On("Remove", mock.Anything).Return() + + mockSSHChannel := &MockSSHChannel{} + mockSSHChannel.On("Close").Return(tt.channelCloseErr) + + mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad") + + mockLifecycle.SetStatus(types.SessionStatusRUNNING) + mockLifecycle.SetChannel(mockSSHChannel) + + if tt.alreadyClosed { + err := mockLifecycle.Close() + assert.NoError(t, err) + } + + err := mockLifecycle.Close() + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + assert.False(t, mockLifecycle.IsActive()) + + mockSSHConn.AssertExpectations(t) + mockForwarder.AssertExpectations(t) + mockSlug.AssertExpectations(t) + mockPort.AssertExpectations(t) + mockSessionRegistry.AssertExpectations(t) + mockSSHChannel.AssertExpectations(t) + }) + } +} diff --git a/session/session.go b/session/session.go index b1895ab..cc27c4c 100644 --- a/session/session.go +++ b/session/session.go @@ -1,7 +1,6 @@ package session import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -37,6 +36,7 @@ type Session interface { } type session struct { + randomizer random.Random config config.Config initialReq <-chan *ssh.Request sshChan <-chan ssh.NewChannel @@ -47,23 +47,35 @@ type session struct { registry registry.Registry } +type Config struct { + Randomizer random.Random + Config config.Config + Conn *ssh.ServerConn + InitialReq <-chan *ssh.Request + SshChan <-chan ssh.NewChannel + SessionRegistry registry.Registry + PortRegistry portUtil.Port + User string +} + var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} -func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session { +func New(conf *Config) Session { slugManager := slug.New() - forwarderManager := forwarder.New(config, slugManager, conn) - lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user) - interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close) + forwarderManager := forwarder.New(conf.Config, slugManager, conf.Conn) + lifecycleManager := lifecycle.New(conf.Conn, forwarderManager, slugManager, conf.PortRegistry, conf.SessionRegistry, conf.User) + interactionManager := interaction.New(conf.Randomizer, conf.Config, slugManager, forwarderManager, conf.SessionRegistry, conf.User, lifecycleManager.Close) return &session{ - config: config, - initialReq: initialReq, - sshChan: sshChan, + randomizer: conf.Randomizer, + config: conf.Config, + initialReq: conf.InitialReq, + sshChan: conf.SshChan, lifecycle: lifecycleManager, interaction: interactionManager, forwarder: forwarderManager, slug: slugManager, - registry: sessionRegistry, + registry: conf.SessionRegistry, } } @@ -85,12 +97,12 @@ func (s *session) Slug() slug.Slug { func (s *session) Detail() *types.Detail { tunnelTypeMap := map[types.TunnelType]string{ - types.TunnelTypeHTTP: "TunnelTypeHTTP", - types.TunnelTypeTCP: "TunnelTypeTCP", + types.TunnelTypeHTTP: "HTTP", + types.TunnelTypeTCP: "TCP", } tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()] if !ok { - tunnelType = "TunnelTypeUNKNOWN" + tunnelType = "UNKNOWN" } return &types.Detail{ @@ -113,7 +125,7 @@ func (s *session) Start() error { } if s.shouldRejectUnauthorized() { - return s.denyForwardingRequest(tcpipReq, nil, nil, fmt.Sprintf("headless forwarding only allowed on node mode")) + return s.denyForwardingRequest(tcpipReq, nil, nil, "headless forwarding only allowed on node mode") } if err := s.HandleTCPIPForward(tcpipReq); err != nil { @@ -160,13 +172,11 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error { } func (s *session) handleMissingForwardRequest() error { - err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort())) + err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort())) if err != nil { return err } - if err = s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) - } + return fmt.Errorf("no forwarding Request") } @@ -182,7 +192,6 @@ func (s *session) waitForSessionEnd() error { } if err := s.lifecycle.Close(); err != nil { - log.Printf("failed to close session: %v", err) return err } return nil @@ -227,8 +236,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error { for req := range GlobalRequest { switch req.Type { case "shell", "pty-req": - err := req.Reply(true, nil) - if err != nil { + if err := req.Reply(true, nil); err != nil { return err } case "window-change": @@ -237,8 +245,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error { } default: log.Println("Unknown request type:", req.Type) - err := req.Reply(false, nil) - if err != nil { + if err := req.Reply(false, nil); err != nil { return err } } @@ -246,24 +253,24 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error { return nil } -func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) { - address, err = readSSHString(payloadReader) - if err != nil { - return "", 0, err +func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) { + var forwardPayload struct { + BindAddr string + BindPort uint32 } - var rawPortToBind uint32 - if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil { - return "", 0, err + if err = ssh.Unmarshal(payload, &forwardPayload); err != nil { + return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err) } - if rawPortToBind > 65535 { + if forwardPayload.BindPort > 65535 { return "", 0, fmt.Errorf("port is larger than allowed port of 65535") } - port = uint16(rawPortToBind) + port = uint16(forwardPayload.BindPort) + if isBlockedPort(port) { - return "", 0, fmt.Errorf("port is block") + return "", 0, fmt.Errorf("port is blocked") } if port == 0 { @@ -271,10 +278,10 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, if !ok { return "", 0, fmt.Errorf("no available port") } - return address, unassigned, err + return forwardPayload.BindAddr, unassigned, nil } - return address, port, err + return forwardPayload.BindAddr, port, nil } func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error { @@ -282,37 +289,25 @@ func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, if key != nil { s.registry.Remove(*key) } + if listener != nil { - if err := listener.Close(); err != nil { - errs = append(errs, fmt.Errorf("close listener: %w", err)) - } - } - if err := req.Reply(false, nil); err != nil { - errs = append(errs, fmt.Errorf("reply request: %w", err)) - } - if err := s.lifecycle.Close(); err != nil { - errs = append(errs, fmt.Errorf("close session: %w", err)) + errs = append(errs, listener.Close()) } + + errs = append(errs, req.Reply(false, nil)) + errs = append(errs, s.lifecycle.Close()) errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg)) return errors.Join(errs...) } -func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) { - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.BigEndian, uint32(port)) - if err != nil { - return err - } - - err = req.Reply(true, buf.Bytes()) - if err != nil { - return err - } - return nil -} - func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error { - err := s.approveForwardingRequest(req, portToBind) + replyPayload := struct { + BoundPort uint32 + }{ + BoundPort: uint32(portToBind), + } + err := req.Reply(true, ssh.Marshal(replyPayload)) + if err != nil { return err } @@ -330,9 +325,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen } func (s *session) HandleTCPIPForward(req *ssh.Request) error { - reader := bytes.NewReader(req.Payload) - - address, port, err := s.parseForwardPayload(reader) + address, port, err := s.parseForwardPayload(req.Payload) if err != nil { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error())) } @@ -346,7 +339,7 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) error { } func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { - randomString, err := random.GenerateRandomString(20) + randomString, err := s.randomizer.String(20) if err != nil { return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err)) } @@ -364,13 +357,13 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error { if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed { - return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) + return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind)) } tcpServer := transport.NewTCPServer(portToBind, s.forwarder) listener, err := tcpServer.Listen() if err != nil { - return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind)) + return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind)) } key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP} @@ -393,18 +386,6 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin return nil } -func readSSHString(reader io.Reader) (string, error) { - var length uint32 - if err := binary.Read(reader, binary.BigEndian, &length); err != nil { - return "", err - } - strBytes := make([]byte, length) - if _, err := reader.Read(strBytes); err != nil { - return "", err - } - return string(strBytes), nil -} - func isBlockedPort(port uint16) bool { if port == 80 || port == 443 { return false diff --git a/session/session_test.go b/session/session_test.go new file mode 100644 index 0000000..8a89125 --- /dev/null +++ b/session/session_test.go @@ -0,0 +1,1360 @@ +package session + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/binary" + "encoding/pem" + "fmt" + "io" + "net" + "strconv" + "strings" + "testing" + "time" + "tunnel_pls/internal/config" + "tunnel_pls/internal/registry" + "tunnel_pls/session/lifecycle" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +type mockRandom struct { + mock.Mock +} + +func (m *mockRandom) String(length int) (string, error) { + args := m.Called(length) + return args.String(0), args.Error(1) +} + +type mockConfig struct { + mock.Mock + config.Config +} + +func (m *mockConfig) Domain() string { return m.Called().String(0) } +func (m *mockConfig) SSHPort() string { return m.Called().String(0) } +func (m *mockConfig) Mode() types.ServerMode { + args := m.Called() + if args.Get(0) == nil { + return 0 + } + switch v := args.Get(0).(type) { + case types.ServerMode: + return v + case int: + return types.ServerMode(v) + default: + return types.ServerMode(args.Int(0)) + } +} +func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) } + +type mockRegistry struct { + mock.Mock + registry.Registry + removedKey types.SessionKey +} + +func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool { + return m.Called(key, session).Bool(0) +} + +func (m *mockRegistry) Remove(key types.SessionKey) { + m.removedKey = key +} + +type mockPort struct { + mock.Mock +} + +func (m *mockPort) AddRange(startPort, endPort uint16) error { + return m.Called(startPort, endPort).Error(0) +} +func (m *mockPort) Unassigned() (uint16, bool) { + args := m.Called() + var port uint16 + if args.Get(0) != nil { + switch v := args.Get(0).(type) { + case int: + port = uint16(v) + case uint16: + port = v + case uint32: + port = uint16(v) + case int32: + port = uint16(v) + case float64: + port = uint16(v) + default: + port = uint16(args.Int(0)) + } + } + return port, args.Bool(1) +} +func (m *mockPort) SetStatus(port uint16, assigned bool) error { + return m.Called(port, assigned).Error(0) +} +func (m *mockPort) Claim(port uint16) bool { + return m.Called(port).Bool(0) +} + +type mockSSHConn struct { + ssh.Conn + mock.Mock +} + +func (m *mockSSHConn) Wait() error { + return m.Called().Error(0) +} + +func (m *mockSSHConn) Close() error { + return m.Called().Error(0) +} + +func (m *mockSSHConn) User() string { + return m.Called().String(0) +} + +func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, sChans <-chan ssh.NewChannel, cConn ssh.Conn, cleanup func()) { + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + privDER := x509.MarshalPKCS1PrivateKey(key) + privBlock := pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: privDER, + } + pk, err := ssh.ParsePrivateKey(pem.EncodeToMemory(&privBlock)) + require.NoError(t, err) + + sCfg := &ssh.ServerConfig{ + NoClientAuth: true, + } + sCfg.AddHostKey(pk) + + cCfg := &ssh.ClientConfig{ + User: "test", + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + + var sConnObj *ssh.ServerConn + var sChansChan <-chan ssh.NewChannel + var sReqsChan <-chan *ssh.Request + + errChan := make(chan error, 1) + go func() { + conn, err := l.Accept() + if err != nil { + errChan <- err + return + } + sConnObj, sChansChan, sReqsChan, err = ssh.NewServerConn(conn, sCfg) + errChan <- err + }() + + conn, err := net.Dial("tcp", l.Addr().String()) + require.NoError(t, err) + cConnObj, cChans, cReqs, err := ssh.NewClientConn(conn, "pipe", cCfg) + require.NoError(t, err) + + go ssh.DiscardRequests(cReqs) + go func() { + for newChan := range cChans { + if newChan.ChannelType() == "session" { + continue + } + err = newChan.Reject(ssh.Prohibited, "") + assert.NoError(t, err) + } + }() + + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(5 * time.Second): + t.Fatal("SSH handshake timed out") + } + + return sConnObj, sReqsChan, sChansChan, cConnObj, func() { + _ = cConnObj.Close() + _ = sConnObj.Close() + _ = l.Close() + } +} + +func TestNew(t *testing.T) { + conf := &Config{ + Randomizer: &mockRandom{}, + Config: &mockConfig{}, + Conn: &ssh.ServerConn{}, + InitialReq: make(chan *ssh.Request), + SshChan: make(chan ssh.NewChannel), + SessionRegistry: &mockRegistry{}, + PortRegistry: &mockPort{}, + User: "testuser", + } + + s := New(conf) + assert.NotNil(t, s) + assert.NotNil(t, s.Lifecycle()) + assert.NotNil(t, s.Interaction()) + assert.NotNil(t, s.Forwarder()) + assert.NotNil(t, s.Slug()) +} + +func TestDetail(t *testing.T) { + conf := &Config{ + Randomizer: &mockRandom{}, + Config: &mockConfig{}, + Conn: &ssh.ServerConn{}, + InitialReq: make(chan *ssh.Request), + SshChan: make(chan ssh.NewChannel), + SessionRegistry: &mockRegistry{}, + PortRegistry: &mockPort{}, + User: "testuser", + } + + s := New(conf).(*session) + s.forwarder.SetType(types.TunnelTypeHTTP) + s.slug.Set("test-slug") + s.lifecycle.SetStatus(types.SessionStatusRUNNING) + + detail := s.Detail() + assert.Equal(t, "HTTP", detail.ForwardingType) + assert.Equal(t, "test-slug", detail.Slug) + assert.Equal(t, "testuser", detail.UserID) + assert.True(t, detail.Active) + + s.forwarder.SetType(types.TunnelTypeTCP) + detail = s.Detail() + assert.Equal(t, "TCP", detail.ForwardingType) + + s.forwarder.SetType(types.TunnelTypeUNKNOWN) + detail = s.Detail() + assert.Equal(t, "UNKNOWN", detail.ForwardingType) +} + +func TestIsBlockedPort(t *testing.T) { + tests := []struct { + port uint16 + expected bool + }{ + {80, false}, + {443, false}, + {22, true}, + {1023, true}, + {1024, false}, + {1080, true}, + {3306, true}, + {8080, true}, + {0, false}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("Port %d", tt.port), func(t *testing.T) { + assert.Equal(t, tt.expected, isBlockedPort(tt.port)) + }) + } +} + +func TestHandleGlobalRequest(t *testing.T) { + _, sReqs, _, cConn, cleanup := setupSSH(t) + defer cleanup() + + conf := &Config{ + Randomizer: &mockRandom{}, + Config: &mockConfig{}, + Conn: &ssh.ServerConn{}, + InitialReq: make(chan *ssh.Request), + SshChan: make(chan ssh.NewChannel), + SessionRegistry: &mockRegistry{}, + PortRegistry: &mockPort{}, + User: "testuser", + } + s := New(conf).(*session) + + done := make(chan struct{}) + go func() { + _ = s.HandleGlobalRequest(sReqs) + close(done) + }() + + tests := []struct { + name string + reqType string + payload []byte + wantReply bool + expected bool + }{ + {"shell", "shell", nil, true, true}, + {"pty-req", "pty-req", nil, true, true}, + {"window-change valid", "window-change", make([]byte, 16), true, true}, + {"window-change invalid", "window-change", make([]byte, 4), true, false}, + {"unknown", "unknown", nil, true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ok, _, err := cConn.SendRequest(tt.reqType, tt.wantReply, tt.payload) + assert.NoError(t, err) + assert.Equal(t, tt.expected, ok) + }) + } + + err := cConn.Close() + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("HandleGlobalRequest timed out after cConn.Close()") + } +} + +func TestHandleTCPIPForward_Table(t *testing.T) { + setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { + sConn, sReqs, _, cConn, cleanup := setupSSH(t) + mRegistry := &mockRegistry{} + mPort := &mockPort{} + mRandom := &mockRandom{} + conf := &Config{ + Randomizer: mRandom, + Config: &mockConfig{}, + Conn: sConn, + InitialReq: make(chan *ssh.Request), + SshChan: make(chan ssh.NewChannel), + SessionRegistry: mRegistry, + PortRegistry: mPort, + User: "testuser", + } + s := New(conf).(*session) + return s, mRegistry, mPort, mRandom, sConn, sReqs, cConn, cleanup + } + + t.Run("HTTP Forward Success", func(t *testing.T) { + s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mRandom.On("String", 20).Return("test-slug-1234567890", nil) + mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 80) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + var req *ssh.Request + select { + case req = <-sReqs: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for tcpip-forward request") + } + + err := s.HandleTCPIPForward(req) + assert.NoError(t, err) + assert.Equal(t, "test-slug-1234567890", s.slug.String()) + }) + + t.Run("TCP Forward Success", func(t *testing.T) { + s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mPort.On("Claim", mock.Anything).Return(true) + mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 0) + + mPort.On("Unassigned").Return(uint16(12345), true) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + var req *ssh.Request + select { + case req = <-sReqs: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for tcpip-forward request") + } + + err := s.HandleTCPIPForward(req) + assert.NoError(t, err) + assert.Equal(t, uint16(12345), s.forwarder.ForwardedPort()) + }) + + t.Run("Invalid Payload", func(t *testing.T) { + s, _, _, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + payload := []byte{0, 0, 0} + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + var req *ssh.Request + select { + case req = <-sReqs: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for tcpip-forward request") + } + + err := s.HandleTCPIPForward(req) + assert.Error(t, err) + }) + + t.Run("Blocked Port", func(t *testing.T) { + s, _, _, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 22) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + var req *ssh.Request + select { + case req = <-sReqs: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for tcpip-forward request") + } + + err := s.HandleTCPIPForward(req) + assert.Error(t, err) + }) +} + +func TestStart_Table(t *testing.T) { + setup := func(t *testing.T) (*session, *Config, ssh.Conn, func()) { + sConn, sReqs, sChans, cConn, cleanup := setupSSH(t) + mRegistry := &mockRegistry{} + mPort := &mockPort{} + mRandom := &mockRandom{} + mConfig := &mockConfig{} + mConfig.On("Mode").Return(types.ServerModeSTANDALONE) + mConfig.On("Domain").Return("example.com") + mConfig.On("SSHPort").Return("2222") + + conf := &Config{ + Randomizer: mRandom, + Config: mConfig, + Conn: sConn, + InitialReq: sReqs, + SshChan: sChans, + SessionRegistry: mRegistry, + PortRegistry: mPort, + User: "testuser", + } + s := New(conf).(*session) + return s, conf, cConn, cleanup + } + + t.Run("Full Success TCP", func(t *testing.T) { + s, conf, cConn, cleanup := setup(t) + defer cleanup() + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 0) + + conf.PortRegistry.(*mockPort).On("Claim", mock.Anything).Return(true) + conf.PortRegistry.(*mockPort).On("Unassigned").Return(uint16(0), true) + conf.PortRegistry.(*mockPort).On("SetStatus", mock.AnythingOfType("uint16"), mock.Anything).Return(nil) + conf.SessionRegistry.(*mockRegistry).On("Register", mock.Anything, mock.Anything).Return(true) + conf.Config.(*mockConfig).On("TLSEnabled").Return(false) + go func() { + time.Sleep(200 * time.Millisecond) + ch, reqs, err := cConn.OpenChannel("session", nil) + if err == nil { + go ssh.DiscardRequests(reqs) + time.Sleep(200 * time.Millisecond) + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + time.Sleep(200 * time.Millisecond) + write, err := ch.Write([]byte("q")) + assert.NoError(t, err) + assert.NotZero(t, write) + time.Sleep(100 * time.Millisecond) + _ = ch.Close() + } + _ = cConn.Close() + }() + + err := s.Start() + assert.NoError(t, err) + }) + + t.Run("Headless mode success", func(t *testing.T) { + s, conf, cConn, cleanup := setup(t) + defer cleanup() + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 80) + + conf.Randomizer.(*mockRandom).On("String", 20).Return("headless-slug", nil) + conf.SessionRegistry.(*mockRegistry).On("Register", mock.Anything, mock.Anything).Return(true) + + go func() { + time.Sleep(600 * time.Millisecond) + _, _, err := cConn.SendRequest("tcpip-forward", true, payload) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + err = cConn.Close() + assert.NoError(t, err) + + }() + + err := s.Start() + assert.NoError(t, err) + }) + + t.Run("Missing Forward Request", func(t *testing.T) { + s, _, cConn, cleanup := setup(t) + defer cleanup() + + go func() { + time.Sleep(1200 * time.Millisecond) + _ = cConn.Close() + }() + + err := s.Start() + assert.Error(t, err) + assert.Contains(t, err.Error(), "no forwarding Request") + }) + + t.Run("Unauthorized Headless", func(t *testing.T) { + _, conf, cConn, cleanup := setup(t) + defer cleanup() + + conf.User = "UNAUTHORIZED" + s := New(conf).(*session) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 80) + + go func() { + time.Sleep(600 * time.Millisecond) + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + err := s.Start() + assert.Error(t, err) + }) +} + +func TestForwardingFailures(t *testing.T) { + setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { + sConn, sReqs, _, cConn, cleanup := setupSSH(t) + mRegistry := &mockRegistry{} + mPort := &mockPort{} + mRandom := &mockRandom{} + conf := &Config{ + Randomizer: mRandom, + Config: &mockConfig{}, + Conn: sConn, + InitialReq: make(chan *ssh.Request), + SshChan: make(chan ssh.NewChannel), + SessionRegistry: mRegistry, + PortRegistry: mPort, + User: "testuser", + } + s := New(conf).(*session) + return s, mRegistry, mPort, mRandom, sConn, sReqs, cConn, cleanup + } + + t.Run("HTTP Registration Failed", func(t *testing.T) { + s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mRandom.On("String", 20).Return("test-slug", nil) + mRegistry.On("Register", mock.Anything, mock.Anything).Return(false) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 80) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + var req *ssh.Request + select { + case req = <-sReqs: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for tcpip-forward request") + } + + err := s.HandleTCPIPForward(req) + assert.Error(t, err) + }) + + t.Run("TCP Port Claim Failed", func(t *testing.T) { + s, _, mPort, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mPort.On("Claim", mock.Anything).Return(false) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 1234) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + var req *ssh.Request + select { + case req = <-sReqs: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for tcpip-forward request") + } + + err := s.HandleTCPIPForward(req) + assert.Error(t, err) + }) + + t.Run("HTTP Randomizer Error", func(t *testing.T) { + s, _, _, mRandom, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mRandom.On("String", 20).Return("", fmt.Errorf("random error")) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 80) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + req := <-sReqs + err := s.HandleTCPIPForward(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "random error") + }) + + t.Run("Port Registry No Port", func(t *testing.T) { + s, _, mPort, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mPort.On("Unassigned").Return(uint16(0), false) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 0) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + req := <-sReqs + err := s.HandleTCPIPForward(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no available port") + }) + + t.Run("Port too large", func(t *testing.T) { + s, _, _, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 70000) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + req := <-sReqs + err := s.HandleTCPIPForward(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "port is larger than allowed") + }) + + t.Run("TCP Registration Failed", func(t *testing.T) { + s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mPort.On("Claim", mock.Anything).Return(true) + mRegistry.On("Register", mock.Anything, mock.Anything).Return(false) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 1234) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + req := <-sReqs + err := s.HandleTCPIPForward(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "Failed to register TunnelTypeTCP client") + }) + + t.Run("Finalize Forwarding Failure", func(t *testing.T) { + s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mRandom.On("String", 20).Return("test-slug", nil) + mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], 80) + + go func() { + _, _, err := cConn.SendRequest("tcpip-forward", true, payload) + assert.Error(t, err, io.EOF) + }() + + req := <-sReqs + err := cConn.Close() + assert.NoError(t, err) + + time.Sleep(50 * time.Millisecond) + + err = s.HandleTCPIPForward(req) + assert.Error(t, err) + }) + + t.Run("TCP Listen Failure", func(t *testing.T) { + s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mPort.On("Claim", mock.Anything).Return(true) + mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + t.Fatal(err) + } + defer func(l net.Listener) { + err = l.Close() + assert.NoError(t, err) + }(l) + _, portStr, _ := net.SplitHostPort(l.Addr().String()) + port, _ := strconv.Atoi(portStr) + + payload := make([]byte, 4+9+4) + binary.BigEndian.PutUint32(payload[0:4], 9) + copy(payload[4:13], "localhost") + binary.BigEndian.PutUint32(payload[13:17], uint32(port)) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, payload) + }() + + req := <-sReqs + err = s.HandleTCPIPForward(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "is already in use or restricted") + }) +} + +func TestSetupInteractiveMode_Error(t *testing.T) { + sConn, _, sChans, _, cleanup := setupSSH(t) + defer cleanup() + + conf := &Config{ + Randomizer: &mockRandom{}, + Config: &mockConfig{}, + Conn: sConn, + InitialReq: make(chan *ssh.Request), + SshChan: sChans, + SessionRegistry: &mockRegistry{}, + PortRegistry: &mockPort{}, + User: "testuser", + } + s := New(conf).(*session) + + mockChan := &mockNewChanFail{} + err := s.setupInteractiveMode(mockChan) + if err == nil { + t.Error("expected error, got nil") + } +} + +type mockNewChanFail struct { + ssh.NewChannel +} + +func (m *mockNewChanFail) Accept() (ssh.Channel, <-chan *ssh.Request, error) { + return nil, nil, fmt.Errorf("accept failed") +} + +func TestWaitForTCPIPForward_EdgeCases(t *testing.T) { + t.Run("Wrong Request Type", func(t *testing.T) { + _, sReqs, _, cConn, cleanup := setupSSH(t) + defer cleanup() + + s := &session{initialReq: sReqs} + + go func() { + _, _, _ = cConn.SendRequest("not-tcpip-forward", true, nil) + }() + + req := s.waitForTCPIPForward() + if req != nil { + t.Error("expected nil request") + } + }) + + t.Run("Channel Closed", func(t *testing.T) { + initialReq := make(chan *ssh.Request) + s := &session{initialReq: initialReq} + close(initialReq) + + req := s.waitForTCPIPForward() + if req != nil { + t.Error("expected nil request") + } + }) +} + +func TestSetupSessionMode_ChannelClosed(t *testing.T) { + sshChan := make(chan ssh.NewChannel) + s := &session{sshChan: sshChan} + close(sshChan) + + err := s.setupSessionMode() + if err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestStart_SetupSessionModeError(t *testing.T) { + sshChan := make(chan ssh.NewChannel, 1) + conf := &Config{ + Randomizer: &mockRandom{}, + Config: &mockConfig{}, + Conn: &ssh.ServerConn{}, + InitialReq: make(chan *ssh.Request), + SshChan: sshChan, + SessionRegistry: &mockRegistry{}, + PortRegistry: &mockPort{}, + User: "testuser", + } + s := New(conf).(*session) + + mockChan := &mockNewChanFail{} + sshChan <- mockChan + + err := s.Start() + if err == nil { + t.Error("expected error, got nil") + } +} + +func TestWaitForSessionEnd_Error(t *testing.T) { + mConn := &mockSSHConn{} + mConn.On("Wait").Return(fmt.Errorf("wait error")) + mConn.On("Close").Return(nil) + + mForwarder := &mockLifecycleForwarder{} + mForwarder.On("TunnelType").Return(types.TunnelTypeTCP) + mForwarder.On("ForwardedPort").Return(uint16(80)) + mForwarder.On("Close").Return(fmt.Errorf("close error")) + + mSlug := &mockLifecycleSlug{} + mSlug.On("String").Return("slug") + + mPort := &mockPort{} + mPort.On("SetStatus", mock.Anything, mock.Anything).Return(nil) + + mRegistry := &mockRegistry{} + mRegistry.On("Remove", mock.Anything).Return() + + l := lifecycle.New(mConn, mForwarder, mSlug, mPort, mRegistry, "testuser") + s := &session{ + lifecycle: l, + } + + err := s.waitForSessionEnd() + assert.Error(t, err) +} + +type mockLifecycleForwarder struct { + mock.Mock + lifecycle.Forwarder +} + +func (m *mockLifecycleForwarder) TunnelType() types.TunnelType { + return m.Called().Get(0).(types.TunnelType) +} +func (m *mockLifecycleForwarder) ForwardedPort() uint16 { + args := m.Called() + if args.Get(0) == nil { + return 0 + } + switch v := args.Get(0).(type) { + case uint16: + return v + case uint32: + return uint16(v) + case uint64: + return uint16(v) + case uint8: + return uint16(v) + case uint: + return uint16(v) + case int: + return uint16(v) + case int8: + return uint16(v) + case int16: + return uint16(v) + case int32: + return uint16(v) + case int64: + return uint16(v) + case float32: + return uint16(v) + case float64: + return uint16(v) + default: + return uint16(args.Int(0)) + } +} +func (m *mockLifecycleForwarder) Close() error { + return m.Called().Error(0) +} + +type mockLifecycleSlug struct { + mock.Mock +} + +func (m *mockLifecycleSlug) String() string { return m.Called().String(0) } +func (m *mockLifecycleSlug) Set(slug string) { + m.Called(slug) +} + +func TestHandleMissingForwardRequest(t *testing.T) { + mConn := &mockSSHConn{} + mConfig := &mockConfig{} + mConfig.On("Domain").Return("example.com") + mConfig.On("SSHPort").Return("2222") + mConn.On("Close").Return(nil) + + conf := &Config{ + Randomizer: &mockRandom{}, + Config: mConfig, + Conn: &ssh.ServerConn{Conn: mConn}, + InitialReq: make(chan *ssh.Request), + SshChan: make(chan ssh.NewChannel), + SessionRegistry: &mockRegistry{}, + PortRegistry: &mockPort{}, + User: "testuser", + } + + s := New(conf).(*session) + + err := s.handleMissingForwardRequest() + if err == nil { + t.Error("expected error, got nil") + } +} + +func TestParseForwardPayload_Errors(t *testing.T) { + s := &session{} + + t.Run("Short Address", func(t *testing.T) { + _, _, err := s.parseForwardPayload([]byte{0, 0, 0, 4}) + if err == nil { + t.Error("expected error, got nil") + } + }) + + t.Run("Short Port", func(t *testing.T) { + payload := append([]byte{0, 0, 0, 4}, []byte("addr")...) + _, _, err := s.parseForwardPayload(payload) + if err == nil { + t.Error("expected error, got nil") + } + }) + + t.Run("Blocked Port", func(t *testing.T) { + payload := append([]byte{0, 0, 0, 4}, []byte("addr")...) + portBuf := make([]byte, 4) + binary.BigEndian.PutUint32(portBuf, 22) + payload = append(payload, portBuf...) + _, _, err := s.parseForwardPayload(payload) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "port is block") { + t.Errorf("expected error to contain %q, got %q", "port is block", err.Error()) + } + }) +} + +func TestDenyForwardingRequest_TunnelNotSetupYet(t *testing.T) { + sConn, sReqs, _, cConn, cleanup := setupSSH(t) + defer cleanup() + + mRegistry := &mockRegistry{} + mPort := &mockPort{} + mRandom := &mockRandom{} + conf := &Config{ + Randomizer: mRandom, + Config: &mockConfig{}, + Conn: sConn, + InitialReq: sReqs, + SshChan: make(chan ssh.NewChannel), + SessionRegistry: mRegistry, + PortRegistry: mPort, + User: "testuser", + } + s := New(conf).(*session) + + go func() { + _, _, _ = cConn.SendRequest("tcpip-forward", true, nil) + }() + + var req *ssh.Request + select { + case req = <-sReqs: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + key := &types.SessionKey{Id: "", Type: types.TunnelTypeUNKNOWN} + err := s.denyForwardingRequest(req, key, &mockCloser{}, "test error") + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "test error") { + t.Errorf("expected error to contain %q, got %q", "test error", err.Error()) + } + assert.Equal(t, *key, mRegistry.removedKey) +} + +func TestDenyForwardingRequest_Full(t *testing.T) { + setup := func(t *testing.T) (*session, *mockRegistry, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { + sConn, sReqs, _, cConn, cleanup := setupSSH(t) + mRegistry := &mockRegistry{} + conf := &Config{ + Randomizer: &mockRandom{}, + Config: &mockConfig{}, + Conn: sConn, + InitialReq: sReqs, + SshChan: make(chan ssh.NewChannel), + SessionRegistry: mRegistry, + PortRegistry: &mockPort{}, + User: "testuser", + } + s := New(conf).(*session) + return s, mRegistry, sConn, sReqs, cConn, cleanup + } + + getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request { + go func() { + _, _, _ = client.SendRequest("tcpip-forward", true, nil) + }() + select { + case req, ok := <-serverReqs: + if !ok { + t.Fatal("channel closed") + } + return req + case <-time.After(2 * time.Second): + t.Fatal("timeout getting request") + return nil + } + } + + t.Run("All Success", func(t *testing.T) { + s, mRegistry, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + req := getReq(t, cConn, sReqs) + key := &types.SessionKey{Id: "test", Type: types.TunnelTypeHTTP} + + s.slug.Set("test") + s.forwarder.SetType(types.TunnelTypeHTTP) + + mCloser := &mockCloser{} + err := s.denyForwardingRequest(req, key, mCloser, "error") + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "error") { + t.Errorf("expected error to contain %q, got %q", "error", err.Error()) + } + assert.Equal(t, *key, mRegistry.removedKey) + }) + + t.Run("Listener Close error", func(t *testing.T) { + s, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + req := getReq(t, cConn, sReqs) + mCloser := &mockCloser{err: fmt.Errorf("close error")} + err := s.denyForwardingRequest(req, nil, mCloser, "error") + assert.Error(t, err, net.ErrClosed) + }) + + t.Run("Reply error", func(t *testing.T) { + s, _, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + req := getReq(t, cConn, sReqs) + err := cConn.Close() + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + + err = s.denyForwardingRequest(req, nil, nil, assert.AnError.Error()) + assert.Error(t, err, assert.AnError) + }) +} + +func TestHandleTCPForward_Failures(t *testing.T) { + setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { + sConn, sReqs, _, cConn, cleanup := setupSSH(t) + mRegistry := &mockRegistry{} + mPort := &mockPort{} + conf := &Config{ + Randomizer: &mockRandom{}, + Config: &mockConfig{}, + Conn: sConn, + InitialReq: sReqs, + SshChan: make(chan ssh.NewChannel), + SessionRegistry: mRegistry, + PortRegistry: mPort, + User: "testuser", + } + s := New(conf).(*session) + return s, mRegistry, mPort, sConn, sReqs, cConn, cleanup + } + + getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request { + go func() { + _, _, _ = client.SendRequest("tcpip-forward", true, nil) + }() + select { + case req, ok := <-serverReqs: + if !ok { + t.Fatal("channel closed") + } + return req + case <-time.After(2 * time.Second): + t.Fatal("timeout getting request") + return nil + } + } + + t.Run("Port Claim fail", func(t *testing.T) { + s, _, mPort, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mPort.On("Claim", mock.Anything).Return(false) + err := s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", 1234) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "already in use") { + t.Errorf("expected error to contain %q, got %q", "already in use", err.Error()) + } + }) + + t.Run("Listen fail", func(t *testing.T) { + s, _, mPort, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mPort.On("Claim", mock.Anything).Return(true) + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + t.Fatal(err) + } + defer func(l net.Listener) { + err = l.Close() + assert.NoError(t, err) + }(l) + port := uint16(l.Addr().(*net.TCPAddr).Port) + + err = s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", port) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "already in use") { + t.Errorf("expected error to contain %q, got %q", "already in use", err.Error()) + } + }) + + t.Run("Registry Register fail", func(t *testing.T) { + s, mRegistry, mPort, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mPort.On("Claim", mock.Anything).Return(true) + mRegistry.On("Register", mock.Anything, mock.Anything).Return(false) + err := s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", 0) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "Failed to register") { + t.Errorf("expected error to contain %q, got %q", "Failed to register", err.Error()) + } + }) + + t.Run("Finalize fail (Reply fail)", func(t *testing.T) { + s, mRegistry, mPort, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mPort.On("Claim", mock.Anything).Return(true) + mRegistry.On("Register", mock.Anything, mock.Anything).Return(true) + req := getReq(t, cConn, sReqs) + err := cConn.Close() + assert.NoError(t, err) + time.Sleep(100 * time.Millisecond) + + err = s.HandleTCPForward(req, "localhost", 0) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "Failed to finalize forwarding") { + t.Errorf("expected error to contain %q, got %q", "Failed to finalize forwarding", err.Error()) + } + }) +} + +func TestHandleHTTPForward_Failures(t *testing.T) { + setup := func(t *testing.T) (*session, *mockRegistry, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) { + sConn, sReqs, _, cConn, cleanup := setupSSH(t) + mRegistry := &mockRegistry{} + mRandom := &mockRandom{} + s := New(&Config{ + Randomizer: mRandom, + Config: &mockConfig{}, + Conn: sConn, + InitialReq: sReqs, + SshChan: make(chan ssh.NewChannel), + SessionRegistry: mRegistry, + PortRegistry: &mockPort{}, + User: "testuser", + }).(*session) + return s, mRegistry, mRandom, sConn, sReqs, cConn, cleanup + } + + getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request { + go func() { _, _, _ = client.SendRequest("tcpip-forward", true, nil) }() + return <-serverReqs + } + + t.Run("Random fail", func(t *testing.T) { + s, _, mRandom, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mRandom.On("String", 20).Return("", fmt.Errorf("random error")) + err := s.HandleHTTPForward(getReq(t, cConn, sReqs), 80) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "Failed to create slug") { + t.Errorf("expected error to contain %q, got %q", "Failed to create slug", err.Error()) + } + }) + + t.Run("Register fail", func(t *testing.T) { + s, mRegistry, mRandom, _, sReqs, cConn, cleanup := setup(t) + defer cleanup() + mRandom.On("String", 20).Return("slug", nil) + mRegistry.On("Register", mock.Anything, mock.Anything).Return(false) + err := s.HandleHTTPForward(getReq(t, cConn, sReqs), 80) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "Failed to register") { + t.Errorf("expected error to contain %q, got %q", "Failed to register", err.Error()) + } + }) +} + +func TestHandleGlobalRequest_Failures(t *testing.T) { + _, sReqs, _, cConn, cleanup := setupSSH(t) + defer cleanup() + + conf := &Config{ + Randomizer: &mockRandom{}, + Config: &mockConfig{}, + Conn: &ssh.ServerConn{}, + InitialReq: make(chan *ssh.Request), + SshChan: make(chan ssh.NewChannel), + SessionRegistry: &mockRegistry{}, + PortRegistry: &mockPort{}, + User: "testuser", + } + s := New(conf).(*session) + + done := make(chan struct{}) + go func() { + _ = s.HandleGlobalRequest(sReqs) + close(done) + }() + + tests := []struct { + name string + reqType string + payload []byte + wantReply bool + expected bool + }{ + {"shell", "shell", nil, true, true}, + {"pty-req", "pty-req", nil, true, true}, + {"window-change valid", "window-change", make([]byte, 16), true, true}, + {"window-change invalid", "window-change", make([]byte, 4), true, false}, + {"unknown", "unknown", nil, true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ok, _, err := cConn.SendRequest(tt.reqType, tt.wantReply, tt.payload) + assert.NoError(t, err) + assert.Equal(t, tt.expected, ok) + }) + } + + err := cConn.Close() + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("HandleGlobalRequest timed out after cConn.Close()") + } +} + +func TestSetupInteractiveMode_GlobalRequestError(t *testing.T) { + sConn, _, sChans, _, cleanup := setupSSH(t) + defer cleanup() + + conf := &Config{ + Randomizer: &mockRandom{}, + Config: &mockConfig{}, + Conn: sConn, + InitialReq: make(chan *ssh.Request), + SshChan: sChans, + SessionRegistry: &mockRegistry{}, + PortRegistry: &mockPort{}, + User: "testuser", + } + s := New(conf).(*session) + + mockChan := &mockNewChanFail{} + err := s.setupInteractiveMode(mockChan) + if err == nil { + t.Error("expected error, got nil") + } +} + +type mockCloser struct { + err error +} + +func (m *mockCloser) Close() error { return m.err } diff --git a/session/slug/slug_test.go b/session/slug/slug_test.go new file mode 100644 index 0000000..c7af138 --- /dev/null +++ b/session/slug/slug_test.go @@ -0,0 +1,99 @@ +package slug + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type SlugTestSuite struct { + suite.Suite + slug Slug +} + +func (suite *SlugTestSuite) SetupTest() { + suite.slug = New() +} + +func TestNew(t *testing.T) { + s := New() + + assert.NotNil(t, s, "New() should return a non-nil Slug") + assert.Implements(t, (*Slug)(nil), s, "New() should return a type that implements Slug interface") + assert.Equal(t, "", s.String(), "New() should initialize with empty string") +} + +func (suite *SlugTestSuite) TestString() { + assert.Equal(suite.T(), "", suite.slug.String(), "String() should return empty string initially") + + suite.slug.Set("test-slug") + assert.Equal(suite.T(), "test-slug", suite.slug.String(), "String() should return the set value") +} + +func (suite *SlugTestSuite) TestSet() { + testCases := []struct { + name string + input string + expected string + }{ + { + name: "simple slug", + input: "hello-world", + expected: "hello-world", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "slug with numbers", + input: "test-123", + expected: "test-123", + }, + { + name: "slug with special characters", + input: "hello_world-123", + expected: "hello_world-123", + }, + { + name: "overwrite existing slug", + input: "new-slug", + expected: "new-slug", + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + suite.slug.Set(tc.input) + assert.Equal(suite.T(), tc.expected, suite.slug.String()) + }) + } +} + +func (suite *SlugTestSuite) TestMultipleSet() { + suite.slug.Set("first-slug") + assert.Equal(suite.T(), "first-slug", suite.slug.String()) + + suite.slug.Set("second-slug") + assert.Equal(suite.T(), "second-slug", suite.slug.String()) + + suite.slug.Set("") + assert.Equal(suite.T(), "", suite.slug.String()) +} + +func TestSlugIsolation(t *testing.T) { + slug1 := New() + slug2 := New() + + slug1.Set("slug-one") + slug2.Set("slug-two") + + assert.Equal(t, "slug-one", slug1.String(), "First slug should maintain its value") + assert.Equal(t, "slug-two", slug2.String(), "Second slug should maintain its value") +} + +func TestSlugTestSuite(t *testing.T) { + suite.Run(t, new(SlugTestSuite)) +} diff --git a/sonar-project.properties b/sonar-project.properties deleted file mode 100644 index 277a293..0000000 --- a/sonar-project.properties +++ /dev/null @@ -1 +0,0 @@ -sonar.projectKey=tunnel-please \ No newline at end of file diff --git a/types/types.go b/types/types.go index 34ccfb4..77d6ac4 100644 --- a/types/types.go +++ b/types/types.go @@ -7,6 +7,7 @@ type SessionStatus int const ( SessionStatusINITIALIZING SessionStatus = iota SessionStatusRUNNING + SessionStatusCLOSED ) type InteractiveMode int