Compare commits
63 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5b05723e93 | |||
| 22ad935299 | |||
| ebd915e18e | |||
| 728691d119 | |||
| 1344afd1b2 | |||
| 4cbee5079c | |||
| 0b071dfde7 | |||
| 6062c2e11d | |||
| 2a2d484e91 | |||
| 9377233515 | |||
| fab625e13a | |||
| 1ed845bf2d | |||
| 67378aabda | |||
| a26d1672d9 | |||
| 7f44cc7bc0 | |||
| a3f6baa6ae | |||
| 6def82a095 | |||
| 354da27424 | |||
| ee1dc3c3cd | |||
| 65df01fee5 | |||
| 79fd292a77 | |||
| 4041681be6 | |||
| 2ee24c8d51 | |||
| 384bb98f48 | |||
| 9785a97973 | |||
| b8c6359820 | |||
| 8fee8bf92e | |||
| 04c9ddbc13 | |||
| 211745dc26 | |||
| 09aa92a0ae | |||
| 1ed9f3631f | |||
| bd826d6d06 | |||
| 2f5c44ff01 | |||
| d0e052524c | |||
| 24b9872aa4 | |||
| 8b84373036 | |||
| e796ab5328 | |||
| efdfc4ce95 | |||
| 1dc929cc25 | |||
| 14abac6579 | |||
| 21179da4b5 | |||
| 32f8be2891 | |||
| 5af7af3139 | |||
| f4848e9754 | |||
| d2e508c8ef | |||
| 5499b7d08a | |||
| 58f1fdabe1 | |||
| c1fb588cf4 | |||
| 3029996773 | |||
| 3fd179d32b | |||
| a598a10e94 | |||
| 29cabe42d3 | |||
| e534972abc | |||
| a55ff5f6ab | |||
| 50b4127cb3 | |||
| 7e635721fb | |||
| 016df9caee | |||
| d91eecb2a0 | |||
| 961a905542 | |||
| 634c8321ef | |||
| 9f4c24a3f3 | |||
| 1408b80917 | |||
| 2bc20dd991 |
+40
-80
@@ -2,24 +2,38 @@ name: Docker Build and Push
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- staging
|
||||
tags:
|
||||
- 'v*'
|
||||
paths:
|
||||
- '**.go'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Dockerfile'
|
||||
- 'Dockerfile.*'
|
||||
- '.dockerignore'
|
||||
- '.gitea/workflows/build.yml'
|
||||
|
||||
jobs:
|
||||
build-and-push-branches:
|
||||
test:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: github.ref_type == 'branch'
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 'stable'
|
||||
cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Run tests
|
||||
run: go test -v -p 4 ./...
|
||||
|
||||
|
||||
build-and-push:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -28,64 +42,7 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: git.fossy.my.id
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Set version variables
|
||||
id: vars
|
||||
run: |
|
||||
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
|
||||
echo "VERSION=dev-main" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "VERSION=dev-staging" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
echo "BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT
|
||||
echo "COMMIT=${{ github.sha }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build and push Docker image for main
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: |
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:latest
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
VERSION=${{ steps.vars.outputs.VERSION }}
|
||||
BUILD_DATE=${{ steps.vars.outputs.BUILD_DATE }}
|
||||
COMMIT=${{ steps.vars.outputs.COMMIT }}
|
||||
if: github.ref == 'refs/heads/main'
|
||||
|
||||
- name: Build and push Docker image for staging
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: |
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:staging
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
VERSION=${{ steps.vars.outputs.VERSION }}
|
||||
BUILD_DATE=${{ steps.vars.outputs.BUILD_DATE }}
|
||||
COMMIT=${{ steps.vars.outputs.COMMIT }}
|
||||
if: github.ref == 'refs/heads/staging'
|
||||
|
||||
build-and-push-tags:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.ref_type == 'tag' && startsWith(github.ref, 'refs/tags/v')
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
- name: Log in to Docker Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: git.fossy.my.id
|
||||
@@ -103,32 +60,35 @@ jobs:
|
||||
if echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$'; then
|
||||
MAJOR=$(echo "$VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$VERSION" | cut -d. -f2)
|
||||
|
||||
PATCH=$(echo "$VERSION" | cut -d. -f3 | cut -d- -f1)
|
||||
|
||||
echo "MAJOR=$MAJOR" >> $GITHUB_OUTPUT
|
||||
echo "MINOR=$MINOR" >> $GITHUB_OUTPUT
|
||||
|
||||
echo "PATCH=$PATCH" >> $GITHUB_OUTPUT
|
||||
|
||||
if echo "$VERSION" | grep -q '-'; then
|
||||
PRERELEASE_TAG=$(echo "$VERSION" | cut -d- -f2 | cut -d. -f1)
|
||||
echo "IS_PRERELEASE=true" >> $GITHUB_OUTPUT
|
||||
echo "ADDITIONAL_TAG=staging" >> $GITHUB_OUTPUT
|
||||
echo "PRERELEASE_TAG=$PRERELEASE_TAG" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "IS_PRERELEASE=false" >> $GITHUB_OUTPUT
|
||||
echo "ADDITIONAL_TAG=latest" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
else
|
||||
echo "Invalid version format: $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Build and push Docker image for release
|
||||
- name: Build and push Docker image (release)
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: |
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:release
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }}.${{ steps.version.outputs.MINOR }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:${{ steps.version.outputs.ADDITIONAL_TAG }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:latest
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
VERSION=${{ steps.version.outputs.VERSION }}
|
||||
@@ -136,17 +96,17 @@ jobs:
|
||||
COMMIT=${{ steps.version.outputs.COMMIT }}
|
||||
if: steps.version.outputs.IS_PRERELEASE == 'false'
|
||||
|
||||
- name: Build and push Docker image for pre-release
|
||||
- name: Build and push Docker image (pre-release)
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: |
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:${{ steps.version.outputs.ADDITIONAL_TAG }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:staging
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
VERSION=${{ steps.version.outputs.VERSION }}
|
||||
BUILD_DATE=${{ steps.version.outputs.BUILD_DATE }}
|
||||
COMMIT=${{ steps.version.outputs.COMMIT }}
|
||||
if: steps.version.outputs.IS_PRERELEASE == 'true'
|
||||
if: steps.version.outputs.IS_PRERELEASE == 'true'
|
||||
@@ -0,0 +1,60 @@
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- staging
|
||||
- 'feat/**'
|
||||
|
||||
name: SonarQube Scan
|
||||
jobs:
|
||||
sonarqube:
|
||||
name: SonarQube Trigger
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checking out
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 'stable'
|
||||
cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: go mod tidy
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./... 2>&1 | tee vet-results.txt
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
go test ./... -v -p 4 -coverprofile=coverage
|
||||
|
||||
- name: Run GolangCI-Lint Analysis
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
skip-cache: true
|
||||
version: v2.6
|
||||
args: >
|
||||
--issues-exit-code=0
|
||||
--output.text.path=stdout
|
||||
--output.checkstyle.path=golangci-lint-report.xml
|
||||
|
||||
- name: SonarQube Scan
|
||||
uses: SonarSource/sonarqube-scan-action@v7.0.0
|
||||
env:
|
||||
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
|
||||
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
|
||||
with:
|
||||
args: >
|
||||
-Dsonar.projectKey=tunnel-please
|
||||
-Dsonar.go.coverage.reportPaths=coverage
|
||||
-Dsonar.test.inclusions=**/*_test.go
|
||||
-Dsonar.test.exclusions=**/vendor/**
|
||||
-Dsonar.exclusions=**/*_test.go,**/vendor/**,**/golangci-lint-report.xml
|
||||
-Dsonar.go.govet.reportPaths=vet-results.txt
|
||||
-Dsonar.go.golangci-lint.reportPaths=golangci-lint-report.xml
|
||||
-Dsonar.sources=./
|
||||
-Dsonar.tests=./
|
||||
@@ -0,0 +1,36 @@
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event_name == 'pull_request' ||
|
||||
(github.event_name == 'issue_comment' &&
|
||||
github.event.issue.pull_request != null &&
|
||||
contains(github.event.comment.body, '/retest'))
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 'stable'
|
||||
cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Run tests
|
||||
run: go test -v -p 4 ./...
|
||||
Vendored
+3
-1
@@ -4,4 +4,6 @@ id_rsa*
|
||||
.env
|
||||
tmp
|
||||
certs
|
||||
app
|
||||
app
|
||||
coverage
|
||||
test-results.json
|
||||
+4
-1
@@ -22,7 +22,10 @@ RUN --mount=type=cache,target=/go/pkg/mod \
|
||||
--mount=type=cache,target=/root/.cache/go-build \
|
||||
CGO_ENABLED=0 GOOS=linux \
|
||||
go build -trimpath \
|
||||
-ldflags="-w -s -X tunnel_pls/version.Version=${VERSION} -X tunnel_pls/version.BuildDate=${BUILD_DATE} -X tunnel_pls/version.Commit=${COMMIT}" \
|
||||
-ldflags="-w -s \
|
||||
-X tunnel_pls/internal/version.Version=${VERSION} \
|
||||
-X tunnel_pls/internal/version.BuildDate=${BUILD_DATE} \
|
||||
-X tunnel_pls/internal/version.Commit=${COMMIT}" \
|
||||
-o /app/tunnel_pls \
|
||||
.
|
||||
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
<div align="center">
|
||||
|
||||
<img alt="gopher" title="gopher" src="./docs/images/gopher.png" width="325" />
|
||||
|
||||
# Tunnel Please
|
||||
|
||||
A lightweight SSH-based tunnel server written in Go that enables secure TCP and HTTP forwarding with an interactive terminal interface for managing connections and custom subdomains.
|
||||
A lightweight SSH-based tunnel server
|
||||
|
||||
<br/><br/>
|
||||
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
|
||||
</div>
|
||||
|
||||
## Features
|
||||
|
||||
@@ -17,108 +33,32 @@ A lightweight SSH-based tunnel server written in Go that enables secure TCP and
|
||||
|
||||
The following environment variables can be configured in the `.env` file:
|
||||
|
||||
| Variable | Description | Default | Required |
|
||||
|----------|-------------|---------|----------|
|
||||
| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
|
||||
| `PORT` | SSH server port | `2200` | No |
|
||||
| `HTTP_PORT` | HTTP server port | `8080` | No |
|
||||
| `HTTPS_PORT` | HTTPS server port | `8443` | No |
|
||||
| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
|
||||
| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
|
||||
| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No |
|
||||
| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | - | Yes (if auto-cert) |
|
||||
| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
|
||||
| `CORS_LIST` | Comma-separated list of allowed CORS origins | - | No |
|
||||
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
|
||||
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
|
||||
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
|
||||
| `PPROF_PORT` | Port for pprof server | `6060` | No |
|
||||
| `MODE` | Runtime mode: `standalone` (default, no gRPC/auth) or `node` (enable gRPC + auth) | `standalone` | No |
|
||||
| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
|
||||
| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
|
||||
| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | - (required in `node`) | Yes (node mode) |
|
||||
| Variable | Description | Default | Required |
|
||||
|---------------------|-----------------------------------------------------------------------------|-------------------------|---------------------|
|
||||
| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
|
||||
| `PORT` | SSH server port | `2200` | No |
|
||||
| `HTTP_PORT` | HTTP server port | `8080` | No |
|
||||
| `HTTPS_PORT` | HTTPS server port | `8443` | No |
|
||||
| `KEY_LOC` | Path to the private key file | `certs/privkey.pem` | No |
|
||||
| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
|
||||
| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
|
||||
| `TLS_STORAGE_PATH` | Path to store TLS certificates | `certs/tls/` | No |
|
||||
| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No |
|
||||
| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | `-` | Yes (if auto-cert) |
|
||||
| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
|
||||
| `CORS_LIST` | Comma-separated list of allowed CORS origins | `-` | No |
|
||||
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
|
||||
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
|
||||
| `MAX_HEADER_SIZE` | Maximum size of HTTP headers in bytes (4096-131072) | `4096` | No |
|
||||
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
|
||||
| `PPROF_PORT` | Port for pprof server | `6060` | No |
|
||||
| `MODE` | Runtime mode: `standalone` or `node` | `standalone` | No |
|
||||
| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
|
||||
| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
|
||||
| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | `-` | Yes (node mode) |
|
||||
|
||||
**Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality.
|
||||
|
||||
### Automatic TLS Certificate Management
|
||||
|
||||
The server supports automatic TLS certificate generation and renewal using [CertMagic](https://github.com/caddyserver/certmagic) with Cloudflare DNS-01 challenge. This is required for wildcard certificate support (`*.yourdomain.com`).
|
||||
|
||||
**Certificate Storage:**
|
||||
- TLS certificates are stored in `certs/tls/` (relative to application directory)
|
||||
- User-provided certificates: `certs/tls/cert.pem` and `certs/tls/privkey.pem`
|
||||
- CertMagic automatic certificates: `certs/tls/certmagic/`
|
||||
- SSH keys are stored separately in `certs/ssh/`
|
||||
|
||||
**How it works:**
|
||||
1. If user-provided certificates exist at `certs/tls/cert.pem` and `certs/tls/privkey.pem` and cover both `DOMAIN` and `*.DOMAIN`, they will be used
|
||||
2. If certificates are missing, expired, expiring within 30 days, or don't cover the required domains, CertMagic will automatically obtain new certificates from Let's Encrypt
|
||||
3. Certificates are automatically renewed before expiration
|
||||
4. User-provided certificates support hot-reload (changes detected every 30 seconds)
|
||||
|
||||
**Cloudflare API Token Setup:**
|
||||
|
||||
To use automatic certificate generation, you need a Cloudflare API token with the following permissions:
|
||||
|
||||
1. Go to [Cloudflare Dashboard](https://dash.cloudflare.com/profile/api-tokens)
|
||||
2. Click "Create Token"
|
||||
3. Use "Create Custom Token" with these permissions:
|
||||
- **Zone → Zone → Read** (for all zones or specific zone)
|
||||
- **Zone → DNS → Edit** (for all zones or specific zone)
|
||||
4. Copy the token and set it as `CF_API_TOKEN` environment variable
|
||||
|
||||
**Example configuration for automatic certificates:**
|
||||
```env
|
||||
DOMAIN=example.com
|
||||
TLS_ENABLED=true
|
||||
CF_API_TOKEN=your_cloudflare_api_token_here
|
||||
ACME_EMAIL=admin@example.com
|
||||
# ACME_STAGING=true # Uncomment for testing to avoid rate limits
|
||||
```
|
||||
|
||||
### SSH Key Auto-Generation
|
||||
|
||||
The application will automatically generate a new 4096-bit RSA key pair at `certs/ssh/id_rsa` if it doesn't exist. This makes it easier to get started without manually creating SSH keys. SSH keys are stored separately from TLS certificates.
|
||||
|
||||
### Memory Optimization
|
||||
|
||||
The application uses a buffer pool with controlled buffer sizes to prevent excessive memory usage under high concurrent loads. The `BUFFER_SIZE` environment variable controls the size of buffers used for io.Copy operations:
|
||||
|
||||
- **Default:** 32768 bytes (32 KB) - Good balance for most scenarios
|
||||
- **Minimum:** 4096 bytes (4 KB) - Lower memory usage, more CPU overhead
|
||||
- **Maximum:** 1048576 bytes (1 MB) - Higher throughput, more memory usage
|
||||
|
||||
**Recommended settings based on load:**
|
||||
- **Low traffic (<100 concurrent):** `BUFFER_SIZE=32768` (default)
|
||||
- **High traffic (>100 concurrent):** `BUFFER_SIZE=16384` or `BUFFER_SIZE=8192`
|
||||
- **Very high traffic (>1000 concurrent):** `BUFFER_SIZE=8192` or `BUFFER_SIZE=4096`
|
||||
|
||||
The buffer pool reuses buffers across connections, preventing memory fragmentation and reducing garbage collection pressure.
|
||||
|
||||
### Profiling with pprof
|
||||
|
||||
To enable profiling for performance analysis:
|
||||
|
||||
1. Set `PPROF_ENABLED=true` in your `.env` file
|
||||
2. Optionally set `PPROF_PORT` to your desired port (default: 6060)
|
||||
3. Access profiling data at `http://localhost:6060/debug/pprof/`
|
||||
|
||||
Common pprof endpoints:
|
||||
- `/debug/pprof/` - Index page with available profiles
|
||||
- `/debug/pprof/heap` - Memory allocation profile
|
||||
- `/debug/pprof/goroutine` - Stack traces of all current goroutines
|
||||
- `/debug/pprof/profile` - CPU profile (30-second sample by default)
|
||||
- `/debug/pprof/trace` - Execution trace
|
||||
|
||||
Example usage with `go tool pprof`:
|
||||
```bash
|
||||
# Analyze CPU profile
|
||||
go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
|
||||
|
||||
# Analyze memory heap
|
||||
go tool pprof http://localhost:6060/debug/pprof/heap
|
||||
```
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`.
|
||||
@@ -197,22 +137,6 @@ docker-compose -f docker-compose.tcp.yml up -d
|
||||
docker-compose -f docker-compose.root.yml down
|
||||
```
|
||||
|
||||
### Volume Management
|
||||
|
||||
All configurations use a named volume `certs` for persistent storage:
|
||||
- SSH keys: `/app/certs/ssh/`
|
||||
- TLS certificates: `/app/certs/tls/`
|
||||
|
||||
To backup certificates:
|
||||
```bash
|
||||
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar czf /backup/certs-backup.tar.gz -C /data .
|
||||
```
|
||||
|
||||
To restore certificates:
|
||||
```bash
|
||||
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar xzf /backup/certs-backup.tar.gz -C /data
|
||||
```
|
||||
|
||||
### Recommendation
|
||||
|
||||
**Use `docker-compose.root.yml`** for production deployments if you need:
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 2.0 MiB |
@@ -11,6 +11,7 @@ require (
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/libdns/cloudflare v0.2.2
|
||||
github.com/muesli/termenv v0.16.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/crypto v0.47.0
|
||||
google.golang.org/grpc v1.78.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
@@ -27,6 +28,7 @@ require (
|
||||
github.com/clipperhouse/displaywidth v0.6.2 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/libdns/libdns v1.1.1 // indirect
|
||||
@@ -38,8 +40,10 @@ require (
|
||||
github.com/miekg/dns v1.1.69 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/sahilm/fuzzy v0.1.1 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/zeebo/blake3 v0.2.4 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
@@ -52,4 +56,5 @@ require (
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -32,6 +32,7 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
@@ -80,8 +81,18 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
|
||||
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
|
||||
@@ -138,5 +149,8 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
|
||||
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/grpc/client"
|
||||
"tunnel_pls/internal/key"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/internal/transport"
|
||||
"tunnel_pls/internal/version"
|
||||
"tunnel_pls/server"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Bootstrap struct {
|
||||
Randomizer random.Random
|
||||
Config config.Config
|
||||
SessionRegistry registry.Registry
|
||||
Port port.Port
|
||||
GrpcClient client.Client
|
||||
ErrChan chan error
|
||||
SignalChan chan os.Signal
|
||||
}
|
||||
|
||||
func New(config config.Config, port port.Port) (*Bootstrap, error) {
|
||||
randomizer := random.New()
|
||||
sessionRegistry := registry.NewRegistry()
|
||||
|
||||
if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
grpcClient, err := client.New(config, sessionRegistry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errChan := make(chan error, 5)
|
||||
signalChan := make(chan os.Signal, 1)
|
||||
|
||||
return &Bootstrap{
|
||||
Randomizer: randomizer,
|
||||
Config: config,
|
||||
SessionRegistry: sessionRegistry,
|
||||
Port: port,
|
||||
GrpcClient: grpcClient,
|
||||
ErrChan: errChan,
|
||||
SignalChan: signalChan,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) {
|
||||
sshCfg := &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
|
||||
}
|
||||
|
||||
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
|
||||
return nil, fmt.Errorf("generate ssh key: %w", err)
|
||||
}
|
||||
privateBytes, err := os.ReadFile(sshKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read private key: %w", err)
|
||||
}
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
sshCfg.AddHostKey(private)
|
||||
return sshCfg, nil
|
||||
}
|
||||
|
||||
func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error {
|
||||
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer healthCancel()
|
||||
if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil {
|
||||
return fmt.Errorf("gRPC health check failed: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
|
||||
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
|
||||
httpserver := transport.NewHTTPServer(conf, registry)
|
||||
ln, err := httpserver.Listen()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
||||
return
|
||||
}
|
||||
if err = httpserver.Serve(ln); err != nil {
|
||||
errChan <- fmt.Errorf("error when serving http server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func startHTTPSServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
|
||||
tlsCfg, err := transport.NewTLSConfig(conf)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
|
||||
return
|
||||
}
|
||||
httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg)
|
||||
ln, err := httpsServer.Listen()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
|
||||
return
|
||||
}
|
||||
if err = httpsServer.Serve(ln); err != nil {
|
||||
errChan <- fmt.Errorf("error when serving https server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) {
|
||||
sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort())
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
sshServer.Start()
|
||||
|
||||
errChan <- sshServer.Close()
|
||||
}
|
||||
|
||||
func startPprof(pprofPort string, errChan chan<- error) {
|
||||
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
|
||||
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||
errChan <- fmt.Errorf("pprof server error: %v", err)
|
||||
}
|
||||
}
|
||||
func (b *Bootstrap) Run() error {
|
||||
sshConfig, err := newSSHConfig(b.Config.KeyLoc())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create SSH config: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
if b.Config.Mode() == types.ServerModeNODE {
|
||||
err = b.startGRPCClient(ctx, b.Config, b.ErrChan)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start gRPC client: %w", err)
|
||||
}
|
||||
defer func(grpcClient client.Client) {
|
||||
err = grpcClient.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close gRPC client")
|
||||
}
|
||||
}(b.GrpcClient)
|
||||
}
|
||||
|
||||
go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan)
|
||||
|
||||
if b.Config.TLSEnabled() {
|
||||
go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan)
|
||||
}
|
||||
|
||||
go func() {
|
||||
startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan)
|
||||
}()
|
||||
|
||||
if b.Config.PprofEnabled() {
|
||||
go startPprof(b.Config.PprofPort(), b.ErrChan)
|
||||
}
|
||||
|
||||
log.Println("All services started successfully")
|
||||
|
||||
select {
|
||||
case err = <-b.ErrChan:
|
||||
return fmt.Errorf("service error: %w", err)
|
||||
case sig := <-b.SignalChan:
|
||||
log.Printf("Received signal %s, initiating graceful shutdown", sig)
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,558 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
args := m.Called(user, oldKey, newKey)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
args := m.Called(key, session)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
args := m.Called(user)
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
type MockRandom struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRandom) String(length int) (string, error) {
|
||||
args := m.Called(length)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
type MockConfig struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConfig) Domain() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return 0
|
||||
}
|
||||
switch v := args.Get(0).(type) {
|
||||
case types.ServerMode:
|
||||
return v
|
||||
case int:
|
||||
return types.ServerMode(v)
|
||||
default:
|
||||
return types.ServerMode(args.Int(0))
|
||||
}
|
||||
}
|
||||
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
||||
|
||||
type MockPort struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
||||
return m.Called(startPort, endPort).Error(0)
|
||||
}
|
||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
||||
args := m.Called()
|
||||
var mPort uint16
|
||||
if args.Get(0) != nil {
|
||||
switch v := args.Get(0).(type) {
|
||||
case int:
|
||||
mPort = uint16(v)
|
||||
case uint16:
|
||||
mPort = v
|
||||
case uint32:
|
||||
mPort = uint16(v)
|
||||
case int32:
|
||||
mPort = uint16(v)
|
||||
case float64:
|
||||
mPort = uint16(v)
|
||||
default:
|
||||
mPort = uint16(args.Int(0))
|
||||
}
|
||||
}
|
||||
return mPort, args.Bool(1)
|
||||
}
|
||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
||||
return m.Called(port, assigned).Error(0)
|
||||
}
|
||||
func (m *MockPort) Claim(port uint16) bool {
|
||||
return m.Called(port).Bool(0)
|
||||
}
|
||||
|
||||
type MockGRPCClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*grpc.ClientConn)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
||||
args := m.Called(ctx, token)
|
||||
return args.Bool(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
|
||||
args := m.Called(ctx, domain, token)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupConfig func() config.Config
|
||||
setupPort func() port.Port
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Success New with default value",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Error when AddRange fails",
|
||||
setupPort: func() port.Port {
|
||||
mockPort := &MockPort{}
|
||||
mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range"))
|
||||
return mockPort
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "invalid port range",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var mockPort port.Port
|
||||
if tt.setupPort != nil {
|
||||
mockPort = tt.setupPort()
|
||||
} else {
|
||||
mockPort = port.New()
|
||||
}
|
||||
|
||||
var mockConfig config.Config
|
||||
if tt.setupConfig != nil {
|
||||
mockConfig = tt.setupConfig()
|
||||
} else {
|
||||
var err error
|
||||
mockConfig, err = config.MustLoad()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
bootstrap, err := New(mockConfig, mockPort)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, bootstrap)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, bootstrap)
|
||||
assert.NotNil(t, bootstrap.Randomizer)
|
||||
assert.NotNil(t, bootstrap.SessionRegistry)
|
||||
assert.NotNil(t, bootstrap.Config)
|
||||
assert.NotNil(t, bootstrap.Port)
|
||||
assert.NotNil(t, bootstrap.ErrChan)
|
||||
assert.NotNil(t, bootstrap.SignalChan)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func randomAvailablePort() (string, error) {
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func(listener net.Listener) {
|
||||
_ = listener.Close()
|
||||
}(listener)
|
||||
|
||||
mPort := listener.Addr().(*net.TCPAddr).Port
|
||||
return strconv.Itoa(mPort), nil
|
||||
}
|
||||
|
||||
func TestRun(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockErrChan := make(chan error, 1)
|
||||
mockSignalChan := make(chan os.Signal, 1)
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
keyLoc := filepath.Join(tmpDir, "key.key")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupConfig func() *MockConfig
|
||||
setupGrpcClient func() *MockGRPCClient
|
||||
needCerts bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful run and termination",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "error from SSH server invalid port",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("invalid")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "error from HTTP server invalid port",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("invalid")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "error from HTTPS server invalid port",
|
||||
setupConfig: func() *MockConfig {
|
||||
tempDir := os.TempDir()
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("invalid")
|
||||
mockConfig.On("TLSEnabled").Return(true)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("TLSStoragePath").Return(tempDir)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "grpc health check failed",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("invalid")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
setupGrpcClient: func() *MockGRPCClient {
|
||||
mockGRPCClient := &MockGRPCClient{}
|
||||
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed"))
|
||||
return mockGRPCClient
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "successful run with pprof enabled",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
pprofPort, _ := randomAvailablePort()
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(true)
|
||||
mockConfig.On("PprofPort").Return(pprofPort)
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: false,
|
||||
}, {
|
||||
name: "successful run in NODE mode with signal",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
setupGrpcClient: func() *MockGRPCClient {
|
||||
mockGRPCClient := &MockGRPCClient{}
|
||||
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
|
||||
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
mockGRPCClient.On("Close").Return(nil)
|
||||
return mockGRPCClient
|
||||
},
|
||||
expectError: false,
|
||||
}, {
|
||||
name: "successful run in NODE mode with signal buf error when closing",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
setupGrpcClient: func() *MockGRPCClient {
|
||||
mockGRPCClient := &MockGRPCClient{}
|
||||
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
|
||||
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy"))
|
||||
return mockGRPCClient
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockConfig := tt.setupConfig()
|
||||
mockGRPCClient := &MockGRPCClient{}
|
||||
bootstrap := &Bootstrap{
|
||||
Randomizer: mockRandom,
|
||||
Config: mockConfig,
|
||||
SessionRegistry: mockSessionRegistry,
|
||||
Port: mockPort,
|
||||
ErrChan: mockErrChan,
|
||||
SignalChan: mockSignalChan,
|
||||
GrpcClient: mockGRPCClient,
|
||||
}
|
||||
|
||||
if tt.setupGrpcClient != nil {
|
||||
bootstrap.GrpcClient = tt.setupGrpcClient()
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- bootstrap.Run()
|
||||
}()
|
||||
|
||||
if tt.expectError {
|
||||
err := <-done
|
||||
assert.Error(t, err)
|
||||
} else if tt.name == "successful run with pprof enabled" {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
fmt.Println(mockConfig.PprofPort())
|
||||
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
err = resp.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
mockSignalChan <- os.Interrupt
|
||||
err = <-done
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
mockSignalChan <- os.Interrupt
|
||||
err := <-done
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+60
-23
@@ -1,33 +1,70 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
import "tunnel_pls/types"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
type Config interface {
|
||||
Domain() string
|
||||
SSHPort() string
|
||||
|
||||
func Load() error {
|
||||
if _, err := os.Stat(".env"); err == nil {
|
||||
return godotenv.Load(".env")
|
||||
}
|
||||
return nil
|
||||
HTTPPort() string
|
||||
HTTPSPort() string
|
||||
|
||||
KeyLoc() string
|
||||
|
||||
TLSEnabled() bool
|
||||
TLSRedirect() bool
|
||||
TLSStoragePath() string
|
||||
|
||||
ACMEEmail() string
|
||||
CFAPIToken() string
|
||||
ACMEStaging() bool
|
||||
|
||||
AllowedPortsStart() uint16
|
||||
AllowedPortsEnd() uint16
|
||||
|
||||
BufferSize() int
|
||||
HeaderSize() int
|
||||
|
||||
PprofEnabled() bool
|
||||
PprofPort() string
|
||||
|
||||
Mode() types.ServerMode
|
||||
GRPCAddress() string
|
||||
GRPCPort() string
|
||||
NodeToken() string
|
||||
}
|
||||
|
||||
func Getenv(key, defaultValue string) string {
|
||||
val := os.Getenv(key)
|
||||
if val == "" {
|
||||
val = defaultValue
|
||||
func MustLoad() (Config, error) {
|
||||
if err := loadEnvFile(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val
|
||||
cfg, err := parse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func GetBufferSize() int {
|
||||
sizeStr := Getenv("BUFFER_SIZE", "32768")
|
||||
size, err := strconv.Atoi(sizeStr)
|
||||
if err != nil || size < 4096 || size > 1048576 {
|
||||
return 32768
|
||||
}
|
||||
return size
|
||||
}
|
||||
func (c *config) Domain() string { return c.domain }
|
||||
func (c *config) SSHPort() string { return c.sshPort }
|
||||
func (c *config) HTTPPort() string { return c.httpPort }
|
||||
func (c *config) HTTPSPort() string { return c.httpsPort }
|
||||
func (c *config) KeyLoc() string { return c.keyLoc }
|
||||
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
|
||||
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
|
||||
func (c *config) TLSStoragePath() string { return c.tlsStoragePath }
|
||||
func (c *config) ACMEEmail() string { return c.acmeEmail }
|
||||
func (c *config) CFAPIToken() string { return c.cfAPIToken }
|
||||
func (c *config) ACMEStaging() bool { return c.acmeStaging }
|
||||
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
|
||||
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
|
||||
func (c *config) BufferSize() int { return c.bufferSize }
|
||||
func (c *config) HeaderSize() int { return c.headerSize }
|
||||
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
|
||||
func (c *config) PprofPort() string { return c.pprofPort }
|
||||
func (c *config) Mode() types.ServerMode { return c.mode }
|
||||
func (c *config) GRPCAddress() string { return c.grpcAddress }
|
||||
func (c *config) GRPCPort() string { return c.grpcPort }
|
||||
func (c *config) NodeToken() string { return c.nodeToken }
|
||||
|
||||
@@ -0,0 +1,405 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetenv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
val string
|
||||
def string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "returns existing env",
|
||||
key: "TEST_ENV_EXIST",
|
||||
val: "value",
|
||||
def: "default",
|
||||
expected: "value",
|
||||
},
|
||||
{
|
||||
name: "returns default when env missing",
|
||||
key: "TEST_ENV_MISSING",
|
||||
val: "",
|
||||
def: "default",
|
||||
expected: "default",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv(tt.key, tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv(tt.key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetenvBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
val string
|
||||
def bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "returns true when env is true",
|
||||
key: "TEST_BOOL_TRUE",
|
||||
val: "true",
|
||||
def: false,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "returns false when env is false",
|
||||
key: "TEST_BOOL_FALSE",
|
||||
val: "false",
|
||||
def: true,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "returns default when env missing",
|
||||
key: "TEST_BOOL_MISSING",
|
||||
val: "",
|
||||
def: true,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "returns false when env is not true",
|
||||
key: "TEST_BOOL_INVALID",
|
||||
val: "yes",
|
||||
def: true,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv(tt.key, tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv(tt.key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
expect types.ServerMode
|
||||
expectErr bool
|
||||
}{
|
||||
{"standalone", "standalone", types.ServerModeSTANDALONE, false},
|
||||
{"node", "node", types.ServerModeNODE, false},
|
||||
{"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false},
|
||||
{"invalid", "invalid", 0, true},
|
||||
{"empty (default)", "", types.ServerModeSTANDALONE, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.mode != "" {
|
||||
t.Setenv("MODE", tt.mode)
|
||||
} else {
|
||||
err := os.Unsetenv("MODE")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
mode, err := parseMode()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expect, mode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAllowedPorts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
start uint16
|
||||
end uint16
|
||||
expectErr bool
|
||||
}{
|
||||
{"valid range", "1000-2000", 1000, 2000, false},
|
||||
{"empty", "", 0, 0, false},
|
||||
{"invalid format - no dash", "1000", 0, 0, true},
|
||||
{"invalid format - too many dashes", "1000-2000-3000", 0, 0, true},
|
||||
{"invalid start port", "abc-2000", 0, 0, true},
|
||||
{"invalid end port", "1000-abc", 0, 0, true},
|
||||
{"out of range start", "70000-80000", 0, 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv("ALLOWED_PORTS", tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv("ALLOWED_PORTS")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
start, end, err := parseAllowedPorts()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.start, start)
|
||||
assert.Equal(t, tt.end, end)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBufferSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
expect int
|
||||
}{
|
||||
{"valid size", "8192", 8192},
|
||||
{"default size", "", 32768},
|
||||
{"too small", "1024", 4096},
|
||||
{"too large", "2000000", 4096},
|
||||
{"invalid format", "abc", 4096},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv("BUFFER_SIZE", tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv("BUFFER_SIZE")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
size := parseBufferSize()
|
||||
assert.Equal(t, tt.expect, size)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHeaderSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
expect int
|
||||
}{
|
||||
{"valid size", "8192", 8192},
|
||||
{"default size", "", 4096},
|
||||
{"too small", "1024", 4096},
|
||||
{"too large", "2000000", 4096},
|
||||
{"invalid format", "abc", 4096},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv("MAX_HEADER_SIZE", tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv("MAX_HEADER_SIZE")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
size := parseHeaderSize()
|
||||
assert.Equal(t, tt.expect, size)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envs map[string]string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "minimal valid config",
|
||||
envs: map[string]string{
|
||||
"DOMAIN": "example.com",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "TLS enabled without token",
|
||||
envs: map[string]string{
|
||||
"TLS_ENABLED": "true",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "TLS enabled with token",
|
||||
envs: map[string]string{
|
||||
"TLS_ENABLED": "true",
|
||||
"CF_API_TOKEN": "secret",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Node mode without token",
|
||||
envs: map[string]string{
|
||||
"MODE": "node",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Node mode with token",
|
||||
envs: map[string]string{
|
||||
"MODE": "node",
|
||||
"NODE_TOKEN": "token",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid mode",
|
||||
envs: map[string]string{
|
||||
"MODE": "invalid",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid allowed ports",
|
||||
envs: map[string]string{
|
||||
"ALLOWED_PORTS": "1000",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
os.Clearenv()
|
||||
for k, v := range tt.envs {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
cfg, err := parse()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cfg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetters(t *testing.T) {
|
||||
envs := map[string]string{
|
||||
"DOMAIN": "example.com",
|
||||
"PORT": "2222",
|
||||
"HTTP_PORT": "80",
|
||||
"HTTPS_PORT": "443",
|
||||
"KEY_LOC": "certs/ssh/id_rsa",
|
||||
"TLS_ENABLED": "true",
|
||||
"TLS_REDIRECT": "true",
|
||||
"TLS_STORAGE_PATH": "certs/tls/",
|
||||
"ACME_EMAIL": "test@example.com",
|
||||
"CF_API_TOKEN": "token",
|
||||
"ACME_STAGING": "true",
|
||||
"ALLOWED_PORTS": "1000-2000",
|
||||
"BUFFER_SIZE": "16384",
|
||||
"MAX_HEADER_SIZE": "4096",
|
||||
"PPROF_ENABLED": "true",
|
||||
"PPROF_PORT": "7070",
|
||||
"MODE": "standalone",
|
||||
"GRPC_ADDRESS": "127.0.0.1",
|
||||
"GRPC_PORT": "9090",
|
||||
"NODE_TOKEN": "ntoken",
|
||||
}
|
||||
|
||||
os.Clearenv()
|
||||
for k, v := range envs {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
cfg, err := parse()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "example.com", cfg.Domain())
|
||||
assert.Equal(t, "2222", cfg.SSHPort())
|
||||
assert.Equal(t, "80", cfg.HTTPPort())
|
||||
assert.Equal(t, "443", cfg.HTTPSPort())
|
||||
assert.Equal(t, "certs/ssh/id_rsa", cfg.KeyLoc())
|
||||
assert.Equal(t, true, cfg.TLSEnabled())
|
||||
assert.Equal(t, true, cfg.TLSRedirect())
|
||||
assert.Equal(t, "certs/tls/", cfg.TLSStoragePath())
|
||||
assert.Equal(t, "test@example.com", cfg.ACMEEmail())
|
||||
assert.Equal(t, "token", cfg.CFAPIToken())
|
||||
assert.Equal(t, true, cfg.ACMEStaging())
|
||||
assert.Equal(t, uint16(1000), cfg.AllowedPortsStart())
|
||||
assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd())
|
||||
assert.Equal(t, 16384, cfg.BufferSize())
|
||||
assert.Equal(t, 4096, cfg.HeaderSize())
|
||||
assert.Equal(t, true, cfg.PprofEnabled())
|
||||
assert.Equal(t, "7070", cfg.PprofPort())
|
||||
assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode())
|
||||
assert.Equal(t, "127.0.0.1", cfg.GRPCAddress())
|
||||
assert.Equal(t, "9090", cfg.GRPCPort())
|
||||
assert.Equal(t, "ntoken", cfg.NodeToken())
|
||||
}
|
||||
|
||||
func TestMustLoad(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
os.Clearenv()
|
||||
t.Setenv("DOMAIN", "example.com")
|
||||
cfg, err := MustLoad()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg)
|
||||
})
|
||||
|
||||
t.Run("loadEnvFile error", func(t *testing.T) {
|
||||
err := os.Mkdir(".env", 0755)
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
err = os.Remove(".env")
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
cfg, err := MustLoad()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cfg)
|
||||
})
|
||||
|
||||
t.Run("parse error", func(t *testing.T) {
|
||||
os.Clearenv()
|
||||
t.Setenv("MODE", "invalid")
|
||||
cfg, err := MustLoad()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cfg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadEnvFile(t *testing.T) {
|
||||
t.Run("file exists", func(t *testing.T) {
|
||||
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
err = os.Remove(".env")
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
err = loadEnvFile()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE"))
|
||||
})
|
||||
|
||||
t.Run("file missing", func(t *testing.T) {
|
||||
_ = os.Remove(".env")
|
||||
err := loadEnvFile()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
domain string
|
||||
sshPort string
|
||||
|
||||
httpPort string
|
||||
httpsPort string
|
||||
|
||||
keyLoc string
|
||||
|
||||
tlsEnabled bool
|
||||
tlsRedirect bool
|
||||
tlsStoragePath string
|
||||
acmeEmail string
|
||||
cfAPIToken string
|
||||
acmeStaging bool
|
||||
|
||||
allowedPortsStart uint16
|
||||
allowedPortsEnd uint16
|
||||
|
||||
bufferSize int
|
||||
headerSize int
|
||||
|
||||
pprofEnabled bool
|
||||
pprofPort string
|
||||
|
||||
mode types.ServerMode
|
||||
grpcAddress string
|
||||
grpcPort string
|
||||
nodeToken string
|
||||
}
|
||||
|
||||
func parse() (*config, error) {
|
||||
mode, err := parseMode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
domain := getenv("DOMAIN", "localhost")
|
||||
sshPort := getenv("PORT", "2200")
|
||||
|
||||
httpPort := getenv("HTTP_PORT", "8080")
|
||||
httpsPort := getenv("HTTPS_PORT", "8443")
|
||||
|
||||
keyLoc := getenv("KEY_LOC", "certs/privkey.pem")
|
||||
|
||||
tlsEnabled := getenvBool("TLS_ENABLED", false)
|
||||
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
|
||||
tlsStoragePath := getenv("TLS_STORAGE_PATH", "certs/tls/")
|
||||
|
||||
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
|
||||
acmeStaging := getenvBool("ACME_STAGING", false)
|
||||
|
||||
cfToken := getenv("CF_API_TOKEN", "")
|
||||
if tlsEnabled && cfToken == "" {
|
||||
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
|
||||
}
|
||||
|
||||
start, end, err := parseAllowedPorts()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bufferSize := parseBufferSize()
|
||||
headerSize := parseHeaderSize()
|
||||
|
||||
pprofEnabled := getenvBool("PPROF_ENABLED", false)
|
||||
pprofPort := getenv("PPROF_PORT", "6060")
|
||||
|
||||
grpcHost := getenv("GRPC_ADDRESS", "localhost")
|
||||
grpcPort := getenv("GRPC_PORT", "8080")
|
||||
|
||||
nodeToken := getenv("NODE_TOKEN", "")
|
||||
if mode == types.ServerModeNODE && nodeToken == "" {
|
||||
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
|
||||
}
|
||||
|
||||
return &config{
|
||||
domain: domain,
|
||||
sshPort: sshPort,
|
||||
httpPort: httpPort,
|
||||
httpsPort: httpsPort,
|
||||
keyLoc: keyLoc,
|
||||
tlsEnabled: tlsEnabled,
|
||||
tlsRedirect: tlsRedirect,
|
||||
tlsStoragePath: tlsStoragePath,
|
||||
acmeEmail: acmeEmail,
|
||||
cfAPIToken: cfToken,
|
||||
acmeStaging: acmeStaging,
|
||||
allowedPortsStart: start,
|
||||
allowedPortsEnd: end,
|
||||
bufferSize: bufferSize,
|
||||
headerSize: headerSize,
|
||||
pprofEnabled: pprofEnabled,
|
||||
pprofPort: pprofPort,
|
||||
mode: mode,
|
||||
grpcAddress: grpcHost,
|
||||
grpcPort: grpcPort,
|
||||
nodeToken: nodeToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func loadEnvFile() error {
|
||||
if _, err := os.Stat(".env"); err == nil {
|
||||
return godotenv.Load(".env")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseMode() (types.ServerMode, error) {
|
||||
switch strings.ToLower(getenv("MODE", "standalone")) {
|
||||
case "standalone":
|
||||
return types.ServerModeSTANDALONE, nil
|
||||
case "node":
|
||||
return types.ServerModeNODE, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid MODE value")
|
||||
}
|
||||
}
|
||||
|
||||
func parseAllowedPorts() (uint16, uint16, error) {
|
||||
raw := getenv("ALLOWED_PORTS", "")
|
||||
if raw == "" {
|
||||
return 0, 0, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(raw, "-")
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
|
||||
}
|
||||
|
||||
start, err := strconv.ParseUint(parts[0], 10, 16)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
end, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return uint16(start), uint16(end), nil
|
||||
}
|
||||
|
||||
func parseBufferSize() int {
|
||||
raw := getenv("BUFFER_SIZE", "32768")
|
||||
size, err := strconv.Atoi(raw)
|
||||
if err != nil || size < 4096 || size > 1048576 {
|
||||
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
|
||||
return 4096
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func parseHeaderSize() int {
|
||||
raw := getenv("MAX_HEADER_SIZE", "4096")
|
||||
size, err := strconv.Atoi(raw)
|
||||
if err != nil || size < 4096 || size > 131072 {
|
||||
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
|
||||
return 4096
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func getenv(key, def string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func getenvBool(key string, def bool) bool {
|
||||
val := os.Getenv(key)
|
||||
if val == "" {
|
||||
return def
|
||||
}
|
||||
return val == "true"
|
||||
}
|
||||
+139
-127
@@ -29,6 +29,7 @@ type Client interface {
|
||||
CheckServerHealth(ctx context.Context) error
|
||||
}
|
||||
type client struct {
|
||||
config config.Config
|
||||
conn *grpc.ClientConn
|
||||
address string
|
||||
sessionRegistry registry.Registry
|
||||
@@ -37,7 +38,15 @@ type client struct {
|
||||
closing bool
|
||||
}
|
||||
|
||||
func New(address string, sessionRegistry registry.Registry) (Client, error) {
|
||||
var (
|
||||
grpcNewClient = grpc.NewClient
|
||||
healthNewHealthClient = grpc_health_v1.NewHealthClient
|
||||
initialBackoff = time.Second
|
||||
)
|
||||
|
||||
func New(config config.Config, sessionRegistry registry.Registry) (Client, error) {
|
||||
address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort())
|
||||
|
||||
var opts []grpc.DialOption
|
||||
|
||||
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
@@ -57,7 +66,7 @@ func New(address string, sessionRegistry registry.Registry) (Client, error) {
|
||||
),
|
||||
)
|
||||
|
||||
conn, err := grpc.NewClient(address, opts...)
|
||||
conn, err := grpcNewClient(address, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err)
|
||||
}
|
||||
@@ -66,6 +75,7 @@ func New(address string, sessionRegistry registry.Registry) (Client, error) {
|
||||
authorizeConnectionService := proto.NewUserServiceClient(conn)
|
||||
|
||||
return &client{
|
||||
config: config,
|
||||
conn: conn,
|
||||
address: address,
|
||||
sessionRegistry: sessionRegistry,
|
||||
@@ -75,85 +85,100 @@ func New(address string, sessionRegistry registry.Registry) (Client, error) {
|
||||
}
|
||||
|
||||
func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error {
|
||||
const (
|
||||
baseBackoff = time.Second
|
||||
maxBackoff = 30 * time.Second
|
||||
)
|
||||
|
||||
backoff := baseBackoff
|
||||
wait := func() error {
|
||||
if backoff <= 0 {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-time.After(backoff):
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
growBackoff := func() {
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
}
|
||||
backoff := initialBackoff
|
||||
|
||||
for {
|
||||
subscribe, err := c.eventService.Subscribe(ctx)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
||||
return err
|
||||
}
|
||||
if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
|
||||
return err
|
||||
}
|
||||
if err = wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
growBackoff()
|
||||
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
||||
continue
|
||||
if err := c.subscribeAndProcess(ctx, identity, authToken, &backoff); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = subscribe.Send(&proto.Node{
|
||||
Type: proto.EventType_AUTHENTICATION,
|
||||
Payload: &proto.Node_AuthEvent{
|
||||
AuthEvent: &proto.Authentication{
|
||||
Identity: identity,
|
||||
AuthToken: authToken,
|
||||
},
|
||||
func (c *client) subscribeAndProcess(ctx context.Context, identity, authToken string, backoff *time.Duration) error {
|
||||
subscribe, err := c.eventService.Subscribe(ctx)
|
||||
if err != nil {
|
||||
return c.handleSubscribeError(ctx, err, backoff)
|
||||
}
|
||||
|
||||
err = subscribe.Send(&proto.Node{
|
||||
Type: proto.EventType_AUTHENTICATION,
|
||||
Payload: &proto.Node_AuthEvent{
|
||||
AuthEvent: &proto.Authentication{
|
||||
Identity: identity,
|
||||
AuthToken: authToken,
|
||||
},
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Println("Authentication failed to send to gRPC server:", err)
|
||||
if c.isConnectionError(err) {
|
||||
if err = wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
growBackoff()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Println("Authentication Successfully sent to gRPC server")
|
||||
backoff = baseBackoff
|
||||
if err != nil {
|
||||
return c.handleAuthError(ctx, err, backoff)
|
||||
}
|
||||
|
||||
if err = c.processEventStream(subscribe); err != nil {
|
||||
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
||||
return err
|
||||
}
|
||||
if c.isConnectionError(err) {
|
||||
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
||||
if err = wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
growBackoff()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Println("Authentication Successfully sent to gRPC server")
|
||||
*backoff = time.Second
|
||||
|
||||
return c.handleStreamError(ctx, c.processEventStream(subscribe), backoff)
|
||||
}
|
||||
|
||||
func (c *client) handleSubscribeError(ctx context.Context, err error, backoff *time.Duration) error {
|
||||
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
||||
return err
|
||||
}
|
||||
if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
|
||||
return err
|
||||
}
|
||||
if err = c.wait(ctx, *backoff); err != nil {
|
||||
return err
|
||||
}
|
||||
c.growBackoff(backoff)
|
||||
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleAuthError(ctx context.Context, err error, backoff *time.Duration) error {
|
||||
log.Println("Authentication failed to send to gRPC server:", err)
|
||||
if !c.isConnectionError(err) {
|
||||
return err
|
||||
}
|
||||
if err := c.wait(ctx, *backoff); err != nil {
|
||||
return err
|
||||
}
|
||||
c.growBackoff(backoff)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleStreamError(ctx context.Context, err error, backoff *time.Duration) error {
|
||||
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
||||
return err
|
||||
}
|
||||
if !c.isConnectionError(err) {
|
||||
return err
|
||||
}
|
||||
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
||||
if err := c.wait(ctx, *backoff); err != nil {
|
||||
return err
|
||||
}
|
||||
c.growBackoff(backoff)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) wait(ctx context.Context, duration time.Duration) error {
|
||||
if duration <= 0 {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) growBackoff(backoff *time.Duration) {
|
||||
const maxBackoff = 30 * time.Second
|
||||
*backoff *= 2
|
||||
if *backoff > maxBackoff {
|
||||
*backoff = maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,35 +214,20 @@ func (c *client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, pr
|
||||
func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
|
||||
slugEvent := evt.GetSlugEvent()
|
||||
user := slugEvent.GetUser()
|
||||
oldSlug := slugEvent.GetOld()
|
||||
newSlug := slugEvent.GetNew()
|
||||
oldKey := types.SessionKey{Id: slugEvent.GetOld(), Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: slugEvent.GetNew(), Type: types.TunnelTypeHTTP}
|
||||
|
||||
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP})
|
||||
userSession, err := c.sessionRegistry.Get(oldKey)
|
||||
if err != nil {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||
Payload: &proto.Node_SlugEventResponse{
|
||||
SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
|
||||
},
|
||||
}, "slug change failure response")
|
||||
return c.sendSlugChangeResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||
Payload: &proto.Node_SlugEventResponse{
|
||||
SlugEventResponse: &proto.SlugChangeEventResponse{Success: false, Message: err.Error()},
|
||||
},
|
||||
}, "slug change failure response")
|
||||
if err = c.sessionRegistry.Update(user, oldKey, newKey); err != nil {
|
||||
return c.sendSlugChangeResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
userSession.Interaction().Redraw()
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||
Payload: &proto.Node_SlugEventResponse{
|
||||
SlugEventResponse: &proto.SlugChangeEventResponse{Success: true, Message: ""},
|
||||
},
|
||||
}, "slug change success response")
|
||||
return c.sendSlugChangeResponse(subscribe, true, "")
|
||||
}
|
||||
|
||||
func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
|
||||
@@ -227,7 +237,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
|
||||
for _, ses := range sessions {
|
||||
detail := ses.Detail()
|
||||
details = append(details, &proto.Detail{
|
||||
Node: config.Getenv("DOMAIN", "localhost"),
|
||||
Node: c.config.Domain(),
|
||||
ForwardingType: detail.ForwardingType,
|
||||
Slug: detail.Slug,
|
||||
UserId: detail.UserID,
|
||||
@@ -236,12 +246,7 @@ func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
|
||||
})
|
||||
}
|
||||
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_GET_SESSIONS,
|
||||
Payload: &proto.Node_GetSessionsEvent{
|
||||
GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
|
||||
},
|
||||
}, "send get sessions response")
|
||||
return c.sendGetSessionsResponse(subscribe, details)
|
||||
}
|
||||
|
||||
func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
|
||||
@@ -251,39 +256,46 @@ func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto
|
||||
|
||||
tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType())
|
||||
if err != nil {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_TERMINATE_SESSION,
|
||||
Payload: &proto.Node_TerminateSessionEventResponse{
|
||||
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
|
||||
},
|
||||
}, "terminate session invalid tunnel type")
|
||||
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType})
|
||||
if err != nil {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_TERMINATE_SESSION,
|
||||
Payload: &proto.Node_TerminateSessionEventResponse{
|
||||
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
|
||||
},
|
||||
}, "terminate session fetch failed")
|
||||
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
if err = userSession.Lifecycle().Close(); err != nil {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_TERMINATE_SESSION,
|
||||
Payload: &proto.Node_TerminateSessionEventResponse{
|
||||
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: false, Message: err.Error()},
|
||||
},
|
||||
}, "terminate session close failed")
|
||||
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
return c.sendTerminateSessionResponse(subscribe, true, "")
|
||||
}
|
||||
|
||||
func (c *client) sendSlugChangeResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||
Payload: &proto.Node_SlugEventResponse{
|
||||
SlugEventResponse: &proto.SlugChangeEventResponse{Success: success, Message: message},
|
||||
},
|
||||
}, "slug change response")
|
||||
}
|
||||
|
||||
func (c *client) sendGetSessionsResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], details []*proto.Detail) error {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_GET_SESSIONS,
|
||||
Payload: &proto.Node_GetSessionsEvent{
|
||||
GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
|
||||
},
|
||||
}, "send get sessions response")
|
||||
}
|
||||
|
||||
func (c *client) sendTerminateSessionResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_TERMINATE_SESSION,
|
||||
Payload: &proto.Node_TerminateSessionEventResponse{
|
||||
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: true, Message: ""},
|
||||
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: success, Message: message},
|
||||
},
|
||||
}, "terminate session success response")
|
||||
}, "terminate session response")
|
||||
}
|
||||
|
||||
func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error {
|
||||
@@ -299,11 +311,11 @@ func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E
|
||||
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
|
||||
switch t {
|
||||
case proto.TunnelType_HTTP:
|
||||
return types.HTTP, nil
|
||||
return types.TunnelTypeHTTP, nil
|
||||
case proto.TunnelType_TCP:
|
||||
return types.TCP, nil
|
||||
return types.TunnelTypeTCP, nil
|
||||
default:
|
||||
return types.UNKNOWN, fmt.Errorf("unknown tunnel type received")
|
||||
return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,7 +336,7 @@ func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bo
|
||||
}
|
||||
|
||||
func (c *client) CheckServerHealth(ctx context.Context) error {
|
||||
healthClient := grpc_health_v1.NewHealthClient(c.ClientConn())
|
||||
healthClient := healthNewHealthClient(c.ClientConn())
|
||||
resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{
|
||||
Service: "",
|
||||
})
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -17,9 +17,9 @@ type RequestHeader interface {
|
||||
Set(key string, value string)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
GetMethod() string
|
||||
GetPath() string
|
||||
GetVersion() string
|
||||
Method() string
|
||||
Path() string
|
||||
Version() string
|
||||
}
|
||||
type requestHeader struct {
|
||||
method string
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
package header
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expectErr bool
|
||||
errContains string
|
||||
expectMethod string
|
||||
expectPath string
|
||||
expectVersion string
|
||||
expectHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
data: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\nX-Custom: value\r\n\r\n"),
|
||||
expectErr: false,
|
||||
expectMethod: "GET",
|
||||
expectPath: "/path",
|
||||
expectVersion: "HTTP/1.1",
|
||||
expectHeaders: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Custom": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no CRLF in start line",
|
||||
data: []byte("GET /path HTTP/1.1"),
|
||||
expectErr: true,
|
||||
errContains: "no CRLF found in start line",
|
||||
},
|
||||
{
|
||||
name: "invalid start line - missing method",
|
||||
data: []byte("INVALID\r\n\r\n"),
|
||||
expectErr: true,
|
||||
errContains: "invalid start line: missing method",
|
||||
},
|
||||
{
|
||||
name: "invalid start line - missing version",
|
||||
data: []byte("GET /path\r\n\r\n"),
|
||||
expectErr: true,
|
||||
errContains: "invalid start line: missing version",
|
||||
},
|
||||
{
|
||||
name: "invalid start line - multiple spaces",
|
||||
data: []byte("GET /path HTTP/1.1\r\n\r\n"),
|
||||
expectErr: false,
|
||||
expectMethod: "GET",
|
||||
expectPath: "",
|
||||
expectVersion: "/path HTTP/1.1",
|
||||
expectHeaders: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "start line with trailing space",
|
||||
data: []byte("GET / HTTP/1.1 \r\n\r\n"),
|
||||
expectErr: false,
|
||||
expectMethod: "GET",
|
||||
expectPath: "/",
|
||||
expectVersion: "HTTP/1.1 ",
|
||||
expectHeaders: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := NewRequest(tt.data)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, req)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, req)
|
||||
assert.Equal(t, tt.expectMethod, req.Method())
|
||||
assert.Equal(t, tt.expectPath, req.Path())
|
||||
assert.Equal(t, tt.expectVersion, req.Version())
|
||||
for k, v := range tt.expectHeaders {
|
||||
assert.Equal(t, v, req.Value(k))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHeaderMethods(t *testing.T) {
|
||||
data := []byte("GET / HTTP/1.1\r\nHost: original\r\n\r\n")
|
||||
req, _ := NewRequest(data)
|
||||
|
||||
req.Set("Host", "updated")
|
||||
req.Set("X-New", "new-value")
|
||||
assert.Equal(t, "updated", req.Value("Host"))
|
||||
assert.Equal(t, "new-value", req.Value("X-New"))
|
||||
|
||||
assert.Equal(t, "", req.Value("Non-Existent"))
|
||||
|
||||
req.Remove("X-New")
|
||||
assert.Equal(t, "", req.Value("X-New"))
|
||||
|
||||
final := req.Finalize()
|
||||
assert.Contains(t, string(final), "GET / HTTP/1.1\r\n")
|
||||
assert.Contains(t, string(final), "Host: updated\r\n")
|
||||
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
|
||||
}
|
||||
|
||||
func TestNewResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expectErr bool
|
||||
errContains string
|
||||
expectHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
data: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"),
|
||||
expectErr: false,
|
||||
expectHeaders: map[string]string{
|
||||
"Content-Length": "0",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid response - no CRLF",
|
||||
data: []byte("HTTP/1.1 200 OK"),
|
||||
expectErr: true,
|
||||
errContains: "no CRLF found in start line",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, err := NewResponse(tt.data)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, resp)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
for k, v := range tt.expectHeaders {
|
||||
assert.Equal(t, v, resp.Value(k))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseHeaderMethods(t *testing.T) {
|
||||
data := []byte("HTTP/1.1 200 OK\r\nServer: old\r\n\r\n")
|
||||
resp, _ := NewResponse(data)
|
||||
|
||||
resp.Set("Server", "new")
|
||||
resp.Set("X-Res", "val")
|
||||
assert.Equal(t, "new", resp.Value("Server"))
|
||||
assert.Equal(t, "val", resp.Value("X-Res"))
|
||||
|
||||
resp.Remove("X-Res")
|
||||
assert.Equal(t, "", resp.Value("X-Res"))
|
||||
|
||||
final := resp.Finalize()
|
||||
assert.Contains(t, string(final), "HTTP/1.1 200 OK\r\n")
|
||||
assert.Contains(t, string(final), "Server: new\r\n")
|
||||
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
|
||||
}
|
||||
|
||||
func TestSetRemainingHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
initialHeaders map[string]string
|
||||
expectHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "various header formats",
|
||||
data: []byte("K1: V1\r\nK2:V2\r\n K3 : V3 \r\nNoColon\r\n\r\n"),
|
||||
expectHeaders: map[string]string{
|
||||
"K1": "V1",
|
||||
"K2": "V2",
|
||||
"K3": "V3",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no trailing CRLF",
|
||||
data: []byte("K1: V1"),
|
||||
expectHeaders: map[string]string{
|
||||
"K1": "V1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty lines",
|
||||
data: []byte("\r\nK1: V1"),
|
||||
expectHeaders: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "headers with only colon",
|
||||
data: []byte(": value\r\nkey:\r\n"),
|
||||
expectHeaders: map[string]string{
|
||||
"": "value",
|
||||
"key": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &requestHeader{headers: make(map[string]string)}
|
||||
if tt.initialHeaders != nil {
|
||||
req.headers = tt.initialHeaders
|
||||
}
|
||||
setRemainingHeaders(tt.data, req)
|
||||
assert.Equal(t, len(tt.expectHeaders), len(req.headers))
|
||||
for k, v := range tt.expectHeaders {
|
||||
assert.Equal(t, v, req.headers[k])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package header
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
@@ -36,31 +35,6 @@ func setRemainingHeaders(remaining []byte, header interface {
|
||||
}
|
||||
}
|
||||
|
||||
func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) {
|
||||
header := &requestHeader{
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
||||
if lineEnd == -1 {
|
||||
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
|
||||
}
|
||||
|
||||
startLine := headerData[:lineEnd]
|
||||
header.startLine = startLine
|
||||
var err error
|
||||
header.method, header.path, header.version, err = parseStartLine(startLine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
remaining := headerData[lineEnd+2:]
|
||||
|
||||
setRemainingHeaders(remaining, header)
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func parseStartLine(startLine []byte) (method, path, version string, err error) {
|
||||
firstSpace := bytes.IndexByte(startLine, ' ')
|
||||
if firstSpace == -1 {
|
||||
@@ -80,51 +54,6 @@ func parseStartLine(startLine []byte) (method, path, version string, err error)
|
||||
return method, path, version, nil
|
||||
}
|
||||
|
||||
func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) {
|
||||
header := &requestHeader{
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
startLineBytes, err := br.ReadSlice('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
|
||||
header.startLine = make([]byte, len(startLineBytes))
|
||||
copy(header.startLine, startLineBytes)
|
||||
|
||||
header.method, header.path, header.version, err = parseStartLine(header.startLine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
lineBytes, err := br.ReadSlice('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
|
||||
|
||||
if len(lineBytes) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
colonIdx := bytes.IndexByte(lineBytes, ':')
|
||||
if colonIdx == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := bytes.TrimSpace(lineBytes[:colonIdx])
|
||||
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
|
||||
|
||||
header.headers[string(key)] = string(value)
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func finalize(startLine []byte, headers map[string]string) []byte {
|
||||
size := len(startLine) + 2
|
||||
for key, val := range headers {
|
||||
|
||||
@@ -1,19 +1,33 @@
|
||||
package header
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func NewRequest(r interface{}) (RequestHeader, error) {
|
||||
switch v := r.(type) {
|
||||
case []byte:
|
||||
return parseHeadersFromBytes(v)
|
||||
case *bufio.Reader:
|
||||
return parseHeadersFromReader(v)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type: %T", r)
|
||||
func NewRequest(headerData []byte) (RequestHeader, error) {
|
||||
header := &requestHeader{
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
||||
if lineEnd == -1 {
|
||||
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
|
||||
}
|
||||
|
||||
startLine := headerData[:lineEnd]
|
||||
header.startLine = startLine
|
||||
var err error
|
||||
header.method, header.path, header.version, err = parseStartLine(startLine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
remaining := headerData[lineEnd+2:]
|
||||
|
||||
setRemainingHeaders(remaining, header)
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func (req *requestHeader) Value(key string) string {
|
||||
@@ -32,15 +46,15 @@ func (req *requestHeader) Remove(key string) {
|
||||
delete(req.headers, key)
|
||||
}
|
||||
|
||||
func (req *requestHeader) GetMethod() string {
|
||||
func (req *requestHeader) Method() string {
|
||||
return req.method
|
||||
}
|
||||
|
||||
func (req *requestHeader) GetPath() string {
|
||||
func (req *requestHeader) Path() string {
|
||||
return req.path
|
||||
}
|
||||
|
||||
func (req *requestHeader) GetVersion() string {
|
||||
func (req *requestHeader) Version() string {
|
||||
return req.version
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ type http struct {
|
||||
remoteAddr net.Addr
|
||||
writer io.Writer
|
||||
reader io.Reader
|
||||
headerBuf []byte
|
||||
buf []byte
|
||||
respHeader header.ResponseHeader
|
||||
reqHeader header.RequestHeader
|
||||
@@ -72,7 +71,10 @@ func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
|
||||
}
|
||||
|
||||
func (hs *http) Close() error {
|
||||
return hs.writer.(io.Closer).Close()
|
||||
if closer, ok := hs.writer.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *http) CloseWrite() error {
|
||||
|
||||
@@ -0,0 +1,765 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tunnel_pls/internal/http/header"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockAddr struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAddr) String() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
func (m *MockAddr) Network() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
type MockRequestMiddleware struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
|
||||
args := m.Called(h)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockResponseMiddleware struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
|
||||
args := m.Called(h, body)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockReadWriter struct {
|
||||
mock.Mock
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockReadWriter) Read(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockReadWriter) Write(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockReadWriter) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockReadWriter) CloseWrite() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockReadWriterOnlyCloser struct {
|
||||
mock.Mock
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockReadWriterOnlyCloser) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockWriterOnly struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockWriterOnly) Write(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockWriterOnly) Read(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
type MockReader struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockReader) Read(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
type MockWriter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockWriter) Write(p []byte) (int, error) {
|
||||
ret := m.Called(p)
|
||||
|
||||
var n int
|
||||
var err error
|
||||
|
||||
switch v := ret.Get(0).(type) {
|
||||
case func([]byte) int:
|
||||
n = v(p)
|
||||
case int:
|
||||
n = v
|
||||
default:
|
||||
n = len(p)
|
||||
}
|
||||
|
||||
switch v := ret.Get(1).(type) {
|
||||
case func([]byte) error:
|
||||
err = v(p)
|
||||
case error:
|
||||
err = v
|
||||
default:
|
||||
err = nil
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (m *MockWriter) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestHTTPMethods(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
rw := new(MockReadWriter)
|
||||
hs := New(rw, rw, addr)
|
||||
|
||||
assert.Equal(t, addr, hs.RemoteAddr())
|
||||
|
||||
reqMW := new(MockRequestMiddleware)
|
||||
hs.UseRequestMiddleware(reqMW)
|
||||
assert.Equal(t, 1, len(hs.RequestMiddlewares()))
|
||||
assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
|
||||
|
||||
respMW := new(MockResponseMiddleware)
|
||||
hs.UseResponseMiddleware(respMW)
|
||||
assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
|
||||
assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
|
||||
|
||||
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
hs.SetRequestHeader(reqH)
|
||||
}
|
||||
|
||||
func TestApplyMiddlewares(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware)
|
||||
apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
|
||||
verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "apply request middleware success",
|
||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
|
||||
h := args.Get(0).(header.RequestHeader)
|
||||
h.Set("X-Middleware", "true")
|
||||
}).Return(nil)
|
||||
hs.UseRequestMiddleware(reqMW)
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyRequestMiddlewares(reqH)
|
||||
},
|
||||
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
|
||||
assert.Equal(t, "true", reqH.Value("X-Middleware"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "apply response middleware success",
|
||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
h := args.Get(0).(header.ResponseHeader)
|
||||
h.Set("X-Resp-Middleware", "true")
|
||||
}).Return(nil)
|
||||
hs.UseResponseMiddleware(respMW)
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
||||
},
|
||||
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
|
||||
assert.Equal(t, "true", respH.Value("X-Resp-Middleware"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "apply request middleware error",
|
||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||
reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError)
|
||||
hs.UseRequestMiddleware(reqMW)
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyRequestMiddlewares(reqH)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "apply response middleware error",
|
||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError)
|
||||
hs.UseResponseMiddleware(respMW)
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
||||
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
rw := new(MockReadWriter)
|
||||
hs := New(rw, rw, addr)
|
||||
|
||||
reqMW := new(MockRequestMiddleware)
|
||||
respMW := new(MockResponseMiddleware)
|
||||
tt.setup(hs, reqMW, respMW)
|
||||
|
||||
err := tt.apply(hs, reqH, respH)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, reqH, respH)
|
||||
}
|
||||
}
|
||||
|
||||
reqMW.AssertExpectations(t)
|
||||
respMW.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseMethods(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() (io.Writer, io.Reader)
|
||||
op func(HTTP) error
|
||||
verify func(*testing.T, io.Writer)
|
||||
}{
|
||||
{
|
||||
name: "Close success",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
rw := new(MockReadWriter)
|
||||
rw.On("Close").Return(nil)
|
||||
return rw, rw
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.Close() },
|
||||
verify: func(t *testing.T, w io.Writer) {
|
||||
w.(*MockReadWriter).AssertCalled(t, "Close")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWrite with CloseWrite implementation",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
rw := new(MockReadWriter)
|
||||
rw.On("CloseWrite").Return(nil)
|
||||
return rw, rw
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||
verify: func(t *testing.T, w io.Writer) {
|
||||
w.(*MockReadWriter).AssertCalled(t, "CloseWrite")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWrite fallback to Close",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
rw := new(MockReadWriterOnlyCloser)
|
||||
rw.On("Close").Return(nil)
|
||||
return rw, rw
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||
verify: func(t *testing.T, w io.Writer) {
|
||||
w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Close with No Closer",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
w := new(MockWriterOnly)
|
||||
r := new(MockReader)
|
||||
return w, r
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.Close() },
|
||||
},
|
||||
{
|
||||
name: "CloseWrite with No CloseWrite and No Closer",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
w := new(MockWriterOnly)
|
||||
r := new(MockReader)
|
||||
return w, r
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
w, r := tt.setup()
|
||||
hs := New(w, r, addr)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
err := tt.op(hs)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, w)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitHeaderAndBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
delimiterIdx int
|
||||
expectHeader []byte
|
||||
expectBody []byte
|
||||
}{
|
||||
{
|
||||
name: "standard",
|
||||
data: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nBodyContent"),
|
||||
delimiterIdx: 31,
|
||||
expectHeader: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"),
|
||||
expectBody: []byte("BodyContent"),
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
delimiterIdx: 15,
|
||||
expectHeader: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
expectBody: []byte(""),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, b := splitHeaderAndBody(tt.data, tt.delimiterIdx)
|
||||
assert.Equal(t, tt.expectHeader, h)
|
||||
assert.Equal(t, tt.expectBody, b)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsHTTPHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
buf []byte
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
buf: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "valid response",
|
||||
buf: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "invalid start line",
|
||||
buf: []byte("NOT_HTTP /path\r\nHost: example.com\r\n\r\n"),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header line (no colon)",
|
||||
buf: []byte("GET / HTTP/1.1\r\nInvalidHeaderLine\r\n\r\n"),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header line (colon at 0)",
|
||||
buf: []byte("GET / HTTP/1.1\r\n: value\r\n\r\n"),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "empty header section",
|
||||
buf: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "multiple headers",
|
||||
buf: []byte("GET / HTTP/1.1\r\nH1: V1\r\nH2: V2\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isHTTPHeader(tt.buf)
|
||||
assert.Equal(t, tt.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
readLen int
|
||||
expectContent string
|
||||
expectRead int
|
||||
expectErr bool
|
||||
middlewareErr error
|
||||
isHTTP bool
|
||||
}{
|
||||
{
|
||||
name: "valid http request",
|
||||
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\nBody"),
|
||||
readLen: 100,
|
||||
expectContent: "Body",
|
||||
expectRead: 54,
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
name: "non-http data",
|
||||
input: []byte("Some random data\r\n\r\nMore data"),
|
||||
readLen: 100,
|
||||
expectContent: "Some random data\r\n\r\nMore data",
|
||||
expectRead: 29,
|
||||
isHTTP: false,
|
||||
},
|
||||
{
|
||||
name: "no delimiter",
|
||||
input: []byte("Partial data without delimiter"),
|
||||
readLen: 100,
|
||||
expectContent: "Partial data without delimiter",
|
||||
expectRead: 30,
|
||||
isHTTP: false,
|
||||
},
|
||||
{
|
||||
name: "middleware error",
|
||||
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\n"),
|
||||
readLen: 100,
|
||||
middlewareErr: assert.AnError,
|
||||
expectErr: true,
|
||||
isHTTP: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
reader := new(MockReader)
|
||||
writer := new(MockWriterOnly)
|
||||
|
||||
if tt.expectErr || tt.name == "valid http request" {
|
||||
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
||||
p := args.Get(0).([]byte)
|
||||
copy(p, tt.input)
|
||||
}).Return(len(tt.input), io.EOF).Once()
|
||||
} else {
|
||||
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
||||
p := args.Get(0).([]byte)
|
||||
copy(p, tt.input)
|
||||
}).Return(len(tt.input), nil).Once()
|
||||
}
|
||||
|
||||
hs := New(writer, reader, addr)
|
||||
|
||||
reqMW := new(MockRequestMiddleware)
|
||||
if tt.isHTTP {
|
||||
if tt.middlewareErr != nil {
|
||||
reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr)
|
||||
} else {
|
||||
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
|
||||
h := args.Get(0).(header.RequestHeader)
|
||||
h.Set("X-Middleware", "true")
|
||||
}).Return(nil)
|
||||
}
|
||||
}
|
||||
hs.UseRequestMiddleware(reqMW)
|
||||
|
||||
p := make([]byte, tt.readLen)
|
||||
n, err := hs.Read(p)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectRead, n)
|
||||
if tt.name == "valid http request" {
|
||||
content := string(p[:n])
|
||||
assert.Contains(t, content, "GET / HTTP/1.1\r\n")
|
||||
assert.Contains(t, content, "Host: test\r\n")
|
||||
assert.Contains(t, content, "X-Middleware: true\r\n")
|
||||
assert.True(t, bytes.HasSuffix(p[:n], []byte("\r\n\r\nBody")))
|
||||
} else {
|
||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
if tt.isHTTP {
|
||||
reqMW.AssertExpectations(t)
|
||||
}
|
||||
reader.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
writes [][]byte
|
||||
expectWritten string
|
||||
expectErr bool
|
||||
middlewareErr error
|
||||
isHTTP bool
|
||||
}{
|
||||
{
|
||||
name: "valid http response in one write",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
|
||||
},
|
||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
name: "valid http response in multiple writes",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\n"),
|
||||
[]byte("Content-Length: 4\r\n\r\n"),
|
||||
[]byte("Body"),
|
||||
},
|
||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
name: "non-http data",
|
||||
writes: [][]byte{
|
||||
[]byte("Random data with delimiter\r\n\r\nFlush"),
|
||||
},
|
||||
expectWritten: "Random data with delimiter\r\n\r\nFlush",
|
||||
isHTTP: false,
|
||||
},
|
||||
{
|
||||
name: "bypass buffering",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
},
|
||||
expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
|
||||
"HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
name: "middleware error",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
},
|
||||
middlewareErr: assert.AnError,
|
||||
expectErr: true,
|
||||
isHTTP: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
var writtenData bytes.Buffer
|
||||
writer := new(MockWriter)
|
||||
|
||||
writer.On("Write", mock.Anything).Run(func(args mock.Arguments) {
|
||||
p := args.Get(0).([]byte)
|
||||
writtenData.Write(p)
|
||||
}).Return(func(p []byte) int {
|
||||
return len(p)
|
||||
}, nil)
|
||||
|
||||
reader := new(MockReader)
|
||||
hs := New(writer, reader, addr)
|
||||
|
||||
respMW := new(MockResponseMiddleware)
|
||||
if tt.isHTTP {
|
||||
if tt.middlewareErr != nil {
|
||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr)
|
||||
} else {
|
||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
h := args.Get(0).(header.ResponseHeader)
|
||||
h.Set("X-Resp-Middleware", "true")
|
||||
}).Return(nil)
|
||||
}
|
||||
}
|
||||
hs.UseResponseMiddleware(respMW)
|
||||
|
||||
var totalN int
|
||||
var err error
|
||||
for _, w := range tt.writes {
|
||||
var n int
|
||||
n, err = hs.Write(w)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
totalN += n
|
||||
}
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
written := writtenData.String()
|
||||
if strings.HasPrefix(tt.expectWritten, "HTTP/") {
|
||||
assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
|
||||
assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
|
||||
if strings.Contains(tt.expectWritten, "Content-Length: 4") {
|
||||
assert.Contains(t, written, "Content-Length: 4\r\n")
|
||||
}
|
||||
assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
|
||||
} else {
|
||||
assert.Equal(t, tt.expectWritten, written)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.isHTTP {
|
||||
respMW.AssertExpectations(t)
|
||||
}
|
||||
if tt.middlewareErr == nil {
|
||||
writer.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() (io.Writer, io.Reader)
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "write error in writeHeaderAndBody",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
writer := new(MockWriter)
|
||||
writer.On("Write", mock.Anything).Return(0, assert.AnError)
|
||||
reader := new(MockReader)
|
||||
return writer, reader
|
||||
},
|
||||
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
},
|
||||
{
|
||||
name: "write error in writeHeaderAndBody second write",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
writer := new(MockWriter)
|
||||
writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once()
|
||||
writer.On("Write", mock.Anything).Return(0, assert.AnError).Once()
|
||||
reader := new(MockReader)
|
||||
return writer, reader
|
||||
},
|
||||
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
|
||||
},
|
||||
{
|
||||
name: "write error in writeRawBuffer",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
writer := new(MockWriter)
|
||||
writer.On("Write", mock.Anything).Return(0, assert.AnError)
|
||||
reader := new(MockReader)
|
||||
return writer, reader
|
||||
},
|
||||
data: []byte("Not HTTP\r\n\r\nFlush"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
w, r := tt.setup()
|
||||
hs := New(w, r, addr)
|
||||
|
||||
_, err := hs.Write(tt.data)
|
||||
assert.Error(t, err)
|
||||
|
||||
w.(*MockWriter).AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadEOF(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() io.Reader
|
||||
expectN int
|
||||
expectErr error
|
||||
expectContent string
|
||||
}{
|
||||
{
|
||||
name: "read eof",
|
||||
setup: func() io.Reader {
|
||||
reader := new(MockReader)
|
||||
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
||||
p := args.Get(0).([]byte)
|
||||
copy(p, "data")
|
||||
}).Return(4, io.EOF)
|
||||
return reader
|
||||
},
|
||||
expectN: 4,
|
||||
expectErr: io.EOF,
|
||||
expectContent: "data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
reader := tt.setup()
|
||||
hs := New(nil, reader, addr)
|
||||
|
||||
p := make([]byte, 100)
|
||||
n, err := hs.Read(p)
|
||||
|
||||
assert.Equal(t, tt.expectN, n)
|
||||
assert.Equal(t, tt.expectErr, err)
|
||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
||||
|
||||
reader.(*MockReader).AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
+28
-9
@@ -5,6 +5,8 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -12,7 +14,20 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
rsaGenerateKey = rsa.GenerateKey
|
||||
pemEncode = pem.Encode
|
||||
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
|
||||
return ssh.NewPublicKey(key)
|
||||
}
|
||||
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
|
||||
return w.Write(data)
|
||||
}
|
||||
osOpenFile = os.OpenFile
|
||||
)
|
||||
|
||||
func GenerateSSHKeyIfNotExist(keyPath string) error {
|
||||
var errGroup = make([]error, 0)
|
||||
if _, err := os.Stat(keyPath); err == nil {
|
||||
log.Printf("SSH key already exists at %s", keyPath)
|
||||
return nil
|
||||
@@ -20,7 +35,7 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
|
||||
|
||||
log.Printf("SSH key not found at %s, generating new key pair...", keyPath)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
privateKey, err := rsaGenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -35,33 +50,37 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
privateKeyFile, err := osOpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer privateKeyFile.Close()
|
||||
defer func(privateKeyFile *os.File) {
|
||||
errGroup = append(errGroup, privateKeyFile.Close())
|
||||
}(privateKeyFile)
|
||||
|
||||
if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil {
|
||||
if err := pemEncode(privateKeyFile, privateKeyPEM); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
|
||||
publicKey, err := sshNewPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pubKeyPath := keyPath + ".pub"
|
||||
pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
pubKeyFile, err := osOpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer pubKeyFile.Close()
|
||||
defer func(pubKeyFile *os.File) {
|
||||
errGroup = append(errGroup, pubKeyFile.Close())
|
||||
}(pubKeyFile)
|
||||
|
||||
_, err = pubKeyFile.Write(ssh.MarshalAuthorizedKey(publicKey))
|
||||
_, err = pubKeyWrite(pubKeyFile, ssh.MarshalAuthorizedKey(publicKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("SSH key pair generated successfully at %s and %s", keyPath, pubKeyPath)
|
||||
return nil
|
||||
return errors.Join(errGroup...)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestGenerateSSHKeyIfNotExist(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(t *testing.T, tempDir string) string
|
||||
mockSetup func() func()
|
||||
wantErr bool
|
||||
errStr string
|
||||
verify func(t *testing.T, keyPath string)
|
||||
}{
|
||||
{
|
||||
name: "GenerateNewKey",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "id_rsa")
|
||||
},
|
||||
verify: func(t *testing.T, keyPath string) {
|
||||
pubKeyPath := keyPath + ".pub"
|
||||
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
||||
t.Errorf("Private key file not created")
|
||||
}
|
||||
if _, err := os.Stat(pubKeyPath); os.IsNotExist(err) {
|
||||
t.Errorf("Public key file not created")
|
||||
}
|
||||
privateKeyBytes, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read private key: %v", err)
|
||||
}
|
||||
if _, err = ssh.ParseRawPrivateKey(privateKeyBytes); err != nil {
|
||||
t.Errorf("Failed to parse private key: %v", err)
|
||||
}
|
||||
publicKeyBytes, err := os.ReadFile(pubKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read public key: %v", err)
|
||||
}
|
||||
if _, _, _, _, err = ssh.ParseAuthorizedKey(publicKeyBytes); err != nil {
|
||||
t.Errorf("Failed to parse public key: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DoNotOverwriteExistingKey",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
keyPath := filepath.Join(tempDir, "existing_id_rsa")
|
||||
dummyPrivate := "dummy private"
|
||||
dummyPublic := "dummy public"
|
||||
if err := os.WriteFile(keyPath, []byte(dummyPrivate), 0600); err != nil {
|
||||
t.Fatalf("Failed to create dummy private key: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath+".pub", []byte(dummyPublic), 0644); err != nil {
|
||||
t.Fatalf("Failed to create dummy public key: %v", err)
|
||||
}
|
||||
return keyPath
|
||||
},
|
||||
verify: func(t *testing.T, keyPath string) {
|
||||
gotPrivate, _ := os.ReadFile(keyPath)
|
||||
if string(gotPrivate) != "dummy private" {
|
||||
t.Errorf("Private key was overwritten")
|
||||
}
|
||||
gotPublic, _ := os.ReadFile(keyPath + ".pub")
|
||||
if string(gotPublic) != "dummy public" {
|
||||
t.Errorf("Public key was overwritten")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CreateNestedDirectories",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "nested", "dir", "id_rsa")
|
||||
},
|
||||
verify: func(t *testing.T, keyPath string) {
|
||||
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
||||
t.Errorf("Private key file not created in nested directory")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FailureMkdirAll",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
dirPath := filepath.Join(tempDir, "file_as_dir")
|
||||
if err := os.WriteFile(dirPath, []byte("not a dir"), 0644); err != nil {
|
||||
t.Fatalf("Failed to create file: %v", err)
|
||||
}
|
||||
return filepath.Join(dirPath, "id_rsa")
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "PrivateExistsPublicMissing",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
keyPath := filepath.Join(tempDir, "partial_id_rsa")
|
||||
if err := os.WriteFile(keyPath, []byte("private"), 0600); err != nil {
|
||||
t.Fatalf("Failed to create private key: %v", err)
|
||||
}
|
||||
return keyPath
|
||||
},
|
||||
verify: func(t *testing.T, keyPath string) {
|
||||
if _, err := os.Stat(keyPath + ".pub"); !os.IsNotExist(err) {
|
||||
t.Errorf("Public key should NOT have been created if private key existed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FailureRSAGenerateKey",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_rsa")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := rsaGenerateKey
|
||||
rsaGenerateKey = func(random io.Reader, bits int) (*rsa.PrivateKey, error) {
|
||||
return nil, errors.New("rsa error")
|
||||
}
|
||||
return func() { rsaGenerateKey = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "rsa error",
|
||||
},
|
||||
{
|
||||
name: "FailureOpenFilePrivate",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_open_private")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := osOpenFile
|
||||
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
return nil, errors.New("open error")
|
||||
}
|
||||
return func() { osOpenFile = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "open error",
|
||||
},
|
||||
{
|
||||
name: "FailurePemEncode",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_pem")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := pemEncode
|
||||
pemEncode = func(out io.Writer, b *pem.Block) error {
|
||||
return errors.New("pem error")
|
||||
}
|
||||
return func() { pemEncode = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "pem error",
|
||||
},
|
||||
{
|
||||
name: "FailureSSHNewPublicKey",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_ssh")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := sshNewPublicKey
|
||||
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
|
||||
return nil, errors.New("ssh error")
|
||||
}
|
||||
return func() { sshNewPublicKey = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "ssh error",
|
||||
},
|
||||
{
|
||||
name: "FailureOpenFilePublic",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_open_public")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := osOpenFile
|
||||
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
if filepath.Ext(name) == ".pub" {
|
||||
return nil, errors.New("open pub error")
|
||||
}
|
||||
return os.OpenFile(name, flag, perm)
|
||||
}
|
||||
return func() { osOpenFile = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "open pub error",
|
||||
},
|
||||
{
|
||||
name: "FailurePubKeyWrite",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_write")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := pubKeyWrite
|
||||
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
|
||||
return 0, errors.New("write error")
|
||||
}
|
||||
return func() { pubKeyWrite = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "write error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
keyPath := tt.setup(t, tempDir)
|
||||
if tt.mockSetup != nil {
|
||||
cleanup := tt.mockSetup()
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
err := GenerateSSHKeyIfNotExist(keyPath)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr && tt.errStr != "" && err != nil && err.Error() != tt.errStr {
|
||||
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErrStr %v", err, tt.errStr)
|
||||
}
|
||||
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, keyPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type mockRequestHeader struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Value(key string) string {
|
||||
return m.Called(key).String(0)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Set(key string, value string) {
|
||||
m.Called(key, value)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Remove(key string) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Finalize() []byte {
|
||||
return m.Called().Get(0).([]byte)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Method() string {
|
||||
return m.Called().String(0)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Path() string {
|
||||
return m.Called().String(0)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Version() string {
|
||||
return m.Called().String(0)
|
||||
}
|
||||
|
||||
func TestForwardedFor_HandleRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
expectedHost string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid IPv4 address",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 8080},
|
||||
expectedHost: "192.168.1.100",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid IPv6 address",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 8080},
|
||||
expectedHost: "2001:db8::ff00:42:8329",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid address format",
|
||||
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
||||
expectedHost: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid IPv4 address with port",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
|
||||
expectedHost: "127.0.0.1",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ff := NewForwardedFor(tc.addr)
|
||||
reqHeader := new(mockRequestHeader)
|
||||
|
||||
if !tc.expectError {
|
||||
reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return()
|
||||
}
|
||||
|
||||
err := ff.HandleRequest(reqHeader)
|
||||
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
reqHeader.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForwardedFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
expectAddr net.Addr
|
||||
}{
|
||||
{
|
||||
name: "IPv4 address",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
|
||||
expectAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
|
||||
expectAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
|
||||
},
|
||||
{
|
||||
name: "Unix address",
|
||||
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
||||
expectAddr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ff := NewForwardedFor(tc.addr)
|
||||
assert.Equal(t, tc.expectAddr.String(), ff.addr.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type mockResponseHeader struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockResponseHeader) Value(key string) string {
|
||||
return m.Called(key).String(0)
|
||||
}
|
||||
|
||||
func (m *mockResponseHeader) Set(key string, value string) {
|
||||
m.Called(key, value)
|
||||
}
|
||||
|
||||
func (m *mockResponseHeader) Remove(key string) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *mockResponseHeader) Finalize() []byte {
|
||||
return m.Called().Get(0).([]byte)
|
||||
}
|
||||
|
||||
func TestTunnelFingerprintHandleResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expected map[string]string
|
||||
body []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "Sets Server Header",
|
||||
expected: map[string]string{"Server": "Tunnel Please"},
|
||||
body: []byte("Sample body"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Overwrites Server Header",
|
||||
expected: map[string]string{"Server": "Tunnel Please"},
|
||||
body: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockHeader := new(mockResponseHeader)
|
||||
for k, v := range tt.expected {
|
||||
mockHeader.On("Set", k, v).Return()
|
||||
}
|
||||
|
||||
tunnelFingerprint := NewTunnelFingerprint()
|
||||
|
||||
err := tunnelFingerprint.HandleResponse(mockHeader, tt.body)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
mockHeader.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTunnelFingerprint(t *testing.T) {
|
||||
instance := NewTunnelFingerprint()
|
||||
assert.NotNil(t, instance)
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
startPort uint16
|
||||
endPort uint16
|
||||
wantErr bool
|
||||
}{
|
||||
{"normal range", 1000, 1002, false},
|
||||
{"invalid range", 2000, 1999, true},
|
||||
{"single port range", 3000, 3000, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pm := New()
|
||||
err := pm.AddRange(tt.startPort, tt.endPort)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnassigned(t *testing.T) {
|
||||
pm := New()
|
||||
_ = pm.AddRange(1000, 1002)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
status map[uint16]bool
|
||||
want uint16
|
||||
wantOk bool
|
||||
}{
|
||||
{"all unassigned", map[uint16]bool{1000: false, 1001: false, 1002: false}, 1000, true},
|
||||
{"some assigned", map[uint16]bool{1000: true, 1001: false, 1002: true}, 1001, true},
|
||||
{"all assigned", map[uint16]bool{1000: true, 1001: true, 1002: true}, 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for k, v := range tt.status {
|
||||
_ = pm.SetStatus(k, v)
|
||||
}
|
||||
got, gotOk := pm.Unassigned()
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.wantOk, gotOk)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetStatus(t *testing.T) {
|
||||
pm := New()
|
||||
_ = pm.AddRange(1000, 1002)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port uint16
|
||||
assigned bool
|
||||
}{
|
||||
{"assign port 1000", 1000, true},
|
||||
{"unassign port 1001", 1001, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := pm.SetStatus(tt.port, tt.assigned)
|
||||
assert.NoError(t, err)
|
||||
|
||||
status, ok := pm.(*port).ports[tt.port]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tt.assigned, status)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaim(t *testing.T) {
|
||||
pm := New()
|
||||
_ = pm.AddRange(1000, 1002)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port uint16
|
||||
status bool
|
||||
want bool
|
||||
}{
|
||||
{"claim unassigned port", 1000, false, true},
|
||||
{"claim already assigned port", 1001, true, false},
|
||||
{"claim non-existent port", 5000, false, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if _, exists := pm.(*port).ports[tt.port]; exists {
|
||||
_ = pm.SetStatus(tt.port, tt.status)
|
||||
}
|
||||
|
||||
got := pm.Claim(tt.port)
|
||||
assert.Equal(t, tt.want, got)
|
||||
|
||||
finalState := pm.(*port).ports[tt.port]
|
||||
assert.True(t, finalState)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,35 @@
|
||||
package random
|
||||
|
||||
import "crypto/rand"
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
func GenerateRandomString(length int) (string, error) {
|
||||
var (
|
||||
ErrInvalidLength = fmt.Errorf("invalid length")
|
||||
)
|
||||
|
||||
type Random interface {
|
||||
String(length int) (string, error)
|
||||
}
|
||||
|
||||
type random struct {
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func New() Random {
|
||||
return &random{reader: rand.Reader}
|
||||
}
|
||||
|
||||
func (ran *random) String(length int) (string, error) {
|
||||
if length < 0 {
|
||||
return "", ErrInvalidLength
|
||||
}
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, length)
|
||||
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
if _, err := ran.reader.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package random
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRandom_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ValidLengthZero", 0, false},
|
||||
{"ValidPositiveLength", 10, false},
|
||||
{"NegativeLength", -1, true},
|
||||
{"VeryLargeLength", 1_000_000, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
randomizer := New()
|
||||
|
||||
result, err := randomizer.String(tt.length)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, tt.length)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomWithFailingReader_String(t *testing.T) {
|
||||
errBrainrot := assert.AnError
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
reader io.Reader
|
||||
expectErr error
|
||||
}{
|
||||
{
|
||||
name: "failing reader",
|
||||
reader: func() io.Reader {
|
||||
return &failingReader{err: errBrainrot}
|
||||
}(),
|
||||
expectErr: errBrainrot,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
randomizer := &random{reader: tt.reader}
|
||||
result, err := randomizer.String(20)
|
||||
assert.ErrorIs(t, err, tt.expectErr)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type failingReader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *failingReader) Read(p []byte) (int, error) {
|
||||
return 0, f.err
|
||||
}
|
||||
@@ -34,6 +34,15 @@ type registry struct {
|
||||
slugIndex map[Key]string
|
||||
}
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = fmt.Errorf("session not found")
|
||||
ErrSlugInUse = fmt.Errorf("slug already in use")
|
||||
ErrInvalidSlug = fmt.Errorf("invalid slug")
|
||||
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
|
||||
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
|
||||
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
|
||||
)
|
||||
|
||||
func NewRegistry() Registry {
|
||||
return ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
@@ -47,12 +56,12 @@ func (r *registry) Get(key Key) (session Session, err error) {
|
||||
|
||||
userID, ok := r.slugIndex[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Session not found")
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
|
||||
client, ok := r.byUser[userID][key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Session not found")
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
@@ -63,37 +72,38 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error
|
||||
|
||||
client, ok := r.byUser[user][key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("Session not found")
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
||||
if oldKey.Type != newKey.Type {
|
||||
return fmt.Errorf("tunnel type cannot change")
|
||||
return ErrSlugUnchanged
|
||||
}
|
||||
|
||||
if newKey.Type != types.HTTP {
|
||||
return fmt.Errorf("non http tunnel cannot change slug")
|
||||
if newKey.Type != types.TunnelTypeHTTP {
|
||||
return ErrSlugChangeNotAllowed
|
||||
}
|
||||
|
||||
if isForbiddenSlug(newKey.Id) {
|
||||
return fmt.Errorf("this subdomain is reserved. Please choose a different one")
|
||||
return ErrForbiddenSlug
|
||||
}
|
||||
|
||||
if !isValidSlug(newKey.Id) {
|
||||
return fmt.Errorf("invalid subdomain. Follow the rules")
|
||||
return ErrInvalidSlug
|
||||
}
|
||||
|
||||
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
||||
return ErrSlugInUse
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
||||
return fmt.Errorf("someone already uses this subdomain")
|
||||
}
|
||||
client, ok := r.byUser[user][oldKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("Session not found")
|
||||
return ErrSessionNotFound
|
||||
}
|
||||
|
||||
delete(r.byUser[user], oldKey)
|
||||
@@ -102,9 +112,6 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
|
||||
client.Slug().Set(newKey.Id)
|
||||
r.slugIndex[newKey] = user
|
||||
|
||||
if r.byUser[user] == nil {
|
||||
r.byUser[user] = make(map[Key]Session)
|
||||
}
|
||||
r.byUser[user][newKey] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,695 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type mockSession struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(lifecycle.Lifecycle)
|
||||
}
|
||||
func (m *mockSession) Interaction() interaction.Interaction {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(interaction.Interaction)
|
||||
}
|
||||
func (m *mockSession) Forwarder() forwarder.Forwarder {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(forwarder.Forwarder)
|
||||
}
|
||||
func (m *mockSession) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
func (m *mockSession) Detail() *types.Detail {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*types.Detail)
|
||||
}
|
||||
|
||||
type mockLifecycle struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (ml *mockLifecycle) Channel() ssh.Channel {
|
||||
args := ml.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(ssh.Channel)
|
||||
}
|
||||
|
||||
func (ml *mockLifecycle) Connection() ssh.Conn {
|
||||
args := ml.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(ssh.Conn)
|
||||
}
|
||||
|
||||
func (ml *mockLifecycle) PortRegistry() port.Port {
|
||||
args := ml.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(port.Port)
|
||||
}
|
||||
|
||||
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { ml.Called(channel) }
|
||||
func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { ml.Called(status) }
|
||||
func (ml *mockLifecycle) IsActive() bool { return ml.Called().Bool(0) }
|
||||
func (ml *mockLifecycle) StartedAt() time.Time { return ml.Called().Get(0).(time.Time) }
|
||||
func (ml *mockLifecycle) Close() error { return ml.Called().Error(0) }
|
||||
func (ml *mockLifecycle) User() string { return ml.Called().String(0) }
|
||||
|
||||
type mockSlug struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (ms *mockSlug) Set(slug string) { ms.Called(slug) }
|
||||
func (ms *mockSlug) String() string { return ms.Called().String(0) }
|
||||
|
||||
func createMockSession(user ...string) *mockSession {
|
||||
u := "user1"
|
||||
if len(user) > 0 {
|
||||
u = user[0]
|
||||
}
|
||||
m := new(mockSession)
|
||||
ml := new(mockLifecycle)
|
||||
ml.On("User").Return(u).Maybe()
|
||||
m.On("Lifecycle").Return(ml).Maybe()
|
||||
ms := new(mockSlug)
|
||||
ms.On("Set", mock.Anything).Maybe()
|
||||
m.On("Slug").Return(ms).Maybe()
|
||||
m.On("Interaction").Return(nil).Maybe()
|
||||
m.On("Forwarder").Return(nil).Maybe()
|
||||
m.On("Detail").Return(nil).Maybe()
|
||||
return m
|
||||
}
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
require.NotNil(t, r)
|
||||
}
|
||||
|
||||
func TestRegistry_Get(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry)
|
||||
key types.SessionKey
|
||||
wantErr error
|
||||
wantResult bool
|
||||
}{
|
||||
{
|
||||
name: "session found",
|
||||
setupFunc: func(r *registry) {
|
||||
user := "user1"
|
||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession(user)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser[user] = map[types.SessionKey]Session{
|
||||
key: session,
|
||||
}
|
||||
r.slugIndex[key] = user
|
||||
},
|
||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||
wantErr: nil,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "session not found in slugIndex",
|
||||
setupFunc: func(r *registry) {},
|
||||
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
{
|
||||
name: "session not found in byUser",
|
||||
setupFunc: func(r *registry) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
|
||||
},
|
||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[types.SessionKey]Session),
|
||||
slugIndex: make(map[types.SessionKey]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
tt.setupFunc(r)
|
||||
|
||||
session, err := r.Get(tt.key)
|
||||
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantResult, session != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_GetWithUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry)
|
||||
user string
|
||||
key types.SessionKey
|
||||
wantErr error
|
||||
wantResult bool
|
||||
}{
|
||||
{
|
||||
name: "session found",
|
||||
setupFunc: func(r *registry) {
|
||||
user := "user1"
|
||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser[user] = map[types.SessionKey]Session{
|
||||
key: session,
|
||||
}
|
||||
r.slugIndex[key] = user
|
||||
},
|
||||
user: "user1",
|
||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||
wantErr: nil,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "session not found in slugIndex",
|
||||
setupFunc: func(r *registry) {},
|
||||
user: "user1",
|
||||
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
{
|
||||
name: "session not found in byUser",
|
||||
setupFunc: func(r *registry) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
|
||||
},
|
||||
user: "user1",
|
||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[types.SessionKey]Session),
|
||||
slugIndex: make(map[types.SessionKey]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
tt.setupFunc(r)
|
||||
|
||||
session, err := r.GetWithUser(tt.user, tt.key)
|
||||
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantResult, session != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Update(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user string
|
||||
setupFunc func(r *registry) (oldKey, newKey types.SessionKey)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "change slug success",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession("user1")
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "change slug to already used slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
newKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
r.slugIndex[newKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSlugInUse,
|
||||
},
|
||||
{
|
||||
name: "change slug to forbidden slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrForbiddenSlug,
|
||||
},
|
||||
{
|
||||
name: "change slug to invalid slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrInvalidSlug,
|
||||
},
|
||||
{
|
||||
name: "change slug but session not found",
|
||||
user: "user2",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
|
||||
}
|
||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
{
|
||||
name: "change slug but session is not in the map",
|
||||
user: "user2",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
|
||||
}
|
||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
{
|
||||
name: "change slug with same slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSlugUnchanged,
|
||||
},
|
||||
{
|
||||
name: "tcp tunnel cannot change slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
||||
newKey := oldKey
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSlugChangeNotAllowed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[types.SessionKey]Session),
|
||||
slugIndex: make(map[types.SessionKey]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
|
||||
oldKey, newKey := tt.setupFunc(r)
|
||||
|
||||
err := r.Update(tt.user, oldKey, newKey)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
if err == nil {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
_, ok := r.byUser[tt.user][newKey]
|
||||
assert.True(t, ok, "newKey not found in registry")
|
||||
_, ok = r.byUser[tt.user][oldKey]
|
||||
assert.False(t, ok, "oldKey still exists in registry")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Register(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user string
|
||||
setupFunc func(r *registry) Key
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "register new key successfully",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) Key {
|
||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
return key
|
||||
},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "register already existing key fails",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) Key {
|
||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
r.byUser["user1"] = map[Key]Session{key: session}
|
||||
r.slugIndex[key] = "user1"
|
||||
r.mu.Unlock()
|
||||
|
||||
return key
|
||||
},
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "register multiple keys for same user",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) Key {
|
||||
firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
r.mu.Lock()
|
||||
r.byUser["user1"] = map[Key]Session{firstKey: session}
|
||||
r.slugIndex[firstKey] = "user1"
|
||||
r.mu.Unlock()
|
||||
|
||||
return types.SessionKey{Id: "second", Type: types.TunnelTypeHTTP}
|
||||
},
|
||||
wantOK: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
|
||||
key := tt.setupFunc(r)
|
||||
session := createMockSession()
|
||||
|
||||
ok := r.Register(key, session)
|
||||
assert.Equal(t, tt.wantOK, ok)
|
||||
|
||||
if ok {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
assert.Equal(t, session, r.byUser[tt.user][key], "session not stored in byUser")
|
||||
assert.Equal(t, tt.user, r.slugIndex[key], "slugIndex not updated")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_GetAllSessionFromUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry) string
|
||||
expectN int
|
||||
}{
|
||||
{
|
||||
name: "user has no sessions",
|
||||
setupFunc: func(r *registry) string {
|
||||
return "user1"
|
||||
},
|
||||
expectN: 0,
|
||||
},
|
||||
{
|
||||
name: "user has multiple sessions",
|
||||
setupFunc: func(r *registry) string {
|
||||
user := "user1"
|
||||
key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
|
||||
r.mu.Lock()
|
||||
r.byUser[user] = map[Key]Session{
|
||||
key1: createMockSession(),
|
||||
key2: createMockSession(),
|
||||
}
|
||||
r.mu.Unlock()
|
||||
return user
|
||||
},
|
||||
expectN: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
user := tt.setupFunc(r)
|
||||
sessions := r.GetAllSessionFromUser(user)
|
||||
assert.Len(t, sessions, tt.expectN)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Remove(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry) (string, types.SessionKey)
|
||||
key types.SessionKey
|
||||
verify func(*testing.T, *registry, string, types.SessionKey)
|
||||
}{
|
||||
{
|
||||
name: "remove existing key",
|
||||
setupFunc: func(r *registry) (string, types.SessionKey) {
|
||||
user := "user1"
|
||||
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
r.mu.Lock()
|
||||
r.byUser[user] = map[Key]Session{key: session}
|
||||
r.slugIndex[key] = user
|
||||
r.mu.Unlock()
|
||||
return user, key
|
||||
},
|
||||
verify: func(t *testing.T, r *registry, user string, key types.SessionKey) {
|
||||
_, ok := r.byUser[user][key]
|
||||
assert.False(t, ok, "expected key to be removed from byUser")
|
||||
_, ok = r.slugIndex[key]
|
||||
assert.False(t, ok, "expected key to be removed from slugIndex")
|
||||
_, ok = r.byUser[user]
|
||||
assert.False(t, ok, "expected user to be removed from byUser map")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove non-existing key",
|
||||
setupFunc: func(r *registry) (string, types.SessionKey) {
|
||||
return "", types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
user, key := tt.setupFunc(r)
|
||||
if user == "" {
|
||||
key = tt.key
|
||||
}
|
||||
r.Remove(key)
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, r, user, key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidSlug(t *testing.T) {
|
||||
tests := []struct {
|
||||
slug string
|
||||
want bool
|
||||
}{
|
||||
{"abc", true},
|
||||
{"abc-123", true},
|
||||
{"a", false},
|
||||
{"verybigdihsixsevenlabubu", false},
|
||||
{"-iamsigma", false},
|
||||
{"ligma-", false},
|
||||
{"invalid$", false},
|
||||
{"valid-slug1", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.slug, func(t *testing.T) {
|
||||
got := isValidSlug(tt.slug)
|
||||
if got != tt.want {
|
||||
t.Errorf("isValidSlug(%q) = %v; want %v", tt.slug, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidSlugChar(t *testing.T) {
|
||||
tests := []struct {
|
||||
char byte
|
||||
want bool
|
||||
}{
|
||||
{'a', true},
|
||||
{'z', true},
|
||||
{'0', true},
|
||||
{'9', true},
|
||||
{'-', true},
|
||||
{'A', false},
|
||||
{'$', false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(string(tt.char), func(t *testing.T) {
|
||||
got := isValidSlugChar(tt.char)
|
||||
if got != tt.want {
|
||||
t.Errorf("isValidSlugChar(%q) = %v; want %v", tt.char, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsForbiddenSlug(t *testing.T) {
|
||||
forbiddenSlugs = map[string]struct{}{
|
||||
"admin": {},
|
||||
"root": {},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
slug string
|
||||
want bool
|
||||
}{
|
||||
{"admin", true},
|
||||
{"root", true},
|
||||
{"user", false},
|
||||
{"guest", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.slug, func(t *testing.T) {
|
||||
got := isForbiddenSlug(tt.slug)
|
||||
if got != tt.want {
|
||||
t.Errorf("isForbiddenSlug(%q) = %v; want %v", tt.slug, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,27 +4,28 @@ import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/registry"
|
||||
)
|
||||
|
||||
type httpServer struct {
|
||||
handler *httpHandler
|
||||
port string
|
||||
config config.Config
|
||||
}
|
||||
|
||||
func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
|
||||
func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport {
|
||||
return &httpServer{
|
||||
handler: newHTTPHandler(sessionRegistry, redirectTLS),
|
||||
port: port,
|
||||
handler: newHTTPHandler(config, sessionRegistry),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (ht *httpServer) Listen() (net.Listener, error) {
|
||||
return net.Listen("tcp", ":"+ht.port)
|
||||
return net.Listen("tcp", ":"+ht.config.HTTPPort())
|
||||
}
|
||||
|
||||
func (ht *httpServer) Serve(listener net.Listener) error {
|
||||
log.Printf("HTTP server is starting on port %s", ht.port)
|
||||
log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort())
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
@@ -35,6 +36,6 @@ func (ht *httpServer) Serve(listener net.Listener) error {
|
||||
continue
|
||||
}
|
||||
|
||||
go ht.handler.handler(conn, false)
|
||||
go ht.handler.Handler(conn, false)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestNewHTTPServer(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
httpSrv, ok := srv.(*httpServer)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
|
||||
assert.NotNil(t, srv)
|
||||
}
|
||||
|
||||
func TestHTTPServer_Listen(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, listener)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHTTPServer_Serve(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
err = srv.Serve(listener)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
}
|
||||
|
||||
func TestHTTPServer_Serve_AcceptError(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
err := srv.Serve(ml)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
ml.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestHTTPServer_Serve_Success(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
mockConfig.On("HeaderSize").Return(4096)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
listenerport := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
}
|
||||
|
||||
type mockListener struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockListener) Accept() (net.Conn, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(net.Conn), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockListener) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockListener) Addr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -20,14 +22,14 @@ import (
|
||||
)
|
||||
|
||||
type httpHandler struct {
|
||||
config config.Config
|
||||
sessionRegistry registry.Registry
|
||||
redirectTLS bool
|
||||
}
|
||||
|
||||
func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
|
||||
func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler {
|
||||
return &httpHandler{
|
||||
config: config,
|
||||
sessionRegistry: sessionRegistry,
|
||||
redirectTLS: redirectTLS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,13 +52,28 @@ func (hh *httpHandler) badRequest(conn net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
||||
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
|
||||
defer hh.closeConnection(conn)
|
||||
|
||||
dstReader := bufio.NewReader(conn)
|
||||
reqhf, err := header.NewRequest(dstReader)
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
buf := make([]byte, hh.config.HeaderSize())
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 {
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
|
||||
reqhf, err := header.NewRequest(buf[:n])
|
||||
if err != nil {
|
||||
log.Printf("Error creating request header: %v", err)
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -67,7 +84,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
||||
}
|
||||
|
||||
if hh.shouldRedirectToTLS(isTLS) {
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain()))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -75,13 +92,16 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
||||
return
|
||||
}
|
||||
|
||||
sshSession, err := hh.getSession(slug)
|
||||
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
||||
Id: slug,
|
||||
Type: types.TunnelTypeHTTP,
|
||||
})
|
||||
if err != nil {
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
|
||||
return
|
||||
}
|
||||
|
||||
hw := stream.New(conn, dstReader, conn.RemoteAddr())
|
||||
hw := stream.New(conn, conn, conn.RemoteAddr())
|
||||
defer func(hw stream.HTTP) {
|
||||
err = hw.Close()
|
||||
if err != nil {
|
||||
@@ -100,14 +120,14 @@ func (hh *httpHandler) closeConnection(conn net.Conn) {
|
||||
|
||||
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
|
||||
host := strings.Split(reqhf.Value("Host"), ".")
|
||||
if len(host) < 1 {
|
||||
if len(host) <= 1 {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
return host[0], nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
|
||||
return !isTLS && hh.redirectTLS
|
||||
return !isTLS && hh.config.TLSRedirect()
|
||||
}
|
||||
|
||||
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
||||
@@ -126,34 +146,29 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
||||
))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 200 OK:", err)
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
|
||||
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
||||
Id: slug,
|
||||
Type: types.HTTP,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sshSession, nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
||||
channel, err := hh.openForwardedChannel(hw, sshSession)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, hw.RemoteAddr())
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
defer func() {
|
||||
err = channel.Close()
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error closing forwarded channel: %v", err)
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
log.Printf("Failed to establish channel: %v", err)
|
||||
sshSession.Forwarder().WriteBadGatewayResponse(hw)
|
||||
return
|
||||
}
|
||||
|
||||
hh.setupMiddlewares(hw)
|
||||
|
||||
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
|
||||
@@ -163,47 +178,6 @@ func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.Requ
|
||||
sshSession.Forwarder().HandleConnection(hw, channel)
|
||||
}
|
||||
|
||||
func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
|
||||
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
|
||||
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
err error
|
||||
}
|
||||
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
hh.cleanupUnusedChannel(channel, reqs)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
return nil, result.err
|
||||
}
|
||||
go ssh.DiscardRequests(result.reqs)
|
||||
return result.channel, nil
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, errors.New("timeout opening forwarded-tcpip channel")
|
||||
}
|
||||
}
|
||||
|
||||
func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
|
||||
if channel != nil {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Printf("Failed to close unused channel: %v", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
}
|
||||
}
|
||||
|
||||
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
|
||||
fingerprintMiddleware := middleware.NewTunnelFingerprint()
|
||||
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
|
||||
|
||||
@@ -0,0 +1,717 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
args := m.Called(user, oldKey, newKey)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
args := m.Called(key, session)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
args := m.Called(user)
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
|
||||
args := m.Called()
|
||||
return args.Get(0).(lifecycle.Lifecycle)
|
||||
}
|
||||
|
||||
func (m *MockSession) Interaction() interaction.Interaction {
|
||||
args := m.Called()
|
||||
return args.Get(0).(interaction.Interaction)
|
||||
}
|
||||
|
||||
func (m *MockSession) Forwarder() forwarder.Forwarder {
|
||||
args := m.Called()
|
||||
return args.Get(0).(forwarder.Forwarder)
|
||||
}
|
||||
|
||||
func (m *MockSession) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
func (m *MockSession) Detail() *types.Detail {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*types.Detail)
|
||||
}
|
||||
|
||||
type MockSSHChannel struct {
|
||||
ssh.Channel
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Write(data []byte) (int, error) {
|
||||
args := m.Called(data)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockForwarder struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||
m.Called(dst, src)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) TunnelType() types.TunnelType {
|
||||
args := m.Called()
|
||||
return args.Get(0).(types.TunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) ForwardedPort() uint16 {
|
||||
args := m.Called()
|
||||
return uint16(args.Int(0))
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
||||
m.Called(tunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
||||
m.Called(port)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetListener(listener net.Listener) {
|
||||
m.Called(listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Listener() net.Listener {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(ctx, origin)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
|
||||
type MockConn struct {
|
||||
mock.Mock
|
||||
ReadBuffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
if m.ReadBuffer != nil {
|
||||
return m.ReadBuffer.Read(b)
|
||||
}
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
if args.Int(0) == -1 {
|
||||
return len(b), args.Error(1)
|
||||
}
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *wrappedConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func TestNewHTTPHandler(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("Domain").Return("domain")
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
hh := newHTTPHandler(mockConfig, msr)
|
||||
assert.NotNil(t, hh)
|
||||
assert.Equal(t, msr, hh.sessionRegistry)
|
||||
}
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isTLS bool
|
||||
redirectTLS bool
|
||||
request []byte
|
||||
expected []byte
|
||||
setupMocks func(*MockSessionRegistry)
|
||||
setupConn func() (net.Conn, net.Conn)
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "bad request - invalid host",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - missing host",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "isTLS true and redirectTLS true - no redirect",
|
||||
isTLS: true,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redirect to TLS",
|
||||
isTLS: false,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle ping request",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "session not found",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return((registry.Session)(nil), fmt.Errorf("session not found"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - invalid http",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("INVALID\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - header too large",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - no request",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte(""),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - open channel fails",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - send initial request fails",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
go func() {
|
||||
for range reqCh {
|
||||
}
|
||||
}()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - success",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.ReadWriter)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
||||
})
|
||||
|
||||
go func() {
|
||||
for range reqCh {
|
||||
}
|
||||
}()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redirect - write failure",
|
||||
isTLS: false,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - write failure",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read error - connection failure",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte(""),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle ping request - write failure",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "close connection - error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(182, nil)
|
||||
mc.On("Close").Return(fmt.Errorf("close error"))
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - stream close error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return()
|
||||
},
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
|
||||
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
mc.On("RemoteAddr").Return(addr)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - middleware failure",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
|
||||
return k.Id == "test"
|
||||
})).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
},
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Close").Return(nil).Times(2)
|
||||
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - channel close error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.ReadWriter)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - open channel timeout",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
ctx := args.Get(0).(context.Context)
|
||||
<-ctx.Done()
|
||||
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSessionRegistry := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
mockConfig.On("HeaderSize").Return(4096)
|
||||
mockConfig.On("TLSRedirect").Return(true)
|
||||
hh := &httpHandler{
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
config: mockConfig,
|
||||
}
|
||||
|
||||
if tt.setupMocks != nil {
|
||||
tt.setupMocks(mockSessionRegistry)
|
||||
}
|
||||
|
||||
var serverConn, clientConn net.Conn
|
||||
if tt.setupConn != nil {
|
||||
serverConn, clientConn = tt.setupConn()
|
||||
} else {
|
||||
serverConn, clientConn = net.Pipe()
|
||||
}
|
||||
|
||||
if clientConn != nil {
|
||||
defer func(clientConn net.Conn) {
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
}(clientConn)
|
||||
}
|
||||
|
||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
var wrappedServerConn net.Conn
|
||||
if _, ok := serverConn.(*MockConn); ok {
|
||||
wrappedServerConn = serverConn
|
||||
} else {
|
||||
wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr}
|
||||
}
|
||||
|
||||
responseChan := make(chan []byte, 1)
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
if clientConn != nil {
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
var res []byte
|
||||
for {
|
||||
buf := make([]byte, 4096)
|
||||
n, err := clientConn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Logf("Error reading response: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
res = append(res, buf[:n]...)
|
||||
if len(tt.expected) > 0 && len(res) >= len(tt.expected) {
|
||||
break
|
||||
}
|
||||
}
|
||||
responseChan <- res
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := clientConn.Write(tt.request)
|
||||
if err != nil {
|
||||
t.Logf("Error writing request: %v", err)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
close(responseChan)
|
||||
close(doneChan)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
hh.Handler(wrappedServerConn, tt.isTLS)
|
||||
}()
|
||||
|
||||
select {
|
||||
case response := <-responseChan:
|
||||
if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" {
|
||||
resStr := string(response)
|
||||
assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n"))
|
||||
assert.Contains(t, resStr, "Content-Length: 5\r\n")
|
||||
assert.Contains(t, resStr, "Server: Tunnel Please\r\n")
|
||||
assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello"))
|
||||
} else {
|
||||
assert.Equal(t, string(tt.expected), string(response))
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
if clientConn != nil {
|
||||
t.Fatal("Test timeout - no response received")
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if clientConn != nil {
|
||||
<-doneChan
|
||||
}
|
||||
|
||||
mockSessionRegistry.AssertExpectations(t)
|
||||
if mc, ok := serverConn.(*MockConn); ok {
|
||||
mc.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+11
-15
@@ -5,34 +5,30 @@ import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/registry"
|
||||
)
|
||||
|
||||
type https struct {
|
||||
config config.Config
|
||||
tlsConfig *tls.Config
|
||||
httpHandler *httpHandler
|
||||
domain string
|
||||
port string
|
||||
}
|
||||
|
||||
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
|
||||
func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport {
|
||||
return &https{
|
||||
httpHandler: newHTTPHandler(sessionRegistry, redirectTLS),
|
||||
domain: domain,
|
||||
port: port,
|
||||
config: config,
|
||||
tlsConfig: tlsConfig,
|
||||
httpHandler: newHTTPHandler(config, sessionRegistry),
|
||||
}
|
||||
}
|
||||
|
||||
func (ht *https) Listen() (net.Listener, error) {
|
||||
tlsConfig, err := NewTLSConfig(ht.domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tls.Listen("tcp", ":"+ht.port, tlsConfig)
|
||||
|
||||
return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig)
|
||||
}
|
||||
|
||||
func (ht *https) Serve(listener net.Listener) error {
|
||||
log.Printf("HTTPS server is starting on port %s", ht.port)
|
||||
log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort())
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
@@ -43,6 +39,6 @@ func (ht *https) Serve(listener net.Listener) error {
|
||||
continue
|
||||
}
|
||||
|
||||
go ht.httpHandler.handler(conn, true)
|
||||
go ht.httpHandler.Handler(conn, true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewHTTPSServer(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
tlsConfig := &tls.Config{}
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
httpsSrv, ok := srv.(*https)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
|
||||
assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Listen(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
if err != nil {
|
||||
t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err)
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, listener)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Serve(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
err = srv.Serve(listener)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
err := srv.Serve(ml)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
ml.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Serve_Success(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
mockConfig.On("HeaderSize").Return(4096)
|
||||
|
||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
listenerport := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -1,27 +1,28 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type tcp struct {
|
||||
port uint16
|
||||
forwarder forwarder
|
||||
forwarder Forwarder
|
||||
}
|
||||
|
||||
type forwarder interface {
|
||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
||||
type Forwarder interface {
|
||||
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||
}
|
||||
|
||||
func NewTCPServer(port uint16, forwarder forwarder) Transport {
|
||||
func NewTCPServer(port uint16, forwarder Forwarder) Transport {
|
||||
return &tcp{
|
||||
port: port,
|
||||
forwarder: forwarder,
|
||||
@@ -53,11 +54,11 @@ func (tt *tcp) handleTcp(conn net.Conn) {
|
||||
log.Printf("Failed to close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||
channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, conn.RemoteAddr())
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestNewTCPServer(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
port := uint16(9000)
|
||||
|
||||
srv := NewTCPServer(port, mf)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
tcpSrv, ok := srv.(*tcp)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, port, tcpSrv.port)
|
||||
assert.Equal(t, mf, tcpSrv.forwarder)
|
||||
}
|
||||
|
||||
func TestTCPServer_Listen(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, listener)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTCPServer_Serve(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
err = srv.Serve(listener)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestTCPServer_Serve_AcceptError(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
err := srv.Serve(ml)
|
||||
assert.Nil(t, err)
|
||||
ml.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_Serve_Success(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
reqs := make(chan *ssh.Request)
|
||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
|
||||
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
mf.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_handleTcp_Success(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf).(*tcp)
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer func(clientConn net.Conn) {
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
}(clientConn)
|
||||
|
||||
reqs := make(chan *ssh.Request)
|
||||
mockChannel := new(MockSSHChannel)
|
||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
|
||||
|
||||
mf.On("HandleConnection", serverConn, mockChannel).Return()
|
||||
|
||||
srv.handleTcp(serverConn)
|
||||
|
||||
mf.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_handleTcp_CloseError(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf).(*tcp)
|
||||
|
||||
mc := new(MockConn)
|
||||
mc.On("Close").Return(errors.New("close error"))
|
||||
mc.On("RemoteAddr").Return(&net.TCPAddr{})
|
||||
|
||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||
|
||||
srv.handleTcp(mc)
|
||||
mc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf).(*tcp)
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer func(clientConn net.Conn) {
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
}(clientConn)
|
||||
|
||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||
|
||||
srv.handleTcp(serverConn)
|
||||
|
||||
mf.AssertExpectations(t)
|
||||
}
|
||||
+285
-180
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
@@ -16,17 +17,27 @@ import (
|
||||
"github.com/libdns/cloudflare"
|
||||
)
|
||||
|
||||
type TLSManager interface {
|
||||
userCertsExistAndValid() bool
|
||||
loadUserCerts() error
|
||||
startCertWatcher()
|
||||
initCertMagic() error
|
||||
getTLSConfig() *tls.Config
|
||||
getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
func NewTLSConfig(config config.Config) (*tls.Config, error) {
|
||||
var initErr error
|
||||
|
||||
tlsManagerOnce.Do(func() {
|
||||
tm := createTLSManager(config)
|
||||
initErr = tm.initialize()
|
||||
if initErr == nil {
|
||||
globalTLSManager = tm
|
||||
}
|
||||
})
|
||||
|
||||
if initErr != nil {
|
||||
return nil, initErr
|
||||
}
|
||||
|
||||
return globalTLSManager.getTLSConfig(), nil
|
||||
}
|
||||
|
||||
type tlsManager struct {
|
||||
domain string
|
||||
config config.Config
|
||||
|
||||
certPath string
|
||||
keyPath string
|
||||
storagePath string
|
||||
@@ -39,64 +50,60 @@ type tlsManager struct {
|
||||
useCertMagic bool
|
||||
}
|
||||
|
||||
var globalTLSManager TLSManager
|
||||
var globalTLSManager *tlsManager
|
||||
var tlsManagerOnce sync.Once
|
||||
|
||||
func NewTLSConfig(domain string) (*tls.Config, error) {
|
||||
var initErr error
|
||||
func createTLSManager(cfg config.Config) *tlsManager {
|
||||
storagePath := cfg.TLSStoragePath()
|
||||
cleanBase := filepath.Clean(storagePath)
|
||||
|
||||
tlsManagerOnce.Do(func() {
|
||||
certPath := "certs/tls/cert.pem"
|
||||
keyPath := "certs/tls/privkey.pem"
|
||||
storagePath := "certs/tls/certmagic"
|
||||
|
||||
tm := &tlsManager{
|
||||
domain: domain,
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
storagePath: storagePath,
|
||||
}
|
||||
|
||||
if tm.userCertsExistAndValid() {
|
||||
log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath)
|
||||
if err := tm.loadUserCerts(); err != nil {
|
||||
initErr = fmt.Errorf("failed to load user certificates: %w", err)
|
||||
return
|
||||
}
|
||||
tm.useCertMagic = false
|
||||
tm.startCertWatcher()
|
||||
} else {
|
||||
if !isACMEConfigComplete() {
|
||||
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
|
||||
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
|
||||
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
|
||||
if err := tm.initCertMagic(); err != nil {
|
||||
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
|
||||
return
|
||||
}
|
||||
tm.useCertMagic = true
|
||||
}
|
||||
|
||||
globalTLSManager = tm
|
||||
})
|
||||
|
||||
if initErr != nil {
|
||||
return nil, initErr
|
||||
return &tlsManager{
|
||||
config: cfg,
|
||||
certPath: filepath.Join(cleanBase, "cert.pem"),
|
||||
keyPath: filepath.Join(cleanBase, "privkey.pem"),
|
||||
storagePath: filepath.Join(cleanBase, "certmagic"),
|
||||
}
|
||||
|
||||
return globalTLSManager.getTLSConfig(), nil
|
||||
}
|
||||
|
||||
func isACMEConfigComplete() bool {
|
||||
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
|
||||
return cfAPIToken != ""
|
||||
func (tm *tlsManager) initialize() error {
|
||||
if tm.userCertsExistAndValid() {
|
||||
return tm.initializeWithUserCerts()
|
||||
}
|
||||
return tm.initializeWithCertMagic()
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initializeWithUserCerts() error {
|
||||
log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath)
|
||||
|
||||
if err := tm.loadUserCerts(); err != nil {
|
||||
return fmt.Errorf("failed to load user certificates: %w", err)
|
||||
}
|
||||
|
||||
tm.useCertMagic = false
|
||||
tm.startCertWatcher()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initializeWithCertMagic() error {
|
||||
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic",
|
||||
tm.config.Domain(), tm.config.Domain())
|
||||
|
||||
if err := tm.initCertMagic(); err != nil {
|
||||
return fmt.Errorf("failed to initialize CertMagic: %w", err)
|
||||
}
|
||||
|
||||
tm.useCertMagic = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||
if !tm.certFilesExist() {
|
||||
return false
|
||||
}
|
||||
return validateCertDomains(tm.certPath, tm.config.Domain())
|
||||
}
|
||||
|
||||
func (tm *tlsManager) certFilesExist() bool {
|
||||
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
||||
log.Printf("Certificate file not found: %s", tm.certPath)
|
||||
return false
|
||||
@@ -105,66 +112,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||
log.Printf("Key file not found: %s", tm.keyPath)
|
||||
return false
|
||||
}
|
||||
|
||||
return ValidateCertDomains(tm.certPath, tm.domain)
|
||||
}
|
||||
|
||||
func ValidateCertDomains(certPath, domain string) bool {
|
||||
certPEM, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read certificate: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
log.Printf("Failed to decode PEM block from certificate")
|
||||
return false
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse certificate: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().After(cert.NotAfter) {
|
||||
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) {
|
||||
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
var certDomains []string
|
||||
if cert.Subject.CommonName != "" {
|
||||
certDomains = append(certDomains, cert.Subject.CommonName)
|
||||
}
|
||||
certDomains = append(certDomains, cert.DNSNames...)
|
||||
|
||||
hasBase := false
|
||||
hasWildcard := false
|
||||
wildcardDomain := "*." + domain
|
||||
|
||||
for _, d := range certDomains {
|
||||
if d == domain {
|
||||
hasBase = true
|
||||
}
|
||||
if d == wildcardDomain {
|
||||
hasWildcard = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasBase {
|
||||
log.Printf("Certificate does not cover base domain: %s", domain)
|
||||
}
|
||||
if !hasWildcard {
|
||||
log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain)
|
||||
}
|
||||
|
||||
return hasBase && hasWildcard
|
||||
return true
|
||||
}
|
||||
|
||||
func (tm *tlsManager) loadUserCerts() error {
|
||||
@@ -183,74 +131,36 @@ func (tm *tlsManager) loadUserCerts() error {
|
||||
|
||||
func (tm *tlsManager) startCertWatcher() {
|
||||
go func() {
|
||||
var lastCertMod, lastKeyMod time.Time
|
||||
|
||||
if info, err := os.Stat(tm.certPath); err == nil {
|
||||
lastCertMod = info.ModTime()
|
||||
}
|
||||
if info, err := os.Stat(tm.keyPath); err == nil {
|
||||
lastKeyMod = info.ModTime()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
certInfo, certErr := os.Stat(tm.certPath)
|
||||
keyInfo, keyErr := os.Stat(tm.keyPath)
|
||||
|
||||
if certErr != nil || keyErr != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
|
||||
log.Printf("Certificate files changed, reloading...")
|
||||
|
||||
if !ValidateCertDomains(tm.certPath, tm.domain) {
|
||||
log.Printf("New certificates don't cover required domains")
|
||||
|
||||
if !isACMEConfigComplete() {
|
||||
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Switching to CertMagic for automatic certificate management")
|
||||
if err := tm.initCertMagic(); err != nil {
|
||||
log.Printf("Failed to initialize CertMagic: %v", err)
|
||||
continue
|
||||
}
|
||||
tm.useCertMagic = true
|
||||
return
|
||||
}
|
||||
|
||||
if err := tm.loadUserCerts(); err != nil {
|
||||
log.Printf("Failed to reload certificates: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
lastCertMod = certInfo.ModTime()
|
||||
lastKeyMod = keyInfo.ModTime()
|
||||
log.Printf("Certificates reloaded successfully")
|
||||
}
|
||||
}
|
||||
watcher := newCertWatcher(tm)
|
||||
watcher.watch()
|
||||
}()
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initCertMagic() error {
|
||||
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||
if err := tm.createStorageDirectory(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain)
|
||||
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
|
||||
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
|
||||
|
||||
if cfAPIToken == "" {
|
||||
if tm.config.CFAPIToken() == "" {
|
||||
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
||||
}
|
||||
|
||||
magic := tm.createCertMagicConfig()
|
||||
tm.magic = magic
|
||||
|
||||
return tm.obtainCertificates(magic)
|
||||
}
|
||||
|
||||
func (tm *tlsManager) createStorageDirectory() error {
|
||||
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
|
||||
cfProvider := &cloudflare.Provider{
|
||||
APIToken: cfAPIToken,
|
||||
APIToken: tm.config.CFAPIToken(),
|
||||
}
|
||||
|
||||
storage := &certmagic.FileStorage{Path: tm.storagePath}
|
||||
@@ -265,8 +175,15 @@ func (tm *tlsManager) initCertMagic() error {
|
||||
Storage: storage,
|
||||
})
|
||||
|
||||
acmeIssuer := tm.createACMEIssuer(magic, cfProvider)
|
||||
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
||||
|
||||
return magic
|
||||
}
|
||||
|
||||
func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer {
|
||||
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
||||
Email: acmeEmail,
|
||||
Email: tm.config.ACMEEmail(),
|
||||
Agreed: true,
|
||||
DNS01Solver: &certmagic.DNS01Solver{
|
||||
DNSManager: certmagic.DNSManager{
|
||||
@@ -275,7 +192,7 @@ func (tm *tlsManager) initCertMagic() error {
|
||||
},
|
||||
})
|
||||
|
||||
if acmeStaging {
|
||||
if tm.config.ACMEStaging() {
|
||||
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
|
||||
log.Printf("Using Let's Encrypt staging server")
|
||||
} else {
|
||||
@@ -283,10 +200,11 @@ func (tm *tlsManager) initCertMagic() error {
|
||||
log.Printf("Using Let's Encrypt production server")
|
||||
}
|
||||
|
||||
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
||||
tm.magic = magic
|
||||
return acmeIssuer
|
||||
}
|
||||
|
||||
domains := []string{tm.domain, "*." + tm.domain}
|
||||
func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error {
|
||||
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
|
||||
log.Printf("Requesting certificates for: %v", domains)
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -328,3 +246,190 @@ func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certifica
|
||||
|
||||
return tm.userCert, nil
|
||||
}
|
||||
|
||||
func validateCertDomains(certPath, domain string) bool {
|
||||
cert, err := loadAndParseCertificate(certPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !isCertificateValid(cert) {
|
||||
return false
|
||||
}
|
||||
|
||||
return certCoversRequiredDomains(cert, domain)
|
||||
}
|
||||
|
||||
func loadAndParseCertificate(certPath string) (*x509.Certificate, error) {
|
||||
certPEM, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read certificate: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
log.Printf("Failed to decode PEM block from certificate")
|
||||
return nil, fmt.Errorf("failed to decode PEM block")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse certificate: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func isCertificateValid(cert *x509.Certificate) bool {
|
||||
now := time.Now()
|
||||
|
||||
if now.After(cert.NotAfter) {
|
||||
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
thirtyDaysFromNow := now.Add(30 * 24 * time.Hour)
|
||||
if thirtyDaysFromNow.After(cert.NotAfter) {
|
||||
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool {
|
||||
certDomains := extractCertDomains(cert)
|
||||
hasBase, hasWildcard := checkDomainCoverage(certDomains, domain)
|
||||
|
||||
logDomainCoverage(hasBase, hasWildcard, domain)
|
||||
return hasBase && hasWildcard
|
||||
}
|
||||
|
||||
func extractCertDomains(cert *x509.Certificate) []string {
|
||||
var domains []string
|
||||
if cert.Subject.CommonName != "" {
|
||||
domains = append(domains, cert.Subject.CommonName)
|
||||
}
|
||||
domains = append(domains, cert.DNSNames...)
|
||||
return domains
|
||||
}
|
||||
|
||||
func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) {
|
||||
wildcardDomain := "*." + domain
|
||||
|
||||
for _, d := range certDomains {
|
||||
if d == domain {
|
||||
hasBase = true
|
||||
}
|
||||
if d == wildcardDomain {
|
||||
hasWildcard = true
|
||||
}
|
||||
}
|
||||
|
||||
return hasBase, hasWildcard
|
||||
}
|
||||
|
||||
func logDomainCoverage(hasBase, hasWildcard bool, domain string) {
|
||||
if !hasBase {
|
||||
log.Printf("Certificate does not cover base domain: %s", domain)
|
||||
}
|
||||
if !hasWildcard {
|
||||
log.Printf("Certificate does not cover wildcard domain: *.%s", domain)
|
||||
}
|
||||
}
|
||||
|
||||
type certWatcher struct {
|
||||
tm *tlsManager
|
||||
lastCertMod time.Time
|
||||
lastKeyMod time.Time
|
||||
}
|
||||
|
||||
func newCertWatcher(tm *tlsManager) *certWatcher {
|
||||
watcher := &certWatcher{tm: tm}
|
||||
watcher.initializeModTimes()
|
||||
return watcher
|
||||
}
|
||||
|
||||
func (cw *certWatcher) initializeModTimes() {
|
||||
if info, err := os.Stat(cw.tm.certPath); err == nil {
|
||||
cw.lastCertMod = info.ModTime()
|
||||
}
|
||||
if info, err := os.Stat(cw.tm.keyPath); err == nil {
|
||||
cw.lastKeyMod = info.ModTime()
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *certWatcher) watch() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if cw.checkAndReloadCerts() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *certWatcher) checkAndReloadCerts() bool {
|
||||
certInfo, keyInfo, err := cw.getFileInfo()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cw.filesModified(certInfo, keyInfo) {
|
||||
return false
|
||||
}
|
||||
|
||||
return cw.handleCertificateChange(certInfo, keyInfo)
|
||||
}
|
||||
|
||||
func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) {
|
||||
certInfo, certErr := os.Stat(cw.tm.certPath)
|
||||
keyInfo, keyErr := os.Stat(cw.tm.keyPath)
|
||||
|
||||
if certErr != nil || keyErr != nil {
|
||||
return nil, nil, fmt.Errorf("file stat error")
|
||||
}
|
||||
|
||||
return certInfo, keyInfo, nil
|
||||
}
|
||||
|
||||
func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool {
|
||||
return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod)
|
||||
}
|
||||
|
||||
func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool {
|
||||
log.Printf("Certificate files changed, reloading...")
|
||||
|
||||
if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) {
|
||||
return cw.switchToCertMagic()
|
||||
}
|
||||
|
||||
if err := cw.tm.loadUserCerts(); err != nil {
|
||||
log.Printf("Failed to reload certificates: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
cw.updateModTimes(certInfo, keyInfo)
|
||||
log.Printf("Certificates reloaded successfully")
|
||||
return false
|
||||
}
|
||||
|
||||
func (cw *certWatcher) switchToCertMagic() bool {
|
||||
log.Printf("New certificates don't cover required domains")
|
||||
|
||||
if err := cw.tm.initCertMagic(); err != nil {
|
||||
log.Printf("Failed to initialize CertMagic: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
cw.tm.useCertMagic = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) {
|
||||
cw.lastCertMod = certInfo.ModTime()
|
||||
cw.lastKeyMod = keyInfo.ModTime()
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,3 +8,7 @@ type Transport interface {
|
||||
Listen() (net.Listener, error)
|
||||
Serve(listener net.Listener) error
|
||||
}
|
||||
|
||||
type HTTP interface {
|
||||
Handler(conn net.Conn, isTLS bool)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVersionFunctions(t *testing.T) {
|
||||
origVersion := Version
|
||||
origBuildDate := BuildDate
|
||||
origCommit := Commit
|
||||
defer func() {
|
||||
Version = origVersion
|
||||
BuildDate = origBuildDate
|
||||
Commit = origCommit
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
version string
|
||||
buildDate string
|
||||
commit string
|
||||
wantFull string
|
||||
wantShort string
|
||||
}{
|
||||
{
|
||||
name: "Default dev version",
|
||||
version: "dev",
|
||||
buildDate: "unknown",
|
||||
commit: "unknown",
|
||||
wantFull: "tunnel_pls dev (commit: unknown, built: unknown)",
|
||||
wantShort: "dev",
|
||||
},
|
||||
{
|
||||
name: "Release version",
|
||||
version: "v1.0.0",
|
||||
buildDate: "2026-01-23",
|
||||
commit: "abcdef123",
|
||||
wantFull: "tunnel_pls v1.0.0 (commit: abcdef123, built: 2026-01-23)",
|
||||
wantShort: "v1.0.0",
|
||||
},
|
||||
{
|
||||
name: "Empty values",
|
||||
version: "",
|
||||
buildDate: "",
|
||||
commit: "",
|
||||
wantFull: "tunnel_pls (commit: , built: )",
|
||||
wantShort: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
Version = tt.version
|
||||
BuildDate = tt.buildDate
|
||||
Commit = tt.commit
|
||||
|
||||
gotFull := GetVersion()
|
||||
if gotFull != tt.wantFull {
|
||||
t.Errorf("GetVersion() = %q, want %q", gotFull, tt.wantFull)
|
||||
}
|
||||
|
||||
gotShort := GetShortVersion()
|
||||
if gotShort != tt.wantShort {
|
||||
t.Errorf("GetShortVersion() = %q, want %q", gotShort, tt.wantShort)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetVersion_Format(t *testing.T) {
|
||||
v := "1.2.3"
|
||||
c := "brainrot"
|
||||
d := "now"
|
||||
|
||||
Version = v
|
||||
Commit = c
|
||||
BuildDate = d
|
||||
|
||||
expected := fmt.Sprintf("tunnel_pls %s (commit: %s, built: %s)", v, c, d)
|
||||
if GetVersion() != expected {
|
||||
t.Errorf("GetVersion() formatting mismatch")
|
||||
}
|
||||
}
|
||||
@@ -1,28 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
"tunnel_pls/internal/bootstrap"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/grpc/client"
|
||||
"tunnel_pls/internal/key"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/internal/transport"
|
||||
"tunnel_pls/internal/version"
|
||||
"tunnel_pls/server"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -33,184 +18,19 @@ func main() {
|
||||
|
||||
log.SetOutput(os.Stdout)
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
log.Printf("Starting %s", version.GetVersion())
|
||||
|
||||
err := config.Load()
|
||||
conf, err := config.MustLoad()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load configuration: %s", err)
|
||||
return
|
||||
log.Fatalf("Config load error: %v", err)
|
||||
}
|
||||
|
||||
mode := strings.ToLower(config.Getenv("MODE", "standalone"))
|
||||
isNodeMode := mode == "node"
|
||||
|
||||
pprofEnabled := config.Getenv("PPROF_ENABLED", "false")
|
||||
if pprofEnabled == "true" {
|
||||
pprofPort := config.Getenv("PPROF_PORT", "6060")
|
||||
go func() {
|
||||
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
|
||||
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
||||
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||
log.Printf("pprof server error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
sshConfig := &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
|
||||
}
|
||||
|
||||
sshKeyPath := "certs/ssh/id_rsa"
|
||||
if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
|
||||
log.Fatalf("Failed to generate SSH key: %s", err)
|
||||
}
|
||||
|
||||
privateBytes, err := os.ReadFile(sshKeyPath)
|
||||
boot, err := bootstrap.New(conf, port.New())
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load private key: %s", err)
|
||||
log.Fatalf("Startup error: %v", err)
|
||||
}
|
||||
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse private key: %s", err)
|
||||
}
|
||||
|
||||
sshConfig.AddHostKey(private)
|
||||
sessionRegistry := registry.NewRegistry()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
shutdownChan := make(chan os.Signal, 1)
|
||||
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
var grpcClient client.Client
|
||||
if isNodeMode {
|
||||
grpcHost := config.Getenv("GRPC_ADDRESS", "localhost")
|
||||
grpcPort := config.Getenv("GRPC_PORT", "8080")
|
||||
grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort)
|
||||
nodeToken := config.Getenv("NODE_TOKEN", "")
|
||||
if nodeToken == "" {
|
||||
log.Fatalf("NODE_TOKEN is required in node mode")
|
||||
}
|
||||
|
||||
grpcClient, err = client.New(grpcAddr, sessionRegistry)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create grpc client: %v", err)
|
||||
}
|
||||
|
||||
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
if err = grpcClient.CheckServerHealth(healthCtx); err != nil {
|
||||
healthCancel()
|
||||
log.Fatalf("gRPC health check failed: %v", err)
|
||||
}
|
||||
healthCancel()
|
||||
|
||||
go func() {
|
||||
identity := config.Getenv("DOMAIN", "localhost")
|
||||
if err = grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil {
|
||||
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
portManager := port.New()
|
||||
rawRange := config.Getenv("ALLOWED_PORTS", "")
|
||||
if rawRange != "" {
|
||||
splitRange := strings.Split(rawRange, "-")
|
||||
if len(splitRange) == 2 {
|
||||
var start, end uint64
|
||||
start, err = strconv.ParseUint(splitRange[0], 10, 16)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse start port: %s", err)
|
||||
}
|
||||
|
||||
end, err = strconv.ParseUint(splitRange[1], 10, 16)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse end port: %s", err)
|
||||
}
|
||||
|
||||
if err = portManager.AddRange(uint16(start), uint16(end)); err != nil {
|
||||
log.Fatalf("Failed to add port range: %s", err)
|
||||
}
|
||||
log.Printf("PortRegistry range configured: %d-%d", start, end)
|
||||
} else {
|
||||
log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange)
|
||||
}
|
||||
}
|
||||
|
||||
tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true"
|
||||
redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true"
|
||||
|
||||
go func() {
|
||||
httpPort := config.Getenv("HTTP_PORT", "8080")
|
||||
|
||||
var httpListener net.Listener
|
||||
httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS)
|
||||
httpListener, err = httpserver.Listen()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
||||
return
|
||||
}
|
||||
err = httpserver.Serve(httpListener)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error when serving http server: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
if tlsEnabled {
|
||||
go func() {
|
||||
httpsPort := config.Getenv("HTTPS_PORT", "8443")
|
||||
domain := config.Getenv("DOMAIN", "localhost")
|
||||
|
||||
var httpListener net.Listener
|
||||
httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS)
|
||||
httpListener, err = httpserver.Listen()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
||||
return
|
||||
}
|
||||
err = httpserver.Serve(httpListener)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("error when serving http server: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
var app server.Server
|
||||
go func() {
|
||||
sshPort := config.Getenv("PORT", "2200")
|
||||
app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to start server: %s", err)
|
||||
return
|
||||
}
|
||||
app.Start()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err = <-errChan:
|
||||
log.Printf("error happen : %s", err)
|
||||
case sig := <-shutdownChan:
|
||||
log.Printf("received signal %s, shutting down", sig)
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
if app != nil {
|
||||
if err = app.Close(); err != nil {
|
||||
log.Printf("failed to close server : %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if grpcClient != nil {
|
||||
if err = grpcClient.Close(); err != nil {
|
||||
log.Printf("failed to close grpc conn : %s", err)
|
||||
}
|
||||
if err = boot.Run(); err != nil {
|
||||
log.Fatalf("Application error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
+23
-8
@@ -4,11 +4,14 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/grpc/client"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session"
|
||||
|
||||
@@ -20,24 +23,28 @@ type Server interface {
|
||||
Close() error
|
||||
}
|
||||
type server struct {
|
||||
randomizer random.Random
|
||||
config config.Config
|
||||
sshPort string
|
||||
sshListener net.Listener
|
||||
config *ssh.ServerConfig
|
||||
sshConfig *ssh.ServerConfig
|
||||
grpcClient client.Client
|
||||
sessionRegistry registry.Registry
|
||||
portRegistry port.Port
|
||||
}
|
||||
|
||||
func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
|
||||
func New(randomizer random.Random, config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &server{
|
||||
randomizer: randomizer,
|
||||
config: config,
|
||||
sshPort: sshPort,
|
||||
sshListener: listener,
|
||||
config: sshConfig,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: grpcClient,
|
||||
sessionRegistry: sessionRegistry,
|
||||
portRegistry: portRegistry,
|
||||
@@ -66,7 +73,7 @@ func (s *server) Close() error {
|
||||
}
|
||||
|
||||
func (s *server) handleConnection(conn net.Conn) {
|
||||
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
|
||||
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
||||
if err != nil {
|
||||
log.Printf("failed to establish SSH connection: %v", err)
|
||||
err = conn.Close()
|
||||
@@ -79,7 +86,7 @@ func (s *server) handleConnection(conn net.Conn) {
|
||||
|
||||
defer func(sshConn *ssh.ServerConn) {
|
||||
err = sshConn.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
|
||||
log.Printf("failed to close SSH server: %v", err)
|
||||
}
|
||||
}(sshConn)
|
||||
@@ -92,11 +99,19 @@ func (s *server) handleConnection(conn net.Conn) {
|
||||
cancel()
|
||||
}
|
||||
log.Println("SSH connection established:", sshConn.User())
|
||||
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
|
||||
sshSession := session.New(&session.Config{
|
||||
Randomizer: s.randomizer,
|
||||
Config: s.config,
|
||||
Conn: sshConn,
|
||||
InitialReq: forwardingReqs,
|
||||
SshChan: chans,
|
||||
SessionRegistry: s.sessionRegistry,
|
||||
PortRegistry: s.portRegistry,
|
||||
User: user,
|
||||
})
|
||||
err = sshSession.Start()
|
||||
if err != nil {
|
||||
log.Printf("SSH session ended with error: %v", err)
|
||||
log.Printf("SSH session ended with error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,880 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type MockRandom struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRandom) String(length int) (string, error) {
|
||||
args := m.Called(length)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
type MockConfig struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConfig) Domain() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return 0
|
||||
}
|
||||
switch v := args.Get(0).(type) {
|
||||
case types.ServerMode:
|
||||
return v
|
||||
case int:
|
||||
return types.ServerMode(v)
|
||||
default:
|
||||
return types.ServerMode(args.Int(0))
|
||||
}
|
||||
}
|
||||
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
args := m.Called(user, oldKey, newKey)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
args := m.Called(key, session)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
args := m.Called(user)
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
type MockGRPCClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*grpc.ClientConn)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
||||
args := m.Called(ctx, token)
|
||||
return args.Bool(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
|
||||
args := m.Called(ctx, domain, token)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockPort struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
||||
return m.Called(startPort, endPort).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
||||
args := m.Called()
|
||||
return uint16(args.Int(0)), args.Bool(1)
|
||||
}
|
||||
|
||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
||||
return m.Called(port, assigned).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPort) Claim(port uint16) bool {
|
||||
return m.Called(port).Bool(0)
|
||||
}
|
||||
|
||||
type MockListener struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockListener) Accept() (net.Conn, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(net.Conn), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockListener) Close() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func (m *MockListener) Addr() net.Addr {
|
||||
return m.Called().Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
signer, _ := ssh.NewSignerFromKey(key)
|
||||
config := &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
}
|
||||
config.AddHostKey(signer)
|
||||
return config, signer
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
port: "0",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
port: "invalid",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, err := New(mr, mc, sc, mreg, mg, mp, tt.port)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
_ = s.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("port already in use", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := l.Addr().(*net.TCPAddr).Port
|
||||
defer func(l net.Listener) {
|
||||
err = l.Close()
|
||||
assert.NoError(t, err)
|
||||
}(l)
|
||||
|
||||
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
t.Run("successful close", func(t *testing.T) {
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
err := s.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("close already closed listener", func(t *testing.T) {
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
_ = s.Close()
|
||||
err := s.Close()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("close with nil listener", func(t *testing.T) {
|
||||
s := &server{
|
||||
sshListener: nil,
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
assert.NotNil(t, r)
|
||||
}
|
||||
}()
|
||||
_ = s.Close()
|
||||
t.Fatal("expected panic for nil listener")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
t.Run("normal stop", func(t *testing.T) {
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_ = s.Close()
|
||||
}()
|
||||
s.Start()
|
||||
})
|
||||
|
||||
t.Run("accept error - temporary error continues loop", func(t *testing.T) {
|
||||
ml := new(MockListener)
|
||||
s := &server{
|
||||
sshListener: ml,
|
||||
sshPort: "0",
|
||||
}
|
||||
|
||||
ml.On("Accept").Return(nil, errors.New("temporary error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s.Start()
|
||||
ml.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("accept error - immediate close", func(t *testing.T) {
|
||||
ml := new(MockListener)
|
||||
s := &server{
|
||||
sshListener: ml,
|
||||
sshPort: "0",
|
||||
}
|
||||
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s.Start()
|
||||
ml.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("accept success - connection fails SSH handshake", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
mockListener := &MockListener{}
|
||||
mockListener.On("Accept").Return(serverConn, nil).Once()
|
||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshListener: mockListener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("accept success - valid SSH connection without auth", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
mockListener := &MockListener{}
|
||||
mockListener.On("Accept").Return(serverConn, nil).Once()
|
||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshListener: mockListener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
t.Run("SSH handshake fails - connection closed", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
s.handleConnection(serverConn)
|
||||
})
|
||||
|
||||
// SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS
|
||||
// GONNA IMPLEMENT THIS UNIT TEST LATER
|
||||
|
||||
//t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) {
|
||||
// mockRandom := &MockRandom{}
|
||||
// mockConfig := &MockConfig{}
|
||||
// mockSessionRegistry := &MockSessionRegistry{}
|
||||
// mockGrpcClient := &MockGRPCClient{}
|
||||
// mockPort := &MockPort{}
|
||||
//
|
||||
// sshConfig, _ := getTestSSHConfig()
|
||||
//
|
||||
// serverConn, clientConn := net.Pipe()
|
||||
//
|
||||
// s := &server{
|
||||
// randomizer: mockRandom,
|
||||
// config: mockConfig,
|
||||
// sshPort: "0",
|
||||
// sshConfig: sshConfig,
|
||||
// grpcClient: mockGrpcClient,
|
||||
// sessionRegistry: mockSessionRegistry,
|
||||
// portRegistry: mockPort,
|
||||
// }
|
||||
//
|
||||
// done := make(chan bool, 1)
|
||||
//
|
||||
// go func() {
|
||||
// s.handleConnection(serverConn)
|
||||
// done <- true
|
||||
// }()
|
||||
//
|
||||
// go func() {
|
||||
// clientConn.Write([]byte("invalid ssh protocol\n"))
|
||||
// clientConn.Close()
|
||||
// }()
|
||||
//
|
||||
// select {
|
||||
// case <-done:
|
||||
// case <-time.After(1 * time.Second):
|
||||
// t.Fatal("handleConnection did not complete in time")
|
||||
// }
|
||||
//})
|
||||
|
||||
t.Run("SSH connection established without gRPC client", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
mockConfig.On("Domain").Return("test.com")
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("SSHPort").Return("2200")
|
||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func(listener net.Listener) {
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}(listener)
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConnection(conn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
||||
if err != nil {
|
||||
t.Logf("Client dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(client *ssh.Client) {
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}(client)
|
||||
|
||||
type forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
payload := ssh.Marshal(forwardPayload{
|
||||
BindAddr: "localhost",
|
||||
BindPort: 80,
|
||||
})
|
||||
|
||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
||||
if err != nil {
|
||||
t.Logf("Forward request failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("handleConnection completed")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SSH connection established with gRPC authorization", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
||||
mockConfig.On("Domain").Return("test.com")
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("SSHPort").Return("2200")
|
||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func(listener net.Listener) {
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}(listener)
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConnection(conn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
||||
if err != nil {
|
||||
t.Logf("Client dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(client *ssh.Client) {
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}(client)
|
||||
|
||||
type forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
payload := ssh.Marshal(forwardPayload{
|
||||
BindAddr: "localhost",
|
||||
BindPort: 80,
|
||||
})
|
||||
|
||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
||||
if err != nil {
|
||||
t.Logf("Forward request failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
mockGrpcClient.AssertExpectations(t)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SSH connection with gRPC authorization error", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
||||
mockConfig.On("Domain").Return("test.com")
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("SSHPort").Return("2200")
|
||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func(listener net.Listener) {
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}(listener)
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConnection(conn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
||||
if err != nil {
|
||||
t.Logf("Client dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(client *ssh.Client) {
|
||||
_ = client.Close()
|
||||
}(client)
|
||||
|
||||
type forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
payload := ssh.Marshal(forwardPayload{
|
||||
BindAddr: "localhost",
|
||||
BindPort: 8080,
|
||||
})
|
||||
|
||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
||||
if err != nil {
|
||||
t.Logf("Forward request failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
mockGrpcClient.AssertExpectations(t)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("connection cleanup on close", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
s.handleConnection(serverConn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
t.Run("full server lifecycle", func(t *testing.T) {
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
s, err := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err := s.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
s.Start()
|
||||
})
|
||||
|
||||
t.Run("multiple connections", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
conn1Server, conn1Client := net.Pipe()
|
||||
conn2Server, conn2Client := net.Pipe()
|
||||
|
||||
mockListener := &MockListener{}
|
||||
mockListener.On("Accept").Return(conn1Server, nil).Once()
|
||||
mockListener.On("Accept").Return(conn2Server, nil).Once()
|
||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshListener: mockListener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = conn1Client.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = conn2Client.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorHandling(t *testing.T) {
|
||||
t.Run("write error during SSH handshake", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
s.handleConnection(serverConn)
|
||||
})
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,7 +9,6 @@ import (
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
@@ -18,37 +16,6 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
bufSize := config.GetBufferSize()
|
||||
return make([]byte, bufSize)
|
||||
},
|
||||
}
|
||||
|
||||
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||
buf := bufferPool.Get().([]byte)
|
||||
defer bufferPool.Put(buf)
|
||||
return io.CopyBuffer(dst, src, buf)
|
||||
}
|
||||
|
||||
type forwarder struct {
|
||||
listener net.Listener
|
||||
tunnelType types.TunnelType
|
||||
forwardedPort uint16
|
||||
slug slug.Slug
|
||||
conn ssh.Conn
|
||||
}
|
||||
|
||||
func New(slug slug.Slug, conn ssh.Conn) Forwarder {
|
||||
return &forwarder{
|
||||
listener: nil,
|
||||
tunnelType: types.UNKNOWN,
|
||||
forwardedPort: 0,
|
||||
slug: slug,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
type Forwarder interface {
|
||||
SetType(tunnelType types.TunnelType)
|
||||
SetForwardedPort(port uint16)
|
||||
@@ -57,13 +24,43 @@ type Forwarder interface {
|
||||
TunnelType() types.TunnelType
|
||||
ForwardedPort() uint16
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
||||
WriteBadGatewayResponse(dst io.Writer)
|
||||
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
|
||||
Close() error
|
||||
}
|
||||
type forwarder struct {
|
||||
listener net.Listener
|
||||
tunnelType types.TunnelType
|
||||
forwardedPort uint16
|
||||
slug slug.Slug
|
||||
conn ssh.Conn
|
||||
bufferPool sync.Pool
|
||||
}
|
||||
|
||||
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
|
||||
return &forwarder{
|
||||
listener: nil,
|
||||
tunnelType: types.TunnelTypeUNKNOWN,
|
||||
forwardedPort: 0,
|
||||
slug: slug,
|
||||
conn: conn,
|
||||
bufferPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
bufSize := config.BufferSize()
|
||||
buf := make([]byte, bufSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||
buf := f.bufferPool.Get().(*[]byte)
|
||||
defer f.bufferPool.Put(buf)
|
||||
return io.CopyBuffer(dst, src, *buf)
|
||||
}
|
||||
|
||||
func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
payload := createForwardedTCPIPPayload(origin, f.forwardedPort)
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
@@ -75,13 +72,9 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
|
||||
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
if channel != nil {
|
||||
err = channel.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close unused channel: %v", err)
|
||||
return
|
||||
}
|
||||
_ = channel.Close()
|
||||
go ssh.DiscardRequests(reqs)
|
||||
}
|
||||
}
|
||||
@@ -90,8 +83,8 @@ func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *s
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
return result.channel, result.reqs, result.err
|
||||
case <-time.After(5 * time.Second):
|
||||
return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
|
||||
case <-ctx.Done():
|
||||
return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,7 +100,7 @@ func closeWriter(w io.Writer) error {
|
||||
|
||||
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
|
||||
var errs []error
|
||||
_, err := copyWithBuffer(dst, src)
|
||||
_, err := f.copyWithBuffer(dst, src)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
|
||||
}
|
||||
@@ -120,10 +113,7 @@ func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string)
|
||||
|
||||
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||
defer func() {
|
||||
_, err := io.Copy(io.Discard, src)
|
||||
if err != nil {
|
||||
log.Printf("Failed to discard connection: %v", err)
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, src)
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@@ -174,14 +164,6 @@ func (f *forwarder) Listener() net.Listener {
|
||||
return f.listener
|
||||
}
|
||||
|
||||
func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||
_, err := dst.Write(types.BadGatewayResponse)
|
||||
if err != nil {
|
||||
log.Printf("failed to write Bad Gateway response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarder) Close() error {
|
||||
if f.Listener() != nil {
|
||||
return f.listener.Close()
|
||||
@@ -189,43 +171,21 @@ func (f *forwarder) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
host, originPort := parseAddr(origin.String())
|
||||
|
||||
writeSSHString(&buf, "localhost")
|
||||
err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
writeSSHString(&buf, host)
|
||||
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func parseAddr(addr string) (string, uint16) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||
return "0.0.0.0", uint16(0)
|
||||
}
|
||||
func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte {
|
||||
host, portStr, _ := net.SplitHostPort(origin.String())
|
||||
port, _ := strconv.Atoi(portStr)
|
||||
return host, uint16(port)
|
||||
}
|
||||
|
||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return
|
||||
forwardPayload := struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}{
|
||||
DestAddr: "localhost",
|
||||
DestPort: uint32(destPort),
|
||||
OriginAddr: host,
|
||||
OriginPort: uint32(port),
|
||||
}
|
||||
buffer.WriteString(str)
|
||||
|
||||
return ssh.Marshal(forwardPayload)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,34 +10,37 @@ import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
func (m *model) handleCommandSelection(item commandItem) (tea.Model, tea.Cmd) {
|
||||
switch item.name {
|
||||
case "slug":
|
||||
m.showingCommands = false
|
||||
m.editingSlug = true
|
||||
m.slugInput.SetValue(m.interaction.slug.String())
|
||||
m.slugInput.Focus()
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case "tunnel-type":
|
||||
m.showingCommands = false
|
||||
m.showingComingSoon = true
|
||||
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
|
||||
default:
|
||||
m.showingCommands = false
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
switch {
|
||||
case key.Matches(msg, m.keymap.quit):
|
||||
case key.Matches(msg, m.keymap.quit), msg.String() == "esc":
|
||||
m.showingCommands = false
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case msg.String() == "enter":
|
||||
selectedItem := m.commandList.SelectedItem()
|
||||
if selectedItem != nil {
|
||||
item := selectedItem.(commandItem)
|
||||
if item.name == "slug" {
|
||||
m.showingCommands = false
|
||||
m.editingSlug = true
|
||||
m.slugInput.SetValue(m.interaction.slug.String())
|
||||
m.slugInput.Focus()
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
} else if item.name == "tunnel-type" {
|
||||
m.showingCommands = false
|
||||
m.showingComingSoon = true
|
||||
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
m.showingCommands = false
|
||||
return m, nil
|
||||
return m.handleCommandSelection(item)
|
||||
}
|
||||
case msg.String() == "esc":
|
||||
m.showingCommands = false
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
m.commandList, cmd = m.commandList.Update(msg)
|
||||
return m, cmd
|
||||
|
||||
+140
-110
@@ -23,164 +23,194 @@ func (m *model) dashboardUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
func (m *model) dashboardView() string {
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
PaddingTop(1)
|
||||
|
||||
subtitleStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#888888")).
|
||||
Italic(true)
|
||||
|
||||
urlStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
Underline(true).
|
||||
Italic(true)
|
||||
|
||||
urlBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#04B575")).
|
||||
Bold(true).
|
||||
Italic(true)
|
||||
|
||||
keyHintStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
Bold(true)
|
||||
isCompact := shouldUseCompactLayout(m.width, BreakpointLarge)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(m.renderHeader(isCompact))
|
||||
b.WriteString(m.renderUserInfo(isCompact))
|
||||
b.WriteString(m.renderQuickActions(isCompact))
|
||||
b.WriteString(m.renderFooter(isCompact))
|
||||
|
||||
isCompact := shouldUseCompactLayout(m.width, 85)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
var asciiArtMargin int
|
||||
if isCompact {
|
||||
asciiArtMargin = 0
|
||||
} else {
|
||||
asciiArtMargin = 1
|
||||
}
|
||||
func (m *model) renderHeader(isCompact bool) string {
|
||||
var b strings.Builder
|
||||
|
||||
asciiArtMargin := getMarginValue(isCompact, 0, 1)
|
||||
asciiArtStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
MarginBottom(asciiArtMargin)
|
||||
|
||||
var asciiArt string
|
||||
if shouldUseCompactLayout(m.width, 50) {
|
||||
asciiArt = "TUNNEL PLS"
|
||||
} else if isCompact {
|
||||
asciiArt = `
|
||||
b.WriteString(asciiArtStyle.Render(m.getASCIIArt()))
|
||||
b.WriteString("\n")
|
||||
|
||||
if !shouldUseCompactLayout(m.width, BreakpointSmall) {
|
||||
b.WriteString(m.renderSubtitle())
|
||||
} else {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) getASCIIArt() string {
|
||||
if shouldUseCompactLayout(m.width, BreakpointTiny) {
|
||||
return "TUNNEL PLS"
|
||||
}
|
||||
|
||||
if shouldUseCompactLayout(m.width, BreakpointLarge) {
|
||||
return `
|
||||
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
|
||||
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
|
||||
} else {
|
||||
asciiArt = `
|
||||
}
|
||||
|
||||
return `
|
||||
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
|
||||
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
|
||||
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
|
||||
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
|
||||
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
|
||||
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString(asciiArtStyle.Render(asciiArt))
|
||||
b.WriteString("\n")
|
||||
func (m *model) renderSubtitle() string {
|
||||
subtitleStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorGray)).
|
||||
Italic(true)
|
||||
|
||||
if !shouldUseCompactLayout(m.width, 60) {
|
||||
b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
|
||||
b.WriteString(urlStyle.Render("https://fossy.my.id"))
|
||||
b.WriteString("\n\n")
|
||||
} else {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
urlStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
Underline(true).
|
||||
Italic(true)
|
||||
|
||||
return subtitleStyle.Render("Secure tunnel service by Bagas • ") +
|
||||
urlStyle.Render("https://fossy.my.id") + "\n\n"
|
||||
}
|
||||
|
||||
func (m *model) renderUserInfo(isCompact bool) string {
|
||||
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
|
||||
var boxPadding int
|
||||
var boxMargin int
|
||||
if isCompact {
|
||||
boxPadding = 1
|
||||
boxMargin = 1
|
||||
} else {
|
||||
boxPadding = 2
|
||||
boxMargin = 2
|
||||
}
|
||||
boxPadding := getMarginValue(isCompact, 1, 2)
|
||||
boxMargin := getMarginValue(isCompact, 1, 2)
|
||||
|
||||
responsiveInfoBox := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||
BorderForeground(lipgloss.Color(ColorPrimary)).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin).
|
||||
Width(boxMaxWidth)
|
||||
|
||||
authenticatedUser := m.interaction.user
|
||||
infoContent := m.getUserInfoContent(isCompact)
|
||||
return responsiveInfoBox.Render(infoContent) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) getUserInfoContent(isCompact bool) string {
|
||||
userInfoStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
Foreground(lipgloss.Color(ColorWhite)).
|
||||
Bold(true)
|
||||
|
||||
sectionHeaderStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#888888")).
|
||||
Foreground(lipgloss.Color(ColorGray)).
|
||||
Bold(true)
|
||||
|
||||
addressStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA"))
|
||||
Foreground(lipgloss.Color(ColorWhite))
|
||||
|
||||
var infoContent string
|
||||
if shouldUseCompactLayout(m.width, 70) {
|
||||
infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s",
|
||||
urlBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorSecondary)).
|
||||
Bold(true).
|
||||
Italic(true)
|
||||
|
||||
authenticatedUser := m.interaction.user
|
||||
tunnelURL := urlBoxStyle.Render(m.getTunnelURL())
|
||||
|
||||
if isCompact {
|
||||
return fmt.Sprintf("👤 %s\n\n%s\n%s",
|
||||
userInfoStyle.Render(authenticatedUser),
|
||||
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
|
||||
addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL()))))
|
||||
} else {
|
||||
infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
|
||||
userInfoStyle.Render(authenticatedUser),
|
||||
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
|
||||
addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL())))
|
||||
addressStyle.Render(fmt.Sprintf(" %s", tunnelURL)))
|
||||
}
|
||||
|
||||
b.WriteString(responsiveInfoBox.Render(infoContent))
|
||||
return fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
|
||||
userInfoStyle.Render(authenticatedUser),
|
||||
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
|
||||
addressStyle.Render(tunnelURL))
|
||||
}
|
||||
|
||||
func (m *model) renderQuickActions(isCompact bool) string {
|
||||
var b strings.Builder
|
||||
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
PaddingTop(1)
|
||||
|
||||
b.WriteString(titleStyle.Render(m.getQuickActionsTitle()))
|
||||
b.WriteString("\n")
|
||||
|
||||
var quickActionsTitle string
|
||||
if shouldUseCompactLayout(m.width, 50) {
|
||||
quickActionsTitle = "Actions"
|
||||
} else if isCompact {
|
||||
quickActionsTitle = "Quick Actions"
|
||||
} else {
|
||||
quickActionsTitle = "✨ Quick Actions"
|
||||
}
|
||||
b.WriteString(titleStyle.Render(quickActionsTitle))
|
||||
b.WriteString("\n")
|
||||
|
||||
var featureMargin int
|
||||
if isCompact {
|
||||
featureMargin = 1
|
||||
} else {
|
||||
featureMargin = 2
|
||||
}
|
||||
|
||||
compactFeatureStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
featureMargin := getMarginValue(isCompact, 1, 2)
|
||||
featureStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorWhite)).
|
||||
MarginLeft(featureMargin)
|
||||
|
||||
var commandsText string
|
||||
var quitText string
|
||||
if shouldUseCompactLayout(m.width, 60) {
|
||||
commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
|
||||
quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
|
||||
} else {
|
||||
commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
|
||||
quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
|
||||
}
|
||||
keyHintStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
Bold(true)
|
||||
|
||||
b.WriteString(compactFeatureStyle.Render(commandsText))
|
||||
commands := m.getActionCommands(keyHintStyle)
|
||||
b.WriteString(featureStyle.Render(commands.commandsText))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(compactFeatureStyle.Render(quitText))
|
||||
|
||||
if !shouldUseCompactLayout(m.width, 70) {
|
||||
b.WriteString("\n\n")
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#666666")).
|
||||
Italic(true)
|
||||
b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
|
||||
}
|
||||
b.WriteString(featureStyle.Render(commands.quitText))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) getQuickActionsTitle() string {
|
||||
if shouldUseCompactLayout(m.width, BreakpointTiny) {
|
||||
return "Actions"
|
||||
}
|
||||
if shouldUseCompactLayout(m.width, BreakpointLarge) {
|
||||
return "Quick Actions"
|
||||
}
|
||||
return "✨ Quick Actions"
|
||||
}
|
||||
|
||||
type actionCommands struct {
|
||||
commandsText string
|
||||
quitText string
|
||||
}
|
||||
|
||||
func (m *model) getActionCommands(keyHintStyle lipgloss.Style) actionCommands {
|
||||
if shouldUseCompactLayout(m.width, BreakpointSmall) {
|
||||
return actionCommands{
|
||||
commandsText: fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]")),
|
||||
quitText: fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]")),
|
||||
}
|
||||
}
|
||||
|
||||
return actionCommands{
|
||||
commandsText: fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]")),
|
||||
quitText: fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]")),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) renderFooter(isCompact bool) string {
|
||||
if isCompact {
|
||||
return ""
|
||||
}
|
||||
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorDarkGray)).
|
||||
Italic(true)
|
||||
|
||||
return "\n\n" + footerStyle.Render("Press 'C' to customize your tunnel settings")
|
||||
}
|
||||
|
||||
func getMarginValue(isCompact bool, compactValue, normalValue int) int {
|
||||
if isCompact {
|
||||
return compactValue
|
||||
}
|
||||
return normalValue
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package interaction
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
@@ -18,9 +20,9 @@ import (
|
||||
)
|
||||
|
||||
type Interaction interface {
|
||||
Mode() types.Mode
|
||||
Mode() types.InteractiveMode
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetMode(m types.Mode)
|
||||
SetMode(m types.InteractiveMode)
|
||||
SetWH(w, h int)
|
||||
Start()
|
||||
Redraw()
|
||||
@@ -39,6 +41,8 @@ type Forwarder interface {
|
||||
|
||||
type CloseFunc func() error
|
||||
type interaction struct {
|
||||
randomizer random.Random
|
||||
config config.Config
|
||||
channel ssh.Channel
|
||||
slug slug.Slug
|
||||
forwarder Forwarder
|
||||
@@ -48,14 +52,15 @@ type interaction struct {
|
||||
program *tea.Program
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mode types.Mode
|
||||
mode types.InteractiveMode
|
||||
programMu sync.Mutex
|
||||
}
|
||||
|
||||
func (i *interaction) SetMode(m types.Mode) {
|
||||
func (i *interaction) SetMode(m types.InteractiveMode) {
|
||||
i.mode = m
|
||||
}
|
||||
|
||||
func (i *interaction) Mode() types.Mode {
|
||||
func (i *interaction) Mode() types.InteractiveMode {
|
||||
return i.mode
|
||||
}
|
||||
|
||||
@@ -75,9 +80,11 @@ func (i *interaction) SetWH(w, h int) {
|
||||
}
|
||||
}
|
||||
|
||||
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
|
||||
func New(randomizer random.Random, config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &interaction{
|
||||
randomizer: randomizer,
|
||||
config: config,
|
||||
channel: nil,
|
||||
slug: slug,
|
||||
forwarder: forwarder,
|
||||
@@ -98,6 +105,10 @@ func (i *interaction) Stop() {
|
||||
if i.cancel != nil {
|
||||
i.cancel()
|
||||
}
|
||||
|
||||
i.programMu.Lock()
|
||||
defer i.programMu.Unlock()
|
||||
|
||||
if i.program != nil {
|
||||
i.program.Kill()
|
||||
i.program = nil
|
||||
@@ -174,14 +185,13 @@ func (m *model) View() string {
|
||||
}
|
||||
|
||||
func (i *interaction) Start() {
|
||||
if i.mode == types.HEADLESS {
|
||||
if i.mode == types.InteractiveModeHEADLESS {
|
||||
return
|
||||
}
|
||||
lipgloss.SetColorProfile(termenv.TrueColor)
|
||||
|
||||
domain := config.Getenv("DOMAIN", "localhost")
|
||||
protocol := "http"
|
||||
if config.Getenv("TLS_ENABLED", "false") == "true" {
|
||||
if i.config.TLSEnabled() {
|
||||
protocol = "https"
|
||||
}
|
||||
|
||||
@@ -209,7 +219,8 @@ func (i *interaction) Start() {
|
||||
ti.Width = 50
|
||||
|
||||
m := &model{
|
||||
domain: domain,
|
||||
randomizer: i.randomizer,
|
||||
domain: i.config.Domain(),
|
||||
protocol: protocol,
|
||||
tunnelType: tunnelType,
|
||||
port: port,
|
||||
@@ -233,6 +244,7 @@ func (i *interaction) Start() {
|
||||
help: help.New(),
|
||||
}
|
||||
|
||||
i.programMu.Lock()
|
||||
i.program = tea.NewProgram(
|
||||
m,
|
||||
tea.WithInput(i.channel),
|
||||
@@ -243,16 +255,21 @@ func (i *interaction) Start() {
|
||||
tea.WithoutSignalHandler(),
|
||||
tea.WithFPS(30),
|
||||
)
|
||||
i.programMu.Unlock()
|
||||
|
||||
_, err := i.program.Run()
|
||||
if err != nil {
|
||||
log.Printf("Cannot close tea: %s \n", err)
|
||||
}
|
||||
i.program.Kill()
|
||||
i.program = nil
|
||||
|
||||
i.programMu.Lock()
|
||||
if i.program != nil {
|
||||
i.program.Kill()
|
||||
i.program = nil
|
||||
}
|
||||
i.programMu.Unlock()
|
||||
|
||||
if i.closeFunc != nil {
|
||||
if err := i.closeFunc(); err != nil {
|
||||
log.Printf("Cannot close session: %s \n", err)
|
||||
}
|
||||
_ = i.closeFunc()
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ package interaction
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/charmbracelet/bubbles/help"
|
||||
@@ -22,6 +23,7 @@ func (i commandItem) Title() string { return i.name }
|
||||
func (i commandItem) Description() string { return i.desc }
|
||||
|
||||
type model struct {
|
||||
randomizer random.Random
|
||||
domain string
|
||||
protocol string
|
||||
tunnelType types.TunnelType
|
||||
@@ -40,8 +42,27 @@ type model struct {
|
||||
height int
|
||||
}
|
||||
|
||||
const (
|
||||
ColorPrimary = "#7D56F4"
|
||||
ColorSecondary = "#04B575"
|
||||
ColorGray = "#888888"
|
||||
ColorDarkGray = "#666666"
|
||||
ColorWhite = "#FAFAFA"
|
||||
ColorError = "#FF0000"
|
||||
ColorErrorBg = "#3D0000"
|
||||
ColorWarning = "#FFA500"
|
||||
ColorWarningBg = "#3D2000"
|
||||
)
|
||||
|
||||
const (
|
||||
BreakpointTiny = 50
|
||||
BreakpointSmall = 60
|
||||
BreakpointMedium = 70
|
||||
BreakpointLarge = 85
|
||||
)
|
||||
|
||||
func (m *model) getTunnelURL() string {
|
||||
if m.tunnelType == types.HTTP {
|
||||
if m.tunnelType == types.TunnelTypeHTTP {
|
||||
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
|
||||
}
|
||||
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
|
||||
|
||||
+160
-119
@@ -3,7 +3,6 @@ package interaction
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
@@ -15,14 +14,14 @@ import (
|
||||
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
if m.tunnelType != types.HTTP {
|
||||
if m.tunnelType != types.TunnelTypeHTTP {
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
case "esc", "ctrl+c":
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
@@ -30,10 +29,10 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
inputValue := m.slugInput.Value()
|
||||
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
|
||||
Id: m.interaction.slug.String(),
|
||||
Type: types.HTTP,
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}, types.SessionKey{
|
||||
Id: inputValue,
|
||||
Type: types.HTTP,
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}); err != nil {
|
||||
m.slugError = err.Error()
|
||||
return m, nil
|
||||
@@ -41,19 +40,13 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case "ctrl+c":
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
default:
|
||||
if key.Matches(msg, m.keymap.random) {
|
||||
newSubdomain, err := random.GenerateRandomString(20)
|
||||
newSubdomain, err := m.randomizer.String(20)
|
||||
if err != nil {
|
||||
return m, cmd
|
||||
}
|
||||
m.slugInput.SetValue(newSubdomain)
|
||||
m.slugError = ""
|
||||
m.slugInput, cmd = m.slugInput.Update(msg)
|
||||
}
|
||||
m.slugError = ""
|
||||
m.slugInput, cmd = m.slugInput.Update(msg)
|
||||
@@ -62,163 +55,211 @@ func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
}
|
||||
|
||||
func (m *model) slugView() string {
|
||||
isCompact := shouldUseCompactLayout(m.width, 70)
|
||||
isVeryCompact := shouldUseCompactLayout(m.width, 50)
|
||||
isCompact := shouldUseCompactLayout(m.width, BreakpointMedium)
|
||||
isVeryCompact := shouldUseCompactLayout(m.width, BreakpointTiny)
|
||||
|
||||
var boxPadding int
|
||||
var boxMargin int
|
||||
if isVeryCompact {
|
||||
boxPadding = 1
|
||||
boxMargin = 1
|
||||
} else if isCompact {
|
||||
boxPadding = 1
|
||||
boxMargin = 1
|
||||
} else {
|
||||
boxPadding = 2
|
||||
boxMargin = 2
|
||||
var b strings.Builder
|
||||
b.WriteString(m.renderSlugTitle(isVeryCompact))
|
||||
|
||||
if m.tunnelType != types.TunnelTypeHTTP {
|
||||
b.WriteString(m.renderTCPWarning(isVeryCompact, isCompact))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
b.WriteString(m.renderSlugRules(isVeryCompact, isCompact))
|
||||
b.WriteString(m.renderSlugInstruction(isVeryCompact))
|
||||
b.WriteString(m.renderSlugInput(isVeryCompact, isCompact))
|
||||
b.WriteString(m.renderSlugPreview(isVeryCompact))
|
||||
b.WriteString(m.renderSlugHelp(isVeryCompact))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) renderSlugTitle(isVeryCompact bool) string {
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
PaddingTop(1).
|
||||
PaddingBottom(1)
|
||||
|
||||
instructionStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
MarginTop(1)
|
||||
title := "🔧 Edit Subdomain"
|
||||
if isVeryCompact {
|
||||
title = "Edit Subdomain"
|
||||
}
|
||||
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
return titleStyle.Render(title) + "\n\n"
|
||||
}
|
||||
|
||||
func (m *model) renderTCPWarning(isVeryCompact, isCompact bool) string {
|
||||
boxPadding := getPaddingValue(isVeryCompact, isCompact)
|
||||
boxMargin := getMarginValue(isCompact, 1, 2)
|
||||
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||
|
||||
warningBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorWarning)).
|
||||
Background(lipgloss.Color(ColorWarningBg)).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||
BorderForeground(lipgloss.Color(ColorWarning)).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin)
|
||||
MarginBottom(boxMargin).
|
||||
Width(warningBoxWidth)
|
||||
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#666666")).
|
||||
Foreground(lipgloss.Color(ColorDarkGray)).
|
||||
Italic(true).
|
||||
MarginTop(1)
|
||||
|
||||
errorBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FF0000")).
|
||||
Background(lipgloss.Color("#3D0000")).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#FF0000")).
|
||||
Padding(0, boxPadding).
|
||||
MarginTop(1).
|
||||
MarginBottom(1)
|
||||
warningText := m.getTCPWarningText(isVeryCompact)
|
||||
helpText := m.getTCPHelpText(isVeryCompact)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(warningBoxStyle.Render(warningText))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) getTCPWarningText(isVeryCompact bool) string {
|
||||
if isVeryCompact {
|
||||
return "⚠️ TCP tunnels don't support custom subdomains."
|
||||
}
|
||||
return "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
|
||||
}
|
||||
|
||||
func (m *model) getTCPHelpText(isVeryCompact bool) string {
|
||||
if isVeryCompact {
|
||||
return "Press any key to go back"
|
||||
}
|
||||
return "Press Enter or Esc to go back"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugRules(isVeryCompact, isCompact bool) string {
|
||||
boxPadding := getPaddingValue(isVeryCompact, isCompact)
|
||||
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||
|
||||
rulesBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
Foreground(lipgloss.Color(ColorWhite)).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||
BorderForeground(lipgloss.Color(ColorPrimary)).
|
||||
Padding(0, boxPadding).
|
||||
MarginTop(1).
|
||||
MarginBottom(1).
|
||||
Width(rulesBoxWidth)
|
||||
|
||||
var b strings.Builder
|
||||
var title string
|
||||
rulesContent := m.getRulesContent(isVeryCompact, isCompact)
|
||||
return rulesBoxStyle.Render(rulesContent) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) getRulesContent(isVeryCompact, isCompact bool) string {
|
||||
if isVeryCompact {
|
||||
title = "Edit Subdomain"
|
||||
} else {
|
||||
title = "🔧 Edit Subdomain"
|
||||
}
|
||||
b.WriteString(titleStyle.Render(title))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.tunnelType != types.HTTP {
|
||||
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||
warningBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FFA500")).
|
||||
Background(lipgloss.Color("#3D2000")).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#FFA500")).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin).
|
||||
Width(warningBoxWidth)
|
||||
|
||||
var warningText string
|
||||
if isVeryCompact {
|
||||
warningText = "⚠️ TCP tunnels don't support custom subdomains."
|
||||
} else {
|
||||
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
|
||||
}
|
||||
b.WriteString(warningBoxStyle.Render(warningText))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
var helpText string
|
||||
if isVeryCompact {
|
||||
helpText = "Press any key to go back"
|
||||
} else {
|
||||
helpText = "Press Enter or Esc to go back"
|
||||
}
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
return b.String()
|
||||
return "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
|
||||
}
|
||||
|
||||
var rulesContent string
|
||||
if isVeryCompact {
|
||||
rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
|
||||
} else if isCompact {
|
||||
rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
|
||||
} else {
|
||||
rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
|
||||
if isCompact {
|
||||
return "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
|
||||
}
|
||||
b.WriteString(rulesBoxStyle.Render(rulesContent))
|
||||
b.WriteString("\n")
|
||||
|
||||
var instruction string
|
||||
return "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugInstruction(isVeryCompact bool) string {
|
||||
instructionStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorWhite)).
|
||||
MarginTop(1)
|
||||
|
||||
instruction := "Enter your custom subdomain:"
|
||||
if isVeryCompact {
|
||||
instruction = "Custom subdomain:"
|
||||
} else {
|
||||
instruction = "Enter your custom subdomain:"
|
||||
}
|
||||
b.WriteString(instructionStyle.Render(instruction))
|
||||
b.WriteString("\n")
|
||||
|
||||
return instructionStyle.Render(instruction) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugInput(isVeryCompact, isCompact bool) string {
|
||||
boxPadding := getPaddingValue(isVeryCompact, isCompact)
|
||||
boxMargin := getMarginValue(isCompact, 1, 2)
|
||||
|
||||
if m.slugError != "" {
|
||||
errorInputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#FF0000")).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(1)
|
||||
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
|
||||
b.WriteString("\n")
|
||||
} else {
|
||||
b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
|
||||
b.WriteString("\n")
|
||||
return m.renderErrorInput(boxPadding, boxMargin)
|
||||
}
|
||||
|
||||
return m.renderNormalInput(boxPadding, boxMargin)
|
||||
}
|
||||
|
||||
func (m *model) renderErrorInput(boxPadding, boxMargin int) string {
|
||||
errorInputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color(ColorError)).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(1)
|
||||
|
||||
errorBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorError)).
|
||||
Background(lipgloss.Color(ColorErrorBg)).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color(ColorError)).
|
||||
Padding(0, boxPadding).
|
||||
MarginTop(1).
|
||||
MarginBottom(1)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) renderNormalInput(boxPadding, boxMargin int) string {
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color(ColorPrimary)).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin)
|
||||
|
||||
return inputBoxStyle.Render(m.slugInput.View()) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugPreview(isVeryCompact bool) string {
|
||||
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
|
||||
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
|
||||
|
||||
if len(previewURL) > previewWidth-10 {
|
||||
if isVeryCompact {
|
||||
previewURL = truncateString(previewURL, previewWidth-10)
|
||||
}
|
||||
|
||||
previewStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#04B575")).
|
||||
Foreground(lipgloss.Color(ColorSecondary)).
|
||||
Italic(true).
|
||||
Width(previewWidth)
|
||||
b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
|
||||
b.WriteString("\n")
|
||||
|
||||
var helpText string
|
||||
return previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugHelp(isVeryCompact bool) string {
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorDarkGray)).
|
||||
Italic(true).
|
||||
MarginTop(1)
|
||||
|
||||
helpText := "Press Enter to save • CTRL+R for random • Esc to cancel"
|
||||
if isVeryCompact {
|
||||
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
|
||||
} else {
|
||||
helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
|
||||
}
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
|
||||
return b.String()
|
||||
return helpStyle.Render(helpText)
|
||||
}
|
||||
|
||||
func getPaddingValue(isVeryCompact, isCompact bool) int {
|
||||
if isVeryCompact || isCompact {
|
||||
return 1
|
||||
}
|
||||
return 2
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
@@ -24,7 +25,9 @@ type SessionRegistry interface {
|
||||
}
|
||||
|
||||
type lifecycle struct {
|
||||
status types.Status
|
||||
mu sync.Mutex
|
||||
status types.SessionStatus
|
||||
closeErr error
|
||||
conn ssh.Conn
|
||||
channel ssh.Channel
|
||||
forwarder Forwarder
|
||||
@@ -37,7 +40,7 @@ type lifecycle struct {
|
||||
|
||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
|
||||
return &lifecycle{
|
||||
status: types.INITIALIZING,
|
||||
status: types.SessionStatusINITIALIZING,
|
||||
conn: conn,
|
||||
channel: nil,
|
||||
forwarder: forwarder,
|
||||
@@ -51,10 +54,11 @@ func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUti
|
||||
|
||||
type Lifecycle interface {
|
||||
Connection() ssh.Conn
|
||||
Channel() ssh.Channel
|
||||
PortRegistry() portUtil.Port
|
||||
User() string
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetStatus(status types.Status)
|
||||
SetStatus(status types.SessionStatus)
|
||||
IsActive() bool
|
||||
StartedAt() time.Time
|
||||
Close() error
|
||||
@@ -71,37 +75,47 @@ func (l *lifecycle) User() string {
|
||||
func (l *lifecycle) SetChannel(channel ssh.Channel) {
|
||||
l.channel = channel
|
||||
}
|
||||
|
||||
func (l *lifecycle) Channel() ssh.Channel {
|
||||
return l.channel
|
||||
}
|
||||
|
||||
func (l *lifecycle) Connection() ssh.Conn {
|
||||
return l.conn
|
||||
}
|
||||
func (l *lifecycle) SetStatus(status types.Status) {
|
||||
|
||||
func (l *lifecycle) SetStatus(status types.SessionStatus) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.status = status
|
||||
if status == types.RUNNING && l.startedAt.IsZero() {
|
||||
l.startedAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *lifecycle) IsActive() bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.status == types.SessionStatusRUNNING
|
||||
}
|
||||
|
||||
func (l *lifecycle) Close() error {
|
||||
var firstErr error
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if l.status == types.SessionStatusCLOSED {
|
||||
return l.closeErr
|
||||
}
|
||||
l.status = types.SessionStatusCLOSED
|
||||
|
||||
var errs []error
|
||||
tunnelType := l.forwarder.TunnelType()
|
||||
|
||||
if err := l.forwarder.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
firstErr = err
|
||||
}
|
||||
|
||||
if l.channel != nil {
|
||||
if err := l.channel.Close(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := l.channel.Close(); err != nil && !isClosedError(err) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if l.conn != nil {
|
||||
if err := l.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if err := l.conn.Close(); err != nil && !isClosedError(err) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,17 +126,20 @@ func (l *lifecycle) Close() error {
|
||||
}
|
||||
l.sessionRegistry.Remove(key)
|
||||
|
||||
if tunnelType == types.TCP {
|
||||
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
if tunnelType == types.TunnelTypeTCP {
|
||||
errs = append(errs, l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false))
|
||||
errs = append(errs, l.forwarder.Close())
|
||||
}
|
||||
|
||||
return firstErr
|
||||
l.closeErr = errors.Join(errs...)
|
||||
return l.closeErr
|
||||
}
|
||||
|
||||
func (l *lifecycle) IsActive() bool {
|
||||
return l.status == types.RUNNING
|
||||
func isClosedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || err.Error() == "EOF"
|
||||
}
|
||||
|
||||
func (l *lifecycle) StartedAt() time.Time {
|
||||
|
||||
@@ -0,0 +1,303 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key types.SessionKey) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
type MockForwarder struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||
args := m.Called(origin)
|
||||
return args.Get(0).([]byte)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||
m.Called(dst, src)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) TunnelType() types.TunnelType {
|
||||
args := m.Called()
|
||||
return args.Get(0).(types.TunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) ForwardedPort() uint16 {
|
||||
args := m.Called()
|
||||
return args.Get(0).(uint16)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
||||
m.Called(tunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
||||
m.Called(port)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetListener(listener net.Listener) {
|
||||
m.Called(listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Listener() net.Listener {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(ctx, origin)
|
||||
if args.Get(0) == nil {
|
||||
return nil, nil, args.Error(2)
|
||||
}
|
||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
|
||||
type MockPort struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
||||
return m.Called(startPort, endPort).Error(0)
|
||||
}
|
||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
||||
args := m.Called()
|
||||
var port uint16
|
||||
if args.Get(0) != nil {
|
||||
switch v := args.Get(0).(type) {
|
||||
case int:
|
||||
port = uint16(v)
|
||||
case uint16:
|
||||
port = v
|
||||
case uint32:
|
||||
port = uint16(v)
|
||||
case int32:
|
||||
port = uint16(v)
|
||||
case float64:
|
||||
port = uint16(v)
|
||||
default:
|
||||
port = uint16(args.Int(0))
|
||||
}
|
||||
}
|
||||
return port, args.Bool(1)
|
||||
}
|
||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
||||
return m.Called(port, assigned).Error(0)
|
||||
}
|
||||
func (m *MockPort) Claim(port uint16) bool {
|
||||
return m.Called(port).Bool(0)
|
||||
}
|
||||
|
||||
type MockSlug struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (ms *MockSlug) Set(slug string) {
|
||||
ms.Called(slug)
|
||||
}
|
||||
func (ms *MockSlug) String() string {
|
||||
return ms.Called().String(0)
|
||||
}
|
||||
|
||||
type MockSSHConn struct {
|
||||
ssh.Conn
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockSSHChannel struct {
|
||||
ssh.Channel
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Close() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
assert.NotNil(t, mockLifecycle.Connection())
|
||||
assert.NotNil(t, mockLifecycle.User())
|
||||
assert.NotNil(t, mockLifecycle.PortRegistry())
|
||||
assert.NotNil(t, mockLifecycle.StartedAt())
|
||||
}
|
||||
|
||||
func TestLifecycle_User(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
user := "mas-fuad"
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user)
|
||||
assert.Equal(t, user, mockLifecycle.User())
|
||||
}
|
||||
|
||||
func TestLifecycle_SetChannel(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
mockSSHChannel := &MockSSHChannel{}
|
||||
|
||||
mockLifecycle.SetChannel(mockSSHChannel)
|
||||
|
||||
assert.Equal(t, mockSSHChannel, mockLifecycle.Channel())
|
||||
}
|
||||
|
||||
func TestLifecycle_SetStatus(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||
assert.True(t, mockLifecycle.IsActive())
|
||||
}
|
||||
|
||||
func TestLifecycle_IsActive(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||
assert.True(t, mockLifecycle.IsActive())
|
||||
}
|
||||
|
||||
func TestLifecycle_Close(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tunnelType types.TunnelType
|
||||
connCloseErr error
|
||||
channelCloseErr error
|
||||
expectErr bool
|
||||
alreadyClosed bool
|
||||
}{
|
||||
{
|
||||
name: "Close HTTP forwarding success",
|
||||
tunnelType: types.TunnelTypeHTTP,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Close TCP forwarding success",
|
||||
tunnelType: types.TunnelTypeTCP,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Close with conn close error",
|
||||
tunnelType: types.TunnelTypeHTTP,
|
||||
connCloseErr: errors.New("conn close error"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Close with channel close error",
|
||||
tunnelType: types.TunnelTypeHTTP,
|
||||
channelCloseErr: errors.New("channel close error"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Close when already closed",
|
||||
tunnelType: types.TunnelTypeHTTP,
|
||||
alreadyClosed: true,
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSSHConn := &MockSSHConn{}
|
||||
mockSSHConn.On("Close").Return(tt.connCloseErr)
|
||||
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockForwarder.On("TunnelType").Return(tt.tunnelType)
|
||||
if tt.tunnelType == types.TunnelTypeTCP {
|
||||
mockForwarder.On("ForwardedPort").Return(uint16(8080))
|
||||
mockForwarder.On("Close").Return(nil)
|
||||
}
|
||||
|
||||
mockSlug := &MockSlug{}
|
||||
mockSlug.On("String").Return("test-slug")
|
||||
|
||||
mockPort := &MockPort{}
|
||||
if tt.tunnelType == types.TunnelTypeTCP {
|
||||
mockPort.On("SetStatus", uint16(8080), false).Return(nil)
|
||||
}
|
||||
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return()
|
||||
|
||||
mockSSHChannel := &MockSSHChannel{}
|
||||
mockSSHChannel.On("Close").Return(tt.channelCloseErr)
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||
mockLifecycle.SetChannel(mockSSHChannel)
|
||||
|
||||
if tt.alreadyClosed {
|
||||
err := mockLifecycle.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err := mockLifecycle.Close()
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.False(t, mockLifecycle.IsActive())
|
||||
|
||||
mockSSHConn.AssertExpectations(t)
|
||||
mockForwarder.AssertExpectations(t)
|
||||
mockSlug.AssertExpectations(t)
|
||||
mockPort.AssertExpectations(t)
|
||||
mockSessionRegistry.AssertExpectations(t)
|
||||
mockSSHChannel.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
+67
-84
@@ -1,7 +1,6 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -37,6 +36,8 @@ type Session interface {
|
||||
}
|
||||
|
||||
type session struct {
|
||||
randomizer random.Random
|
||||
config config.Config
|
||||
initialReq <-chan *ssh.Request
|
||||
sshChan <-chan ssh.NewChannel
|
||||
lifecycle lifecycle.Lifecycle
|
||||
@@ -46,22 +47,35 @@ type session struct {
|
||||
registry registry.Registry
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Randomizer random.Random
|
||||
Config config.Config
|
||||
Conn *ssh.ServerConn
|
||||
InitialReq <-chan *ssh.Request
|
||||
SshChan <-chan ssh.NewChannel
|
||||
SessionRegistry registry.Registry
|
||||
PortRegistry portUtil.Port
|
||||
User string
|
||||
}
|
||||
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
|
||||
func New(conf *Config) Session {
|
||||
slugManager := slug.New()
|
||||
forwarderManager := forwarder.New(slugManager, conn)
|
||||
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
|
||||
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
|
||||
forwarderManager := forwarder.New(conf.Config, slugManager, conf.Conn)
|
||||
lifecycleManager := lifecycle.New(conf.Conn, forwarderManager, slugManager, conf.PortRegistry, conf.SessionRegistry, conf.User)
|
||||
interactionManager := interaction.New(conf.Randomizer, conf.Config, slugManager, forwarderManager, conf.SessionRegistry, conf.User, lifecycleManager.Close)
|
||||
|
||||
return &session{
|
||||
initialReq: initialReq,
|
||||
sshChan: sshChan,
|
||||
randomizer: conf.Randomizer,
|
||||
config: conf.Config,
|
||||
initialReq: conf.InitialReq,
|
||||
sshChan: conf.SshChan,
|
||||
lifecycle: lifecycleManager,
|
||||
interaction: interactionManager,
|
||||
forwarder: forwarderManager,
|
||||
slug: slugManager,
|
||||
registry: sessionRegistry,
|
||||
registry: conf.SessionRegistry,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,8 +97,8 @@ func (s *session) Slug() slug.Slug {
|
||||
|
||||
func (s *session) Detail() *types.Detail {
|
||||
tunnelTypeMap := map[types.TunnelType]string{
|
||||
types.HTTP: "HTTP",
|
||||
types.TCP: "TCP",
|
||||
types.TunnelTypeHTTP: "HTTP",
|
||||
types.TunnelTypeTCP: "TCP",
|
||||
}
|
||||
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
|
||||
if !ok {
|
||||
@@ -111,7 +125,7 @@ func (s *session) Start() error {
|
||||
}
|
||||
|
||||
if s.shouldRejectUnauthorized() {
|
||||
return s.denyForwardingRequest(tcpipReq, nil, nil, fmt.Sprintf("headless forwarding only allowed on node mode"))
|
||||
return s.denyForwardingRequest(tcpipReq, nil, nil, "headless forwarding only allowed on node mode")
|
||||
}
|
||||
|
||||
if err := s.HandleTCPIPForward(tcpipReq); err != nil {
|
||||
@@ -131,7 +145,7 @@ func (s *session) setupSessionMode() error {
|
||||
}
|
||||
return s.setupInteractiveMode(channel)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
s.interaction.SetMode(types.HEADLESS)
|
||||
s.interaction.SetMode(types.InteractiveModeHEADLESS)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -152,25 +166,23 @@ func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
|
||||
|
||||
s.lifecycle.SetChannel(ch)
|
||||
s.interaction.SetChannel(ch)
|
||||
s.interaction.SetMode(types.INTERACTIVE)
|
||||
s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleMissingForwardRequest() error {
|
||||
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
|
||||
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = s.lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("no forwarding Request")
|
||||
}
|
||||
|
||||
func (s *session) shouldRejectUnauthorized() bool {
|
||||
return s.interaction.Mode() == types.HEADLESS &&
|
||||
config.Getenv("MODE", "standalone") == "standalone" &&
|
||||
return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
|
||||
s.config.Mode() == types.ServerModeSTANDALONE &&
|
||||
s.lifecycle.User() == "UNAUTHORIZED"
|
||||
}
|
||||
|
||||
@@ -180,7 +192,6 @@ func (s *session) waitForSessionEnd() error {
|
||||
}
|
||||
|
||||
if err := s.lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -225,8 +236,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
||||
for req := range GlobalRequest {
|
||||
switch req.Type {
|
||||
case "shell", "pty-req":
|
||||
err := req.Reply(true, nil)
|
||||
if err != nil {
|
||||
if err := req.Reply(true, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
case "window-change":
|
||||
@@ -235,8 +245,7 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
||||
}
|
||||
default:
|
||||
log.Println("Unknown request type:", req.Type)
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
if err := req.Reply(false, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -244,24 +253,24 @@ func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) {
|
||||
address, err = readSSHString(payloadReader)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) {
|
||||
var forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
var rawPortToBind uint32
|
||||
if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil {
|
||||
return "", 0, err
|
||||
if err = ssh.Unmarshal(payload, &forwardPayload); err != nil {
|
||||
return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err)
|
||||
}
|
||||
|
||||
if rawPortToBind > 65535 {
|
||||
if forwardPayload.BindPort > 65535 {
|
||||
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
|
||||
}
|
||||
|
||||
port = uint16(rawPortToBind)
|
||||
port = uint16(forwardPayload.BindPort)
|
||||
|
||||
if isBlockedPort(port) {
|
||||
return "", 0, fmt.Errorf("port is block")
|
||||
return "", 0, fmt.Errorf("port is blocked")
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
@@ -269,10 +278,10 @@ func (s *session) parseForwardPayload(payloadReader io.Reader) (address string,
|
||||
if !ok {
|
||||
return "", 0, fmt.Errorf("no available port")
|
||||
}
|
||||
return address, unassigned, err
|
||||
return forwardPayload.BindAddr, unassigned, nil
|
||||
}
|
||||
|
||||
return address, port, err
|
||||
return forwardPayload.BindAddr, port, nil
|
||||
}
|
||||
|
||||
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
|
||||
@@ -280,37 +289,25 @@ func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey,
|
||||
if key != nil {
|
||||
s.registry.Remove(*key)
|
||||
}
|
||||
|
||||
if listener != nil {
|
||||
if err := listener.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("close listener: %w", err))
|
||||
}
|
||||
}
|
||||
if err := req.Reply(false, nil); err != nil {
|
||||
errs = append(errs, fmt.Errorf("reply request: %w", err))
|
||||
}
|
||||
if err := s.lifecycle.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("close session: %w", err))
|
||||
errs = append(errs, listener.Close())
|
||||
}
|
||||
|
||||
errs = append(errs, req.Reply(false, nil))
|
||||
errs = append(errs, s.lifecycle.Close())
|
||||
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) {
|
||||
buf := new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.BigEndian, uint32(port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = req.Reply(true, buf.Bytes())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
|
||||
err := s.approveForwardingRequest(req, portToBind)
|
||||
replyPayload := struct {
|
||||
BoundPort uint32
|
||||
}{
|
||||
BoundPort: uint32(portToBind),
|
||||
}
|
||||
err := req.Reply(true, ssh.Marshal(replyPayload))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -318,7 +315,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
|
||||
s.forwarder.SetType(tunnelType)
|
||||
s.forwarder.SetForwardedPort(portToBind)
|
||||
s.slug.Set(slug)
|
||||
s.lifecycle.SetStatus(types.RUNNING)
|
||||
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||
|
||||
if listener != nil {
|
||||
s.forwarder.SetListener(listener)
|
||||
@@ -328,9 +325,7 @@ func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listen
|
||||
}
|
||||
|
||||
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
|
||||
reader := bytes.NewReader(req.Payload)
|
||||
|
||||
address, port, err := s.parseForwardPayload(reader)
|
||||
address, port, err := s.parseForwardPayload(req.Payload)
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
|
||||
}
|
||||
@@ -344,16 +339,16 @@ func (s *session) HandleTCPIPForward(req *ssh.Request) error {
|
||||
}
|
||||
|
||||
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
|
||||
randomString, err := random.GenerateRandomString(20)
|
||||
randomString, err := s.randomizer.String(20)
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
|
||||
}
|
||||
key := types.SessionKey{Id: randomString, Type: types.HTTP}
|
||||
key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
|
||||
if !s.registry.Register(key, s) {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
|
||||
}
|
||||
|
||||
err = s.finalizeForwarding(req, portToBind, nil, types.HTTP, key.Id)
|
||||
err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
||||
}
|
||||
@@ -362,21 +357,21 @@ func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
|
||||
|
||||
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
|
||||
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||
}
|
||||
|
||||
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
|
||||
listener, err := tcpServer.Listen()
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
|
||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||
}
|
||||
|
||||
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP}
|
||||
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
|
||||
if !s.registry.Register(key, s) {
|
||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TCP client with id: %s", key.Id))
|
||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
|
||||
}
|
||||
|
||||
err = s.finalizeForwarding(req, portToBind, listener, types.TCP, key.Id)
|
||||
err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
||||
}
|
||||
@@ -391,18 +386,6 @@ func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uin
|
||||
return nil
|
||||
}
|
||||
|
||||
func readSSHString(reader io.Reader) (string, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
|
||||
return "", err
|
||||
}
|
||||
strBytes := make([]byte, length)
|
||||
if _, err := reader.Read(strBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(strBytes), nil
|
||||
}
|
||||
|
||||
func isBlockedPort(port uint16) bool {
|
||||
if port == 80 || port == 443 {
|
||||
return false
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,99 @@
|
||||
package slug
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type SlugTestSuite struct {
|
||||
suite.Suite
|
||||
slug Slug
|
||||
}
|
||||
|
||||
func (suite *SlugTestSuite) SetupTest() {
|
||||
suite.slug = New()
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
s := New()
|
||||
|
||||
assert.NotNil(t, s, "New() should return a non-nil Slug")
|
||||
assert.Implements(t, (*Slug)(nil), s, "New() should return a type that implements Slug interface")
|
||||
assert.Equal(t, "", s.String(), "New() should initialize with empty string")
|
||||
}
|
||||
|
||||
func (suite *SlugTestSuite) TestString() {
|
||||
assert.Equal(suite.T(), "", suite.slug.String(), "String() should return empty string initially")
|
||||
|
||||
suite.slug.Set("test-slug")
|
||||
assert.Equal(suite.T(), "test-slug", suite.slug.String(), "String() should return the set value")
|
||||
}
|
||||
|
||||
func (suite *SlugTestSuite) TestSet() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple slug",
|
||||
input: "hello-world",
|
||||
expected: "hello-world",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "slug with numbers",
|
||||
input: "test-123",
|
||||
expected: "test-123",
|
||||
},
|
||||
{
|
||||
name: "slug with special characters",
|
||||
input: "hello_world-123",
|
||||
expected: "hello_world-123",
|
||||
},
|
||||
{
|
||||
name: "overwrite existing slug",
|
||||
input: "new-slug",
|
||||
expected: "new-slug",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
suite.Run(tc.name, func() {
|
||||
suite.slug.Set(tc.input)
|
||||
assert.Equal(suite.T(), tc.expected, suite.slug.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *SlugTestSuite) TestMultipleSet() {
|
||||
suite.slug.Set("first-slug")
|
||||
assert.Equal(suite.T(), "first-slug", suite.slug.String())
|
||||
|
||||
suite.slug.Set("second-slug")
|
||||
assert.Equal(suite.T(), "second-slug", suite.slug.String())
|
||||
|
||||
suite.slug.Set("")
|
||||
assert.Equal(suite.T(), "", suite.slug.String())
|
||||
}
|
||||
|
||||
func TestSlugIsolation(t *testing.T) {
|
||||
slug1 := New()
|
||||
slug2 := New()
|
||||
|
||||
slug1.Set("slug-one")
|
||||
slug2.Set("slug-two")
|
||||
|
||||
assert.Equal(t, "slug-one", slug1.String(), "First slug should maintain its value")
|
||||
assert.Equal(t, "slug-two", slug2.String(), "Second slug should maintain its value")
|
||||
}
|
||||
|
||||
func TestSlugTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(SlugTestSuite))
|
||||
}
|
||||
+17
-9
@@ -2,26 +2,34 @@ package types
|
||||
|
||||
import "time"
|
||||
|
||||
type Status int
|
||||
type SessionStatus int
|
||||
|
||||
const (
|
||||
INITIALIZING Status = iota
|
||||
RUNNING
|
||||
SessionStatusINITIALIZING SessionStatus = iota
|
||||
SessionStatusRUNNING
|
||||
SessionStatusCLOSED
|
||||
)
|
||||
|
||||
type Mode int
|
||||
type InteractiveMode int
|
||||
|
||||
const (
|
||||
INTERACTIVE Mode = iota
|
||||
HEADLESS
|
||||
InteractiveModeINTERACTIVE InteractiveMode = iota + 1
|
||||
InteractiveModeHEADLESS
|
||||
)
|
||||
|
||||
type TunnelType int
|
||||
|
||||
const (
|
||||
UNKNOWN TunnelType = iota
|
||||
HTTP
|
||||
TCP
|
||||
TunnelTypeUNKNOWN TunnelType = iota
|
||||
TunnelTypeHTTP
|
||||
TunnelTypeTCP
|
||||
)
|
||||
|
||||
type ServerMode int
|
||||
|
||||
const (
|
||||
ServerModeSTANDALONE = iota + 1
|
||||
ServerModeNODE
|
||||
)
|
||||
|
||||
type SessionKey struct {
|
||||
|
||||
Reference in New Issue
Block a user