1 Commits

Author SHA1 Message Date
78b7b894d9 chore(deps): update actions/checkout action to v6
SonarQube Scan / SonarQube Trigger (push) Successful in 49s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 48s
2026-01-22 16:01:02 +00:00
60 changed files with 1247 additions and 15052 deletions
+77 -37
View File
@@ -2,38 +2,24 @@ name: Docker Build and Push
on: on:
push: push:
branches:
- main
- staging
tags: tags:
- 'v*' - 'v*'
paths:
- '**.go'
- 'go.mod'
- 'go.sum'
- 'Dockerfile'
- 'Dockerfile.*'
- '.dockerignore'
- '.gitea/workflows/build.yml'
jobs: jobs:
test: build-and-push-branches:
name: Run Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
if: github.ref_type == 'branch'
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: 'stable'
cache: false
- name: Install dependencies
run: go mod download
- name: Run go vet
run: go vet ./...
- name: Run tests
run: go test -v -p 4 ./...
build-and-push:
name: Build and Push Docker Image
runs-on: ubuntu-latest
needs: test
steps: steps:
- name: Checkout repository - name: Checkout repository
@@ -42,7 +28,64 @@ jobs:
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Log in to Docker Registry - name: Log in to Docker Hub
uses: docker/login-action@v3
with:
registry: git.fossy.my.id
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Set version variables
id: vars
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "VERSION=dev-main" >> $GITHUB_OUTPUT
else
echo "VERSION=dev-staging" >> $GITHUB_OUTPUT
fi
echo "BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT
echo "COMMIT=${{ github.sha }}" >> $GITHUB_OUTPUT
- name: Build and push Docker image for main
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: |
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:latest
platforms: linux/amd64,linux/arm64
build-args: |
VERSION=${{ steps.vars.outputs.VERSION }}
BUILD_DATE=${{ steps.vars.outputs.BUILD_DATE }}
COMMIT=${{ steps.vars.outputs.COMMIT }}
if: github.ref == 'refs/heads/main'
- name: Build and push Docker image for staging
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: |
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:staging
platforms: linux/amd64,linux/arm64
build-args: |
VERSION=${{ steps.vars.outputs.VERSION }}
BUILD_DATE=${{ steps.vars.outputs.BUILD_DATE }}
COMMIT=${{ steps.vars.outputs.COMMIT }}
if: github.ref == 'refs/heads/staging'
build-and-push-tags:
runs-on: ubuntu-latest
if: github.ref_type == 'tag' && startsWith(github.ref, 'refs/tags/v')
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
registry: git.fossy.my.id registry: git.fossy.my.id
@@ -60,35 +103,32 @@ jobs:
if echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$'; then if echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$'; then
MAJOR=$(echo "$VERSION" | cut -d. -f1) MAJOR=$(echo "$VERSION" | cut -d. -f1)
MINOR=$(echo "$VERSION" | cut -d. -f2) MINOR=$(echo "$VERSION" | cut -d. -f2)
PATCH=$(echo "$VERSION" | cut -d. -f3 | cut -d- -f1)
echo "MAJOR=$MAJOR" >> $GITHUB_OUTPUT echo "MAJOR=$MAJOR" >> $GITHUB_OUTPUT
echo "MINOR=$MINOR" >> $GITHUB_OUTPUT echo "MINOR=$MINOR" >> $GITHUB_OUTPUT
echo "PATCH=$PATCH" >> $GITHUB_OUTPUT
if echo "$VERSION" | grep -q '-'; then if echo "$VERSION" | grep -q '-'; then
PRERELEASE_TAG=$(echo "$VERSION" | cut -d- -f2 | cut -d. -f1)
echo "IS_PRERELEASE=true" >> $GITHUB_OUTPUT echo "IS_PRERELEASE=true" >> $GITHUB_OUTPUT
echo "PRERELEASE_TAG=$PRERELEASE_TAG" >> $GITHUB_OUTPUT echo "ADDITIONAL_TAG=staging" >> $GITHUB_OUTPUT
else else
echo "IS_PRERELEASE=false" >> $GITHUB_OUTPUT echo "IS_PRERELEASE=false" >> $GITHUB_OUTPUT
echo "ADDITIONAL_TAG=latest" >> $GITHUB_OUTPUT
fi fi
else else
echo "Invalid version format: $VERSION" echo "Invalid version format: $VERSION"
exit 1 exit 1
fi fi
- name: Build and push Docker image (release) - name: Build and push Docker image for release
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
context: . context: .
push: true push: true
tags: | tags: |
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }} git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:release
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }}.${{ steps.version.outputs.MINOR }} git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }}.${{ steps.version.outputs.MINOR }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }} git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:latest git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:${{ steps.version.outputs.ADDITIONAL_TAG }}
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
build-args: | build-args: |
VERSION=${{ steps.version.outputs.VERSION }} VERSION=${{ steps.version.outputs.VERSION }}
@@ -96,14 +136,14 @@ jobs:
COMMIT=${{ steps.version.outputs.COMMIT }} COMMIT=${{ steps.version.outputs.COMMIT }}
if: steps.version.outputs.IS_PRERELEASE == 'false' if: steps.version.outputs.IS_PRERELEASE == 'false'
- name: Build and push Docker image (pre-release) - name: Build and push Docker image for pre-release
uses: docker/build-push-action@v6 uses: docker/build-push-action@v6
with: with:
context: . context: .
push: true push: true
tags: | tags: |
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }} git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:staging git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:${{ steps.version.outputs.ADDITIONAL_TAG }}
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
build-args: | build-args: |
VERSION=${{ steps.version.outputs.VERSION }} VERSION=${{ steps.version.outputs.VERSION }}
+2 -42
View File
@@ -1,9 +1,7 @@
on: on:
push: push:
branches: pull_request:
- main types: [opened, synchronize, reopened]
- staging
- 'feat/**'
name: SonarQube Scan name: SonarQube Scan
jobs: jobs:
@@ -15,46 +13,8 @@ jobs:
uses: actions/checkout@v6 uses: actions/checkout@v6
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: 'stable'
cache: false
- name: Install dependencies
run: go mod tidy
- name: Run go vet
run: go vet ./... 2>&1 | tee vet-results.txt
- name: Run tests with coverage
run: |
go test ./... -v -p 4 -coverprofile=coverage
- name: Run GolangCI-Lint Analysis
uses: golangci/golangci-lint-action@v9
with:
skip-cache: true
version: v2.6
args: >
--issues-exit-code=0
--output.text.path=stdout
--output.checkstyle.path=golangci-lint-report.xml
- name: SonarQube Scan - name: SonarQube Scan
uses: SonarSource/sonarqube-scan-action@v7.0.0 uses: SonarSource/sonarqube-scan-action@v7.0.0
env: env:
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }} SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }} SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
with:
args: >
-Dsonar.projectKey=tunnel-please
-Dsonar.go.coverage.reportPaths=coverage
-Dsonar.test.inclusions=**/*_test.go
-Dsonar.test.exclusions=**/vendor/**
-Dsonar.exclusions=**/*_test.go,**/vendor/**,**/golangci-lint-report.xml
-Dsonar.go.govet.reportPaths=vet-results.txt
-Dsonar.go.golangci-lint.reportPaths=golangci-lint-report.xml
-Dsonar.sources=./
-Dsonar.tests=./
-36
View File
@@ -1,36 +0,0 @@
name: Tests
on:
pull_request:
types: [opened, synchronize, reopened]
issue_comment:
types: [created]
jobs:
test:
name: Run Tests
runs-on: ubuntu-latest
if: |
github.event_name == 'pull_request' ||
(github.event_name == 'issue_comment' &&
github.event.issue.pull_request != null &&
contains(github.event.comment.body, '/retest'))
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: 'stable'
cache: false
- name: Install dependencies
run: go mod download
- name: Run go vet
run: go vet ./...
- name: Run tests
run: go test -v -p 4 ./...
Vendored
-2
View File
@@ -5,5 +5,3 @@ id_rsa*
tmp tmp
certs certs
app app
coverage
test-results.json
+2 -5
View File
@@ -1,4 +1,4 @@
FROM golang:1.25.7-alpine AS go_builder FROM golang:1.25.6-alpine AS go_builder
ARG VERSION=dev ARG VERSION=dev
ARG BUILD_DATE=unknown ARG BUILD_DATE=unknown
@@ -22,10 +22,7 @@ RUN --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
CGO_ENABLED=0 GOOS=linux \ CGO_ENABLED=0 GOOS=linux \
go build -trimpath \ go build -trimpath \
-ldflags="-w -s \ -ldflags="-w -s -X tunnel_pls/version.Version=${VERSION} -X tunnel_pls/version.BuildDate=${BUILD_DATE} -X tunnel_pls/version.Commit=${COMMIT}" \
-X tunnel_pls/internal/version.Version=${VERSION} \
-X tunnel_pls/internal/version.BuildDate=${BUILD_DATE} \
-X tunnel_pls/internal/version.Commit=${COMMIT}" \
-o /app/tunnel_pls \ -o /app/tunnel_pls \
. .
+116 -40
View File
@@ -1,22 +1,6 @@
<div align="center">
<img alt="gopher" title="gopher" src="./docs/images/gopher.png" width="325" />
# Tunnel Please # Tunnel Please
A lightweight SSH-based tunnel server A lightweight SSH-based tunnel server written in Go that enables secure TCP and HTTP forwarding with an interactive terminal interface for managing connections and custom subdomains.
<br/><br/>
[![Coverage](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=coverage&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Lines of Code](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=ncloc&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Quality Gate Status](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=alert_status&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Security Issues](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_security_issues&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Maintainability Rating](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_maintainability_rating&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Reliability Rating](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_reliability_rating&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Security Rating](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_security_rating&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
</div>
## Features ## Features
@@ -33,32 +17,108 @@ A lightweight SSH-based tunnel server
The following environment variables can be configured in the `.env` file: The following environment variables can be configured in the `.env` file:
| Variable | Description | Default | Required | | Variable | Description | Default | Required |
|---------------------|-----------------------------------------------------------------------------|-------------------------|---------------------| |----------|-------------|---------|----------|
| `DOMAIN` | Domain name for subdomain routing | `localhost` | No | | `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
| `PORT` | SSH server port | `2200` | No | | `PORT` | SSH server port | `2200` | No |
| `HTTP_PORT` | HTTP server port | `8080` | No | | `HTTP_PORT` | HTTP server port | `8080` | No |
| `HTTPS_PORT` | HTTPS server port | `8443` | No | | `HTTPS_PORT` | HTTPS server port | `8443` | No |
| `KEY_LOC` | Path to the private key file | `certs/privkey.pem` | No | | `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No | | `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No | | `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No |
| `TLS_STORAGE_PATH` | Path to store TLS certificates | `certs/tls/` | No | | `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | - | Yes (if auto-cert) |
| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No | | `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | `-` | Yes (if auto-cert) | | `CORS_LIST` | Comma-separated list of allowed CORS origins | - | No |
| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No | | `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
| `CORS_LIST` | Comma-separated list of allowed CORS origins | `-` | No | | `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No | | `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No | | `PPROF_PORT` | Port for pprof server | `6060` | No |
| `MAX_HEADER_SIZE` | Maximum size of HTTP headers in bytes (4096-131072) | `4096` | No | | `MODE` | Runtime mode: `standalone` (default, no gRPC/auth) or `node` (enable gRPC + auth) | `standalone` | No |
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No | | `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
| `PPROF_PORT` | Port for pprof server | `6060` | No | | `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
| `MODE` | Runtime mode: `standalone` or `node` | `standalone` | No | | `NODE_TOKEN` | Authentication token sent to controller in `node` mode | - (required in `node`) | Yes (node mode) |
| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | `-` | Yes (node mode) |
**Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality. **Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality.
### Automatic TLS Certificate Management
The server supports automatic TLS certificate generation and renewal using [CertMagic](https://github.com/caddyserver/certmagic) with Cloudflare DNS-01 challenge. This is required for wildcard certificate support (`*.yourdomain.com`).
**Certificate Storage:**
- TLS certificates are stored in `certs/tls/` (relative to application directory)
- User-provided certificates: `certs/tls/cert.pem` and `certs/tls/privkey.pem`
- CertMagic automatic certificates: `certs/tls/certmagic/`
- SSH keys are stored separately in `certs/ssh/`
**How it works:**
1. If user-provided certificates exist at `certs/tls/cert.pem` and `certs/tls/privkey.pem` and cover both `DOMAIN` and `*.DOMAIN`, they will be used
2. If certificates are missing, expired, expiring within 30 days, or don't cover the required domains, CertMagic will automatically obtain new certificates from Let's Encrypt
3. Certificates are automatically renewed before expiration
4. User-provided certificates support hot-reload (changes detected every 30 seconds)
**Cloudflare API Token Setup:**
To use automatic certificate generation, you need a Cloudflare API token with the following permissions:
1. Go to [Cloudflare Dashboard](https://dash.cloudflare.com/profile/api-tokens)
2. Click "Create Token"
3. Use "Create Custom Token" with these permissions:
- **Zone → Zone → Read** (for all zones or specific zone)
- **Zone → DNS → Edit** (for all zones or specific zone)
4. Copy the token and set it as `CF_API_TOKEN` environment variable
**Example configuration for automatic certificates:**
```env
DOMAIN=example.com
TLS_ENABLED=true
CF_API_TOKEN=your_cloudflare_api_token_here
ACME_EMAIL=admin@example.com
# ACME_STAGING=true # Uncomment for testing to avoid rate limits
```
### SSH Key Auto-Generation
The application will automatically generate a new 4096-bit RSA key pair at `certs/ssh/id_rsa` if it doesn't exist. This makes it easier to get started without manually creating SSH keys. SSH keys are stored separately from TLS certificates.
### Memory Optimization
The application uses a buffer pool with controlled buffer sizes to prevent excessive memory usage under high concurrent loads. The `BUFFER_SIZE` environment variable controls the size of buffers used for io.Copy operations:
- **Default:** 32768 bytes (32 KB) - Good balance for most scenarios
- **Minimum:** 4096 bytes (4 KB) - Lower memory usage, more CPU overhead
- **Maximum:** 1048576 bytes (1 MB) - Higher throughput, more memory usage
**Recommended settings based on load:**
- **Low traffic (<100 concurrent):** `BUFFER_SIZE=32768` (default)
- **High traffic (>100 concurrent):** `BUFFER_SIZE=16384` or `BUFFER_SIZE=8192`
- **Very high traffic (>1000 concurrent):** `BUFFER_SIZE=8192` or `BUFFER_SIZE=4096`
The buffer pool reuses buffers across connections, preventing memory fragmentation and reducing garbage collection pressure.
### Profiling with pprof
To enable profiling for performance analysis:
1. Set `PPROF_ENABLED=true` in your `.env` file
2. Optionally set `PPROF_PORT` to your desired port (default: 6060)
3. Access profiling data at `http://localhost:6060/debug/pprof/`
Common pprof endpoints:
- `/debug/pprof/` - Index page with available profiles
- `/debug/pprof/heap` - Memory allocation profile
- `/debug/pprof/goroutine` - Stack traces of all current goroutines
- `/debug/pprof/profile` - CPU profile (30-second sample by default)
- `/debug/pprof/trace` - Execution trace
Example usage with `go tool pprof`:
```bash
# Analyze CPU profile
go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
# Analyze memory heap
go tool pprof http://localhost:6060/debug/pprof/heap
```
## Docker Deployment ## Docker Deployment
Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`. Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`.
@@ -137,6 +197,22 @@ docker-compose -f docker-compose.tcp.yml up -d
docker-compose -f docker-compose.root.yml down docker-compose -f docker-compose.root.yml down
``` ```
### Volume Management
All configurations use a named volume `certs` for persistent storage:
- SSH keys: `/app/certs/ssh/`
- TLS certificates: `/app/certs/tls/`
To backup certificates:
```bash
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar czf /backup/certs-backup.tar.gz -C /data .
```
To restore certificates:
```bash
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar xzf /backup/certs-backup.tar.gz -C /data
```
### Recommendation ### Recommendation
**Use `docker-compose.root.yml`** for production deployments if you need: **Use `docker-compose.root.yml`** for production deployments if you need:
Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 MiB

+5 -10
View File
@@ -5,13 +5,12 @@ go 1.25.5
require ( require (
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0
github.com/caddyserver/certmagic v0.25.1 github.com/caddyserver/certmagic v0.25.1
github.com/charmbracelet/bubbles v0.21.1 github.com/charmbracelet/bubbles v0.21.0
github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/lipgloss v1.1.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/libdns/cloudflare v0.2.2 github.com/libdns/cloudflare v0.2.2
github.com/muesli/termenv v0.16.0 github.com/muesli/termenv v0.16.0
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.47.0 golang.org/x/crypto v0.47.0
google.golang.org/grpc v1.78.0 google.golang.org/grpc v1.78.0
google.golang.org/protobuf v1.36.11 google.golang.org/protobuf v1.36.11
@@ -22,13 +21,12 @@ require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/caddyserver/zerossl v0.1.4 // indirect github.com/caddyserver/zerossl v0.1.4 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.5 // indirect github.com/charmbracelet/x/ansi v0.11.3 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect github.com/charmbracelet/x/cellbuf v0.0.14 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect github.com/clipperhouse/displaywidth v0.6.2 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/libdns/libdns v1.1.1 // indirect github.com/libdns/libdns v1.1.1 // indirect
@@ -40,10 +38,8 @@ require (
github.com/miekg/dns v1.1.69 // indirect github.com/miekg/dns v1.1.69 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/cancelreader v0.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect github.com/rivo/uniseg v0.4.7 // indirect
github.com/sahilm/fuzzy v0.1.1 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/zeebo/blake3 v0.2.4 // indirect github.com/zeebo/blake3 v0.2.4 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
@@ -56,5 +52,4 @@ require (
golang.org/x/text v0.33.0 // indirect golang.org/x/text v0.33.0 // indirect
golang.org/x/tools v0.40.0 // indirect golang.org/x/tools v0.40.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )
-24
View File
@@ -12,8 +12,6 @@ github.com/caddyserver/zerossl v0.1.4 h1:CVJOE3MZeFisCERZjkxIcsqIH4fnFdlYWnPYeFt
github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
github.com/charmbracelet/bubbles v0.21.1 h1:nj0decPiixaZeL9diI4uzzQTkkz1kYY8+jgzCZXSmW0=
github.com/charmbracelet/bubbles v0.21.1/go.mod h1:HHvIYRCpbkCJw2yo0vNX1O5loCwSr9/mWS8GYSg50Sk=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk= github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
@@ -22,27 +20,18 @@ github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoF
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.3 h1:6DcVaqWI82BBVM/atTyq6yBoRLZFBsnoDoX9GCu2YOI= github.com/charmbracelet/x/ansi v0.11.3 h1:6DcVaqWI82BBVM/atTyq6yBoRLZFBsnoDoX9GCu2YOI=
github.com/charmbracelet/x/ansi v0.11.3/go.mod h1:yI7Zslym9tCJcedxz5+WBq+eUGMJT0bM06Fqy1/Y4dI= github.com/charmbracelet/x/ansi v0.11.3/go.mod h1:yI7Zslym9tCJcedxz5+WBq+eUGMJT0bM06Fqy1/Y4dI=
github.com/charmbracelet/x/ansi v0.11.5 h1:NBWeBpj/lJPE3Q5l+Lusa4+mH6v7487OP8K0r1IhRg4=
github.com/charmbracelet/x/ansi v0.11.5/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.14 h1:iUEMryGyFTelKW3THW4+FfPgi4fkmKnnaLOXuc+/Kj4= github.com/charmbracelet/x/cellbuf v0.0.14 h1:iUEMryGyFTelKW3THW4+FfPgi4fkmKnnaLOXuc+/Kj4=
github.com/charmbracelet/x/cellbuf v0.0.14/go.mod h1:P447lJl49ywBbil/KjCk2HexGh4tEY9LH0/1QrZZ9rA= github.com/charmbracelet/x/cellbuf v0.0.14/go.mod h1:P447lJl49ywBbil/KjCk2HexGh4tEY9LH0/1QrZZ9rA=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ= github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U= github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo= github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo=
github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
@@ -91,18 +80,8 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
@@ -159,8 +138,5 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
-196
View File
@@ -1,196 +0,0 @@
package bootstrap
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/key"
"tunnel_pls/internal/port"
"tunnel_pls/internal/random"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/internal/version"
"tunnel_pls/server"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type Bootstrap struct {
Randomizer random.Random
Config config.Config
SessionRegistry registry.Registry
Port port.Port
GrpcClient client.Client
ErrChan chan error
SignalChan chan os.Signal
}
func New(config config.Config, port port.Port) (*Bootstrap, error) {
randomizer := random.New()
sessionRegistry := registry.NewRegistry()
if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil {
return nil, err
}
grpcClient, err := client.New(config, sessionRegistry)
if err != nil {
return nil, err
}
errChan := make(chan error, 5)
signalChan := make(chan os.Signal, 1)
return &Bootstrap{
Randomizer: randomizer,
Config: config,
SessionRegistry: sessionRegistry,
Port: port,
GrpcClient: grpcClient,
ErrChan: errChan,
SignalChan: signalChan,
}, nil
}
func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) {
sshCfg := &ssh.ServerConfig{
NoClientAuth: true,
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
}
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
return nil, fmt.Errorf("generate ssh key: %w", err)
}
privateBytes, err := os.ReadFile(sshKeyPath)
if err != nil {
return nil, fmt.Errorf("read private key: %w", err)
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
sshCfg.AddHostKey(private)
return sshCfg, nil
}
func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error {
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
defer healthCancel()
if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil {
return fmt.Errorf("gRPC health check failed: %w", err)
}
go func() {
if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
}
}()
return nil
}
func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
httpserver := transport.NewHTTPServer(conf, registry)
ln, err := httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
if err = httpserver.Serve(ln); err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
}
}
func startHTTPSServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
tlsCfg, err := transport.NewTLSConfig(conf)
if err != nil {
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
return
}
httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg)
ln, err := httpsServer.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
return
}
if err = httpsServer.Serve(ln); err != nil {
errChan <- fmt.Errorf("error when serving https server: %w", err)
}
}
func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) {
sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort())
if err != nil {
errChan <- err
return
}
sshServer.Start()
errChan <- sshServer.Close()
}
func startPprof(pprofPort string, errChan chan<- error) {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
errChan <- fmt.Errorf("pprof server error: %v", err)
}
}
func (b *Bootstrap) Run() error {
sshConfig, err := newSSHConfig(b.Config.KeyLoc())
if err != nil {
return fmt.Errorf("failed to create SSH config: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM)
if b.Config.Mode() == types.ServerModeNODE {
err = b.startGRPCClient(ctx, b.Config, b.ErrChan)
if err != nil {
return fmt.Errorf("failed to start gRPC client: %w", err)
}
defer func(grpcClient client.Client) {
err = grpcClient.Close()
if err != nil {
log.Printf("failed to close gRPC client")
}
}(b.GrpcClient)
}
go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan)
if b.Config.TLSEnabled() {
go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan)
}
go func() {
startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan)
}()
if b.Config.PprofEnabled() {
go startPprof(b.Config.PprofPort(), b.ErrChan)
}
log.Println("All services started successfully")
select {
case err = <-b.ErrChan:
return fmt.Errorf("service error: %w", err)
case sig := <-b.SignalChan:
log.Printf("Received signal %s, initiating graceful shutdown", sig)
cancel()
return nil
}
}
-558
View File
@@ -1,558 +0,0 @@
package bootstrap
import (
"context"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
"os"
"path/filepath"
"strconv"
"testing"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
)
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockRandom struct {
mock.Mock
}
func (m *MockRandom) String(length int) (string, error) {
args := m.Called(length)
return args.String(0), args.Error(1)
}
type MockConfig struct {
mock.Mock
}
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode {
args := m.Called()
if args.Get(0) == nil {
return 0
}
switch v := args.Get(0).(type) {
case types.ServerMode:
return v
case int:
return types.ServerMode(v)
default:
return types.ServerMode(args.Int(0))
}
}
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type MockPort struct {
mock.Mock
}
func (m *MockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
var mPort uint16
if args.Get(0) != nil {
switch v := args.Get(0).(type) {
case int:
mPort = uint16(v)
case uint16:
mPort = v
case uint32:
mPort = uint16(v)
case int32:
mPort = uint16(v)
case float64:
mPort = uint16(v)
default:
mPort = uint16(args.Int(0))
}
}
return mPort, args.Bool(1)
}
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *MockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type MockGRPCClient struct {
mock.Mock
}
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
args := m.Called()
return args.Get(0).(*grpc.ClientConn)
}
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
args := m.Called(ctx, token)
return args.Bool(0), args.String(1), args.Error(2)
}
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
args := m.Called(ctx, domain, token)
return args.Error(0)
}
func (m *MockGRPCClient) Close() error {
args := m.Called()
return args.Error(0)
}
func TestNew(t *testing.T) {
tests := []struct {
name string
setupConfig func() config.Config
setupPort func() port.Port
wantErr bool
errContains string
}{
{
name: "Success New with default value",
wantErr: false,
},
{
name: "Error when AddRange fails",
setupPort: func() port.Port {
mockPort := &MockPort{}
mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range"))
return mockPort
},
wantErr: true,
errContains: "invalid port range",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var mockPort port.Port
if tt.setupPort != nil {
mockPort = tt.setupPort()
} else {
mockPort = port.New()
}
var mockConfig config.Config
if tt.setupConfig != nil {
mockConfig = tt.setupConfig()
} else {
var err error
mockConfig, err = config.MustLoad()
assert.NoError(t, err)
}
bootstrap, err := New(mockConfig, mockPort)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, bootstrap)
} else {
assert.NoError(t, err)
assert.NotNil(t, bootstrap)
assert.NotNil(t, bootstrap.Randomizer)
assert.NotNil(t, bootstrap.SessionRegistry)
assert.NotNil(t, bootstrap.Config)
assert.NotNil(t, bootstrap.Port)
assert.NotNil(t, bootstrap.ErrChan)
assert.NotNil(t, bootstrap.SignalChan)
}
})
}
}
func randomAvailablePort() (string, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", err
}
defer func(listener net.Listener) {
_ = listener.Close()
}(listener)
mPort := listener.Addr().(*net.TCPAddr).Port
return strconv.Itoa(mPort), nil
}
func TestRun(t *testing.T) {
mockRandom := &MockRandom{}
mockErrChan := make(chan error, 1)
mockSignalChan := make(chan os.Signal, 1)
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
tmpDir := t.TempDir()
keyLoc := filepath.Join(tmpDir, "key.key")
tests := []struct {
name string
setupConfig func() *MockConfig
setupGrpcClient func() *MockGRPCClient
needCerts bool
expectError bool
}{
{
name: "successful run and termination",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: false,
},
{
name: "error from SSH server invalid port",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("invalid")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: true,
},
{
name: "error from HTTP server invalid port",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("invalid")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: true,
},
{
name: "error from HTTPS server invalid port",
setupConfig: func() *MockConfig {
tempDir := os.TempDir()
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("invalid")
mockConfig.On("TLSEnabled").Return(true)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("TLSStoragePath").Return(tempDir)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: true,
},
{
name: "grpc health check failed",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("invalid")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed"))
return mockGRPCClient
},
expectError: true,
},
{
name: "successful run with pprof enabled",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
pprofPort, _ := randomAvailablePort()
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(true)
mockConfig.On("PprofPort").Return(pprofPort)
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: false,
}, {
name: "successful run in NODE mode with signal",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockGRPCClient.On("Close").Return(nil)
return mockGRPCClient
},
expectError: false,
}, {
name: "successful run in NODE mode with signal buf error when closing",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy"))
return mockGRPCClient
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConfig := tt.setupConfig()
mockGRPCClient := &MockGRPCClient{}
bootstrap := &Bootstrap{
Randomizer: mockRandom,
Config: mockConfig,
SessionRegistry: mockSessionRegistry,
Port: mockPort,
ErrChan: mockErrChan,
SignalChan: mockSignalChan,
GrpcClient: mockGRPCClient,
}
if tt.setupGrpcClient != nil {
bootstrap.GrpcClient = tt.setupGrpcClient()
}
done := make(chan error, 1)
go func() {
done <- bootstrap.Run()
}()
if tt.expectError {
err := <-done
assert.Error(t, err)
} else if tt.name == "successful run with pprof enabled" {
time.Sleep(200 * time.Millisecond)
fmt.Println(mockConfig.PprofPort())
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
err = resp.Body.Close()
assert.NoError(t, err)
mockSignalChan <- os.Interrupt
err = <-done
assert.NoError(t, err)
} else {
time.Sleep(time.Second)
mockSignalChan <- os.Interrupt
err := <-done
assert.NoError(t, err)
}
})
}
}
-7
View File
@@ -9,11 +9,8 @@ type Config interface {
HTTPPort() string HTTPPort() string
HTTPSPort() string HTTPSPort() string
KeyLoc() string
TLSEnabled() bool TLSEnabled() bool
TLSRedirect() bool TLSRedirect() bool
TLSStoragePath() string
ACMEEmail() string ACMEEmail() string
CFAPIToken() string CFAPIToken() string
@@ -23,7 +20,6 @@ type Config interface {
AllowedPortsEnd() uint16 AllowedPortsEnd() uint16
BufferSize() int BufferSize() int
HeaderSize() int
PprofEnabled() bool PprofEnabled() bool
PprofPort() string PprofPort() string
@@ -51,17 +47,14 @@ func (c *config) Domain() string { return c.domain }
func (c *config) SSHPort() string { return c.sshPort } func (c *config) SSHPort() string { return c.sshPort }
func (c *config) HTTPPort() string { return c.httpPort } func (c *config) HTTPPort() string { return c.httpPort }
func (c *config) HTTPSPort() string { return c.httpsPort } func (c *config) HTTPSPort() string { return c.httpsPort }
func (c *config) KeyLoc() string { return c.keyLoc }
func (c *config) TLSEnabled() bool { return c.tlsEnabled } func (c *config) TLSEnabled() bool { return c.tlsEnabled }
func (c *config) TLSRedirect() bool { return c.tlsRedirect } func (c *config) TLSRedirect() bool { return c.tlsRedirect }
func (c *config) TLSStoragePath() string { return c.tlsStoragePath }
func (c *config) ACMEEmail() string { return c.acmeEmail } func (c *config) ACMEEmail() string { return c.acmeEmail }
func (c *config) CFAPIToken() string { return c.cfAPIToken } func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging } func (c *config) ACMEStaging() bool { return c.acmeStaging }
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart } func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd } func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
func (c *config) BufferSize() int { return c.bufferSize } func (c *config) BufferSize() int { return c.bufferSize }
func (c *config) HeaderSize() int { return c.headerSize }
func (c *config) PprofEnabled() bool { return c.pprofEnabled } func (c *config) PprofEnabled() bool { return c.pprofEnabled }
func (c *config) PprofPort() string { return c.pprofPort } func (c *config) PprofPort() string { return c.pprofPort }
func (c *config) Mode() types.ServerMode { return c.mode } func (c *config) Mode() types.ServerMode { return c.mode }
-405
View File
@@ -1,405 +0,0 @@
package config
import (
"os"
"testing"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
)
func TestGetenv(t *testing.T) {
tests := []struct {
name string
key string
val string
def string
expected string
}{
{
name: "returns existing env",
key: "TEST_ENV_EXIST",
val: "value",
def: "default",
expected: "value",
},
{
name: "returns default when env missing",
key: "TEST_ENV_MISSING",
val: "",
def: "default",
expected: "default",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv(tt.key, tt.val)
} else {
err := os.Unsetenv(tt.key)
assert.NoError(t, err)
}
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
})
}
}
func TestGetenvBool(t *testing.T) {
tests := []struct {
name string
key string
val string
def bool
expected bool
}{
{
name: "returns true when env is true",
key: "TEST_BOOL_TRUE",
val: "true",
def: false,
expected: true,
},
{
name: "returns false when env is false",
key: "TEST_BOOL_FALSE",
val: "false",
def: true,
expected: false,
},
{
name: "returns default when env missing",
key: "TEST_BOOL_MISSING",
val: "",
def: true,
expected: true,
},
{
name: "returns false when env is not true",
key: "TEST_BOOL_INVALID",
val: "yes",
def: true,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv(tt.key, tt.val)
} else {
err := os.Unsetenv(tt.key)
assert.NoError(t, err)
}
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
})
}
}
func TestParseMode(t *testing.T) {
tests := []struct {
name string
mode string
expect types.ServerMode
expectErr bool
}{
{"standalone", "standalone", types.ServerModeSTANDALONE, false},
{"node", "node", types.ServerModeNODE, false},
{"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false},
{"invalid", "invalid", 0, true},
{"empty (default)", "", types.ServerModeSTANDALONE, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.mode != "" {
t.Setenv("MODE", tt.mode)
} else {
err := os.Unsetenv("MODE")
assert.NoError(t, err)
}
mode, err := parseMode()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expect, mode)
}
})
}
}
func TestParseAllowedPorts(t *testing.T) {
tests := []struct {
name string
val string
start uint16
end uint16
expectErr bool
}{
{"valid range", "1000-2000", 1000, 2000, false},
{"empty", "", 0, 0, false},
{"invalid format - no dash", "1000", 0, 0, true},
{"invalid format - too many dashes", "1000-2000-3000", 0, 0, true},
{"invalid start port", "abc-2000", 0, 0, true},
{"invalid end port", "1000-abc", 0, 0, true},
{"out of range start", "70000-80000", 0, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv("ALLOWED_PORTS", tt.val)
} else {
err := os.Unsetenv("ALLOWED_PORTS")
assert.NoError(t, err)
}
start, end, err := parseAllowedPorts()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.start, start)
assert.Equal(t, tt.end, end)
}
})
}
}
func TestParseBufferSize(t *testing.T) {
tests := []struct {
name string
val string
expect int
}{
{"valid size", "8192", 8192},
{"default size", "", 32768},
{"too small", "1024", 4096},
{"too large", "2000000", 4096},
{"invalid format", "abc", 4096},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv("BUFFER_SIZE", tt.val)
} else {
err := os.Unsetenv("BUFFER_SIZE")
assert.NoError(t, err)
}
size := parseBufferSize()
assert.Equal(t, tt.expect, size)
})
}
}
func TestParseHeaderSize(t *testing.T) {
tests := []struct {
name string
val string
expect int
}{
{"valid size", "8192", 8192},
{"default size", "", 4096},
{"too small", "1024", 4096},
{"too large", "2000000", 4096},
{"invalid format", "abc", 4096},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv("MAX_HEADER_SIZE", tt.val)
} else {
err := os.Unsetenv("MAX_HEADER_SIZE")
assert.NoError(t, err)
}
size := parseHeaderSize()
assert.Equal(t, tt.expect, size)
})
}
}
func TestParse(t *testing.T) {
tests := []struct {
name string
envs map[string]string
expectErr bool
}{
{
name: "minimal valid config",
envs: map[string]string{
"DOMAIN": "example.com",
},
expectErr: false,
},
{
name: "TLS enabled without token",
envs: map[string]string{
"TLS_ENABLED": "true",
},
expectErr: true,
},
{
name: "TLS enabled with token",
envs: map[string]string{
"TLS_ENABLED": "true",
"CF_API_TOKEN": "secret",
},
expectErr: false,
},
{
name: "Node mode without token",
envs: map[string]string{
"MODE": "node",
},
expectErr: true,
},
{
name: "Node mode with token",
envs: map[string]string{
"MODE": "node",
"NODE_TOKEN": "token",
},
expectErr: false,
},
{
name: "invalid mode",
envs: map[string]string{
"MODE": "invalid",
},
expectErr: true,
},
{
name: "invalid allowed ports",
envs: map[string]string{
"ALLOWED_PORTS": "1000",
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Clearenv()
for k, v := range tt.envs {
t.Setenv(k, v)
}
cfg, err := parse()
if tt.expectErr {
assert.Error(t, err)
assert.Nil(t, cfg)
} else {
assert.NoError(t, err)
assert.NotNil(t, cfg)
}
})
}
}
func TestGetters(t *testing.T) {
envs := map[string]string{
"DOMAIN": "example.com",
"PORT": "2222",
"HTTP_PORT": "80",
"HTTPS_PORT": "443",
"KEY_LOC": "certs/ssh/id_rsa",
"TLS_ENABLED": "true",
"TLS_REDIRECT": "true",
"TLS_STORAGE_PATH": "certs/tls/",
"ACME_EMAIL": "test@example.com",
"CF_API_TOKEN": "token",
"ACME_STAGING": "true",
"ALLOWED_PORTS": "1000-2000",
"BUFFER_SIZE": "16384",
"MAX_HEADER_SIZE": "4096",
"PPROF_ENABLED": "true",
"PPROF_PORT": "7070",
"MODE": "standalone",
"GRPC_ADDRESS": "127.0.0.1",
"GRPC_PORT": "9090",
"NODE_TOKEN": "ntoken",
}
os.Clearenv()
for k, v := range envs {
t.Setenv(k, v)
}
cfg, err := parse()
assert.NoError(t, err)
assert.Equal(t, "example.com", cfg.Domain())
assert.Equal(t, "2222", cfg.SSHPort())
assert.Equal(t, "80", cfg.HTTPPort())
assert.Equal(t, "443", cfg.HTTPSPort())
assert.Equal(t, "certs/ssh/id_rsa", cfg.KeyLoc())
assert.Equal(t, true, cfg.TLSEnabled())
assert.Equal(t, true, cfg.TLSRedirect())
assert.Equal(t, "certs/tls/", cfg.TLSStoragePath())
assert.Equal(t, "test@example.com", cfg.ACMEEmail())
assert.Equal(t, "token", cfg.CFAPIToken())
assert.Equal(t, true, cfg.ACMEStaging())
assert.Equal(t, uint16(1000), cfg.AllowedPortsStart())
assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd())
assert.Equal(t, 16384, cfg.BufferSize())
assert.Equal(t, 4096, cfg.HeaderSize())
assert.Equal(t, true, cfg.PprofEnabled())
assert.Equal(t, "7070", cfg.PprofPort())
assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode())
assert.Equal(t, "127.0.0.1", cfg.GRPCAddress())
assert.Equal(t, "9090", cfg.GRPCPort())
assert.Equal(t, "ntoken", cfg.NodeToken())
}
func TestMustLoad(t *testing.T) {
t.Run("success", func(t *testing.T) {
os.Clearenv()
t.Setenv("DOMAIN", "example.com")
cfg, err := MustLoad()
assert.NoError(t, err)
assert.NotNil(t, cfg)
})
t.Run("loadEnvFile error", func(t *testing.T) {
err := os.Mkdir(".env", 0755)
assert.NoError(t, err)
defer func() {
err = os.Remove(".env")
assert.NoError(t, err)
}()
cfg, err := MustLoad()
assert.Error(t, err)
assert.Nil(t, cfg)
})
t.Run("parse error", func(t *testing.T) {
os.Clearenv()
t.Setenv("MODE", "invalid")
cfg, err := MustLoad()
assert.Error(t, err)
assert.Nil(t, cfg)
})
}
func TestLoadEnvFile(t *testing.T) {
t.Run("file exists", func(t *testing.T) {
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
assert.NoError(t, err)
defer func() {
err = os.Remove(".env")
assert.NoError(t, err)
}()
err = loadEnvFile()
assert.NoError(t, err)
assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE"))
})
t.Run("file missing", func(t *testing.T) {
_ = os.Remove(".env")
err := loadEnvFile()
assert.NoError(t, err)
})
}
+5 -25
View File
@@ -18,20 +18,17 @@ type config struct {
httpPort string httpPort string
httpsPort string httpsPort string
keyLoc string tlsEnabled bool
tlsRedirect bool
tlsEnabled bool acmeEmail string
tlsRedirect bool cfAPIToken string
tlsStoragePath string acmeStaging bool
acmeEmail string
cfAPIToken string
acmeStaging bool
allowedPortsStart uint16 allowedPortsStart uint16
allowedPortsEnd uint16 allowedPortsEnd uint16
bufferSize int bufferSize int
headerSize int
pprofEnabled bool pprofEnabled bool
pprofPort string pprofPort string
@@ -54,11 +51,8 @@ func parse() (*config, error) {
httpPort := getenv("HTTP_PORT", "8080") httpPort := getenv("HTTP_PORT", "8080")
httpsPort := getenv("HTTPS_PORT", "8443") httpsPort := getenv("HTTPS_PORT", "8443")
keyLoc := getenv("KEY_LOC", "certs/privkey.pem")
tlsEnabled := getenvBool("TLS_ENABLED", false) tlsEnabled := getenvBool("TLS_ENABLED", false)
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false) tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
tlsStoragePath := getenv("TLS_STORAGE_PATH", "certs/tls/")
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain) acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
acmeStaging := getenvBool("ACME_STAGING", false) acmeStaging := getenvBool("ACME_STAGING", false)
@@ -74,7 +68,6 @@ func parse() (*config, error) {
} }
bufferSize := parseBufferSize() bufferSize := parseBufferSize()
headerSize := parseHeaderSize()
pprofEnabled := getenvBool("PPROF_ENABLED", false) pprofEnabled := getenvBool("PPROF_ENABLED", false)
pprofPort := getenv("PPROF_PORT", "6060") pprofPort := getenv("PPROF_PORT", "6060")
@@ -92,17 +85,14 @@ func parse() (*config, error) {
sshPort: sshPort, sshPort: sshPort,
httpPort: httpPort, httpPort: httpPort,
httpsPort: httpsPort, httpsPort: httpsPort,
keyLoc: keyLoc,
tlsEnabled: tlsEnabled, tlsEnabled: tlsEnabled,
tlsRedirect: tlsRedirect, tlsRedirect: tlsRedirect,
tlsStoragePath: tlsStoragePath,
acmeEmail: acmeEmail, acmeEmail: acmeEmail,
cfAPIToken: cfToken, cfAPIToken: cfToken,
acmeStaging: acmeStaging, acmeStaging: acmeStaging,
allowedPortsStart: start, allowedPortsStart: start,
allowedPortsEnd: end, allowedPortsEnd: end,
bufferSize: bufferSize, bufferSize: bufferSize,
headerSize: headerSize,
pprofEnabled: pprofEnabled, pprofEnabled: pprofEnabled,
pprofPort: pprofPort, pprofPort: pprofPort,
mode: mode, mode: mode,
@@ -164,16 +154,6 @@ func parseBufferSize() int {
return size return size
} }
func parseHeaderSize() int {
raw := getenv("MAX_HEADER_SIZE", "4096")
size, err := strconv.Atoi(raw)
if err != nil || size < 4096 || size > 131072 {
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
return 4096
}
return size
}
func getenv(key, def string) string { func getenv(key, def string) string {
if v := os.Getenv(key); v != "" { if v := os.Getenv(key); v != "" {
return v return v
+122 -132
View File
@@ -38,15 +38,7 @@ type client struct {
closing bool closing bool
} }
var ( func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) {
grpcNewClient = grpc.NewClient
healthNewHealthClient = grpc_health_v1.NewHealthClient
initialBackoff = time.Second
)
func New(config config.Config, sessionRegistry registry.Registry) (Client, error) {
address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort())
var opts []grpc.DialOption var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
@@ -66,7 +58,7 @@ func New(config config.Config, sessionRegistry registry.Registry) (Client, error
), ),
) )
conn, err := grpcNewClient(address, opts...) conn, err := grpc.NewClient(address, opts...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err) return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err)
} }
@@ -85,100 +77,85 @@ func New(config config.Config, sessionRegistry registry.Registry) (Client, error
} }
func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error { func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error {
backoff := initialBackoff const (
baseBackoff = time.Second
maxBackoff = 30 * time.Second
)
for { backoff := baseBackoff
if err := c.subscribeAndProcess(ctx, identity, authToken, &backoff); err != nil { wait := func() error {
return err if backoff <= 0 {
return nil
}
select {
case <-time.After(backoff):
return nil
case <-ctx.Done():
return ctx.Err()
} }
} }
} growBackoff := func() {
backoff *= 2
func (c *client) subscribeAndProcess(ctx context.Context, identity, authToken string, backoff *time.Duration) error { if backoff > maxBackoff {
subscribe, err := c.eventService.Subscribe(ctx) backoff = maxBackoff
if err != nil { }
return c.handleSubscribeError(ctx, err, backoff)
} }
err = subscribe.Send(&proto.Node{ for {
Type: proto.EventType_AUTHENTICATION, subscribe, err := c.eventService.Subscribe(ctx)
Payload: &proto.Node_AuthEvent{ if err != nil {
AuthEvent: &proto.Authentication{ if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
Identity: identity, return err
AuthToken: authToken, }
if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
return err
}
if err = wait(); err != nil {
return err
}
growBackoff()
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
continue
}
err = subscribe.Send(&proto.Node{
Type: proto.EventType_AUTHENTICATION,
Payload: &proto.Node_AuthEvent{
AuthEvent: &proto.Authentication{
Identity: identity,
AuthToken: authToken,
},
}, },
}, })
})
if err != nil { if err != nil {
return c.handleAuthError(ctx, err, backoff) log.Println("Authentication failed to send to gRPC server:", err)
} if c.isConnectionError(err) {
if err = wait(); err != nil {
return err
}
growBackoff()
continue
}
return err
}
log.Println("Authentication Successfully sent to gRPC server")
backoff = baseBackoff
log.Println("Authentication Successfully sent to gRPC server") if err = c.processEventStream(subscribe); err != nil {
*backoff = time.Second if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
return err
return c.handleStreamError(ctx, c.processEventStream(subscribe), backoff) }
} if c.isConnectionError(err) {
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
func (c *client) handleSubscribeError(ctx context.Context, err error, backoff *time.Duration) error { if err = wait(); err != nil {
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil { return err
return err }
} growBackoff()
if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated { continue
return err }
} return err
if err = c.wait(ctx, *backoff); err != nil { }
return err
}
c.growBackoff(backoff)
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
return nil
}
func (c *client) handleAuthError(ctx context.Context, err error, backoff *time.Duration) error {
log.Println("Authentication failed to send to gRPC server:", err)
if !c.isConnectionError(err) {
return err
}
if err := c.wait(ctx, *backoff); err != nil {
return err
}
c.growBackoff(backoff)
return nil
}
func (c *client) handleStreamError(ctx context.Context, err error, backoff *time.Duration) error {
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
return err
}
if !c.isConnectionError(err) {
return err
}
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
if err := c.wait(ctx, *backoff); err != nil {
return err
}
c.growBackoff(backoff)
return nil
}
func (c *client) wait(ctx context.Context, duration time.Duration) error {
if duration <= 0 {
return nil
}
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (c *client) growBackoff(backoff *time.Duration) {
const maxBackoff = 30 * time.Second
*backoff *= 2
if *backoff > maxBackoff {
*backoff = maxBackoff
} }
} }
@@ -214,20 +191,35 @@ func (c *client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, pr
func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
slugEvent := evt.GetSlugEvent() slugEvent := evt.GetSlugEvent()
user := slugEvent.GetUser() user := slugEvent.GetUser()
oldKey := types.SessionKey{Id: slugEvent.GetOld(), Type: types.TunnelTypeHTTP} oldSlug := slugEvent.GetOld()
newKey := types.SessionKey{Id: slugEvent.GetNew(), Type: types.TunnelTypeHTTP} newSlug := slugEvent.GetNew()
userSession, err := c.sessionRegistry.Get(oldKey) userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP})
if err != nil { if err != nil {
return c.sendSlugChangeResponse(subscribe, false, err.Error()) return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
},
}, "slug change failure response")
} }
if err = c.sessionRegistry.Update(user, oldKey, newKey); err != nil { if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil {
return c.sendSlugChangeResponse(subscribe, false, err.Error()) return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
},
}, "slug change failure response")
} }
userSession.Interaction().Redraw() userSession.Interaction().Redraw()
return c.sendSlugChangeResponse(subscribe, true, "") return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{Success: true, Message: ""},
},
}, "slug change success response")
} }
func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
@@ -246,7 +238,12 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
}) })
} }
return c.sendGetSessionsResponse(subscribe, details) return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_GET_SESSIONS,
Payload: &proto.Node_GetSessionsEvent{
GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
},
}, "send get sessions response")
} }
func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
@@ -256,46 +253,39 @@ func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto
tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType()) tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType())
if err != nil { if err != nil {
return c.sendTerminateSessionResponse(subscribe, false, err.Error()) return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
},
}, "terminate session invalid tunnel type")
} }
userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType}) userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType})
if err != nil { if err != nil {
return c.sendTerminateSessionResponse(subscribe, false, err.Error()) return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
},
}, "terminate session fetch failed")
} }
if err = userSession.Lifecycle().Close(); err != nil { if err = userSession.Lifecycle().Close(); err != nil {
return c.sendTerminateSessionResponse(subscribe, false, err.Error()) return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
},
}, "terminate session close failed")
} }
return c.sendTerminateSessionResponse(subscribe, true, "")
}
func (c *client) sendSlugChangeResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{Success: success, Message: message},
},
}, "slug change response")
}
func (c *client) sendGetSessionsResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], details []*proto.Detail) error {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_GET_SESSIONS,
Payload: &proto.Node_GetSessionsEvent{
GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
},
}, "send get sessions response")
}
func (c *client) sendTerminateSessionResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION, Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{ Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: success, Message: message}, TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: true, Message: ""},
}, },
}, "terminate session response") }, "terminate session success response")
} }
func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error { func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error {
@@ -336,7 +326,7 @@ func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bo
} }
func (c *client) CheckServerHealth(ctx context.Context) error { func (c *client) CheckServerHealth(ctx context.Context) error {
healthClient := healthNewHealthClient(c.ClientConn()) healthClient := grpc_health_v1.NewHealthClient(c.ClientConn())
resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{ resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{
Service: "", Service: "",
}) })
File diff suppressed because it is too large Load Diff
-227
View File
@@ -1,227 +0,0 @@
package header
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewRequest(t *testing.T) {
tests := []struct {
name string
data []byte
expectErr bool
errContains string
expectMethod string
expectPath string
expectVersion string
expectHeaders map[string]string
}{
{
name: "success",
data: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\nX-Custom: value\r\n\r\n"),
expectErr: false,
expectMethod: "GET",
expectPath: "/path",
expectVersion: "HTTP/1.1",
expectHeaders: map[string]string{
"Host": "example.com",
"X-Custom": "value",
},
},
{
name: "no CRLF in start line",
data: []byte("GET /path HTTP/1.1"),
expectErr: true,
errContains: "no CRLF found in start line",
},
{
name: "invalid start line - missing method",
data: []byte("INVALID\r\n\r\n"),
expectErr: true,
errContains: "invalid start line: missing method",
},
{
name: "invalid start line - missing version",
data: []byte("GET /path\r\n\r\n"),
expectErr: true,
errContains: "invalid start line: missing version",
},
{
name: "invalid start line - multiple spaces",
data: []byte("GET /path HTTP/1.1\r\n\r\n"),
expectErr: false,
expectMethod: "GET",
expectPath: "",
expectVersion: "/path HTTP/1.1",
expectHeaders: map[string]string{},
},
{
name: "start line with trailing space",
data: []byte("GET / HTTP/1.1 \r\n\r\n"),
expectErr: false,
expectMethod: "GET",
expectPath: "/",
expectVersion: "HTTP/1.1 ",
expectHeaders: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := NewRequest(tt.data)
if tt.expectErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, req)
} else {
assert.NoError(t, err)
assert.NotNil(t, req)
assert.Equal(t, tt.expectMethod, req.Method())
assert.Equal(t, tt.expectPath, req.Path())
assert.Equal(t, tt.expectVersion, req.Version())
for k, v := range tt.expectHeaders {
assert.Equal(t, v, req.Value(k))
}
}
})
}
}
func TestRequestHeaderMethods(t *testing.T) {
data := []byte("GET / HTTP/1.1\r\nHost: original\r\n\r\n")
req, _ := NewRequest(data)
req.Set("Host", "updated")
req.Set("X-New", "new-value")
assert.Equal(t, "updated", req.Value("Host"))
assert.Equal(t, "new-value", req.Value("X-New"))
assert.Equal(t, "", req.Value("Non-Existent"))
req.Remove("X-New")
assert.Equal(t, "", req.Value("X-New"))
final := req.Finalize()
assert.Contains(t, string(final), "GET / HTTP/1.1\r\n")
assert.Contains(t, string(final), "Host: updated\r\n")
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
}
func TestNewResponse(t *testing.T) {
tests := []struct {
name string
data []byte
expectErr bool
errContains string
expectHeaders map[string]string
}{
{
name: "success",
data: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"),
expectErr: false,
expectHeaders: map[string]string{
"Content-Length": "0",
},
},
{
name: "invalid response - no CRLF",
data: []byte("HTTP/1.1 200 OK"),
expectErr: true,
errContains: "no CRLF found in start line",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := NewResponse(tt.data)
if tt.expectErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, resp)
} else {
assert.NoError(t, err)
assert.NotNil(t, resp)
for k, v := range tt.expectHeaders {
assert.Equal(t, v, resp.Value(k))
}
}
})
}
}
func TestResponseHeaderMethods(t *testing.T) {
data := []byte("HTTP/1.1 200 OK\r\nServer: old\r\n\r\n")
resp, _ := NewResponse(data)
resp.Set("Server", "new")
resp.Set("X-Res", "val")
assert.Equal(t, "new", resp.Value("Server"))
assert.Equal(t, "val", resp.Value("X-Res"))
resp.Remove("X-Res")
assert.Equal(t, "", resp.Value("X-Res"))
final := resp.Finalize()
assert.Contains(t, string(final), "HTTP/1.1 200 OK\r\n")
assert.Contains(t, string(final), "Server: new\r\n")
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
}
func TestSetRemainingHeaders(t *testing.T) {
tests := []struct {
name string
data []byte
initialHeaders map[string]string
expectHeaders map[string]string
}{
{
name: "various header formats",
data: []byte("K1: V1\r\nK2:V2\r\n K3 : V3 \r\nNoColon\r\n\r\n"),
expectHeaders: map[string]string{
"K1": "V1",
"K2": "V2",
"K3": "V3",
},
},
{
name: "no trailing CRLF",
data: []byte("K1: V1"),
expectHeaders: map[string]string{
"K1": "V1",
},
},
{
name: "empty lines",
data: []byte("\r\nK1: V1"),
expectHeaders: map[string]string{},
},
{
name: "headers with only colon",
data: []byte(": value\r\nkey:\r\n"),
expectHeaders: map[string]string{
"": "value",
"key": "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &requestHeader{headers: make(map[string]string)}
if tt.initialHeaders != nil {
req.headers = tt.initialHeaders
}
setRemainingHeaders(tt.data, req)
assert.Equal(t, len(tt.expectHeaders), len(req.headers))
for k, v := range tt.expectHeaders {
assert.Equal(t, v, req.headers[k])
}
})
}
}
+71
View File
@@ -1,6 +1,7 @@
package header package header
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
) )
@@ -35,6 +36,31 @@ func setRemainingHeaders(remaining []byte, header interface {
} }
} }
func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) {
header := &requestHeader{
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
}
startLine := headerData[:lineEnd]
header.startLine = startLine
var err error
header.method, header.path, header.version, err = parseStartLine(startLine)
if err != nil {
return nil, err
}
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func parseStartLine(startLine []byte) (method, path, version string, err error) { func parseStartLine(startLine []byte) (method, path, version string, err error) {
firstSpace := bytes.IndexByte(startLine, ' ') firstSpace := bytes.IndexByte(startLine, ' ')
if firstSpace == -1 { if firstSpace == -1 {
@@ -54,6 +80,51 @@ func parseStartLine(startLine []byte) (method, path, version string, err error)
return method, path, version, nil return method, path, version, nil
} }
func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) {
header := &requestHeader{
headers: make(map[string]string, 16),
}
startLineBytes, err := br.ReadSlice('\n')
if err != nil {
return nil, err
}
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
header.startLine = make([]byte, len(startLineBytes))
copy(header.startLine, startLineBytes)
header.method, header.path, header.version, err = parseStartLine(header.startLine)
if err != nil {
return nil, err
}
for {
lineBytes, err := br.ReadSlice('\n')
if err != nil {
return nil, err
}
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
if len(lineBytes) == 0 {
break
}
colonIdx := bytes.IndexByte(lineBytes, ':')
if colonIdx == -1 {
continue
}
key := bytes.TrimSpace(lineBytes[:colonIdx])
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
header.headers[string(key)] = string(value)
}
return header, nil
}
func finalize(startLine []byte, headers map[string]string) []byte { func finalize(startLine []byte, headers map[string]string) []byte {
size := len(startLine) + 2 size := len(startLine) + 2
for key, val := range headers { for key, val := range headers {
+9 -23
View File
@@ -1,33 +1,19 @@
package header package header
import ( import (
"bytes" "bufio"
"fmt" "fmt"
) )
func NewRequest(headerData []byte) (RequestHeader, error) { func NewRequest(r interface{}) (RequestHeader, error) {
header := &requestHeader{ switch v := r.(type) {
headers: make(map[string]string, 16), case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
} }
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
}
startLine := headerData[:lineEnd]
header.startLine = startLine
var err error
header.method, header.path, header.version, err = parseStartLine(startLine)
if err != nil {
return nil, err
}
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
} }
func (req *requestHeader) Value(key string) string { func (req *requestHeader) Value(key string) string {
+2 -4
View File
@@ -30,6 +30,7 @@ type http struct {
remoteAddr net.Addr remoteAddr net.Addr
writer io.Writer writer io.Writer
reader io.Reader reader io.Reader
headerBuf []byte
buf []byte buf []byte
respHeader header.ResponseHeader respHeader header.ResponseHeader
reqHeader header.RequestHeader reqHeader header.RequestHeader
@@ -71,10 +72,7 @@ func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
} }
func (hs *http) Close() error { func (hs *http) Close() error {
if closer, ok := hs.writer.(io.Closer); ok { return hs.writer.(io.Closer).Close()
return closer.Close()
}
return nil
} }
func (hs *http) CloseWrite() error { func (hs *http) CloseWrite() error {
-765
View File
@@ -1,765 +0,0 @@
package stream
import (
"bytes"
"io"
"strings"
"testing"
"tunnel_pls/internal/http/header"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockAddr struct {
mock.Mock
}
func (m *MockAddr) String() string {
args := m.Called()
return args.String(0)
}
func (m *MockAddr) Network() string {
args := m.Called()
return args.String(0)
}
type MockRequestMiddleware struct {
mock.Mock
}
func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
args := m.Called(h)
return args.Error(0)
}
type MockResponseMiddleware struct {
mock.Mock
}
func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
args := m.Called(h, body)
return args.Error(0)
}
type MockReadWriter struct {
mock.Mock
bytes.Buffer
}
func (m *MockReadWriter) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriter) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriter) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockReadWriter) CloseWrite() error {
args := m.Called()
return args.Error(0)
}
type MockReadWriterOnlyCloser struct {
mock.Mock
bytes.Buffer
}
func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriterOnlyCloser) Close() error {
args := m.Called()
return args.Error(0)
}
type MockWriterOnly struct {
mock.Mock
}
func (m *MockWriterOnly) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockWriterOnly) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
type MockReader struct {
mock.Mock
}
func (m *MockReader) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
type MockWriter struct {
mock.Mock
}
func (m *MockWriter) Write(p []byte) (int, error) {
ret := m.Called(p)
var n int
var err error
switch v := ret.Get(0).(type) {
case func([]byte) int:
n = v(p)
case int:
n = v
default:
n = len(p)
}
switch v := ret.Get(1).(type) {
case func([]byte) error:
err = v(p)
case error:
err = v
default:
err = nil
}
return n, err
}
func (m *MockWriter) Close() error {
args := m.Called()
return args.Error(0)
}
func TestHTTPMethods(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
rw := new(MockReadWriter)
hs := New(rw, rw, addr)
assert.Equal(t, addr, hs.RemoteAddr())
reqMW := new(MockRequestMiddleware)
hs.UseRequestMiddleware(reqMW)
assert.Equal(t, 1, len(hs.RequestMiddlewares()))
assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
respMW := new(MockResponseMiddleware)
hs.UseResponseMiddleware(respMW)
assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
hs.SetRequestHeader(reqH)
}
func TestApplyMiddlewares(t *testing.T) {
tests := []struct {
name string
setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware)
apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
expectErr bool
}{
{
name: "apply request middleware success",
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.RequestHeader)
h.Set("X-Middleware", "true")
}).Return(nil)
hs.UseRequestMiddleware(reqMW)
},
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyRequestMiddlewares(reqH)
},
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
assert.Equal(t, "true", reqH.Value("X-Middleware"))
},
},
{
name: "apply response middleware success",
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.ResponseHeader)
h.Set("X-Resp-Middleware", "true")
}).Return(nil)
hs.UseResponseMiddleware(respMW)
},
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
},
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
assert.Equal(t, "true", respH.Value("X-Resp-Middleware"))
},
},
{
name: "apply request middleware error",
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError)
hs.UseRequestMiddleware(reqMW)
},
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyRequestMiddlewares(reqH)
},
expectErr: true,
},
{
name: "apply response middleware error",
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError)
hs.UseResponseMiddleware(respMW)
},
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
rw := new(MockReadWriter)
hs := New(rw, rw, addr)
reqMW := new(MockRequestMiddleware)
respMW := new(MockResponseMiddleware)
tt.setup(hs, reqMW, respMW)
err := tt.apply(hs, reqH, respH)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.verify != nil {
tt.verify(t, reqH, respH)
}
}
reqMW.AssertExpectations(t)
respMW.AssertExpectations(t)
})
}
}
func TestCloseMethods(t *testing.T) {
tests := []struct {
name string
setup func() (io.Writer, io.Reader)
op func(HTTP) error
verify func(*testing.T, io.Writer)
}{
{
name: "Close success",
setup: func() (io.Writer, io.Reader) {
rw := new(MockReadWriter)
rw.On("Close").Return(nil)
return rw, rw
},
op: func(hs HTTP) error { return hs.Close() },
verify: func(t *testing.T, w io.Writer) {
w.(*MockReadWriter).AssertCalled(t, "Close")
},
},
{
name: "CloseWrite with CloseWrite implementation",
setup: func() (io.Writer, io.Reader) {
rw := new(MockReadWriter)
rw.On("CloseWrite").Return(nil)
return rw, rw
},
op: func(hs HTTP) error { return hs.CloseWrite() },
verify: func(t *testing.T, w io.Writer) {
w.(*MockReadWriter).AssertCalled(t, "CloseWrite")
},
},
{
name: "CloseWrite fallback to Close",
setup: func() (io.Writer, io.Reader) {
rw := new(MockReadWriterOnlyCloser)
rw.On("Close").Return(nil)
return rw, rw
},
op: func(hs HTTP) error { return hs.CloseWrite() },
verify: func(t *testing.T, w io.Writer) {
w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close")
},
},
{
name: "Close with No Closer",
setup: func() (io.Writer, io.Reader) {
w := new(MockWriterOnly)
r := new(MockReader)
return w, r
},
op: func(hs HTTP) error { return hs.Close() },
},
{
name: "CloseWrite with No CloseWrite and No Closer",
setup: func() (io.Writer, io.Reader) {
w := new(MockWriterOnly)
r := new(MockReader)
return w, r
},
op: func(hs HTTP) error { return hs.CloseWrite() },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
w, r := tt.setup()
hs := New(w, r, addr)
assert.NotPanics(t, func() {
err := tt.op(hs)
assert.NoError(t, err)
})
if tt.verify != nil {
tt.verify(t, w)
}
})
}
}
func TestSplitHeaderAndBody(t *testing.T) {
tests := []struct {
name string
data []byte
delimiterIdx int
expectHeader []byte
expectBody []byte
}{
{
name: "standard",
data: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nBodyContent"),
delimiterIdx: 31,
expectHeader: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"),
expectBody: []byte("BodyContent"),
},
{
name: "empty body",
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
delimiterIdx: 15,
expectHeader: []byte("HTTP/1.1 200 OK\r\n\r\n"),
expectBody: []byte(""),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, b := splitHeaderAndBody(tt.data, tt.delimiterIdx)
assert.Equal(t, tt.expectHeader, h)
assert.Equal(t, tt.expectBody, b)
})
}
}
func TestIsHTTPHeader(t *testing.T) {
tests := []struct {
name string
buf []byte
expect bool
}{
{
name: "valid request",
buf: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"),
expect: true,
},
{
name: "valid response",
buf: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"),
expect: true,
},
{
name: "invalid start line",
buf: []byte("NOT_HTTP /path\r\nHost: example.com\r\n\r\n"),
expect: false,
},
{
name: "invalid header line (no colon)",
buf: []byte("GET / HTTP/1.1\r\nInvalidHeaderLine\r\n\r\n"),
expect: false,
},
{
name: "invalid header line (colon at 0)",
buf: []byte("GET / HTTP/1.1\r\n: value\r\n\r\n"),
expect: false,
},
{
name: "empty header section",
buf: []byte("GET / HTTP/1.1\r\n\r\n"),
expect: true,
},
{
name: "multiple headers",
buf: []byte("GET / HTTP/1.1\r\nH1: V1\r\nH2: V2\r\n\r\n"),
expect: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isHTTPHeader(tt.buf)
assert.Equal(t, tt.expect, result)
})
}
}
func TestRead(t *testing.T) {
tests := []struct {
name string
input []byte
readLen int
expectContent string
expectRead int
expectErr bool
middlewareErr error
isHTTP bool
}{
{
name: "valid http request",
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\nBody"),
readLen: 100,
expectContent: "Body",
expectRead: 54,
isHTTP: true,
},
{
name: "non-http data",
input: []byte("Some random data\r\n\r\nMore data"),
readLen: 100,
expectContent: "Some random data\r\n\r\nMore data",
expectRead: 29,
isHTTP: false,
},
{
name: "no delimiter",
input: []byte("Partial data without delimiter"),
readLen: 100,
expectContent: "Partial data without delimiter",
expectRead: 30,
isHTTP: false,
},
{
name: "middleware error",
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\n"),
readLen: 100,
middlewareErr: assert.AnError,
expectErr: true,
isHTTP: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
reader := new(MockReader)
writer := new(MockWriterOnly)
if tt.expectErr || tt.name == "valid http request" {
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
copy(p, tt.input)
}).Return(len(tt.input), io.EOF).Once()
} else {
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
copy(p, tt.input)
}).Return(len(tt.input), nil).Once()
}
hs := New(writer, reader, addr)
reqMW := new(MockRequestMiddleware)
if tt.isHTTP {
if tt.middlewareErr != nil {
reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr)
} else {
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.RequestHeader)
h.Set("X-Middleware", "true")
}).Return(nil)
}
}
hs.UseRequestMiddleware(reqMW)
p := make([]byte, tt.readLen)
n, err := hs.Read(p)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectRead, n)
if tt.name == "valid http request" {
content := string(p[:n])
assert.Contains(t, content, "GET / HTTP/1.1\r\n")
assert.Contains(t, content, "Host: test\r\n")
assert.Contains(t, content, "X-Middleware: true\r\n")
assert.True(t, bytes.HasSuffix(p[:n], []byte("\r\n\r\nBody")))
} else {
assert.Equal(t, tt.expectContent, string(p[:n]))
}
}
if tt.isHTTP {
reqMW.AssertExpectations(t)
}
reader.AssertExpectations(t)
})
}
}
func TestWrite(t *testing.T) {
tests := []struct {
name string
writes [][]byte
expectWritten string
expectErr bool
middlewareErr error
isHTTP bool
}{
{
name: "valid http response in one write",
writes: [][]byte{
[]byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
},
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
isHTTP: true,
},
{
name: "valid http response in multiple writes",
writes: [][]byte{
[]byte("HTTP/1.1 200 OK\r\n"),
[]byte("Content-Length: 4\r\n\r\n"),
[]byte("Body"),
},
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
isHTTP: true,
},
{
name: "non-http data",
writes: [][]byte{
[]byte("Random data with delimiter\r\n\r\nFlush"),
},
expectWritten: "Random data with delimiter\r\n\r\nFlush",
isHTTP: false,
},
{
name: "bypass buffering",
writes: [][]byte{
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
},
expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
"HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
isHTTP: true,
},
{
name: "middleware error",
writes: [][]byte{
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
},
middlewareErr: assert.AnError,
expectErr: true,
isHTTP: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
var writtenData bytes.Buffer
writer := new(MockWriter)
writer.On("Write", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
writtenData.Write(p)
}).Return(func(p []byte) int {
return len(p)
}, nil)
reader := new(MockReader)
hs := New(writer, reader, addr)
respMW := new(MockResponseMiddleware)
if tt.isHTTP {
if tt.middlewareErr != nil {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr)
} else {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.ResponseHeader)
h.Set("X-Resp-Middleware", "true")
}).Return(nil)
}
}
hs.UseResponseMiddleware(respMW)
var totalN int
var err error
for _, w := range tt.writes {
var n int
n, err = hs.Write(w)
if err != nil {
break
}
totalN += n
}
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
written := writtenData.String()
if strings.HasPrefix(tt.expectWritten, "HTTP/") {
assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
if strings.Contains(tt.expectWritten, "Content-Length: 4") {
assert.Contains(t, written, "Content-Length: 4\r\n")
}
assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
} else {
assert.Equal(t, tt.expectWritten, written)
}
}
if tt.isHTTP {
respMW.AssertExpectations(t)
}
if tt.middlewareErr == nil {
writer.AssertExpectations(t)
}
})
}
}
func TestWriteErrors(t *testing.T) {
tests := []struct {
name string
setup func() (io.Writer, io.Reader)
data []byte
}{
{
name: "write error in writeHeaderAndBody",
setup: func() (io.Writer, io.Reader) {
writer := new(MockWriter)
writer.On("Write", mock.Anything).Return(0, assert.AnError)
reader := new(MockReader)
return writer, reader
},
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
},
{
name: "write error in writeHeaderAndBody second write",
setup: func() (io.Writer, io.Reader) {
writer := new(MockWriter)
writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once()
writer.On("Write", mock.Anything).Return(0, assert.AnError).Once()
reader := new(MockReader)
return writer, reader
},
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
},
{
name: "write error in writeRawBuffer",
setup: func() (io.Writer, io.Reader) {
writer := new(MockWriter)
writer.On("Write", mock.Anything).Return(0, assert.AnError)
reader := new(MockReader)
return writer, reader
},
data: []byte("Not HTTP\r\n\r\nFlush"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
w, r := tt.setup()
hs := New(w, r, addr)
_, err := hs.Write(tt.data)
assert.Error(t, err)
w.(*MockWriter).AssertExpectations(t)
})
}
}
func TestReadEOF(t *testing.T) {
tests := []struct {
name string
setup func() io.Reader
expectN int
expectErr error
expectContent string
}{
{
name: "read eof",
setup: func() io.Reader {
reader := new(MockReader)
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
copy(p, "data")
}).Return(4, io.EOF)
return reader
},
expectN: 4,
expectErr: io.EOF,
expectContent: "data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
reader := tt.setup()
hs := New(nil, reader, addr)
p := make([]byte, 100)
n, err := hs.Read(p)
assert.Equal(t, tt.expectN, n)
assert.Equal(t, tt.expectErr, err)
assert.Equal(t, tt.expectContent, string(p[:n]))
reader.(*MockReader).AssertExpectations(t)
})
}
}
+9 -28
View File
@@ -5,8 +5,6 @@ import (
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors"
"io"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
@@ -14,20 +12,7 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
var (
rsaGenerateKey = rsa.GenerateKey
pemEncode = pem.Encode
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
return ssh.NewPublicKey(key)
}
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
return w.Write(data)
}
osOpenFile = os.OpenFile
)
func GenerateSSHKeyIfNotExist(keyPath string) error { func GenerateSSHKeyIfNotExist(keyPath string) error {
var errGroup = make([]error, 0)
if _, err := os.Stat(keyPath); err == nil { if _, err := os.Stat(keyPath); err == nil {
log.Printf("SSH key already exists at %s", keyPath) log.Printf("SSH key already exists at %s", keyPath)
return nil return nil
@@ -35,7 +20,7 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
log.Printf("SSH key not found at %s, generating new key pair...", keyPath) log.Printf("SSH key not found at %s, generating new key pair...", keyPath)
privateKey, err := rsaGenerateKey(rand.Reader, 4096) privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil { if err != nil {
return err return err
} }
@@ -50,37 +35,33 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
return err return err
} }
privateKeyFile, err := osOpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) privateKeyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil { if err != nil {
return err return err
} }
defer func(privateKeyFile *os.File) { defer privateKeyFile.Close()
errGroup = append(errGroup, privateKeyFile.Close())
}(privateKeyFile)
if err := pemEncode(privateKeyFile, privateKeyPEM); err != nil { if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil {
return err return err
} }
publicKey, err := sshNewPublicKey(&privateKey.PublicKey) publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
if err != nil { if err != nil {
return err return err
} }
pubKeyPath := keyPath + ".pub" pubKeyPath := keyPath + ".pub"
pubKeyFile, err := osOpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil { if err != nil {
return err return err
} }
defer func(pubKeyFile *os.File) { defer pubKeyFile.Close()
errGroup = append(errGroup, pubKeyFile.Close())
}(pubKeyFile)
_, err = pubKeyWrite(pubKeyFile, ssh.MarshalAuthorizedKey(publicKey)) _, err = pubKeyFile.Write(ssh.MarshalAuthorizedKey(publicKey))
if err != nil { if err != nil {
return err return err
} }
log.Printf("SSH key pair generated successfully at %s and %s", keyPath, pubKeyPath) log.Printf("SSH key pair generated successfully at %s and %s", keyPath, pubKeyPath)
return errors.Join(errGroup...) return nil
} }
-235
View File
@@ -1,235 +0,0 @@
package key
import (
"crypto/rsa"
"encoding/pem"
"errors"
"io"
"os"
"path/filepath"
"testing"
"golang.org/x/crypto/ssh"
)
func TestGenerateSSHKeyIfNotExist(t *testing.T) {
tempDir := t.TempDir()
tests := []struct {
name string
setup func(t *testing.T, tempDir string) string
mockSetup func() func()
wantErr bool
errStr string
verify func(t *testing.T, keyPath string)
}{
{
name: "GenerateNewKey",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "id_rsa")
},
verify: func(t *testing.T, keyPath string) {
pubKeyPath := keyPath + ".pub"
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Errorf("Private key file not created")
}
if _, err := os.Stat(pubKeyPath); os.IsNotExist(err) {
t.Errorf("Public key file not created")
}
privateKeyBytes, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("Failed to read private key: %v", err)
}
if _, err = ssh.ParseRawPrivateKey(privateKeyBytes); err != nil {
t.Errorf("Failed to parse private key: %v", err)
}
publicKeyBytes, err := os.ReadFile(pubKeyPath)
if err != nil {
t.Fatalf("Failed to read public key: %v", err)
}
if _, _, _, _, err = ssh.ParseAuthorizedKey(publicKeyBytes); err != nil {
t.Errorf("Failed to parse public key: %v", err)
}
},
},
{
name: "DoNotOverwriteExistingKey",
setup: func(t *testing.T, tempDir string) string {
keyPath := filepath.Join(tempDir, "existing_id_rsa")
dummyPrivate := "dummy private"
dummyPublic := "dummy public"
if err := os.WriteFile(keyPath, []byte(dummyPrivate), 0600); err != nil {
t.Fatalf("Failed to create dummy private key: %v", err)
}
if err := os.WriteFile(keyPath+".pub", []byte(dummyPublic), 0644); err != nil {
t.Fatalf("Failed to create dummy public key: %v", err)
}
return keyPath
},
verify: func(t *testing.T, keyPath string) {
gotPrivate, _ := os.ReadFile(keyPath)
if string(gotPrivate) != "dummy private" {
t.Errorf("Private key was overwritten")
}
gotPublic, _ := os.ReadFile(keyPath + ".pub")
if string(gotPublic) != "dummy public" {
t.Errorf("Public key was overwritten")
}
},
},
{
name: "CreateNestedDirectories",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "nested", "dir", "id_rsa")
},
verify: func(t *testing.T, keyPath string) {
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Errorf("Private key file not created in nested directory")
}
},
},
{
name: "FailureMkdirAll",
setup: func(t *testing.T, tempDir string) string {
dirPath := filepath.Join(tempDir, "file_as_dir")
if err := os.WriteFile(dirPath, []byte("not a dir"), 0644); err != nil {
t.Fatalf("Failed to create file: %v", err)
}
return filepath.Join(dirPath, "id_rsa")
},
wantErr: true,
},
{
name: "PrivateExistsPublicMissing",
setup: func(t *testing.T, tempDir string) string {
keyPath := filepath.Join(tempDir, "partial_id_rsa")
if err := os.WriteFile(keyPath, []byte("private"), 0600); err != nil {
t.Fatalf("Failed to create private key: %v", err)
}
return keyPath
},
verify: func(t *testing.T, keyPath string) {
if _, err := os.Stat(keyPath + ".pub"); !os.IsNotExist(err) {
t.Errorf("Public key should NOT have been created if private key existed")
}
},
},
{
name: "FailureRSAGenerateKey",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_rsa")
},
mockSetup: func() func() {
old := rsaGenerateKey
rsaGenerateKey = func(random io.Reader, bits int) (*rsa.PrivateKey, error) {
return nil, errors.New("rsa error")
}
return func() { rsaGenerateKey = old }
},
wantErr: true,
errStr: "rsa error",
},
{
name: "FailureOpenFilePrivate",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_open_private")
},
mockSetup: func() func() {
old := osOpenFile
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
return nil, errors.New("open error")
}
return func() { osOpenFile = old }
},
wantErr: true,
errStr: "open error",
},
{
name: "FailurePemEncode",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_pem")
},
mockSetup: func() func() {
old := pemEncode
pemEncode = func(out io.Writer, b *pem.Block) error {
return errors.New("pem error")
}
return func() { pemEncode = old }
},
wantErr: true,
errStr: "pem error",
},
{
name: "FailureSSHNewPublicKey",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_ssh")
},
mockSetup: func() func() {
old := sshNewPublicKey
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
return nil, errors.New("ssh error")
}
return func() { sshNewPublicKey = old }
},
wantErr: true,
errStr: "ssh error",
},
{
name: "FailureOpenFilePublic",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_open_public")
},
mockSetup: func() func() {
old := osOpenFile
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
if filepath.Ext(name) == ".pub" {
return nil, errors.New("open pub error")
}
return os.OpenFile(name, flag, perm)
}
return func() { osOpenFile = old }
},
wantErr: true,
errStr: "open pub error",
},
{
name: "FailurePubKeyWrite",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_write")
},
mockSetup: func() func() {
old := pubKeyWrite
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
return 0, errors.New("write error")
}
return func() { pubKeyWrite = old }
},
wantErr: true,
errStr: "write error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyPath := tt.setup(t, tempDir)
if tt.mockSetup != nil {
cleanup := tt.mockSetup()
defer cleanup()
}
err := GenerateSSHKeyIfNotExist(keyPath)
if (err != nil) != tt.wantErr {
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr && tt.errStr != "" && err != nil && err.Error() != tt.errStr {
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErrStr %v", err, tt.errStr)
}
if tt.verify != nil {
tt.verify(t, keyPath)
}
})
}
}
-126
View File
@@ -1,126 +0,0 @@
package middleware
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type mockRequestHeader struct {
mock.Mock
}
func (m *mockRequestHeader) Value(key string) string {
return m.Called(key).String(0)
}
func (m *mockRequestHeader) Set(key string, value string) {
m.Called(key, value)
}
func (m *mockRequestHeader) Remove(key string) {
m.Called(key)
}
func (m *mockRequestHeader) Finalize() []byte {
return m.Called().Get(0).([]byte)
}
func (m *mockRequestHeader) Method() string {
return m.Called().String(0)
}
func (m *mockRequestHeader) Path() string {
return m.Called().String(0)
}
func (m *mockRequestHeader) Version() string {
return m.Called().String(0)
}
func TestForwardedFor_HandleRequest(t *testing.T) {
tests := []struct {
name string
addr net.Addr
expectedHost string
expectError bool
}{
{
name: "valid IPv4 address",
addr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 8080},
expectedHost: "192.168.1.100",
expectError: false,
},
{
name: "valid IPv6 address",
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 8080},
expectedHost: "2001:db8::ff00:42:8329",
expectError: false,
},
{
name: "invalid address format",
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
expectedHost: "",
expectError: true,
},
{
name: "valid IPv4 address with port",
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
expectedHost: "127.0.0.1",
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ff := NewForwardedFor(tc.addr)
reqHeader := new(mockRequestHeader)
if !tc.expectError {
reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return()
}
err := ff.HandleRequest(reqHeader)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
reqHeader.AssertExpectations(t)
}
})
}
}
func TestNewForwardedFor(t *testing.T) {
tests := []struct {
name string
addr net.Addr
expectAddr net.Addr
}{
{
name: "IPv4 address",
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
expectAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
},
{
name: "IPv6 address",
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
expectAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
},
{
name: "Unix address",
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
expectAddr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ff := NewForwardedFor(tc.addr)
assert.Equal(t, tc.expectAddr.String(), ff.addr.String())
})
}
}
@@ -1,70 +0,0 @@
package middleware
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type mockResponseHeader struct {
mock.Mock
}
func (m *mockResponseHeader) Value(key string) string {
return m.Called(key).String(0)
}
func (m *mockResponseHeader) Set(key string, value string) {
m.Called(key, value)
}
func (m *mockResponseHeader) Remove(key string) {
m.Called(key)
}
func (m *mockResponseHeader) Finalize() []byte {
return m.Called().Get(0).([]byte)
}
func TestTunnelFingerprintHandleResponse(t *testing.T) {
tests := []struct {
name string
expected map[string]string
body []byte
wantErr error
}{
{
name: "Sets Server Header",
expected: map[string]string{"Server": "Tunnel Please"},
body: []byte("Sample body"),
wantErr: nil,
},
{
name: "Overwrites Server Header",
expected: map[string]string{"Server": "Tunnel Please"},
body: nil,
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockHeader := new(mockResponseHeader)
for k, v := range tt.expected {
mockHeader.On("Set", k, v).Return()
}
tunnelFingerprint := NewTunnelFingerprint()
err := tunnelFingerprint.HandleResponse(mockHeader, tt.body)
assert.ErrorIs(t, err, tt.wantErr)
mockHeader.AssertExpectations(t)
})
}
}
func TestNewTunnelFingerprint(t *testing.T) {
instance := NewTunnelFingerprint()
assert.NotNil(t, instance)
}
-114
View File
@@ -1,114 +0,0 @@
package port
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddRange(t *testing.T) {
tests := []struct {
name string
startPort uint16
endPort uint16
wantErr bool
}{
{"normal range", 1000, 1002, false},
{"invalid range", 2000, 1999, true},
{"single port range", 3000, 3000, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm := New()
err := pm.AddRange(tt.startPort, tt.endPort)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestUnassigned(t *testing.T) {
pm := New()
_ = pm.AddRange(1000, 1002)
tests := []struct {
name string
status map[uint16]bool
want uint16
wantOk bool
}{
{"all unassigned", map[uint16]bool{1000: false, 1001: false, 1002: false}, 1000, true},
{"some assigned", map[uint16]bool{1000: true, 1001: false, 1002: true}, 1001, true},
{"all assigned", map[uint16]bool{1000: true, 1001: true, 1002: true}, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for k, v := range tt.status {
_ = pm.SetStatus(k, v)
}
got, gotOk := pm.Unassigned()
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantOk, gotOk)
})
}
}
func TestSetStatus(t *testing.T) {
pm := New()
_ = pm.AddRange(1000, 1002)
tests := []struct {
name string
port uint16
assigned bool
}{
{"assign port 1000", 1000, true},
{"unassign port 1001", 1001, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := pm.SetStatus(tt.port, tt.assigned)
assert.NoError(t, err)
status, ok := pm.(*port).ports[tt.port]
assert.True(t, ok)
assert.Equal(t, tt.assigned, status)
})
}
}
func TestClaim(t *testing.T) {
pm := New()
_ = pm.AddRange(1000, 1002)
tests := []struct {
name string
port uint16
status bool
want bool
}{
{"claim unassigned port", 1000, false, true},
{"claim already assigned port", 1001, true, false},
{"claim non-existent port", 5000, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, exists := pm.(*port).ports[tt.port]; exists {
_ = pm.SetStatus(tt.port, tt.status)
}
got := pm.Claim(tt.port)
assert.Equal(t, tt.want, got)
finalState := pm.(*port).ports[tt.port]
assert.True(t, finalState)
})
}
}
+3 -26
View File
@@ -1,35 +1,12 @@
package random package random
import ( import "crypto/rand"
"crypto/rand"
"fmt"
"io"
)
var ( func GenerateRandomString(length int) (string, error) {
ErrInvalidLength = fmt.Errorf("invalid length")
)
type Random interface {
String(length int) (string, error)
}
type random struct {
reader io.Reader
}
func New() Random {
return &random{reader: rand.Reader}
}
func (ran *random) String(length int) (string, error) {
if length < 0 {
return "", ErrInvalidLength
}
const charset = "abcdefghijklmnopqrstuvwxyz0123456789" const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length) b := make([]byte, length)
if _, err := ran.reader.Read(b); err != nil { if _, err := rand.Read(b); err != nil {
return "", err return "", err
} }
-70
View File
@@ -1,70 +0,0 @@
package random
import (
"io"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRandom_String(t *testing.T) {
tests := []struct {
name string
length int
wantErr bool
}{
{"ValidLengthZero", 0, false},
{"ValidPositiveLength", 10, false},
{"NegativeLength", -1, true},
{"VeryLargeLength", 1_000_000, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randomizer := New()
result, err := randomizer.String(tt.length)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, result, tt.length)
}
})
}
}
func TestRandomWithFailingReader_String(t *testing.T) {
errBrainrot := assert.AnError
tests := []struct {
name string
reader io.Reader
expectErr error
}{
{
name: "failing reader",
reader: func() io.Reader {
return &failingReader{err: errBrainrot}
}(),
expectErr: errBrainrot,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randomizer := &random{reader: tt.reader}
result, err := randomizer.String(20)
assert.ErrorIs(t, err, tt.expectErr)
assert.Empty(t, result)
})
}
}
type failingReader struct {
err error
}
func (f *failingReader) Read(p []byte) (int, error) {
return 0, f.err
}
+6 -4
View File
@@ -94,13 +94,12 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
return ErrInvalidSlug return ErrInvalidSlug
} }
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return ErrSlugInUse
}
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return ErrSlugInUse
}
client, ok := r.byUser[user][oldKey] client, ok := r.byUser[user][oldKey]
if !ok { if !ok {
return ErrSessionNotFound return ErrSessionNotFound
@@ -112,6 +111,9 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
client.Slug().Set(newKey.Id) client.Slug().Set(newKey.Id)
r.slugIndex[newKey] = user r.slugIndex[newKey] = user
if r.byUser[user] == nil {
r.byUser[user] = make(map[Key]Session)
}
r.byUser[user][newKey] = client r.byUser[user][newKey] = client
return nil return nil
} }
-695
View File
@@ -1,695 +0,0 @@
package registry
import (
"sync"
"testing"
"time"
"tunnel_pls/internal/port"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
type mockSession struct {
mock.Mock
}
func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(lifecycle.Lifecycle)
}
func (m *mockSession) Interaction() interaction.Interaction {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(interaction.Interaction)
}
func (m *mockSession) Forwarder() forwarder.Forwarder {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(forwarder.Forwarder)
}
func (m *mockSession) Slug() slug.Slug {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(slug.Slug)
}
func (m *mockSession) Detail() *types.Detail {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(*types.Detail)
}
type mockLifecycle struct {
mock.Mock
}
func (ml *mockLifecycle) Channel() ssh.Channel {
args := ml.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(ssh.Channel)
}
func (ml *mockLifecycle) Connection() ssh.Conn {
args := ml.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(ssh.Conn)
}
func (ml *mockLifecycle) PortRegistry() port.Port {
args := ml.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(port.Port)
}
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { ml.Called(channel) }
func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { ml.Called(status) }
func (ml *mockLifecycle) IsActive() bool { return ml.Called().Bool(0) }
func (ml *mockLifecycle) StartedAt() time.Time { return ml.Called().Get(0).(time.Time) }
func (ml *mockLifecycle) Close() error { return ml.Called().Error(0) }
func (ml *mockLifecycle) User() string { return ml.Called().String(0) }
type mockSlug struct {
mock.Mock
}
func (ms *mockSlug) Set(slug string) { ms.Called(slug) }
func (ms *mockSlug) String() string { return ms.Called().String(0) }
func createMockSession(user ...string) *mockSession {
u := "user1"
if len(user) > 0 {
u = user[0]
}
m := new(mockSession)
ml := new(mockLifecycle)
ml.On("User").Return(u).Maybe()
m.On("Lifecycle").Return(ml).Maybe()
ms := new(mockSlug)
ms.On("Set", mock.Anything).Maybe()
m.On("Slug").Return(ms).Maybe()
m.On("Interaction").Return(nil).Maybe()
m.On("Forwarder").Return(nil).Maybe()
m.On("Detail").Return(nil).Maybe()
return m
}
func TestNewRegistry(t *testing.T) {
r := NewRegistry()
require.NotNil(t, r)
}
func TestRegistry_Get(t *testing.T) {
tests := []struct {
name string
setupFunc func(r *registry)
key types.SessionKey
wantErr error
wantResult bool
}{
{
name: "session found",
setupFunc: func(r *registry) {
user := "user1"
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
session := createMockSession(user)
r.mu.Lock()
defer r.mu.Unlock()
r.byUser[user] = map[types.SessionKey]Session{
key: session,
}
r.slugIndex[key] = user
},
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
wantErr: nil,
wantResult: true,
},
{
name: "session not found in slugIndex",
setupFunc: func(r *registry) {},
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
wantErr: ErrSessionNotFound,
},
{
name: "session not found in byUser",
setupFunc: func(r *registry) {
r.mu.Lock()
defer r.mu.Unlock()
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
},
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
wantErr: ErrSessionNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
byUser: make(map[string]map[types.SessionKey]Session),
slugIndex: make(map[types.SessionKey]string),
mu: sync.RWMutex{},
}
tt.setupFunc(r)
session, err := r.Get(tt.key)
assert.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantResult, session != nil)
})
}
}
func TestRegistry_GetWithUser(t *testing.T) {
tests := []struct {
name string
setupFunc func(r *registry)
user string
key types.SessionKey
wantErr error
wantResult bool
}{
{
name: "session found",
setupFunc: func(r *registry) {
user := "user1"
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser[user] = map[types.SessionKey]Session{
key: session,
}
r.slugIndex[key] = user
},
user: "user1",
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
wantErr: nil,
wantResult: true,
},
{
name: "session not found in slugIndex",
setupFunc: func(r *registry) {},
user: "user1",
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
wantErr: ErrSessionNotFound,
},
{
name: "session not found in byUser",
setupFunc: func(r *registry) {
r.mu.Lock()
defer r.mu.Unlock()
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
},
user: "user1",
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
wantErr: ErrSessionNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
byUser: make(map[string]map[types.SessionKey]Session),
slugIndex: make(map[types.SessionKey]string),
mu: sync.RWMutex{},
}
tt.setupFunc(r)
session, err := r.GetWithUser(tt.user, tt.key)
assert.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantResult, session != nil)
})
}
}
func TestRegistry_Update(t *testing.T) {
tests := []struct {
name string
user string
setupFunc func(r *registry) (oldKey, newKey types.SessionKey)
wantErr error
}{
{
name: "change slug success",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
session := createMockSession("user1")
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: nil,
},
{
name: "change slug to already used slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
newKey: session,
}
r.slugIndex[oldKey] = "user1"
r.slugIndex[newKey] = "user1"
return oldKey, newKey
},
wantErr: ErrSlugInUse,
},
{
name: "change slug to forbidden slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: ErrForbiddenSlug,
},
{
name: "change slug to invalid slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: ErrInvalidSlug,
},
{
name: "change slug but session not found",
user: "user2",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
}
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
return oldKey, newKey
},
wantErr: ErrSessionNotFound,
},
{
name: "change slug but session is not in the map",
user: "user2",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
}
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
return oldKey, newKey
},
wantErr: ErrSessionNotFound,
},
{
name: "change slug with same slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: ErrSlugUnchanged,
},
{
name: "tcp tunnel cannot change slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
newKey := oldKey
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: ErrSlugChangeNotAllowed,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
r := &registry{
byUser: make(map[string]map[types.SessionKey]Session),
slugIndex: make(map[types.SessionKey]string),
mu: sync.RWMutex{},
}
oldKey, newKey := tt.setupFunc(r)
err := r.Update(tt.user, oldKey, newKey)
assert.ErrorIs(t, err, tt.wantErr)
if err == nil {
r.mu.RLock()
defer r.mu.RUnlock()
_, ok := r.byUser[tt.user][newKey]
assert.True(t, ok, "newKey not found in registry")
_, ok = r.byUser[tt.user][oldKey]
assert.False(t, ok, "oldKey still exists in registry")
}
})
}
}
func TestRegistry_Register(t *testing.T) {
tests := []struct {
name string
user string
setupFunc func(r *registry) Key
wantOK bool
}{
{
name: "register new key successfully",
user: "user1",
setupFunc: func(r *registry) Key {
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
return key
},
wantOK: true,
},
{
name: "register already existing key fails",
user: "user1",
setupFunc: func(r *registry) Key {
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
r.byUser["user1"] = map[Key]Session{key: session}
r.slugIndex[key] = "user1"
r.mu.Unlock()
return key
},
wantOK: false,
},
{
name: "register multiple keys for same user",
user: "user1",
setupFunc: func(r *registry) Key {
firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
r.byUser["user1"] = map[Key]Session{firstKey: session}
r.slugIndex[firstKey] = "user1"
r.mu.Unlock()
return types.SessionKey{Id: "second", Type: types.TunnelTypeHTTP}
},
wantOK: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
r := &registry{
byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string),
mu: sync.RWMutex{},
}
key := tt.setupFunc(r)
session := createMockSession()
ok := r.Register(key, session)
assert.Equal(t, tt.wantOK, ok)
if ok {
r.mu.RLock()
defer r.mu.RUnlock()
assert.Equal(t, session, r.byUser[tt.user][key], "session not stored in byUser")
assert.Equal(t, tt.user, r.slugIndex[key], "slugIndex not updated")
}
})
}
}
func TestRegistry_GetAllSessionFromUser(t *testing.T) {
tests := []struct {
name string
setupFunc func(r *registry) string
expectN int
}{
{
name: "user has no sessions",
setupFunc: func(r *registry) string {
return "user1"
},
expectN: 0,
},
{
name: "user has multiple sessions",
setupFunc: func(r *registry) string {
user := "user1"
key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
r.mu.Lock()
r.byUser[user] = map[Key]Session{
key1: createMockSession(),
key2: createMockSession(),
}
r.mu.Unlock()
return user
},
expectN: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string),
mu: sync.RWMutex{},
}
user := tt.setupFunc(r)
sessions := r.GetAllSessionFromUser(user)
assert.Len(t, sessions, tt.expectN)
})
}
}
func TestRegistry_Remove(t *testing.T) {
tests := []struct {
name string
setupFunc func(r *registry) (string, types.SessionKey)
key types.SessionKey
verify func(*testing.T, *registry, string, types.SessionKey)
}{
{
name: "remove existing key",
setupFunc: func(r *registry) (string, types.SessionKey) {
user := "user1"
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
r.byUser[user] = map[Key]Session{key: session}
r.slugIndex[key] = user
r.mu.Unlock()
return user, key
},
verify: func(t *testing.T, r *registry, user string, key types.SessionKey) {
_, ok := r.byUser[user][key]
assert.False(t, ok, "expected key to be removed from byUser")
_, ok = r.slugIndex[key]
assert.False(t, ok, "expected key to be removed from slugIndex")
_, ok = r.byUser[user]
assert.False(t, ok, "expected user to be removed from byUser map")
},
},
{
name: "remove non-existing key",
setupFunc: func(r *registry) (string, types.SessionKey) {
return "", types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string),
mu: sync.RWMutex{},
}
user, key := tt.setupFunc(r)
if user == "" {
key = tt.key
}
r.Remove(key)
if tt.verify != nil {
tt.verify(t, r, user, key)
}
})
}
}
func TestIsValidSlug(t *testing.T) {
tests := []struct {
slug string
want bool
}{
{"abc", true},
{"abc-123", true},
{"a", false},
{"verybigdihsixsevenlabubu", false},
{"-iamsigma", false},
{"ligma-", false},
{"invalid$", false},
{"valid-slug1", true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.slug, func(t *testing.T) {
got := isValidSlug(tt.slug)
if got != tt.want {
t.Errorf("isValidSlug(%q) = %v; want %v", tt.slug, got, tt.want)
}
})
}
}
func TestIsValidSlugChar(t *testing.T) {
tests := []struct {
char byte
want bool
}{
{'a', true},
{'z', true},
{'0', true},
{'9', true},
{'-', true},
{'A', false},
{'$', false},
}
for _, tt := range tests {
tt := tt
t.Run(string(tt.char), func(t *testing.T) {
got := isValidSlugChar(tt.char)
if got != tt.want {
t.Errorf("isValidSlugChar(%q) = %v; want %v", tt.char, got, tt.want)
}
})
}
}
func TestIsForbiddenSlug(t *testing.T) {
forbiddenSlugs = map[string]struct{}{
"admin": {},
"root": {},
}
tests := []struct {
slug string
want bool
}{
{"admin", true},
{"root", true},
{"user", false},
{"guest", false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.slug, func(t *testing.T) {
got := isForbiddenSlug(tt.slug)
if got != tt.want {
t.Errorf("isForbiddenSlug(%q) = %v; want %v", tt.slug, got, tt.want)
}
})
}
}
+7 -8
View File
@@ -4,28 +4,27 @@ import (
"errors" "errors"
"log" "log"
"net" "net"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry" "tunnel_pls/internal/registry"
) )
type httpServer struct { type httpServer struct {
handler *httpHandler handler *httpHandler
config config.Config port string
} }
func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport { func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &httpServer{ return &httpServer{
handler: newHTTPHandler(config, sessionRegistry), handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
config: config, port: port,
} }
} }
func (ht *httpServer) Listen() (net.Listener, error) { func (ht *httpServer) Listen() (net.Listener, error) {
return net.Listen("tcp", ":"+ht.config.HTTPPort()) return net.Listen("tcp", ":"+ht.port)
} }
func (ht *httpServer) Serve(listener net.Listener) error { func (ht *httpServer) Serve(listener net.Listener) error {
log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort()) log.Printf("HTTP server is starting on port %s", ht.port)
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
@@ -36,6 +35,6 @@ func (ht *httpServer) Serve(listener net.Listener) error {
continue continue
} }
go ht.handler.Handler(conn, false) go ht.handler.handler(conn, false)
} }
} }
-135
View File
@@ -1,135 +0,0 @@
package transport
import (
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestNewHTTPServer(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
assert.NotNil(t, srv)
httpSrv, ok := srv.(*httpServer)
assert.True(t, ok)
assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
assert.NotNil(t, srv)
}
func TestHTTPServer_Listen(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
listener, err := srv.Listen()
assert.NoError(t, err)
assert.NotNil(t, listener)
err = listener.Close()
assert.NoError(t, err)
}
func TestHTTPServer_Serve(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
time.Sleep(100 * time.Millisecond)
err = listener.Close()
assert.NoError(t, err)
}()
err = srv.Serve(listener)
assert.True(t, errors.Is(err, net.ErrClosed))
}
func TestHTTPServer_Serve_AcceptError(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
err := srv.Serve(ml)
assert.True(t, errors.Is(err, net.ErrClosed))
ml.AssertExpectations(t)
}
func TestHTTPServer_Serve_Success(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
mockConfig.On("HeaderSize").Return(4096)
mockConfig.On("TLSRedirect").Return(false)
srv := NewHTTPServer(mockConfig, msr)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
listenerport := listener.Addr().(*net.TCPAddr).Port
go func() {
_ = srv.Serve(listener)
}()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
assert.NoError(t, err)
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
time.Sleep(100 * time.Millisecond)
err = conn.Close()
assert.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
}
type mockListener struct {
mock.Mock
}
func (m *mockListener) Accept() (net.Conn, error) {
args := m.Called()
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(net.Conn), args.Error(1)
}
func (m *mockListener) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *mockListener) Addr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
+70 -40
View File
@@ -1,8 +1,7 @@
package transport package transport
import ( import (
"bytes" "bufio"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -11,7 +10,6 @@ import (
"net/http" "net/http"
"strings" "strings"
"time" "time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/http/header" "tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream" "tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware" "tunnel_pls/internal/middleware"
@@ -22,14 +20,16 @@ import (
) )
type httpHandler struct { type httpHandler struct {
config config.Config domain string
sessionRegistry registry.Registry sessionRegistry registry.Registry
redirectTLS bool
} }
func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler { func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
return &httpHandler{ return &httpHandler{
config: config, domain: domain,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
} }
} }
@@ -52,28 +52,13 @@ func (hh *httpHandler) badRequest(conn net.Conn) error {
return nil return nil
} }
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) { func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
defer hh.closeConnection(conn) defer hh.closeConnection(conn)
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second)) dstReader := bufio.NewReader(conn)
buf := make([]byte, hh.config.HeaderSize()) reqhf, err := header.NewRequest(dstReader)
n, err := conn.Read(buf)
if err != nil {
_ = hh.badRequest(conn)
return
}
if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 {
_ = hh.badRequest(conn)
return
}
_ = conn.SetReadDeadline(time.Time{})
reqhf, err := header.NewRequest(buf[:n])
if err != nil { if err != nil {
log.Printf("Error creating request header: %v", err) log.Printf("Error creating request header: %v", err)
_ = hh.badRequest(conn)
return return
} }
@@ -84,7 +69,7 @@ func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
} }
if hh.shouldRedirectToTLS(isTLS) { if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain())) _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
return return
} }
@@ -92,16 +77,13 @@ func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
return return
} }
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ sshSession, err := hh.getSession(slug)
Id: slug,
Type: types.TunnelTypeHTTP,
})
if err != nil { if err != nil {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return return
} }
hw := stream.New(conn, conn, conn.RemoteAddr()) hw := stream.New(conn, dstReader, conn.RemoteAddr())
defer func(hw stream.HTTP) { defer func(hw stream.HTTP) {
err = hw.Close() err = hw.Close()
if err != nil { if err != nil {
@@ -120,14 +102,14 @@ func (hh *httpHandler) closeConnection(conn net.Conn) {
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) { func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".") host := strings.Split(reqhf.Value("Host"), ".")
if len(host) <= 1 { if len(host) < 1 {
return "", errors.New("invalid host") return "", errors.New("invalid host")
} }
return host[0], nil return host[0], nil
} }
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool { func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
return !isTLS && hh.config.TLSRedirect() return !isTLS && hh.redirectTLS
} }
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
@@ -146,22 +128,29 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
)) ))
if err != nil { if err != nil {
log.Println("Failed to write 200 OK:", err) log.Println("Failed to write 200 OK:", err)
return true
} }
return true return true
} }
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
defer cancel() Id: slug,
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, hw.RemoteAddr()) Type: types.TunnelTypeHTTP,
})
if err != nil { if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err) return nil, err
}
return sshSession, nil
}
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession)
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return return
} }
go ssh.DiscardRequests(reqs)
defer func() { defer func() {
err = channel.Close() err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
@@ -178,6 +167,47 @@ func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.Requ
sshSession.Forwarder().HandleConnection(hw, channel) sshSession.Forwarder().HandleConnection(hw, channel)
} }
func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
hh.cleanupUnusedChannel(channel, reqs)
}
}()
select {
case result := <-resultChan:
if result.err != nil {
return nil, result.err
}
go ssh.DiscardRequests(result.reqs)
return result.channel, nil
case <-time.After(5 * time.Second):
return nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
if channel != nil {
if err := channel.Close(); err != nil {
log.Printf("Failed to close unused channel: %v", err)
}
go ssh.DiscardRequests(reqs)
}
}
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) { func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
fingerprintMiddleware := middleware.NewTunnelFingerprint() fingerprintMiddleware := middleware.NewTunnelFingerprint()
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr()) forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
-717
View File
@@ -1,717 +0,0 @@
package transport
import (
"bytes"
"context"
"fmt"
"io"
"net"
"strings"
"sync"
"testing"
"time"
"tunnel_pls/internal/registry"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockSession struct {
mock.Mock
}
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
args := m.Called()
return args.Get(0).(lifecycle.Lifecycle)
}
func (m *MockSession) Interaction() interaction.Interaction {
args := m.Called()
return args.Get(0).(interaction.Interaction)
}
func (m *MockSession) Forwarder() forwarder.Forwarder {
args := m.Called()
return args.Get(0).(forwarder.Forwarder)
}
func (m *MockSession) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
func (m *MockSession) Detail() *types.Detail {
args := m.Called()
return args.Get(0).(*types.Detail)
}
type MockSSHChannel struct {
ssh.Channel
mock.Mock
}
func (m *MockSSHChannel) Write(data []byte) (int, error) {
args := m.Called(data)
return args.Int(0), args.Error(1)
}
func (m *MockSSHChannel) Close() error {
args := m.Called()
return args.Error(0)
}
type MockForwarder struct {
mock.Mock
}
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
m.Called(dst, src)
}
func (m *MockForwarder) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockForwarder) TunnelType() types.TunnelType {
args := m.Called()
return args.Get(0).(types.TunnelType)
}
func (m *MockForwarder) ForwardedPort() uint16 {
args := m.Called()
return uint16(args.Int(0))
}
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
m.Called(tunnelType)
}
func (m *MockForwarder) SetForwardedPort(port uint16) {
m.Called(port)
}
func (m *MockForwarder) SetListener(listener net.Listener) {
m.Called(listener)
}
func (m *MockForwarder) Listener() net.Listener {
args := m.Called()
return args.Get(0).(net.Listener)
}
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
args := m.Called(ctx, origin)
if args.Get(0) == nil {
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
type MockConn struct {
mock.Mock
ReadBuffer *bytes.Buffer
}
func (m *MockConn) LocalAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockConn) SetDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) Read(b []byte) (n int, err error) {
if m.ReadBuffer != nil {
return m.ReadBuffer.Read(b)
}
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Write(b []byte) (n int, err error) {
args := m.Called(b)
if args.Int(0) == -1 {
return len(b), args.Error(1)
}
return args.Int(0), args.Error(1)
}
func (m *MockConn) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockConn) RemoteAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
type wrappedConn struct {
net.Conn
remoteAddr net.Addr
}
func (c *wrappedConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func TestNewHTTPHandler(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
mockConfig.On("Domain").Return("domain")
mockConfig.On("TLSRedirect").Return(false)
hh := newHTTPHandler(mockConfig, msr)
assert.NotNil(t, hh)
assert.Equal(t, msr, hh.sessionRegistry)
}
func TestHandler(t *testing.T) {
tests := []struct {
name string
isTLS bool
redirectTLS bool
request []byte
expected []byte
setupMocks func(*MockSessionRegistry)
setupConn func() (net.Conn, net.Conn)
expectError bool
}{
{
name: "bad request - invalid host",
isTLS: false,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "bad request - missing host",
isTLS: false,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\n\r\n"),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "isTLS true and redirectTLS true - no redirect",
isTLS: true,
redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "redirect to TLS",
isTLS: false,
redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"),
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "handle ping request",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "session not found",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return((registry.Session)(nil), fmt.Errorf("session not found"))
},
},
{
name: "bad request - invalid http",
isTLS: false,
redirectTLS: false,
request: []byte("INVALID\r\n\r\n"),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "bad request - header too large",
isTLS: false,
redirectTLS: false,
request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "bad request - no request",
isTLS: false,
redirectTLS: false,
request: []byte(""),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "forwarding - open channel fails",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
},
},
{
name: "forwarding - send initial request fails",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mockSSHChannel.On("Close").Return(nil)
go func() {
for range reqCh {
}
}()
},
},
{
name: "forwarding - success",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil)
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
w := args.Get(0).(io.ReadWriter)
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
})
go func() {
for range reqCh {
}
}()
},
},
{
name: "redirect - write failure",
isTLS: false,
redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "bad request - write failure",
isTLS: false,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "read error - connection failure",
isTLS: false,
redirectTLS: false,
request: []byte(""),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "handle ping request - write failure",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "close connection - error",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(182, nil)
mc.On("Close").Return(fmt.Errorf("close error"))
return mc, nil
},
},
{
name: "forwarding - stream close error",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil)
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return()
},
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
mc.On("RemoteAddr").Return(addr)
return mc, nil
},
},
{
name: "forwarding - middleware failure",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
return k.Id == "test"
})).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Close").Return(nil)
},
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Close").Return(nil).Times(2)
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
return mc, nil
},
},
{
name: "forwarding - channel close error",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
w := args.Get(0).(io.ReadWriter)
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
})
},
},
{
name: "forwarding - open channel timeout",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
ctx := args.Get(0).(context.Context)
<-ctx.Done()
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSessionRegistry := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
mockConfig.On("HeaderSize").Return(4096)
mockConfig.On("TLSRedirect").Return(true)
hh := &httpHandler{
sessionRegistry: mockSessionRegistry,
config: mockConfig,
}
if tt.setupMocks != nil {
tt.setupMocks(mockSessionRegistry)
}
var serverConn, clientConn net.Conn
if tt.setupConn != nil {
serverConn, clientConn = tt.setupConn()
} else {
serverConn, clientConn = net.Pipe()
}
if clientConn != nil {
defer func(clientConn net.Conn) {
err := clientConn.Close()
assert.NoError(t, err)
}(clientConn)
}
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
var wrappedServerConn net.Conn
if _, ok := serverConn.(*MockConn); ok {
wrappedServerConn = serverConn
} else {
wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr}
}
responseChan := make(chan []byte, 1)
doneChan := make(chan struct{})
if clientConn != nil {
go func() {
defer close(doneChan)
var res []byte
for {
buf := make([]byte, 4096)
n, err := clientConn.Read(buf)
if err != nil {
if err != io.EOF {
t.Logf("Error reading response: %v", err)
}
break
}
res = append(res, buf[:n]...)
if len(tt.expected) > 0 && len(res) >= len(tt.expected) {
break
}
}
responseChan <- res
}()
go func() {
_, err := clientConn.Write(tt.request)
if err != nil {
t.Logf("Error writing request: %v", err)
}
}()
} else {
close(responseChan)
close(doneChan)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
hh.Handler(wrappedServerConn, tt.isTLS)
}()
select {
case response := <-responseChan:
if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" {
resStr := string(response)
assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n"))
assert.Contains(t, resStr, "Content-Length: 5\r\n")
assert.Contains(t, resStr, "Server: Tunnel Please\r\n")
assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello"))
} else {
assert.Equal(t, string(tt.expected), string(response))
}
case <-time.After(10 * time.Second):
if clientConn != nil {
t.Fatal("Test timeout - no response received")
}
}
wg.Wait()
if clientConn != nil {
<-doneChan
}
mockSessionRegistry.AssertExpectations(t)
if mc, ok := serverConn.(*MockConn); ok {
mc.AssertExpectations(t)
}
})
}
}
+9 -8
View File
@@ -5,30 +5,31 @@ import (
"errors" "errors"
"log" "log"
"net" "net"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry" "tunnel_pls/internal/registry"
) )
type https struct { type https struct {
config config.Config
tlsConfig *tls.Config tlsConfig *tls.Config
httpHandler *httpHandler httpHandler *httpHandler
domain string
port string
} }
func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport { func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
return &https{ return &https{
config: config,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
httpHandler: newHTTPHandler(config, sessionRegistry), httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
domain: domain,
port: port,
} }
} }
func (ht *https) Listen() (net.Listener, error) { func (ht *https) Listen() (net.Listener, error) {
return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig) return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
} }
func (ht *https) Serve(listener net.Listener) error { func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort()) log.Printf("HTTPS server is starting on port %s", ht.port)
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
@@ -39,6 +40,6 @@ func (ht *https) Serve(listener net.Listener) error {
continue continue
} }
go ht.httpHandler.Handler(conn, true) go ht.httpHandler.handler(conn, true)
} }
} }
-120
View File
@@ -1,120 +0,0 @@
package transport
import (
"crypto/tls"
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewHTTPSServer(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
tlsConfig := &tls.Config{}
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
assert.NotNil(t, srv)
httpsSrv, ok := srv.(*https)
assert.True(t, ok)
assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
}
func TestHTTPSServer_Listen(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
tlsConfig := &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, nil
},
}
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
listener, err := srv.Listen()
if err != nil {
t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err)
return
}
assert.NotNil(t, listener)
err = listener.Close()
assert.NoError(t, err)
}
func TestHTTPSServer_Serve(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
time.Sleep(100 * time.Millisecond)
err = listener.Close()
assert.NoError(t, err)
}()
err = srv.Serve(listener)
assert.True(t, errors.Is(err, net.ErrClosed))
}
func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
err := srv.Serve(ml)
assert.True(t, errors.Is(err, net.ErrClosed))
ml.AssertExpectations(t)
}
func TestHTTPSServer_Serve_Success(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
mockConfig.On("HeaderSize").Return(4096)
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
listenerport := listener.Addr().(*net.TCPAddr).Port
go func() {
_ = srv.Serve(listener)
}()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
assert.NoError(t, err)
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
time.Sleep(100 * time.Millisecond)
err = conn.Close()
assert.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
}
+8 -9
View File
@@ -1,28 +1,27 @@
package transport package transport
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type tcp struct { type tcp struct {
port uint16 port uint16
forwarder Forwarder forwarder forwarder
} }
type Forwarder interface { type forwarder interface {
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
HandleConnection(dst io.ReadWriter, src ssh.Channel) HandleConnection(dst io.ReadWriter, src ssh.Channel)
} }
func NewTCPServer(port uint16, forwarder Forwarder) Transport { func NewTCPServer(port uint16, forwarder forwarder) Transport {
return &tcp{ return &tcp{
port: port, port: port,
forwarder: forwarder, forwarder: forwarder,
@@ -54,11 +53,11 @@ func (tt *tcp) handleTcp(conn net.Conn) {
log.Printf("Failed to close connection: %v", err) log.Printf("Failed to close connection: %v", err)
} }
}() }()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
defer cancel() channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, conn.RemoteAddr())
if err != nil { if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err) log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return return
} }
-146
View File
@@ -1,146 +0,0 @@
package transport
import (
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/crypto/ssh"
)
func TestNewTCPServer(t *testing.T) {
mf := new(MockForwarder)
port := uint16(9000)
srv := NewTCPServer(port, mf)
assert.NotNil(t, srv)
tcpSrv, ok := srv.(*tcp)
assert.True(t, ok)
assert.Equal(t, port, tcpSrv.port)
assert.Equal(t, mf, tcpSrv.forwarder)
}
func TestTCPServer_Listen(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
listener, err := srv.Listen()
assert.NoError(t, err)
assert.NotNil(t, listener)
err = listener.Close()
assert.NoError(t, err)
}
func TestTCPServer_Serve(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
time.Sleep(100 * time.Millisecond)
err = listener.Close()
assert.NoError(t, err)
}()
err = srv.Serve(listener)
assert.Nil(t, err)
}
func TestTCPServer_Serve_AcceptError(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
err := srv.Serve(ml)
assert.Nil(t, err)
ml.AssertExpectations(t)
}
func TestTCPServer_Serve_Success(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
reqs := make(chan *ssh.Request)
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
go func() {
_ = srv.Serve(listener)
}()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = conn.Close()
assert.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
mf.AssertExpectations(t)
}
func TestTCPServer_handleTcp_Success(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf).(*tcp)
serverConn, clientConn := net.Pipe()
defer func(clientConn net.Conn) {
err := clientConn.Close()
assert.NoError(t, err)
}(clientConn)
reqs := make(chan *ssh.Request)
mockChannel := new(MockSSHChannel)
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
mf.On("HandleConnection", serverConn, mockChannel).Return()
srv.handleTcp(serverConn)
mf.AssertExpectations(t)
}
func TestTCPServer_handleTcp_CloseError(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf).(*tcp)
mc := new(MockConn)
mc.On("Close").Return(errors.New("close error"))
mc.On("RemoteAddr").Return(&net.TCPAddr{})
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
srv.handleTcp(mc)
mc.AssertExpectations(t)
}
func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf).(*tcp)
serverConn, clientConn := net.Pipe()
defer func(clientConn net.Conn) {
err := clientConn.Close()
assert.NoError(t, err)
}(clientConn)
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
srv.handleTcp(serverConn)
mf.AssertExpectations(t)
}
+150 -276
View File
@@ -8,7 +8,6 @@ import (
"fmt" "fmt"
"log" "log"
"os" "os"
"path/filepath"
"sync" "sync"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
@@ -17,22 +16,13 @@ import (
"github.com/libdns/cloudflare" "github.com/libdns/cloudflare"
) )
func NewTLSConfig(config config.Config) (*tls.Config, error) { type TLSManager interface {
var initErr error userCertsExistAndValid() bool
loadUserCerts() error
tlsManagerOnce.Do(func() { startCertWatcher()
tm := createTLSManager(config) initCertMagic() error
initErr = tm.initialize() getTLSConfig() *tls.Config
if initErr == nil { getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
globalTLSManager = tm
}
})
if initErr != nil {
return nil, initErr
}
return globalTLSManager.getTLSConfig(), nil
} }
type tlsManager struct { type tlsManager struct {
@@ -50,60 +40,52 @@ type tlsManager struct {
useCertMagic bool useCertMagic bool
} }
var globalTLSManager *tlsManager var globalTLSManager TLSManager
var tlsManagerOnce sync.Once var tlsManagerOnce sync.Once
func createTLSManager(cfg config.Config) *tlsManager { func NewTLSConfig(config config.Config) (*tls.Config, error) {
storagePath := cfg.TLSStoragePath() var initErr error
cleanBase := filepath.Clean(storagePath)
return &tlsManager{ tlsManagerOnce.Do(func() {
config: cfg, certPath := "certs/tls/cert.pem"
certPath: filepath.Join(cleanBase, "cert.pem"), keyPath := "certs/tls/privkey.pem"
keyPath: filepath.Join(cleanBase, "privkey.pem"), storagePath := "certs/tls/certmagic"
storagePath: filepath.Join(cleanBase, "certmagic"),
}
}
func (tm *tlsManager) initialize() error { tm := &tlsManager{
if tm.userCertsExistAndValid() { config: config,
return tm.initializeWithUserCerts() certPath: certPath,
} keyPath: keyPath,
return tm.initializeWithCertMagic() storagePath: storagePath,
} }
func (tm *tlsManager) initializeWithUserCerts() error { if tm.userCertsExistAndValid() {
log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath) log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath)
if err := tm.loadUserCerts(); err != nil {
initErr = fmt.Errorf("failed to load user certificates: %w", err)
return
}
tm.useCertMagic = false
tm.startCertWatcher()
} else {
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain())
if err := tm.initCertMagic(); err != nil {
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
return
}
tm.useCertMagic = true
}
if err := tm.loadUserCerts(); err != nil { globalTLSManager = tm
return fmt.Errorf("failed to load user certificates: %w", err) })
if initErr != nil {
return nil, initErr
} }
tm.useCertMagic = false return globalTLSManager.getTLSConfig(), nil
tm.startCertWatcher()
return nil
}
func (tm *tlsManager) initializeWithCertMagic() error {
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic",
tm.config.Domain(), tm.config.Domain())
if err := tm.initCertMagic(); err != nil {
return fmt.Errorf("failed to initialize CertMagic: %w", err)
}
tm.useCertMagic = true
return nil
} }
func (tm *tlsManager) userCertsExistAndValid() bool { func (tm *tlsManager) userCertsExistAndValid() bool {
if !tm.certFilesExist() {
return false
}
return validateCertDomains(tm.certPath, tm.config.Domain())
}
func (tm *tlsManager) certFilesExist() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath) log.Printf("Certificate file not found: %s", tm.certPath)
return false return false
@@ -112,7 +94,66 @@ func (tm *tlsManager) certFilesExist() bool {
log.Printf("Key file not found: %s", tm.keyPath) log.Printf("Key file not found: %s", tm.keyPath)
return false return false
} }
return true
return ValidateCertDomains(tm.certPath, tm.config.Domain())
}
func ValidateCertDomains(certPath, domain string) bool {
certPEM, err := os.ReadFile(certPath)
if err != nil {
log.Printf("Failed to read certificate: %v", err)
return false
}
block, _ := pem.Decode(certPEM)
if block == nil {
log.Printf("Failed to decode PEM block from certificate")
return false
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Printf("Failed to parse certificate: %v", err)
return false
}
if time.Now().After(cert.NotAfter) {
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
return false
}
if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) {
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
return false
}
var certDomains []string
if cert.Subject.CommonName != "" {
certDomains = append(certDomains, cert.Subject.CommonName)
}
certDomains = append(certDomains, cert.DNSNames...)
hasBase := false
hasWildcard := false
wildcardDomain := "*." + domain
for _, d := range certDomains {
if d == domain {
hasBase = true
}
if d == wildcardDomain {
hasWildcard = true
}
}
if !hasBase {
log.Printf("Certificate does not cover base domain: %s", domain)
}
if !hasWildcard {
log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain)
}
return hasBase && hasWildcard
} }
func (tm *tlsManager) loadUserCerts() error { func (tm *tlsManager) loadUserCerts() error {
@@ -131,34 +172,62 @@ func (tm *tlsManager) loadUserCerts() error {
func (tm *tlsManager) startCertWatcher() { func (tm *tlsManager) startCertWatcher() {
go func() { go func() {
watcher := newCertWatcher(tm) var lastCertMod, lastKeyMod time.Time
watcher.watch()
if info, err := os.Stat(tm.certPath); err == nil {
lastCertMod = info.ModTime()
}
if info, err := os.Stat(tm.keyPath); err == nil {
lastKeyMod = info.ModTime()
}
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
certInfo, certErr := os.Stat(tm.certPath)
keyInfo, keyErr := os.Stat(tm.keyPath)
if certErr != nil || keyErr != nil {
continue
}
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
log.Printf("Certificate files changed, reloading...")
if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
log.Printf("New certificates don't cover required domains")
if err := tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err)
continue
}
tm.useCertMagic = true
return
}
if err := tm.loadUserCerts(); err != nil {
log.Printf("Failed to reload certificates: %v", err)
continue
}
lastCertMod = certInfo.ModTime()
lastKeyMod = keyInfo.ModTime()
log.Printf("Certificates reloaded successfully")
}
}
}() }()
} }
func (tm *tlsManager) initCertMagic() error { func (tm *tlsManager) initCertMagic() error {
if err := tm.createStorageDirectory(); err != nil { if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
return err return fmt.Errorf("failed to create cert storage directory: %w", err)
} }
if tm.config.CFAPIToken() == "" { if tm.config.CFAPIToken() == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
} }
magic := tm.createCertMagicConfig()
tm.magic = magic
return tm.obtainCertificates(magic)
}
func (tm *tlsManager) createStorageDirectory() error {
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
return fmt.Errorf("failed to create cert storage directory: %w", err)
}
return nil
}
func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
cfProvider := &cloudflare.Provider{ cfProvider := &cloudflare.Provider{
APIToken: tm.config.CFAPIToken(), APIToken: tm.config.CFAPIToken(),
} }
@@ -175,13 +244,6 @@ func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
Storage: storage, Storage: storage,
}) })
acmeIssuer := tm.createACMEIssuer(magic, cfProvider)
magic.Issuers = []certmagic.Issuer{acmeIssuer}
return magic
}
func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer {
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: tm.config.ACMEEmail(), Email: tm.config.ACMEEmail(),
Agreed: true, Agreed: true,
@@ -200,10 +262,9 @@ func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *clou
log.Printf("Using Let's Encrypt production server") log.Printf("Using Let's Encrypt production server")
} }
return acmeIssuer magic.Issuers = []certmagic.Issuer{acmeIssuer}
} tm.magic = magic
func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error {
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()} domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
log.Printf("Requesting certificates for: %v", domains) log.Printf("Requesting certificates for: %v", domains)
@@ -246,190 +307,3 @@ func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certifica
return tm.userCert, nil return tm.userCert, nil
} }
func validateCertDomains(certPath, domain string) bool {
cert, err := loadAndParseCertificate(certPath)
if err != nil {
return false
}
if !isCertificateValid(cert) {
return false
}
return certCoversRequiredDomains(cert, domain)
}
func loadAndParseCertificate(certPath string) (*x509.Certificate, error) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
log.Printf("Failed to read certificate: %v", err)
return nil, err
}
block, _ := pem.Decode(certPEM)
if block == nil {
log.Printf("Failed to decode PEM block from certificate")
return nil, fmt.Errorf("failed to decode PEM block")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Printf("Failed to parse certificate: %v", err)
return nil, err
}
return cert, nil
}
func isCertificateValid(cert *x509.Certificate) bool {
now := time.Now()
if now.After(cert.NotAfter) {
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
return false
}
thirtyDaysFromNow := now.Add(30 * 24 * time.Hour)
if thirtyDaysFromNow.After(cert.NotAfter) {
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
return false
}
return true
}
func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool {
certDomains := extractCertDomains(cert)
hasBase, hasWildcard := checkDomainCoverage(certDomains, domain)
logDomainCoverage(hasBase, hasWildcard, domain)
return hasBase && hasWildcard
}
func extractCertDomains(cert *x509.Certificate) []string {
var domains []string
if cert.Subject.CommonName != "" {
domains = append(domains, cert.Subject.CommonName)
}
domains = append(domains, cert.DNSNames...)
return domains
}
func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) {
wildcardDomain := "*." + domain
for _, d := range certDomains {
if d == domain {
hasBase = true
}
if d == wildcardDomain {
hasWildcard = true
}
}
return hasBase, hasWildcard
}
func logDomainCoverage(hasBase, hasWildcard bool, domain string) {
if !hasBase {
log.Printf("Certificate does not cover base domain: %s", domain)
}
if !hasWildcard {
log.Printf("Certificate does not cover wildcard domain: *.%s", domain)
}
}
type certWatcher struct {
tm *tlsManager
lastCertMod time.Time
lastKeyMod time.Time
}
func newCertWatcher(tm *tlsManager) *certWatcher {
watcher := &certWatcher{tm: tm}
watcher.initializeModTimes()
return watcher
}
func (cw *certWatcher) initializeModTimes() {
if info, err := os.Stat(cw.tm.certPath); err == nil {
cw.lastCertMod = info.ModTime()
}
if info, err := os.Stat(cw.tm.keyPath); err == nil {
cw.lastKeyMod = info.ModTime()
}
}
func (cw *certWatcher) watch() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
if cw.checkAndReloadCerts() {
return
}
}
}
func (cw *certWatcher) checkAndReloadCerts() bool {
certInfo, keyInfo, err := cw.getFileInfo()
if err != nil {
return false
}
if !cw.filesModified(certInfo, keyInfo) {
return false
}
return cw.handleCertificateChange(certInfo, keyInfo)
}
func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) {
certInfo, certErr := os.Stat(cw.tm.certPath)
keyInfo, keyErr := os.Stat(cw.tm.keyPath)
if certErr != nil || keyErr != nil {
return nil, nil, fmt.Errorf("file stat error")
}
return certInfo, keyInfo, nil
}
func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool {
return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod)
}
func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool {
log.Printf("Certificate files changed, reloading...")
if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) {
return cw.switchToCertMagic()
}
if err := cw.tm.loadUserCerts(); err != nil {
log.Printf("Failed to reload certificates: %v", err)
return false
}
cw.updateModTimes(certInfo, keyInfo)
log.Printf("Certificates reloaded successfully")
return false
}
func (cw *certWatcher) switchToCertMagic() bool {
log.Printf("New certificates don't cover required domains")
if err := cw.tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err)
return false
}
cw.tm.useCertMagic = true
return true
}
func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) {
cw.lastCertMod = certInfo.ModTime()
cw.lastKeyMod = keyInfo.ModTime()
}
File diff suppressed because it is too large Load Diff
-4
View File
@@ -8,7 +8,3 @@ type Transport interface {
Listen() (net.Listener, error) Listen() (net.Listener, error)
Serve(listener net.Listener) error Serve(listener net.Listener) error
} }
type HTTP interface {
Handler(conn net.Conn, isTLS bool)
}
-84
View File
@@ -1,84 +0,0 @@
package version
import (
"fmt"
"testing"
)
func TestVersionFunctions(t *testing.T) {
origVersion := Version
origBuildDate := BuildDate
origCommit := Commit
defer func() {
Version = origVersion
BuildDate = origBuildDate
Commit = origCommit
}()
tests := []struct {
name string
version string
buildDate string
commit string
wantFull string
wantShort string
}{
{
name: "Default dev version",
version: "dev",
buildDate: "unknown",
commit: "unknown",
wantFull: "tunnel_pls dev (commit: unknown, built: unknown)",
wantShort: "dev",
},
{
name: "Release version",
version: "v1.0.0",
buildDate: "2026-01-23",
commit: "abcdef123",
wantFull: "tunnel_pls v1.0.0 (commit: abcdef123, built: 2026-01-23)",
wantShort: "v1.0.0",
},
{
name: "Empty values",
version: "",
buildDate: "",
commit: "",
wantFull: "tunnel_pls (commit: , built: )",
wantShort: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Version = tt.version
BuildDate = tt.buildDate
Commit = tt.commit
gotFull := GetVersion()
if gotFull != tt.wantFull {
t.Errorf("GetVersion() = %q, want %q", gotFull, tt.wantFull)
}
gotShort := GetShortVersion()
if gotShort != tt.wantShort {
t.Errorf("GetShortVersion() = %q, want %q", gotShort, tt.wantShort)
}
})
}
}
func TestGetVersion_Format(t *testing.T) {
v := "1.2.3"
c := "brainrot"
d := "now"
Version = v
Commit = c
BuildDate = d
expected := fmt.Sprintf("tunnel_pls %s (commit: %s, built: %s)", v, c, d)
if GetVersion() != expected {
t.Errorf("GetVersion() formatting mismatch")
}
}
+149 -6
View File
@@ -1,13 +1,27 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"net"
"net/http"
_ "net/http/pprof"
"os" "os"
"tunnel_pls/internal/bootstrap" "os/signal"
"syscall"
"time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/key"
"tunnel_pls/internal/port" "tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/internal/version" "tunnel_pls/internal/version"
"tunnel_pls/server"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
) )
func main() { func main() {
@@ -18,19 +32,148 @@ func main() {
log.SetOutput(os.Stdout) log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile) log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Printf("Starting %s", version.GetVersion()) log.Printf("Starting %s", version.GetVersion())
conf, err := config.MustLoad() conf, err := config.MustLoad()
if err != nil { if err != nil {
log.Fatalf("Config load error: %v", err) log.Fatalf("Failed to load configuration: %s", err)
return
} }
boot, err := bootstrap.New(conf, port.New()) sshConfig := &ssh.ServerConfig{
NoClientAuth: true,
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
}
sshKeyPath := "certs/ssh/id_rsa"
if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
log.Fatalf("Failed to generate SSH key: %s", err)
}
privateBytes, err := os.ReadFile(sshKeyPath)
if err != nil { if err != nil {
log.Fatalf("Startup error: %v", err) log.Fatalf("Failed to load private key: %s", err)
} }
if err = boot.Run(); err != nil { private, err := ssh.ParsePrivateKey(privateBytes)
log.Fatalf("Application error: %v", err) if err != nil {
log.Fatalf("Failed to parse private key: %s", err)
}
sshConfig.AddHostKey(private)
sessionRegistry := registry.NewRegistry()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errChan := make(chan error, 2)
shutdownChan := make(chan os.Signal, 1)
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
var grpcClient client.Client
if conf.Mode() == types.ServerModeNODE {
grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort())
grpcClient, err = client.New(conf, grpcAddr, sessionRegistry)
if err != nil {
log.Fatalf("failed to create grpc client: %v", err)
}
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
if err = grpcClient.CheckServerHealth(healthCtx); err != nil {
healthCancel()
log.Fatalf("gRPC health check failed: %v", err)
}
healthCancel()
go func() {
if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
}
}()
}
go func() {
var httpListener net.Listener
httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect())
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
if conf.TLSEnabled() {
go func() {
var httpsListener net.Listener
tlsConfig, _ := transport.NewTLSConfig(conf)
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig)
httpsListener, err = httpsServer.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpsServer.Serve(httpsListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
}
portManager := port.New()
err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd())
if err != nil {
log.Fatalf("Failed to initialize port manager: %s", err)
return
}
var app server.Server
go func() {
app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort())
if err != nil {
errChan <- fmt.Errorf("failed to start server: %s", err)
return
}
app.Start()
}()
if conf.PprofEnabled() {
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort())
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
}
select {
case err = <-errChan:
log.Printf("error happen : %s", err)
case sig := <-shutdownChan:
log.Printf("received signal %s, shutting down", sig)
}
cancel()
if app != nil {
if err = app.Close(); err != nil {
log.Printf("failed to close server : %s", err)
}
}
if grpcClient != nil {
if err = grpcClient.Close(); err != nil {
log.Printf("failed to close grpc conn : %s", err)
}
} }
} }
+5 -17
View File
@@ -4,14 +4,12 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"log" "log"
"net" "net"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port" "tunnel_pls/internal/port"
"tunnel_pls/internal/random"
"tunnel_pls/internal/registry" "tunnel_pls/internal/registry"
"tunnel_pls/session" "tunnel_pls/session"
@@ -23,7 +21,6 @@ type Server interface {
Close() error Close() error
} }
type server struct { type server struct {
randomizer random.Random
config config.Config config config.Config
sshPort string sshPort string
sshListener net.Listener sshListener net.Listener
@@ -33,14 +30,13 @@ type server struct {
portRegistry port.Port portRegistry port.Port
} }
func New(randomizer random.Random, config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) { func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort)) listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &server{ return &server{
randomizer: randomizer,
config: config, config: config,
sshPort: sshPort, sshPort: sshPort,
sshListener: listener, sshListener: listener,
@@ -86,7 +82,7 @@ func (s *server) handleConnection(conn net.Conn) {
defer func(sshConn *ssh.ServerConn) { defer func(sshConn *ssh.ServerConn) {
err = sshConn.Close() err = sshConn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("failed to close SSH server: %v", err) log.Printf("failed to close SSH server: %v", err)
} }
}(sshConn) }(sshConn)
@@ -99,19 +95,11 @@ func (s *server) handleConnection(conn net.Conn) {
cancel() cancel()
} }
log.Println("SSH connection established:", sshConn.User()) log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(&session.Config{ sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
Randomizer: s.randomizer,
Config: s.config,
Conn: sshConn,
InitialReq: forwardingReqs,
SshChan: chans,
SessionRegistry: s.sessionRegistry,
PortRegistry: s.portRegistry,
User: user,
})
err = sshSession.Start() err = sshSession.Start()
if err != nil { if err != nil {
log.Printf("SSH session ended with error: %s", err.Error()) log.Printf("SSH session ended with error: %v", err)
return return
} }
return
} }
-880
View File
@@ -1,880 +0,0 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"net"
"testing"
"time"
"tunnel_pls/internal/registry"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
)
type MockRandom struct {
mock.Mock
}
func (m *MockRandom) String(length int) (string, error) {
args := m.Called(length)
return args.String(0), args.Error(1)
}
type MockConfig struct {
mock.Mock
}
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode {
args := m.Called()
if args.Get(0) == nil {
return 0
}
switch v := args.Get(0).(type) {
case types.ServerMode:
return v
case int:
return types.ServerMode(v)
default:
return types.ServerMode(args.Int(0))
}
}
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockGRPCClient struct {
mock.Mock
}
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
args := m.Called()
return args.Get(0).(*grpc.ClientConn)
}
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
args := m.Called(ctx, token)
return args.Bool(0), args.String(1), args.Error(2)
}
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
args := m.Called(ctx, domain, token)
return args.Error(0)
}
func (m *MockGRPCClient) Close() error {
args := m.Called()
return args.Error(0)
}
type MockPort struct {
mock.Mock
}
func (m *MockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
return uint16(args.Int(0)), args.Bool(1)
}
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *MockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type MockListener struct {
mock.Mock
}
func (m *MockListener) Accept() (net.Conn, error) {
args := m.Called()
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(net.Conn), args.Error(1)
}
func (m *MockListener) Close() error {
return m.Called().Error(0)
}
func (m *MockListener) Addr() net.Addr {
return m.Called().Get(0).(net.Addr)
}
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
key, _ := rsa.GenerateKey(rand.Reader, 2048)
signer, _ := ssh.NewSignerFromKey(key)
config := &ssh.ServerConfig{
NoClientAuth: true,
}
config.AddHostKey(signer)
return config, signer
}
func TestNew(t *testing.T) {
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
tests := []struct {
name string
port string
wantErr bool
}{
{
name: "success",
port: "0",
wantErr: false,
},
{
name: "invalid port",
port: "invalid",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := New(mr, mc, sc, mreg, mg, mp, tt.port)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, s)
} else {
assert.NoError(t, err)
assert.NotNil(t, s)
_ = s.Close()
}
})
}
t.Run("port already in use", func(t *testing.T) {
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
port := l.Addr().(*net.TCPAddr).Port
defer func(l net.Listener) {
err = l.Close()
assert.NoError(t, err)
}(l)
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
assert.Error(t, err)
assert.Nil(t, s)
})
}
func TestClose(t *testing.T) {
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
t.Run("successful close", func(t *testing.T) {
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
err := s.Close()
assert.NoError(t, err)
})
t.Run("close already closed listener", func(t *testing.T) {
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
_ = s.Close()
err := s.Close()
assert.Error(t, err)
})
t.Run("close with nil listener", func(t *testing.T) {
s := &server{
sshListener: nil,
}
defer func() {
if r := recover(); r != nil {
assert.NotNil(t, r)
}
}()
_ = s.Close()
t.Fatal("expected panic for nil listener")
})
}
func TestStart(t *testing.T) {
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
t.Run("normal stop", func(t *testing.T) {
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
go func() {
time.Sleep(100 * time.Millisecond)
_ = s.Close()
}()
s.Start()
})
t.Run("accept error - temporary error continues loop", func(t *testing.T) {
ml := new(MockListener)
s := &server{
sshListener: ml,
sshPort: "0",
}
ml.On("Accept").Return(nil, errors.New("temporary error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
s.Start()
ml.AssertExpectations(t)
})
t.Run("accept error - immediate close", func(t *testing.T) {
ml := new(MockListener)
s := &server{
sshListener: ml,
sshPort: "0",
}
ml.On("Accept").Return(nil, net.ErrClosed).Once()
s.Start()
ml.AssertExpectations(t)
})
t.Run("accept success - connection fails SSH handshake", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
mockListener := &MockListener{}
mockListener.On("Accept").Return(serverConn, nil).Once()
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshListener: mockListener,
sshConfig: sshConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
go s.Start()
time.Sleep(50 * time.Millisecond)
err := clientConn.Close()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
})
t.Run("accept success - valid SSH connection without auth", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
mockListener := &MockListener{}
mockListener.On("Accept").Return(serverConn, nil).Once()
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshListener: mockListener,
sshConfig: sshConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
go s.Start()
time.Sleep(50 * time.Millisecond)
err := clientConn.Close()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
})
}
func TestHandleConnection(t *testing.T) {
t.Run("SSH handshake fails - connection closed", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: sshConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
err := clientConn.Close()
assert.NoError(t, err)
s.handleConnection(serverConn)
})
// SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS
// GONNA IMPLEMENT THIS UNIT TEST LATER
//t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) {
// mockRandom := &MockRandom{}
// mockConfig := &MockConfig{}
// mockSessionRegistry := &MockSessionRegistry{}
// mockGrpcClient := &MockGRPCClient{}
// mockPort := &MockPort{}
//
// sshConfig, _ := getTestSSHConfig()
//
// serverConn, clientConn := net.Pipe()
//
// s := &server{
// randomizer: mockRandom,
// config: mockConfig,
// sshPort: "0",
// sshConfig: sshConfig,
// grpcClient: mockGrpcClient,
// sessionRegistry: mockSessionRegistry,
// portRegistry: mockPort,
// }
//
// done := make(chan bool, 1)
//
// go func() {
// s.handleConnection(serverConn)
// done <- true
// }()
//
// go func() {
// clientConn.Write([]byte("invalid ssh protocol\n"))
// clientConn.Close()
// }()
//
// select {
// case <-done:
// case <-time.After(1 * time.Second):
// t.Fatal("handleConnection did not complete in time")
// }
//})
t.Run("SSH connection established without gRPC client", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
mockConfig.On("Domain").Return("test.com")
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("SSHPort").Return("2200")
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer func(listener net.Listener) {
err = listener.Close()
assert.NoError(t, err)
}(listener)
serverAddr := listener.Addr().String()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
s.handleConnection(conn)
done <- true
}()
time.Sleep(50 * time.Millisecond)
clientConfig := &ssh.ClientConfig{
User: "testuser",
Auth: []ssh.AuthMethod{ssh.Password("password")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
go func() {
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
if err != nil {
t.Logf("Client dial failed: %v", err)
return
}
defer func(client *ssh.Client) {
err = client.Close()
assert.NoError(t, err)
}(client)
type forwardPayload struct {
BindAddr string
BindPort uint32
}
payload := ssh.Marshal(forwardPayload{
BindAddr: "localhost",
BindPort: 80,
})
_, _, err = client.SendRequest("tcpip-forward", true, payload)
if err != nil {
t.Logf("Forward request failed: %v", err)
}
time.Sleep(500 * time.Millisecond)
}()
select {
case <-done:
t.Log("handleConnection completed")
case <-time.After(5 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
t.Run("SSH connection established with gRPC authorization", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
mockConfig.On("Domain").Return("test.com")
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("SSHPort").Return("2200")
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer func(listener net.Listener) {
err = listener.Close()
assert.NoError(t, err)
}(listener)
serverAddr := listener.Addr().String()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
s.handleConnection(conn)
done <- true
}()
time.Sleep(50 * time.Millisecond)
clientConfig := &ssh.ClientConfig{
User: "testuser",
Auth: []ssh.AuthMethod{ssh.Password("password")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
go func() {
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
if err != nil {
t.Logf("Client dial failed: %v", err)
return
}
defer func(client *ssh.Client) {
err = client.Close()
assert.NoError(t, err)
}(client)
type forwardPayload struct {
BindAddr string
BindPort uint32
}
payload := ssh.Marshal(forwardPayload{
BindAddr: "localhost",
BindPort: 80,
})
_, _, err = client.SendRequest("tcpip-forward", true, payload)
if err != nil {
t.Logf("Forward request failed: %v", err)
}
time.Sleep(500 * time.Millisecond)
}()
select {
case <-done:
mockGrpcClient.AssertExpectations(t)
case <-time.After(5 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
t.Run("SSH connection with gRPC authorization error", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
mockConfig.On("Domain").Return("test.com")
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("SSHPort").Return("2200")
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer func(listener net.Listener) {
err = listener.Close()
assert.NoError(t, err)
}(listener)
serverAddr := listener.Addr().String()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
s.handleConnection(conn)
done <- true
}()
time.Sleep(50 * time.Millisecond)
clientConfig := &ssh.ClientConfig{
User: "testuser",
Auth: []ssh.AuthMethod{ssh.Password("password")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
go func() {
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
if err != nil {
t.Logf("Client dial failed: %v", err)
return
}
defer func(client *ssh.Client) {
_ = client.Close()
}(client)
type forwardPayload struct {
BindAddr string
BindPort uint32
}
payload := ssh.Marshal(forwardPayload{
BindAddr: "localhost",
BindPort: 8080,
})
_, _, err = client.SendRequest("tcpip-forward", true, payload)
if err != nil {
t.Logf("Forward request failed: %v", err)
}
time.Sleep(500 * time.Millisecond)
}()
select {
case <-done:
mockGrpcClient.AssertExpectations(t)
case <-time.After(5 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
t.Run("connection cleanup on close", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
s.handleConnection(serverConn)
done <- true
}()
err := clientConn.Close()
assert.NoError(t, err)
select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
}
func TestIntegration(t *testing.T) {
t.Run("full server lifecycle", func(t *testing.T) {
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
s, err := New(mr, mc, sc, mreg, mg, mp, "0")
assert.NoError(t, err)
assert.NotNil(t, s)
go func() {
time.Sleep(100 * time.Millisecond)
err := s.Close()
assert.NoError(t, err)
}()
s.Start()
})
t.Run("multiple connections", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
conn1Server, conn1Client := net.Pipe()
conn2Server, conn2Client := net.Pipe()
mockListener := &MockListener{}
mockListener.On("Accept").Return(conn1Server, nil).Once()
mockListener.On("Accept").Return(conn2Server, nil).Once()
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshListener: mockListener,
sshConfig: sshConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
go s.Start()
time.Sleep(50 * time.Millisecond)
_ = conn1Client.Close()
time.Sleep(50 * time.Millisecond)
_ = conn2Client.Close()
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
})
}
func TestErrorHandling(t *testing.T) {
t.Run("write error during SSH handshake", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
err := clientConn.Close()
assert.NoError(t, err)
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: sshConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
s.handleConnection(serverConn)
})
}
+66 -27
View File
@@ -1,7 +1,8 @@
package forwarder package forwarder
import ( import (
"context" "bytes"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -9,6 +10,7 @@ import (
"net" "net"
"strconv" "strconv"
"sync" "sync"
"time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
@@ -24,7 +26,9 @@ type Forwarder interface {
TunnelType() types.TunnelType TunnelType() types.TunnelType
ForwardedPort() uint16 ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel) HandleConnection(dst io.ReadWriter, src ssh.Channel)
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
WriteBadGatewayResponse(dst io.Writer)
Close() error Close() error
} }
type forwarder struct { type forwarder struct {
@@ -46,21 +50,19 @@ func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
bufferPool: sync.Pool{ bufferPool: sync.Pool{
New: func() interface{} { New: func() interface{} {
bufSize := config.BufferSize() bufSize := config.BufferSize()
buf := make([]byte, bufSize) return make([]byte, bufSize)
return &buf
}, },
}, },
} }
} }
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := f.bufferPool.Get().(*[]byte) buf := f.bufferPool.Get().([]byte)
defer f.bufferPool.Put(buf) defer f.bufferPool.Put(buf)
return io.CopyBuffer(dst, src, *buf) return io.CopyBuffer(dst, src, buf)
} }
func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) { func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
payload := createForwardedTCPIPPayload(origin, f.forwardedPort)
type channelResult struct { type channelResult struct {
channel ssh.Channel channel ssh.Channel
reqs <-chan *ssh.Request reqs <-chan *ssh.Request
@@ -72,9 +74,13 @@ func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload) channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
select { select {
case resultChan <- channelResult{channel, reqs, err}: case resultChan <- channelResult{channel, reqs, err}:
case <-ctx.Done(): default:
if channel != nil { if channel != nil {
_ = channel.Close() err = channel.Close()
if err != nil {
log.Printf("Failed to close unused channel: %v", err)
return
}
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
} }
} }
@@ -83,8 +89,8 @@ func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (
select { select {
case result := <-resultChan: case result := <-resultChan:
return result.channel, result.reqs, result.err return result.channel, result.reqs, result.err
case <-ctx.Done(): case <-time.After(5 * time.Second):
return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err()) return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
} }
} }
@@ -113,7 +119,10 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string)
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
defer func() { defer func() {
_, _ = io.Copy(io.Discard, src) _, err := io.Copy(io.Discard, src)
if err != nil {
log.Printf("Failed to discard connection: %v", err)
}
}() }()
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -164,6 +173,14 @@ func (f *forwarder) Listener() net.Listener {
return f.listener return f.listener
} }
func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
_, err := dst.Write(types.BadGatewayResponse)
if err != nil {
log.Printf("failed to write Bad Gateway response: %v", err)
return
}
}
func (f *forwarder) Close() error { func (f *forwarder) Close() error {
if f.Listener() != nil { if f.Listener() != nil {
return f.listener.Close() return f.listener.Close()
@@ -171,21 +188,43 @@ func (f *forwarder) Close() error {
return nil return nil
} }
func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte { func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
host, portStr, _ := net.SplitHostPort(origin.String()) var buf bytes.Buffer
port, _ := strconv.Atoi(portStr)
forwardPayload := struct { host, originPort := parseAddr(origin.String())
DestAddr string
DestPort uint32 writeSSHString(&buf, "localhost")
OriginAddr string err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
OriginPort uint32 if err != nil {
}{ log.Printf("Failed to write string to buffer: %v", err)
DestAddr: "localhost", return nil
DestPort: uint32(destPort),
OriginAddr: host,
OriginPort: uint32(port),
} }
return ssh.Marshal(forwardPayload) writeSSHString(&buf, host)
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
}
return buf.Bytes()
}
func parseAddr(addr string) (string, uint16) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
return "0.0.0.0", uint16(0)
}
port, _ := strconv.Atoi(portStr)
return host, uint16(port)
}
func writeSSHString(buffer *bytes.Buffer, str string) {
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return
}
buffer.WriteString(str)
} }
File diff suppressed because it is too large Load Diff
+17 -20
View File
@@ -10,37 +10,34 @@ import (
"github.com/charmbracelet/lipgloss" "github.com/charmbracelet/lipgloss"
) )
func (m *model) handleCommandSelection(item commandItem) (tea.Model, tea.Cmd) {
switch item.name {
case "slug":
m.showingCommands = false
m.editingSlug = true
m.slugInput.SetValue(m.interaction.slug.String())
m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "tunnel-type":
m.showingCommands = false
m.showingComingSoon = true
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
default:
m.showingCommands = false
return m, nil
}
}
func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) { func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd var cmd tea.Cmd
switch { switch {
case key.Matches(msg, m.keymap.quit), msg.String() == "esc": case key.Matches(msg, m.keymap.quit):
m.showingCommands = false m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case msg.String() == "enter": case msg.String() == "enter":
selectedItem := m.commandList.SelectedItem() selectedItem := m.commandList.SelectedItem()
if selectedItem != nil { if selectedItem != nil {
item := selectedItem.(commandItem) item := selectedItem.(commandItem)
return m.handleCommandSelection(item) if item.name == "slug" {
m.showingCommands = false
m.editingSlug = true
m.slugInput.SetValue(m.interaction.slug.String())
m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" {
m.showingCommands = false
m.showingComingSoon = true
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
}
m.showingCommands = false
return m, nil
} }
case msg.String() == "esc":
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} }
m.commandList, cmd = m.commandList.Update(msg) m.commandList, cmd = m.commandList.Update(msg)
return m, cmd return m, cmd
+110 -140
View File
@@ -23,194 +23,164 @@ func (m *model) dashboardUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
} }
func (m *model) dashboardView() string { func (m *model) dashboardView() string {
isCompact := shouldUseCompactLayout(m.width, BreakpointLarge) titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1)
subtitleStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Italic(true)
urlStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Underline(true).
Italic(true)
urlBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Bold(true).
Italic(true)
keyHintStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Bold(true)
var b strings.Builder var b strings.Builder
b.WriteString(m.renderHeader(isCompact))
b.WriteString(m.renderUserInfo(isCompact))
b.WriteString(m.renderQuickActions(isCompact))
b.WriteString(m.renderFooter(isCompact))
return b.String() isCompact := shouldUseCompactLayout(m.width, 85)
}
func (m *model) renderHeader(isCompact bool) string { var asciiArtMargin int
var b strings.Builder if isCompact {
asciiArtMargin = 0
} else {
asciiArtMargin = 1
}
asciiArtMargin := getMarginValue(isCompact, 0, 1)
asciiArtStyle := lipgloss.NewStyle(). asciiArtStyle := lipgloss.NewStyle().
Bold(true). Bold(true).
Foreground(lipgloss.Color(ColorPrimary)). Foreground(lipgloss.Color("#7D56F4")).
MarginBottom(asciiArtMargin) MarginBottom(asciiArtMargin)
b.WriteString(asciiArtStyle.Render(m.getASCIIArt())) var asciiArt string
b.WriteString("\n") if shouldUseCompactLayout(m.width, 50) {
asciiArt = "TUNNEL PLS"
if !shouldUseCompactLayout(m.width, BreakpointSmall) { } else if isCompact {
b.WriteString(m.renderSubtitle()) asciiArt = `
} else {
b.WriteString("\n")
}
return b.String()
}
func (m *model) getASCIIArt() string {
if shouldUseCompactLayout(m.width, BreakpointTiny) {
return "TUNNEL PLS"
}
if shouldUseCompactLayout(m.width, BreakpointLarge) {
return `
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀ ▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀` █ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
} } else {
asciiArt = `
return `
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗ ████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝ ╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗ ██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║ ██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║ ██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝` ╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
} }
func (m *model) renderSubtitle() string { b.WriteString(asciiArtStyle.Render(asciiArt))
subtitleStyle := lipgloss.NewStyle(). b.WriteString("\n")
Foreground(lipgloss.Color(ColorGray)).
Italic(true)
urlStyle := lipgloss.NewStyle(). if !shouldUseCompactLayout(m.width, 60) {
Foreground(lipgloss.Color(ColorPrimary)). b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
Underline(true). b.WriteString(urlStyle.Render("https://fossy.my.id"))
Italic(true) b.WriteString("\n\n")
} else {
b.WriteString("\n")
}
return subtitleStyle.Render("Secure tunnel service by Bagas • ") +
urlStyle.Render("https://fossy.my.id") + "\n\n"
}
func (m *model) renderUserInfo(isCompact bool) string {
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80) boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
boxPadding := getMarginValue(isCompact, 1, 2) var boxPadding int
boxMargin := getMarginValue(isCompact, 1, 2) var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
responsiveInfoBox := lipgloss.NewStyle(). responsiveInfoBox := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()). Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorPrimary)). BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding). Padding(1, boxPadding).
MarginTop(boxMargin). MarginTop(boxMargin).
MarginBottom(boxMargin). MarginBottom(boxMargin).
Width(boxMaxWidth) Width(boxMaxWidth)
infoContent := m.getUserInfoContent(isCompact) authenticatedUser := m.interaction.user
return responsiveInfoBox.Render(infoContent) + "\n"
}
func (m *model) getUserInfoContent(isCompact bool) string {
userInfoStyle := lipgloss.NewStyle(). userInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWhite)). Foreground(lipgloss.Color("#FAFAFA")).
Bold(true) Bold(true)
sectionHeaderStyle := lipgloss.NewStyle(). sectionHeaderStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorGray)). Foreground(lipgloss.Color("#888888")).
Bold(true) Bold(true)
addressStyle := lipgloss.NewStyle(). addressStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWhite)) Foreground(lipgloss.Color("#FAFAFA"))
urlBoxStyle := lipgloss.NewStyle(). var infoContent string
Foreground(lipgloss.Color(ColorSecondary)). if shouldUseCompactLayout(m.width, 70) {
Bold(true). infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s",
Italic(true)
authenticatedUser := m.interaction.user
tunnelURL := urlBoxStyle.Render(m.getTunnelURL())
if isCompact {
return fmt.Sprintf("👤 %s\n\n%s\n%s",
userInfoStyle.Render(authenticatedUser), userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"), sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(fmt.Sprintf(" %s", tunnelURL))) addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL()))))
} else {
infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL())))
} }
return fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s", b.WriteString(responsiveInfoBox.Render(infoContent))
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(tunnelURL))
}
func (m *model) renderQuickActions(isCompact bool) string {
var b strings.Builder
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color(ColorPrimary)).
PaddingTop(1)
b.WriteString(titleStyle.Render(m.getQuickActionsTitle()))
b.WriteString("\n") b.WriteString("\n")
featureMargin := getMarginValue(isCompact, 1, 2) var quickActionsTitle string
featureStyle := lipgloss.NewStyle(). if shouldUseCompactLayout(m.width, 50) {
Foreground(lipgloss.Color(ColorWhite)). quickActionsTitle = "Actions"
} else if isCompact {
quickActionsTitle = "Quick Actions"
} else {
quickActionsTitle = "✨ Quick Actions"
}
b.WriteString(titleStyle.Render(quickActionsTitle))
b.WriteString("\n")
var featureMargin int
if isCompact {
featureMargin = 1
} else {
featureMargin = 2
}
compactFeatureStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginLeft(featureMargin) MarginLeft(featureMargin)
keyHintStyle := lipgloss.NewStyle(). var commandsText string
Foreground(lipgloss.Color(ColorPrimary)). var quitText string
Bold(true) if shouldUseCompactLayout(m.width, 60) {
commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
} else {
commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
}
commands := m.getActionCommands(keyHintStyle) b.WriteString(compactFeatureStyle.Render(commandsText))
b.WriteString(featureStyle.Render(commands.commandsText))
b.WriteString("\n") b.WriteString("\n")
b.WriteString(featureStyle.Render(commands.quitText)) b.WriteString(compactFeatureStyle.Render(quitText))
if !shouldUseCompactLayout(m.width, 70) {
b.WriteString("\n\n")
footerStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true)
b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
}
return b.String() return b.String()
} }
func (m *model) getQuickActionsTitle() string {
if shouldUseCompactLayout(m.width, BreakpointTiny) {
return "Actions"
}
if shouldUseCompactLayout(m.width, BreakpointLarge) {
return "Quick Actions"
}
return "✨ Quick Actions"
}
type actionCommands struct {
commandsText string
quitText string
}
func (m *model) getActionCommands(keyHintStyle lipgloss.Style) actionCommands {
if shouldUseCompactLayout(m.width, BreakpointSmall) {
return actionCommands{
commandsText: fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]")),
quitText: fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]")),
}
}
return actionCommands{
commandsText: fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]")),
quitText: fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]")),
}
}
func (m *model) renderFooter(isCompact bool) string {
if isCompact {
return ""
}
footerStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorDarkGray)).
Italic(true)
return "\n\n" + footerStyle.Render("Press 'C' to customize your tunnel settings")
}
func getMarginValue(isCompact bool, compactValue, normalValue int) int {
if isCompact {
return compactValue
}
return normalValue
}
+6 -22
View File
@@ -3,9 +3,7 @@ package interaction
import ( import (
"context" "context"
"log" "log"
"sync"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/random"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
@@ -41,7 +39,6 @@ type Forwarder interface {
type CloseFunc func() error type CloseFunc func() error
type interaction struct { type interaction struct {
randomizer random.Random
config config.Config config config.Config
channel ssh.Channel channel ssh.Channel
slug slug.Slug slug slug.Slug
@@ -53,7 +50,6 @@ type interaction struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
mode types.InteractiveMode mode types.InteractiveMode
programMu sync.Mutex
} }
func (i *interaction) SetMode(m types.InteractiveMode) { func (i *interaction) SetMode(m types.InteractiveMode) {
@@ -80,10 +76,9 @@ func (i *interaction) SetWH(w, h int) {
} }
} }
func New(randomizer random.Random, config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction { func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &interaction{ return &interaction{
randomizer: randomizer,
config: config, config: config,
channel: nil, channel: nil,
slug: slug, slug: slug,
@@ -105,10 +100,6 @@ func (i *interaction) Stop() {
if i.cancel != nil { if i.cancel != nil {
i.cancel() i.cancel()
} }
i.programMu.Lock()
defer i.programMu.Unlock()
if i.program != nil { if i.program != nil {
i.program.Kill() i.program.Kill()
i.program = nil i.program = nil
@@ -219,7 +210,6 @@ func (i *interaction) Start() {
ti.Width = 50 ti.Width = 50
m := &model{ m := &model{
randomizer: i.randomizer,
domain: i.config.Domain(), domain: i.config.Domain(),
protocol: protocol, protocol: protocol,
tunnelType: tunnelType, tunnelType: tunnelType,
@@ -244,7 +234,6 @@ func (i *interaction) Start() {
help: help.New(), help: help.New(),
} }
i.programMu.Lock()
i.program = tea.NewProgram( i.program = tea.NewProgram(
m, m,
tea.WithInput(i.channel), tea.WithInput(i.channel),
@@ -255,21 +244,16 @@ func (i *interaction) Start() {
tea.WithoutSignalHandler(), tea.WithoutSignalHandler(),
tea.WithFPS(30), tea.WithFPS(30),
) )
i.programMu.Unlock()
_, err := i.program.Run() _, err := i.program.Run()
if err != nil { if err != nil {
log.Printf("Cannot close tea: %s \n", err) log.Printf("Cannot close tea: %s \n", err)
} }
i.program.Kill()
i.programMu.Lock() i.program = nil
if i.program != nil {
i.program.Kill()
i.program = nil
}
i.programMu.Unlock()
if i.closeFunc != nil { if i.closeFunc != nil {
_ = i.closeFunc() if err := i.closeFunc(); err != nil {
log.Printf("Cannot close session: %s \n", err)
}
} }
} }
File diff suppressed because it is too large Load Diff
-21
View File
@@ -3,7 +3,6 @@ package interaction
import ( import (
"fmt" "fmt"
"time" "time"
"tunnel_pls/internal/random"
"tunnel_pls/types" "tunnel_pls/types"
"github.com/charmbracelet/bubbles/help" "github.com/charmbracelet/bubbles/help"
@@ -23,7 +22,6 @@ func (i commandItem) Title() string { return i.name }
func (i commandItem) Description() string { return i.desc } func (i commandItem) Description() string { return i.desc }
type model struct { type model struct {
randomizer random.Random
domain string domain string
protocol string protocol string
tunnelType types.TunnelType tunnelType types.TunnelType
@@ -42,25 +40,6 @@ type model struct {
height int height int
} }
const (
ColorPrimary = "#7D56F4"
ColorSecondary = "#04B575"
ColorGray = "#888888"
ColorDarkGray = "#666666"
ColorWhite = "#FAFAFA"
ColorError = "#FF0000"
ColorErrorBg = "#3D0000"
ColorWarning = "#FFA500"
ColorWarningBg = "#3D2000"
)
const (
BreakpointTiny = 50
BreakpointSmall = 60
BreakpointMedium = 70
BreakpointLarge = 85
)
func (m *model) getTunnelURL() string { func (m *model) getTunnelURL() string {
if m.tunnelType == types.TunnelTypeHTTP { if m.tunnelType == types.TunnelTypeHTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain) return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
+119 -160
View File
@@ -3,6 +3,7 @@ package interaction
import ( import (
"fmt" "fmt"
"strings" "strings"
"tunnel_pls/internal/random"
"tunnel_pls/types" "tunnel_pls/types"
"github.com/charmbracelet/bubbles/key" "github.com/charmbracelet/bubbles/key"
@@ -21,7 +22,7 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
} }
switch msg.String() { switch msg.String() {
case "esc", "ctrl+c": case "esc":
m.editingSlug = false m.editingSlug = false
m.slugError = "" m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
@@ -40,13 +41,19 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.editingSlug = false m.editingSlug = false
m.slugError = "" m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink) return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "ctrl+c":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
default: default:
if key.Matches(msg, m.keymap.random) { if key.Matches(msg, m.keymap.random) {
newSubdomain, err := m.randomizer.String(20) newSubdomain, err := random.GenerateRandomString(20)
if err != nil { if err != nil {
return m, cmd return m, cmd
} }
m.slugInput.SetValue(newSubdomain) m.slugInput.SetValue(newSubdomain)
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
} }
m.slugError = "" m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg) m.slugInput, cmd = m.slugInput.Update(msg)
@@ -55,211 +62,163 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
} }
func (m *model) slugView() string { func (m *model) slugView() string {
isCompact := shouldUseCompactLayout(m.width, BreakpointMedium) isCompact := shouldUseCompactLayout(m.width, 70)
isVeryCompact := shouldUseCompactLayout(m.width, BreakpointTiny) isVeryCompact := shouldUseCompactLayout(m.width, 50)
var b strings.Builder var boxPadding int
b.WriteString(m.renderSlugTitle(isVeryCompact)) var boxMargin int
if isVeryCompact {
if m.tunnelType != types.TunnelTypeHTTP { boxPadding = 1
b.WriteString(m.renderTCPWarning(isVeryCompact, isCompact)) boxMargin = 1
return b.String() } else if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
} }
b.WriteString(m.renderSlugRules(isVeryCompact, isCompact))
b.WriteString(m.renderSlugInstruction(isVeryCompact))
b.WriteString(m.renderSlugInput(isVeryCompact, isCompact))
b.WriteString(m.renderSlugPreview(isVeryCompact))
b.WriteString(m.renderSlugHelp(isVeryCompact))
return b.String()
}
func (m *model) renderSlugTitle(isVeryCompact bool) string {
titleStyle := lipgloss.NewStyle(). titleStyle := lipgloss.NewStyle().
Bold(true). Bold(true).
Foreground(lipgloss.Color(ColorPrimary)). Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1). PaddingTop(1).
PaddingBottom(1) PaddingBottom(1)
title := "🔧 Edit Subdomain" instructionStyle := lipgloss.NewStyle().
if isVeryCompact { Foreground(lipgloss.Color("#FAFAFA")).
title = "Edit Subdomain" MarginTop(1)
}
return titleStyle.Render(title) + "\n\n" inputBoxStyle := lipgloss.NewStyle().
}
func (m *model) renderTCPWarning(isVeryCompact, isCompact bool) string {
boxPadding := getPaddingValue(isVeryCompact, isCompact)
boxMargin := getMarginValue(isCompact, 1, 2)
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWarning)).
Background(lipgloss.Color(ColorWarningBg)).
Bold(true).
Border(lipgloss.RoundedBorder()). Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorWarning)). BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding). Padding(1, boxPadding).
MarginTop(boxMargin). MarginTop(boxMargin).
MarginBottom(boxMargin). MarginBottom(boxMargin)
Width(warningBoxWidth)
helpStyle := lipgloss.NewStyle(). helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorDarkGray)). Foreground(lipgloss.Color("#666666")).
Italic(true). Italic(true).
MarginTop(1) MarginTop(1)
warningText := m.getTCPWarningText(isVeryCompact) errorBoxStyle := lipgloss.NewStyle().
helpText := m.getTCPHelpText(isVeryCompact) Foreground(lipgloss.Color("#FF0000")).
Background(lipgloss.Color("#3D0000")).
var b strings.Builder Bold(true).
b.WriteString(warningBoxStyle.Render(warningText))
b.WriteString("\n\n")
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
func (m *model) getTCPWarningText(isVeryCompact bool) string {
if isVeryCompact {
return "⚠️ TCP tunnels don't support custom subdomains."
}
return "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
}
func (m *model) getTCPHelpText(isVeryCompact bool) string {
if isVeryCompact {
return "Press any key to go back"
}
return "Press Enter or Esc to go back"
}
func (m *model) renderSlugRules(isVeryCompact, isCompact bool) string {
boxPadding := getPaddingValue(isVeryCompact, isCompact)
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
rulesBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWhite)).
Border(lipgloss.RoundedBorder()). Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorPrimary)). BorderForeground(lipgloss.Color("#FF0000")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1)
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
rulesBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(0, boxPadding). Padding(0, boxPadding).
MarginTop(1). MarginTop(1).
MarginBottom(1). MarginBottom(1).
Width(rulesBoxWidth) Width(rulesBoxWidth)
rulesContent := m.getRulesContent(isVeryCompact, isCompact) var b strings.Builder
return rulesBoxStyle.Render(rulesContent) + "\n" var title string
}
func (m *model) getRulesContent(isVeryCompact, isCompact bool) string {
if isVeryCompact { if isVeryCompact {
return "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -" title = "Edit Subdomain"
} else {
title = "🔧 Edit Subdomain"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
if m.tunnelType != types.TunnelTypeHTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")).
Background(lipgloss.Color("#3D2000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FFA500")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(warningBoxWidth)
var warningText string
if isVeryCompact {
warningText = "⚠️ TCP tunnels don't support custom subdomains."
} else {
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
}
b.WriteString(warningBoxStyle.Render(warningText))
b.WriteString("\n\n")
var helpText string
if isVeryCompact {
helpText = "Press any key to go back"
} else {
helpText = "Press Enter or Esc to go back"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
} }
if isCompact { var rulesContent string
return "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -" if isVeryCompact {
rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
} else if isCompact {
rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
} else {
rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
} }
b.WriteString(rulesBoxStyle.Render(rulesContent))
b.WriteString("\n")
return "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -" var instruction string
}
func (m *model) renderSlugInstruction(isVeryCompact bool) string {
instructionStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWhite)).
MarginTop(1)
instruction := "Enter your custom subdomain:"
if isVeryCompact { if isVeryCompact {
instruction = "Custom subdomain:" instruction = "Custom subdomain:"
} else {
instruction = "Enter your custom subdomain:"
} }
b.WriteString(instructionStyle.Render(instruction))
return instructionStyle.Render(instruction) + "\n" b.WriteString("\n")
}
func (m *model) renderSlugInput(isVeryCompact, isCompact bool) string {
boxPadding := getPaddingValue(isVeryCompact, isCompact)
boxMargin := getMarginValue(isCompact, 1, 2)
if m.slugError != "" { if m.slugError != "" {
return m.renderErrorInput(boxPadding, boxMargin) errorInputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(1)
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
b.WriteString("\n")
} else {
b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
} }
return m.renderNormalInput(boxPadding, boxMargin)
}
func (m *model) renderErrorInput(boxPadding, boxMargin int) string {
errorInputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorError)).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(1)
errorBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorError)).
Background(lipgloss.Color(ColorErrorBg)).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorError)).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1)
var b strings.Builder
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
b.WriteString("\n")
return b.String()
}
func (m *model) renderNormalInput(boxPadding, boxMargin int) string {
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorPrimary)).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin)
return inputBoxStyle.Render(m.slugInput.View()) + "\n"
}
func (m *model) renderSlugPreview(isVeryCompact bool) string {
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain) previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
previewWidth := getResponsiveWidth(m.width, 10, 30, 80) previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
if isVeryCompact { if len(previewURL) > previewWidth-10 {
previewURL = truncateString(previewURL, previewWidth-10) previewURL = truncateString(previewURL, previewWidth-10)
} }
previewStyle := lipgloss.NewStyle(). previewStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorSecondary)). Foreground(lipgloss.Color("#04B575")).
Italic(true). Italic(true).
Width(previewWidth) Width(previewWidth)
b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
b.WriteString("\n")
return previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)) + "\n" var helpText string
}
func (m *model) renderSlugHelp(isVeryCompact bool) string {
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorDarkGray)).
Italic(true).
MarginTop(1)
helpText := "Press Enter to save • CTRL+R for random • Esc to cancel"
if isVeryCompact { if isVeryCompact {
helpText = "Enter: save • CTRL+R: random • Esc: cancel" helpText = "Enter: save • CTRL+R: random • Esc: cancel"
} else {
helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
} }
b.WriteString(helpStyle.Render(helpText))
return helpStyle.Render(helpText) return b.String()
}
func getPaddingValue(isVeryCompact, isCompact bool) int {
if isVeryCompact || isCompact {
return 1
}
return 2
} }
+21 -42
View File
@@ -2,9 +2,6 @@ package lifecycle
import ( import (
"errors" "errors"
"io"
"net"
"sync"
"time" "time"
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
@@ -25,9 +22,7 @@ type SessionRegistry interface {
} }
type lifecycle struct { type lifecycle struct {
mu sync.Mutex
status types.SessionStatus status types.SessionStatus
closeErr error
conn ssh.Conn conn ssh.Conn
channel ssh.Channel channel ssh.Channel
forwarder Forwarder forwarder Forwarder
@@ -54,7 +49,6 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti
type Lifecycle interface { type Lifecycle interface {
Connection() ssh.Conn Connection() ssh.Conn
Channel() ssh.Channel
PortRegistry() portUtil.Port PortRegistry() portUtil.Port
User() string User() string
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
@@ -75,48 +69,33 @@ func (l *lifecycle) User() string {
func (l *lifecycle) SetChannel(channel ssh.Channel) { func (l *lifecycle) SetChannel(channel ssh.Channel) {
l.channel = channel l.channel = channel
} }
func (l *lifecycle) Channel() ssh.Channel {
return l.channel
}
func (l *lifecycle) Connection() ssh.Conn { func (l *lifecycle) Connection() ssh.Conn {
return l.conn return l.conn
} }
func (l *lifecycle) SetStatus(status types.SessionStatus) { func (l *lifecycle) SetStatus(status types.SessionStatus) {
l.mu.Lock()
defer l.mu.Unlock()
l.status = status l.status = status
if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now()
}
} }
func (l *lifecycle) IsActive() bool { func closeIfNotNil(c interface{ Close() error }) error {
l.mu.Lock() if c != nil {
defer l.mu.Unlock() return c.Close()
return l.status == types.SessionStatusRUNNING }
return nil
} }
func (l *lifecycle) Close() error { func (l *lifecycle) Close() error {
l.mu.Lock()
defer l.mu.Unlock()
if l.status == types.SessionStatusCLOSED {
return l.closeErr
}
l.status = types.SessionStatusCLOSED
var errs []error var errs []error
tunnelType := l.forwarder.TunnelType() tunnelType := l.forwarder.TunnelType()
if l.channel != nil { if err := closeIfNotNil(l.channel); err != nil {
if err := l.channel.Close(); err != nil && !isClosedError(err) { errs = append(errs, err)
errs = append(errs, err)
}
} }
if l.conn != nil { if err := closeIfNotNil(l.conn); err != nil {
if err := l.conn.Close(); err != nil && !isClosedError(err) { errs = append(errs, err)
errs = append(errs, err)
}
} }
clientSlug := l.slug.String() clientSlug := l.slug.String()
@@ -127,19 +106,19 @@ func (l *lifecycle) Close() error {
l.sessionRegistry.Remove(key) l.sessionRegistry.Remove(key)
if tunnelType == types.TunnelTypeTCP { if tunnelType == types.TunnelTypeTCP {
errs = append(errs, l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false)) if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil {
errs = append(errs, l.forwarder.Close()) errs = append(errs, err)
}
if err := l.forwarder.Close(); err != nil {
errs = append(errs, err)
}
} }
l.closeErr = errors.Join(errs...) return errors.Join(errs...)
return l.closeErr
} }
func isClosedError(err error) bool { func (l *lifecycle) IsActive() bool {
if err == nil { return l.status == types.SessionStatusRUNNING
return false
}
return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || err.Error() == "EOF"
} }
func (l *lifecycle) StartedAt() time.Time { func (l *lifecycle) StartedAt() time.Time {
-303
View File
@@ -1,303 +0,0 @@
package lifecycle
import (
"context"
"errors"
"io"
"net"
"testing"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/crypto/ssh"
)
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Remove(key types.SessionKey) {
m.Called(key)
}
type MockForwarder struct {
mock.Mock
}
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
args := m.Called(origin)
return args.Get(0).([]byte)
}
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
m.Called(dst, src)
}
func (m *MockForwarder) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockForwarder) TunnelType() types.TunnelType {
args := m.Called()
return args.Get(0).(types.TunnelType)
}
func (m *MockForwarder) ForwardedPort() uint16 {
args := m.Called()
return args.Get(0).(uint16)
}
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
m.Called(tunnelType)
}
func (m *MockForwarder) SetForwardedPort(port uint16) {
m.Called(port)
}
func (m *MockForwarder) SetListener(listener net.Listener) {
m.Called(listener)
}
func (m *MockForwarder) Listener() net.Listener {
args := m.Called()
return args.Get(0).(net.Listener)
}
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
args := m.Called(ctx, origin)
if args.Get(0) == nil {
return nil, nil, args.Error(2)
}
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
type MockPort struct {
mock.Mock
}
func (m *MockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
var port uint16
if args.Get(0) != nil {
switch v := args.Get(0).(type) {
case int:
port = uint16(v)
case uint16:
port = v
case uint32:
port = uint16(v)
case int32:
port = uint16(v)
case float64:
port = uint16(v)
default:
port = uint16(args.Int(0))
}
}
return port, args.Bool(1)
}
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *MockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type MockSlug struct {
mock.Mock
}
func (ms *MockSlug) Set(slug string) {
ms.Called(slug)
}
func (ms *MockSlug) String() string {
return ms.Called().String(0)
}
type MockSSHConn struct {
ssh.Conn
mock.Mock
}
func (m *MockSSHConn) Close() error {
args := m.Called()
return args.Error(0)
}
type MockSSHChannel struct {
ssh.Channel
mock.Mock
}
func (m *MockSSHChannel) Close() error {
return m.Called().Error(0)
}
func TestNew(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
assert.NotNil(t, mockLifecycle.Connection())
assert.NotNil(t, mockLifecycle.User())
assert.NotNil(t, mockLifecycle.PortRegistry())
assert.NotNil(t, mockLifecycle.StartedAt())
}
func TestLifecycle_User(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
user := "mas-fuad"
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user)
assert.Equal(t, user, mockLifecycle.User())
}
func TestLifecycle_SetChannel(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
mockSSHChannel := &MockSSHChannel{}
mockLifecycle.SetChannel(mockSSHChannel)
assert.Equal(t, mockSSHChannel, mockLifecycle.Channel())
}
func TestLifecycle_SetStatus(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
assert.True(t, mockLifecycle.IsActive())
}
func TestLifecycle_IsActive(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
assert.True(t, mockLifecycle.IsActive())
}
func TestLifecycle_Close(t *testing.T) {
tests := []struct {
name string
tunnelType types.TunnelType
connCloseErr error
channelCloseErr error
expectErr bool
alreadyClosed bool
}{
{
name: "Close HTTP forwarding success",
tunnelType: types.TunnelTypeHTTP,
expectErr: false,
},
{
name: "Close TCP forwarding success",
tunnelType: types.TunnelTypeTCP,
expectErr: false,
},
{
name: "Close with conn close error",
tunnelType: types.TunnelTypeHTTP,
connCloseErr: errors.New("conn close error"),
expectErr: true,
},
{
name: "Close with channel close error",
tunnelType: types.TunnelTypeHTTP,
channelCloseErr: errors.New("channel close error"),
expectErr: true,
},
{
name: "Close when already closed",
tunnelType: types.TunnelTypeHTTP,
alreadyClosed: true,
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSSHConn := &MockSSHConn{}
mockSSHConn.On("Close").Return(tt.connCloseErr)
mockForwarder := &MockForwarder{}
mockForwarder.On("TunnelType").Return(tt.tunnelType)
if tt.tunnelType == types.TunnelTypeTCP {
mockForwarder.On("ForwardedPort").Return(uint16(8080))
mockForwarder.On("Close").Return(nil)
}
mockSlug := &MockSlug{}
mockSlug.On("String").Return("test-slug")
mockPort := &MockPort{}
if tt.tunnelType == types.TunnelTypeTCP {
mockPort.On("SetStatus", uint16(8080), false).Return(nil)
}
mockSessionRegistry := &MockSessionRegistry{}
mockSessionRegistry.On("Remove", mock.Anything).Return()
mockSSHChannel := &MockSSHChannel{}
mockSSHChannel.On("Close").Return(tt.channelCloseErr)
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
mockLifecycle.SetChannel(mockSSHChannel)
if tt.alreadyClosed {
err := mockLifecycle.Close()
assert.NoError(t, err)
}
err := mockLifecycle.Close()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.False(t, mockLifecycle.IsActive())
mockSSHConn.AssertExpectations(t)
mockForwarder.AssertExpectations(t)
mockSlug.AssertExpectations(t)
mockPort.AssertExpectations(t)
mockSessionRegistry.AssertExpectations(t)
mockSSHChannel.AssertExpectations(t)
})
}
}
+75 -56
View File
@@ -1,6 +1,7 @@
package session package session
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@@ -36,7 +37,6 @@ type Session interface {
} }
type session struct { type session struct {
randomizer random.Random
config config.Config config config.Config
initialReq <-chan *ssh.Request initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel sshChan <-chan ssh.NewChannel
@@ -47,35 +47,23 @@ type session struct {
registry registry.Registry registry registry.Registry
} }
type Config struct {
Randomizer random.Random
Config config.Config
Conn *ssh.ServerConn
InitialReq <-chan *ssh.Request
SshChan <-chan ssh.NewChannel
SessionRegistry registry.Registry
PortRegistry portUtil.Port
User string
}
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017} var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func New(conf *Config) Session { func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
slugManager := slug.New() slugManager := slug.New()
forwarderManager := forwarder.New(conf.Config, slugManager, conf.Conn) forwarderManager := forwarder.New(config, slugManager, conn)
lifecycleManager := lifecycle.New(conf.Conn, forwarderManager, slugManager, conf.PortRegistry, conf.SessionRegistry, conf.User) lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
interactionManager := interaction.New(conf.Randomizer, conf.Config, slugManager, forwarderManager, conf.SessionRegistry, conf.User, lifecycleManager.Close) interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
return &session{ return &session{
randomizer: conf.Randomizer, config: config,
config: conf.Config, initialReq: initialReq,
initialReq: conf.InitialReq, sshChan: sshChan,
sshChan: conf.SshChan,
lifecycle: lifecycleManager, lifecycle: lifecycleManager,
interaction: interactionManager, interaction: interactionManager,
forwarder: forwarderManager, forwarder: forwarderManager,
slug: slugManager, slug: slugManager,
registry: conf.SessionRegistry, registry: sessionRegistry,
} }
} }
@@ -97,12 +85,12 @@ func (s *session) Slug() slug.Slug {
func (s *session) Detail() *types.Detail { func (s *session) Detail() *types.Detail {
tunnelTypeMap := map[types.TunnelType]string{ tunnelTypeMap := map[types.TunnelType]string{
types.TunnelTypeHTTP: "HTTP", types.TunnelTypeHTTP: "TunnelTypeHTTP",
types.TunnelTypeTCP: "TCP", types.TunnelTypeTCP: "TunnelTypeTCP",
} }
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()] tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
if !ok { if !ok {
tunnelType = "UNKNOWN" tunnelType = "TunnelTypeUNKNOWN"
} }
return &types.Detail{ return &types.Detail{
@@ -125,7 +113,7 @@ func (s *session) Start() error {
} }
if s.shouldRejectUnauthorized() { if s.shouldRejectUnauthorized() {
return s.denyForwardingRequest(tcpipReq, nil, nil, "headless forwarding only allowed on node mode") return s.denyForwardingRequest(tcpipReq, nil, nil, fmt.Sprintf("headless forwarding only allowed on node mode"))
} }
if err := s.HandleTCPIPForward(tcpipReq); err != nil { if err := s.HandleTCPIPForward(tcpipReq); err != nil {
@@ -172,11 +160,13 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
} }
func (s *session) handleMissingForwardRequest() error { func (s *session) handleMissingForwardRequest() error {
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort())) err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
if err != nil { if err != nil {
return err return err
} }
if err = s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("no forwarding Request") return fmt.Errorf("no forwarding Request")
} }
@@ -192,6 +182,7 @@ func (s *session) waitForSessionEnd() error {
} }
if err := s.lifecycle.Close(); err != nil { if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err return err
} }
return nil return nil
@@ -236,7 +227,8 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
for req := range GlobalRequest { for req := range GlobalRequest {
switch req.Type { switch req.Type {
case "shell", "pty-req": case "shell", "pty-req":
if err := req.Reply(true, nil); err != nil { err := req.Reply(true, nil)
if err != nil {
return err return err
} }
case "window-change": case "window-change":
@@ -245,7 +237,8 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
} }
default: default:
log.Println("Unknown request type:", req.Type) log.Println("Unknown request type:", req.Type)
if err := req.Reply(false, nil); err != nil { err := req.Reply(false, nil)
if err != nil {
return err return err
} }
} }
@@ -253,24 +246,24 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
return nil return nil
} }
func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) { func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) {
var forwardPayload struct { address, err = readSSHString(payloadReader)
BindAddr string if err != nil {
BindPort uint32 return "", 0, err
} }
if err = ssh.Unmarshal(payload, &forwardPayload); err != nil { var rawPortToBind uint32
return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err) if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil {
return "", 0, err
} }
if forwardPayload.BindPort > 65535 { if rawPortToBind > 65535 {
return "", 0, fmt.Errorf("port is larger than allowed port of 65535") return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
} }
port = uint16(forwardPayload.BindPort) port = uint16(rawPortToBind)
if isBlockedPort(port) { if isBlockedPort(port) {
return "", 0, fmt.Errorf("port is blocked") return "", 0, fmt.Errorf("port is block")
} }
if port == 0 { if port == 0 {
@@ -278,10 +271,10 @@ func (s *session) parseForwardPayload(payload []byte) (address string, port uint
if !ok { if !ok {
return "", 0, fmt.Errorf("no available port") return "", 0, fmt.Errorf("no available port")
} }
return forwardPayload.BindAddr, unassigned, nil return address, unassigned, err
} }
return forwardPayload.BindAddr, port, nil return address, port, err
} }
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error { func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
@@ -289,25 +282,37 @@ func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey,
if key != nil { if key != nil {
s.registry.Remove(*key) s.registry.Remove(*key)
} }
if listener != nil { if listener != nil {
errs = append(errs, listener.Close()) if err := listener.Close(); err != nil {
errs = append(errs, fmt.Errorf("close listener: %w", err))
}
}
if err := req.Reply(false, nil); err != nil {
errs = append(errs, fmt.Errorf("reply request: %w", err))
}
if err := s.lifecycle.Close(); err != nil {
errs = append(errs, fmt.Errorf("close session: %w", err))
} }
errs = append(errs, req.Reply(false, nil))
errs = append(errs, s.lifecycle.Close())
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg)) errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
return errors.Join(errs...) return errors.Join(errs...)
} }
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error { func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) {
replyPayload := struct { buf := new(bytes.Buffer)
BoundPort uint32 err = binary.Write(buf, binary.BigEndian, uint32(port))
}{ if err != nil {
BoundPort: uint32(portToBind), return err
} }
err := req.Reply(true, ssh.Marshal(replyPayload))
err = req.Reply(true, buf.Bytes())
if err != nil {
return err
}
return nil
}
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
err := s.approveForwardingRequest(req, portToBind)
if err != nil { if err != nil {
return err return err
} }
@@ -325,7 +330,9 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
} }
func (s *session) HandleTCPIPForward(req *ssh.Request) error { func (s *session) HandleTCPIPForward(req *ssh.Request) error {
address, port, err := s.parseForwardPayload(req.Payload) reader := bytes.NewReader(req.Payload)
address, port, err := s.parseForwardPayload(reader)
if err != nil { if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error())) return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
} }
@@ -339,7 +346,7 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) error {
} }
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error { func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
randomString, err := s.randomizer.String(20) randomString, err := random.GenerateRandomString(20)
if err != nil { if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err)) return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
} }
@@ -357,13 +364,13 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error { func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed { if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind)) return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
} }
tcpServer := transport.NewTCPServer(portToBind, s.forwarder) tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
listener, err := tcpServer.Listen() listener, err := tcpServer.Listen()
if err != nil { if err != nil {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind)) return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
} }
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP} key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
@@ -386,6 +393,18 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
return nil return nil
} }
func readSSHString(reader io.Reader) (string, error) {
var length uint32
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
return "", err
}
strBytes := make([]byte, length)
if _, err := reader.Read(strBytes); err != nil {
return "", err
}
return string(strBytes), nil
}
func isBlockedPort(port uint16) bool { func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 { if port == 80 || port == 443 {
return false return false
File diff suppressed because it is too large Load Diff
-99
View File
@@ -1,99 +0,0 @@
package slug
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type SlugTestSuite struct {
suite.Suite
slug Slug
}
func (suite *SlugTestSuite) SetupTest() {
suite.slug = New()
}
func TestNew(t *testing.T) {
s := New()
assert.NotNil(t, s, "New() should return a non-nil Slug")
assert.Implements(t, (*Slug)(nil), s, "New() should return a type that implements Slug interface")
assert.Equal(t, "", s.String(), "New() should initialize with empty string")
}
func (suite *SlugTestSuite) TestString() {
assert.Equal(suite.T(), "", suite.slug.String(), "String() should return empty string initially")
suite.slug.Set("test-slug")
assert.Equal(suite.T(), "test-slug", suite.slug.String(), "String() should return the set value")
}
func (suite *SlugTestSuite) TestSet() {
testCases := []struct {
name string
input string
expected string
}{
{
name: "simple slug",
input: "hello-world",
expected: "hello-world",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "slug with numbers",
input: "test-123",
expected: "test-123",
},
{
name: "slug with special characters",
input: "hello_world-123",
expected: "hello_world-123",
},
{
name: "overwrite existing slug",
input: "new-slug",
expected: "new-slug",
},
}
for _, tc := range testCases {
suite.Run(tc.name, func() {
suite.slug.Set(tc.input)
assert.Equal(suite.T(), tc.expected, suite.slug.String())
})
}
}
func (suite *SlugTestSuite) TestMultipleSet() {
suite.slug.Set("first-slug")
assert.Equal(suite.T(), "first-slug", suite.slug.String())
suite.slug.Set("second-slug")
assert.Equal(suite.T(), "second-slug", suite.slug.String())
suite.slug.Set("")
assert.Equal(suite.T(), "", suite.slug.String())
}
func TestSlugIsolation(t *testing.T) {
slug1 := New()
slug2 := New()
slug1.Set("slug-one")
slug2.Set("slug-two")
assert.Equal(t, "slug-one", slug1.String(), "First slug should maintain its value")
assert.Equal(t, "slug-two", slug2.String(), "Second slug should maintain its value")
}
func TestSlugTestSuite(t *testing.T) {
suite.Run(t, new(SlugTestSuite))
}
+1
View File
@@ -0,0 +1 @@
sonar.projectKey=tunnel-please
-1
View File
@@ -7,7 +7,6 @@ type SessionStatus int
const ( const (
SessionStatusINITIALIZING SessionStatus = iota SessionStatusINITIALIZING SessionStatus = iota
SessionStatusRUNNING SessionStatusRUNNING
SessionStatusCLOSED
) )
type InteractiveMode int type InteractiveMode int