Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5edb3c8086 | |||
| 5b603d8317 | |||
| 8fd9f8b567 | |||
| 30e84ac3b7 | |||
| fd6ffc2500 | |||
| e1cd4ed981 | |||
| 96d2b88f95 |
@@ -0,0 +1,21 @@
|
|||||||
|
name: renovate
|
||||||
|
|
||||||
|
on:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 0 * * *"
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- staging
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
renovate:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container: git.fossy.my.id/renovate-clanker/renovate:latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- run: renovate
|
||||||
|
env:
|
||||||
|
RENOVATE_CONFIG_FILE: ${{ gitea.workspace }}/renovate-config.js
|
||||||
|
LOG_LEVEL: "debug"
|
||||||
|
RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }}
|
||||||
|
GITHUB_COM_TOKEN: ${{ secrets.COM_TOKEN }}
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
on:
|
|
||||||
push:
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened]
|
|
||||||
|
|
||||||
name: SonarQube Scan
|
|
||||||
jobs:
|
|
||||||
sonarqube:
|
|
||||||
name: SonarQube Trigger
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checking out
|
|
||||||
uses: actions/checkout@v6
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Set up Go
|
|
||||||
uses: actions/setup-go@v6
|
|
||||||
with:
|
|
||||||
go-version: '1.25.6'
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: go mod tidy
|
|
||||||
|
|
||||||
- name: Run go vet
|
|
||||||
run: go vet ./... 2>&1 | tee vet-results.txt
|
|
||||||
|
|
||||||
- name: Run tests with coverage
|
|
||||||
run: |
|
|
||||||
go test ./... -v -coverprofile=coverage
|
|
||||||
|
|
||||||
- name: Run GolangCI-Lint Analysis
|
|
||||||
uses: golangci/golangci-lint-action@v9
|
|
||||||
with:
|
|
||||||
skip-cache: true
|
|
||||||
version: v2.6
|
|
||||||
args: >
|
|
||||||
--issues-exit-code=0
|
|
||||||
--output.text.path=stdout
|
|
||||||
--output.checkstyle.path=golangci-lint-report.xml
|
|
||||||
|
|
||||||
- name: SonarQube Scan
|
|
||||||
uses: SonarSource/sonarqube-scan-action@v7.0.0
|
|
||||||
env:
|
|
||||||
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
|
|
||||||
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
|
|
||||||
with:
|
|
||||||
args: >
|
|
||||||
-Dsonar.projectKey=tunnel-please
|
|
||||||
-Dsonar.go.coverage.reportPaths=coverage
|
|
||||||
-Dsonar.test.inclusions=**/*_test.go
|
|
||||||
-Dsonar.test.exclusions=**/vendor/**
|
|
||||||
-Dsonar.exclusions=**/*_test.go,**/vendor/**,**/golangci-lint-report.xml
|
|
||||||
-Dsonar.go.govet.reportPaths=vet-results.txt
|
|
||||||
-Dsonar.go.golangci-lint.reportPaths=golangci-lint-report.xml
|
|
||||||
-Dsonar.sources=./
|
|
||||||
-Dsonar.tests=./
|
|
||||||
Vendored
+1
-3
@@ -4,6 +4,4 @@ id_rsa*
|
|||||||
.env
|
.env
|
||||||
tmp
|
tmp
|
||||||
certs
|
certs
|
||||||
app
|
app
|
||||||
coverage
|
|
||||||
test-results.json
|
|
||||||
+2
-5
@@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.25.6-alpine AS go_builder
|
FROM golang:1.25.5-alpine AS go_builder
|
||||||
|
|
||||||
ARG VERSION=dev
|
ARG VERSION=dev
|
||||||
ARG BUILD_DATE=unknown
|
ARG BUILD_DATE=unknown
|
||||||
@@ -22,10 +22,7 @@ RUN --mount=type=cache,target=/go/pkg/mod \
|
|||||||
--mount=type=cache,target=/root/.cache/go-build \
|
--mount=type=cache,target=/root/.cache/go-build \
|
||||||
CGO_ENABLED=0 GOOS=linux \
|
CGO_ENABLED=0 GOOS=linux \
|
||||||
go build -trimpath \
|
go build -trimpath \
|
||||||
-ldflags="-w -s \
|
-ldflags="-w -s -X tunnel_pls/version.Version=${VERSION} -X tunnel_pls/version.BuildDate=${BUILD_DATE} -X tunnel_pls/version.Commit=${COMMIT}" \
|
||||||
-X tunnel_pls/internal/version.Version=${VERSION} \
|
|
||||||
-X tunnel_pls/internal/version.BuildDate=${BUILD_DATE} \
|
|
||||||
-X tunnel_pls/internal/version.Commit=${COMMIT}" \
|
|
||||||
-o /app/tunnel_pls \
|
-o /app/tunnel_pls \
|
||||||
.
|
.
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,6 @@
|
|||||||
<div align="center">
|
|
||||||
|
|
||||||
<img alt="gopher" title="gopher" src="./docs/images/gopher.png" width="325" />
|
|
||||||
|
|
||||||
# Tunnel Please
|
# Tunnel Please
|
||||||
|
|
||||||
A lightweight SSH-based tunnel server
|
A lightweight SSH-based tunnel server written in Go that enables secure TCP and HTTP forwarding with an interactive terminal interface for managing connections and custom subdomains.
|
||||||
|
|
||||||
<br/><br/>
|
|
||||||
|
|
||||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
|
||||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
|
||||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
|
||||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
|
||||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
|
||||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
|
||||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
@@ -33,32 +17,108 @@ A lightweight SSH-based tunnel server
|
|||||||
|
|
||||||
The following environment variables can be configured in the `.env` file:
|
The following environment variables can be configured in the `.env` file:
|
||||||
|
|
||||||
| Variable | Description | Default | Required |
|
| Variable | Description | Default | Required |
|
||||||
|---------------------|-----------------------------------------------------------------------------|-------------------------|---------------------|
|
|----------|-------------|---------|----------|
|
||||||
| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
|
| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
|
||||||
| `PORT` | SSH server port | `2200` | No |
|
| `PORT` | SSH server port | `2200` | No |
|
||||||
| `HTTP_PORT` | HTTP server port | `8080` | No |
|
| `HTTP_PORT` | HTTP server port | `8080` | No |
|
||||||
| `HTTPS_PORT` | HTTPS server port | `8443` | No |
|
| `HTTPS_PORT` | HTTPS server port | `8443` | No |
|
||||||
| `KEY_LOC` | Path to the private key file | `certs/privkey.pem` | No |
|
| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
|
||||||
| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
|
| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
|
||||||
| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
|
| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No |
|
||||||
| `TLS_STORAGE_PATH` | Path to store TLS certificates | `certs/tls/` | No |
|
| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | - | Yes (if auto-cert) |
|
||||||
| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No |
|
| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
|
||||||
| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | `-` | Yes (if auto-cert) |
|
| `CORS_LIST` | Comma-separated list of allowed CORS origins | - | No |
|
||||||
| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
|
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
|
||||||
| `CORS_LIST` | Comma-separated list of allowed CORS origins | `-` | No |
|
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
|
||||||
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
|
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
|
||||||
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
|
| `PPROF_PORT` | Port for pprof server | `6060` | No |
|
||||||
| `MAX_HEADER_SIZE` | Maximum size of HTTP headers in bytes (4096-131072) | `4096` | No |
|
| `MODE` | Runtime mode: `standalone` (default, no gRPC/auth) or `node` (enable gRPC + auth) | `standalone` | No |
|
||||||
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
|
| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
|
||||||
| `PPROF_PORT` | Port for pprof server | `6060` | No |
|
| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
|
||||||
| `MODE` | Runtime mode: `standalone` or `node` | `standalone` | No |
|
| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | - (required in `node`) | Yes (node mode) |
|
||||||
| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
|
|
||||||
| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
|
|
||||||
| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | `-` | Yes (node mode) |
|
|
||||||
|
|
||||||
**Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality.
|
**Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality.
|
||||||
|
|
||||||
|
### Automatic TLS Certificate Management
|
||||||
|
|
||||||
|
The server supports automatic TLS certificate generation and renewal using [CertMagic](https://github.com/caddyserver/certmagic) with Cloudflare DNS-01 challenge. This is required for wildcard certificate support (`*.yourdomain.com`).
|
||||||
|
|
||||||
|
**Certificate Storage:**
|
||||||
|
- TLS certificates are stored in `certs/tls/` (relative to application directory)
|
||||||
|
- User-provided certificates: `certs/tls/cert.pem` and `certs/tls/privkey.pem`
|
||||||
|
- CertMagic automatic certificates: `certs/tls/certmagic/`
|
||||||
|
- SSH keys are stored separately in `certs/ssh/`
|
||||||
|
|
||||||
|
**How it works:**
|
||||||
|
1. If user-provided certificates exist at `certs/tls/cert.pem` and `certs/tls/privkey.pem` and cover both `DOMAIN` and `*.DOMAIN`, they will be used
|
||||||
|
2. If certificates are missing, expired, expiring within 30 days, or don't cover the required domains, CertMagic will automatically obtain new certificates from Let's Encrypt
|
||||||
|
3. Certificates are automatically renewed before expiration
|
||||||
|
4. User-provided certificates support hot-reload (changes detected every 30 seconds)
|
||||||
|
|
||||||
|
**Cloudflare API Token Setup:**
|
||||||
|
|
||||||
|
To use automatic certificate generation, you need a Cloudflare API token with the following permissions:
|
||||||
|
|
||||||
|
1. Go to [Cloudflare Dashboard](https://dash.cloudflare.com/profile/api-tokens)
|
||||||
|
2. Click "Create Token"
|
||||||
|
3. Use "Create Custom Token" with these permissions:
|
||||||
|
- **Zone → Zone → Read** (for all zones or specific zone)
|
||||||
|
- **Zone → DNS → Edit** (for all zones or specific zone)
|
||||||
|
4. Copy the token and set it as `CF_API_TOKEN` environment variable
|
||||||
|
|
||||||
|
**Example configuration for automatic certificates:**
|
||||||
|
```env
|
||||||
|
DOMAIN=example.com
|
||||||
|
TLS_ENABLED=true
|
||||||
|
CF_API_TOKEN=your_cloudflare_api_token_here
|
||||||
|
ACME_EMAIL=admin@example.com
|
||||||
|
# ACME_STAGING=true # Uncomment for testing to avoid rate limits
|
||||||
|
```
|
||||||
|
|
||||||
|
### SSH Key Auto-Generation
|
||||||
|
|
||||||
|
The application will automatically generate a new 4096-bit RSA key pair at `certs/ssh/id_rsa` if it doesn't exist. This makes it easier to get started without manually creating SSH keys. SSH keys are stored separately from TLS certificates.
|
||||||
|
|
||||||
|
### Memory Optimization
|
||||||
|
|
||||||
|
The application uses a buffer pool with controlled buffer sizes to prevent excessive memory usage under high concurrent loads. The `BUFFER_SIZE` environment variable controls the size of buffers used for io.Copy operations:
|
||||||
|
|
||||||
|
- **Default:** 32768 bytes (32 KB) - Good balance for most scenarios
|
||||||
|
- **Minimum:** 4096 bytes (4 KB) - Lower memory usage, more CPU overhead
|
||||||
|
- **Maximum:** 1048576 bytes (1 MB) - Higher throughput, more memory usage
|
||||||
|
|
||||||
|
**Recommended settings based on load:**
|
||||||
|
- **Low traffic (<100 concurrent):** `BUFFER_SIZE=32768` (default)
|
||||||
|
- **High traffic (>100 concurrent):** `BUFFER_SIZE=16384` or `BUFFER_SIZE=8192`
|
||||||
|
- **Very high traffic (>1000 concurrent):** `BUFFER_SIZE=8192` or `BUFFER_SIZE=4096`
|
||||||
|
|
||||||
|
The buffer pool reuses buffers across connections, preventing memory fragmentation and reducing garbage collection pressure.
|
||||||
|
|
||||||
|
### Profiling with pprof
|
||||||
|
|
||||||
|
To enable profiling for performance analysis:
|
||||||
|
|
||||||
|
1. Set `PPROF_ENABLED=true` in your `.env` file
|
||||||
|
2. Optionally set `PPROF_PORT` to your desired port (default: 6060)
|
||||||
|
3. Access profiling data at `http://localhost:6060/debug/pprof/`
|
||||||
|
|
||||||
|
Common pprof endpoints:
|
||||||
|
- `/debug/pprof/` - Index page with available profiles
|
||||||
|
- `/debug/pprof/heap` - Memory allocation profile
|
||||||
|
- `/debug/pprof/goroutine` - Stack traces of all current goroutines
|
||||||
|
- `/debug/pprof/profile` - CPU profile (30-second sample by default)
|
||||||
|
- `/debug/pprof/trace` - Execution trace
|
||||||
|
|
||||||
|
Example usage with `go tool pprof`:
|
||||||
|
```bash
|
||||||
|
# Analyze CPU profile
|
||||||
|
go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
|
||||||
|
|
||||||
|
# Analyze memory heap
|
||||||
|
go tool pprof http://localhost:6060/debug/pprof/heap
|
||||||
|
```
|
||||||
|
|
||||||
## Docker Deployment
|
## Docker Deployment
|
||||||
|
|
||||||
Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`.
|
Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`.
|
||||||
@@ -137,6 +197,22 @@ docker-compose -f docker-compose.tcp.yml up -d
|
|||||||
docker-compose -f docker-compose.root.yml down
|
docker-compose -f docker-compose.root.yml down
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Volume Management
|
||||||
|
|
||||||
|
All configurations use a named volume `certs` for persistent storage:
|
||||||
|
- SSH keys: `/app/certs/ssh/`
|
||||||
|
- TLS certificates: `/app/certs/tls/`
|
||||||
|
|
||||||
|
To backup certificates:
|
||||||
|
```bash
|
||||||
|
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar czf /backup/certs-backup.tar.gz -C /data .
|
||||||
|
```
|
||||||
|
|
||||||
|
To restore certificates:
|
||||||
|
```bash
|
||||||
|
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar xzf /backup/certs-backup.tar.gz -C /data
|
||||||
|
```
|
||||||
|
|
||||||
### Recommendation
|
### Recommendation
|
||||||
|
|
||||||
**Use `docker-compose.root.yml`** for production deployments if you need:
|
**Use `docker-compose.root.yml`** for production deployments if you need:
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 2.0 MiB |
@@ -3,16 +3,15 @@ module tunnel_pls
|
|||||||
go 1.25.5
|
go 1.25.5
|
||||||
|
|
||||||
require (
|
require (
|
||||||
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0
|
git.fossy.my.id/bagas/tunnel-please-grpc v1.2.0
|
||||||
github.com/caddyserver/certmagic v0.25.1
|
github.com/caddyserver/certmagic v0.25.0
|
||||||
github.com/charmbracelet/bubbles v0.21.0
|
github.com/charmbracelet/bubbles v0.21.0
|
||||||
github.com/charmbracelet/bubbletea v1.3.10
|
github.com/charmbracelet/bubbletea v1.3.10
|
||||||
github.com/charmbracelet/lipgloss v1.1.0
|
github.com/charmbracelet/lipgloss v1.1.0
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/libdns/cloudflare v0.2.2
|
github.com/libdns/cloudflare v0.2.2
|
||||||
github.com/muesli/termenv v0.16.0
|
github.com/muesli/termenv v0.16.0
|
||||||
github.com/stretchr/testify v1.11.1
|
golang.org/x/crypto v0.46.0
|
||||||
golang.org/x/crypto v0.47.0
|
|
||||||
google.golang.org/grpc v1.78.0
|
google.golang.org/grpc v1.78.0
|
||||||
google.golang.org/protobuf v1.36.11
|
google.golang.org/protobuf v1.36.11
|
||||||
)
|
)
|
||||||
@@ -20,7 +19,7 @@ require (
|
|||||||
require (
|
require (
|
||||||
github.com/atotto/clipboard v0.1.4 // indirect
|
github.com/atotto/clipboard v0.1.4 // indirect
|
||||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||||
github.com/caddyserver/zerossl v0.1.4 // indirect
|
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||||
github.com/charmbracelet/x/ansi v0.11.3 // indirect
|
github.com/charmbracelet/x/ansi v0.11.3 // indirect
|
||||||
github.com/charmbracelet/x/cellbuf v0.0.14 // indirect
|
github.com/charmbracelet/x/cellbuf v0.0.14 // indirect
|
||||||
@@ -28,7 +27,6 @@ require (
|
|||||||
github.com/clipperhouse/displaywidth v0.6.2 // indirect
|
github.com/clipperhouse/displaywidth v0.6.2 // indirect
|
||||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||||
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
|
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
|
||||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
github.com/libdns/libdns v1.1.1 // indirect
|
github.com/libdns/libdns v1.1.1 // indirect
|
||||||
@@ -40,10 +38,8 @@ require (
|
|||||||
github.com/miekg/dns v1.1.69 // indirect
|
github.com/miekg/dns v1.1.69 // indirect
|
||||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
|
||||||
github.com/rivo/uniseg v0.4.7 // indirect
|
github.com/rivo/uniseg v0.4.7 // indirect
|
||||||
github.com/sahilm/fuzzy v0.1.1 // indirect
|
github.com/sahilm/fuzzy v0.1.1 // indirect
|
||||||
github.com/stretchr/objx v0.5.2 // indirect
|
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||||
github.com/zeebo/blake3 v0.2.4 // indirect
|
github.com/zeebo/blake3 v0.2.4 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
@@ -52,9 +48,8 @@ require (
|
|||||||
golang.org/x/mod v0.31.0 // indirect
|
golang.org/x/mod v0.31.0 // indirect
|
||||||
golang.org/x/net v0.48.0 // indirect
|
golang.org/x/net v0.48.0 // indirect
|
||||||
golang.org/x/sync v0.19.0 // indirect
|
golang.org/x/sync v0.19.0 // indirect
|
||||||
golang.org/x/sys v0.40.0 // indirect
|
golang.org/x/sys v0.39.0 // indirect
|
||||||
golang.org/x/text v0.33.0 // indirect
|
golang.org/x/text v0.32.0 // indirect
|
||||||
golang.org/x/tools v0.40.0 // indirect
|
golang.org/x/tools v0.40.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA=
|
git.fossy.my.id/bagas/tunnel-please-grpc v1.2.0 h1:BS1dJU3wa2ILgTGwkV95Knle0il0OQtErGqyb6xV7SU=
|
||||||
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo=
|
git.fossy.my.id/bagas/tunnel-please-grpc v1.2.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY=
|
||||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||||
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
|
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
|
||||||
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
|
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
|
||||||
github.com/caddyserver/certmagic v0.25.1 h1:4sIKKbOt5pg6+sL7tEwymE1x2bj6CHr80da1CRRIPbY=
|
github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic=
|
||||||
github.com/caddyserver/certmagic v0.25.1/go.mod h1:VhyvndxtVton/Fo/wKhRoC46Rbw1fmjvQ3GjHYSQTEY=
|
github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA=
|
||||||
github.com/caddyserver/zerossl v0.1.4 h1:CVJOE3MZeFisCERZjkxIcsqIH4fnFdlYWnPYeFtBHRw=
|
github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA=
|
||||||
github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
|
github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
|
||||||
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
|
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
|
||||||
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
|
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
|
||||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||||
@@ -32,7 +32,6 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa
|
|||||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||||
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
|
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
|
||||||
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||||
@@ -81,18 +80,8 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
|||||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||||
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
|
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
|
||||||
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
|
||||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
|
||||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
|
||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
|
||||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
|
||||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||||
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
|
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
|
||||||
@@ -121,8 +110,8 @@ go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
|||||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||||
go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U=
|
go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U=
|
||||||
go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ=
|
go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ=
|
||||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||||
@@ -133,12 +122,12 @@ golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
|||||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
@@ -149,8 +138,5 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
|
|||||||
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
||||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
|
||||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
||||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
|||||||
@@ -1,196 +0,0 @@
|
|||||||
package bootstrap
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/config"
|
|
||||||
"tunnel_pls/internal/grpc/client"
|
|
||||||
"tunnel_pls/internal/key"
|
|
||||||
"tunnel_pls/internal/port"
|
|
||||||
"tunnel_pls/internal/random"
|
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
"tunnel_pls/internal/transport"
|
|
||||||
"tunnel_pls/internal/version"
|
|
||||||
"tunnel_pls/server"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Bootstrap struct {
|
|
||||||
Randomizer random.Random
|
|
||||||
Config config.Config
|
|
||||||
SessionRegistry registry.Registry
|
|
||||||
Port port.Port
|
|
||||||
GrpcClient client.Client
|
|
||||||
ErrChan chan error
|
|
||||||
SignalChan chan os.Signal
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(config config.Config, port port.Port) (*Bootstrap, error) {
|
|
||||||
randomizer := random.New()
|
|
||||||
sessionRegistry := registry.NewRegistry()
|
|
||||||
|
|
||||||
if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
grpcClient, err := client.New(config, sessionRegistry)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
errChan := make(chan error, 5)
|
|
||||||
signalChan := make(chan os.Signal, 1)
|
|
||||||
|
|
||||||
return &Bootstrap{
|
|
||||||
Randomizer: randomizer,
|
|
||||||
Config: config,
|
|
||||||
SessionRegistry: sessionRegistry,
|
|
||||||
Port: port,
|
|
||||||
GrpcClient: grpcClient,
|
|
||||||
ErrChan: errChan,
|
|
||||||
SignalChan: signalChan,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) {
|
|
||||||
sshCfg := &ssh.ServerConfig{
|
|
||||||
NoClientAuth: true,
|
|
||||||
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
|
|
||||||
return nil, fmt.Errorf("generate ssh key: %w", err)
|
|
||||||
}
|
|
||||||
privateBytes, err := os.ReadFile(sshKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("read private key: %w", err)
|
|
||||||
}
|
|
||||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse private key: %w", err)
|
|
||||||
}
|
|
||||||
sshCfg.AddHostKey(private)
|
|
||||||
return sshCfg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error {
|
|
||||||
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
|
|
||||||
defer healthCancel()
|
|
||||||
if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil {
|
|
||||||
return fmt.Errorf("gRPC health check failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
|
|
||||||
httpserver := transport.NewHTTPServer(conf, registry)
|
|
||||||
ln, err := httpserver.Listen()
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = httpserver.Serve(ln); err != nil {
|
|
||||||
errChan <- fmt.Errorf("error when serving http server: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startHTTPSServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
|
|
||||||
tlsCfg, err := transport.NewTLSConfig(conf)
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg)
|
|
||||||
ln, err := httpsServer.Listen()
|
|
||||||
if err != nil {
|
|
||||||
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = httpsServer.Serve(ln); err != nil {
|
|
||||||
errChan <- fmt.Errorf("error when serving https server: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) {
|
|
||||||
sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort())
|
|
||||||
if err != nil {
|
|
||||||
errChan <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sshServer.Start()
|
|
||||||
|
|
||||||
errChan <- sshServer.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func startPprof(pprofPort string, errChan chan<- error) {
|
|
||||||
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
|
|
||||||
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
|
||||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
|
||||||
errChan <- fmt.Errorf("pprof server error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func (b *Bootstrap) Run() error {
|
|
||||||
sshConfig, err := newSSHConfig(b.Config.KeyLoc())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create SSH config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM)
|
|
||||||
|
|
||||||
if b.Config.Mode() == types.ServerModeNODE {
|
|
||||||
err = b.startGRPCClient(ctx, b.Config, b.ErrChan)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to start gRPC client: %w", err)
|
|
||||||
}
|
|
||||||
defer func(grpcClient client.Client) {
|
|
||||||
err = grpcClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed to close gRPC client")
|
|
||||||
}
|
|
||||||
}(b.GrpcClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan)
|
|
||||||
|
|
||||||
if b.Config.TLSEnabled() {
|
|
||||||
go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if b.Config.PprofEnabled() {
|
|
||||||
go startPprof(b.Config.PprofPort(), b.ErrChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println("All services started successfully")
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-b.ErrChan:
|
|
||||||
return fmt.Errorf("service error: %w", err)
|
|
||||||
case sig := <-b.SignalChan:
|
|
||||||
log.Printf("Received signal %s, initiating graceful shutdown", sig)
|
|
||||||
cancel()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,558 +0,0 @@
|
|||||||
package bootstrap
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
_ "net/http/pprof"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/config"
|
|
||||||
"tunnel_pls/internal/port"
|
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
"tunnel_pls/session/slug"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockSessionRegistry struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
|
||||||
args := m.Called(key)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Error(1)
|
|
||||||
}
|
|
||||||
return args.Get(0).(registry.Session), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
|
||||||
args := m.Called(user, key)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Error(1)
|
|
||||||
}
|
|
||||||
return args.Get(0).(registry.Session), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
|
||||||
args := m.Called(user, oldKey, newKey)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
|
||||||
args := m.Called(key, session)
|
|
||||||
return args.Bool(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
|
||||||
m.Called(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
|
||||||
args := m.Called(user)
|
|
||||||
return args.Get(0).([]registry.Session)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(slug.Slug)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockRandom struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockRandom) String(length int) (string, error) {
|
|
||||||
args := m.Called(length)
|
|
||||||
return args.String(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockConfig struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConfig) Domain() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
|
||||||
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
|
||||||
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
|
||||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
|
||||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
|
||||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
|
||||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
|
||||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
|
||||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) Mode() types.ServerMode {
|
|
||||||
args := m.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
switch v := args.Get(0).(type) {
|
|
||||||
case types.ServerMode:
|
|
||||||
return v
|
|
||||||
case int:
|
|
||||||
return types.ServerMode(v)
|
|
||||||
default:
|
|
||||||
return types.ServerMode(args.Int(0))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
|
||||||
|
|
||||||
type MockPort struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
|
||||||
return m.Called(startPort, endPort).Error(0)
|
|
||||||
}
|
|
||||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
|
||||||
args := m.Called()
|
|
||||||
var mPort uint16
|
|
||||||
if args.Get(0) != nil {
|
|
||||||
switch v := args.Get(0).(type) {
|
|
||||||
case int:
|
|
||||||
mPort = uint16(v)
|
|
||||||
case uint16:
|
|
||||||
mPort = v
|
|
||||||
case uint32:
|
|
||||||
mPort = uint16(v)
|
|
||||||
case int32:
|
|
||||||
mPort = uint16(v)
|
|
||||||
case float64:
|
|
||||||
mPort = uint16(v)
|
|
||||||
default:
|
|
||||||
mPort = uint16(args.Int(0))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return mPort, args.Bool(1)
|
|
||||||
}
|
|
||||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
|
||||||
return m.Called(port, assigned).Error(0)
|
|
||||||
}
|
|
||||||
func (m *MockPort) Claim(port uint16) bool {
|
|
||||||
return m.Called(port).Bool(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockGRPCClient struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(*grpc.ClientConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
|
||||||
args := m.Called(ctx, token)
|
|
||||||
return args.Bool(0), args.String(1), args.Error(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
|
||||||
args := m.Called(ctx)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
|
|
||||||
args := m.Called(ctx, domain, token)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setupConfig func() config.Config
|
|
||||||
setupPort func() port.Port
|
|
||||||
wantErr bool
|
|
||||||
errContains string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Success New with default value",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error when AddRange fails",
|
|
||||||
setupPort: func() port.Port {
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range"))
|
|
||||||
return mockPort
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
errContains: "invalid port range",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
var mockPort port.Port
|
|
||||||
if tt.setupPort != nil {
|
|
||||||
mockPort = tt.setupPort()
|
|
||||||
} else {
|
|
||||||
mockPort = port.New()
|
|
||||||
}
|
|
||||||
|
|
||||||
var mockConfig config.Config
|
|
||||||
if tt.setupConfig != nil {
|
|
||||||
mockConfig = tt.setupConfig()
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
mockConfig, err = config.MustLoad()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
bootstrap, err := New(mockConfig, mockPort)
|
|
||||||
|
|
||||||
if tt.wantErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
if tt.errContains != "" {
|
|
||||||
assert.Contains(t, err.Error(), tt.errContains)
|
|
||||||
}
|
|
||||||
assert.Nil(t, bootstrap)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, bootstrap)
|
|
||||||
assert.NotNil(t, bootstrap.Randomizer)
|
|
||||||
assert.NotNil(t, bootstrap.SessionRegistry)
|
|
||||||
assert.NotNil(t, bootstrap.Config)
|
|
||||||
assert.NotNil(t, bootstrap.Port)
|
|
||||||
assert.NotNil(t, bootstrap.ErrChan)
|
|
||||||
assert.NotNil(t, bootstrap.SignalChan)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func randomAvailablePort() (string, error) {
|
|
||||||
listener, err := net.Listen("tcp", "localhost:0")
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer func(listener net.Listener) {
|
|
||||||
_ = listener.Close()
|
|
||||||
}(listener)
|
|
||||||
|
|
||||||
mPort := listener.Addr().(*net.TCPAddr).Port
|
|
||||||
return strconv.Itoa(mPort), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockErrChan := make(chan error, 1)
|
|
||||||
mockSignalChan := make(chan os.Signal, 1)
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
tmpDir := t.TempDir()
|
|
||||||
keyLoc := filepath.Join(tmpDir, "key.key")
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setupConfig func() *MockConfig
|
|
||||||
setupGrpcClient func() *MockGRPCClient
|
|
||||||
needCerts bool
|
|
||||||
expectError bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "successful run and termination",
|
|
||||||
setupConfig: func() *MockConfig {
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("SSHPort").Return("0")
|
|
||||||
mockConfig.On("HTTPPort").Return("0")
|
|
||||||
mockConfig.On("HTTPSPort").Return("0")
|
|
||||||
mockConfig.On("TLSEnabled").Return(false)
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
|
||||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
|
||||||
mockConfig.On("ACMEStaging").Return(true)
|
|
||||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
|
||||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
|
||||||
mockConfig.On("BufferSize").Return(4096)
|
|
||||||
mockConfig.On("PprofEnabled").Return(false)
|
|
||||||
mockConfig.On("PprofPort").Return("0")
|
|
||||||
mockConfig.On("GRPCAddress").Return("localhost")
|
|
||||||
mockConfig.On("GRPCPort").Return("0")
|
|
||||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
|
||||||
return mockConfig
|
|
||||||
},
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error from SSH server invalid port",
|
|
||||||
setupConfig: func() *MockConfig {
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("SSHPort").Return("invalid")
|
|
||||||
mockConfig.On("HTTPPort").Return("0")
|
|
||||||
mockConfig.On("HTTPSPort").Return("0")
|
|
||||||
mockConfig.On("TLSEnabled").Return(false)
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
|
||||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
|
||||||
mockConfig.On("ACMEStaging").Return(true)
|
|
||||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
|
||||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
|
||||||
mockConfig.On("BufferSize").Return(4096)
|
|
||||||
mockConfig.On("PprofEnabled").Return(false)
|
|
||||||
mockConfig.On("PprofPort").Return("0")
|
|
||||||
mockConfig.On("GRPCAddress").Return("localhost")
|
|
||||||
mockConfig.On("GRPCPort").Return("0")
|
|
||||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
|
||||||
return mockConfig
|
|
||||||
},
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error from HTTP server invalid port",
|
|
||||||
setupConfig: func() *MockConfig {
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("SSHPort").Return("0")
|
|
||||||
mockConfig.On("HTTPPort").Return("invalid")
|
|
||||||
mockConfig.On("HTTPSPort").Return("0")
|
|
||||||
mockConfig.On("TLSEnabled").Return(false)
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
|
||||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
|
||||||
mockConfig.On("ACMEStaging").Return(true)
|
|
||||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
|
||||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
|
||||||
mockConfig.On("BufferSize").Return(4096)
|
|
||||||
mockConfig.On("PprofEnabled").Return(false)
|
|
||||||
mockConfig.On("PprofPort").Return("0")
|
|
||||||
mockConfig.On("GRPCAddress").Return("localhost")
|
|
||||||
mockConfig.On("GRPCPort").Return("0")
|
|
||||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
|
||||||
return mockConfig
|
|
||||||
},
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error from HTTPS server invalid port",
|
|
||||||
setupConfig: func() *MockConfig {
|
|
||||||
tempDir := os.TempDir()
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("SSHPort").Return("0")
|
|
||||||
mockConfig.On("HTTPPort").Return("0")
|
|
||||||
mockConfig.On("HTTPSPort").Return("invalid")
|
|
||||||
mockConfig.On("TLSEnabled").Return(true)
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
mockConfig.On("TLSStoragePath").Return(tempDir)
|
|
||||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
|
||||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
|
||||||
mockConfig.On("ACMEStaging").Return(true)
|
|
||||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
|
||||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
|
||||||
mockConfig.On("BufferSize").Return(4096)
|
|
||||||
mockConfig.On("PprofEnabled").Return(false)
|
|
||||||
mockConfig.On("PprofPort").Return("0")
|
|
||||||
mockConfig.On("GRPCAddress").Return("localhost")
|
|
||||||
mockConfig.On("GRPCPort").Return("0")
|
|
||||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
|
||||||
return mockConfig
|
|
||||||
},
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "grpc health check failed",
|
|
||||||
setupConfig: func() *MockConfig {
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("SSHPort").Return("0")
|
|
||||||
mockConfig.On("HTTPPort").Return("0")
|
|
||||||
mockConfig.On("HTTPSPort").Return("0")
|
|
||||||
mockConfig.On("TLSEnabled").Return(false)
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
|
||||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
|
||||||
mockConfig.On("ACMEStaging").Return(true)
|
|
||||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
|
||||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
|
||||||
mockConfig.On("BufferSize").Return(4096)
|
|
||||||
mockConfig.On("PprofEnabled").Return(false)
|
|
||||||
mockConfig.On("PprofPort").Return("0")
|
|
||||||
mockConfig.On("GRPCAddress").Return("localhost")
|
|
||||||
mockConfig.On("GRPCPort").Return("invalid")
|
|
||||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
|
||||||
return mockConfig
|
|
||||||
},
|
|
||||||
setupGrpcClient: func() *MockGRPCClient {
|
|
||||||
mockGRPCClient := &MockGRPCClient{}
|
|
||||||
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed"))
|
|
||||||
return mockGRPCClient
|
|
||||||
},
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "successful run with pprof enabled",
|
|
||||||
setupConfig: func() *MockConfig {
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
pprofPort, _ := randomAvailablePort()
|
|
||||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("SSHPort").Return("0")
|
|
||||||
mockConfig.On("HTTPPort").Return("0")
|
|
||||||
mockConfig.On("HTTPSPort").Return("0")
|
|
||||||
mockConfig.On("TLSEnabled").Return(false)
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
|
||||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
|
||||||
mockConfig.On("ACMEStaging").Return(true)
|
|
||||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
|
||||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
|
||||||
mockConfig.On("BufferSize").Return(4096)
|
|
||||||
mockConfig.On("PprofEnabled").Return(true)
|
|
||||||
mockConfig.On("PprofPort").Return(pprofPort)
|
|
||||||
mockConfig.On("GRPCAddress").Return("localhost")
|
|
||||||
mockConfig.On("GRPCPort").Return("0")
|
|
||||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
|
||||||
return mockConfig
|
|
||||||
},
|
|
||||||
expectError: false,
|
|
||||||
}, {
|
|
||||||
name: "successful run in NODE mode with signal",
|
|
||||||
setupConfig: func() *MockConfig {
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("SSHPort").Return("0")
|
|
||||||
mockConfig.On("HTTPPort").Return("0")
|
|
||||||
mockConfig.On("HTTPSPort").Return("0")
|
|
||||||
mockConfig.On("TLSEnabled").Return(false)
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
|
||||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
|
||||||
mockConfig.On("ACMEStaging").Return(true)
|
|
||||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
|
||||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
|
||||||
mockConfig.On("BufferSize").Return(4096)
|
|
||||||
mockConfig.On("PprofEnabled").Return(false)
|
|
||||||
mockConfig.On("PprofPort").Return("0")
|
|
||||||
mockConfig.On("GRPCAddress").Return("localhost")
|
|
||||||
mockConfig.On("GRPCPort").Return("0")
|
|
||||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
|
||||||
return mockConfig
|
|
||||||
},
|
|
||||||
setupGrpcClient: func() *MockGRPCClient {
|
|
||||||
mockGRPCClient := &MockGRPCClient{}
|
|
||||||
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
|
|
||||||
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
|
||||||
mockGRPCClient.On("Close").Return(nil)
|
|
||||||
return mockGRPCClient
|
|
||||||
},
|
|
||||||
expectError: false,
|
|
||||||
}, {
|
|
||||||
name: "successful run in NODE mode with signal buf error when closing",
|
|
||||||
setupConfig: func() *MockConfig {
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("SSHPort").Return("0")
|
|
||||||
mockConfig.On("HTTPPort").Return("0")
|
|
||||||
mockConfig.On("HTTPSPort").Return("0")
|
|
||||||
mockConfig.On("TLSEnabled").Return(false)
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
|
||||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
|
||||||
mockConfig.On("ACMEStaging").Return(true)
|
|
||||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
|
||||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
|
||||||
mockConfig.On("BufferSize").Return(4096)
|
|
||||||
mockConfig.On("PprofEnabled").Return(false)
|
|
||||||
mockConfig.On("PprofPort").Return("0")
|
|
||||||
mockConfig.On("GRPCAddress").Return("localhost")
|
|
||||||
mockConfig.On("GRPCPort").Return("0")
|
|
||||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
|
||||||
return mockConfig
|
|
||||||
},
|
|
||||||
setupGrpcClient: func() *MockGRPCClient {
|
|
||||||
mockGRPCClient := &MockGRPCClient{}
|
|
||||||
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
|
|
||||||
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
|
||||||
mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy"))
|
|
||||||
return mockGRPCClient
|
|
||||||
},
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockConfig := tt.setupConfig()
|
|
||||||
mockGRPCClient := &MockGRPCClient{}
|
|
||||||
bootstrap := &Bootstrap{
|
|
||||||
Randomizer: mockRandom,
|
|
||||||
Config: mockConfig,
|
|
||||||
SessionRegistry: mockSessionRegistry,
|
|
||||||
Port: mockPort,
|
|
||||||
ErrChan: mockErrChan,
|
|
||||||
SignalChan: mockSignalChan,
|
|
||||||
GrpcClient: mockGRPCClient,
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.setupGrpcClient != nil {
|
|
||||||
bootstrap.GrpcClient = tt.setupGrpcClient()
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
done <- bootstrap.Run()
|
|
||||||
}()
|
|
||||||
|
|
||||||
if tt.expectError {
|
|
||||||
err := <-done
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else if tt.name == "successful run with pprof enabled" {
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
fmt.Println(mockConfig.PprofPort())
|
|
||||||
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, 200, resp.StatusCode)
|
|
||||||
err = resp.Body.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
mockSignalChan <- os.Interrupt
|
|
||||||
err = <-done
|
|
||||||
assert.NoError(t, err)
|
|
||||||
} else {
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
mockSignalChan <- os.Interrupt
|
|
||||||
err := <-done
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+25
-60
@@ -1,70 +1,35 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import "tunnel_pls/types"
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
type Config interface {
|
"github.com/joho/godotenv"
|
||||||
Domain() string
|
)
|
||||||
SSHPort() string
|
|
||||||
|
|
||||||
HTTPPort() string
|
func init() {
|
||||||
HTTPSPort() string
|
if _, err := os.Stat(".env"); err == nil {
|
||||||
|
if err := godotenv.Load(".env"); err != nil {
|
||||||
KeyLoc() string
|
log.Printf("Warning: Failed to load .env file: %s", err)
|
||||||
|
}
|
||||||
TLSEnabled() bool
|
}
|
||||||
TLSRedirect() bool
|
|
||||||
TLSStoragePath() string
|
|
||||||
|
|
||||||
ACMEEmail() string
|
|
||||||
CFAPIToken() string
|
|
||||||
ACMEStaging() bool
|
|
||||||
|
|
||||||
AllowedPortsStart() uint16
|
|
||||||
AllowedPortsEnd() uint16
|
|
||||||
|
|
||||||
BufferSize() int
|
|
||||||
HeaderSize() int
|
|
||||||
|
|
||||||
PprofEnabled() bool
|
|
||||||
PprofPort() string
|
|
||||||
|
|
||||||
Mode() types.ServerMode
|
|
||||||
GRPCAddress() string
|
|
||||||
GRPCPort() string
|
|
||||||
NodeToken() string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustLoad() (Config, error) {
|
func Getenv(key, defaultValue string) string {
|
||||||
if err := loadEnvFile(); err != nil {
|
val := os.Getenv(key)
|
||||||
return nil, err
|
if val == "" {
|
||||||
|
val = defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, err := parse()
|
return val
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cfg, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *config) Domain() string { return c.domain }
|
func GetBufferSize() int {
|
||||||
func (c *config) SSHPort() string { return c.sshPort }
|
sizeStr := Getenv("BUFFER_SIZE", "32768")
|
||||||
func (c *config) HTTPPort() string { return c.httpPort }
|
size, err := strconv.Atoi(sizeStr)
|
||||||
func (c *config) HTTPSPort() string { return c.httpsPort }
|
if err != nil || size < 4096 || size > 1048576 {
|
||||||
func (c *config) KeyLoc() string { return c.keyLoc }
|
return 32768
|
||||||
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
|
}
|
||||||
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
|
return size
|
||||||
func (c *config) TLSStoragePath() string { return c.tlsStoragePath }
|
}
|
||||||
func (c *config) ACMEEmail() string { return c.acmeEmail }
|
|
||||||
func (c *config) CFAPIToken() string { return c.cfAPIToken }
|
|
||||||
func (c *config) ACMEStaging() bool { return c.acmeStaging }
|
|
||||||
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
|
|
||||||
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
|
|
||||||
func (c *config) BufferSize() int { return c.bufferSize }
|
|
||||||
func (c *config) HeaderSize() int { return c.headerSize }
|
|
||||||
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
|
|
||||||
func (c *config) PprofPort() string { return c.pprofPort }
|
|
||||||
func (c *config) Mode() types.ServerMode { return c.mode }
|
|
||||||
func (c *config) GRPCAddress() string { return c.grpcAddress }
|
|
||||||
func (c *config) GRPCPort() string { return c.grpcPort }
|
|
||||||
func (c *config) NodeToken() string { return c.nodeToken }
|
|
||||||
|
|||||||
@@ -1,405 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetenv(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
key string
|
|
||||||
val string
|
|
||||||
def string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "returns existing env",
|
|
||||||
key: "TEST_ENV_EXIST",
|
|
||||||
val: "value",
|
|
||||||
def: "default",
|
|
||||||
expected: "value",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "returns default when env missing",
|
|
||||||
key: "TEST_ENV_MISSING",
|
|
||||||
val: "",
|
|
||||||
def: "default",
|
|
||||||
expected: "default",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.val != "" {
|
|
||||||
t.Setenv(tt.key, tt.val)
|
|
||||||
} else {
|
|
||||||
err := os.Unsetenv(tt.key)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetenvBool(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
key string
|
|
||||||
val string
|
|
||||||
def bool
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "returns true when env is true",
|
|
||||||
key: "TEST_BOOL_TRUE",
|
|
||||||
val: "true",
|
|
||||||
def: false,
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "returns false when env is false",
|
|
||||||
key: "TEST_BOOL_FALSE",
|
|
||||||
val: "false",
|
|
||||||
def: true,
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "returns default when env missing",
|
|
||||||
key: "TEST_BOOL_MISSING",
|
|
||||||
val: "",
|
|
||||||
def: true,
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "returns false when env is not true",
|
|
||||||
key: "TEST_BOOL_INVALID",
|
|
||||||
val: "yes",
|
|
||||||
def: true,
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.val != "" {
|
|
||||||
t.Setenv(tt.key, tt.val)
|
|
||||||
} else {
|
|
||||||
err := os.Unsetenv(tt.key)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseMode(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
mode string
|
|
||||||
expect types.ServerMode
|
|
||||||
expectErr bool
|
|
||||||
}{
|
|
||||||
{"standalone", "standalone", types.ServerModeSTANDALONE, false},
|
|
||||||
{"node", "node", types.ServerModeNODE, false},
|
|
||||||
{"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false},
|
|
||||||
{"invalid", "invalid", 0, true},
|
|
||||||
{"empty (default)", "", types.ServerModeSTANDALONE, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.mode != "" {
|
|
||||||
t.Setenv("MODE", tt.mode)
|
|
||||||
} else {
|
|
||||||
err := os.Unsetenv("MODE")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
mode, err := parseMode()
|
|
||||||
if tt.expectErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.expect, mode)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseAllowedPorts(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
val string
|
|
||||||
start uint16
|
|
||||||
end uint16
|
|
||||||
expectErr bool
|
|
||||||
}{
|
|
||||||
{"valid range", "1000-2000", 1000, 2000, false},
|
|
||||||
{"empty", "", 0, 0, false},
|
|
||||||
{"invalid format - no dash", "1000", 0, 0, true},
|
|
||||||
{"invalid format - too many dashes", "1000-2000-3000", 0, 0, true},
|
|
||||||
{"invalid start port", "abc-2000", 0, 0, true},
|
|
||||||
{"invalid end port", "1000-abc", 0, 0, true},
|
|
||||||
{"out of range start", "70000-80000", 0, 0, true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.val != "" {
|
|
||||||
t.Setenv("ALLOWED_PORTS", tt.val)
|
|
||||||
} else {
|
|
||||||
err := os.Unsetenv("ALLOWED_PORTS")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
start, end, err := parseAllowedPorts()
|
|
||||||
if tt.expectErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.start, start)
|
|
||||||
assert.Equal(t, tt.end, end)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseBufferSize(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
val string
|
|
||||||
expect int
|
|
||||||
}{
|
|
||||||
{"valid size", "8192", 8192},
|
|
||||||
{"default size", "", 32768},
|
|
||||||
{"too small", "1024", 4096},
|
|
||||||
{"too large", "2000000", 4096},
|
|
||||||
{"invalid format", "abc", 4096},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.val != "" {
|
|
||||||
t.Setenv("BUFFER_SIZE", tt.val)
|
|
||||||
} else {
|
|
||||||
err := os.Unsetenv("BUFFER_SIZE")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
size := parseBufferSize()
|
|
||||||
assert.Equal(t, tt.expect, size)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParseHeaderSize(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
val string
|
|
||||||
expect int
|
|
||||||
}{
|
|
||||||
{"valid size", "8192", 8192},
|
|
||||||
{"default size", "", 4096},
|
|
||||||
{"too small", "1024", 4096},
|
|
||||||
{"too large", "2000000", 4096},
|
|
||||||
{"invalid format", "abc", 4096},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.val != "" {
|
|
||||||
t.Setenv("MAX_HEADER_SIZE", tt.val)
|
|
||||||
} else {
|
|
||||||
err := os.Unsetenv("MAX_HEADER_SIZE")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
size := parseHeaderSize()
|
|
||||||
assert.Equal(t, tt.expect, size)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParse(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
envs map[string]string
|
|
||||||
expectErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "minimal valid config",
|
|
||||||
envs: map[string]string{
|
|
||||||
"DOMAIN": "example.com",
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "TLS enabled without token",
|
|
||||||
envs: map[string]string{
|
|
||||||
"TLS_ENABLED": "true",
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "TLS enabled with token",
|
|
||||||
envs: map[string]string{
|
|
||||||
"TLS_ENABLED": "true",
|
|
||||||
"CF_API_TOKEN": "secret",
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Node mode without token",
|
|
||||||
envs: map[string]string{
|
|
||||||
"MODE": "node",
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Node mode with token",
|
|
||||||
envs: map[string]string{
|
|
||||||
"MODE": "node",
|
|
||||||
"NODE_TOKEN": "token",
|
|
||||||
},
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid mode",
|
|
||||||
envs: map[string]string{
|
|
||||||
"MODE": "invalid",
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid allowed ports",
|
|
||||||
envs: map[string]string{
|
|
||||||
"ALLOWED_PORTS": "1000",
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
os.Clearenv()
|
|
||||||
for k, v := range tt.envs {
|
|
||||||
t.Setenv(k, v)
|
|
||||||
}
|
|
||||||
cfg, err := parse()
|
|
||||||
if tt.expectErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, cfg)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, cfg)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetters(t *testing.T) {
|
|
||||||
envs := map[string]string{
|
|
||||||
"DOMAIN": "example.com",
|
|
||||||
"PORT": "2222",
|
|
||||||
"HTTP_PORT": "80",
|
|
||||||
"HTTPS_PORT": "443",
|
|
||||||
"KEY_LOC": "certs/ssh/id_rsa",
|
|
||||||
"TLS_ENABLED": "true",
|
|
||||||
"TLS_REDIRECT": "true",
|
|
||||||
"TLS_STORAGE_PATH": "certs/tls/",
|
|
||||||
"ACME_EMAIL": "test@example.com",
|
|
||||||
"CF_API_TOKEN": "token",
|
|
||||||
"ACME_STAGING": "true",
|
|
||||||
"ALLOWED_PORTS": "1000-2000",
|
|
||||||
"BUFFER_SIZE": "16384",
|
|
||||||
"MAX_HEADER_SIZE": "4096",
|
|
||||||
"PPROF_ENABLED": "true",
|
|
||||||
"PPROF_PORT": "7070",
|
|
||||||
"MODE": "standalone",
|
|
||||||
"GRPC_ADDRESS": "127.0.0.1",
|
|
||||||
"GRPC_PORT": "9090",
|
|
||||||
"NODE_TOKEN": "ntoken",
|
|
||||||
}
|
|
||||||
|
|
||||||
os.Clearenv()
|
|
||||||
for k, v := range envs {
|
|
||||||
t.Setenv(k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfg, err := parse()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "example.com", cfg.Domain())
|
|
||||||
assert.Equal(t, "2222", cfg.SSHPort())
|
|
||||||
assert.Equal(t, "80", cfg.HTTPPort())
|
|
||||||
assert.Equal(t, "443", cfg.HTTPSPort())
|
|
||||||
assert.Equal(t, "certs/ssh/id_rsa", cfg.KeyLoc())
|
|
||||||
assert.Equal(t, true, cfg.TLSEnabled())
|
|
||||||
assert.Equal(t, true, cfg.TLSRedirect())
|
|
||||||
assert.Equal(t, "certs/tls/", cfg.TLSStoragePath())
|
|
||||||
assert.Equal(t, "test@example.com", cfg.ACMEEmail())
|
|
||||||
assert.Equal(t, "token", cfg.CFAPIToken())
|
|
||||||
assert.Equal(t, true, cfg.ACMEStaging())
|
|
||||||
assert.Equal(t, uint16(1000), cfg.AllowedPortsStart())
|
|
||||||
assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd())
|
|
||||||
assert.Equal(t, 16384, cfg.BufferSize())
|
|
||||||
assert.Equal(t, 4096, cfg.HeaderSize())
|
|
||||||
assert.Equal(t, true, cfg.PprofEnabled())
|
|
||||||
assert.Equal(t, "7070", cfg.PprofPort())
|
|
||||||
assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode())
|
|
||||||
assert.Equal(t, "127.0.0.1", cfg.GRPCAddress())
|
|
||||||
assert.Equal(t, "9090", cfg.GRPCPort())
|
|
||||||
assert.Equal(t, "ntoken", cfg.NodeToken())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMustLoad(t *testing.T) {
|
|
||||||
t.Run("success", func(t *testing.T) {
|
|
||||||
os.Clearenv()
|
|
||||||
t.Setenv("DOMAIN", "example.com")
|
|
||||||
cfg, err := MustLoad()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, cfg)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("loadEnvFile error", func(t *testing.T) {
|
|
||||||
err := os.Mkdir(".env", 0755)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
err = os.Remove(".env")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
cfg, err := MustLoad()
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, cfg)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("parse error", func(t *testing.T) {
|
|
||||||
os.Clearenv()
|
|
||||||
t.Setenv("MODE", "invalid")
|
|
||||||
cfg, err := MustLoad()
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, cfg)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadEnvFile(t *testing.T) {
|
|
||||||
t.Run("file exists", func(t *testing.T) {
|
|
||||||
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
err = os.Remove(".env")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = loadEnvFile()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE"))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("file missing", func(t *testing.T) {
|
|
||||||
_ = os.Remove(".env")
|
|
||||||
err := loadEnvFile()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,190 +0,0 @@
|
|||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"github.com/joho/godotenv"
|
|
||||||
)
|
|
||||||
|
|
||||||
type config struct {
|
|
||||||
domain string
|
|
||||||
sshPort string
|
|
||||||
|
|
||||||
httpPort string
|
|
||||||
httpsPort string
|
|
||||||
|
|
||||||
keyLoc string
|
|
||||||
|
|
||||||
tlsEnabled bool
|
|
||||||
tlsRedirect bool
|
|
||||||
tlsStoragePath string
|
|
||||||
acmeEmail string
|
|
||||||
cfAPIToken string
|
|
||||||
acmeStaging bool
|
|
||||||
|
|
||||||
allowedPortsStart uint16
|
|
||||||
allowedPortsEnd uint16
|
|
||||||
|
|
||||||
bufferSize int
|
|
||||||
headerSize int
|
|
||||||
|
|
||||||
pprofEnabled bool
|
|
||||||
pprofPort string
|
|
||||||
|
|
||||||
mode types.ServerMode
|
|
||||||
grpcAddress string
|
|
||||||
grpcPort string
|
|
||||||
nodeToken string
|
|
||||||
}
|
|
||||||
|
|
||||||
func parse() (*config, error) {
|
|
||||||
mode, err := parseMode()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
domain := getenv("DOMAIN", "localhost")
|
|
||||||
sshPort := getenv("PORT", "2200")
|
|
||||||
|
|
||||||
httpPort := getenv("HTTP_PORT", "8080")
|
|
||||||
httpsPort := getenv("HTTPS_PORT", "8443")
|
|
||||||
|
|
||||||
keyLoc := getenv("KEY_LOC", "certs/privkey.pem")
|
|
||||||
|
|
||||||
tlsEnabled := getenvBool("TLS_ENABLED", false)
|
|
||||||
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
|
|
||||||
tlsStoragePath := getenv("TLS_STORAGE_PATH", "certs/tls/")
|
|
||||||
|
|
||||||
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
|
|
||||||
acmeStaging := getenvBool("ACME_STAGING", false)
|
|
||||||
|
|
||||||
cfToken := getenv("CF_API_TOKEN", "")
|
|
||||||
if tlsEnabled && cfToken == "" {
|
|
||||||
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
|
|
||||||
}
|
|
||||||
|
|
||||||
start, end, err := parseAllowedPorts()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
bufferSize := parseBufferSize()
|
|
||||||
headerSize := parseHeaderSize()
|
|
||||||
|
|
||||||
pprofEnabled := getenvBool("PPROF_ENABLED", false)
|
|
||||||
pprofPort := getenv("PPROF_PORT", "6060")
|
|
||||||
|
|
||||||
grpcHost := getenv("GRPC_ADDRESS", "localhost")
|
|
||||||
grpcPort := getenv("GRPC_PORT", "8080")
|
|
||||||
|
|
||||||
nodeToken := getenv("NODE_TOKEN", "")
|
|
||||||
if mode == types.ServerModeNODE && nodeToken == "" {
|
|
||||||
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &config{
|
|
||||||
domain: domain,
|
|
||||||
sshPort: sshPort,
|
|
||||||
httpPort: httpPort,
|
|
||||||
httpsPort: httpsPort,
|
|
||||||
keyLoc: keyLoc,
|
|
||||||
tlsEnabled: tlsEnabled,
|
|
||||||
tlsRedirect: tlsRedirect,
|
|
||||||
tlsStoragePath: tlsStoragePath,
|
|
||||||
acmeEmail: acmeEmail,
|
|
||||||
cfAPIToken: cfToken,
|
|
||||||
acmeStaging: acmeStaging,
|
|
||||||
allowedPortsStart: start,
|
|
||||||
allowedPortsEnd: end,
|
|
||||||
bufferSize: bufferSize,
|
|
||||||
headerSize: headerSize,
|
|
||||||
pprofEnabled: pprofEnabled,
|
|
||||||
pprofPort: pprofPort,
|
|
||||||
mode: mode,
|
|
||||||
grpcAddress: grpcHost,
|
|
||||||
grpcPort: grpcPort,
|
|
||||||
nodeToken: nodeToken,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadEnvFile() error {
|
|
||||||
if _, err := os.Stat(".env"); err == nil {
|
|
||||||
return godotenv.Load(".env")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseMode() (types.ServerMode, error) {
|
|
||||||
switch strings.ToLower(getenv("MODE", "standalone")) {
|
|
||||||
case "standalone":
|
|
||||||
return types.ServerModeSTANDALONE, nil
|
|
||||||
case "node":
|
|
||||||
return types.ServerModeNODE, nil
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("invalid MODE value")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseAllowedPorts() (uint16, uint16, error) {
|
|
||||||
raw := getenv("ALLOWED_PORTS", "")
|
|
||||||
if raw == "" {
|
|
||||||
return 0, 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strings.Split(raw, "-")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
|
|
||||||
}
|
|
||||||
|
|
||||||
start, err := strconv.ParseUint(parts[0], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
end, err := strconv.ParseUint(parts[1], 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return uint16(start), uint16(end), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseBufferSize() int {
|
|
||||||
raw := getenv("BUFFER_SIZE", "32768")
|
|
||||||
size, err := strconv.Atoi(raw)
|
|
||||||
if err != nil || size < 4096 || size > 1048576 {
|
|
||||||
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
|
|
||||||
return 4096
|
|
||||||
}
|
|
||||||
return size
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseHeaderSize() int {
|
|
||||||
raw := getenv("MAX_HEADER_SIZE", "4096")
|
|
||||||
size, err := strconv.Atoi(raw)
|
|
||||||
if err != nil || size < 4096 || size > 131072 {
|
|
||||||
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
|
|
||||||
return 4096
|
|
||||||
}
|
|
||||||
return size
|
|
||||||
}
|
|
||||||
|
|
||||||
func getenv(key, def string) string {
|
|
||||||
if v := os.Getenv(key); v != "" {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
|
|
||||||
func getenvBool(key string, def bool) bool {
|
|
||||||
val := os.Getenv(key)
|
|
||||||
if val == "" {
|
|
||||||
return def
|
|
||||||
}
|
|
||||||
return val == "true"
|
|
||||||
}
|
|
||||||
+260
-261
@@ -2,18 +2,20 @@ package client
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"time"
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
"tunnel_pls/types"
|
"tunnel_pls/session"
|
||||||
|
|
||||||
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
|
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/health/grpc_health_v1"
|
"google.golang.org/grpc/health/grpc_health_v1"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -21,43 +23,84 @@ import (
|
|||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client interface {
|
type GrpcConfig struct {
|
||||||
SubscribeEvents(ctx context.Context, identity, authToken string) error
|
Address string
|
||||||
ClientConn() *grpc.ClientConn
|
UseTLS bool
|
||||||
AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error)
|
InsecureSkipVerify bool
|
||||||
Close() error
|
Timeout time.Duration
|
||||||
CheckServerHealth(ctx context.Context) error
|
KeepAlive bool
|
||||||
|
MaxRetries int
|
||||||
|
KeepAliveTime time.Duration
|
||||||
|
KeepAliveTimeout time.Duration
|
||||||
|
PermitWithoutStream bool
|
||||||
}
|
}
|
||||||
type client struct {
|
|
||||||
config config.Config
|
type Client struct {
|
||||||
conn *grpc.ClientConn
|
conn *grpc.ClientConn
|
||||||
address string
|
config *GrpcConfig
|
||||||
sessionRegistry registry.Registry
|
sessionRegistry session.Registry
|
||||||
|
slugService proto.SlugChangeClient
|
||||||
eventService proto.EventServiceClient
|
eventService proto.EventServiceClient
|
||||||
authorizeConnectionService proto.UserServiceClient
|
authorizeConnectionService proto.UserServiceClient
|
||||||
closing bool
|
closing bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
func DefaultConfig() *GrpcConfig {
|
||||||
grpcNewClient = grpc.NewClient
|
return &GrpcConfig{
|
||||||
healthNewHealthClient = grpc_health_v1.NewHealthClient
|
Address: "localhost:50051",
|
||||||
initialBackoff = time.Second
|
UseTLS: false,
|
||||||
)
|
InsecureSkipVerify: false,
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
KeepAlive: true,
|
||||||
|
MaxRetries: 3,
|
||||||
|
KeepAliveTime: 2 * time.Minute,
|
||||||
|
KeepAliveTimeout: 10 * time.Second,
|
||||||
|
PermitWithoutStream: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func New(config config.Config, sessionRegistry registry.Registry) (Client, error) {
|
func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error) {
|
||||||
address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort())
|
if config == nil {
|
||||||
|
config = DefaultConfig()
|
||||||
|
} else {
|
||||||
|
defaults := DefaultConfig()
|
||||||
|
if config.Address == "" {
|
||||||
|
config.Address = defaults.Address
|
||||||
|
}
|
||||||
|
if config.Timeout == 0 {
|
||||||
|
config.Timeout = defaults.Timeout
|
||||||
|
}
|
||||||
|
if config.MaxRetries == 0 {
|
||||||
|
config.MaxRetries = defaults.MaxRetries
|
||||||
|
}
|
||||||
|
if config.KeepAliveTime == 0 {
|
||||||
|
config.KeepAliveTime = defaults.KeepAliveTime
|
||||||
|
}
|
||||||
|
if config.KeepAliveTimeout == 0 {
|
||||||
|
config.KeepAliveTimeout = defaults.KeepAliveTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var opts []grpc.DialOption
|
var opts []grpc.DialOption
|
||||||
|
|
||||||
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
if config.UseTLS {
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
kaParams := keepalive.ClientParameters{
|
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||||
Time: 2 * time.Minute,
|
}
|
||||||
Timeout: 10 * time.Second,
|
creds := credentials.NewTLS(tlsConfig)
|
||||||
PermitWithoutStream: false,
|
opts = append(opts, grpc.WithTransportCredentials(creds))
|
||||||
|
} else {
|
||||||
|
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
}
|
}
|
||||||
|
|
||||||
opts = append(opts, grpc.WithKeepaliveParams(kaParams))
|
if config.KeepAlive {
|
||||||
|
kaParams := keepalive.ClientParameters{
|
||||||
|
Time: config.KeepAliveTime,
|
||||||
|
Timeout: config.KeepAliveTimeout,
|
||||||
|
PermitWithoutStream: config.PermitWithoutStream,
|
||||||
|
}
|
||||||
|
opts = append(opts, grpc.WithKeepaliveParams(kaParams))
|
||||||
|
}
|
||||||
|
|
||||||
opts = append(opts,
|
opts = append(opts,
|
||||||
grpc.WithDefaultCallOptions(
|
grpc.WithDefaultCallOptions(
|
||||||
@@ -66,264 +109,216 @@ func New(config config.Config, sessionRegistry registry.Registry) (Client, error
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
conn, err := grpcNewClient(address, opts...)
|
conn, err := grpc.NewClient(config.Address, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err)
|
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", config.Address, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slugService := proto.NewSlugChangeClient(conn)
|
||||||
eventService := proto.NewEventServiceClient(conn)
|
eventService := proto.NewEventServiceClient(conn)
|
||||||
authorizeConnectionService := proto.NewUserServiceClient(conn)
|
authorizeConnectionService := proto.NewUserServiceClient(conn)
|
||||||
|
|
||||||
return &client{
|
return &Client{
|
||||||
config: config,
|
|
||||||
conn: conn,
|
conn: conn,
|
||||||
address: address,
|
config: config,
|
||||||
|
slugService: slugService,
|
||||||
sessionRegistry: sessionRegistry,
|
sessionRegistry: sessionRegistry,
|
||||||
eventService: eventService,
|
eventService: eventService,
|
||||||
authorizeConnectionService: authorizeConnectionService,
|
authorizeConnectionService: authorizeConnectionService,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error {
|
func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string) error {
|
||||||
backoff := initialBackoff
|
const (
|
||||||
|
baseBackoff = time.Second
|
||||||
|
maxBackoff = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
backoff := baseBackoff
|
||||||
|
wait := func() error {
|
||||||
|
if backoff <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(backoff):
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
growBackoff := func() {
|
||||||
|
backoff *= 2
|
||||||
|
if backoff > maxBackoff {
|
||||||
|
backoff = maxBackoff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if err := c.subscribeAndProcess(ctx, identity, authToken, &backoff); err != nil {
|
subscribe, err := c.eventService.Subscribe(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = wait(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
growBackoff()
|
||||||
|
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = subscribe.Send(&proto.Node{
|
||||||
|
Type: proto.EventType_AUTHENTICATION,
|
||||||
|
Payload: &proto.Node_AuthEvent{
|
||||||
|
AuthEvent: &proto.Authentication{
|
||||||
|
Identity: identity,
|
||||||
|
AuthToken: authToken,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Authentication failed to send to gRPC server:", err)
|
||||||
|
if c.isConnectionError(err) {
|
||||||
|
if err = wait(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
growBackoff()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Println("Authentication Successfully sent to gRPC server")
|
||||||
|
backoff = baseBackoff
|
||||||
|
|
||||||
|
if err = c.processEventStream(subscribe); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if c.isConnectionError(err) {
|
||||||
|
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
||||||
|
if err = wait(); err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
growBackoff()
|
||||||
|
continue
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) subscribeAndProcess(ctx context.Context, identity, authToken string, backoff *time.Duration) error {
|
func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error {
|
||||||
subscribe, err := c.eventService.Subscribe(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return c.handleSubscribeError(ctx, err, backoff)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = subscribe.Send(&proto.Node{
|
|
||||||
Type: proto.EventType_AUTHENTICATION,
|
|
||||||
Payload: &proto.Node_AuthEvent{
|
|
||||||
AuthEvent: &proto.Authentication{
|
|
||||||
Identity: identity,
|
|
||||||
AuthToken: authToken,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return c.handleAuthError(ctx, err, backoff)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Println("Authentication Successfully sent to gRPC server")
|
|
||||||
*backoff = time.Second
|
|
||||||
|
|
||||||
return c.handleStreamError(ctx, c.processEventStream(subscribe), backoff)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handleSubscribeError(ctx context.Context, err error, backoff *time.Duration) error {
|
|
||||||
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err = c.wait(ctx, *backoff); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.growBackoff(backoff)
|
|
||||||
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handleAuthError(ctx context.Context, err error, backoff *time.Duration) error {
|
|
||||||
log.Println("Authentication failed to send to gRPC server:", err)
|
|
||||||
if !c.isConnectionError(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := c.wait(ctx, *backoff); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.growBackoff(backoff)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handleStreamError(ctx context.Context, err error, backoff *time.Duration) error {
|
|
||||||
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if !c.isConnectionError(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
|
||||||
if err := c.wait(ctx, *backoff); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.growBackoff(backoff)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) wait(ctx context.Context, duration time.Duration) error {
|
|
||||||
if duration <= 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-time.After(duration):
|
|
||||||
return nil
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) growBackoff(backoff *time.Duration) {
|
|
||||||
const maxBackoff = 30 * time.Second
|
|
||||||
*backoff *= 2
|
|
||||||
if *backoff > maxBackoff {
|
|
||||||
*backoff = maxBackoff
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error {
|
|
||||||
handlers := c.eventHandlers(subscribe)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
recv, err := subscribe.Recv()
|
recv, err := subscribe.Recv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
switch recv.GetType() {
|
||||||
handler, ok := handlers[recv.GetType()]
|
case proto.EventType_SLUG_CHANGE:
|
||||||
if !ok {
|
oldSlug := recv.GetSlugEvent().GetOld()
|
||||||
|
newSlug := recv.GetSlugEvent().GetNew()
|
||||||
|
sess, err := c.sessionRegistry.Get(oldSlug)
|
||||||
|
if err != nil {
|
||||||
|
errSend := subscribe.Send(&proto.Node{
|
||||||
|
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||||
|
Payload: &proto.Node_SlugEventResponse{
|
||||||
|
SlugEventResponse: &proto.SlugChangeEventResponse{
|
||||||
|
Success: false,
|
||||||
|
Message: err.Error(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if errSend != nil {
|
||||||
|
if c.isConnectionError(errSend) {
|
||||||
|
return errSend
|
||||||
|
}
|
||||||
|
log.Printf("non-connection send error for slug change failure: %v", errSend)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = c.sessionRegistry.Update(oldSlug, newSlug)
|
||||||
|
if err != nil {
|
||||||
|
errSend := subscribe.Send(&proto.Node{
|
||||||
|
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||||
|
Payload: &proto.Node_SlugEventResponse{
|
||||||
|
SlugEventResponse: &proto.SlugChangeEventResponse{
|
||||||
|
Success: false,
|
||||||
|
Message: err.Error(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if errSend != nil {
|
||||||
|
if c.isConnectionError(errSend) {
|
||||||
|
return errSend
|
||||||
|
}
|
||||||
|
log.Printf("non-connection send error for slug change failure: %v", errSend)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sess.GetInteraction().Redraw()
|
||||||
|
err = subscribe.Send(&proto.Node{
|
||||||
|
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||||
|
Payload: &proto.Node_SlugEventResponse{
|
||||||
|
SlugEventResponse: &proto.SlugChangeEventResponse{
|
||||||
|
Success: true,
|
||||||
|
Message: "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if c.isConnectionError(err) {
|
||||||
|
log.Printf("connection error sending slug change success: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("non-connection send error for slug change success: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
case proto.EventType_GET_SESSIONS:
|
||||||
|
sessions := c.sessionRegistry.GetAllSessionFromUser(recv.GetGetSessionsEvent().GetIdentity())
|
||||||
|
var details []*proto.Detail
|
||||||
|
for _, ses := range sessions {
|
||||||
|
detail := ses.Detail()
|
||||||
|
details = append(details, &proto.Detail{
|
||||||
|
Node: config.Getenv("domain", "localhost"),
|
||||||
|
ForwardingType: detail.ForwardingType,
|
||||||
|
Slug: detail.Slug,
|
||||||
|
UserId: detail.UserID,
|
||||||
|
Active: detail.Active,
|
||||||
|
StartedAt: timestamppb.New(detail.StartedAt),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
err = subscribe.Send(&proto.Node{
|
||||||
|
Type: proto.EventType_GET_SESSIONS,
|
||||||
|
Payload: &proto.Node_GetSessionsEvent{
|
||||||
|
GetSessionsEvent: &proto.GetSessionsResponse{
|
||||||
|
Details: details,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
if c.isConnectionError(err) {
|
||||||
|
log.Printf("connection error sending sessions success: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Printf("non-connection send error for sessions success: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
default:
|
||||||
log.Printf("Unknown event type received: %v", recv.GetType())
|
log.Printf("Unknown event type received: %v", recv.GetType())
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = handler(recv); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error {
|
func (c *Client) GetConnection() *grpc.ClientConn {
|
||||||
return map[proto.EventType]func(*proto.Events) error{
|
|
||||||
proto.EventType_SLUG_CHANGE: func(evt *proto.Events) error { return c.handleSlugChange(subscribe, evt) },
|
|
||||||
proto.EventType_GET_SESSIONS: func(evt *proto.Events) error { return c.handleGetSessions(subscribe, evt) },
|
|
||||||
proto.EventType_TERMINATE_SESSION: func(evt *proto.Events) error { return c.handleTerminateSession(subscribe, evt) },
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
|
|
||||||
slugEvent := evt.GetSlugEvent()
|
|
||||||
user := slugEvent.GetUser()
|
|
||||||
oldKey := types.SessionKey{Id: slugEvent.GetOld(), Type: types.TunnelTypeHTTP}
|
|
||||||
newKey := types.SessionKey{Id: slugEvent.GetNew(), Type: types.TunnelTypeHTTP}
|
|
||||||
|
|
||||||
userSession, err := c.sessionRegistry.Get(oldKey)
|
|
||||||
if err != nil {
|
|
||||||
return c.sendSlugChangeResponse(subscribe, false, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = c.sessionRegistry.Update(user, oldKey, newKey); err != nil {
|
|
||||||
return c.sendSlugChangeResponse(subscribe, false, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
userSession.Interaction().Redraw()
|
|
||||||
return c.sendSlugChangeResponse(subscribe, true, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
|
|
||||||
sessions := c.sessionRegistry.GetAllSessionFromUser(evt.GetGetSessionsEvent().GetIdentity())
|
|
||||||
|
|
||||||
var details []*proto.Detail
|
|
||||||
for _, ses := range sessions {
|
|
||||||
detail := ses.Detail()
|
|
||||||
details = append(details, &proto.Detail{
|
|
||||||
Node: c.config.Domain(),
|
|
||||||
ForwardingType: detail.ForwardingType,
|
|
||||||
Slug: detail.Slug,
|
|
||||||
UserId: detail.UserID,
|
|
||||||
Active: detail.Active,
|
|
||||||
StartedAt: timestamppb.New(detail.StartedAt),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.sendGetSessionsResponse(subscribe, details)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
|
|
||||||
terminate := evt.GetTerminateSessionEvent()
|
|
||||||
user := terminate.GetUser()
|
|
||||||
slug := terminate.GetSlug()
|
|
||||||
|
|
||||||
tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType())
|
|
||||||
if err != nil {
|
|
||||||
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType})
|
|
||||||
if err != nil {
|
|
||||||
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = userSession.Lifecycle().Close(); err != nil {
|
|
||||||
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.sendTerminateSessionResponse(subscribe, true, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) sendSlugChangeResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
|
|
||||||
return c.sendNode(subscribe, &proto.Node{
|
|
||||||
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
|
||||||
Payload: &proto.Node_SlugEventResponse{
|
|
||||||
SlugEventResponse: &proto.SlugChangeEventResponse{Success: success, Message: message},
|
|
||||||
},
|
|
||||||
}, "slug change response")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) sendGetSessionsResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], details []*proto.Detail) error {
|
|
||||||
return c.sendNode(subscribe, &proto.Node{
|
|
||||||
Type: proto.EventType_GET_SESSIONS,
|
|
||||||
Payload: &proto.Node_GetSessionsEvent{
|
|
||||||
GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
|
|
||||||
},
|
|
||||||
}, "send get sessions response")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) sendTerminateSessionResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
|
|
||||||
return c.sendNode(subscribe, &proto.Node{
|
|
||||||
Type: proto.EventType_TERMINATE_SESSION,
|
|
||||||
Payload: &proto.Node_TerminateSessionEventResponse{
|
|
||||||
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: success, Message: message},
|
|
||||||
},
|
|
||||||
}, "terminate session response")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error {
|
|
||||||
if err := subscribe.Send(node); err != nil {
|
|
||||||
if c.isConnectionError(err) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Printf("%s: %v", context, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
|
|
||||||
switch t {
|
|
||||||
case proto.TunnelType_HTTP:
|
|
||||||
return types.TunnelTypeHTTP, nil
|
|
||||||
case proto.TunnelType_TCP:
|
|
||||||
return types.TunnelTypeTCP, nil
|
|
||||||
default:
|
|
||||||
return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) ClientConn() *grpc.ClientConn {
|
|
||||||
return c.conn
|
return c.conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
||||||
check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token})
|
check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, "UNAUTHORIZED", err
|
return false, "UNAUTHORIZED", err
|
||||||
@@ -335,8 +330,17 @@ func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bo
|
|||||||
return true, check.GetUser(), nil
|
return true, check.GetUser(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) CheckServerHealth(ctx context.Context) error {
|
func (c *Client) Close() error {
|
||||||
healthClient := healthNewHealthClient(c.ClientConn())
|
if c.conn != nil {
|
||||||
|
log.Printf("Closing gRPC connection to %s", c.config.Address)
|
||||||
|
c.closing = true
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) CheckServerHealth(ctx context.Context) error {
|
||||||
|
healthClient := grpc_health_v1.NewHealthClient(c.GetConnection())
|
||||||
resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{
|
resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{
|
||||||
Service: "",
|
Service: "",
|
||||||
})
|
})
|
||||||
@@ -349,16 +353,11 @@ func (c *client) CheckServerHealth(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) Close() error {
|
func (c *Client) GetConfig() *GrpcConfig {
|
||||||
if c.conn != nil {
|
return c.config
|
||||||
log.Printf("Closing gRPC connection to %s", c.address)
|
|
||||||
c.closing = true
|
|
||||||
return c.conn.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) isConnectionError(err error) bool {
|
func (c *Client) isConnectionError(err error) bool {
|
||||||
if c.closing {
|
if c.closing {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,30 +0,0 @@
|
|||||||
package header
|
|
||||||
|
|
||||||
type ResponseHeader interface {
|
|
||||||
Value(key string) string
|
|
||||||
Set(key string, value string)
|
|
||||||
Remove(key string)
|
|
||||||
Finalize() []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type responseHeader struct {
|
|
||||||
startLine []byte
|
|
||||||
headers map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
type RequestHeader interface {
|
|
||||||
Value(key string) string
|
|
||||||
Set(key string, value string)
|
|
||||||
Remove(key string)
|
|
||||||
Finalize() []byte
|
|
||||||
Method() string
|
|
||||||
Path() string
|
|
||||||
Version() string
|
|
||||||
}
|
|
||||||
type requestHeader struct {
|
|
||||||
method string
|
|
||||||
path string
|
|
||||||
version string
|
|
||||||
startLine []byte
|
|
||||||
headers map[string]string
|
|
||||||
}
|
|
||||||
@@ -1,227 +0,0 @@
|
|||||||
package header
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewRequest(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
data []byte
|
|
||||||
expectErr bool
|
|
||||||
errContains string
|
|
||||||
expectMethod string
|
|
||||||
expectPath string
|
|
||||||
expectVersion string
|
|
||||||
expectHeaders map[string]string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "success",
|
|
||||||
data: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\nX-Custom: value\r\n\r\n"),
|
|
||||||
expectErr: false,
|
|
||||||
expectMethod: "GET",
|
|
||||||
expectPath: "/path",
|
|
||||||
expectVersion: "HTTP/1.1",
|
|
||||||
expectHeaders: map[string]string{
|
|
||||||
"Host": "example.com",
|
|
||||||
"X-Custom": "value",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no CRLF in start line",
|
|
||||||
data: []byte("GET /path HTTP/1.1"),
|
|
||||||
expectErr: true,
|
|
||||||
errContains: "no CRLF found in start line",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid start line - missing method",
|
|
||||||
data: []byte("INVALID\r\n\r\n"),
|
|
||||||
expectErr: true,
|
|
||||||
errContains: "invalid start line: missing method",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid start line - missing version",
|
|
||||||
data: []byte("GET /path\r\n\r\n"),
|
|
||||||
expectErr: true,
|
|
||||||
errContains: "invalid start line: missing version",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid start line - multiple spaces",
|
|
||||||
data: []byte("GET /path HTTP/1.1\r\n\r\n"),
|
|
||||||
expectErr: false,
|
|
||||||
expectMethod: "GET",
|
|
||||||
expectPath: "",
|
|
||||||
expectVersion: "/path HTTP/1.1",
|
|
||||||
expectHeaders: map[string]string{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "start line with trailing space",
|
|
||||||
data: []byte("GET / HTTP/1.1 \r\n\r\n"),
|
|
||||||
expectErr: false,
|
|
||||||
expectMethod: "GET",
|
|
||||||
expectPath: "/",
|
|
||||||
expectVersion: "HTTP/1.1 ",
|
|
||||||
expectHeaders: map[string]string{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
req, err := NewRequest(tt.data)
|
|
||||||
if tt.expectErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
if tt.errContains != "" {
|
|
||||||
assert.Contains(t, err.Error(), tt.errContains)
|
|
||||||
}
|
|
||||||
assert.Nil(t, req)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, req)
|
|
||||||
assert.Equal(t, tt.expectMethod, req.Method())
|
|
||||||
assert.Equal(t, tt.expectPath, req.Path())
|
|
||||||
assert.Equal(t, tt.expectVersion, req.Version())
|
|
||||||
for k, v := range tt.expectHeaders {
|
|
||||||
assert.Equal(t, v, req.Value(k))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRequestHeaderMethods(t *testing.T) {
|
|
||||||
data := []byte("GET / HTTP/1.1\r\nHost: original\r\n\r\n")
|
|
||||||
req, _ := NewRequest(data)
|
|
||||||
|
|
||||||
req.Set("Host", "updated")
|
|
||||||
req.Set("X-New", "new-value")
|
|
||||||
assert.Equal(t, "updated", req.Value("Host"))
|
|
||||||
assert.Equal(t, "new-value", req.Value("X-New"))
|
|
||||||
|
|
||||||
assert.Equal(t, "", req.Value("Non-Existent"))
|
|
||||||
|
|
||||||
req.Remove("X-New")
|
|
||||||
assert.Equal(t, "", req.Value("X-New"))
|
|
||||||
|
|
||||||
final := req.Finalize()
|
|
||||||
assert.Contains(t, string(final), "GET / HTTP/1.1\r\n")
|
|
||||||
assert.Contains(t, string(final), "Host: updated\r\n")
|
|
||||||
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewResponse(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
data []byte
|
|
||||||
expectErr bool
|
|
||||||
errContains string
|
|
||||||
expectHeaders map[string]string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "success",
|
|
||||||
data: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"),
|
|
||||||
expectErr: false,
|
|
||||||
expectHeaders: map[string]string{
|
|
||||||
"Content-Length": "0",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid response - no CRLF",
|
|
||||||
data: []byte("HTTP/1.1 200 OK"),
|
|
||||||
expectErr: true,
|
|
||||||
errContains: "no CRLF found in start line",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
resp, err := NewResponse(tt.data)
|
|
||||||
if tt.expectErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
if tt.errContains != "" {
|
|
||||||
assert.Contains(t, err.Error(), tt.errContains)
|
|
||||||
}
|
|
||||||
assert.Nil(t, resp)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp)
|
|
||||||
for k, v := range tt.expectHeaders {
|
|
||||||
assert.Equal(t, v, resp.Value(k))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResponseHeaderMethods(t *testing.T) {
|
|
||||||
data := []byte("HTTP/1.1 200 OK\r\nServer: old\r\n\r\n")
|
|
||||||
resp, _ := NewResponse(data)
|
|
||||||
|
|
||||||
resp.Set("Server", "new")
|
|
||||||
resp.Set("X-Res", "val")
|
|
||||||
assert.Equal(t, "new", resp.Value("Server"))
|
|
||||||
assert.Equal(t, "val", resp.Value("X-Res"))
|
|
||||||
|
|
||||||
resp.Remove("X-Res")
|
|
||||||
assert.Equal(t, "", resp.Value("X-Res"))
|
|
||||||
|
|
||||||
final := resp.Finalize()
|
|
||||||
assert.Contains(t, string(final), "HTTP/1.1 200 OK\r\n")
|
|
||||||
assert.Contains(t, string(final), "Server: new\r\n")
|
|
||||||
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetRemainingHeaders(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
data []byte
|
|
||||||
initialHeaders map[string]string
|
|
||||||
expectHeaders map[string]string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "various header formats",
|
|
||||||
data: []byte("K1: V1\r\nK2:V2\r\n K3 : V3 \r\nNoColon\r\n\r\n"),
|
|
||||||
expectHeaders: map[string]string{
|
|
||||||
"K1": "V1",
|
|
||||||
"K2": "V2",
|
|
||||||
"K3": "V3",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no trailing CRLF",
|
|
||||||
data: []byte("K1: V1"),
|
|
||||||
expectHeaders: map[string]string{
|
|
||||||
"K1": "V1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty lines",
|
|
||||||
data: []byte("\r\nK1: V1"),
|
|
||||||
expectHeaders: map[string]string{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "headers with only colon",
|
|
||||||
data: []byte(": value\r\nkey:\r\n"),
|
|
||||||
expectHeaders: map[string]string{
|
|
||||||
"": "value",
|
|
||||||
"key": "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
req := &requestHeader{headers: make(map[string]string)}
|
|
||||||
if tt.initialHeaders != nil {
|
|
||||||
req.headers = tt.initialHeaders
|
|
||||||
}
|
|
||||||
setRemainingHeaders(tt.data, req)
|
|
||||||
assert.Equal(t, len(tt.expectHeaders), len(req.headers))
|
|
||||||
for k, v := range tt.expectHeaders {
|
|
||||||
assert.Equal(t, v, req.headers[k])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,77 +0,0 @@
|
|||||||
package header
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setRemainingHeaders(remaining []byte, header interface {
|
|
||||||
Set(key string, value string)
|
|
||||||
}) {
|
|
||||||
for len(remaining) > 0 {
|
|
||||||
lineEnd := bytes.Index(remaining, []byte("\r\n"))
|
|
||||||
if lineEnd == -1 {
|
|
||||||
lineEnd = len(remaining)
|
|
||||||
}
|
|
||||||
|
|
||||||
line := remaining[:lineEnd]
|
|
||||||
|
|
||||||
if len(line) == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
colonIdx := bytes.IndexByte(line, ':')
|
|
||||||
if colonIdx != -1 {
|
|
||||||
key := bytes.TrimSpace(line[:colonIdx])
|
|
||||||
value := bytes.TrimSpace(line[colonIdx+1:])
|
|
||||||
header.Set(string(key), string(value))
|
|
||||||
}
|
|
||||||
|
|
||||||
if lineEnd == len(remaining) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
remaining = remaining[lineEnd+2:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseStartLine(startLine []byte) (method, path, version string, err error) {
|
|
||||||
firstSpace := bytes.IndexByte(startLine, ' ')
|
|
||||||
if firstSpace == -1 {
|
|
||||||
return "", "", "", fmt.Errorf("invalid start line: missing method")
|
|
||||||
}
|
|
||||||
|
|
||||||
secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ')
|
|
||||||
if secondSpace == -1 {
|
|
||||||
return "", "", "", fmt.Errorf("invalid start line: missing version")
|
|
||||||
}
|
|
||||||
secondSpace += firstSpace + 1
|
|
||||||
|
|
||||||
method = string(startLine[:firstSpace])
|
|
||||||
path = string(startLine[firstSpace+1 : secondSpace])
|
|
||||||
version = string(startLine[secondSpace+1:])
|
|
||||||
|
|
||||||
return method, path, version, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func finalize(startLine []byte, headers map[string]string) []byte {
|
|
||||||
size := len(startLine) + 2
|
|
||||||
for key, val := range headers {
|
|
||||||
size += len(key) + 2 + len(val) + 2
|
|
||||||
}
|
|
||||||
size += 2
|
|
||||||
|
|
||||||
buf := make([]byte, 0, size)
|
|
||||||
buf = append(buf, startLine...)
|
|
||||||
buf = append(buf, '\r', '\n')
|
|
||||||
|
|
||||||
for key, val := range headers {
|
|
||||||
buf = append(buf, key...)
|
|
||||||
buf = append(buf, ':', ' ')
|
|
||||||
buf = append(buf, val...)
|
|
||||||
buf = append(buf, '\r', '\n')
|
|
||||||
}
|
|
||||||
|
|
||||||
buf = append(buf, '\r', '\n')
|
|
||||||
return buf
|
|
||||||
}
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
package header
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewRequest(headerData []byte) (RequestHeader, error) {
|
|
||||||
header := &requestHeader{
|
|
||||||
headers: make(map[string]string, 16),
|
|
||||||
}
|
|
||||||
|
|
||||||
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
|
||||||
if lineEnd == -1 {
|
|
||||||
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
|
|
||||||
}
|
|
||||||
|
|
||||||
startLine := headerData[:lineEnd]
|
|
||||||
header.startLine = startLine
|
|
||||||
var err error
|
|
||||||
header.method, header.path, header.version, err = parseStartLine(startLine)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
remaining := headerData[lineEnd+2:]
|
|
||||||
|
|
||||||
setRemainingHeaders(remaining, header)
|
|
||||||
|
|
||||||
return header, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Value(key string) string {
|
|
||||||
val, ok := req.headers[key]
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Set(key string, value string) {
|
|
||||||
req.headers[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Remove(key string) {
|
|
||||||
delete(req.headers, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Method() string {
|
|
||||||
return req.method
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Path() string {
|
|
||||||
return req.path
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Version() string {
|
|
||||||
return req.version
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Finalize() []byte {
|
|
||||||
return finalize(req.startLine, req.headers)
|
|
||||||
}
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
package header
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewResponse(headerData []byte) (ResponseHeader, error) {
|
|
||||||
header := &responseHeader{
|
|
||||||
startLine: nil,
|
|
||||||
headers: make(map[string]string, 16),
|
|
||||||
}
|
|
||||||
|
|
||||||
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
|
||||||
if lineEnd == -1 {
|
|
||||||
return nil, fmt.Errorf("invalid response: no CRLF found in start line")
|
|
||||||
}
|
|
||||||
|
|
||||||
header.startLine = headerData[:lineEnd]
|
|
||||||
remaining := headerData[lineEnd+2:]
|
|
||||||
setRemainingHeaders(remaining, header)
|
|
||||||
|
|
||||||
return header, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (resp *responseHeader) Value(key string) string {
|
|
||||||
return resp.headers[key]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (resp *responseHeader) Set(key string, value string) {
|
|
||||||
resp.headers[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (resp *responseHeader) Remove(key string) {
|
|
||||||
delete(resp.headers, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (resp *responseHeader) Finalize() []byte {
|
|
||||||
return finalize(resp.startLine, resp.headers)
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package stream
|
|
||||||
|
|
||||||
import "bytes"
|
|
||||||
|
|
||||||
func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
|
|
||||||
headerByte := data[:delimiterIdx+len(DELIMITER)]
|
|
||||||
body := data[delimiterIdx+len(DELIMITER):]
|
|
||||||
return headerByte, body
|
|
||||||
}
|
|
||||||
|
|
||||||
func isHTTPHeader(buf []byte) bool {
|
|
||||||
lines := bytes.Split(buf, []byte("\r\n"))
|
|
||||||
|
|
||||||
startLine := string(lines[0])
|
|
||||||
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, line := range lines[1:] {
|
|
||||||
if len(line) == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
colonIdx := bytes.IndexByte(line, ':')
|
|
||||||
if colonIdx <= 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
package stream
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"tunnel_pls/internal/http/header"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (hs *http) Read(p []byte) (int, error) {
|
|
||||||
tmp := make([]byte, len(p))
|
|
||||||
read, err := hs.reader.Read(tmp)
|
|
||||||
if read == 0 && err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
tmp = tmp[:read]
|
|
||||||
|
|
||||||
headerEndIdx := bytes.Index(tmp, DELIMITER)
|
|
||||||
if headerEndIdx == -1 {
|
|
||||||
return handleNoDelimiter(p, tmp, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx)
|
|
||||||
|
|
||||||
if !isHTTPHeader(headerByte) {
|
|
||||||
copy(p, tmp)
|
|
||||||
return read, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return hs.processHTTPRequest(p, headerByte, bodyByte)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) {
|
|
||||||
reqhf, err := header.NewRequest(headerByte)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = hs.ApplyRequestMiddlewares(reqhf); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.reqHeader = reqhf
|
|
||||||
combined := append(reqhf.Finalize(), bodyByte...)
|
|
||||||
return copy(p, combined), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleNoDelimiter(p, tmp []byte, err error) (int, error) {
|
|
||||||
copy(p, tmp)
|
|
||||||
return len(tmp), err
|
|
||||||
}
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
package stream
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"regexp"
|
|
||||||
"tunnel_pls/internal/http/header"
|
|
||||||
"tunnel_pls/internal/middleware"
|
|
||||||
)
|
|
||||||
|
|
||||||
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
|
|
||||||
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
|
|
||||||
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
|
|
||||||
|
|
||||||
type HTTP interface {
|
|
||||||
io.ReadWriteCloser
|
|
||||||
CloseWrite() error
|
|
||||||
RemoteAddr() net.Addr
|
|
||||||
UseResponseMiddleware(mw middleware.ResponseMiddleware)
|
|
||||||
UseRequestMiddleware(mw middleware.RequestMiddleware)
|
|
||||||
SetRequestHeader(header header.RequestHeader)
|
|
||||||
RequestMiddlewares() []middleware.RequestMiddleware
|
|
||||||
ResponseMiddlewares() []middleware.ResponseMiddleware
|
|
||||||
ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error
|
|
||||||
ApplyRequestMiddlewares(reqhf header.RequestHeader) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type http struct {
|
|
||||||
remoteAddr net.Addr
|
|
||||||
writer io.Writer
|
|
||||||
reader io.Reader
|
|
||||||
buf []byte
|
|
||||||
respHeader header.ResponseHeader
|
|
||||||
reqHeader header.RequestHeader
|
|
||||||
respMW []middleware.ResponseMiddleware
|
|
||||||
reqMW []middleware.RequestMiddleware
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP {
|
|
||||||
return &http{
|
|
||||||
remoteAddr: remoteAddr,
|
|
||||||
writer: writer,
|
|
||||||
reader: reader,
|
|
||||||
buf: make([]byte, 0, 4096),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) RemoteAddr() net.Addr {
|
|
||||||
return hs.remoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) {
|
|
||||||
hs.respMW = append(hs.respMW, mw)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) {
|
|
||||||
hs.reqMW = append(hs.reqMW, mw)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) SetRequestHeader(header header.RequestHeader) {
|
|
||||||
hs.reqHeader = header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware {
|
|
||||||
return hs.reqMW
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
|
|
||||||
return hs.respMW
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) Close() error {
|
|
||||||
if closer, ok := hs.writer.(io.Closer); ok {
|
|
||||||
return closer.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) CloseWrite() error {
|
|
||||||
if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok {
|
|
||||||
return closer.CloseWrite()
|
|
||||||
}
|
|
||||||
return hs.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error {
|
|
||||||
for _, m := range hs.RequestMiddlewares() {
|
|
||||||
if err := m.HandleRequest(reqhf); err != nil {
|
|
||||||
log.Printf("Error when applying request middleware: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error {
|
|
||||||
for _, m := range hs.ResponseMiddlewares() {
|
|
||||||
if err := m.HandleResponse(resphf, bodyByte); err != nil {
|
|
||||||
log.Printf("Cannot apply middleware: %s\n", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,765 +0,0 @@
|
|||||||
package stream
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"tunnel_pls/internal/http/header"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockAddr struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockAddr) String() string {
|
|
||||||
args := m.Called()
|
|
||||||
return args.String(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockAddr) Network() string {
|
|
||||||
args := m.Called()
|
|
||||||
return args.String(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockRequestMiddleware struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
|
|
||||||
args := m.Called(h)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockResponseMiddleware struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
|
|
||||||
args := m.Called(h, body)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockReadWriter struct {
|
|
||||||
mock.Mock
|
|
||||||
bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockReadWriter) Read(p []byte) (int, error) {
|
|
||||||
args := m.Called(p)
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockReadWriter) Write(p []byte) (int, error) {
|
|
||||||
args := m.Called(p)
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockReadWriter) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockReadWriter) CloseWrite() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockReadWriterOnlyCloser struct {
|
|
||||||
mock.Mock
|
|
||||||
bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) {
|
|
||||||
args := m.Called(p)
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) {
|
|
||||||
args := m.Called(p)
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockReadWriterOnlyCloser) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockWriterOnly struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWriterOnly) Write(p []byte) (int, error) {
|
|
||||||
args := m.Called(p)
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWriterOnly) Read(p []byte) (int, error) {
|
|
||||||
args := m.Called(p)
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockReader struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockReader) Read(p []byte) (int, error) {
|
|
||||||
args := m.Called(p)
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockWriter struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWriter) Write(p []byte) (int, error) {
|
|
||||||
ret := m.Called(p)
|
|
||||||
|
|
||||||
var n int
|
|
||||||
var err error
|
|
||||||
|
|
||||||
switch v := ret.Get(0).(type) {
|
|
||||||
case func([]byte) int:
|
|
||||||
n = v(p)
|
|
||||||
case int:
|
|
||||||
n = v
|
|
||||||
default:
|
|
||||||
n = len(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := ret.Get(1).(type) {
|
|
||||||
case func([]byte) error:
|
|
||||||
err = v(p)
|
|
||||||
case error:
|
|
||||||
err = v
|
|
||||||
default:
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockWriter) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPMethods(t *testing.T) {
|
|
||||||
addr := new(MockAddr)
|
|
||||||
addr.On("String").Return("1.2.3.4:1234")
|
|
||||||
|
|
||||||
rw := new(MockReadWriter)
|
|
||||||
hs := New(rw, rw, addr)
|
|
||||||
|
|
||||||
assert.Equal(t, addr, hs.RemoteAddr())
|
|
||||||
|
|
||||||
reqMW := new(MockRequestMiddleware)
|
|
||||||
hs.UseRequestMiddleware(reqMW)
|
|
||||||
assert.Equal(t, 1, len(hs.RequestMiddlewares()))
|
|
||||||
assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
|
|
||||||
|
|
||||||
respMW := new(MockResponseMiddleware)
|
|
||||||
hs.UseResponseMiddleware(respMW)
|
|
||||||
assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
|
|
||||||
assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
|
|
||||||
|
|
||||||
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
|
||||||
hs.SetRequestHeader(reqH)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestApplyMiddlewares(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware)
|
|
||||||
apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
|
|
||||||
verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
|
|
||||||
expectErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "apply request middleware success",
|
|
||||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
|
||||||
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
h := args.Get(0).(header.RequestHeader)
|
|
||||||
h.Set("X-Middleware", "true")
|
|
||||||
}).Return(nil)
|
|
||||||
hs.UseRequestMiddleware(reqMW)
|
|
||||||
},
|
|
||||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
|
||||||
return hs.ApplyRequestMiddlewares(reqH)
|
|
||||||
},
|
|
||||||
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
|
|
||||||
assert.Equal(t, "true", reqH.Value("X-Middleware"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "apply response middleware success",
|
|
||||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
|
||||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
h := args.Get(0).(header.ResponseHeader)
|
|
||||||
h.Set("X-Resp-Middleware", "true")
|
|
||||||
}).Return(nil)
|
|
||||||
hs.UseResponseMiddleware(respMW)
|
|
||||||
},
|
|
||||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
|
||||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
|
||||||
},
|
|
||||||
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
|
|
||||||
assert.Equal(t, "true", respH.Value("X-Resp-Middleware"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "apply request middleware error",
|
|
||||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
|
||||||
reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError)
|
|
||||||
hs.UseRequestMiddleware(reqMW)
|
|
||||||
},
|
|
||||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
|
||||||
return hs.ApplyRequestMiddlewares(reqH)
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "apply response middleware error",
|
|
||||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
|
||||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError)
|
|
||||||
hs.UseResponseMiddleware(respMW)
|
|
||||||
},
|
|
||||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
|
||||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
|
||||||
},
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
|
||||||
respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
|
||||||
|
|
||||||
addr := new(MockAddr)
|
|
||||||
addr.On("String").Return("1.2.3.4:1234")
|
|
||||||
|
|
||||||
rw := new(MockReadWriter)
|
|
||||||
hs := New(rw, rw, addr)
|
|
||||||
|
|
||||||
reqMW := new(MockRequestMiddleware)
|
|
||||||
respMW := new(MockResponseMiddleware)
|
|
||||||
tt.setup(hs, reqMW, respMW)
|
|
||||||
|
|
||||||
err := tt.apply(hs, reqH, respH)
|
|
||||||
if tt.expectErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
if tt.verify != nil {
|
|
||||||
tt.verify(t, reqH, respH)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
reqMW.AssertExpectations(t)
|
|
||||||
respMW.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloseMethods(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func() (io.Writer, io.Reader)
|
|
||||||
op func(HTTP) error
|
|
||||||
verify func(*testing.T, io.Writer)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Close success",
|
|
||||||
setup: func() (io.Writer, io.Reader) {
|
|
||||||
rw := new(MockReadWriter)
|
|
||||||
rw.On("Close").Return(nil)
|
|
||||||
return rw, rw
|
|
||||||
},
|
|
||||||
op: func(hs HTTP) error { return hs.Close() },
|
|
||||||
verify: func(t *testing.T, w io.Writer) {
|
|
||||||
w.(*MockReadWriter).AssertCalled(t, "Close")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "CloseWrite with CloseWrite implementation",
|
|
||||||
setup: func() (io.Writer, io.Reader) {
|
|
||||||
rw := new(MockReadWriter)
|
|
||||||
rw.On("CloseWrite").Return(nil)
|
|
||||||
return rw, rw
|
|
||||||
},
|
|
||||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
|
||||||
verify: func(t *testing.T, w io.Writer) {
|
|
||||||
w.(*MockReadWriter).AssertCalled(t, "CloseWrite")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "CloseWrite fallback to Close",
|
|
||||||
setup: func() (io.Writer, io.Reader) {
|
|
||||||
rw := new(MockReadWriterOnlyCloser)
|
|
||||||
rw.On("Close").Return(nil)
|
|
||||||
return rw, rw
|
|
||||||
},
|
|
||||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
|
||||||
verify: func(t *testing.T, w io.Writer) {
|
|
||||||
w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Close with No Closer",
|
|
||||||
setup: func() (io.Writer, io.Reader) {
|
|
||||||
w := new(MockWriterOnly)
|
|
||||||
r := new(MockReader)
|
|
||||||
return w, r
|
|
||||||
},
|
|
||||||
op: func(hs HTTP) error { return hs.Close() },
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "CloseWrite with No CloseWrite and No Closer",
|
|
||||||
setup: func() (io.Writer, io.Reader) {
|
|
||||||
w := new(MockWriterOnly)
|
|
||||||
r := new(MockReader)
|
|
||||||
return w, r
|
|
||||||
},
|
|
||||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
addr := new(MockAddr)
|
|
||||||
addr.On("String").Return("1.2.3.4:1234")
|
|
||||||
|
|
||||||
w, r := tt.setup()
|
|
||||||
hs := New(w, r, addr)
|
|
||||||
|
|
||||||
assert.NotPanics(t, func() {
|
|
||||||
err := tt.op(hs)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
if tt.verify != nil {
|
|
||||||
tt.verify(t, w)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSplitHeaderAndBody(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
data []byte
|
|
||||||
delimiterIdx int
|
|
||||||
expectHeader []byte
|
|
||||||
expectBody []byte
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "standard",
|
|
||||||
data: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nBodyContent"),
|
|
||||||
delimiterIdx: 31,
|
|
||||||
expectHeader: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"),
|
|
||||||
expectBody: []byte("BodyContent"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty body",
|
|
||||||
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
|
||||||
delimiterIdx: 15,
|
|
||||||
expectHeader: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
|
||||||
expectBody: []byte(""),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
h, b := splitHeaderAndBody(tt.data, tt.delimiterIdx)
|
|
||||||
assert.Equal(t, tt.expectHeader, h)
|
|
||||||
assert.Equal(t, tt.expectBody, b)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsHTTPHeader(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
buf []byte
|
|
||||||
expect bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid request",
|
|
||||||
buf: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"),
|
|
||||||
expect: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid response",
|
|
||||||
buf: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"),
|
|
||||||
expect: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid start line",
|
|
||||||
buf: []byte("NOT_HTTP /path\r\nHost: example.com\r\n\r\n"),
|
|
||||||
expect: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid header line (no colon)",
|
|
||||||
buf: []byte("GET / HTTP/1.1\r\nInvalidHeaderLine\r\n\r\n"),
|
|
||||||
expect: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid header line (colon at 0)",
|
|
||||||
buf: []byte("GET / HTTP/1.1\r\n: value\r\n\r\n"),
|
|
||||||
expect: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty header section",
|
|
||||||
buf: []byte("GET / HTTP/1.1\r\n\r\n"),
|
|
||||||
expect: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple headers",
|
|
||||||
buf: []byte("GET / HTTP/1.1\r\nH1: V1\r\nH2: V2\r\n\r\n"),
|
|
||||||
expect: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := isHTTPHeader(tt.buf)
|
|
||||||
assert.Equal(t, tt.expect, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRead(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input []byte
|
|
||||||
readLen int
|
|
||||||
expectContent string
|
|
||||||
expectRead int
|
|
||||||
expectErr bool
|
|
||||||
middlewareErr error
|
|
||||||
isHTTP bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid http request",
|
|
||||||
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\nBody"),
|
|
||||||
readLen: 100,
|
|
||||||
expectContent: "Body",
|
|
||||||
expectRead: 54,
|
|
||||||
isHTTP: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-http data",
|
|
||||||
input: []byte("Some random data\r\n\r\nMore data"),
|
|
||||||
readLen: 100,
|
|
||||||
expectContent: "Some random data\r\n\r\nMore data",
|
|
||||||
expectRead: 29,
|
|
||||||
isHTTP: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no delimiter",
|
|
||||||
input: []byte("Partial data without delimiter"),
|
|
||||||
readLen: 100,
|
|
||||||
expectContent: "Partial data without delimiter",
|
|
||||||
expectRead: 30,
|
|
||||||
isHTTP: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "middleware error",
|
|
||||||
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\n"),
|
|
||||||
readLen: 100,
|
|
||||||
middlewareErr: assert.AnError,
|
|
||||||
expectErr: true,
|
|
||||||
isHTTP: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
addr := new(MockAddr)
|
|
||||||
addr.On("String").Return("1.2.3.4:1234")
|
|
||||||
|
|
||||||
reader := new(MockReader)
|
|
||||||
writer := new(MockWriterOnly)
|
|
||||||
|
|
||||||
if tt.expectErr || tt.name == "valid http request" {
|
|
||||||
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
p := args.Get(0).([]byte)
|
|
||||||
copy(p, tt.input)
|
|
||||||
}).Return(len(tt.input), io.EOF).Once()
|
|
||||||
} else {
|
|
||||||
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
p := args.Get(0).([]byte)
|
|
||||||
copy(p, tt.input)
|
|
||||||
}).Return(len(tt.input), nil).Once()
|
|
||||||
}
|
|
||||||
|
|
||||||
hs := New(writer, reader, addr)
|
|
||||||
|
|
||||||
reqMW := new(MockRequestMiddleware)
|
|
||||||
if tt.isHTTP {
|
|
||||||
if tt.middlewareErr != nil {
|
|
||||||
reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr)
|
|
||||||
} else {
|
|
||||||
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
h := args.Get(0).(header.RequestHeader)
|
|
||||||
h.Set("X-Middleware", "true")
|
|
||||||
}).Return(nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hs.UseRequestMiddleware(reqMW)
|
|
||||||
|
|
||||||
p := make([]byte, tt.readLen)
|
|
||||||
n, err := hs.Read(p)
|
|
||||||
|
|
||||||
if tt.expectErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, tt.expectRead, n)
|
|
||||||
if tt.name == "valid http request" {
|
|
||||||
content := string(p[:n])
|
|
||||||
assert.Contains(t, content, "GET / HTTP/1.1\r\n")
|
|
||||||
assert.Contains(t, content, "Host: test\r\n")
|
|
||||||
assert.Contains(t, content, "X-Middleware: true\r\n")
|
|
||||||
assert.True(t, bytes.HasSuffix(p[:n], []byte("\r\n\r\nBody")))
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.isHTTP {
|
|
||||||
reqMW.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
reader.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWrite(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
writes [][]byte
|
|
||||||
expectWritten string
|
|
||||||
expectErr bool
|
|
||||||
middlewareErr error
|
|
||||||
isHTTP bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid http response in one write",
|
|
||||||
writes: [][]byte{
|
|
||||||
[]byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
|
|
||||||
},
|
|
||||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
|
||||||
isHTTP: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid http response in multiple writes",
|
|
||||||
writes: [][]byte{
|
|
||||||
[]byte("HTTP/1.1 200 OK\r\n"),
|
|
||||||
[]byte("Content-Length: 4\r\n\r\n"),
|
|
||||||
[]byte("Body"),
|
|
||||||
},
|
|
||||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
|
||||||
isHTTP: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-http data",
|
|
||||||
writes: [][]byte{
|
|
||||||
[]byte("Random data with delimiter\r\n\r\nFlush"),
|
|
||||||
},
|
|
||||||
expectWritten: "Random data with delimiter\r\n\r\nFlush",
|
|
||||||
isHTTP: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bypass buffering",
|
|
||||||
writes: [][]byte{
|
|
||||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
|
||||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
|
||||||
},
|
|
||||||
expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
|
|
||||||
"HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
|
|
||||||
isHTTP: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "middleware error",
|
|
||||||
writes: [][]byte{
|
|
||||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
|
||||||
},
|
|
||||||
middlewareErr: assert.AnError,
|
|
||||||
expectErr: true,
|
|
||||||
isHTTP: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
addr := new(MockAddr)
|
|
||||||
addr.On("String").Return("1.2.3.4:1234")
|
|
||||||
|
|
||||||
var writtenData bytes.Buffer
|
|
||||||
writer := new(MockWriter)
|
|
||||||
|
|
||||||
writer.On("Write", mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
p := args.Get(0).([]byte)
|
|
||||||
writtenData.Write(p)
|
|
||||||
}).Return(func(p []byte) int {
|
|
||||||
return len(p)
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
reader := new(MockReader)
|
|
||||||
hs := New(writer, reader, addr)
|
|
||||||
|
|
||||||
respMW := new(MockResponseMiddleware)
|
|
||||||
if tt.isHTTP {
|
|
||||||
if tt.middlewareErr != nil {
|
|
||||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr)
|
|
||||||
} else {
|
|
||||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
h := args.Get(0).(header.ResponseHeader)
|
|
||||||
h.Set("X-Resp-Middleware", "true")
|
|
||||||
}).Return(nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
hs.UseResponseMiddleware(respMW)
|
|
||||||
|
|
||||||
var totalN int
|
|
||||||
var err error
|
|
||||||
for _, w := range tt.writes {
|
|
||||||
var n int
|
|
||||||
n, err = hs.Write(w)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
totalN += n
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.expectErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
written := writtenData.String()
|
|
||||||
if strings.HasPrefix(tt.expectWritten, "HTTP/") {
|
|
||||||
assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
|
|
||||||
assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
|
|
||||||
if strings.Contains(tt.expectWritten, "Content-Length: 4") {
|
|
||||||
assert.Contains(t, written, "Content-Length: 4\r\n")
|
|
||||||
}
|
|
||||||
assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, tt.expectWritten, written)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.isHTTP {
|
|
||||||
respMW.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
if tt.middlewareErr == nil {
|
|
||||||
writer.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteErrors(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func() (io.Writer, io.Reader)
|
|
||||||
data []byte
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "write error in writeHeaderAndBody",
|
|
||||||
setup: func() (io.Writer, io.Reader) {
|
|
||||||
writer := new(MockWriter)
|
|
||||||
writer.On("Write", mock.Anything).Return(0, assert.AnError)
|
|
||||||
reader := new(MockReader)
|
|
||||||
return writer, reader
|
|
||||||
},
|
|
||||||
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "write error in writeHeaderAndBody second write",
|
|
||||||
setup: func() (io.Writer, io.Reader) {
|
|
||||||
writer := new(MockWriter)
|
|
||||||
writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once()
|
|
||||||
writer.On("Write", mock.Anything).Return(0, assert.AnError).Once()
|
|
||||||
reader := new(MockReader)
|
|
||||||
return writer, reader
|
|
||||||
},
|
|
||||||
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "write error in writeRawBuffer",
|
|
||||||
setup: func() (io.Writer, io.Reader) {
|
|
||||||
writer := new(MockWriter)
|
|
||||||
writer.On("Write", mock.Anything).Return(0, assert.AnError)
|
|
||||||
reader := new(MockReader)
|
|
||||||
return writer, reader
|
|
||||||
},
|
|
||||||
data: []byte("Not HTTP\r\n\r\nFlush"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
addr := new(MockAddr)
|
|
||||||
addr.On("String").Return("1.2.3.4:1234")
|
|
||||||
|
|
||||||
w, r := tt.setup()
|
|
||||||
hs := New(w, r, addr)
|
|
||||||
|
|
||||||
_, err := hs.Write(tt.data)
|
|
||||||
assert.Error(t, err)
|
|
||||||
|
|
||||||
w.(*MockWriter).AssertExpectations(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadEOF(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func() io.Reader
|
|
||||||
expectN int
|
|
||||||
expectErr error
|
|
||||||
expectContent string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "read eof",
|
|
||||||
setup: func() io.Reader {
|
|
||||||
reader := new(MockReader)
|
|
||||||
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
p := args.Get(0).([]byte)
|
|
||||||
copy(p, "data")
|
|
||||||
}).Return(4, io.EOF)
|
|
||||||
return reader
|
|
||||||
},
|
|
||||||
expectN: 4,
|
|
||||||
expectErr: io.EOF,
|
|
||||||
expectContent: "data",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
addr := new(MockAddr)
|
|
||||||
addr.On("String").Return("1.2.3.4:1234")
|
|
||||||
|
|
||||||
reader := tt.setup()
|
|
||||||
hs := New(nil, reader, addr)
|
|
||||||
|
|
||||||
p := make([]byte, 100)
|
|
||||||
n, err := hs.Read(p)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.expectN, n)
|
|
||||||
assert.Equal(t, tt.expectErr, err)
|
|
||||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
|
||||||
|
|
||||||
reader.(*MockReader).AssertExpectations(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
package stream
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"tunnel_pls/internal/http/header"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (hs *http) Write(p []byte) (int, error) {
|
|
||||||
if hs.shouldBypassBuffering(p) {
|
|
||||||
hs.respHeader = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if hs.respHeader != nil {
|
|
||||||
return hs.writer.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.buf = append(hs.buf, p...)
|
|
||||||
|
|
||||||
headerEndIdx := bytes.Index(hs.buf, DELIMITER)
|
|
||||||
if headerEndIdx == -1 {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return hs.processBufferedResponse(p, headerEndIdx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) shouldBypassBuffering(p []byte) bool {
|
|
||||||
return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
|
|
||||||
headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx)
|
|
||||||
|
|
||||||
if !isHTTPHeader(headerByte) {
|
|
||||||
return hs.writeRawBuffer()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.buf = nil
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) writeRawBuffer() (int, error) {
|
|
||||||
_, err := hs.writer.Write(hs.buf)
|
|
||||||
length := len(hs.buf)
|
|
||||||
hs.buf = nil
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return length, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error {
|
|
||||||
resphf, err := header.NewResponse(headerByte)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
hs.respHeader = resphf
|
|
||||||
finalHeader := resphf.Finalize()
|
|
||||||
|
|
||||||
if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error {
|
|
||||||
if _, err := hs.writer.Write(header); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(bodyByte) > 0 {
|
|
||||||
if _, err := hs.writer.Write(bodyByte); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
+9
-28
@@ -5,8 +5,6 @@ import (
|
|||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -14,20 +12,7 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
|
||||||
rsaGenerateKey = rsa.GenerateKey
|
|
||||||
pemEncode = pem.Encode
|
|
||||||
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
|
|
||||||
return ssh.NewPublicKey(key)
|
|
||||||
}
|
|
||||||
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
|
|
||||||
return w.Write(data)
|
|
||||||
}
|
|
||||||
osOpenFile = os.OpenFile
|
|
||||||
)
|
|
||||||
|
|
||||||
func GenerateSSHKeyIfNotExist(keyPath string) error {
|
func GenerateSSHKeyIfNotExist(keyPath string) error {
|
||||||
var errGroup = make([]error, 0)
|
|
||||||
if _, err := os.Stat(keyPath); err == nil {
|
if _, err := os.Stat(keyPath); err == nil {
|
||||||
log.Printf("SSH key already exists at %s", keyPath)
|
log.Printf("SSH key already exists at %s", keyPath)
|
||||||
return nil
|
return nil
|
||||||
@@ -35,7 +20,7 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
|
|||||||
|
|
||||||
log.Printf("SSH key not found at %s, generating new key pair...", keyPath)
|
log.Printf("SSH key not found at %s, generating new key pair...", keyPath)
|
||||||
|
|
||||||
privateKey, err := rsaGenerateKey(rand.Reader, 4096)
|
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -50,37 +35,33 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKeyFile, err := osOpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
privateKeyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func(privateKeyFile *os.File) {
|
defer privateKeyFile.Close()
|
||||||
errGroup = append(errGroup, privateKeyFile.Close())
|
|
||||||
}(privateKeyFile)
|
|
||||||
|
|
||||||
if err := pemEncode(privateKeyFile, privateKeyPEM); err != nil {
|
if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
publicKey, err := sshNewPublicKey(&privateKey.PublicKey)
|
publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pubKeyPath := keyPath + ".pub"
|
pubKeyPath := keyPath + ".pub"
|
||||||
pubKeyFile, err := osOpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func(pubKeyFile *os.File) {
|
defer pubKeyFile.Close()
|
||||||
errGroup = append(errGroup, pubKeyFile.Close())
|
|
||||||
}(pubKeyFile)
|
|
||||||
|
|
||||||
_, err = pubKeyWrite(pubKeyFile, ssh.MarshalAuthorizedKey(publicKey))
|
_, err = pubKeyFile.Write(ssh.MarshalAuthorizedKey(publicKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("SSH key pair generated successfully at %s and %s", keyPath, pubKeyPath)
|
log.Printf("SSH key pair generated successfully at %s and %s", keyPath, pubKeyPath)
|
||||||
return errors.Join(errGroup...)
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,235 +0,0 @@
|
|||||||
package key
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/pem"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGenerateSSHKeyIfNotExist(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setup func(t *testing.T, tempDir string) string
|
|
||||||
mockSetup func() func()
|
|
||||||
wantErr bool
|
|
||||||
errStr string
|
|
||||||
verify func(t *testing.T, keyPath string)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "GenerateNewKey",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
return filepath.Join(tempDir, "id_rsa")
|
|
||||||
},
|
|
||||||
verify: func(t *testing.T, keyPath string) {
|
|
||||||
pubKeyPath := keyPath + ".pub"
|
|
||||||
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
|
||||||
t.Errorf("Private key file not created")
|
|
||||||
}
|
|
||||||
if _, err := os.Stat(pubKeyPath); os.IsNotExist(err) {
|
|
||||||
t.Errorf("Public key file not created")
|
|
||||||
}
|
|
||||||
privateKeyBytes, err := os.ReadFile(keyPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to read private key: %v", err)
|
|
||||||
}
|
|
||||||
if _, err = ssh.ParseRawPrivateKey(privateKeyBytes); err != nil {
|
|
||||||
t.Errorf("Failed to parse private key: %v", err)
|
|
||||||
}
|
|
||||||
publicKeyBytes, err := os.ReadFile(pubKeyPath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to read public key: %v", err)
|
|
||||||
}
|
|
||||||
if _, _, _, _, err = ssh.ParseAuthorizedKey(publicKeyBytes); err != nil {
|
|
||||||
t.Errorf("Failed to parse public key: %v", err)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DoNotOverwriteExistingKey",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
keyPath := filepath.Join(tempDir, "existing_id_rsa")
|
|
||||||
dummyPrivate := "dummy private"
|
|
||||||
dummyPublic := "dummy public"
|
|
||||||
if err := os.WriteFile(keyPath, []byte(dummyPrivate), 0600); err != nil {
|
|
||||||
t.Fatalf("Failed to create dummy private key: %v", err)
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(keyPath+".pub", []byte(dummyPublic), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to create dummy public key: %v", err)
|
|
||||||
}
|
|
||||||
return keyPath
|
|
||||||
},
|
|
||||||
verify: func(t *testing.T, keyPath string) {
|
|
||||||
gotPrivate, _ := os.ReadFile(keyPath)
|
|
||||||
if string(gotPrivate) != "dummy private" {
|
|
||||||
t.Errorf("Private key was overwritten")
|
|
||||||
}
|
|
||||||
gotPublic, _ := os.ReadFile(keyPath + ".pub")
|
|
||||||
if string(gotPublic) != "dummy public" {
|
|
||||||
t.Errorf("Public key was overwritten")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "CreateNestedDirectories",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
return filepath.Join(tempDir, "nested", "dir", "id_rsa")
|
|
||||||
},
|
|
||||||
verify: func(t *testing.T, keyPath string) {
|
|
||||||
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
|
||||||
t.Errorf("Private key file not created in nested directory")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "FailureMkdirAll",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
dirPath := filepath.Join(tempDir, "file_as_dir")
|
|
||||||
if err := os.WriteFile(dirPath, []byte("not a dir"), 0644); err != nil {
|
|
||||||
t.Fatalf("Failed to create file: %v", err)
|
|
||||||
}
|
|
||||||
return filepath.Join(dirPath, "id_rsa")
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "PrivateExistsPublicMissing",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
keyPath := filepath.Join(tempDir, "partial_id_rsa")
|
|
||||||
if err := os.WriteFile(keyPath, []byte("private"), 0600); err != nil {
|
|
||||||
t.Fatalf("Failed to create private key: %v", err)
|
|
||||||
}
|
|
||||||
return keyPath
|
|
||||||
},
|
|
||||||
verify: func(t *testing.T, keyPath string) {
|
|
||||||
if _, err := os.Stat(keyPath + ".pub"); !os.IsNotExist(err) {
|
|
||||||
t.Errorf("Public key should NOT have been created if private key existed")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "FailureRSAGenerateKey",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
return filepath.Join(tempDir, "fail_rsa")
|
|
||||||
},
|
|
||||||
mockSetup: func() func() {
|
|
||||||
old := rsaGenerateKey
|
|
||||||
rsaGenerateKey = func(random io.Reader, bits int) (*rsa.PrivateKey, error) {
|
|
||||||
return nil, errors.New("rsa error")
|
|
||||||
}
|
|
||||||
return func() { rsaGenerateKey = old }
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
errStr: "rsa error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "FailureOpenFilePrivate",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
return filepath.Join(tempDir, "fail_open_private")
|
|
||||||
},
|
|
||||||
mockSetup: func() func() {
|
|
||||||
old := osOpenFile
|
|
||||||
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
|
|
||||||
return nil, errors.New("open error")
|
|
||||||
}
|
|
||||||
return func() { osOpenFile = old }
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
errStr: "open error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "FailurePemEncode",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
return filepath.Join(tempDir, "fail_pem")
|
|
||||||
},
|
|
||||||
mockSetup: func() func() {
|
|
||||||
old := pemEncode
|
|
||||||
pemEncode = func(out io.Writer, b *pem.Block) error {
|
|
||||||
return errors.New("pem error")
|
|
||||||
}
|
|
||||||
return func() { pemEncode = old }
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
errStr: "pem error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "FailureSSHNewPublicKey",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
return filepath.Join(tempDir, "fail_ssh")
|
|
||||||
},
|
|
||||||
mockSetup: func() func() {
|
|
||||||
old := sshNewPublicKey
|
|
||||||
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
|
|
||||||
return nil, errors.New("ssh error")
|
|
||||||
}
|
|
||||||
return func() { sshNewPublicKey = old }
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
errStr: "ssh error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "FailureOpenFilePublic",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
return filepath.Join(tempDir, "fail_open_public")
|
|
||||||
},
|
|
||||||
mockSetup: func() func() {
|
|
||||||
old := osOpenFile
|
|
||||||
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
|
|
||||||
if filepath.Ext(name) == ".pub" {
|
|
||||||
return nil, errors.New("open pub error")
|
|
||||||
}
|
|
||||||
return os.OpenFile(name, flag, perm)
|
|
||||||
}
|
|
||||||
return func() { osOpenFile = old }
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
errStr: "open pub error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "FailurePubKeyWrite",
|
|
||||||
setup: func(t *testing.T, tempDir string) string {
|
|
||||||
return filepath.Join(tempDir, "fail_write")
|
|
||||||
},
|
|
||||||
mockSetup: func() func() {
|
|
||||||
old := pubKeyWrite
|
|
||||||
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
|
|
||||||
return 0, errors.New("write error")
|
|
||||||
}
|
|
||||||
return func() { pubKeyWrite = old }
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
errStr: "write error",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
keyPath := tt.setup(t, tempDir)
|
|
||||||
if tt.mockSetup != nil {
|
|
||||||
cleanup := tt.mockSetup()
|
|
||||||
defer cleanup()
|
|
||||||
}
|
|
||||||
|
|
||||||
err := GenerateSSHKeyIfNotExist(keyPath)
|
|
||||||
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if tt.wantErr && tt.errStr != "" && err != nil && err.Error() != tt.errStr {
|
|
||||||
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErrStr %v", err, tt.errStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.verify != nil {
|
|
||||||
tt.verify(t, keyPath)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"tunnel_pls/internal/http/header"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ForwardedFor struct {
|
|
||||||
addr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewForwardedFor(addr net.Addr) *ForwardedFor {
|
|
||||||
return &ForwardedFor{addr: addr}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error {
|
|
||||||
host, _, err := net.SplitHostPort(ff.addr.String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
header.Set("X-Forwarded-For", host)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,126 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockRequestHeader struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockRequestHeader) Value(key string) string {
|
|
||||||
return m.Called(key).String(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockRequestHeader) Set(key string, value string) {
|
|
||||||
m.Called(key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockRequestHeader) Remove(key string) {
|
|
||||||
m.Called(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockRequestHeader) Finalize() []byte {
|
|
||||||
return m.Called().Get(0).([]byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockRequestHeader) Method() string {
|
|
||||||
return m.Called().String(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockRequestHeader) Path() string {
|
|
||||||
return m.Called().String(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockRequestHeader) Version() string {
|
|
||||||
return m.Called().String(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestForwardedFor_HandleRequest(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
addr net.Addr
|
|
||||||
expectedHost string
|
|
||||||
expectError bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid IPv4 address",
|
|
||||||
addr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 8080},
|
|
||||||
expectedHost: "192.168.1.100",
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid IPv6 address",
|
|
||||||
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 8080},
|
|
||||||
expectedHost: "2001:db8::ff00:42:8329",
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid address format",
|
|
||||||
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
|
||||||
expectedHost: "",
|
|
||||||
expectError: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid IPv4 address with port",
|
|
||||||
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
|
|
||||||
expectedHost: "127.0.0.1",
|
|
||||||
expectError: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
ff := NewForwardedFor(tc.addr)
|
|
||||||
reqHeader := new(mockRequestHeader)
|
|
||||||
|
|
||||||
if !tc.expectError {
|
|
||||||
reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return()
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ff.HandleRequest(reqHeader)
|
|
||||||
|
|
||||||
if tc.expectError {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
reqHeader.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewForwardedFor(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
addr net.Addr
|
|
||||||
expectAddr net.Addr
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "IPv4 address",
|
|
||||||
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
|
|
||||||
expectAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv6 address",
|
|
||||||
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
|
|
||||||
expectAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Unix address",
|
|
||||||
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
|
||||||
expectAddr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
ff := NewForwardedFor(tc.addr)
|
|
||||||
assert.Equal(t, tc.expectAddr.String(), ff.addr.String())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"tunnel_pls/internal/http/header"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RequestMiddleware interface {
|
|
||||||
HandleRequest(header header.RequestHeader) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponseMiddleware interface {
|
|
||||||
HandleResponse(header header.ResponseHeader, body []byte) error
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"tunnel_pls/internal/http/header"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TunnelFingerprint struct{}
|
|
||||||
|
|
||||||
func NewTunnelFingerprint() *TunnelFingerprint {
|
|
||||||
return &TunnelFingerprint{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error {
|
|
||||||
header.Set("Server", "Tunnel Please")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
package middleware
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockResponseHeader struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockResponseHeader) Value(key string) string {
|
|
||||||
return m.Called(key).String(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockResponseHeader) Set(key string, value string) {
|
|
||||||
m.Called(key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockResponseHeader) Remove(key string) {
|
|
||||||
m.Called(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockResponseHeader) Finalize() []byte {
|
|
||||||
return m.Called().Get(0).([]byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTunnelFingerprintHandleResponse(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
expected map[string]string
|
|
||||||
body []byte
|
|
||||||
wantErr error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Sets Server Header",
|
|
||||||
expected: map[string]string{"Server": "Tunnel Please"},
|
|
||||||
body: []byte("Sample body"),
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Overwrites Server Header",
|
|
||||||
expected: map[string]string{"Server": "Tunnel Please"},
|
|
||||||
body: nil,
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockHeader := new(mockResponseHeader)
|
|
||||||
for k, v := range tt.expected {
|
|
||||||
mockHeader.On("Set", k, v).Return()
|
|
||||||
}
|
|
||||||
|
|
||||||
tunnelFingerprint := NewTunnelFingerprint()
|
|
||||||
|
|
||||||
err := tunnelFingerprint.HandleResponse(mockHeader, tt.body)
|
|
||||||
assert.ErrorIs(t, err, tt.wantErr)
|
|
||||||
mockHeader.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewTunnelFingerprint(t *testing.T) {
|
|
||||||
instance := NewTunnelFingerprint()
|
|
||||||
assert.NotNil(t, instance)
|
|
||||||
}
|
|
||||||
+49
-36
@@ -3,40 +3,63 @@ package port
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"tunnel_pls/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Port interface {
|
type Manager interface {
|
||||||
AddRange(startPort, endPort uint16) error
|
AddPortRange(startPort, endPort uint16) error
|
||||||
Unassigned() (uint16, bool)
|
GetUnassignedPort() (uint16, bool)
|
||||||
SetStatus(port uint16, assigned bool) error
|
SetPortStatus(port uint16, assigned bool) error
|
||||||
Claim(port uint16) (claimed bool)
|
GetPortStatus(port uint16) (bool, bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type port struct {
|
type manager struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ports map[uint16]bool
|
ports map[uint16]bool
|
||||||
sortedPorts []uint16
|
sortedPorts []uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func New() Port {
|
var Default Manager = &manager{
|
||||||
return &port{
|
ports: make(map[uint16]bool),
|
||||||
ports: make(map[uint16]bool),
|
sortedPorts: []uint16{},
|
||||||
sortedPorts: []uint16{},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *port) AddRange(startPort, endPort uint16) error {
|
func init() {
|
||||||
|
rawRange := config.Getenv("ALLOWED_PORTS", "")
|
||||||
|
if rawRange == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
splitRange := strings.Split(rawRange, "-")
|
||||||
|
if len(splitRange) != 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
start, err := strconv.ParseUint(splitRange[0], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
end, err := strconv.ParseUint(splitRange[1], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = Default.AddPortRange(uint16(start), uint16(end))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *manager) AddPortRange(startPort, endPort uint16) error {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
if startPort > endPort {
|
if startPort > endPort {
|
||||||
return fmt.Errorf("start port cannot be greater than end port")
|
return fmt.Errorf("start port cannot be greater than end port")
|
||||||
}
|
}
|
||||||
for index := startPort; index <= endPort; index++ {
|
for port := startPort; port <= endPort; port++ {
|
||||||
if _, exists := pm.ports[index]; !exists {
|
if _, exists := pm.ports[port]; !exists {
|
||||||
pm.ports[index] = false
|
pm.ports[port] = false
|
||||||
pm.sortedPorts = append(pm.sortedPorts, index)
|
pm.sortedPorts = append(pm.sortedPorts, port)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sort.Slice(pm.sortedPorts, func(i, j int) bool {
|
sort.Slice(pm.sortedPorts, func(i, j int) bool {
|
||||||
@@ -45,19 +68,20 @@ func (pm *port) AddRange(startPort, endPort uint16) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *port) Unassigned() (uint16, bool) {
|
func (pm *manager) GetUnassignedPort() (uint16, bool) {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
for _, index := range pm.sortedPorts {
|
for _, port := range pm.sortedPorts {
|
||||||
if !pm.ports[index] {
|
if !pm.ports[port] {
|
||||||
return index, true
|
pm.ports[port] = true
|
||||||
|
return port, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0, false
|
return 0, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *port) SetStatus(port uint16, assigned bool) error {
|
func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
|
||||||
pm.mu.Lock()
|
pm.mu.Lock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
@@ -65,21 +89,10 @@ func (pm *port) SetStatus(port uint16, assigned bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *port) Claim(port uint16) (claimed bool) {
|
func (pm *manager) GetPortStatus(port uint16) (bool, bool) {
|
||||||
pm.mu.Lock()
|
pm.mu.RLock()
|
||||||
defer pm.mu.Unlock()
|
defer pm.mu.RUnlock()
|
||||||
|
|
||||||
status, exists := pm.ports[port]
|
status, exists := pm.ports[port]
|
||||||
|
return status, exists
|
||||||
if exists && status {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
pm.ports[port] = true
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
pm.ports[port] = true
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,114 +0,0 @@
|
|||||||
package port
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestAddRange(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
startPort uint16
|
|
||||||
endPort uint16
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"normal range", 1000, 1002, false},
|
|
||||||
{"invalid range", 2000, 1999, true},
|
|
||||||
{"single port range", 3000, 3000, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
pm := New()
|
|
||||||
err := pm.AddRange(tt.startPort, tt.endPort)
|
|
||||||
if tt.wantErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnassigned(t *testing.T) {
|
|
||||||
pm := New()
|
|
||||||
_ = pm.AddRange(1000, 1002)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
status map[uint16]bool
|
|
||||||
want uint16
|
|
||||||
wantOk bool
|
|
||||||
}{
|
|
||||||
{"all unassigned", map[uint16]bool{1000: false, 1001: false, 1002: false}, 1000, true},
|
|
||||||
{"some assigned", map[uint16]bool{1000: true, 1001: false, 1002: true}, 1001, true},
|
|
||||||
{"all assigned", map[uint16]bool{1000: true, 1001: true, 1002: true}, 0, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
for k, v := range tt.status {
|
|
||||||
_ = pm.SetStatus(k, v)
|
|
||||||
}
|
|
||||||
got, gotOk := pm.Unassigned()
|
|
||||||
assert.Equal(t, tt.want, got)
|
|
||||||
assert.Equal(t, tt.wantOk, gotOk)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetStatus(t *testing.T) {
|
|
||||||
pm := New()
|
|
||||||
_ = pm.AddRange(1000, 1002)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
port uint16
|
|
||||||
assigned bool
|
|
||||||
}{
|
|
||||||
{"assign port 1000", 1000, true},
|
|
||||||
{"unassign port 1001", 1001, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
err := pm.SetStatus(tt.port, tt.assigned)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
status, ok := pm.(*port).ports[tt.port]
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, tt.assigned, status)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClaim(t *testing.T) {
|
|
||||||
pm := New()
|
|
||||||
_ = pm.AddRange(1000, 1002)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
port uint16
|
|
||||||
status bool
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"claim unassigned port", 1000, false, true},
|
|
||||||
{"claim already assigned port", 1001, true, false},
|
|
||||||
{"claim non-existent port", 5000, false, true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if _, exists := pm.(*port).ports[tt.port]; exists {
|
|
||||||
_ = pm.SetStatus(tt.port, tt.status)
|
|
||||||
}
|
|
||||||
|
|
||||||
got := pm.Claim(tt.port)
|
|
||||||
assert.Equal(t, tt.want, got)
|
|
||||||
|
|
||||||
finalState := pm.(*port).ports[tt.port]
|
|
||||||
assert.True(t, finalState)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+11
-34
@@ -1,41 +1,18 @@
|
|||||||
package random
|
package random
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
mathrand "math/rand"
|
||||||
"fmt"
|
"strings"
|
||||||
"io"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
func GenerateRandomString(length int) string {
|
||||||
ErrInvalidLength = fmt.Errorf("invalid length")
|
const charset = "abcdefghijklmnopqrstuvwxyz"
|
||||||
)
|
seededRand := mathrand.New(mathrand.NewSource(time.Now().UnixNano() + int64(mathrand.Intn(9999))))
|
||||||
|
var result strings.Builder
|
||||||
type Random interface {
|
for i := 0; i < length; i++ {
|
||||||
String(length int) (string, error)
|
randomIndex := seededRand.Intn(len(charset))
|
||||||
}
|
result.WriteString(string(charset[randomIndex]))
|
||||||
|
|
||||||
type random struct {
|
|
||||||
reader io.Reader
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() Random {
|
|
||||||
return &random{reader: rand.Reader}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ran *random) String(length int) (string, error) {
|
|
||||||
if length < 0 {
|
|
||||||
return "", ErrInvalidLength
|
|
||||||
}
|
}
|
||||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
return result.String()
|
||||||
b := make([]byte, length)
|
|
||||||
|
|
||||||
if _, err := ran.reader.Read(b); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range b {
|
|
||||||
b[i] = charset[int(b[i])%len(charset)]
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(b), nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,70 +0,0 @@
|
|||||||
package random
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestRandom_String(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
length int
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{"ValidLengthZero", 0, false},
|
|
||||||
{"ValidPositiveLength", 10, false},
|
|
||||||
{"NegativeLength", -1, true},
|
|
||||||
{"VeryLargeLength", 1_000_000, false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
randomizer := New()
|
|
||||||
|
|
||||||
result, err := randomizer.String(tt.length)
|
|
||||||
if tt.wantErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Len(t, result, tt.length)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRandomWithFailingReader_String(t *testing.T) {
|
|
||||||
errBrainrot := assert.AnError
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
reader io.Reader
|
|
||||||
expectErr error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "failing reader",
|
|
||||||
reader: func() io.Reader {
|
|
||||||
return &failingReader{err: errBrainrot}
|
|
||||||
}(),
|
|
||||||
expectErr: errBrainrot,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
randomizer := &random{reader: tt.reader}
|
|
||||||
result, err := randomizer.String(20)
|
|
||||||
assert.ErrorIs(t, err, tt.expectErr)
|
|
||||||
assert.Empty(t, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type failingReader struct {
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *failingReader) Read(p []byte) (int, error) {
|
|
||||||
return 0, f.err
|
|
||||||
}
|
|
||||||
@@ -1,695 +0,0 @@
|
|||||||
package registry
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/port"
|
|
||||||
"tunnel_pls/session/forwarder"
|
|
||||||
"tunnel_pls/session/interaction"
|
|
||||||
"tunnel_pls/session/lifecycle"
|
|
||||||
"tunnel_pls/session/slug"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockSession struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
|
|
||||||
args := m.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return args.Get(0).(lifecycle.Lifecycle)
|
|
||||||
}
|
|
||||||
func (m *mockSession) Interaction() interaction.Interaction {
|
|
||||||
args := m.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return args.Get(0).(interaction.Interaction)
|
|
||||||
}
|
|
||||||
func (m *mockSession) Forwarder() forwarder.Forwarder {
|
|
||||||
args := m.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return args.Get(0).(forwarder.Forwarder)
|
|
||||||
}
|
|
||||||
func (m *mockSession) Slug() slug.Slug {
|
|
||||||
args := m.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return args.Get(0).(slug.Slug)
|
|
||||||
}
|
|
||||||
func (m *mockSession) Detail() *types.Detail {
|
|
||||||
args := m.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return args.Get(0).(*types.Detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockLifecycle struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *mockLifecycle) Channel() ssh.Channel {
|
|
||||||
args := ml.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return args.Get(0).(ssh.Channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *mockLifecycle) Connection() ssh.Conn {
|
|
||||||
args := ml.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return args.Get(0).(ssh.Conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *mockLifecycle) PortRegistry() port.Port {
|
|
||||||
args := ml.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return args.Get(0).(port.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { ml.Called(channel) }
|
|
||||||
func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { ml.Called(status) }
|
|
||||||
func (ml *mockLifecycle) IsActive() bool { return ml.Called().Bool(0) }
|
|
||||||
func (ml *mockLifecycle) StartedAt() time.Time { return ml.Called().Get(0).(time.Time) }
|
|
||||||
func (ml *mockLifecycle) Close() error { return ml.Called().Error(0) }
|
|
||||||
func (ml *mockLifecycle) User() string { return ml.Called().String(0) }
|
|
||||||
|
|
||||||
type mockSlug struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ms *mockSlug) Set(slug string) { ms.Called(slug) }
|
|
||||||
func (ms *mockSlug) String() string { return ms.Called().String(0) }
|
|
||||||
|
|
||||||
func createMockSession(user ...string) *mockSession {
|
|
||||||
u := "user1"
|
|
||||||
if len(user) > 0 {
|
|
||||||
u = user[0]
|
|
||||||
}
|
|
||||||
m := new(mockSession)
|
|
||||||
ml := new(mockLifecycle)
|
|
||||||
ml.On("User").Return(u).Maybe()
|
|
||||||
m.On("Lifecycle").Return(ml).Maybe()
|
|
||||||
ms := new(mockSlug)
|
|
||||||
ms.On("Set", mock.Anything).Maybe()
|
|
||||||
m.On("Slug").Return(ms).Maybe()
|
|
||||||
m.On("Interaction").Return(nil).Maybe()
|
|
||||||
m.On("Forwarder").Return(nil).Maybe()
|
|
||||||
m.On("Detail").Return(nil).Maybe()
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewRegistry(t *testing.T) {
|
|
||||||
r := NewRegistry()
|
|
||||||
require.NotNil(t, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_Get(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setupFunc func(r *registry)
|
|
||||||
key types.SessionKey
|
|
||||||
wantErr error
|
|
||||||
wantResult bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "session found",
|
|
||||||
setupFunc: func(r *registry) {
|
|
||||||
user := "user1"
|
|
||||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession(user)
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser[user] = map[types.SessionKey]Session{
|
|
||||||
key: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[key] = user
|
|
||||||
},
|
|
||||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
|
||||||
wantErr: nil,
|
|
||||||
wantResult: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "session not found in slugIndex",
|
|
||||||
setupFunc: func(r *registry) {},
|
|
||||||
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
|
|
||||||
wantErr: ErrSessionNotFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "session not found in byUser",
|
|
||||||
setupFunc: func(r *registry) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
|
|
||||||
},
|
|
||||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
|
||||||
wantErr: ErrSessionNotFound,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r := ®istry{
|
|
||||||
byUser: make(map[string]map[types.SessionKey]Session),
|
|
||||||
slugIndex: make(map[types.SessionKey]string),
|
|
||||||
mu: sync.RWMutex{},
|
|
||||||
}
|
|
||||||
tt.setupFunc(r)
|
|
||||||
|
|
||||||
session, err := r.Get(tt.key)
|
|
||||||
|
|
||||||
assert.ErrorIs(t, err, tt.wantErr)
|
|
||||||
assert.Equal(t, tt.wantResult, session != nil)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_GetWithUser(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setupFunc func(r *registry)
|
|
||||||
user string
|
|
||||||
key types.SessionKey
|
|
||||||
wantErr error
|
|
||||||
wantResult bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "session found",
|
|
||||||
setupFunc: func(r *registry) {
|
|
||||||
user := "user1"
|
|
||||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser[user] = map[types.SessionKey]Session{
|
|
||||||
key: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[key] = user
|
|
||||||
},
|
|
||||||
user: "user1",
|
|
||||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
|
||||||
wantErr: nil,
|
|
||||||
wantResult: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "session not found in slugIndex",
|
|
||||||
setupFunc: func(r *registry) {},
|
|
||||||
user: "user1",
|
|
||||||
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
|
|
||||||
wantErr: ErrSessionNotFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "session not found in byUser",
|
|
||||||
setupFunc: func(r *registry) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
|
|
||||||
},
|
|
||||||
user: "user1",
|
|
||||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
|
||||||
wantErr: ErrSessionNotFound,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r := ®istry{
|
|
||||||
byUser: make(map[string]map[types.SessionKey]Session),
|
|
||||||
slugIndex: make(map[types.SessionKey]string),
|
|
||||||
mu: sync.RWMutex{},
|
|
||||||
}
|
|
||||||
tt.setupFunc(r)
|
|
||||||
|
|
||||||
session, err := r.GetWithUser(tt.user, tt.key)
|
|
||||||
|
|
||||||
assert.ErrorIs(t, err, tt.wantErr)
|
|
||||||
assert.Equal(t, tt.wantResult, session != nil)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_Update(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
user string
|
|
||||||
setupFunc func(r *registry) (oldKey, newKey types.SessionKey)
|
|
||||||
wantErr error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "change slug success",
|
|
||||||
user: "user1",
|
|
||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
|
||||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession("user1")
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
|
||||||
oldKey: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[oldKey] = "user1"
|
|
||||||
|
|
||||||
return oldKey, newKey
|
|
||||||
},
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "change slug to already used slug",
|
|
||||||
user: "user1",
|
|
||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
|
||||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
|
||||||
oldKey: session,
|
|
||||||
newKey: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[oldKey] = "user1"
|
|
||||||
r.slugIndex[newKey] = "user1"
|
|
||||||
|
|
||||||
return oldKey, newKey
|
|
||||||
},
|
|
||||||
wantErr: ErrSlugInUse,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "change slug to forbidden slug",
|
|
||||||
user: "user1",
|
|
||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
|
||||||
newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
|
||||||
oldKey: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[oldKey] = "user1"
|
|
||||||
|
|
||||||
return oldKey, newKey
|
|
||||||
},
|
|
||||||
wantErr: ErrForbiddenSlug,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "change slug to invalid slug",
|
|
||||||
user: "user1",
|
|
||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
|
||||||
newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
|
||||||
oldKey: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[oldKey] = "user1"
|
|
||||||
|
|
||||||
return oldKey, newKey
|
|
||||||
},
|
|
||||||
wantErr: ErrInvalidSlug,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "change slug but session not found",
|
|
||||||
user: "user2",
|
|
||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
|
||||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
|
||||||
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
|
||||||
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
|
|
||||||
|
|
||||||
return oldKey, newKey
|
|
||||||
},
|
|
||||||
wantErr: ErrSessionNotFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "change slug but session is not in the map",
|
|
||||||
user: "user2",
|
|
||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
|
||||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
|
||||||
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
|
||||||
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
|
|
||||||
|
|
||||||
return oldKey, newKey
|
|
||||||
},
|
|
||||||
wantErr: ErrSessionNotFound,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "change slug with same slug",
|
|
||||||
user: "user1",
|
|
||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
|
||||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
|
||||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
|
||||||
oldKey: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[oldKey] = "user1"
|
|
||||||
|
|
||||||
return oldKey, newKey
|
|
||||||
},
|
|
||||||
wantErr: ErrSlugUnchanged,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "tcp tunnel cannot change slug",
|
|
||||||
user: "user1",
|
|
||||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
|
||||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
|
||||||
newKey := oldKey
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
|
||||||
oldKey: session,
|
|
||||||
}
|
|
||||||
r.slugIndex[oldKey] = "user1"
|
|
||||||
|
|
||||||
return oldKey, newKey
|
|
||||||
},
|
|
||||||
wantErr: ErrSlugChangeNotAllowed,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
r := ®istry{
|
|
||||||
byUser: make(map[string]map[types.SessionKey]Session),
|
|
||||||
slugIndex: make(map[types.SessionKey]string),
|
|
||||||
mu: sync.RWMutex{},
|
|
||||||
}
|
|
||||||
|
|
||||||
oldKey, newKey := tt.setupFunc(r)
|
|
||||||
|
|
||||||
err := r.Update(tt.user, oldKey, newKey)
|
|
||||||
assert.ErrorIs(t, err, tt.wantErr)
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
_, ok := r.byUser[tt.user][newKey]
|
|
||||||
assert.True(t, ok, "newKey not found in registry")
|
|
||||||
_, ok = r.byUser[tt.user][oldKey]
|
|
||||||
assert.False(t, ok, "oldKey still exists in registry")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_Register(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
user string
|
|
||||||
setupFunc func(r *registry) Key
|
|
||||||
wantOK bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "register new key successfully",
|
|
||||||
user: "user1",
|
|
||||||
setupFunc: func(r *registry) Key {
|
|
||||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
|
||||||
return key
|
|
||||||
},
|
|
||||||
wantOK: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "register already existing key fails",
|
|
||||||
user: "user1",
|
|
||||||
setupFunc: func(r *registry) Key {
|
|
||||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
r.byUser["user1"] = map[Key]Session{key: session}
|
|
||||||
r.slugIndex[key] = "user1"
|
|
||||||
r.mu.Unlock()
|
|
||||||
|
|
||||||
return key
|
|
||||||
},
|
|
||||||
wantOK: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "register multiple keys for same user",
|
|
||||||
user: "user1",
|
|
||||||
setupFunc: func(r *registry) Key {
|
|
||||||
firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession()
|
|
||||||
r.mu.Lock()
|
|
||||||
r.byUser["user1"] = map[Key]Session{firstKey: session}
|
|
||||||
r.slugIndex[firstKey] = "user1"
|
|
||||||
r.mu.Unlock()
|
|
||||||
|
|
||||||
return types.SessionKey{Id: "second", Type: types.TunnelTypeHTTP}
|
|
||||||
},
|
|
||||||
wantOK: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
r := ®istry{
|
|
||||||
byUser: make(map[string]map[Key]Session),
|
|
||||||
slugIndex: make(map[Key]string),
|
|
||||||
mu: sync.RWMutex{},
|
|
||||||
}
|
|
||||||
|
|
||||||
key := tt.setupFunc(r)
|
|
||||||
session := createMockSession()
|
|
||||||
|
|
||||||
ok := r.Register(key, session)
|
|
||||||
assert.Equal(t, tt.wantOK, ok)
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
r.mu.RLock()
|
|
||||||
defer r.mu.RUnlock()
|
|
||||||
assert.Equal(t, session, r.byUser[tt.user][key], "session not stored in byUser")
|
|
||||||
assert.Equal(t, tt.user, r.slugIndex[key], "slugIndex not updated")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_GetAllSessionFromUser(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setupFunc func(r *registry) string
|
|
||||||
expectN int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "user has no sessions",
|
|
||||||
setupFunc: func(r *registry) string {
|
|
||||||
return "user1"
|
|
||||||
},
|
|
||||||
expectN: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user has multiple sessions",
|
|
||||||
setupFunc: func(r *registry) string {
|
|
||||||
user := "user1"
|
|
||||||
key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
|
||||||
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
|
|
||||||
r.mu.Lock()
|
|
||||||
r.byUser[user] = map[Key]Session{
|
|
||||||
key1: createMockSession(),
|
|
||||||
key2: createMockSession(),
|
|
||||||
}
|
|
||||||
r.mu.Unlock()
|
|
||||||
return user
|
|
||||||
},
|
|
||||||
expectN: 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r := ®istry{
|
|
||||||
byUser: make(map[string]map[Key]Session),
|
|
||||||
slugIndex: make(map[Key]string),
|
|
||||||
mu: sync.RWMutex{},
|
|
||||||
}
|
|
||||||
user := tt.setupFunc(r)
|
|
||||||
sessions := r.GetAllSessionFromUser(user)
|
|
||||||
assert.Len(t, sessions, tt.expectN)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRegistry_Remove(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setupFunc func(r *registry) (string, types.SessionKey)
|
|
||||||
key types.SessionKey
|
|
||||||
verify func(*testing.T, *registry, string, types.SessionKey)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "remove existing key",
|
|
||||||
setupFunc: func(r *registry) (string, types.SessionKey) {
|
|
||||||
user := "user1"
|
|
||||||
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
|
||||||
session := createMockSession()
|
|
||||||
r.mu.Lock()
|
|
||||||
r.byUser[user] = map[Key]Session{key: session}
|
|
||||||
r.slugIndex[key] = user
|
|
||||||
r.mu.Unlock()
|
|
||||||
return user, key
|
|
||||||
},
|
|
||||||
verify: func(t *testing.T, r *registry, user string, key types.SessionKey) {
|
|
||||||
_, ok := r.byUser[user][key]
|
|
||||||
assert.False(t, ok, "expected key to be removed from byUser")
|
|
||||||
_, ok = r.slugIndex[key]
|
|
||||||
assert.False(t, ok, "expected key to be removed from slugIndex")
|
|
||||||
_, ok = r.byUser[user]
|
|
||||||
assert.False(t, ok, "expected user to be removed from byUser map")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "remove non-existing key",
|
|
||||||
setupFunc: func(r *registry) (string, types.SessionKey) {
|
|
||||||
return "", types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r := ®istry{
|
|
||||||
byUser: make(map[string]map[Key]Session),
|
|
||||||
slugIndex: make(map[Key]string),
|
|
||||||
mu: sync.RWMutex{},
|
|
||||||
}
|
|
||||||
user, key := tt.setupFunc(r)
|
|
||||||
if user == "" {
|
|
||||||
key = tt.key
|
|
||||||
}
|
|
||||||
r.Remove(key)
|
|
||||||
if tt.verify != nil {
|
|
||||||
tt.verify(t, r, user, key)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsValidSlug(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
slug string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"abc", true},
|
|
||||||
{"abc-123", true},
|
|
||||||
{"a", false},
|
|
||||||
{"verybigdihsixsevenlabubu", false},
|
|
||||||
{"-iamsigma", false},
|
|
||||||
{"ligma-", false},
|
|
||||||
{"invalid$", false},
|
|
||||||
{"valid-slug1", true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.slug, func(t *testing.T) {
|
|
||||||
got := isValidSlug(tt.slug)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("isValidSlug(%q) = %v; want %v", tt.slug, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsValidSlugChar(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
char byte
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{'a', true},
|
|
||||||
{'z', true},
|
|
||||||
{'0', true},
|
|
||||||
{'9', true},
|
|
||||||
{'-', true},
|
|
||||||
{'A', false},
|
|
||||||
{'$', false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(string(tt.char), func(t *testing.T) {
|
|
||||||
got := isValidSlugChar(tt.char)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("isValidSlugChar(%q) = %v; want %v", tt.char, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsForbiddenSlug(t *testing.T) {
|
|
||||||
forbiddenSlugs = map[string]struct{}{
|
|
||||||
"admin": {},
|
|
||||||
"root": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
slug string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"admin", true},
|
|
||||||
{"root", true},
|
|
||||||
{"user", false},
|
|
||||||
{"guest", false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.slug, func(t *testing.T) {
|
|
||||||
got := isForbiddenSlug(tt.slug)
|
|
||||||
if got != tt.want {
|
|
||||||
t.Errorf("isForbiddenSlug(%q) = %v; want %v", tt.slug, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"tunnel_pls/internal/config"
|
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
)
|
|
||||||
|
|
||||||
type httpServer struct {
|
|
||||||
handler *httpHandler
|
|
||||||
config config.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport {
|
|
||||||
return &httpServer{
|
|
||||||
handler: newHTTPHandler(config, sessionRegistry),
|
|
||||||
config: config,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ht *httpServer) Listen() (net.Listener, error) {
|
|
||||||
return net.Listen("tcp", ":"+ht.config.HTTPPort())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ht *httpServer) Serve(listener net.Listener) error {
|
|
||||||
log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort())
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Printf("Error accepting connection: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
go ht.handler.Handler(conn, false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewHTTPServer(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("HTTPPort").Return(port)
|
|
||||||
|
|
||||||
srv := NewHTTPServer(mockConfig, msr)
|
|
||||||
assert.NotNil(t, srv)
|
|
||||||
|
|
||||||
httpSrv, ok := srv.(*httpServer)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
|
|
||||||
assert.NotNil(t, srv)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPServer_Listen(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("HTTPPort").Return(port)
|
|
||||||
srv := NewHTTPServer(mockConfig, msr)
|
|
||||||
|
|
||||||
listener, err := srv.Listen()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, listener)
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPServer_Serve(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("HTTPPort").Return(port)
|
|
||||||
srv := NewHTTPServer(mockConfig, msr)
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = srv.Serve(listener)
|
|
||||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPServer_Serve_AcceptError(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("HTTPPort").Return(port)
|
|
||||||
srv := NewHTTPServer(mockConfig, msr)
|
|
||||||
|
|
||||||
ml := new(mockListener)
|
|
||||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
|
||||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
||||||
|
|
||||||
err := srv.Serve(ml)
|
|
||||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
|
||||||
ml.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPServer_Serve_Success(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("HTTPPort").Return(port)
|
|
||||||
mockConfig.On("HeaderSize").Return(4096)
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
srv := NewHTTPServer(mockConfig, msr)
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
listenerport := listener.Addr().(*net.TCPAddr).Port
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = srv.Serve(listener)
|
|
||||||
}()
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
err = conn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockListener struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockListener) Accept() (net.Conn, error) {
|
|
||||||
args := m.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Error(1)
|
|
||||||
}
|
|
||||||
return args.Get(0).(net.Conn), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockListener) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockListener) Addr() net.Addr {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(net.Addr)
|
|
||||||
}
|
|
||||||
@@ -1,201 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/config"
|
|
||||||
"tunnel_pls/internal/http/header"
|
|
||||||
"tunnel_pls/internal/http/stream"
|
|
||||||
"tunnel_pls/internal/middleware"
|
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type httpHandler struct {
|
|
||||||
config config.Config
|
|
||||||
sessionRegistry registry.Registry
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler {
|
|
||||||
return &httpHandler{
|
|
||||||
config: config,
|
|
||||||
sessionRegistry: sessionRegistry,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
|
|
||||||
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
|
|
||||||
fmt.Sprintf("Location: %s", location) +
|
|
||||||
"Content-Length: 0\r\n" +
|
|
||||||
"Connection: close\r\n" +
|
|
||||||
"\r\n"))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) badRequest(conn net.Conn) error {
|
|
||||||
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
|
|
||||||
defer hh.closeConnection(conn)
|
|
||||||
|
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
|
||||||
buf := make([]byte, hh.config.HeaderSize())
|
|
||||||
n, err := conn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
_ = hh.badRequest(conn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 {
|
|
||||||
_ = hh.badRequest(conn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = conn.SetReadDeadline(time.Time{})
|
|
||||||
|
|
||||||
reqhf, err := header.NewRequest(buf[:n])
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error creating request header: %v", err)
|
|
||||||
_ = hh.badRequest(conn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
slug, err := hh.extractSlug(reqhf)
|
|
||||||
if err != nil {
|
|
||||||
_ = hh.badRequest(conn)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if hh.shouldRedirectToTLS(isTLS) {
|
|
||||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if hh.handlePingRequest(slug, conn) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
|
||||||
Id: slug,
|
|
||||||
Type: types.TunnelTypeHTTP,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hw := stream.New(conn, conn, conn.RemoteAddr())
|
|
||||||
defer func(hw stream.HTTP) {
|
|
||||||
err = hw.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error closing HTTP stream: %v", err)
|
|
||||||
}
|
|
||||||
}(hw)
|
|
||||||
hh.forwardRequest(hw, reqhf, sshSession)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) closeConnection(conn net.Conn) {
|
|
||||||
err := conn.Close()
|
|
||||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
|
||||||
log.Printf("Error closing connection: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
|
|
||||||
host := strings.Split(reqhf.Value("Host"), ".")
|
|
||||||
if len(host) <= 1 {
|
|
||||||
return "", errors.New("invalid host")
|
|
||||||
}
|
|
||||||
return host[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
|
|
||||||
return !isTLS && hh.config.TLSRedirect()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
|
||||||
if slug != "ping" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := conn.Write([]byte(
|
|
||||||
"HTTP/1.1 200 OK\r\n" +
|
|
||||||
"Content-Length: 0\r\n" +
|
|
||||||
"Connection: close\r\n" +
|
|
||||||
"Access-Control-Allow-Origin: *\r\n" +
|
|
||||||
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
|
|
||||||
"Access-Control-Allow-Headers: *\r\n" +
|
|
||||||
"\r\n",
|
|
||||||
))
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Failed to write 200 OK:", err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, hw.RemoteAddr())
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err = channel.Close()
|
|
||||||
if err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
log.Printf("Error closing forwarded channel: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
hh.setupMiddlewares(hw)
|
|
||||||
|
|
||||||
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
|
|
||||||
log.Printf("Failed to forward initial request: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
sshSession.Forwarder().HandleConnection(hw, channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
|
|
||||||
fingerprintMiddleware := middleware.NewTunnelFingerprint()
|
|
||||||
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
|
|
||||||
|
|
||||||
hw.UseResponseMiddleware(fingerprintMiddleware)
|
|
||||||
hw.UseRequestMiddleware(forwardedForMiddleware)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error {
|
|
||||||
hw.SetRequestHeader(initialRequest)
|
|
||||||
|
|
||||||
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
|
|
||||||
return fmt.Errorf("error applying request middlewares: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
|
|
||||||
return fmt.Errorf("error writing to channel: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,717 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
"tunnel_pls/session/forwarder"
|
|
||||||
"tunnel_pls/session/interaction"
|
|
||||||
"tunnel_pls/session/lifecycle"
|
|
||||||
"tunnel_pls/session/slug"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockSessionRegistry struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
|
||||||
args := m.Called(key)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Error(1)
|
|
||||||
}
|
|
||||||
return args.Get(0).(registry.Session), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
|
||||||
args := m.Called(user, key)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Error(1)
|
|
||||||
}
|
|
||||||
return args.Get(0).(registry.Session), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
|
||||||
args := m.Called(user, oldKey, newKey)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
|
||||||
args := m.Called(key, session)
|
|
||||||
return args.Bool(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
|
||||||
m.Called(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
|
||||||
args := m.Called(user)
|
|
||||||
return args.Get(0).([]registry.Session)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(slug.Slug)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockSession struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(lifecycle.Lifecycle)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Interaction() interaction.Interaction {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(interaction.Interaction)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Forwarder() forwarder.Forwarder {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(forwarder.Forwarder)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Slug() slug.Slug {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(slug.Slug)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSession) Detail() *types.Detail {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(*types.Detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockSSHChannel struct {
|
|
||||||
ssh.Channel
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSSHChannel) Write(data []byte) (int, error) {
|
|
||||||
args := m.Called(data)
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSSHChannel) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockForwarder struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
|
||||||
m.Called(dst, src)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) TunnelType() types.TunnelType {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(types.TunnelType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) ForwardedPort() uint16 {
|
|
||||||
args := m.Called()
|
|
||||||
return uint16(args.Int(0))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
|
||||||
m.Called(tunnelType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
|
||||||
m.Called(port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) SetListener(listener net.Listener) {
|
|
||||||
m.Called(listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) Listener() net.Listener {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(net.Listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
|
||||||
args := m.Called(ctx, origin)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
||||||
}
|
|
||||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockConn struct {
|
|
||||||
mock.Mock
|
|
||||||
ReadBuffer *bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConn) LocalAddr() net.Addr {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(net.Addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
|
||||||
args := m.Called(t)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
|
||||||
args := m.Called(t)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
args := m.Called(t)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
|
||||||
if m.ReadBuffer != nil {
|
|
||||||
return m.ReadBuffer.Read(b)
|
|
||||||
}
|
|
||||||
args := m.Called(b)
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
|
||||||
args := m.Called(b)
|
|
||||||
if args.Int(0) == -1 {
|
|
||||||
return len(b), args.Error(1)
|
|
||||||
}
|
|
||||||
return args.Int(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConn) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConn) RemoteAddr() net.Addr {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(net.Addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
type wrappedConn struct {
|
|
||||||
net.Conn
|
|
||||||
remoteAddr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *wrappedConn) RemoteAddr() net.Addr {
|
|
||||||
return c.remoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewHTTPHandler(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockConfig.On("Domain").Return("domain")
|
|
||||||
mockConfig.On("TLSRedirect").Return(false)
|
|
||||||
hh := newHTTPHandler(mockConfig, msr)
|
|
||||||
assert.NotNil(t, hh)
|
|
||||||
assert.Equal(t, msr, hh.sessionRegistry)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandler(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
isTLS bool
|
|
||||||
redirectTLS bool
|
|
||||||
request []byte
|
|
||||||
expected []byte
|
|
||||||
setupMocks func(*MockSessionRegistry)
|
|
||||||
setupConn func() (net.Conn, net.Conn)
|
|
||||||
expectError bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "bad request - invalid host",
|
|
||||||
isTLS: false,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"),
|
|
||||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bad request - missing host",
|
|
||||||
isTLS: false,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
|
||||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "isTLS true and redirectTLS true - no redirect",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: true,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
|
||||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "redirect to TLS",
|
|
||||||
isTLS: false,
|
|
||||||
redirectTLS: true,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"),
|
|
||||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "handle ping request",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
|
||||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "session not found",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
|
||||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
msr.On("Get", types.SessionKey{
|
|
||||||
Id: "test",
|
|
||||||
Type: types.TunnelTypeHTTP,
|
|
||||||
}).Return((registry.Session)(nil), fmt.Errorf("session not found"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bad request - invalid http",
|
|
||||||
isTLS: false,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("INVALID\r\n\r\n"),
|
|
||||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bad request - header too large",
|
|
||||||
isTLS: false,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))),
|
|
||||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bad request - no request",
|
|
||||||
isTLS: false,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte(""),
|
|
||||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "forwarding - open channel fails",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
mockSession := new(MockSession)
|
|
||||||
mockForwarder := new(MockForwarder)
|
|
||||||
|
|
||||||
msr.On("Get", types.SessionKey{
|
|
||||||
Id: "test",
|
|
||||||
Type: types.TunnelTypeHTTP,
|
|
||||||
}).Return(mockSession, nil)
|
|
||||||
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "forwarding - send initial request fails",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
mockSession := new(MockSession)
|
|
||||||
mockForwarder := new(MockForwarder)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
|
||||||
|
|
||||||
msr.On("Get", types.SessionKey{
|
|
||||||
Id: "test",
|
|
||||||
Type: types.TunnelTypeHTTP,
|
|
||||||
}).Return(mockSession, nil)
|
|
||||||
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for range reqCh {
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "forwarding - success",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
|
||||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
mockSession := new(MockSession)
|
|
||||||
mockForwarder := new(MockForwarder)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
|
||||||
|
|
||||||
msr.On("Get", types.SessionKey{
|
|
||||||
Id: "test",
|
|
||||||
Type: types.TunnelTypeHTTP,
|
|
||||||
}).Return(mockSession, nil)
|
|
||||||
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
|
||||||
|
|
||||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
|
||||||
w := args.Get(0).(io.ReadWriter)
|
|
||||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
|
||||||
})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for range reqCh {
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "redirect - write failure",
|
|
||||||
isTLS: false,
|
|
||||||
redirectTLS: true,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupConn: func() (net.Conn, net.Conn) {
|
|
||||||
mc := new(MockConn)
|
|
||||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
|
|
||||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
|
||||||
mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error"))
|
|
||||||
mc.On("Close").Return(nil)
|
|
||||||
return mc, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bad request - write failure",
|
|
||||||
isTLS: false,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupConn: func() (net.Conn, net.Conn) {
|
|
||||||
mc := new(MockConn)
|
|
||||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
|
|
||||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
|
||||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
|
||||||
mc.On("Close").Return(nil)
|
|
||||||
return mc, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "read error - connection failure",
|
|
||||||
isTLS: false,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte(""),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupConn: func() (net.Conn, net.Conn) {
|
|
||||||
mc := new(MockConn)
|
|
||||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
|
||||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
|
||||||
mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer"))
|
|
||||||
mc.On("Close").Return(nil)
|
|
||||||
return mc, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "handle ping request - write failure",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupConn: func() (net.Conn, net.Conn) {
|
|
||||||
mc := new(MockConn)
|
|
||||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
|
||||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
|
||||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
|
||||||
mc.On("Close").Return(nil)
|
|
||||||
return mc, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "close connection - error",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupConn: func() (net.Conn, net.Conn) {
|
|
||||||
mc := new(MockConn)
|
|
||||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
|
||||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
|
||||||
mc.On("Write", mock.Anything).Return(182, nil)
|
|
||||||
mc.On("Close").Return(fmt.Errorf("close error"))
|
|
||||||
return mc, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "forwarding - stream close error",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
mockSession := new(MockSession)
|
|
||||||
mockForwarder := new(MockForwarder)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
|
||||||
|
|
||||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
|
||||||
|
|
||||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return()
|
|
||||||
},
|
|
||||||
setupConn: func() (net.Conn, net.Conn) {
|
|
||||||
mc := new(MockConn)
|
|
||||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
|
||||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
|
||||||
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
|
|
||||||
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
|
||||||
mc.On("RemoteAddr").Return(addr)
|
|
||||||
return mc, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "forwarding - middleware failure",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
mockSession := new(MockSession)
|
|
||||||
mockForwarder := new(MockForwarder)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
|
||||||
|
|
||||||
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
|
|
||||||
return k.Id == "test"
|
|
||||||
})).Return(mockSession, nil)
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
|
||||||
|
|
||||||
reqCh := make(chan *ssh.Request)
|
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
|
||||||
mockSSHChannel.On("Close").Return(nil)
|
|
||||||
},
|
|
||||||
setupConn: func() (net.Conn, net.Conn) {
|
|
||||||
mc := new(MockConn)
|
|
||||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
|
||||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
|
||||||
mc.On("Close").Return(nil).Times(2)
|
|
||||||
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
|
|
||||||
return mc, nil
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "forwarding - channel close error",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
|
||||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
mockSession := new(MockSession)
|
|
||||||
mockForwarder := new(MockForwarder)
|
|
||||||
mockSSHChannel := new(MockSSHChannel)
|
|
||||||
|
|
||||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
|
||||||
reqCh := make(chan *ssh.Request)
|
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
|
||||||
|
|
||||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
|
||||||
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
|
||||||
|
|
||||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
|
||||||
w := args.Get(0).(io.ReadWriter)
|
|
||||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
|
||||||
})
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "forwarding - open channel timeout",
|
|
||||||
isTLS: true,
|
|
||||||
redirectTLS: false,
|
|
||||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
|
||||||
expected: []byte(""),
|
|
||||||
setupMocks: func(msr *MockSessionRegistry) {
|
|
||||||
mockSession := new(MockSession)
|
|
||||||
mockForwarder := new(MockForwarder)
|
|
||||||
|
|
||||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
|
||||||
mockSession.On("Forwarder").Return(mockForwarder)
|
|
||||||
|
|
||||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
|
||||||
|
|
||||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
|
||||||
ctx := args.Get(0).(context.Context)
|
|
||||||
<-ctx.Done()
|
|
||||||
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockSessionRegistry := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return("example.com")
|
|
||||||
mockConfig.On("HTTPPort").Return(port)
|
|
||||||
mockConfig.On("HeaderSize").Return(4096)
|
|
||||||
mockConfig.On("TLSRedirect").Return(true)
|
|
||||||
hh := &httpHandler{
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
config: mockConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.setupMocks != nil {
|
|
||||||
tt.setupMocks(mockSessionRegistry)
|
|
||||||
}
|
|
||||||
|
|
||||||
var serverConn, clientConn net.Conn
|
|
||||||
if tt.setupConn != nil {
|
|
||||||
serverConn, clientConn = tt.setupConn()
|
|
||||||
} else {
|
|
||||||
serverConn, clientConn = net.Pipe()
|
|
||||||
}
|
|
||||||
|
|
||||||
if clientConn != nil {
|
|
||||||
defer func(clientConn net.Conn) {
|
|
||||||
err := clientConn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}(clientConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
|
||||||
var wrappedServerConn net.Conn
|
|
||||||
if _, ok := serverConn.(*MockConn); ok {
|
|
||||||
wrappedServerConn = serverConn
|
|
||||||
} else {
|
|
||||||
wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr}
|
|
||||||
}
|
|
||||||
|
|
||||||
responseChan := make(chan []byte, 1)
|
|
||||||
doneChan := make(chan struct{})
|
|
||||||
|
|
||||||
if clientConn != nil {
|
|
||||||
go func() {
|
|
||||||
defer close(doneChan)
|
|
||||||
var res []byte
|
|
||||||
for {
|
|
||||||
buf := make([]byte, 4096)
|
|
||||||
n, err := clientConn.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Logf("Error reading response: %v", err)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
res = append(res, buf[:n]...)
|
|
||||||
if len(tt.expected) > 0 && len(res) >= len(tt.expected) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
responseChan <- res
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_, err := clientConn.Write(tt.request)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Error writing request: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
} else {
|
|
||||||
close(responseChan)
|
|
||||||
close(doneChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
hh.Handler(wrappedServerConn, tt.isTLS)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case response := <-responseChan:
|
|
||||||
if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" {
|
|
||||||
resStr := string(response)
|
|
||||||
assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n"))
|
|
||||||
assert.Contains(t, resStr, "Content-Length: 5\r\n")
|
|
||||||
assert.Contains(t, resStr, "Server: Tunnel Please\r\n")
|
|
||||||
assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello"))
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, string(tt.expected), string(response))
|
|
||||||
}
|
|
||||||
case <-time.After(10 * time.Second):
|
|
||||||
if clientConn != nil {
|
|
||||||
t.Fatal("Test timeout - no response received")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
if clientConn != nil {
|
|
||||||
<-doneChan
|
|
||||||
}
|
|
||||||
|
|
||||||
mockSessionRegistry.AssertExpectations(t)
|
|
||||||
if mc, ok := serverConn.(*MockConn); ok {
|
|
||||||
mc.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"tunnel_pls/internal/config"
|
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
)
|
|
||||||
|
|
||||||
type https struct {
|
|
||||||
config config.Config
|
|
||||||
tlsConfig *tls.Config
|
|
||||||
httpHandler *httpHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport {
|
|
||||||
return &https{
|
|
||||||
config: config,
|
|
||||||
tlsConfig: tlsConfig,
|
|
||||||
httpHandler: newHTTPHandler(config, sessionRegistry),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ht *https) Listen() (net.Listener, error) {
|
|
||||||
return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ht *https) Serve(listener net.Listener) error {
|
|
||||||
log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort())
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
log.Printf("Error accepting connection: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
go ht.httpHandler.Handler(conn, true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewHTTPSServer(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
tlsConfig := &tls.Config{}
|
|
||||||
mockConfig.On("Domain").Return(mockConfig)
|
|
||||||
mockConfig.On("HTTPSPort").Return(port)
|
|
||||||
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
|
|
||||||
assert.NotNil(t, srv)
|
|
||||||
|
|
||||||
httpsSrv, ok := srv.(*https)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
|
|
||||||
assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPSServer_Listen(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return(mockConfig)
|
|
||||||
mockConfig.On("HTTPSPort").Return(port)
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
return nil, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
|
|
||||||
|
|
||||||
listener, err := srv.Listen()
|
|
||||||
if err != nil {
|
|
||||||
t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NotNil(t, listener)
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPSServer_Serve(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return(mockConfig)
|
|
||||||
mockConfig.On("HTTPSPort").Return(port)
|
|
||||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = srv.Serve(listener)
|
|
||||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return(mockConfig)
|
|
||||||
mockConfig.On("HTTPSPort").Return(port)
|
|
||||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
|
||||||
|
|
||||||
ml := new(mockListener)
|
|
||||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
|
||||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
||||||
|
|
||||||
err := srv.Serve(ml)
|
|
||||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
|
||||||
ml.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPSServer_Serve_Success(t *testing.T) {
|
|
||||||
msr := new(MockSessionRegistry)
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
port := "0"
|
|
||||||
mockConfig.On("Domain").Return(mockConfig)
|
|
||||||
mockConfig.On("HTTPSPort").Return(port)
|
|
||||||
mockConfig.On("HeaderSize").Return(4096)
|
|
||||||
|
|
||||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
listenerport := listener.Addr().(*net.TCPAddr).Port
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = srv.Serve(listener)
|
|
||||||
}()
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
err = conn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
@@ -1,67 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type tcp struct {
|
|
||||||
port uint16
|
|
||||||
forwarder Forwarder
|
|
||||||
}
|
|
||||||
|
|
||||||
type Forwarder interface {
|
|
||||||
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
|
|
||||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTCPServer(port uint16, forwarder Forwarder) Transport {
|
|
||||||
return &tcp{
|
|
||||||
port: port,
|
|
||||||
forwarder: forwarder,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *tcp) Listen() (net.Listener, error) {
|
|
||||||
return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *tcp) Serve(listener net.Listener) error {
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
log.Printf("Error accepting connection: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go tt.handleTcp(conn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *tcp) handleTcp(conn net.Conn) {
|
|
||||||
defer func() {
|
|
||||||
err := conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, conn.RemoteAddr())
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
tt.forwarder.HandleConnection(conn, channel)
|
|
||||||
}
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewTCPServer(t *testing.T) {
|
|
||||||
mf := new(MockForwarder)
|
|
||||||
port := uint16(9000)
|
|
||||||
|
|
||||||
srv := NewTCPServer(port, mf)
|
|
||||||
assert.NotNil(t, srv)
|
|
||||||
|
|
||||||
tcpSrv, ok := srv.(*tcp)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, port, tcpSrv.port)
|
|
||||||
assert.Equal(t, mf, tcpSrv.forwarder)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTCPServer_Listen(t *testing.T) {
|
|
||||||
mf := new(MockForwarder)
|
|
||||||
srv := NewTCPServer(0, mf)
|
|
||||||
|
|
||||||
listener, err := srv.Listen()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, listener)
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTCPServer_Serve(t *testing.T) {
|
|
||||||
mf := new(MockForwarder)
|
|
||||||
srv := NewTCPServer(0, mf)
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
err = srv.Serve(listener)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTCPServer_Serve_AcceptError(t *testing.T) {
|
|
||||||
mf := new(MockForwarder)
|
|
||||||
srv := NewTCPServer(0, mf)
|
|
||||||
|
|
||||||
ml := new(mockListener)
|
|
||||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
|
||||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
||||||
|
|
||||||
err := srv.Serve(ml)
|
|
||||||
assert.Nil(t, err)
|
|
||||||
ml.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTCPServer_Serve_Success(t *testing.T) {
|
|
||||||
mf := new(MockForwarder)
|
|
||||||
srv := NewTCPServer(0, mf)
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
port := listener.Addr().(*net.TCPAddr).Port
|
|
||||||
|
|
||||||
reqs := make(chan *ssh.Request)
|
|
||||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
|
|
||||||
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
_ = srv.Serve(listener)
|
|
||||||
}()
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
err = conn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
mf.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTCPServer_handleTcp_Success(t *testing.T) {
|
|
||||||
mf := new(MockForwarder)
|
|
||||||
srv := NewTCPServer(0, mf).(*tcp)
|
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
|
||||||
defer func(clientConn net.Conn) {
|
|
||||||
err := clientConn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}(clientConn)
|
|
||||||
|
|
||||||
reqs := make(chan *ssh.Request)
|
|
||||||
mockChannel := new(MockSSHChannel)
|
|
||||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
|
|
||||||
|
|
||||||
mf.On("HandleConnection", serverConn, mockChannel).Return()
|
|
||||||
|
|
||||||
srv.handleTcp(serverConn)
|
|
||||||
|
|
||||||
mf.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTCPServer_handleTcp_CloseError(t *testing.T) {
|
|
||||||
mf := new(MockForwarder)
|
|
||||||
srv := NewTCPServer(0, mf).(*tcp)
|
|
||||||
|
|
||||||
mc := new(MockConn)
|
|
||||||
mc.On("Close").Return(errors.New("close error"))
|
|
||||||
mc.On("RemoteAddr").Return(&net.TCPAddr{})
|
|
||||||
|
|
||||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
|
||||||
|
|
||||||
srv.handleTcp(mc)
|
|
||||||
mc.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
|
|
||||||
mf := new(MockForwarder)
|
|
||||||
srv := NewTCPServer(0, mf).(*tcp)
|
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
|
||||||
defer func(clientConn net.Conn) {
|
|
||||||
err := clientConn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}(clientConn)
|
|
||||||
|
|
||||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
|
||||||
|
|
||||||
srv.handleTcp(serverConn)
|
|
||||||
|
|
||||||
mf.AssertExpectations(t)
|
|
||||||
}
|
|
||||||
@@ -1,435 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/config"
|
|
||||||
|
|
||||||
"github.com/caddyserver/certmagic"
|
|
||||||
"github.com/libdns/cloudflare"
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewTLSConfig(config config.Config) (*tls.Config, error) {
|
|
||||||
var initErr error
|
|
||||||
|
|
||||||
tlsManagerOnce.Do(func() {
|
|
||||||
tm := createTLSManager(config)
|
|
||||||
initErr = tm.initialize()
|
|
||||||
if initErr == nil {
|
|
||||||
globalTLSManager = tm
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if initErr != nil {
|
|
||||||
return nil, initErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return globalTLSManager.getTLSConfig(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type tlsManager struct {
|
|
||||||
config config.Config
|
|
||||||
|
|
||||||
certPath string
|
|
||||||
keyPath string
|
|
||||||
storagePath string
|
|
||||||
|
|
||||||
userCert *tls.Certificate
|
|
||||||
userCertMu sync.RWMutex
|
|
||||||
|
|
||||||
magic *certmagic.Config
|
|
||||||
|
|
||||||
useCertMagic bool
|
|
||||||
}
|
|
||||||
|
|
||||||
var globalTLSManager *tlsManager
|
|
||||||
var tlsManagerOnce sync.Once
|
|
||||||
|
|
||||||
func createTLSManager(cfg config.Config) *tlsManager {
|
|
||||||
storagePath := cfg.TLSStoragePath()
|
|
||||||
cleanBase := filepath.Clean(storagePath)
|
|
||||||
|
|
||||||
return &tlsManager{
|
|
||||||
config: cfg,
|
|
||||||
certPath: filepath.Join(cleanBase, "cert.pem"),
|
|
||||||
keyPath: filepath.Join(cleanBase, "privkey.pem"),
|
|
||||||
storagePath: filepath.Join(cleanBase, "certmagic"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) initialize() error {
|
|
||||||
if tm.userCertsExistAndValid() {
|
|
||||||
return tm.initializeWithUserCerts()
|
|
||||||
}
|
|
||||||
return tm.initializeWithCertMagic()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) initializeWithUserCerts() error {
|
|
||||||
log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath)
|
|
||||||
|
|
||||||
if err := tm.loadUserCerts(); err != nil {
|
|
||||||
return fmt.Errorf("failed to load user certificates: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tm.useCertMagic = false
|
|
||||||
tm.startCertWatcher()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) initializeWithCertMagic() error {
|
|
||||||
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic",
|
|
||||||
tm.config.Domain(), tm.config.Domain())
|
|
||||||
|
|
||||||
if err := tm.initCertMagic(); err != nil {
|
|
||||||
return fmt.Errorf("failed to initialize CertMagic: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tm.useCertMagic = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) userCertsExistAndValid() bool {
|
|
||||||
if !tm.certFilesExist() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return validateCertDomains(tm.certPath, tm.config.Domain())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) certFilesExist() bool {
|
|
||||||
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
|
||||||
log.Printf("Certificate file not found: %s", tm.certPath)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if _, err := os.Stat(tm.keyPath); os.IsNotExist(err) {
|
|
||||||
log.Printf("Key file not found: %s", tm.keyPath)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) loadUserCerts() error {
|
|
||||||
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tm.userCertMu.Lock()
|
|
||||||
tm.userCert = &cert
|
|
||||||
tm.userCertMu.Unlock()
|
|
||||||
|
|
||||||
log.Printf("Loaded user certificates successfully")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) startCertWatcher() {
|
|
||||||
go func() {
|
|
||||||
watcher := newCertWatcher(tm)
|
|
||||||
watcher.watch()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) initCertMagic() error {
|
|
||||||
if err := tm.createStorageDirectory(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if tm.config.CFAPIToken() == "" {
|
|
||||||
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
|
||||||
}
|
|
||||||
|
|
||||||
magic := tm.createCertMagicConfig()
|
|
||||||
tm.magic = magic
|
|
||||||
|
|
||||||
return tm.obtainCertificates(magic)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) createStorageDirectory() error {
|
|
||||||
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
|
||||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
|
|
||||||
cfProvider := &cloudflare.Provider{
|
|
||||||
APIToken: tm.config.CFAPIToken(),
|
|
||||||
}
|
|
||||||
|
|
||||||
storage := &certmagic.FileStorage{Path: tm.storagePath}
|
|
||||||
|
|
||||||
cache := certmagic.NewCache(certmagic.CacheOptions{
|
|
||||||
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
|
|
||||||
return tm.magic, nil
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
magic := certmagic.New(cache, certmagic.Config{
|
|
||||||
Storage: storage,
|
|
||||||
})
|
|
||||||
|
|
||||||
acmeIssuer := tm.createACMEIssuer(magic, cfProvider)
|
|
||||||
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
|
||||||
|
|
||||||
return magic
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer {
|
|
||||||
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
|
||||||
Email: tm.config.ACMEEmail(),
|
|
||||||
Agreed: true,
|
|
||||||
DNS01Solver: &certmagic.DNS01Solver{
|
|
||||||
DNSManager: certmagic.DNSManager{
|
|
||||||
DNSProvider: cfProvider,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if tm.config.ACMEStaging() {
|
|
||||||
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
|
|
||||||
log.Printf("Using Let's Encrypt staging server")
|
|
||||||
} else {
|
|
||||||
acmeIssuer.CA = certmagic.LetsEncryptProductionCA
|
|
||||||
log.Printf("Using Let's Encrypt production server")
|
|
||||||
}
|
|
||||||
|
|
||||||
return acmeIssuer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error {
|
|
||||||
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
|
|
||||||
log.Printf("Requesting certificates for: %v", domains)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
if err := magic.ManageSync(ctx, domains); err != nil {
|
|
||||||
return fmt.Errorf("failed to obtain certificates: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Certificates obtained successfully for %v", domains)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) getTLSConfig() *tls.Config {
|
|
||||||
return &tls.Config{
|
|
||||||
GetCertificate: tm.getCertificate,
|
|
||||||
|
|
||||||
MinVersion: tls.VersionTLS13,
|
|
||||||
MaxVersion: tls.VersionTLS13,
|
|
||||||
|
|
||||||
CurvePreferences: []tls.CurveID{
|
|
||||||
tls.X25519,
|
|
||||||
},
|
|
||||||
|
|
||||||
SessionTicketsDisabled: false,
|
|
||||||
ClientAuth: tls.NoClientCert,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
|
||||||
if tm.useCertMagic {
|
|
||||||
return tm.magic.GetCertificate(hello)
|
|
||||||
}
|
|
||||||
|
|
||||||
tm.userCertMu.RLock()
|
|
||||||
defer tm.userCertMu.RUnlock()
|
|
||||||
|
|
||||||
if tm.userCert == nil {
|
|
||||||
return nil, fmt.Errorf("no certificate available")
|
|
||||||
}
|
|
||||||
|
|
||||||
return tm.userCert, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateCertDomains(certPath, domain string) bool {
|
|
||||||
cert, err := loadAndParseCertificate(certPath)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isCertificateValid(cert) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return certCoversRequiredDomains(cert, domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
func loadAndParseCertificate(certPath string) (*x509.Certificate, error) {
|
|
||||||
certPEM, err := os.ReadFile(certPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to read certificate: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
block, _ := pem.Decode(certPEM)
|
|
||||||
if block == nil {
|
|
||||||
log.Printf("Failed to decode PEM block from certificate")
|
|
||||||
return nil, fmt.Errorf("failed to decode PEM block")
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := x509.ParseCertificate(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to parse certificate: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return cert, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isCertificateValid(cert *x509.Certificate) bool {
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
if now.After(cert.NotAfter) {
|
|
||||||
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
thirtyDaysFromNow := now.Add(30 * 24 * time.Hour)
|
|
||||||
if thirtyDaysFromNow.After(cert.NotAfter) {
|
|
||||||
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool {
|
|
||||||
certDomains := extractCertDomains(cert)
|
|
||||||
hasBase, hasWildcard := checkDomainCoverage(certDomains, domain)
|
|
||||||
|
|
||||||
logDomainCoverage(hasBase, hasWildcard, domain)
|
|
||||||
return hasBase && hasWildcard
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractCertDomains(cert *x509.Certificate) []string {
|
|
||||||
var domains []string
|
|
||||||
if cert.Subject.CommonName != "" {
|
|
||||||
domains = append(domains, cert.Subject.CommonName)
|
|
||||||
}
|
|
||||||
domains = append(domains, cert.DNSNames...)
|
|
||||||
return domains
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) {
|
|
||||||
wildcardDomain := "*." + domain
|
|
||||||
|
|
||||||
for _, d := range certDomains {
|
|
||||||
if d == domain {
|
|
||||||
hasBase = true
|
|
||||||
}
|
|
||||||
if d == wildcardDomain {
|
|
||||||
hasWildcard = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return hasBase, hasWildcard
|
|
||||||
}
|
|
||||||
|
|
||||||
func logDomainCoverage(hasBase, hasWildcard bool, domain string) {
|
|
||||||
if !hasBase {
|
|
||||||
log.Printf("Certificate does not cover base domain: %s", domain)
|
|
||||||
}
|
|
||||||
if !hasWildcard {
|
|
||||||
log.Printf("Certificate does not cover wildcard domain: *.%s", domain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type certWatcher struct {
|
|
||||||
tm *tlsManager
|
|
||||||
lastCertMod time.Time
|
|
||||||
lastKeyMod time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCertWatcher(tm *tlsManager) *certWatcher {
|
|
||||||
watcher := &certWatcher{tm: tm}
|
|
||||||
watcher.initializeModTimes()
|
|
||||||
return watcher
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *certWatcher) initializeModTimes() {
|
|
||||||
if info, err := os.Stat(cw.tm.certPath); err == nil {
|
|
||||||
cw.lastCertMod = info.ModTime()
|
|
||||||
}
|
|
||||||
if info, err := os.Stat(cw.tm.keyPath); err == nil {
|
|
||||||
cw.lastKeyMod = info.ModTime()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *certWatcher) watch() {
|
|
||||||
ticker := time.NewTicker(30 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for range ticker.C {
|
|
||||||
if cw.checkAndReloadCerts() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *certWatcher) checkAndReloadCerts() bool {
|
|
||||||
certInfo, keyInfo, err := cw.getFileInfo()
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cw.filesModified(certInfo, keyInfo) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return cw.handleCertificateChange(certInfo, keyInfo)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) {
|
|
||||||
certInfo, certErr := os.Stat(cw.tm.certPath)
|
|
||||||
keyInfo, keyErr := os.Stat(cw.tm.keyPath)
|
|
||||||
|
|
||||||
if certErr != nil || keyErr != nil {
|
|
||||||
return nil, nil, fmt.Errorf("file stat error")
|
|
||||||
}
|
|
||||||
|
|
||||||
return certInfo, keyInfo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool {
|
|
||||||
return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool {
|
|
||||||
log.Printf("Certificate files changed, reloading...")
|
|
||||||
|
|
||||||
if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) {
|
|
||||||
return cw.switchToCertMagic()
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := cw.tm.loadUserCerts(); err != nil {
|
|
||||||
log.Printf("Failed to reload certificates: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
cw.updateModTimes(certInfo, keyInfo)
|
|
||||||
log.Printf("Certificates reloaded successfully")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *certWatcher) switchToCertMagic() bool {
|
|
||||||
log.Printf("New certificates don't cover required domains")
|
|
||||||
|
|
||||||
if err := cw.tm.initCertMagic(); err != nil {
|
|
||||||
log.Printf("Failed to initialize CertMagic: %v", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
cw.tm.useCertMagic = true
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) {
|
|
||||||
cw.lastCertMod = certInfo.ModTime()
|
|
||||||
cw.lastKeyMod = keyInfo.ModTime()
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +0,0 @@
|
|||||||
package transport
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Transport interface {
|
|
||||||
Listen() (net.Listener, error)
|
|
||||||
Serve(listener net.Listener) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type HTTP interface {
|
|
||||||
Handler(conn net.Conn, isTLS bool)
|
|
||||||
}
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package version
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestVersionFunctions(t *testing.T) {
|
|
||||||
origVersion := Version
|
|
||||||
origBuildDate := BuildDate
|
|
||||||
origCommit := Commit
|
|
||||||
defer func() {
|
|
||||||
Version = origVersion
|
|
||||||
BuildDate = origBuildDate
|
|
||||||
Commit = origCommit
|
|
||||||
}()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
buildDate string
|
|
||||||
commit string
|
|
||||||
wantFull string
|
|
||||||
wantShort string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Default dev version",
|
|
||||||
version: "dev",
|
|
||||||
buildDate: "unknown",
|
|
||||||
commit: "unknown",
|
|
||||||
wantFull: "tunnel_pls dev (commit: unknown, built: unknown)",
|
|
||||||
wantShort: "dev",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Release version",
|
|
||||||
version: "v1.0.0",
|
|
||||||
buildDate: "2026-01-23",
|
|
||||||
commit: "abcdef123",
|
|
||||||
wantFull: "tunnel_pls v1.0.0 (commit: abcdef123, built: 2026-01-23)",
|
|
||||||
wantShort: "v1.0.0",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Empty values",
|
|
||||||
version: "",
|
|
||||||
buildDate: "",
|
|
||||||
commit: "",
|
|
||||||
wantFull: "tunnel_pls (commit: , built: )",
|
|
||||||
wantShort: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
Version = tt.version
|
|
||||||
BuildDate = tt.buildDate
|
|
||||||
Commit = tt.commit
|
|
||||||
|
|
||||||
gotFull := GetVersion()
|
|
||||||
if gotFull != tt.wantFull {
|
|
||||||
t.Errorf("GetVersion() = %q, want %q", gotFull, tt.wantFull)
|
|
||||||
}
|
|
||||||
|
|
||||||
gotShort := GetShortVersion()
|
|
||||||
if gotShort != tt.wantShort {
|
|
||||||
t.Errorf("GetShortVersion() = %q, want %q", gotShort, tt.wantShort)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetVersion_Format(t *testing.T) {
|
|
||||||
v := "1.2.3"
|
|
||||||
c := "brainrot"
|
|
||||||
d := "now"
|
|
||||||
|
|
||||||
Version = v
|
|
||||||
Commit = c
|
|
||||||
BuildDate = d
|
|
||||||
|
|
||||||
expected := fmt.Sprintf("tunnel_pls %s (commit: %s, built: %s)", v, c, d)
|
|
||||||
if GetVersion() != expected {
|
|
||||||
t.Errorf("GetVersion() formatting mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,13 +1,24 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
"tunnel_pls/internal/bootstrap"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/internal/port"
|
"tunnel_pls/internal/grpc/client"
|
||||||
"tunnel_pls/internal/version"
|
"tunnel_pls/internal/key"
|
||||||
|
"tunnel_pls/server"
|
||||||
|
"tunnel_pls/session"
|
||||||
|
"tunnel_pls/version"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -18,19 +29,113 @@ func main() {
|
|||||||
|
|
||||||
log.SetOutput(os.Stdout)
|
log.SetOutput(os.Stdout)
|
||||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||||
|
|
||||||
log.Printf("Starting %s", version.GetVersion())
|
log.Printf("Starting %s", version.GetVersion())
|
||||||
|
|
||||||
conf, err := config.MustLoad()
|
mode := strings.ToLower(config.Getenv("MODE", "standalone"))
|
||||||
if err != nil {
|
isNodeMode := mode == "node"
|
||||||
log.Fatalf("Config load error: %v", err)
|
|
||||||
|
pprofEnabled := config.Getenv("PPROF_ENABLED", "false")
|
||||||
|
if pprofEnabled == "true" {
|
||||||
|
pprofPort := config.Getenv("PPROF_PORT", "6060")
|
||||||
|
go func() {
|
||||||
|
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
|
||||||
|
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
||||||
|
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||||
|
log.Printf("pprof server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
boot, err := bootstrap.New(conf, port.New())
|
sshConfig := &ssh.ServerConfig{
|
||||||
if err != nil {
|
NoClientAuth: true,
|
||||||
log.Fatalf("Startup error: %v", err)
|
ServerVersion: fmt.Sprintf("SSH-2.0-TunnlPls-%s", version.GetShortVersion()),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = boot.Run(); err != nil {
|
sshKeyPath := "certs/ssh/id_rsa"
|
||||||
log.Fatalf("Application error: %v", err)
|
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
|
||||||
|
log.Fatalf("Failed to generate SSH key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateBytes, err := os.ReadFile(sshKeyPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load private key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to parse private key: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sshConfig.AddHostKey(private)
|
||||||
|
sessionRegistry := session.NewRegistry()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
shutdownChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
|
var grpcClient *client.Client
|
||||||
|
if isNodeMode {
|
||||||
|
grpcHost := config.Getenv("GRPC_ADDRESS", "localhost")
|
||||||
|
grpcPort := config.Getenv("GRPC_PORT", "8080")
|
||||||
|
grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort)
|
||||||
|
nodeToken := config.Getenv("NODE_TOKEN", "")
|
||||||
|
if nodeToken == "" {
|
||||||
|
log.Fatalf("NODE_TOKEN is required in node mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := client.New(&client.GrpcConfig{
|
||||||
|
Address: grpcAddr,
|
||||||
|
UseTLS: false,
|
||||||
|
InsecureSkipVerify: false,
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
KeepAlive: true,
|
||||||
|
MaxRetries: 3,
|
||||||
|
}, sessionRegistry)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create grpc client: %v", err)
|
||||||
|
}
|
||||||
|
grpcClient = c
|
||||||
|
|
||||||
|
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
if err := grpcClient.CheckServerHealth(healthCtx); err != nil {
|
||||||
|
healthCancel()
|
||||||
|
log.Fatalf("gRPC health check failed: %v", err)
|
||||||
|
}
|
||||||
|
healthCancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
identity := config.Getenv("DOMAIN", "localhost")
|
||||||
|
if err := grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil {
|
||||||
|
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
app, err := server.NewServer(sshConfig, sessionRegistry, grpcClient)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- fmt.Errorf("failed to start server: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
app.Start()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errChan:
|
||||||
|
log.Printf("error happen : %s", err)
|
||||||
|
case sig := <-shutdownChan:
|
||||||
|
log.Printf("received signal %s, shutting down", sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if grpcClient != nil {
|
||||||
|
if err := grpcClient.Close(); err != nil {
|
||||||
|
log.Printf("failed to close grpc conn : %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
module.exports = {
|
||||||
|
"endpoint": "https://git.fossy.my.id/api/v1",
|
||||||
|
"gitAuthor": "Renovate-Clanker <renovate-bot@fossy.my.id>",
|
||||||
|
"platform": "gitea",
|
||||||
|
"onboardingConfigFileName": "renovate.json",
|
||||||
|
"autodiscover": true,
|
||||||
|
"optimizeForDisabled": true,
|
||||||
|
};
|
||||||
@@ -0,0 +1,276 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HeaderManager interface {
|
||||||
|
Get(key string) []byte
|
||||||
|
Set(key string, value []byte)
|
||||||
|
Remove(key string)
|
||||||
|
Finalize() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseHeaderManager interface {
|
||||||
|
Get(key string) string
|
||||||
|
Set(key string, value string)
|
||||||
|
Remove(key string)
|
||||||
|
Finalize() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestHeaderManager interface {
|
||||||
|
Get(key string) string
|
||||||
|
Set(key string, value string)
|
||||||
|
Remove(key string)
|
||||||
|
Finalize() []byte
|
||||||
|
GetMethod() string
|
||||||
|
GetPath() string
|
||||||
|
GetVersion() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseHeaderFactory struct {
|
||||||
|
startLine []byte
|
||||||
|
headers map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestHeaderFactory struct {
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
version string
|
||||||
|
startLine []byte
|
||||||
|
headers map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) {
|
||||||
|
switch v := r.(type) {
|
||||||
|
case []byte:
|
||||||
|
return parseHeadersFromBytes(v)
|
||||||
|
case *bufio.Reader:
|
||||||
|
return parseHeadersFromReader(v)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported type: %T", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) {
|
||||||
|
header := &requestHeaderFactory{
|
||||||
|
headers: make(map[string]string, 16),
|
||||||
|
}
|
||||||
|
|
||||||
|
lineEnd := bytes.IndexByte(headerData, '\n')
|
||||||
|
if lineEnd == -1 {
|
||||||
|
return nil, fmt.Errorf("invalid request: no newline found")
|
||||||
|
}
|
||||||
|
|
||||||
|
startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n")
|
||||||
|
header.startLine = make([]byte, len(startLine))
|
||||||
|
copy(header.startLine, startLine)
|
||||||
|
|
||||||
|
parts := bytes.Split(startLine, []byte{' '})
|
||||||
|
if len(parts) < 3 {
|
||||||
|
return nil, fmt.Errorf("invalid request line")
|
||||||
|
}
|
||||||
|
|
||||||
|
header.method = string(parts[0])
|
||||||
|
header.path = string(parts[1])
|
||||||
|
header.version = string(parts[2])
|
||||||
|
|
||||||
|
remaining := headerData[lineEnd+1:]
|
||||||
|
|
||||||
|
for len(remaining) > 0 {
|
||||||
|
lineEnd = bytes.IndexByte(remaining, '\n')
|
||||||
|
if lineEnd == -1 {
|
||||||
|
lineEnd = len(remaining)
|
||||||
|
}
|
||||||
|
|
||||||
|
line := bytes.TrimRight(remaining[:lineEnd], "\r\n")
|
||||||
|
|
||||||
|
if len(line) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
colonIdx := bytes.IndexByte(line, ':')
|
||||||
|
if colonIdx != -1 {
|
||||||
|
key := bytes.TrimSpace(line[:colonIdx])
|
||||||
|
value := bytes.TrimSpace(line[colonIdx+1:])
|
||||||
|
header.headers[string(key)] = string(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if lineEnd == len(remaining) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
remaining = remaining[lineEnd+1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return header, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) {
|
||||||
|
header := &requestHeaderFactory{
|
||||||
|
headers: make(map[string]string, 16),
|
||||||
|
}
|
||||||
|
|
||||||
|
startLineBytes, err := br.ReadSlice('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == bufio.ErrBufferFull {
|
||||||
|
var startLine string
|
||||||
|
startLine, err = br.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
startLineBytes = []byte(startLine)
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
|
||||||
|
header.startLine = make([]byte, len(startLineBytes))
|
||||||
|
copy(header.startLine, startLineBytes)
|
||||||
|
|
||||||
|
parts := bytes.Split(startLineBytes, []byte{' '})
|
||||||
|
if len(parts) < 3 {
|
||||||
|
return nil, fmt.Errorf("invalid request line")
|
||||||
|
}
|
||||||
|
|
||||||
|
header.method = string(parts[0])
|
||||||
|
header.path = string(parts[1])
|
||||||
|
header.version = string(parts[2])
|
||||||
|
|
||||||
|
for {
|
||||||
|
lineBytes, err := br.ReadSlice('\n')
|
||||||
|
if err != nil {
|
||||||
|
if err == bufio.ErrBufferFull {
|
||||||
|
var line string
|
||||||
|
line, err = br.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
lineBytes = []byte(line)
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
|
||||||
|
|
||||||
|
if len(lineBytes) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
colonIdx := bytes.IndexByte(lineBytes, ':')
|
||||||
|
if colonIdx == -1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
key := bytes.TrimSpace(lineBytes[:colonIdx])
|
||||||
|
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
|
||||||
|
|
||||||
|
header.headers[string(key)] = string(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return header, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
|
||||||
|
header := &responseHeaderFactory{
|
||||||
|
startLine: nil,
|
||||||
|
headers: make(map[string]string),
|
||||||
|
}
|
||||||
|
lines := bytes.Split(startLine, []byte("\r\n"))
|
||||||
|
if len(lines) == 0 {
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
header.startLine = lines[0]
|
||||||
|
for _, h := range lines[1:] {
|
||||||
|
if len(h) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := bytes.SplitN(h, []byte(":"), 2)
|
||||||
|
if len(parts) < 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
key := parts[0]
|
||||||
|
val := bytes.TrimSpace(parts[1])
|
||||||
|
header.headers[string(key)] = string(val)
|
||||||
|
}
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *responseHeaderFactory) Get(key string) string {
|
||||||
|
return resp.headers[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *responseHeaderFactory) Set(key string, value string) {
|
||||||
|
resp.headers[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *responseHeaderFactory) Remove(key string) {
|
||||||
|
delete(resp.headers, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *responseHeaderFactory) Finalize() []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
buf.Write(resp.startLine)
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
|
||||||
|
for key, val := range resp.headers {
|
||||||
|
buf.WriteString(key)
|
||||||
|
buf.WriteString(": ")
|
||||||
|
buf.WriteString(val)
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeaderFactory) Get(key string) string {
|
||||||
|
val, ok := req.headers[key]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeaderFactory) Set(key string, value string) {
|
||||||
|
req.headers[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeaderFactory) Remove(key string) {
|
||||||
|
delete(req.headers, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeaderFactory) GetMethod() string {
|
||||||
|
return req.method
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeaderFactory) GetPath() string {
|
||||||
|
return req.path
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeaderFactory) GetVersion() string {
|
||||||
|
return req.version
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeaderFactory) Finalize() []byte {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
buf.Write(req.startLine)
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
|
||||||
|
for key, val := range req.headers {
|
||||||
|
buf.WriteString(key)
|
||||||
|
buf.WriteString(": ")
|
||||||
|
buf.WriteString(val)
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteString("\r\n")
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
+391
@@ -0,0 +1,391 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"tunnel_pls/internal/config"
|
||||||
|
"tunnel_pls/session"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type HTTPWriter interface {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
GetRemoteAddr() net.Addr
|
||||||
|
GetWriter() io.Writer
|
||||||
|
AddResponseMiddleware(mw ResponseMiddleware)
|
||||||
|
AddRequestStartMiddleware(mw RequestMiddleware)
|
||||||
|
SetRequestHeader(header RequestHeaderManager)
|
||||||
|
GetRequestStartMiddleware() []RequestMiddleware
|
||||||
|
}
|
||||||
|
|
||||||
|
type customWriter struct {
|
||||||
|
remoteAddr net.Addr
|
||||||
|
writer io.Writer
|
||||||
|
reader io.Reader
|
||||||
|
headerBuf []byte
|
||||||
|
buf []byte
|
||||||
|
respHeader ResponseHeaderManager
|
||||||
|
reqHeader RequestHeaderManager
|
||||||
|
respMW []ResponseMiddleware
|
||||||
|
reqStartMW []RequestMiddleware
|
||||||
|
reqEndMW []RequestMiddleware
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *customWriter) GetRemoteAddr() net.Addr {
|
||||||
|
return cw.remoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *customWriter) GetWriter() io.Writer {
|
||||||
|
return cw.writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) {
|
||||||
|
cw.respMW = append(cw.respMW, mw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) {
|
||||||
|
cw.reqStartMW = append(cw.reqStartMW, mw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) {
|
||||||
|
cw.reqHeader = header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware {
|
||||||
|
return cw.reqStartMW
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *customWriter) Read(p []byte) (int, error) {
|
||||||
|
tmp := make([]byte, len(p))
|
||||||
|
read, err := cw.reader.Read(tmp)
|
||||||
|
if read == 0 && err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp = tmp[:read]
|
||||||
|
|
||||||
|
idx := bytes.Index(tmp, DELIMITER)
|
||||||
|
if idx == -1 {
|
||||||
|
copy(p, tmp)
|
||||||
|
if err != nil {
|
||||||
|
return read, err
|
||||||
|
}
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
header := tmp[:idx+len(DELIMITER)]
|
||||||
|
body := tmp[idx+len(DELIMITER):]
|
||||||
|
|
||||||
|
if !isHTTPHeader(header) {
|
||||||
|
copy(p, tmp)
|
||||||
|
return read, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range cw.reqEndMW {
|
||||||
|
err = m.HandleRequest(cw.reqHeader)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error when applying request middleware: %v", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reqhf, err := NewRequestHeaderFactory(header)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range cw.reqStartMW {
|
||||||
|
if mwErr := m.HandleRequest(reqhf); mwErr != nil {
|
||||||
|
log.Printf("Error when applying request middleware: %v", mwErr)
|
||||||
|
return 0, mwErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.reqHeader = reqhf
|
||||||
|
finalHeader := reqhf.Finalize()
|
||||||
|
|
||||||
|
combined := append(finalHeader, body...)
|
||||||
|
|
||||||
|
n := copy(p, combined)
|
||||||
|
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
|
||||||
|
return &customWriter{
|
||||||
|
remoteAddr: remoteAddr,
|
||||||
|
writer: writer,
|
||||||
|
reader: reader,
|
||||||
|
buf: make([]byte, 0, 4096),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
|
||||||
|
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
|
||||||
|
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
|
||||||
|
|
||||||
|
func isHTTPHeader(buf []byte) bool {
|
||||||
|
lines := bytes.Split(buf, []byte("\r\n"))
|
||||||
|
|
||||||
|
startLine := string(lines[0])
|
||||||
|
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range lines[1:] {
|
||||||
|
if len(line) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
colonIdx := bytes.IndexByte(line, ':')
|
||||||
|
if colonIdx <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cw *customWriter) Write(p []byte) (int, error) {
|
||||||
|
if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" {
|
||||||
|
cw.respHeader = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cw.respHeader != nil {
|
||||||
|
n, err := cw.writer.Write(p)
|
||||||
|
if err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cw.buf = append(cw.buf, p...)
|
||||||
|
|
||||||
|
idx := bytes.Index(cw.buf, DELIMITER)
|
||||||
|
if idx == -1 {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
header := cw.buf[:idx+len(DELIMITER)]
|
||||||
|
body := cw.buf[idx+len(DELIMITER):]
|
||||||
|
|
||||||
|
if !isHTTPHeader(header) {
|
||||||
|
_, err := cw.writer.Write(cw.buf)
|
||||||
|
cw.buf = nil
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resphf := NewResponseHeaderFactory(header)
|
||||||
|
for _, m := range cw.respMW {
|
||||||
|
err := m.HandleResponse(resphf, body)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Cannot apply middleware: %s\n", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
header = resphf.Finalize()
|
||||||
|
cw.respHeader = resphf
|
||||||
|
|
||||||
|
_, err := cw.writer.Write(header)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if len(body) > 0 {
|
||||||
|
_, err = cw.writer.Write(body)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cw.buf = nil
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var redirectTLS = false
|
||||||
|
|
||||||
|
type HTTPServer interface {
|
||||||
|
ListenAndServe() error
|
||||||
|
ListenAndServeTLS() error
|
||||||
|
handler(conn net.Conn)
|
||||||
|
handlerTLS(conn net.Conn)
|
||||||
|
}
|
||||||
|
type httpServer struct {
|
||||||
|
sessionRegistry session.Registry
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHTTPServer(sessionRegistry session.Registry) HTTPServer {
|
||||||
|
return &httpServer{sessionRegistry: sessionRegistry}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hs *httpServer) ListenAndServe() error {
|
||||||
|
httpPort := config.Getenv("HTTP_PORT", "8080")
|
||||||
|
listener, err := net.Listen("tcp", ":"+httpPort)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("Error listening: " + err.Error())
|
||||||
|
}
|
||||||
|
if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" {
|
||||||
|
redirectTLS = true
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var conn net.Conn
|
||||||
|
conn, err = listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Error accepting connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go hs.handler(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hs *httpServer) handler(conn net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
err := conn.Close()
|
||||||
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
log.Printf("Error closing connection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}()
|
||||||
|
|
||||||
|
dstReader := bufio.NewReader(conn)
|
||||||
|
reqhf, err := NewRequestHeaderFactory(dstReader)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error creating request header: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host := strings.Split(reqhf.Get("Host"), ".")
|
||||||
|
if len(host) < 1 {
|
||||||
|
_, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to write 400 Bad Request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slug := host[0]
|
||||||
|
|
||||||
|
if redirectTLS {
|
||||||
|
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
|
||||||
|
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) +
|
||||||
|
"Content-Length: 0\r\n" +
|
||||||
|
"Connection: close\r\n" +
|
||||||
|
"\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to write 301 Moved Permanently:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if slug == "ping" {
|
||||||
|
_, err = conn.Write([]byte(
|
||||||
|
"HTTP/1.1 200 OK\r\n" +
|
||||||
|
"Content-Length: 0\r\n" +
|
||||||
|
"Connection: close\r\n" +
|
||||||
|
"Access-Control-Allow-Origin: *\r\n" +
|
||||||
|
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
|
||||||
|
"Access-Control-Allow-Headers: *\r\n" +
|
||||||
|
"\r\n",
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to write 200 OK:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sshSession, err := hs.sessionRegistry.Get(slug)
|
||||||
|
if err != nil {
|
||||||
|
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
|
||||||
|
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
|
||||||
|
"Content-Length: 0\r\n" +
|
||||||
|
"Connection: close\r\n" +
|
||||||
|
"\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to write 301 Moved Permanently:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
||||||
|
forwardRequest(cw, reqhf, sshSession)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) {
|
||||||
|
payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
|
||||||
|
|
||||||
|
type channelResult struct {
|
||||||
|
channel ssh.Channel
|
||||||
|
reqs <-chan *ssh.Request
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultChan := make(chan channelResult, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||||
|
resultChan <- channelResult{channel, reqs, err}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var channel ssh.Channel
|
||||||
|
var reqs <-chan *ssh.Request
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result := <-resultChan:
|
||||||
|
if result.err != nil {
|
||||||
|
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
|
||||||
|
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
channel = result.channel
|
||||||
|
reqs = result.reqs
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
log.Printf("Timeout opening forwarded-tcpip channel")
|
||||||
|
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
|
fingerprintMiddleware := NewTunnelFingerprint()
|
||||||
|
forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr())
|
||||||
|
|
||||||
|
cw.AddResponseMiddleware(fingerprintMiddleware)
|
||||||
|
cw.AddRequestStartMiddleware(forwardedForMiddleware)
|
||||||
|
cw.SetRequestHeader(initialRequest)
|
||||||
|
|
||||||
|
for _, m := range cw.GetRequestStartMiddleware() {
|
||||||
|
if err := m.HandleRequest(initialRequest); err != nil {
|
||||||
|
log.Printf("Error handling request: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := channel.Write(initialRequest.Finalize())
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to forward request: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
|
||||||
|
return
|
||||||
|
}
|
||||||
+108
@@ -0,0 +1,108 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"tunnel_pls/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (hs *httpServer) ListenAndServeTLS() error {
|
||||||
|
domain := config.Getenv("DOMAIN", "localhost")
|
||||||
|
httpsPort := config.Getenv("HTTPS_PORT", "8443")
|
||||||
|
|
||||||
|
tlsConfig, err := NewTLSConfig(domain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize TLS config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
var conn net.Conn
|
||||||
|
conn, err = ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
log.Println("https server closed")
|
||||||
|
}
|
||||||
|
log.Printf("Error accepting connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go hs.handlerTLS(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hs *httpServer) handlerTLS(conn net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
err := conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error closing connection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}()
|
||||||
|
|
||||||
|
dstReader := bufio.NewReader(conn)
|
||||||
|
reqhf, err := NewRequestHeaderFactory(dstReader)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error creating request header: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
host := strings.Split(reqhf.Get("Host"), ".")
|
||||||
|
if len(host) < 1 {
|
||||||
|
_, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to write 400 Bad Request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slug := host[0]
|
||||||
|
|
||||||
|
if slug == "ping" {
|
||||||
|
_, err = conn.Write([]byte(
|
||||||
|
"HTTP/1.1 200 OK\r\n" +
|
||||||
|
"Content-Length: 0\r\n" +
|
||||||
|
"Connection: close\r\n" +
|
||||||
|
"Access-Control-Allow-Origin: *\r\n" +
|
||||||
|
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
|
||||||
|
"Access-Control-Allow-Headers: *\r\n" +
|
||||||
|
"\r\n",
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to write 200 OK:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sshSession, err := hs.sessionRegistry.Get(slug)
|
||||||
|
if err != nil {
|
||||||
|
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
|
||||||
|
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
|
||||||
|
"Content-Length: 0\r\n" +
|
||||||
|
"Connection: close\r\n" +
|
||||||
|
"\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to write 301 Moved Permanently:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
||||||
|
forwardRequest(cw, reqhf, sshSession)
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RequestMiddleware interface {
|
||||||
|
HandleRequest(header RequestHeaderManager) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResponseMiddleware interface {
|
||||||
|
HandleResponse(header ResponseHeaderManager, body []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TunnelFingerprint struct{}
|
||||||
|
|
||||||
|
func NewTunnelFingerprint() *TunnelFingerprint {
|
||||||
|
return &TunnelFingerprint{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
|
||||||
|
header.Set("Server", "Tunnel Please")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ForwardedFor struct {
|
||||||
|
addr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewForwardedFor(addr net.Addr) *ForwardedFor {
|
||||||
|
return &ForwardedFor{addr: addr}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
|
||||||
|
host, _, err := net.SplitHostPort(ff.addr.String())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
header.Set("X-Forwarded-For", host)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
+45
-63
@@ -2,64 +2,58 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/internal/grpc/client"
|
"tunnel_pls/internal/grpc/client"
|
||||||
"tunnel_pls/internal/port"
|
|
||||||
"tunnel_pls/internal/random"
|
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
"tunnel_pls/session"
|
"tunnel_pls/session"
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Server interface {
|
type Server struct {
|
||||||
Start()
|
conn *net.Listener
|
||||||
Close() error
|
config *ssh.ServerConfig
|
||||||
}
|
sessionRegistry session.Registry
|
||||||
type server struct {
|
grpcClient *client.Client
|
||||||
randomizer random.Random
|
|
||||||
config config.Config
|
|
||||||
sshPort string
|
|
||||||
sshListener net.Listener
|
|
||||||
sshConfig *ssh.ServerConfig
|
|
||||||
grpcClient client.Client
|
|
||||||
sessionRegistry registry.Registry
|
|
||||||
portRegistry port.Port
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(randomizer random.Random, config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
|
func NewServer(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient *client.Client) (*Server, error) {
|
||||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200")))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Fatalf("failed to listen on port 2200: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &server{
|
HttpServer := NewHTTPServer(sessionRegistry)
|
||||||
randomizer: randomizer,
|
err = HttpServer.ListenAndServe()
|
||||||
config: config,
|
if err != nil {
|
||||||
sshPort: sshPort,
|
log.Fatalf("failed to start http server: %v", err)
|
||||||
sshListener: listener,
|
return nil, err
|
||||||
sshConfig: sshConfig,
|
}
|
||||||
grpcClient: grpcClient,
|
|
||||||
|
if config.Getenv("TLS_ENABLED", "false") == "true" {
|
||||||
|
err = HttpServer.ListenAndServeTLS()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to start https server: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Server{
|
||||||
|
conn: &listener,
|
||||||
|
config: sshConfig,
|
||||||
sessionRegistry: sessionRegistry,
|
sessionRegistry: sessionRegistry,
|
||||||
portRegistry: portRegistry,
|
grpcClient: grpcClient,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) Start() {
|
func (s *Server) Start() {
|
||||||
log.Printf("SSH server is starting on port %s", s.sshPort)
|
log.Println("SSH server is starting on port 2200...")
|
||||||
for {
|
for {
|
||||||
conn, err := s.sshListener.Accept()
|
conn, err := (*s.conn).Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
log.Println("listener closed, stopping server")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("failed to accept connection: %v", err)
|
log.Printf("failed to accept connection: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -68,50 +62,38 @@ func (s *server) Start() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) Close() error {
|
func (s *Server) handleConnection(conn net.Conn) {
|
||||||
return s.sshListener.Close()
|
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
|
||||||
}
|
defer func(sshConn *ssh.ServerConn) {
|
||||||
|
err = sshConn.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close SSH server: %v", err)
|
||||||
|
}
|
||||||
|
}(sshConn)
|
||||||
|
|
||||||
func (s *server) handleConnection(conn net.Conn) {
|
|
||||||
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to establish SSH connection: %v", err)
|
log.Printf("failed to establish SSH connection: %v", err)
|
||||||
err = conn.Close()
|
err := conn.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to close SSH connection: %v", err)
|
log.Printf("failed to close SSH connection: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
ctx := context.Background()
|
||||||
defer func(sshConn *ssh.ServerConn) {
|
log.Println("SSH connection established:", sshConn.User())
|
||||||
err = sshConn.Close()
|
|
||||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
|
|
||||||
log.Printf("failed to close SSH server: %v", err)
|
|
||||||
}
|
|
||||||
}(sshConn)
|
|
||||||
|
|
||||||
user := "UNAUTHORIZED"
|
user := "UNAUTHORIZED"
|
||||||
if s.grpcClient != nil {
|
if s.grpcClient != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
_, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User())
|
_, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User())
|
||||||
user = u
|
user = u
|
||||||
cancel()
|
|
||||||
}
|
}
|
||||||
log.Println("SSH connection established:", sshConn.User())
|
|
||||||
sshSession := session.New(&session.Config{
|
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user)
|
||||||
Randomizer: s.randomizer,
|
|
||||||
Config: s.config,
|
|
||||||
Conn: sshConn,
|
|
||||||
InitialReq: forwardingReqs,
|
|
||||||
SshChan: chans,
|
|
||||||
SessionRegistry: s.sessionRegistry,
|
|
||||||
PortRegistry: s.portRegistry,
|
|
||||||
User: user,
|
|
||||||
})
|
|
||||||
err = sshSession.Start()
|
err = sshSession.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("SSH session ended with error: %s", err.Error())
|
log.Printf("SSH session ended with error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,880 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/rsa"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
"tunnel_pls/session/slug"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockRandom struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockRandom) String(length int) (string, error) {
|
|
||||||
args := m.Called(length)
|
|
||||||
return args.String(0), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockConfig struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockConfig) Domain() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
|
||||||
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
|
||||||
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
|
||||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
|
||||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
|
||||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
|
||||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
|
||||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
|
||||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) Mode() types.ServerMode {
|
|
||||||
args := m.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
switch v := args.Get(0).(type) {
|
|
||||||
case types.ServerMode:
|
|
||||||
return v
|
|
||||||
case int:
|
|
||||||
return types.ServerMode(v)
|
|
||||||
default:
|
|
||||||
return types.ServerMode(args.Int(0))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
|
||||||
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
|
||||||
|
|
||||||
type MockSessionRegistry struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
|
||||||
args := m.Called(key)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Error(1)
|
|
||||||
}
|
|
||||||
return args.Get(0).(registry.Session), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
|
||||||
args := m.Called(user, key)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Error(1)
|
|
||||||
}
|
|
||||||
return args.Get(0).(registry.Session), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
|
||||||
args := m.Called(user, oldKey, newKey)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
|
||||||
args := m.Called(key, session)
|
|
||||||
return args.Bool(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
|
||||||
m.Called(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
|
||||||
args := m.Called(user)
|
|
||||||
return args.Get(0).([]registry.Session)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(slug.Slug)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockGRPCClient struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(*grpc.ClientConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
|
||||||
args := m.Called(ctx, token)
|
|
||||||
return args.Bool(0), args.String(1), args.Error(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
|
||||||
args := m.Called(ctx)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
|
|
||||||
args := m.Called(ctx, domain, token)
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockGRPCClient) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockPort struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
|
||||||
return m.Called(startPort, endPort).Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
|
||||||
args := m.Called()
|
|
||||||
return uint16(args.Int(0)), args.Bool(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
|
||||||
return m.Called(port, assigned).Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockPort) Claim(port uint16) bool {
|
|
||||||
return m.Called(port).Bool(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockListener struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockListener) Accept() (net.Conn, error) {
|
|
||||||
args := m.Called()
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, args.Error(1)
|
|
||||||
}
|
|
||||||
return args.Get(0).(net.Conn), args.Error(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockListener) Close() error {
|
|
||||||
return m.Called().Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockListener) Addr() net.Addr {
|
|
||||||
return m.Called().Get(0).(net.Addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
|
|
||||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
signer, _ := ssh.NewSignerFromKey(key)
|
|
||||||
config := &ssh.ServerConfig{
|
|
||||||
NoClientAuth: true,
|
|
||||||
}
|
|
||||||
config.AddHostKey(signer)
|
|
||||||
return config, signer
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
mr := new(MockRandom)
|
|
||||||
mc := new(MockConfig)
|
|
||||||
mreg := new(MockSessionRegistry)
|
|
||||||
mg := new(MockGRPCClient)
|
|
||||||
mp := new(MockPort)
|
|
||||||
sc, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
port string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "success",
|
|
||||||
port: "0",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid port",
|
|
||||||
port: "invalid",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s, err := New(mr, mc, sc, mreg, mg, mp, tt.port)
|
|
||||||
if tt.wantErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, s)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, s)
|
|
||||||
_ = s.Close()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("port already in use", func(t *testing.T) {
|
|
||||||
l, err := net.Listen("tcp", ":0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
port := l.Addr().(*net.TCPAddr).Port
|
|
||||||
defer func(l net.Listener) {
|
|
||||||
err = l.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}(l)
|
|
||||||
|
|
||||||
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, s)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClose(t *testing.T) {
|
|
||||||
mr := new(MockRandom)
|
|
||||||
mc := new(MockConfig)
|
|
||||||
mreg := new(MockSessionRegistry)
|
|
||||||
mg := new(MockGRPCClient)
|
|
||||||
mp := new(MockPort)
|
|
||||||
sc, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
t.Run("successful close", func(t *testing.T) {
|
|
||||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
|
||||||
err := s.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("close already closed listener", func(t *testing.T) {
|
|
||||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
|
||||||
_ = s.Close()
|
|
||||||
err := s.Close()
|
|
||||||
assert.Error(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("close with nil listener", func(t *testing.T) {
|
|
||||||
s := &server{
|
|
||||||
sshListener: nil,
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
assert.NotNil(t, r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
_ = s.Close()
|
|
||||||
t.Fatal("expected panic for nil listener")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStart(t *testing.T) {
|
|
||||||
mr := new(MockRandom)
|
|
||||||
mc := new(MockConfig)
|
|
||||||
mreg := new(MockSessionRegistry)
|
|
||||||
mg := new(MockGRPCClient)
|
|
||||||
mp := new(MockPort)
|
|
||||||
sc, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
t.Run("normal stop", func(t *testing.T) {
|
|
||||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
|
||||||
go func() {
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
_ = s.Close()
|
|
||||||
}()
|
|
||||||
s.Start()
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("accept error - temporary error continues loop", func(t *testing.T) {
|
|
||||||
ml := new(MockListener)
|
|
||||||
s := &server{
|
|
||||||
sshListener: ml,
|
|
||||||
sshPort: "0",
|
|
||||||
}
|
|
||||||
|
|
||||||
ml.On("Accept").Return(nil, errors.New("temporary error")).Once()
|
|
||||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
||||||
|
|
||||||
s.Start()
|
|
||||||
ml.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("accept error - immediate close", func(t *testing.T) {
|
|
||||||
ml := new(MockListener)
|
|
||||||
s := &server{
|
|
||||||
sshListener: ml,
|
|
||||||
sshPort: "0",
|
|
||||||
}
|
|
||||||
|
|
||||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
||||||
|
|
||||||
s.Start()
|
|
||||||
ml.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("accept success - connection fails SSH handshake", func(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockGrpcClient := &MockGRPCClient{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
sshConfig, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
|
||||||
|
|
||||||
mockListener := &MockListener{}
|
|
||||||
mockListener.On("Accept").Return(serverConn, nil).Once()
|
|
||||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
||||||
|
|
||||||
s := &server{
|
|
||||||
randomizer: mockRandom,
|
|
||||||
config: mockConfig,
|
|
||||||
sshPort: "0",
|
|
||||||
sshListener: mockListener,
|
|
||||||
sshConfig: sshConfig,
|
|
||||||
grpcClient: mockGrpcClient,
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
portRegistry: mockPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
go s.Start()
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
err := clientConn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
mockListener.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("accept success - valid SSH connection without auth", func(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
sshConfig, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
|
||||||
|
|
||||||
mockListener := &MockListener{}
|
|
||||||
mockListener.On("Accept").Return(serverConn, nil).Once()
|
|
||||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
||||||
|
|
||||||
s := &server{
|
|
||||||
randomizer: mockRandom,
|
|
||||||
config: mockConfig,
|
|
||||||
sshPort: "0",
|
|
||||||
sshListener: mockListener,
|
|
||||||
sshConfig: sshConfig,
|
|
||||||
grpcClient: nil,
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
portRegistry: mockPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
go s.Start()
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
err := clientConn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
mockListener.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleConnection(t *testing.T) {
|
|
||||||
t.Run("SSH handshake fails - connection closed", func(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockGrpcClient := &MockGRPCClient{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
sshConfig, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
|
||||||
|
|
||||||
s := &server{
|
|
||||||
randomizer: mockRandom,
|
|
||||||
config: mockConfig,
|
|
||||||
sshPort: "0",
|
|
||||||
sshConfig: sshConfig,
|
|
||||||
grpcClient: mockGrpcClient,
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
portRegistry: mockPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := clientConn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
s.handleConnection(serverConn)
|
|
||||||
})
|
|
||||||
|
|
||||||
// SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS
|
|
||||||
// GONNA IMPLEMENT THIS UNIT TEST LATER
|
|
||||||
|
|
||||||
//t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) {
|
|
||||||
// mockRandom := &MockRandom{}
|
|
||||||
// mockConfig := &MockConfig{}
|
|
||||||
// mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
// mockGrpcClient := &MockGRPCClient{}
|
|
||||||
// mockPort := &MockPort{}
|
|
||||||
//
|
|
||||||
// sshConfig, _ := getTestSSHConfig()
|
|
||||||
//
|
|
||||||
// serverConn, clientConn := net.Pipe()
|
|
||||||
//
|
|
||||||
// s := &server{
|
|
||||||
// randomizer: mockRandom,
|
|
||||||
// config: mockConfig,
|
|
||||||
// sshPort: "0",
|
|
||||||
// sshConfig: sshConfig,
|
|
||||||
// grpcClient: mockGrpcClient,
|
|
||||||
// sessionRegistry: mockSessionRegistry,
|
|
||||||
// portRegistry: mockPort,
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// done := make(chan bool, 1)
|
|
||||||
//
|
|
||||||
// go func() {
|
|
||||||
// s.handleConnection(serverConn)
|
|
||||||
// done <- true
|
|
||||||
// }()
|
|
||||||
//
|
|
||||||
// go func() {
|
|
||||||
// clientConn.Write([]byte("invalid ssh protocol\n"))
|
|
||||||
// clientConn.Close()
|
|
||||||
// }()
|
|
||||||
//
|
|
||||||
// select {
|
|
||||||
// case <-done:
|
|
||||||
// case <-time.After(1 * time.Second):
|
|
||||||
// t.Fatal("handleConnection did not complete in time")
|
|
||||||
// }
|
|
||||||
//})
|
|
||||||
|
|
||||||
t.Run("SSH connection established without gRPC client", func(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
serverConfig, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
mockConfig.On("Domain").Return("test.com")
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
|
||||||
mockConfig.On("SSHPort").Return("2200")
|
|
||||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
|
||||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
||||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func(listener net.Listener) {
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}(listener)
|
|
||||||
|
|
||||||
serverAddr := listener.Addr().String()
|
|
||||||
|
|
||||||
s := &server{
|
|
||||||
randomizer: mockRandom,
|
|
||||||
config: mockConfig,
|
|
||||||
sshPort: "0",
|
|
||||||
sshConfig: serverConfig,
|
|
||||||
grpcClient: nil,
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
portRegistry: mockPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan bool, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.handleConnection(conn)
|
|
||||||
done <- true
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
clientConfig := &ssh.ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
||||||
Timeout: 2 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Client dial failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func(client *ssh.Client) {
|
|
||||||
err = client.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}(client)
|
|
||||||
|
|
||||||
type forwardPayload struct {
|
|
||||||
BindAddr string
|
|
||||||
BindPort uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := ssh.Marshal(forwardPayload{
|
|
||||||
BindAddr: "localhost",
|
|
||||||
BindPort: 80,
|
|
||||||
})
|
|
||||||
|
|
||||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Forward request failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
t.Log("handleConnection completed")
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
t.Fatal("handleConnection did not complete in time")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("SSH connection established with gRPC authorization", func(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockGrpcClient := &MockGRPCClient{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
serverConfig, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
|
||||||
mockConfig.On("Domain").Return("test.com")
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
|
||||||
mockConfig.On("SSHPort").Return("2200")
|
|
||||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
|
||||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
||||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func(listener net.Listener) {
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}(listener)
|
|
||||||
|
|
||||||
serverAddr := listener.Addr().String()
|
|
||||||
|
|
||||||
s := &server{
|
|
||||||
randomizer: mockRandom,
|
|
||||||
config: mockConfig,
|
|
||||||
sshPort: "0",
|
|
||||||
sshConfig: serverConfig,
|
|
||||||
grpcClient: mockGrpcClient,
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
portRegistry: mockPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan bool, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.handleConnection(conn)
|
|
||||||
done <- true
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
clientConfig := &ssh.ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
||||||
Timeout: 2 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Client dial failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func(client *ssh.Client) {
|
|
||||||
err = client.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}(client)
|
|
||||||
|
|
||||||
type forwardPayload struct {
|
|
||||||
BindAddr string
|
|
||||||
BindPort uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := ssh.Marshal(forwardPayload{
|
|
||||||
BindAddr: "localhost",
|
|
||||||
BindPort: 80,
|
|
||||||
})
|
|
||||||
|
|
||||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Forward request failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
mockGrpcClient.AssertExpectations(t)
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
t.Fatal("handleConnection did not complete in time")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("SSH connection with gRPC authorization error", func(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockGrpcClient := &MockGRPCClient{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
serverConfig, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
|
||||||
mockConfig.On("Domain").Return("test.com")
|
|
||||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
|
||||||
mockConfig.On("SSHPort").Return("2200")
|
|
||||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
|
||||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
||||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer func(listener net.Listener) {
|
|
||||||
err = listener.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}(listener)
|
|
||||||
|
|
||||||
serverAddr := listener.Addr().String()
|
|
||||||
|
|
||||||
s := &server{
|
|
||||||
randomizer: mockRandom,
|
|
||||||
config: mockConfig,
|
|
||||||
sshPort: "0",
|
|
||||||
sshConfig: serverConfig,
|
|
||||||
grpcClient: mockGrpcClient,
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
portRegistry: mockPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan bool, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.handleConnection(conn)
|
|
||||||
done <- true
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
|
|
||||||
clientConfig := &ssh.ClientConfig{
|
|
||||||
User: "testuser",
|
|
||||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
||||||
Timeout: 2 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Client dial failed: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func(client *ssh.Client) {
|
|
||||||
_ = client.Close()
|
|
||||||
}(client)
|
|
||||||
|
|
||||||
type forwardPayload struct {
|
|
||||||
BindAddr string
|
|
||||||
BindPort uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := ssh.Marshal(forwardPayload{
|
|
||||||
BindAddr: "localhost",
|
|
||||||
BindPort: 8080,
|
|
||||||
})
|
|
||||||
|
|
||||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Forward request failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
mockGrpcClient.AssertExpectations(t)
|
|
||||||
case <-time.After(5 * time.Second):
|
|
||||||
t.Fatal("handleConnection did not complete in time")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("connection cleanup on close", func(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
serverConfig, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
|
||||||
|
|
||||||
s := &server{
|
|
||||||
randomizer: mockRandom,
|
|
||||||
config: mockConfig,
|
|
||||||
sshPort: "0",
|
|
||||||
sshConfig: serverConfig,
|
|
||||||
grpcClient: nil,
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
portRegistry: mockPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
done := make(chan bool, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
s.handleConnection(serverConn)
|
|
||||||
done <- true
|
|
||||||
}()
|
|
||||||
|
|
||||||
err := clientConn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(1 * time.Second):
|
|
||||||
t.Fatal("handleConnection did not complete in time")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntegration(t *testing.T) {
|
|
||||||
t.Run("full server lifecycle", func(t *testing.T) {
|
|
||||||
mr := new(MockRandom)
|
|
||||||
mc := new(MockConfig)
|
|
||||||
mreg := new(MockSessionRegistry)
|
|
||||||
mg := new(MockGRPCClient)
|
|
||||||
mp := new(MockPort)
|
|
||||||
sc, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
s, err := New(mr, mc, sc, mreg, mg, mp, "0")
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, s)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
err := s.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}()
|
|
||||||
|
|
||||||
s.Start()
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("multiple connections", func(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
sshConfig, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
conn1Server, conn1Client := net.Pipe()
|
|
||||||
conn2Server, conn2Client := net.Pipe()
|
|
||||||
|
|
||||||
mockListener := &MockListener{}
|
|
||||||
mockListener.On("Accept").Return(conn1Server, nil).Once()
|
|
||||||
mockListener.On("Accept").Return(conn2Server, nil).Once()
|
|
||||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
||||||
|
|
||||||
s := &server{
|
|
||||||
randomizer: mockRandom,
|
|
||||||
config: mockConfig,
|
|
||||||
sshPort: "0",
|
|
||||||
sshListener: mockListener,
|
|
||||||
sshConfig: sshConfig,
|
|
||||||
grpcClient: nil,
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
portRegistry: mockPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
go s.Start()
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
_ = conn1Client.Close()
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
_ = conn2Client.Close()
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
|
|
||||||
mockListener.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorHandling(t *testing.T) {
|
|
||||||
t.Run("write error during SSH handshake", func(t *testing.T) {
|
|
||||||
mockRandom := &MockRandom{}
|
|
||||||
mockConfig := &MockConfig{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
|
|
||||||
sshConfig, _ := getTestSSHConfig()
|
|
||||||
|
|
||||||
serverConn, clientConn := net.Pipe()
|
|
||||||
err := clientConn.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
s := &server{
|
|
||||||
randomizer: mockRandom,
|
|
||||||
config: mockConfig,
|
|
||||||
sshPort: "0",
|
|
||||||
sshConfig: sshConfig,
|
|
||||||
grpcClient: nil,
|
|
||||||
sessionRegistry: mockSessionRegistry,
|
|
||||||
portRegistry: mockPort,
|
|
||||||
}
|
|
||||||
|
|
||||||
s.handleConnection(serverConn)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
+336
@@ -0,0 +1,336 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
"tunnel_pls/internal/config"
|
||||||
|
|
||||||
|
"github.com/caddyserver/certmagic"
|
||||||
|
"github.com/libdns/cloudflare"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TLSManager interface {
|
||||||
|
userCertsExistAndValid() bool
|
||||||
|
loadUserCerts() error
|
||||||
|
startCertWatcher()
|
||||||
|
initCertMagic() error
|
||||||
|
getTLSConfig() *tls.Config
|
||||||
|
getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type tlsManager struct {
|
||||||
|
domain string
|
||||||
|
certPath string
|
||||||
|
keyPath string
|
||||||
|
storagePath string
|
||||||
|
|
||||||
|
userCert *tls.Certificate
|
||||||
|
userCertMu sync.RWMutex
|
||||||
|
|
||||||
|
magic *certmagic.Config
|
||||||
|
|
||||||
|
useCertMagic bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var globalTLSManager TLSManager
|
||||||
|
var tlsManagerOnce sync.Once
|
||||||
|
|
||||||
|
func NewTLSConfig(domain string) (*tls.Config, error) {
|
||||||
|
var initErr error
|
||||||
|
|
||||||
|
tlsManagerOnce.Do(func() {
|
||||||
|
certPath := "certs/tls/cert.pem"
|
||||||
|
keyPath := "certs/tls/privkey.pem"
|
||||||
|
storagePath := "certs/tls/certmagic"
|
||||||
|
|
||||||
|
tm := &tlsManager{
|
||||||
|
domain: domain,
|
||||||
|
certPath: certPath,
|
||||||
|
keyPath: keyPath,
|
||||||
|
storagePath: storagePath,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tm.userCertsExistAndValid() {
|
||||||
|
log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath)
|
||||||
|
if err := tm.loadUserCerts(); err != nil {
|
||||||
|
initErr = fmt.Errorf("failed to load user certificates: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tm.useCertMagic = false
|
||||||
|
tm.startCertWatcher()
|
||||||
|
} else {
|
||||||
|
if !isACMEConfigComplete() {
|
||||||
|
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
|
||||||
|
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
|
||||||
|
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
|
||||||
|
if err := tm.initCertMagic(); err != nil {
|
||||||
|
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tm.useCertMagic = true
|
||||||
|
}
|
||||||
|
|
||||||
|
globalTLSManager = tm
|
||||||
|
})
|
||||||
|
|
||||||
|
if initErr != nil {
|
||||||
|
return nil, initErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return globalTLSManager.getTLSConfig(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isACMEConfigComplete() bool {
|
||||||
|
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
|
||||||
|
return cfAPIToken != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||||
|
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
||||||
|
log.Printf("Certificate file not found: %s", tm.certPath)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(tm.keyPath); os.IsNotExist(err) {
|
||||||
|
log.Printf("Key file not found: %s", tm.keyPath)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValidateCertDomains(tm.certPath, tm.domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateCertDomains(certPath, domain string) bool {
|
||||||
|
certPEM, err := os.ReadFile(certPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to read certificate: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(certPEM)
|
||||||
|
if block == nil {
|
||||||
|
log.Printf("Failed to decode PEM block from certificate")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := x509.ParseCertificate(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to parse certificate: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(cert.NotAfter) {
|
||||||
|
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) {
|
||||||
|
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var certDomains []string
|
||||||
|
if cert.Subject.CommonName != "" {
|
||||||
|
certDomains = append(certDomains, cert.Subject.CommonName)
|
||||||
|
}
|
||||||
|
certDomains = append(certDomains, cert.DNSNames...)
|
||||||
|
|
||||||
|
hasBase := false
|
||||||
|
hasWildcard := false
|
||||||
|
wildcardDomain := "*." + domain
|
||||||
|
|
||||||
|
for _, d := range certDomains {
|
||||||
|
if d == domain {
|
||||||
|
hasBase = true
|
||||||
|
}
|
||||||
|
if d == wildcardDomain {
|
||||||
|
hasWildcard = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasBase {
|
||||||
|
log.Printf("Certificate does not cover base domain: %s", domain)
|
||||||
|
}
|
||||||
|
if !hasWildcard {
|
||||||
|
log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hasBase && hasWildcard
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) loadUserCerts() error {
|
||||||
|
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tm.userCertMu.Lock()
|
||||||
|
tm.userCert = &cert
|
||||||
|
tm.userCertMu.Unlock()
|
||||||
|
|
||||||
|
log.Printf("Loaded user certificates successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) startCertWatcher() {
|
||||||
|
go func() {
|
||||||
|
var lastCertMod, lastKeyMod time.Time
|
||||||
|
|
||||||
|
if info, err := os.Stat(tm.certPath); err == nil {
|
||||||
|
lastCertMod = info.ModTime()
|
||||||
|
}
|
||||||
|
if info, err := os.Stat(tm.keyPath); err == nil {
|
||||||
|
lastKeyMod = info.ModTime()
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
certInfo, certErr := os.Stat(tm.certPath)
|
||||||
|
keyInfo, keyErr := os.Stat(tm.keyPath)
|
||||||
|
|
||||||
|
if certErr != nil || keyErr != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
|
||||||
|
log.Printf("Certificate files changed, reloading...")
|
||||||
|
|
||||||
|
if !ValidateCertDomains(tm.certPath, tm.domain) {
|
||||||
|
log.Printf("New certificates don't cover required domains")
|
||||||
|
|
||||||
|
if !isACMEConfigComplete() {
|
||||||
|
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Switching to CertMagic for automatic certificate management")
|
||||||
|
if err := tm.initCertMagic(); err != nil {
|
||||||
|
log.Printf("Failed to initialize CertMagic: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tm.useCertMagic = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tm.loadUserCerts(); err != nil {
|
||||||
|
log.Printf("Failed to reload certificates: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lastCertMod = certInfo.ModTime()
|
||||||
|
lastKeyMod = keyInfo.ModTime()
|
||||||
|
log.Printf("Certificates reloaded successfully")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) initCertMagic() error {
|
||||||
|
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain)
|
||||||
|
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
|
||||||
|
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
|
||||||
|
|
||||||
|
if cfAPIToken == "" {
|
||||||
|
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfProvider := &cloudflare.Provider{
|
||||||
|
APIToken: cfAPIToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
storage := &certmagic.FileStorage{Path: tm.storagePath}
|
||||||
|
|
||||||
|
cache := certmagic.NewCache(certmagic.CacheOptions{
|
||||||
|
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
|
||||||
|
return tm.magic, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
magic := certmagic.New(cache, certmagic.Config{
|
||||||
|
Storage: storage,
|
||||||
|
})
|
||||||
|
|
||||||
|
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
||||||
|
Email: acmeEmail,
|
||||||
|
Agreed: true,
|
||||||
|
DNS01Solver: &certmagic.DNS01Solver{
|
||||||
|
DNSManager: certmagic.DNSManager{
|
||||||
|
DNSProvider: cfProvider,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if acmeStaging {
|
||||||
|
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
|
||||||
|
log.Printf("Using Let's Encrypt staging server")
|
||||||
|
} else {
|
||||||
|
acmeIssuer.CA = certmagic.LetsEncryptProductionCA
|
||||||
|
log.Printf("Using Let's Encrypt production server")
|
||||||
|
}
|
||||||
|
|
||||||
|
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
||||||
|
tm.magic = magic
|
||||||
|
|
||||||
|
domains := []string{tm.domain, "*." + tm.domain}
|
||||||
|
log.Printf("Requesting certificates for: %v", domains)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := magic.ManageSync(ctx, domains); err != nil {
|
||||||
|
return fmt.Errorf("failed to obtain certificates: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Certificates obtained successfully for %v", domains)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) getTLSConfig() *tls.Config {
|
||||||
|
return &tls.Config{
|
||||||
|
GetCertificate: tm.getCertificate,
|
||||||
|
MinVersion: tls.VersionTLS13,
|
||||||
|
MaxVersion: tls.VersionTLS13,
|
||||||
|
|
||||||
|
SessionTicketsDisabled: false,
|
||||||
|
|
||||||
|
CipherSuites: []uint16{
|
||||||
|
tls.TLS_AES_128_GCM_SHA256,
|
||||||
|
tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||||
|
},
|
||||||
|
|
||||||
|
CurvePreferences: []tls.CurveID{
|
||||||
|
tls.X25519,
|
||||||
|
},
|
||||||
|
|
||||||
|
ClientAuth: tls.NoClientCert,
|
||||||
|
NextProtos: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
if tm.useCertMagic {
|
||||||
|
return tm.magic.GetCertificate(hello)
|
||||||
|
}
|
||||||
|
|
||||||
|
tm.userCertMu.RLock()
|
||||||
|
defer tm.userCertMu.RUnlock()
|
||||||
|
|
||||||
|
if tm.userCert == nil {
|
||||||
|
return nil, fmt.Errorf("no certificate available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return tm.userCert, nil
|
||||||
|
}
|
||||||
+176
-112
@@ -1,14 +1,15 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/session/slug"
|
"tunnel_pls/session/slug"
|
||||||
"tunnel_pls/types"
|
"tunnel_pls/types"
|
||||||
@@ -16,176 +17,239 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Forwarder interface {
|
var bufferPool = sync.Pool{
|
||||||
SetType(tunnelType types.TunnelType)
|
New: func() interface{} {
|
||||||
SetForwardedPort(port uint16)
|
bufSize := config.GetBufferSize()
|
||||||
SetListener(listener net.Listener)
|
return make([]byte, bufSize)
|
||||||
Listener() net.Listener
|
},
|
||||||
TunnelType() types.TunnelType
|
|
||||||
ForwardedPort() uint16
|
|
||||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
|
||||||
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
|
|
||||||
Close() error
|
|
||||||
}
|
}
|
||||||
type forwarder struct {
|
|
||||||
|
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||||
|
buf := bufferPool.Get().([]byte)
|
||||||
|
defer bufferPool.Put(buf)
|
||||||
|
return io.CopyBuffer(dst, src, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Forwarder struct {
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
tunnelType types.TunnelType
|
tunnelType types.TunnelType
|
||||||
forwardedPort uint16
|
forwardedPort uint16
|
||||||
slug slug.Slug
|
slugManager slug.Manager
|
||||||
conn ssh.Conn
|
lifecycle Lifecycle
|
||||||
bufferPool sync.Pool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
|
func NewForwarder(slugManager slug.Manager) *Forwarder {
|
||||||
return &forwarder{
|
return &Forwarder{
|
||||||
listener: nil,
|
listener: nil,
|
||||||
tunnelType: types.TunnelTypeUNKNOWN,
|
tunnelType: "",
|
||||||
forwardedPort: 0,
|
forwardedPort: 0,
|
||||||
slug: slug,
|
slugManager: slugManager,
|
||||||
conn: conn,
|
lifecycle: nil,
|
||||||
bufferPool: sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
bufSize := config.BufferSize()
|
|
||||||
buf := make([]byte, bufSize)
|
|
||||||
return &buf
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
type Lifecycle interface {
|
||||||
buf := f.bufferPool.Get().(*[]byte)
|
GetConnection() ssh.Conn
|
||||||
defer f.bufferPool.Put(buf)
|
|
||||||
return io.CopyBuffer(dst, src, *buf)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
type ForwardingController interface {
|
||||||
payload := createForwardedTCPIPPayload(origin, f.forwardedPort)
|
AcceptTCPConnections()
|
||||||
type channelResult struct {
|
SetType(tunnelType types.TunnelType)
|
||||||
channel ssh.Channel
|
GetTunnelType() types.TunnelType
|
||||||
reqs <-chan *ssh.Request
|
GetForwardedPort() uint16
|
||||||
err error
|
SetForwardedPort(port uint16)
|
||||||
}
|
SetListener(listener net.Listener)
|
||||||
resultChan := make(chan channelResult, 1)
|
GetListener() net.Listener
|
||||||
|
Close() error
|
||||||
|
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
||||||
|
SetLifecycle(lifecycle Lifecycle)
|
||||||
|
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||||
|
WriteBadGatewayResponse(dst io.Writer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
|
||||||
|
f.lifecycle = lifecycle
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) AcceptTCPConnections() {
|
||||||
|
for {
|
||||||
|
conn, err := f.GetListener().Accept()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Error accepting connection: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
log.Printf("Failed to set connection deadline: %v", err)
|
||||||
|
if closeErr := conn.Close(); closeErr != nil {
|
||||||
|
log.Printf("Failed to close connection: %v", closeErr)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||||
|
|
||||||
|
type channelResult struct {
|
||||||
|
channel ssh.Channel
|
||||||
|
reqs <-chan *ssh.Request
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
resultChan := make(chan channelResult, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||||
|
resultChan <- channelResult{channel, reqs, err}
|
||||||
|
}()
|
||||||
|
|
||||||
go func() {
|
|
||||||
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
|
||||||
select {
|
select {
|
||||||
case resultChan <- channelResult{channel, reqs, err}:
|
case result := <-resultChan:
|
||||||
case <-ctx.Done():
|
if result.err != nil {
|
||||||
if channel != nil {
|
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
|
||||||
_ = channel.Close()
|
if closeErr := conn.Close(); closeErr != nil {
|
||||||
go ssh.DiscardRequests(reqs)
|
log.Printf("Failed to close connection: %v", closeErr)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
log.Printf("Failed to clear connection deadline: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go ssh.DiscardRequests(result.reqs)
|
||||||
|
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
|
||||||
|
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
log.Printf("Timeout opening forwarded-tcpip channel")
|
||||||
|
if closeErr := conn.Close(); closeErr != nil {
|
||||||
|
log.Printf("Failed to close connection: %v", closeErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||||
|
defer func() {
|
||||||
|
_, err := io.Copy(io.Discard, src)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to discard connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = src.Close()
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
log.Printf("Error closing source channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if closer, ok := dst.(io.Closer); ok {
|
||||||
|
err = closer.Close()
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
log.Printf("Error closing destination connection: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
log.Printf("Handling new forwarded connection from %s", remoteAddr)
|
||||||
case result := <-resultChan:
|
|
||||||
return result.channel, result.reqs, result.err
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeWriter(w io.Writer) error {
|
|
||||||
if cw, ok := w.(interface{ CloseWrite() error }); ok {
|
|
||||||
return cw.CloseWrite()
|
|
||||||
}
|
|
||||||
if closer, ok := w.(io.Closer); ok {
|
|
||||||
return closer.Close()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
|
|
||||||
var errs []error
|
|
||||||
_, err := f.copyWithBuffer(dst, src)
|
|
||||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
|
||||||
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
|
|
||||||
errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
|
|
||||||
}
|
|
||||||
return errors.Join(errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
|
||||||
defer func() {
|
|
||||||
_, _ = io.Copy(io.Discard, src)
|
|
||||||
}()
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err := f.copyAndClose(dst, src, "src to dst")
|
_, err := copyWithBuffer(dst, src)
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||||
log.Println("Error during copy: ", err)
|
log.Printf("Error copying src→dst: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err := f.copyAndClose(src, dst, "dst to src")
|
_, err := copyWithBuffer(src, dst)
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||||
log.Println("Error during copy: ", err)
|
log.Printf("Error copying dst→src: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) SetType(tunnelType types.TunnelType) {
|
func (f *Forwarder) SetType(tunnelType types.TunnelType) {
|
||||||
f.tunnelType = tunnelType
|
f.tunnelType = tunnelType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) TunnelType() types.TunnelType {
|
func (f *Forwarder) GetTunnelType() types.TunnelType {
|
||||||
return f.tunnelType
|
return f.tunnelType
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) ForwardedPort() uint16 {
|
func (f *Forwarder) GetForwardedPort() uint16 {
|
||||||
return f.forwardedPort
|
return f.forwardedPort
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) SetForwardedPort(port uint16) {
|
func (f *Forwarder) SetForwardedPort(port uint16) {
|
||||||
f.forwardedPort = port
|
f.forwardedPort = port
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) SetListener(listener net.Listener) {
|
func (f *Forwarder) SetListener(listener net.Listener) {
|
||||||
f.listener = listener
|
f.listener = listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) Listener() net.Listener {
|
func (f *Forwarder) GetListener() net.Listener {
|
||||||
return f.listener
|
return f.listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *forwarder) Close() error {
|
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||||
if f.Listener() != nil {
|
_, err := dst.Write(types.BadGatewayResponse)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to write Bad Gateway response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) Close() error {
|
||||||
|
if f.GetListener() != nil {
|
||||||
return f.listener.Close()
|
return f.listener.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte {
|
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||||
host, portStr, _ := net.SplitHostPort(origin.String())
|
var buf bytes.Buffer
|
||||||
port, _ := strconv.Atoi(portStr)
|
|
||||||
|
|
||||||
forwardPayload := struct {
|
host, originPort := parseAddr(origin.String())
|
||||||
DestAddr string
|
|
||||||
DestPort uint32
|
writeSSHString(&buf, "localhost")
|
||||||
OriginAddr string
|
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort()))
|
||||||
OriginPort uint32
|
if err != nil {
|
||||||
}{
|
log.Printf("Failed to write string to buffer: %v", err)
|
||||||
DestAddr: "localhost",
|
return nil
|
||||||
DestPort: uint32(destPort),
|
|
||||||
OriginAddr: host,
|
|
||||||
OriginPort: uint32(port),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ssh.Marshal(forwardPayload)
|
writeSSHString(&buf, host)
|
||||||
|
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to write string to buffer: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAddr(addr string) (string, uint16) {
|
||||||
|
host, portStr, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||||
|
return "0.0.0.0", uint16(0)
|
||||||
|
}
|
||||||
|
port, _ := strconv.Atoi(portStr)
|
||||||
|
return host, uint16(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeSSHString(buffer *bytes.Buffer, str string) {
|
||||||
|
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to write string to buffer: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buffer.WriteString(str)
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,291 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
portUtil "tunnel_pls/internal/port"
|
||||||
|
"tunnel_pls/internal/random"
|
||||||
|
"tunnel_pls/types"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||||
|
|
||||||
|
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||||
|
for req := range GlobalRequest {
|
||||||
|
switch req.Type {
|
||||||
|
case "shell", "pty-req":
|
||||||
|
err := req.Reply(true, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "window-change":
|
||||||
|
p := req.Payload
|
||||||
|
if len(p) < 16 {
|
||||||
|
log.Println("invalid window-change payload")
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cols := binary.BigEndian.Uint32(p[0:4])
|
||||||
|
rows := binary.BigEndian.Uint32(p[4:8])
|
||||||
|
|
||||||
|
s.interaction.SetWH(int(cols), int(rows))
|
||||||
|
|
||||||
|
err := req.Reply(true, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Println("Unknown request type:", req.Type)
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
|
||||||
|
log.Println("Port forwarding request detected")
|
||||||
|
|
||||||
|
reader := bytes.NewReader(req.Payload)
|
||||||
|
|
||||||
|
addr, err := readSSHString(reader)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to read address from payload:", err)
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = s.lifecycle.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var rawPortToBind uint32
|
||||||
|
if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
|
||||||
|
log.Println("Failed to read port from payload:", err)
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = s.lifecycle.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if rawPortToBind > 65535 {
|
||||||
|
log.Printf("Port %d is larger than allowed port of 65535", rawPortToBind)
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = s.lifecycle.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
portToBind := uint16(rawPortToBind)
|
||||||
|
if isBlockedPort(portToBind) {
|
||||||
|
log.Printf("Port %d is blocked or restricted", portToBind)
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = s.lifecycle.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if portToBind == 80 || portToBind == 443 {
|
||||||
|
s.HandleHTTPForward(req, portToBind)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if portToBind == 0 {
|
||||||
|
unassign, success := portUtil.Default.GetUnassignedPort()
|
||||||
|
portToBind = unassign
|
||||||
|
if !success {
|
||||||
|
log.Println("No available port")
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = s.lifecycle.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
|
||||||
|
log.Printf("Port %d is already in use or restricted", portToBind)
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = s.lifecycle.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = portUtil.Default.SetPortStatus(portToBind, true)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to set port status:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.HandleTCPForward(req, addr, portToBind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
|
||||||
|
slug := random.GenerateRandomString(20)
|
||||||
|
|
||||||
|
if !s.registry.Register(slug, s) {
|
||||||
|
log.Printf("Failed to register client with slug: %s", slug)
|
||||||
|
err := req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to write port to buffer:", err)
|
||||||
|
s.registry.Remove(slug)
|
||||||
|
err = req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("HTTP forwarding approved on port: %d", portToBind)
|
||||||
|
|
||||||
|
err = req.Reply(true, buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
s.registry.Remove(slug)
|
||||||
|
err = req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.forwarder.SetType(types.HTTP)
|
||||||
|
s.forwarder.SetForwardedPort(portToBind)
|
||||||
|
s.slugManager.Set(slug)
|
||||||
|
s.lifecycle.SetStatus(types.RUNNING)
|
||||||
|
s.interaction.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
|
||||||
|
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Port %d is already in use or restricted", portToBind)
|
||||||
|
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
|
||||||
|
log.Printf("Failed to reset port status: %v", setErr)
|
||||||
|
}
|
||||||
|
err = req.Reply(false, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = s.lifecycle.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to write port to buffer:", err)
|
||||||
|
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
|
||||||
|
log.Printf("Failed to reset port status: %v", setErr)
|
||||||
|
}
|
||||||
|
err = listener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to close listener: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("TCP forwarding approved on port: %d", portToBind)
|
||||||
|
err = req.Reply(true, buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Failed to reply to request:", err)
|
||||||
|
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
|
||||||
|
log.Printf("Failed to reset port status: %v", setErr)
|
||||||
|
}
|
||||||
|
err = listener.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to close listener: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.forwarder.SetType(types.TCP)
|
||||||
|
s.forwarder.SetListener(listener)
|
||||||
|
s.forwarder.SetForwardedPort(portToBind)
|
||||||
|
s.lifecycle.SetStatus(types.RUNNING)
|
||||||
|
go s.forwarder.AcceptTCPConnections()
|
||||||
|
s.interaction.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSSHString(reader *bytes.Reader) (string, error) {
|
||||||
|
var length uint32
|
||||||
|
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
strBytes := make([]byte, length)
|
||||||
|
if _, err := reader.Read(strBytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(strBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBlockedPort(port uint16) bool {
|
||||||
|
if port == 80 || port == 443 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if port < 1024 && port != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, p := range blockedReservedPorts {
|
||||||
|
if p == port {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
package interaction
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/bubbles/textinput"
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
|
||||||
"github.com/charmbracelet/lipgloss"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m *model) comingSoonUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|
||||||
m.showingComingSoon = false
|
|
||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) comingSoonView() string {
|
|
||||||
isCompact := shouldUseCompactLayout(m.width, 60)
|
|
||||||
|
|
||||||
var boxPadding int
|
|
||||||
var boxMargin int
|
|
||||||
if isCompact {
|
|
||||||
boxPadding = 1
|
|
||||||
boxMargin = 1
|
|
||||||
} else {
|
|
||||||
boxPadding = 3
|
|
||||||
boxMargin = 2
|
|
||||||
}
|
|
||||||
|
|
||||||
titleStyle := lipgloss.NewStyle().
|
|
||||||
Bold(true).
|
|
||||||
Foreground(lipgloss.Color("#7D56F4")).
|
|
||||||
PaddingTop(1).
|
|
||||||
PaddingBottom(1)
|
|
||||||
|
|
||||||
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
|
||||||
messageBoxStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color("#FAFAFA")).
|
|
||||||
Background(lipgloss.Color("#1A1A2E")).
|
|
||||||
Bold(true).
|
|
||||||
Border(lipgloss.RoundedBorder()).
|
|
||||||
BorderForeground(lipgloss.Color("#7D56F4")).
|
|
||||||
Padding(1, boxPadding).
|
|
||||||
MarginTop(boxMargin).
|
|
||||||
MarginBottom(boxMargin).
|
|
||||||
Width(messageBoxWidth).
|
|
||||||
Align(lipgloss.Center)
|
|
||||||
|
|
||||||
helpStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color("#666666")).
|
|
||||||
Italic(true).
|
|
||||||
MarginTop(1)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString("\n\n")
|
|
||||||
|
|
||||||
var title string
|
|
||||||
if shouldUseCompactLayout(m.width, 40) {
|
|
||||||
title = "Coming Soon"
|
|
||||||
} else {
|
|
||||||
title = "⏳ Coming Soon"
|
|
||||||
}
|
|
||||||
b.WriteString(titleStyle.Render(title))
|
|
||||||
b.WriteString("\n\n")
|
|
||||||
|
|
||||||
var message string
|
|
||||||
if shouldUseCompactLayout(m.width, 50) {
|
|
||||||
message = "Coming soon!\nStay tuned."
|
|
||||||
} else {
|
|
||||||
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
|
|
||||||
}
|
|
||||||
b.WriteString(messageBoxStyle.Render(message))
|
|
||||||
b.WriteString("\n\n")
|
|
||||||
|
|
||||||
var helpText string
|
|
||||||
if shouldUseCompactLayout(m.width, 60) {
|
|
||||||
helpText = "Press any key..."
|
|
||||||
} else {
|
|
||||||
helpText = "This message will disappear in 5 seconds or press any key..."
|
|
||||||
}
|
|
||||||
b.WriteString(helpStyle.Render(helpText))
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
@@ -1,86 +0,0 @@
|
|||||||
package interaction
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/bubbles/key"
|
|
||||||
"github.com/charmbracelet/bubbles/textinput"
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
|
||||||
"github.com/charmbracelet/lipgloss"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m *model) handleCommandSelection(item commandItem) (tea.Model, tea.Cmd) {
|
|
||||||
switch item.name {
|
|
||||||
case "slug":
|
|
||||||
m.showingCommands = false
|
|
||||||
m.editingSlug = true
|
|
||||||
m.slugInput.SetValue(m.interaction.slug.String())
|
|
||||||
m.slugInput.Focus()
|
|
||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
|
||||||
case "tunnel-type":
|
|
||||||
m.showingCommands = false
|
|
||||||
m.showingComingSoon = true
|
|
||||||
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
|
|
||||||
default:
|
|
||||||
m.showingCommands = false
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|
||||||
var cmd tea.Cmd
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case key.Matches(msg, m.keymap.quit), msg.String() == "esc":
|
|
||||||
m.showingCommands = false
|
|
||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
|
||||||
case msg.String() == "enter":
|
|
||||||
selectedItem := m.commandList.SelectedItem()
|
|
||||||
if selectedItem != nil {
|
|
||||||
item := selectedItem.(commandItem)
|
|
||||||
return m.handleCommandSelection(item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.commandList, cmd = m.commandList.Update(msg)
|
|
||||||
return m, cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) commandsView() string {
|
|
||||||
isCompact := shouldUseCompactLayout(m.width, 60)
|
|
||||||
|
|
||||||
titleStyle := lipgloss.NewStyle().
|
|
||||||
Bold(true).
|
|
||||||
Foreground(lipgloss.Color("#7D56F4")).
|
|
||||||
PaddingTop(1).
|
|
||||||
PaddingBottom(1)
|
|
||||||
|
|
||||||
helpStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color("#666666")).
|
|
||||||
Italic(true).
|
|
||||||
MarginTop(1)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString("\n")
|
|
||||||
|
|
||||||
var title string
|
|
||||||
if shouldUseCompactLayout(m.width, 40) {
|
|
||||||
title = "Commands"
|
|
||||||
} else {
|
|
||||||
title = "⚡ Commands"
|
|
||||||
}
|
|
||||||
b.WriteString(titleStyle.Render(title))
|
|
||||||
b.WriteString("\n\n")
|
|
||||||
b.WriteString(m.commandList.View())
|
|
||||||
b.WriteString("\n")
|
|
||||||
|
|
||||||
var helpText string
|
|
||||||
if isCompact {
|
|
||||||
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
|
|
||||||
} else {
|
|
||||||
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
|
|
||||||
}
|
|
||||||
b.WriteString(helpStyle.Render(helpText))
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
@@ -1,216 +0,0 @@
|
|||||||
package interaction
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/bubbles/key"
|
|
||||||
"github.com/charmbracelet/bubbles/textinput"
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
|
||||||
"github.com/charmbracelet/lipgloss"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m *model) dashboardUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|
||||||
switch {
|
|
||||||
case key.Matches(msg, m.keymap.quit):
|
|
||||||
m.quitting = true
|
|
||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
|
|
||||||
case key.Matches(msg, m.keymap.command):
|
|
||||||
m.showingCommands = true
|
|
||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
|
||||||
}
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) dashboardView() string {
|
|
||||||
isCompact := shouldUseCompactLayout(m.width, BreakpointLarge)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString(m.renderHeader(isCompact))
|
|
||||||
b.WriteString(m.renderUserInfo(isCompact))
|
|
||||||
b.WriteString(m.renderQuickActions(isCompact))
|
|
||||||
b.WriteString(m.renderFooter(isCompact))
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderHeader(isCompact bool) string {
|
|
||||||
var b strings.Builder
|
|
||||||
|
|
||||||
asciiArtMargin := getMarginValue(isCompact, 0, 1)
|
|
||||||
asciiArtStyle := lipgloss.NewStyle().
|
|
||||||
Bold(true).
|
|
||||||
Foreground(lipgloss.Color(ColorPrimary)).
|
|
||||||
MarginBottom(asciiArtMargin)
|
|
||||||
|
|
||||||
b.WriteString(asciiArtStyle.Render(m.getASCIIArt()))
|
|
||||||
b.WriteString("\n")
|
|
||||||
|
|
||||||
if !shouldUseCompactLayout(m.width, BreakpointSmall) {
|
|
||||||
b.WriteString(m.renderSubtitle())
|
|
||||||
} else {
|
|
||||||
b.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) getASCIIArt() string {
|
|
||||||
if shouldUseCompactLayout(m.width, BreakpointTiny) {
|
|
||||||
return "TUNNEL PLS"
|
|
||||||
}
|
|
||||||
|
|
||||||
if shouldUseCompactLayout(m.width, BreakpointLarge) {
|
|
||||||
return `
|
|
||||||
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
|
|
||||||
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
|
|
||||||
}
|
|
||||||
|
|
||||||
return `
|
|
||||||
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
|
|
||||||
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
|
|
||||||
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
|
|
||||||
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
|
|
||||||
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
|
|
||||||
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderSubtitle() string {
|
|
||||||
subtitleStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorGray)).
|
|
||||||
Italic(true)
|
|
||||||
|
|
||||||
urlStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorPrimary)).
|
|
||||||
Underline(true).
|
|
||||||
Italic(true)
|
|
||||||
|
|
||||||
return subtitleStyle.Render("Secure tunnel service by Bagas • ") +
|
|
||||||
urlStyle.Render("https://fossy.my.id") + "\n\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderUserInfo(isCompact bool) string {
|
|
||||||
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
|
|
||||||
boxPadding := getMarginValue(isCompact, 1, 2)
|
|
||||||
boxMargin := getMarginValue(isCompact, 1, 2)
|
|
||||||
|
|
||||||
responsiveInfoBox := lipgloss.NewStyle().
|
|
||||||
Border(lipgloss.RoundedBorder()).
|
|
||||||
BorderForeground(lipgloss.Color(ColorPrimary)).
|
|
||||||
Padding(1, boxPadding).
|
|
||||||
MarginTop(boxMargin).
|
|
||||||
MarginBottom(boxMargin).
|
|
||||||
Width(boxMaxWidth)
|
|
||||||
|
|
||||||
infoContent := m.getUserInfoContent(isCompact)
|
|
||||||
return responsiveInfoBox.Render(infoContent) + "\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) getUserInfoContent(isCompact bool) string {
|
|
||||||
userInfoStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorWhite)).
|
|
||||||
Bold(true)
|
|
||||||
|
|
||||||
sectionHeaderStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorGray)).
|
|
||||||
Bold(true)
|
|
||||||
|
|
||||||
addressStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorWhite))
|
|
||||||
|
|
||||||
urlBoxStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorSecondary)).
|
|
||||||
Bold(true).
|
|
||||||
Italic(true)
|
|
||||||
|
|
||||||
authenticatedUser := m.interaction.user
|
|
||||||
tunnelURL := urlBoxStyle.Render(m.getTunnelURL())
|
|
||||||
|
|
||||||
if isCompact {
|
|
||||||
return fmt.Sprintf("👤 %s\n\n%s\n%s",
|
|
||||||
userInfoStyle.Render(authenticatedUser),
|
|
||||||
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
|
|
||||||
addressStyle.Render(fmt.Sprintf(" %s", tunnelURL)))
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
|
|
||||||
userInfoStyle.Render(authenticatedUser),
|
|
||||||
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
|
|
||||||
addressStyle.Render(tunnelURL))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderQuickActions(isCompact bool) string {
|
|
||||||
var b strings.Builder
|
|
||||||
|
|
||||||
titleStyle := lipgloss.NewStyle().
|
|
||||||
Bold(true).
|
|
||||||
Foreground(lipgloss.Color(ColorPrimary)).
|
|
||||||
PaddingTop(1)
|
|
||||||
|
|
||||||
b.WriteString(titleStyle.Render(m.getQuickActionsTitle()))
|
|
||||||
b.WriteString("\n")
|
|
||||||
|
|
||||||
featureMargin := getMarginValue(isCompact, 1, 2)
|
|
||||||
featureStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorWhite)).
|
|
||||||
MarginLeft(featureMargin)
|
|
||||||
|
|
||||||
keyHintStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorPrimary)).
|
|
||||||
Bold(true)
|
|
||||||
|
|
||||||
commands := m.getActionCommands(keyHintStyle)
|
|
||||||
b.WriteString(featureStyle.Render(commands.commandsText))
|
|
||||||
b.WriteString("\n")
|
|
||||||
b.WriteString(featureStyle.Render(commands.quitText))
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) getQuickActionsTitle() string {
|
|
||||||
if shouldUseCompactLayout(m.width, BreakpointTiny) {
|
|
||||||
return "Actions"
|
|
||||||
}
|
|
||||||
if shouldUseCompactLayout(m.width, BreakpointLarge) {
|
|
||||||
return "Quick Actions"
|
|
||||||
}
|
|
||||||
return "✨ Quick Actions"
|
|
||||||
}
|
|
||||||
|
|
||||||
type actionCommands struct {
|
|
||||||
commandsText string
|
|
||||||
quitText string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) getActionCommands(keyHintStyle lipgloss.Style) actionCommands {
|
|
||||||
if shouldUseCompactLayout(m.width, BreakpointSmall) {
|
|
||||||
return actionCommands{
|
|
||||||
commandsText: fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]")),
|
|
||||||
quitText: fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return actionCommands{
|
|
||||||
commandsText: fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]")),
|
|
||||||
quitText: fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]")),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderFooter(isCompact bool) string {
|
|
||||||
if isCompact {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
footerStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorDarkGray)).
|
|
||||||
Italic(true)
|
|
||||||
|
|
||||||
return "\n\n" + footerStyle.Render("Press 'C' to customize your tunnel settings")
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMarginValue(isCompact bool, compactValue, normalValue int) int {
|
|
||||||
if isCompact {
|
|
||||||
return compactValue
|
|
||||||
}
|
|
||||||
return normalValue
|
|
||||||
}
|
|
||||||
@@ -2,8 +2,10 @@ package interaction
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"strings"
|
||||||
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
"tunnel_pls/internal/random"
|
"tunnel_pls/internal/random"
|
||||||
"tunnel_pls/session/slug"
|
"tunnel_pls/session/slug"
|
||||||
@@ -19,59 +21,37 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Interaction interface {
|
type Lifecycle interface {
|
||||||
Mode() types.InteractiveMode
|
Close() error
|
||||||
SetChannel(channel ssh.Channel)
|
|
||||||
SetMode(m types.InteractiveMode)
|
|
||||||
SetWH(w, h int)
|
|
||||||
Start()
|
|
||||||
Redraw()
|
|
||||||
Send(message string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionRegistry interface {
|
type Controller interface {
|
||||||
Update(user string, oldKey, newKey types.SessionKey) error
|
SetChannel(channel ssh.Channel)
|
||||||
|
SetLifecycle(lifecycle Lifecycle)
|
||||||
|
SetSlugModificator(func(oldSlug, newSlug string) error)
|
||||||
|
Start()
|
||||||
|
SetWH(w, h int)
|
||||||
|
Redraw()
|
||||||
}
|
}
|
||||||
|
|
||||||
type Forwarder interface {
|
type Forwarder interface {
|
||||||
Close() error
|
Close() error
|
||||||
TunnelType() types.TunnelType
|
GetTunnelType() types.TunnelType
|
||||||
ForwardedPort() uint16
|
GetForwardedPort() uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
type CloseFunc func() error
|
type Interaction struct {
|
||||||
type interaction struct {
|
channel ssh.Channel
|
||||||
randomizer random.Random
|
slugManager slug.Manager
|
||||||
config config.Config
|
forwarder Forwarder
|
||||||
channel ssh.Channel
|
lifecycle Lifecycle
|
||||||
slug slug.Slug
|
updateClientSlug func(oldSlug, newSlug string) error
|
||||||
forwarder Forwarder
|
program *tea.Program
|
||||||
closeFunc CloseFunc
|
ctx context.Context
|
||||||
user string
|
cancel context.CancelFunc
|
||||||
sessionRegistry SessionRegistry
|
|
||||||
program *tea.Program
|
|
||||||
ctx context.Context
|
|
||||||
cancel context.CancelFunc
|
|
||||||
mode types.InteractiveMode
|
|
||||||
programMu sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *interaction) SetMode(m types.InteractiveMode) {
|
func (i *Interaction) SetWH(w, h int) {
|
||||||
i.mode = m
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *interaction) Mode() types.InteractiveMode {
|
|
||||||
return i.mode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *interaction) Send(message string) error {
|
|
||||||
if i.channel != nil {
|
|
||||||
_, err := i.channel.Write([]byte(message))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (i *interaction) SetWH(w, h int) {
|
|
||||||
if i.program != nil {
|
if i.program != nil {
|
||||||
i.program.Send(tea.WindowSizeMsg{
|
i.program.Send(tea.WindowSizeMsg{
|
||||||
Width: w,
|
Width: w,
|
||||||
@@ -80,42 +60,122 @@ func (i *interaction) SetWH(w, h int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(randomizer random.Random, config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
|
type commandItem struct {
|
||||||
|
name string
|
||||||
|
desc string
|
||||||
|
}
|
||||||
|
|
||||||
|
type model struct {
|
||||||
|
domain string
|
||||||
|
protocol string
|
||||||
|
tunnelType types.TunnelType
|
||||||
|
port uint16
|
||||||
|
keymap keymap
|
||||||
|
help help.Model
|
||||||
|
quitting bool
|
||||||
|
showingCommands bool
|
||||||
|
editingSlug bool
|
||||||
|
showingComingSoon bool
|
||||||
|
commandList list.Model
|
||||||
|
slugInput textinput.Model
|
||||||
|
slugError string
|
||||||
|
interaction *Interaction
|
||||||
|
width int
|
||||||
|
height int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) getTunnelURL() string {
|
||||||
|
if m.tunnelType == types.HTTP {
|
||||||
|
return buildURL(m.protocol, m.interaction.slugManager.Get(), m.domain)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
|
||||||
|
}
|
||||||
|
|
||||||
|
type keymap struct {
|
||||||
|
quit key.Binding
|
||||||
|
command key.Binding
|
||||||
|
random key.Binding
|
||||||
|
}
|
||||||
|
|
||||||
|
type tickMsg time.Time
|
||||||
|
|
||||||
|
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &interaction{
|
return &Interaction{
|
||||||
randomizer: randomizer,
|
channel: nil,
|
||||||
config: config,
|
slugManager: slugManager,
|
||||||
channel: nil,
|
forwarder: forwarder,
|
||||||
slug: slug,
|
lifecycle: nil,
|
||||||
forwarder: forwarder,
|
updateClientSlug: nil,
|
||||||
closeFunc: closeFunc,
|
program: nil,
|
||||||
user: user,
|
ctx: ctx,
|
||||||
sessionRegistry: sessionRegistry,
|
cancel: cancel,
|
||||||
program: nil,
|
|
||||||
ctx: ctx,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *interaction) SetChannel(channel ssh.Channel) {
|
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
|
||||||
|
i.lifecycle = lifecycle
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interaction) SetChannel(channel ssh.Channel) {
|
||||||
i.channel = channel
|
i.channel = channel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *interaction) Stop() {
|
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) error) {
|
||||||
|
i.updateClientSlug = modificator
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Interaction) Stop() {
|
||||||
if i.cancel != nil {
|
if i.cancel != nil {
|
||||||
i.cancel()
|
i.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
i.programMu.Lock()
|
|
||||||
defer i.programMu.Unlock()
|
|
||||||
|
|
||||||
if i.program != nil {
|
if i.program != nil {
|
||||||
i.program.Kill()
|
i.program.Kill()
|
||||||
i.program = nil
|
i.program = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
|
||||||
|
width := screenWidth - padding
|
||||||
|
if width > maxWidth {
|
||||||
|
width = maxWidth
|
||||||
|
}
|
||||||
|
if width < minWidth {
|
||||||
|
width = minWidth
|
||||||
|
}
|
||||||
|
return width
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldUseCompactLayout(width int, threshold int) bool {
|
||||||
|
return width < threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateString(s string, maxLength int) string {
|
||||||
|
if len(s) <= maxLength {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
if maxLength < 4 {
|
||||||
|
return s[:maxLength]
|
||||||
|
}
|
||||||
|
return s[:maxLength-3] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i commandItem) FilterValue() string { return i.name }
|
||||||
|
func (i commandItem) Title() string { return i.name }
|
||||||
|
func (i commandItem) Description() string { return i.desc }
|
||||||
|
|
||||||
|
func tickCmd(d time.Duration) tea.Cmd {
|
||||||
|
return tea.Tick(d, func(t time.Time) tea.Msg {
|
||||||
|
return tickMsg(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *model) Init() tea.Cmd {
|
||||||
|
return tea.Batch(textinput.Blink, tea.WindowSize())
|
||||||
|
}
|
||||||
|
|
||||||
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||||
|
var cmd tea.Cmd
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case tickMsg:
|
case tickMsg:
|
||||||
@@ -141,62 +201,543 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
|
|
||||||
case tea.KeyMsg:
|
case tea.KeyMsg:
|
||||||
if m.showingComingSoon {
|
if m.showingComingSoon {
|
||||||
return m.comingSoonUpdate(msg)
|
m.showingComingSoon = false
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.editingSlug {
|
if m.editingSlug {
|
||||||
return m.slugUpdate(msg)
|
if m.tunnelType != types.HTTP {
|
||||||
|
m.editingSlug = false
|
||||||
|
m.slugError = ""
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
|
}
|
||||||
|
switch msg.String() {
|
||||||
|
case "esc":
|
||||||
|
m.editingSlug = false
|
||||||
|
m.slugError = ""
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
|
case "enter":
|
||||||
|
inputValue := m.slugInput.Value()
|
||||||
|
if err := m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue); err != nil {
|
||||||
|
m.slugError = err.Error()
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
m.editingSlug = false
|
||||||
|
m.slugError = ""
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
|
case "ctrl+c":
|
||||||
|
m.editingSlug = false
|
||||||
|
m.slugError = ""
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
|
default:
|
||||||
|
if key.Matches(msg, m.keymap.random) {
|
||||||
|
newSubdomain := generateRandomSubdomain()
|
||||||
|
m.slugInput.SetValue(newSubdomain)
|
||||||
|
m.slugError = ""
|
||||||
|
m.slugInput, cmd = m.slugInput.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
|
m.slugError = ""
|
||||||
|
m.slugInput, cmd = m.slugInput.Update(msg)
|
||||||
|
return m, cmd
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.showingCommands {
|
if m.showingCommands {
|
||||||
return m.commandsUpdate(msg)
|
switch {
|
||||||
|
case key.Matches(msg, m.keymap.quit):
|
||||||
|
m.showingCommands = false
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
|
case msg.String() == "enter":
|
||||||
|
selectedItem := m.commandList.SelectedItem()
|
||||||
|
if selectedItem != nil {
|
||||||
|
item := selectedItem.(commandItem)
|
||||||
|
if item.name == "slug" {
|
||||||
|
m.showingCommands = false
|
||||||
|
m.editingSlug = true
|
||||||
|
m.slugInput.SetValue(m.interaction.slugManager.Get())
|
||||||
|
m.slugInput.Focus()
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
|
} else if item.name == "tunnel-type" {
|
||||||
|
m.showingCommands = false
|
||||||
|
m.showingComingSoon = true
|
||||||
|
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
|
||||||
|
}
|
||||||
|
m.showingCommands = false
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
case msg.String() == "esc":
|
||||||
|
m.showingCommands = false
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
|
}
|
||||||
|
m.commandList, cmd = m.commandList.Update(msg)
|
||||||
|
return m, cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.dashboardUpdate(msg)
|
switch {
|
||||||
|
case key.Matches(msg, m.keymap.quit):
|
||||||
|
m.quitting = true
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
|
||||||
|
case key.Matches(msg, m.keymap.command):
|
||||||
|
m.showingCommands = true
|
||||||
|
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *interaction) Redraw() {
|
func (i *Interaction) Redraw() {
|
||||||
if i.program != nil {
|
if i.program != nil {
|
||||||
i.program.Send(tea.ClearScreen())
|
i.program.Send(tea.ClearScreen())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *model) helpView() string {
|
||||||
|
return "\n" + m.help.ShortHelpView([]key.Binding{
|
||||||
|
m.keymap.command,
|
||||||
|
m.keymap.quit,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (m *model) View() string {
|
func (m *model) View() string {
|
||||||
if m.quitting {
|
if m.quitting {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.showingComingSoon {
|
if m.showingComingSoon {
|
||||||
return m.comingSoonView()
|
isCompact := shouldUseCompactLayout(m.width, 60)
|
||||||
|
|
||||||
|
var boxPadding int
|
||||||
|
var boxMargin int
|
||||||
|
if isCompact {
|
||||||
|
boxPadding = 1
|
||||||
|
boxMargin = 1
|
||||||
|
} else {
|
||||||
|
boxPadding = 3
|
||||||
|
boxMargin = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
titleStyle := lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(lipgloss.Color("#7D56F4")).
|
||||||
|
PaddingTop(1).
|
||||||
|
PaddingBottom(1)
|
||||||
|
|
||||||
|
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||||
|
messageBoxStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#FAFAFA")).
|
||||||
|
Background(lipgloss.Color("#1A1A2E")).
|
||||||
|
Bold(true).
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||||
|
Padding(1, boxPadding).
|
||||||
|
MarginTop(boxMargin).
|
||||||
|
MarginBottom(boxMargin).
|
||||||
|
Width(messageBoxWidth).
|
||||||
|
Align(lipgloss.Center)
|
||||||
|
|
||||||
|
helpStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#666666")).
|
||||||
|
Italic(true).
|
||||||
|
MarginTop(1)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
var title string
|
||||||
|
if shouldUseCompactLayout(m.width, 40) {
|
||||||
|
title = "Coming Soon"
|
||||||
|
} else {
|
||||||
|
title = "⏳ Coming Soon"
|
||||||
|
}
|
||||||
|
b.WriteString(titleStyle.Render(title))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
var message string
|
||||||
|
if shouldUseCompactLayout(m.width, 50) {
|
||||||
|
message = "Coming soon!\nStay tuned."
|
||||||
|
} else {
|
||||||
|
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
|
||||||
|
}
|
||||||
|
b.WriteString(messageBoxStyle.Render(message))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
var helpText string
|
||||||
|
if shouldUseCompactLayout(m.width, 60) {
|
||||||
|
helpText = "Press any key..."
|
||||||
|
} else {
|
||||||
|
helpText = "This message will disappear in 5 seconds or press any key..."
|
||||||
|
}
|
||||||
|
b.WriteString(helpStyle.Render(helpText))
|
||||||
|
|
||||||
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.editingSlug {
|
if m.editingSlug {
|
||||||
return m.slugView()
|
isCompact := shouldUseCompactLayout(m.width, 70)
|
||||||
|
isVeryCompact := shouldUseCompactLayout(m.width, 50)
|
||||||
|
|
||||||
|
var boxPadding int
|
||||||
|
var boxMargin int
|
||||||
|
if isVeryCompact {
|
||||||
|
boxPadding = 1
|
||||||
|
boxMargin = 1
|
||||||
|
} else if isCompact {
|
||||||
|
boxPadding = 1
|
||||||
|
boxMargin = 1
|
||||||
|
} else {
|
||||||
|
boxPadding = 2
|
||||||
|
boxMargin = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
titleStyle := lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(lipgloss.Color("#7D56F4")).
|
||||||
|
PaddingTop(1).
|
||||||
|
PaddingBottom(1)
|
||||||
|
|
||||||
|
instructionStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#FAFAFA")).
|
||||||
|
MarginTop(1)
|
||||||
|
|
||||||
|
inputBoxStyle := lipgloss.NewStyle().
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||||
|
Padding(1, boxPadding).
|
||||||
|
MarginTop(boxMargin).
|
||||||
|
MarginBottom(boxMargin)
|
||||||
|
|
||||||
|
helpStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#666666")).
|
||||||
|
Italic(true).
|
||||||
|
MarginTop(1)
|
||||||
|
|
||||||
|
errorBoxStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#FF0000")).
|
||||||
|
Background(lipgloss.Color("#3D0000")).
|
||||||
|
Bold(true).
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(lipgloss.Color("#FF0000")).
|
||||||
|
Padding(0, boxPadding).
|
||||||
|
MarginTop(1).
|
||||||
|
MarginBottom(1)
|
||||||
|
|
||||||
|
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||||
|
rulesBoxStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#FAFAFA")).
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||||
|
Padding(0, boxPadding).
|
||||||
|
MarginTop(1).
|
||||||
|
MarginBottom(1).
|
||||||
|
Width(rulesBoxWidth)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
var title string
|
||||||
|
if isVeryCompact {
|
||||||
|
title = "Edit Subdomain"
|
||||||
|
} else {
|
||||||
|
title = "🔧 Edit Subdomain"
|
||||||
|
}
|
||||||
|
b.WriteString(titleStyle.Render(title))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
if m.tunnelType != types.HTTP {
|
||||||
|
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||||
|
warningBoxStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#FFA500")).
|
||||||
|
Background(lipgloss.Color("#3D2000")).
|
||||||
|
Bold(true).
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(lipgloss.Color("#FFA500")).
|
||||||
|
Padding(1, boxPadding).
|
||||||
|
MarginTop(boxMargin).
|
||||||
|
MarginBottom(boxMargin).
|
||||||
|
Width(warningBoxWidth)
|
||||||
|
|
||||||
|
var warningText string
|
||||||
|
if isVeryCompact {
|
||||||
|
warningText = "⚠️ TCP tunnels don't support custom subdomains."
|
||||||
|
} else {
|
||||||
|
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
|
||||||
|
}
|
||||||
|
b.WriteString(warningBoxStyle.Render(warningText))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
var helpText string
|
||||||
|
if isVeryCompact {
|
||||||
|
helpText = "Press any key to go back"
|
||||||
|
} else {
|
||||||
|
helpText = "Press Enter or Esc to go back"
|
||||||
|
}
|
||||||
|
b.WriteString(helpStyle.Render(helpText))
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
var rulesContent string
|
||||||
|
if isVeryCompact {
|
||||||
|
rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
|
||||||
|
} else if isCompact {
|
||||||
|
rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
|
||||||
|
} else {
|
||||||
|
rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
|
||||||
|
}
|
||||||
|
b.WriteString(rulesBoxStyle.Render(rulesContent))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
var instruction string
|
||||||
|
if isVeryCompact {
|
||||||
|
instruction = "Custom subdomain:"
|
||||||
|
} else {
|
||||||
|
instruction = "Enter your custom subdomain:"
|
||||||
|
}
|
||||||
|
b.WriteString(instructionStyle.Render(instruction))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
if m.slugError != "" {
|
||||||
|
errorInputBoxStyle := lipgloss.NewStyle().
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(lipgloss.Color("#FF0000")).
|
||||||
|
Padding(1, boxPadding).
|
||||||
|
MarginTop(boxMargin).
|
||||||
|
MarginBottom(1)
|
||||||
|
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
|
||||||
|
b.WriteString("\n")
|
||||||
|
} else {
|
||||||
|
b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
|
||||||
|
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
|
||||||
|
|
||||||
|
if len(previewURL) > previewWidth-10 {
|
||||||
|
previewURL = truncateString(previewURL, previewWidth-10)
|
||||||
|
}
|
||||||
|
|
||||||
|
previewStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#04B575")).
|
||||||
|
Italic(true).
|
||||||
|
Width(previewWidth)
|
||||||
|
b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
var helpText string
|
||||||
|
if isVeryCompact {
|
||||||
|
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
|
||||||
|
} else {
|
||||||
|
helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
|
||||||
|
}
|
||||||
|
b.WriteString(helpStyle.Render(helpText))
|
||||||
|
|
||||||
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.showingCommands {
|
if m.showingCommands {
|
||||||
return m.commandsView()
|
isCompact := shouldUseCompactLayout(m.width, 60)
|
||||||
|
|
||||||
|
titleStyle := lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(lipgloss.Color("#7D56F4")).
|
||||||
|
PaddingTop(1).
|
||||||
|
PaddingBottom(1)
|
||||||
|
|
||||||
|
helpStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#666666")).
|
||||||
|
Italic(true).
|
||||||
|
MarginTop(1)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
var title string
|
||||||
|
if shouldUseCompactLayout(m.width, 40) {
|
||||||
|
title = "Commands"
|
||||||
|
} else {
|
||||||
|
title = "⚡ Commands"
|
||||||
|
}
|
||||||
|
b.WriteString(titleStyle.Render(title))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(m.commandList.View())
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
var helpText string
|
||||||
|
if isCompact {
|
||||||
|
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
|
||||||
|
} else {
|
||||||
|
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
|
||||||
|
}
|
||||||
|
b.WriteString(helpStyle.Render(helpText))
|
||||||
|
|
||||||
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.dashboardView()
|
titleStyle := lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(lipgloss.Color("#7D56F4")).
|
||||||
|
PaddingTop(1)
|
||||||
|
|
||||||
|
subtitleStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#888888")).
|
||||||
|
Italic(true)
|
||||||
|
|
||||||
|
urlStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#7D56F4")).
|
||||||
|
Underline(true).
|
||||||
|
Italic(true)
|
||||||
|
|
||||||
|
urlBoxStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#04B575")).
|
||||||
|
Bold(true).
|
||||||
|
Italic(true)
|
||||||
|
|
||||||
|
keyHintStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#7D56F4")).
|
||||||
|
Bold(true)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
isCompact := shouldUseCompactLayout(m.width, 85)
|
||||||
|
|
||||||
|
var asciiArtMargin int
|
||||||
|
if isCompact {
|
||||||
|
asciiArtMargin = 0
|
||||||
|
} else {
|
||||||
|
asciiArtMargin = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
asciiArtStyle := lipgloss.NewStyle().
|
||||||
|
Bold(true).
|
||||||
|
Foreground(lipgloss.Color("#7D56F4")).
|
||||||
|
MarginBottom(asciiArtMargin)
|
||||||
|
|
||||||
|
var asciiArt string
|
||||||
|
if shouldUseCompactLayout(m.width, 50) {
|
||||||
|
asciiArt = "TUNNEL PLS"
|
||||||
|
} else if isCompact {
|
||||||
|
asciiArt = `
|
||||||
|
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
|
||||||
|
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
|
||||||
|
} else {
|
||||||
|
asciiArt = `
|
||||||
|
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
|
||||||
|
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
|
||||||
|
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
|
||||||
|
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
|
||||||
|
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
|
||||||
|
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(asciiArtStyle.Render(asciiArt))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
if !shouldUseCompactLayout(m.width, 60) {
|
||||||
|
b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
|
||||||
|
b.WriteString(urlStyle.Render("https://fossy.my.id"))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
} else {
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
|
||||||
|
var boxPadding int
|
||||||
|
var boxMargin int
|
||||||
|
if isCompact {
|
||||||
|
boxPadding = 1
|
||||||
|
boxMargin = 1
|
||||||
|
} else {
|
||||||
|
boxPadding = 2
|
||||||
|
boxMargin = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
responsiveInfoBox := lipgloss.NewStyle().
|
||||||
|
Border(lipgloss.RoundedBorder()).
|
||||||
|
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||||
|
Padding(1, boxPadding).
|
||||||
|
MarginTop(boxMargin).
|
||||||
|
MarginBottom(boxMargin).
|
||||||
|
Width(boxMaxWidth)
|
||||||
|
|
||||||
|
urlDisplay := m.getTunnelURL()
|
||||||
|
if shouldUseCompactLayout(m.width, 80) && len(urlDisplay) > m.width-20 {
|
||||||
|
maxLen := m.width - 25
|
||||||
|
if maxLen > 10 {
|
||||||
|
urlDisplay = truncateString(urlDisplay, maxLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var infoContent string
|
||||||
|
if shouldUseCompactLayout(m.width, 70) {
|
||||||
|
infoContent = fmt.Sprintf("🌐 %s", urlBoxStyle.Render(urlDisplay))
|
||||||
|
} else if isCompact {
|
||||||
|
infoContent = fmt.Sprintf("🌐 Forwarding to:\n\n %s", urlBoxStyle.Render(urlDisplay))
|
||||||
|
} else {
|
||||||
|
infoContent = fmt.Sprintf("🌐 F O R W A R D I N G T O:\n\n %s", urlBoxStyle.Render(urlDisplay))
|
||||||
|
}
|
||||||
|
b.WriteString(responsiveInfoBox.Render(infoContent))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
var quickActionsTitle string
|
||||||
|
if shouldUseCompactLayout(m.width, 50) {
|
||||||
|
quickActionsTitle = "Actions"
|
||||||
|
} else if isCompact {
|
||||||
|
quickActionsTitle = "Quick Actions"
|
||||||
|
} else {
|
||||||
|
quickActionsTitle = "✨ Quick Actions"
|
||||||
|
}
|
||||||
|
b.WriteString(titleStyle.Render(quickActionsTitle))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
var featureMargin int
|
||||||
|
if isCompact {
|
||||||
|
featureMargin = 1
|
||||||
|
} else {
|
||||||
|
featureMargin = 2
|
||||||
|
}
|
||||||
|
|
||||||
|
compactFeatureStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#FAFAFA")).
|
||||||
|
MarginLeft(featureMargin)
|
||||||
|
|
||||||
|
var commandsText string
|
||||||
|
var quitText string
|
||||||
|
if shouldUseCompactLayout(m.width, 60) {
|
||||||
|
commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
|
||||||
|
quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
|
||||||
|
} else {
|
||||||
|
commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
|
||||||
|
quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(compactFeatureStyle.Render(commandsText))
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString(compactFeatureStyle.Render(quitText))
|
||||||
|
|
||||||
|
if !shouldUseCompactLayout(m.width, 70) {
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
footerStyle := lipgloss.NewStyle().
|
||||||
|
Foreground(lipgloss.Color("#666666")).
|
||||||
|
Italic(true)
|
||||||
|
b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *interaction) Start() {
|
func (i *Interaction) Start() {
|
||||||
if i.mode == types.InteractiveModeHEADLESS {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
lipgloss.SetColorProfile(termenv.TrueColor)
|
lipgloss.SetColorProfile(termenv.TrueColor)
|
||||||
|
|
||||||
|
domain := config.Getenv("DOMAIN", "localhost")
|
||||||
protocol := "http"
|
protocol := "http"
|
||||||
if i.config.TLSEnabled() {
|
if config.Getenv("TLS_ENABLED", "false") == "true" {
|
||||||
protocol = "https"
|
protocol = "https"
|
||||||
}
|
}
|
||||||
|
|
||||||
tunnelType := i.forwarder.TunnelType()
|
tunnelType := i.forwarder.GetTunnelType()
|
||||||
port := i.forwarder.ForwardedPort()
|
port := i.forwarder.GetForwardedPort()
|
||||||
|
|
||||||
items := []list.Item{
|
items := []list.Item{
|
||||||
commandItem{name: "slug", desc: "Set custom subdomain"},
|
commandItem{name: "slug", desc: "Set custom subdomain"},
|
||||||
@@ -219,8 +760,7 @@ func (i *interaction) Start() {
|
|||||||
ti.Width = 50
|
ti.Width = 50
|
||||||
|
|
||||||
m := &model{
|
m := &model{
|
||||||
randomizer: i.randomizer,
|
domain: domain,
|
||||||
domain: i.config.Domain(),
|
|
||||||
protocol: protocol,
|
protocol: protocol,
|
||||||
tunnelType: tunnelType,
|
tunnelType: tunnelType,
|
||||||
port: port,
|
port: port,
|
||||||
@@ -244,7 +784,6 @@ func (i *interaction) Start() {
|
|||||||
help: help.New(),
|
help: help.New(),
|
||||||
}
|
}
|
||||||
|
|
||||||
i.programMu.Lock()
|
|
||||||
i.program = tea.NewProgram(
|
i.program = tea.NewProgram(
|
||||||
m,
|
m,
|
||||||
tea.WithInput(i.channel),
|
tea.WithInput(i.channel),
|
||||||
@@ -255,21 +794,22 @@ func (i *interaction) Start() {
|
|||||||
tea.WithoutSignalHandler(),
|
tea.WithoutSignalHandler(),
|
||||||
tea.WithFPS(30),
|
tea.WithFPS(30),
|
||||||
)
|
)
|
||||||
i.programMu.Unlock()
|
|
||||||
|
|
||||||
_, err := i.program.Run()
|
_, err := i.program.Run()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Cannot close tea: %s \n", err)
|
log.Printf("Cannot close tea: %s \n", err)
|
||||||
}
|
}
|
||||||
|
i.program.Kill()
|
||||||
i.programMu.Lock()
|
i.program = nil
|
||||||
if i.program != nil {
|
if err := m.interaction.lifecycle.Close(); err != nil {
|
||||||
i.program.Kill()
|
log.Printf("Cannot close session: %s \n", err)
|
||||||
i.program = nil
|
|
||||||
}
|
|
||||||
i.programMu.Unlock()
|
|
||||||
|
|
||||||
if i.closeFunc != nil {
|
|
||||||
_ = i.closeFunc()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildURL(protocol, subdomain, domain string) string {
|
||||||
|
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateRandomSubdomain() string {
|
||||||
|
return random.GenerateRandomString(20)
|
||||||
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,116 +0,0 @@
|
|||||||
package interaction
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
"tunnel_pls/internal/random"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/bubbles/help"
|
|
||||||
"github.com/charmbracelet/bubbles/key"
|
|
||||||
"github.com/charmbracelet/bubbles/list"
|
|
||||||
"github.com/charmbracelet/bubbles/textinput"
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
|
||||||
)
|
|
||||||
|
|
||||||
type commandItem struct {
|
|
||||||
name string
|
|
||||||
desc string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i commandItem) FilterValue() string { return i.name }
|
|
||||||
func (i commandItem) Title() string { return i.name }
|
|
||||||
func (i commandItem) Description() string { return i.desc }
|
|
||||||
|
|
||||||
type model struct {
|
|
||||||
randomizer random.Random
|
|
||||||
domain string
|
|
||||||
protocol string
|
|
||||||
tunnelType types.TunnelType
|
|
||||||
port uint16
|
|
||||||
keymap keymap
|
|
||||||
help help.Model
|
|
||||||
quitting bool
|
|
||||||
showingCommands bool
|
|
||||||
editingSlug bool
|
|
||||||
showingComingSoon bool
|
|
||||||
commandList list.Model
|
|
||||||
slugInput textinput.Model
|
|
||||||
slugError string
|
|
||||||
interaction *interaction
|
|
||||||
width int
|
|
||||||
height int
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
ColorPrimary = "#7D56F4"
|
|
||||||
ColorSecondary = "#04B575"
|
|
||||||
ColorGray = "#888888"
|
|
||||||
ColorDarkGray = "#666666"
|
|
||||||
ColorWhite = "#FAFAFA"
|
|
||||||
ColorError = "#FF0000"
|
|
||||||
ColorErrorBg = "#3D0000"
|
|
||||||
ColorWarning = "#FFA500"
|
|
||||||
ColorWarningBg = "#3D2000"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
BreakpointTiny = 50
|
|
||||||
BreakpointSmall = 60
|
|
||||||
BreakpointMedium = 70
|
|
||||||
BreakpointLarge = 85
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m *model) getTunnelURL() string {
|
|
||||||
if m.tunnelType == types.TunnelTypeHTTP {
|
|
||||||
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
|
|
||||||
}
|
|
||||||
|
|
||||||
type keymap struct {
|
|
||||||
quit key.Binding
|
|
||||||
command key.Binding
|
|
||||||
random key.Binding
|
|
||||||
}
|
|
||||||
|
|
||||||
type tickMsg time.Time
|
|
||||||
|
|
||||||
func (m *model) Init() tea.Cmd {
|
|
||||||
return tea.Batch(textinput.Blink, tea.WindowSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
|
|
||||||
width := screenWidth - padding
|
|
||||||
if width > maxWidth {
|
|
||||||
width = maxWidth
|
|
||||||
}
|
|
||||||
if width < minWidth {
|
|
||||||
width = minWidth
|
|
||||||
}
|
|
||||||
return width
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldUseCompactLayout(width int, threshold int) bool {
|
|
||||||
return width < threshold
|
|
||||||
}
|
|
||||||
|
|
||||||
func truncateString(s string, maxLength int) string {
|
|
||||||
if len(s) <= maxLength {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
if maxLength < 4 {
|
|
||||||
return s[:maxLength]
|
|
||||||
}
|
|
||||||
return s[:maxLength-3] + "..."
|
|
||||||
}
|
|
||||||
|
|
||||||
func tickCmd(d time.Duration) tea.Cmd {
|
|
||||||
return tea.Tick(d, func(t time.Time) tea.Msg {
|
|
||||||
return tickMsg(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildURL(protocol, subdomain, domain string) string {
|
|
||||||
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
|
|
||||||
}
|
|
||||||
@@ -1,265 +0,0 @@
|
|||||||
package interaction
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/bubbles/key"
|
|
||||||
"github.com/charmbracelet/bubbles/textinput"
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
|
||||||
"github.com/charmbracelet/lipgloss"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|
||||||
var cmd tea.Cmd
|
|
||||||
|
|
||||||
if m.tunnelType != types.TunnelTypeHTTP {
|
|
||||||
m.editingSlug = false
|
|
||||||
m.slugError = ""
|
|
||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch msg.String() {
|
|
||||||
case "esc", "ctrl+c":
|
|
||||||
m.editingSlug = false
|
|
||||||
m.slugError = ""
|
|
||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
|
||||||
case "enter":
|
|
||||||
inputValue := m.slugInput.Value()
|
|
||||||
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
|
|
||||||
Id: m.interaction.slug.String(),
|
|
||||||
Type: types.TunnelTypeHTTP,
|
|
||||||
}, types.SessionKey{
|
|
||||||
Id: inputValue,
|
|
||||||
Type: types.TunnelTypeHTTP,
|
|
||||||
}); err != nil {
|
|
||||||
m.slugError = err.Error()
|
|
||||||
return m, nil
|
|
||||||
}
|
|
||||||
m.editingSlug = false
|
|
||||||
m.slugError = ""
|
|
||||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
|
||||||
default:
|
|
||||||
if key.Matches(msg, m.keymap.random) {
|
|
||||||
newSubdomain, err := m.randomizer.String(20)
|
|
||||||
if err != nil {
|
|
||||||
return m, cmd
|
|
||||||
}
|
|
||||||
m.slugInput.SetValue(newSubdomain)
|
|
||||||
}
|
|
||||||
m.slugError = ""
|
|
||||||
m.slugInput, cmd = m.slugInput.Update(msg)
|
|
||||||
return m, cmd
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) slugView() string {
|
|
||||||
isCompact := shouldUseCompactLayout(m.width, BreakpointMedium)
|
|
||||||
isVeryCompact := shouldUseCompactLayout(m.width, BreakpointTiny)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString(m.renderSlugTitle(isVeryCompact))
|
|
||||||
|
|
||||||
if m.tunnelType != types.TunnelTypeHTTP {
|
|
||||||
b.WriteString(m.renderTCPWarning(isVeryCompact, isCompact))
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
b.WriteString(m.renderSlugRules(isVeryCompact, isCompact))
|
|
||||||
b.WriteString(m.renderSlugInstruction(isVeryCompact))
|
|
||||||
b.WriteString(m.renderSlugInput(isVeryCompact, isCompact))
|
|
||||||
b.WriteString(m.renderSlugPreview(isVeryCompact))
|
|
||||||
b.WriteString(m.renderSlugHelp(isVeryCompact))
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderSlugTitle(isVeryCompact bool) string {
|
|
||||||
titleStyle := lipgloss.NewStyle().
|
|
||||||
Bold(true).
|
|
||||||
Foreground(lipgloss.Color(ColorPrimary)).
|
|
||||||
PaddingTop(1).
|
|
||||||
PaddingBottom(1)
|
|
||||||
|
|
||||||
title := "🔧 Edit Subdomain"
|
|
||||||
if isVeryCompact {
|
|
||||||
title = "Edit Subdomain"
|
|
||||||
}
|
|
||||||
|
|
||||||
return titleStyle.Render(title) + "\n\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderTCPWarning(isVeryCompact, isCompact bool) string {
|
|
||||||
boxPadding := getPaddingValue(isVeryCompact, isCompact)
|
|
||||||
boxMargin := getMarginValue(isCompact, 1, 2)
|
|
||||||
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
|
||||||
|
|
||||||
warningBoxStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorWarning)).
|
|
||||||
Background(lipgloss.Color(ColorWarningBg)).
|
|
||||||
Bold(true).
|
|
||||||
Border(lipgloss.RoundedBorder()).
|
|
||||||
BorderForeground(lipgloss.Color(ColorWarning)).
|
|
||||||
Padding(1, boxPadding).
|
|
||||||
MarginTop(boxMargin).
|
|
||||||
MarginBottom(boxMargin).
|
|
||||||
Width(warningBoxWidth)
|
|
||||||
|
|
||||||
helpStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorDarkGray)).
|
|
||||||
Italic(true).
|
|
||||||
MarginTop(1)
|
|
||||||
|
|
||||||
warningText := m.getTCPWarningText(isVeryCompact)
|
|
||||||
helpText := m.getTCPHelpText(isVeryCompact)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString(warningBoxStyle.Render(warningText))
|
|
||||||
b.WriteString("\n\n")
|
|
||||||
b.WriteString(helpStyle.Render(helpText))
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) getTCPWarningText(isVeryCompact bool) string {
|
|
||||||
if isVeryCompact {
|
|
||||||
return "⚠️ TCP tunnels don't support custom subdomains."
|
|
||||||
}
|
|
||||||
return "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) getTCPHelpText(isVeryCompact bool) string {
|
|
||||||
if isVeryCompact {
|
|
||||||
return "Press any key to go back"
|
|
||||||
}
|
|
||||||
return "Press Enter or Esc to go back"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderSlugRules(isVeryCompact, isCompact bool) string {
|
|
||||||
boxPadding := getPaddingValue(isVeryCompact, isCompact)
|
|
||||||
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
|
||||||
|
|
||||||
rulesBoxStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorWhite)).
|
|
||||||
Border(lipgloss.RoundedBorder()).
|
|
||||||
BorderForeground(lipgloss.Color(ColorPrimary)).
|
|
||||||
Padding(0, boxPadding).
|
|
||||||
MarginTop(1).
|
|
||||||
MarginBottom(1).
|
|
||||||
Width(rulesBoxWidth)
|
|
||||||
|
|
||||||
rulesContent := m.getRulesContent(isVeryCompact, isCompact)
|
|
||||||
return rulesBoxStyle.Render(rulesContent) + "\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) getRulesContent(isVeryCompact, isCompact bool) string {
|
|
||||||
if isVeryCompact {
|
|
||||||
return "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
|
|
||||||
}
|
|
||||||
|
|
||||||
if isCompact {
|
|
||||||
return "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
|
|
||||||
}
|
|
||||||
|
|
||||||
return "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderSlugInstruction(isVeryCompact bool) string {
|
|
||||||
instructionStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorWhite)).
|
|
||||||
MarginTop(1)
|
|
||||||
|
|
||||||
instruction := "Enter your custom subdomain:"
|
|
||||||
if isVeryCompact {
|
|
||||||
instruction = "Custom subdomain:"
|
|
||||||
}
|
|
||||||
|
|
||||||
return instructionStyle.Render(instruction) + "\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderSlugInput(isVeryCompact, isCompact bool) string {
|
|
||||||
boxPadding := getPaddingValue(isVeryCompact, isCompact)
|
|
||||||
boxMargin := getMarginValue(isCompact, 1, 2)
|
|
||||||
|
|
||||||
if m.slugError != "" {
|
|
||||||
return m.renderErrorInput(boxPadding, boxMargin)
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.renderNormalInput(boxPadding, boxMargin)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderErrorInput(boxPadding, boxMargin int) string {
|
|
||||||
errorInputBoxStyle := lipgloss.NewStyle().
|
|
||||||
Border(lipgloss.RoundedBorder()).
|
|
||||||
BorderForeground(lipgloss.Color(ColorError)).
|
|
||||||
Padding(1, boxPadding).
|
|
||||||
MarginTop(boxMargin).
|
|
||||||
MarginBottom(1)
|
|
||||||
|
|
||||||
errorBoxStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorError)).
|
|
||||||
Background(lipgloss.Color(ColorErrorBg)).
|
|
||||||
Bold(true).
|
|
||||||
Border(lipgloss.RoundedBorder()).
|
|
||||||
BorderForeground(lipgloss.Color(ColorError)).
|
|
||||||
Padding(0, boxPadding).
|
|
||||||
MarginTop(1).
|
|
||||||
MarginBottom(1)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
|
|
||||||
b.WriteString("\n")
|
|
||||||
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
|
|
||||||
b.WriteString("\n")
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderNormalInput(boxPadding, boxMargin int) string {
|
|
||||||
inputBoxStyle := lipgloss.NewStyle().
|
|
||||||
Border(lipgloss.RoundedBorder()).
|
|
||||||
BorderForeground(lipgloss.Color(ColorPrimary)).
|
|
||||||
Padding(1, boxPadding).
|
|
||||||
MarginTop(boxMargin).
|
|
||||||
MarginBottom(boxMargin)
|
|
||||||
|
|
||||||
return inputBoxStyle.Render(m.slugInput.View()) + "\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderSlugPreview(isVeryCompact bool) string {
|
|
||||||
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
|
|
||||||
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
|
|
||||||
|
|
||||||
if isVeryCompact {
|
|
||||||
previewURL = truncateString(previewURL, previewWidth-10)
|
|
||||||
}
|
|
||||||
|
|
||||||
previewStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorSecondary)).
|
|
||||||
Italic(true).
|
|
||||||
Width(previewWidth)
|
|
||||||
|
|
||||||
return previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)) + "\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *model) renderSlugHelp(isVeryCompact bool) string {
|
|
||||||
helpStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(lipgloss.Color(ColorDarkGray)).
|
|
||||||
Italic(true).
|
|
||||||
MarginTop(1)
|
|
||||||
|
|
||||||
helpText := "Press Enter to save • CTRL+R for random • Esc to cancel"
|
|
||||||
if isVeryCompact {
|
|
||||||
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
|
|
||||||
}
|
|
||||||
|
|
||||||
return helpStyle.Render(helpText)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPaddingValue(isVeryCompact, isCompact bool) int {
|
|
||||||
if isVeryCompact || isCompact {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
portUtil "tunnel_pls/internal/port"
|
portUtil "tunnel_pls/internal/port"
|
||||||
@@ -16,132 +15,103 @@ import (
|
|||||||
|
|
||||||
type Forwarder interface {
|
type Forwarder interface {
|
||||||
Close() error
|
Close() error
|
||||||
TunnelType() types.TunnelType
|
GetTunnelType() types.TunnelType
|
||||||
ForwardedPort() uint16
|
GetForwardedPort() uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionRegistry interface {
|
type Lifecycle struct {
|
||||||
Remove(key types.SessionKey)
|
status types.Status
|
||||||
|
conn ssh.Conn
|
||||||
|
channel ssh.Channel
|
||||||
|
forwarder Forwarder
|
||||||
|
slugManager slug.Manager
|
||||||
|
unregisterClient func(slug string)
|
||||||
|
startedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type lifecycle struct {
|
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
|
||||||
mu sync.Mutex
|
return &Lifecycle{
|
||||||
status types.SessionStatus
|
status: types.INITIALIZING,
|
||||||
closeErr error
|
conn: conn,
|
||||||
conn ssh.Conn
|
channel: nil,
|
||||||
channel ssh.Channel
|
forwarder: forwarder,
|
||||||
forwarder Forwarder
|
slugManager: slugManager,
|
||||||
slug slug.Slug
|
unregisterClient: nil,
|
||||||
startedAt time.Time
|
startedAt: time.Now(),
|
||||||
sessionRegistry SessionRegistry
|
|
||||||
portRegistry portUtil.Port
|
|
||||||
user string
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
|
|
||||||
return &lifecycle{
|
|
||||||
status: types.SessionStatusINITIALIZING,
|
|
||||||
conn: conn,
|
|
||||||
channel: nil,
|
|
||||||
forwarder: forwarder,
|
|
||||||
slug: slugManager,
|
|
||||||
startedAt: time.Now(),
|
|
||||||
sessionRegistry: sessionRegistry,
|
|
||||||
portRegistry: port,
|
|
||||||
user: user,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Lifecycle interface {
|
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
|
||||||
Connection() ssh.Conn
|
l.unregisterClient = unregisterClient
|
||||||
Channel() ssh.Channel
|
}
|
||||||
PortRegistry() portUtil.Port
|
|
||||||
User() string
|
type SessionLifecycle interface {
|
||||||
|
Close() error
|
||||||
|
SetStatus(status types.Status)
|
||||||
|
GetConnection() ssh.Conn
|
||||||
|
GetChannel() ssh.Channel
|
||||||
SetChannel(channel ssh.Channel)
|
SetChannel(channel ssh.Channel)
|
||||||
SetStatus(status types.SessionStatus)
|
SetUnregisterClient(unregisterClient func(slug string))
|
||||||
IsActive() bool
|
IsActive() bool
|
||||||
StartedAt() time.Time
|
StartedAt() time.Time
|
||||||
Close() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lifecycle) PortRegistry() portUtil.Port {
|
func (l *Lifecycle) GetChannel() ssh.Channel {
|
||||||
return l.portRegistry
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *lifecycle) User() string {
|
|
||||||
return l.user
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *lifecycle) SetChannel(channel ssh.Channel) {
|
|
||||||
l.channel = channel
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *lifecycle) Channel() ssh.Channel {
|
|
||||||
return l.channel
|
return l.channel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lifecycle) Connection() ssh.Conn {
|
func (l *Lifecycle) SetChannel(channel ssh.Channel) {
|
||||||
|
l.channel = channel
|
||||||
|
}
|
||||||
|
func (l *Lifecycle) GetConnection() ssh.Conn {
|
||||||
return l.conn
|
return l.conn
|
||||||
}
|
}
|
||||||
|
func (l *Lifecycle) SetStatus(status types.Status) {
|
||||||
func (l *lifecycle) SetStatus(status types.SessionStatus) {
|
|
||||||
l.mu.Lock()
|
|
||||||
defer l.mu.Unlock()
|
|
||||||
l.status = status
|
l.status = status
|
||||||
}
|
if status == types.RUNNING && l.startedAt.IsZero() {
|
||||||
|
l.startedAt = time.Now()
|
||||||
func (l *lifecycle) IsActive() bool {
|
|
||||||
l.mu.Lock()
|
|
||||||
defer l.mu.Unlock()
|
|
||||||
return l.status == types.SessionStatusRUNNING
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *lifecycle) Close() error {
|
|
||||||
l.mu.Lock()
|
|
||||||
defer l.mu.Unlock()
|
|
||||||
if l.status == types.SessionStatusCLOSED {
|
|
||||||
return l.closeErr
|
|
||||||
}
|
}
|
||||||
l.status = types.SessionStatusCLOSED
|
}
|
||||||
|
|
||||||
var errs []error
|
func (l *Lifecycle) Close() error {
|
||||||
tunnelType := l.forwarder.TunnelType()
|
err := l.forwarder.Close()
|
||||||
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if l.channel != nil {
|
if l.channel != nil {
|
||||||
if err := l.channel.Close(); err != nil && !isClosedError(err) {
|
err := l.channel.Close()
|
||||||
errs = append(errs, err)
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if l.conn != nil {
|
if l.conn != nil {
|
||||||
if err := l.conn.Close(); err != nil && !isClosedError(err) {
|
err := l.conn.Close()
|
||||||
errs = append(errs, err)
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
clientSlug := l.slug.String()
|
clientSlug := l.slugManager.Get()
|
||||||
key := types.SessionKey{
|
if clientSlug != "" {
|
||||||
Id: clientSlug,
|
l.unregisterClient(clientSlug)
|
||||||
Type: tunnelType,
|
|
||||||
}
|
|
||||||
l.sessionRegistry.Remove(key)
|
|
||||||
|
|
||||||
if tunnelType == types.TunnelTypeTCP {
|
|
||||||
errs = append(errs, l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false))
|
|
||||||
errs = append(errs, l.forwarder.Close())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
l.closeErr = errors.Join(errs...)
|
if l.forwarder.GetTunnelType() == types.TCP {
|
||||||
return l.closeErr
|
err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func isClosedError(err error) bool {
|
func (l *Lifecycle) IsActive() bool {
|
||||||
if err == nil {
|
return l.status == types.RUNNING
|
||||||
return false
|
|
||||||
}
|
|
||||||
return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || err.Error() == "EOF"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *lifecycle) StartedAt() time.Time {
|
func (l *Lifecycle) StartedAt() time.Time {
|
||||||
return l.startedAt
|
return l.startedAt
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,303 +0,0 @@
|
|||||||
package lifecycle
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/mock"
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MockSessionRegistry struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSessionRegistry) Remove(key types.SessionKey) {
|
|
||||||
m.Called(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockForwarder struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
|
||||||
args := m.Called(origin)
|
|
||||||
return args.Get(0).([]byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
|
||||||
m.Called(dst, src)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) TunnelType() types.TunnelType {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(types.TunnelType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) ForwardedPort() uint16 {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(uint16)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
|
||||||
m.Called(tunnelType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
|
||||||
m.Called(port)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) SetListener(listener net.Listener) {
|
|
||||||
m.Called(listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) Listener() net.Listener {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Get(0).(net.Listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
|
||||||
args := m.Called(ctx, origin)
|
|
||||||
if args.Get(0) == nil {
|
|
||||||
return nil, nil, args.Error(2)
|
|
||||||
}
|
|
||||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockPort struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
|
||||||
return m.Called(startPort, endPort).Error(0)
|
|
||||||
}
|
|
||||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
|
||||||
args := m.Called()
|
|
||||||
var port uint16
|
|
||||||
if args.Get(0) != nil {
|
|
||||||
switch v := args.Get(0).(type) {
|
|
||||||
case int:
|
|
||||||
port = uint16(v)
|
|
||||||
case uint16:
|
|
||||||
port = v
|
|
||||||
case uint32:
|
|
||||||
port = uint16(v)
|
|
||||||
case int32:
|
|
||||||
port = uint16(v)
|
|
||||||
case float64:
|
|
||||||
port = uint16(v)
|
|
||||||
default:
|
|
||||||
port = uint16(args.Int(0))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return port, args.Bool(1)
|
|
||||||
}
|
|
||||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
|
||||||
return m.Called(port, assigned).Error(0)
|
|
||||||
}
|
|
||||||
func (m *MockPort) Claim(port uint16) bool {
|
|
||||||
return m.Called(port).Bool(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockSlug struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ms *MockSlug) Set(slug string) {
|
|
||||||
ms.Called(slug)
|
|
||||||
}
|
|
||||||
func (ms *MockSlug) String() string {
|
|
||||||
return ms.Called().String(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockSSHConn struct {
|
|
||||||
ssh.Conn
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSSHConn) Close() error {
|
|
||||||
args := m.Called()
|
|
||||||
return args.Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
type MockSSHChannel struct {
|
|
||||||
ssh.Channel
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MockSSHChannel) Close() error {
|
|
||||||
return m.Called().Error(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockForwarder := &MockForwarder{}
|
|
||||||
mockSlug := &MockSlug{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
|
|
||||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
||||||
|
|
||||||
assert.NotNil(t, mockLifecycle.Connection())
|
|
||||||
assert.NotNil(t, mockLifecycle.User())
|
|
||||||
assert.NotNil(t, mockLifecycle.PortRegistry())
|
|
||||||
assert.NotNil(t, mockLifecycle.StartedAt())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLifecycle_User(t *testing.T) {
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockForwarder := &MockForwarder{}
|
|
||||||
mockSlug := &MockSlug{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
|
|
||||||
user := "mas-fuad"
|
|
||||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user)
|
|
||||||
assert.Equal(t, user, mockLifecycle.User())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLifecycle_SetChannel(t *testing.T) {
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockForwarder := &MockForwarder{}
|
|
||||||
mockSlug := &MockSlug{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
|
|
||||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
||||||
|
|
||||||
mockSSHChannel := &MockSSHChannel{}
|
|
||||||
|
|
||||||
mockLifecycle.SetChannel(mockSSHChannel)
|
|
||||||
|
|
||||||
assert.Equal(t, mockSSHChannel, mockLifecycle.Channel())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLifecycle_SetStatus(t *testing.T) {
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockForwarder := &MockForwarder{}
|
|
||||||
mockSlug := &MockSlug{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
|
|
||||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
||||||
|
|
||||||
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
|
||||||
assert.True(t, mockLifecycle.IsActive())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLifecycle_IsActive(t *testing.T) {
|
|
||||||
mockSSHConn := new(MockSSHConn)
|
|
||||||
mockForwarder := &MockForwarder{}
|
|
||||||
mockSlug := &MockSlug{}
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
|
|
||||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
||||||
|
|
||||||
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
|
||||||
assert.True(t, mockLifecycle.IsActive())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLifecycle_Close(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
tunnelType types.TunnelType
|
|
||||||
connCloseErr error
|
|
||||||
channelCloseErr error
|
|
||||||
expectErr bool
|
|
||||||
alreadyClosed bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Close HTTP forwarding success",
|
|
||||||
tunnelType: types.TunnelTypeHTTP,
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Close TCP forwarding success",
|
|
||||||
tunnelType: types.TunnelTypeTCP,
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Close with conn close error",
|
|
||||||
tunnelType: types.TunnelTypeHTTP,
|
|
||||||
connCloseErr: errors.New("conn close error"),
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Close with channel close error",
|
|
||||||
tunnelType: types.TunnelTypeHTTP,
|
|
||||||
channelCloseErr: errors.New("channel close error"),
|
|
||||||
expectErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Close when already closed",
|
|
||||||
tunnelType: types.TunnelTypeHTTP,
|
|
||||||
alreadyClosed: true,
|
|
||||||
expectErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockSSHConn := &MockSSHConn{}
|
|
||||||
mockSSHConn.On("Close").Return(tt.connCloseErr)
|
|
||||||
|
|
||||||
mockForwarder := &MockForwarder{}
|
|
||||||
mockForwarder.On("TunnelType").Return(tt.tunnelType)
|
|
||||||
if tt.tunnelType == types.TunnelTypeTCP {
|
|
||||||
mockForwarder.On("ForwardedPort").Return(uint16(8080))
|
|
||||||
mockForwarder.On("Close").Return(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
mockSlug := &MockSlug{}
|
|
||||||
mockSlug.On("String").Return("test-slug")
|
|
||||||
|
|
||||||
mockPort := &MockPort{}
|
|
||||||
if tt.tunnelType == types.TunnelTypeTCP {
|
|
||||||
mockPort.On("SetStatus", uint16(8080), false).Return(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
mockSessionRegistry := &MockSessionRegistry{}
|
|
||||||
mockSessionRegistry.On("Remove", mock.Anything).Return()
|
|
||||||
|
|
||||||
mockSSHChannel := &MockSSHChannel{}
|
|
||||||
mockSSHChannel.On("Close").Return(tt.channelCloseErr)
|
|
||||||
|
|
||||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
|
||||||
|
|
||||||
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
|
||||||
mockLifecycle.SetChannel(mockSSHChannel)
|
|
||||||
|
|
||||||
if tt.alreadyClosed {
|
|
||||||
err := mockLifecycle.Close()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := mockLifecycle.Close()
|
|
||||||
|
|
||||||
if tt.expectErr {
|
|
||||||
assert.Error(t, err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.False(t, mockLifecycle.IsActive())
|
|
||||||
|
|
||||||
mockSSHConn.AssertExpectations(t)
|
|
||||||
mockForwarder.AssertExpectations(t)
|
|
||||||
mockSlug.AssertExpectations(t)
|
|
||||||
mockPort.AssertExpectations(t)
|
|
||||||
mockSessionRegistry.AssertExpectations(t)
|
|
||||||
mockSSHChannel.AssertExpectations(t)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,169 +1,131 @@
|
|||||||
package registry
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"tunnel_pls/session/forwarder"
|
|
||||||
"tunnel_pls/session/interaction"
|
|
||||||
"tunnel_pls/session/lifecycle"
|
|
||||||
"tunnel_pls/session/slug"
|
|
||||||
"tunnel_pls/types"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Key = types.SessionKey
|
|
||||||
|
|
||||||
type Session interface {
|
|
||||||
Lifecycle() lifecycle.Lifecycle
|
|
||||||
Interaction() interaction.Interaction
|
|
||||||
Forwarder() forwarder.Forwarder
|
|
||||||
Slug() slug.Slug
|
|
||||||
Detail() *types.Detail
|
|
||||||
}
|
|
||||||
|
|
||||||
type Registry interface {
|
type Registry interface {
|
||||||
Get(key Key) (session Session, err error)
|
Get(slug string) (session *SSHSession, err error)
|
||||||
GetWithUser(user string, key Key) (session Session, err error)
|
Update(oldSlug, newSlug string) error
|
||||||
Update(user string, oldKey, newKey Key) error
|
Register(slug string, session *SSHSession) (success bool)
|
||||||
Register(key Key, session Session) (success bool)
|
Remove(slug string)
|
||||||
Remove(key Key)
|
GetAllSessionFromUser(user string) []*SSHSession
|
||||||
GetAllSessionFromUser(user string) []Session
|
|
||||||
}
|
}
|
||||||
type registry struct {
|
type registry struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
byUser map[string]map[Key]Session
|
byUser map[string]map[string]*SSHSession
|
||||||
slugIndex map[Key]string
|
slugIndex map[string]string
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
ErrSessionNotFound = fmt.Errorf("session not found")
|
|
||||||
ErrSlugInUse = fmt.Errorf("slug already in use")
|
|
||||||
ErrInvalidSlug = fmt.Errorf("invalid slug")
|
|
||||||
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
|
|
||||||
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
|
|
||||||
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewRegistry() Registry {
|
func NewRegistry() Registry {
|
||||||
return ®istry{
|
return ®istry{
|
||||||
byUser: make(map[string]map[Key]Session),
|
byUser: make(map[string]map[string]*SSHSession),
|
||||||
slugIndex: make(map[Key]string),
|
slugIndex: make(map[string]string),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registry) Get(key Key) (session Session, err error) {
|
func (r *registry) Get(slug string) (session *SSHSession, err error) {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
defer r.mu.RUnlock()
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
userID, ok := r.slugIndex[key]
|
userID, ok := r.slugIndex[slug]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrSessionNotFound
|
return nil, fmt.Errorf("session not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
client, ok := r.byUser[userID][key]
|
client, ok := r.byUser[userID][slug]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrSessionNotFound
|
return nil, fmt.Errorf("session not found")
|
||||||
}
|
}
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registry) GetWithUser(user string, key Key) (session Session, err error) {
|
func (r *registry) Update(oldSlug, newSlug string) error {
|
||||||
r.mu.RLock()
|
if isForbiddenSlug(newSlug) {
|
||||||
defer r.mu.RUnlock()
|
return fmt.Errorf("this subdomain is reserved. Please choose a different one")
|
||||||
|
} else if !isValidSlug(newSlug) {
|
||||||
client, ok := r.byUser[user][key]
|
return fmt.Errorf("invalid subdomain. Follow the rules")
|
||||||
if !ok {
|
|
||||||
return nil, ErrSessionNotFound
|
|
||||||
}
|
|
||||||
return client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
|
||||||
if oldKey.Type != newKey.Type {
|
|
||||||
return ErrSlugUnchanged
|
|
||||||
}
|
|
||||||
|
|
||||||
if newKey.Type != types.TunnelTypeHTTP {
|
|
||||||
return ErrSlugChangeNotAllowed
|
|
||||||
}
|
|
||||||
|
|
||||||
if isForbiddenSlug(newKey.Id) {
|
|
||||||
return ErrForbiddenSlug
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isValidSlug(newKey.Id) {
|
|
||||||
return ErrInvalidSlug
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
|
||||||
return ErrSlugInUse
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
client, ok := r.byUser[user][oldKey]
|
userID, ok := r.slugIndex[oldSlug]
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrSessionNotFound
|
return fmt.Errorf("session not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(r.byUser[user], oldKey)
|
if _, exists := r.slugIndex[newSlug]; exists && newSlug != oldSlug {
|
||||||
delete(r.slugIndex, oldKey)
|
return fmt.Errorf("someone already uses this subdomain")
|
||||||
|
}
|
||||||
|
|
||||||
client.Slug().Set(newKey.Id)
|
client, ok := r.byUser[userID][oldSlug]
|
||||||
r.slugIndex[newKey] = user
|
if !ok {
|
||||||
|
return fmt.Errorf("session not found")
|
||||||
|
}
|
||||||
|
|
||||||
r.byUser[user][newKey] = client
|
delete(r.byUser[userID], oldSlug)
|
||||||
|
delete(r.slugIndex, oldSlug)
|
||||||
|
|
||||||
|
client.slugManager.Set(newSlug)
|
||||||
|
r.slugIndex[newSlug] = userID
|
||||||
|
|
||||||
|
if r.byUser[userID] == nil {
|
||||||
|
r.byUser[userID] = make(map[string]*SSHSession)
|
||||||
|
}
|
||||||
|
r.byUser[userID][newSlug] = client
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registry) Register(key Key, userSession Session) (success bool) {
|
func (r *registry) Register(slug string, session *SSHSession) (success bool) {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
if _, exists := r.slugIndex[key]; exists {
|
if _, exists := r.slugIndex[slug]; exists {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := userSession.Lifecycle().User()
|
userID := session.userID
|
||||||
if r.byUser[userID] == nil {
|
if r.byUser[userID] == nil {
|
||||||
r.byUser[userID] = make(map[Key]Session)
|
r.byUser[userID] = make(map[string]*SSHSession)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.byUser[userID][key] = userSession
|
r.byUser[userID][slug] = session
|
||||||
r.slugIndex[key] = userID
|
r.slugIndex[slug] = userID
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registry) GetAllSessionFromUser(user string) []Session {
|
func (r *registry) GetAllSessionFromUser(user string) []*SSHSession {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
defer r.mu.RUnlock()
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
m := r.byUser[user]
|
m := r.byUser[user]
|
||||||
if len(m) == 0 {
|
if len(m) == 0 {
|
||||||
return []Session{}
|
return []*SSHSession{}
|
||||||
}
|
}
|
||||||
|
|
||||||
sessions := make([]Session, 0, len(m))
|
sessions := make([]*SSHSession, 0, len(m))
|
||||||
for _, s := range m {
|
for _, s := range m {
|
||||||
sessions = append(sessions, s)
|
sessions = append(sessions, s)
|
||||||
}
|
}
|
||||||
return sessions
|
return sessions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registry) Remove(key Key) {
|
func (r *registry) Remove(slug string) {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
userID, ok := r.slugIndex[key]
|
userID, ok := r.slugIndex[slug]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(r.byUser[userID], key)
|
delete(r.byUser[userID], slug)
|
||||||
if len(r.byUser[userID]) == 0 {
|
if len(r.byUser[userID]) == 0 {
|
||||||
delete(r.byUser, userID)
|
delete(r.byUser, userID)
|
||||||
}
|
}
|
||||||
delete(r.slugIndex, key)
|
delete(r.slugIndex, slug)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isValidSlug(slug string) bool {
|
func isValidSlug(slug string) bool {
|
||||||
+68
-327
@@ -1,203 +1,127 @@
|
|||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
|
||||||
"time"
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
portUtil "tunnel_pls/internal/port"
|
|
||||||
"tunnel_pls/internal/random"
|
|
||||||
"tunnel_pls/internal/registry"
|
|
||||||
"tunnel_pls/internal/transport"
|
|
||||||
"tunnel_pls/session/forwarder"
|
"tunnel_pls/session/forwarder"
|
||||||
"tunnel_pls/session/interaction"
|
"tunnel_pls/session/interaction"
|
||||||
"tunnel_pls/session/lifecycle"
|
"tunnel_pls/session/lifecycle"
|
||||||
"tunnel_pls/session/slug"
|
"tunnel_pls/session/slug"
|
||||||
"tunnel_pls/types"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Session interface {
|
type Session interface {
|
||||||
HandleGlobalRequest(ch <-chan *ssh.Request) error
|
HandleGlobalRequest(ch <-chan *ssh.Request)
|
||||||
HandleTCPIPForward(req *ssh.Request) error
|
HandleTCPIPForward(req *ssh.Request)
|
||||||
HandleHTTPForward(req *ssh.Request, port uint16) error
|
HandleHTTPForward(req *ssh.Request, port uint16)
|
||||||
HandleTCPForward(req *ssh.Request, addr string, port uint16) error
|
HandleTCPForward(req *ssh.Request, addr string, port uint16)
|
||||||
Lifecycle() lifecycle.Lifecycle
|
|
||||||
Interaction() interaction.Interaction
|
|
||||||
Forwarder() forwarder.Forwarder
|
|
||||||
Slug() slug.Slug
|
|
||||||
Detail() *types.Detail
|
|
||||||
Start() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type session struct {
|
type SSHSession struct {
|
||||||
randomizer random.Random
|
initialReq <-chan *ssh.Request
|
||||||
config config.Config
|
sshReqChannel <-chan ssh.NewChannel
|
||||||
initialReq <-chan *ssh.Request
|
lifecycle lifecycle.SessionLifecycle
|
||||||
sshChan <-chan ssh.NewChannel
|
interaction interaction.Controller
|
||||||
lifecycle lifecycle.Lifecycle
|
forwarder forwarder.ForwardingController
|
||||||
interaction interaction.Interaction
|
slugManager slug.Manager
|
||||||
forwarder forwarder.Forwarder
|
registry Registry
|
||||||
slug slug.Slug
|
userID string
|
||||||
registry registry.Registry
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
|
||||||
Randomizer random.Random
|
|
||||||
Config config.Config
|
|
||||||
Conn *ssh.ServerConn
|
|
||||||
InitialReq <-chan *ssh.Request
|
|
||||||
SshChan <-chan ssh.NewChannel
|
|
||||||
SessionRegistry registry.Registry
|
|
||||||
PortRegistry portUtil.Port
|
|
||||||
User string
|
|
||||||
}
|
|
||||||
|
|
||||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
|
||||||
|
|
||||||
func New(conf *Config) Session {
|
|
||||||
slugManager := slug.New()
|
|
||||||
forwarderManager := forwarder.New(conf.Config, slugManager, conf.Conn)
|
|
||||||
lifecycleManager := lifecycle.New(conf.Conn, forwarderManager, slugManager, conf.PortRegistry, conf.SessionRegistry, conf.User)
|
|
||||||
interactionManager := interaction.New(conf.Randomizer, conf.Config, slugManager, forwarderManager, conf.SessionRegistry, conf.User, lifecycleManager.Close)
|
|
||||||
|
|
||||||
return &session{
|
|
||||||
randomizer: conf.Randomizer,
|
|
||||||
config: conf.Config,
|
|
||||||
initialReq: conf.InitialReq,
|
|
||||||
sshChan: conf.SshChan,
|
|
||||||
lifecycle: lifecycleManager,
|
|
||||||
interaction: interactionManager,
|
|
||||||
forwarder: forwarderManager,
|
|
||||||
slug: slugManager,
|
|
||||||
registry: conf.SessionRegistry,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) Lifecycle() lifecycle.Lifecycle {
|
|
||||||
return s.lifecycle
|
return s.lifecycle
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) Interaction() interaction.Interaction {
|
func (s *SSHSession) GetInteraction() interaction.Controller {
|
||||||
return s.interaction
|
return s.interaction
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) Forwarder() forwarder.Forwarder {
|
func (s *SSHSession) GetForwarder() forwarder.ForwardingController {
|
||||||
return s.forwarder
|
return s.forwarder
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) Slug() slug.Slug {
|
func (s *SSHSession) GetSlugManager() slug.Manager {
|
||||||
return s.slug
|
return s.slugManager
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) Detail() *types.Detail {
|
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, userID string) *SSHSession {
|
||||||
tunnelTypeMap := map[types.TunnelType]string{
|
slugManager := slug.NewManager()
|
||||||
types.TunnelTypeHTTP: "HTTP",
|
forwarderManager := forwarder.NewForwarder(slugManager)
|
||||||
types.TunnelTypeTCP: "TCP",
|
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
|
||||||
}
|
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager)
|
||||||
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
|
|
||||||
if !ok {
|
|
||||||
tunnelType = "UNKNOWN"
|
|
||||||
}
|
|
||||||
|
|
||||||
return &types.Detail{
|
interactionManager.SetLifecycle(lifecycleManager)
|
||||||
ForwardingType: tunnelType,
|
interactionManager.SetSlugModificator(sessionRegistry.Update)
|
||||||
Slug: s.slug.String(),
|
forwarderManager.SetLifecycle(lifecycleManager)
|
||||||
UserID: s.lifecycle.User(),
|
lifecycleManager.SetUnregisterClient(sessionRegistry.Remove)
|
||||||
|
|
||||||
|
return &SSHSession{
|
||||||
|
initialReq: forwardingReq,
|
||||||
|
sshReqChannel: sshChan,
|
||||||
|
lifecycle: lifecycleManager,
|
||||||
|
interaction: interactionManager,
|
||||||
|
forwarder: forwarderManager,
|
||||||
|
slugManager: slugManager,
|
||||||
|
registry: sessionRegistry,
|
||||||
|
userID: userID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Detail struct {
|
||||||
|
ForwardingType string `json:"forwarding_type,omitempty"`
|
||||||
|
Slug string `json:"slug,omitempty"`
|
||||||
|
UserID string `json:"user_id,omitempty"`
|
||||||
|
Active bool `json:"active,omitempty"`
|
||||||
|
StartedAt time.Time `json:"started_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SSHSession) Detail() Detail {
|
||||||
|
return Detail{
|
||||||
|
ForwardingType: string(s.forwarder.GetTunnelType()),
|
||||||
|
Slug: s.slugManager.Get(),
|
||||||
|
UserID: s.userID,
|
||||||
Active: s.lifecycle.IsActive(),
|
Active: s.lifecycle.IsActive(),
|
||||||
StartedAt: s.lifecycle.StartedAt(),
|
StartedAt: s.lifecycle.StartedAt(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) Start() error {
|
func (s *SSHSession) Start() error {
|
||||||
if err := s.setupSessionMode(); err != nil {
|
channel := <-s.sshReqChannel
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpipReq := s.waitForTCPIPForward()
|
|
||||||
if tcpipReq == nil {
|
|
||||||
return s.handleMissingForwardRequest()
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.shouldRejectUnauthorized() {
|
|
||||||
return s.denyForwardingRequest(tcpipReq, nil, nil, "headless forwarding only allowed on node mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.HandleTCPIPForward(tcpipReq); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.interaction.Start()
|
|
||||||
|
|
||||||
return s.waitForSessionEnd()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) setupSessionMode() error {
|
|
||||||
select {
|
|
||||||
case channel, ok := <-s.sshChan:
|
|
||||||
if !ok {
|
|
||||||
log.Println("Forwarding request channel closed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return s.setupInteractiveMode(channel)
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
s.interaction.SetMode(types.InteractiveModeHEADLESS)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
|
|
||||||
ch, reqs, err := channel.Accept()
|
ch, reqs, err := channel.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("failed to accept channel: %v", err)
|
log.Printf("failed to accept channel: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
go s.HandleGlobalRequest(reqs)
|
||||||
|
|
||||||
go func() {
|
tcpipReq := s.waitForTCPIPForward()
|
||||||
err = s.HandleGlobalRequest(reqs)
|
if tcpipReq == nil {
|
||||||
|
_, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("global request handler error: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
}()
|
if err := s.lifecycle.Close(); err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("no forwarding Request")
|
||||||
|
}
|
||||||
|
|
||||||
s.lifecycle.SetChannel(ch)
|
s.lifecycle.SetChannel(ch)
|
||||||
s.interaction.SetChannel(ch)
|
s.interaction.SetChannel(ch)
|
||||||
s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
|
|
||||||
|
|
||||||
return nil
|
s.HandleTCPIPForward(tcpipReq)
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) handleMissingForwardRequest() error {
|
|
||||||
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("no forwarding Request")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) shouldRejectUnauthorized() bool {
|
|
||||||
return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
|
|
||||||
s.config.Mode() == types.ServerModeSTANDALONE &&
|
|
||||||
s.lifecycle.User() == "UNAUTHORIZED"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) waitForSessionEnd() error {
|
|
||||||
if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
|
||||||
log.Printf("ssh connection closed with error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := s.lifecycle.Close(); err != nil {
|
if err := s.lifecycle.Close(); err != nil {
|
||||||
|
log.Printf("failed to close session: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) waitForTCPIPForward() *ssh.Request {
|
func (s *SSHSession) waitForTCPIPForward() *ssh.Request {
|
||||||
select {
|
select {
|
||||||
case req, ok := <-s.initialReq:
|
case req, ok := <-s.initialReq:
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -217,186 +141,3 @@ func (s *session) waitForTCPIPForward() *ssh.Request {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) handleWindowChange(req *ssh.Request) error {
|
|
||||||
p := req.Payload
|
|
||||||
if len(p) < 16 {
|
|
||||||
log.Println("invalid window-change payload")
|
|
||||||
return req.Reply(false, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
cols := binary.BigEndian.Uint32(p[0:4])
|
|
||||||
rows := binary.BigEndian.Uint32(p[4:8])
|
|
||||||
|
|
||||||
s.interaction.SetWH(int(cols), int(rows))
|
|
||||||
return req.Reply(true, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
|
||||||
for req := range GlobalRequest {
|
|
||||||
switch req.Type {
|
|
||||||
case "shell", "pty-req":
|
|
||||||
if err := req.Reply(true, nil); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case "window-change":
|
|
||||||
if err := s.handleWindowChange(req); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Println("Unknown request type:", req.Type)
|
|
||||||
if err := req.Reply(false, nil); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) {
|
|
||||||
var forwardPayload struct {
|
|
||||||
BindAddr string
|
|
||||||
BindPort uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = ssh.Unmarshal(payload, &forwardPayload); err != nil {
|
|
||||||
return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if forwardPayload.BindPort > 65535 {
|
|
||||||
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
|
|
||||||
}
|
|
||||||
|
|
||||||
port = uint16(forwardPayload.BindPort)
|
|
||||||
|
|
||||||
if isBlockedPort(port) {
|
|
||||||
return "", 0, fmt.Errorf("port is blocked")
|
|
||||||
}
|
|
||||||
|
|
||||||
if port == 0 {
|
|
||||||
unassigned, ok := s.lifecycle.PortRegistry().Unassigned()
|
|
||||||
if !ok {
|
|
||||||
return "", 0, fmt.Errorf("no available port")
|
|
||||||
}
|
|
||||||
return forwardPayload.BindAddr, unassigned, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return forwardPayload.BindAddr, port, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
|
|
||||||
var errs []error
|
|
||||||
if key != nil {
|
|
||||||
s.registry.Remove(*key)
|
|
||||||
}
|
|
||||||
|
|
||||||
if listener != nil {
|
|
||||||
errs = append(errs, listener.Close())
|
|
||||||
}
|
|
||||||
|
|
||||||
errs = append(errs, req.Reply(false, nil))
|
|
||||||
errs = append(errs, s.lifecycle.Close())
|
|
||||||
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
|
|
||||||
return errors.Join(errs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
|
|
||||||
replyPayload := struct {
|
|
||||||
BoundPort uint32
|
|
||||||
}{
|
|
||||||
BoundPort: uint32(portToBind),
|
|
||||||
}
|
|
||||||
err := req.Reply(true, ssh.Marshal(replyPayload))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
s.forwarder.SetType(tunnelType)
|
|
||||||
s.forwarder.SetForwardedPort(portToBind)
|
|
||||||
s.slug.Set(slug)
|
|
||||||
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
|
|
||||||
|
|
||||||
if listener != nil {
|
|
||||||
s.forwarder.SetListener(listener)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
|
|
||||||
address, port, err := s.parseForwardPayload(req.Payload)
|
|
||||||
if err != nil {
|
|
||||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
|
|
||||||
}
|
|
||||||
|
|
||||||
switch port {
|
|
||||||
case 80, 443:
|
|
||||||
return s.HandleHTTPForward(req, port)
|
|
||||||
default:
|
|
||||||
return s.HandleTCPForward(req, address, port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
|
|
||||||
randomString, err := s.randomizer.String(20)
|
|
||||||
if err != nil {
|
|
||||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
|
|
||||||
}
|
|
||||||
key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
|
|
||||||
if !s.registry.Register(key, s) {
|
|
||||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
|
|
||||||
if err != nil {
|
|
||||||
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
|
|
||||||
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
|
|
||||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
|
|
||||||
listener, err := tcpServer.Listen()
|
|
||||||
if err != nil {
|
|
||||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
|
||||||
}
|
|
||||||
|
|
||||||
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
|
|
||||||
if !s.registry.Register(key, s) {
|
|
||||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
|
|
||||||
if err != nil {
|
|
||||||
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
err = tcpServer.Serve(listener)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed serving tcp server: %s\n", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isBlockedPort(port uint16) bool {
|
|
||||||
if port == 80 || port == 443 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if port < 1024 && port != 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for _, p := range blockedReservedPorts {
|
|
||||||
if p == port {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,24 +1,24 @@
|
|||||||
package slug
|
package slug
|
||||||
|
|
||||||
type Slug interface {
|
type Manager interface {
|
||||||
String() string
|
Get() string
|
||||||
Set(slug string)
|
Set(slug string)
|
||||||
}
|
}
|
||||||
|
|
||||||
type slug struct {
|
type manager struct {
|
||||||
slug string
|
slug string
|
||||||
}
|
}
|
||||||
|
|
||||||
func New() Slug {
|
func NewManager() Manager {
|
||||||
return &slug{
|
return &manager{
|
||||||
slug: "",
|
slug: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *slug) String() string {
|
func (s *manager) Get() string {
|
||||||
return s.slug
|
return s.slug
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *slug) Set(slug string) {
|
func (s *manager) Set(slug string) {
|
||||||
s.slug = slug
|
s.slug = slug
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,99 +0,0 @@
|
|||||||
package slug
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/suite"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SlugTestSuite struct {
|
|
||||||
suite.Suite
|
|
||||||
slug Slug
|
|
||||||
}
|
|
||||||
|
|
||||||
func (suite *SlugTestSuite) SetupTest() {
|
|
||||||
suite.slug = New()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
s := New()
|
|
||||||
|
|
||||||
assert.NotNil(t, s, "New() should return a non-nil Slug")
|
|
||||||
assert.Implements(t, (*Slug)(nil), s, "New() should return a type that implements Slug interface")
|
|
||||||
assert.Equal(t, "", s.String(), "New() should initialize with empty string")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (suite *SlugTestSuite) TestString() {
|
|
||||||
assert.Equal(suite.T(), "", suite.slug.String(), "String() should return empty string initially")
|
|
||||||
|
|
||||||
suite.slug.Set("test-slug")
|
|
||||||
assert.Equal(suite.T(), "test-slug", suite.slug.String(), "String() should return the set value")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (suite *SlugTestSuite) TestSet() {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "simple slug",
|
|
||||||
input: "hello-world",
|
|
||||||
expected: "hello-world",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty string",
|
|
||||||
input: "",
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "slug with numbers",
|
|
||||||
input: "test-123",
|
|
||||||
expected: "test-123",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "slug with special characters",
|
|
||||||
input: "hello_world-123",
|
|
||||||
expected: "hello_world-123",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "overwrite existing slug",
|
|
||||||
input: "new-slug",
|
|
||||||
expected: "new-slug",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
suite.Run(tc.name, func() {
|
|
||||||
suite.slug.Set(tc.input)
|
|
||||||
assert.Equal(suite.T(), tc.expected, suite.slug.String())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (suite *SlugTestSuite) TestMultipleSet() {
|
|
||||||
suite.slug.Set("first-slug")
|
|
||||||
assert.Equal(suite.T(), "first-slug", suite.slug.String())
|
|
||||||
|
|
||||||
suite.slug.Set("second-slug")
|
|
||||||
assert.Equal(suite.T(), "second-slug", suite.slug.String())
|
|
||||||
|
|
||||||
suite.slug.Set("")
|
|
||||||
assert.Equal(suite.T(), "", suite.slug.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSlugIsolation(t *testing.T) {
|
|
||||||
slug1 := New()
|
|
||||||
slug2 := New()
|
|
||||||
|
|
||||||
slug1.Set("slug-one")
|
|
||||||
slug2.Set("slug-two")
|
|
||||||
|
|
||||||
assert.Equal(t, "slug-one", slug1.String(), "First slug should maintain its value")
|
|
||||||
assert.Equal(t, "slug-two", slug2.String(), "Second slug should maintain its value")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSlugTestSuite(t *testing.T) {
|
|
||||||
suite.Run(t, new(SlugTestSuite))
|
|
||||||
}
|
|
||||||
+7
-37
@@ -1,50 +1,20 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import "time"
|
type Status string
|
||||||
|
|
||||||
type SessionStatus int
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SessionStatusINITIALIZING SessionStatus = iota
|
INITIALIZING Status = "INITIALIZING"
|
||||||
SessionStatusRUNNING
|
RUNNING Status = "RUNNING"
|
||||||
SessionStatusCLOSED
|
SETUP Status = "SETUP"
|
||||||
)
|
)
|
||||||
|
|
||||||
type InteractiveMode int
|
type TunnelType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
InteractiveModeINTERACTIVE InteractiveMode = iota + 1
|
HTTP TunnelType = "HTTP"
|
||||||
InteractiveModeHEADLESS
|
TCP TunnelType = "TCP"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunnelType int
|
|
||||||
|
|
||||||
const (
|
|
||||||
TunnelTypeUNKNOWN TunnelType = iota
|
|
||||||
TunnelTypeHTTP
|
|
||||||
TunnelTypeTCP
|
|
||||||
)
|
|
||||||
|
|
||||||
type ServerMode int
|
|
||||||
|
|
||||||
const (
|
|
||||||
ServerModeSTANDALONE = iota + 1
|
|
||||||
ServerModeNODE
|
|
||||||
)
|
|
||||||
|
|
||||||
type SessionKey struct {
|
|
||||||
Id string
|
|
||||||
Type TunnelType
|
|
||||||
}
|
|
||||||
|
|
||||||
type Detail struct {
|
|
||||||
ForwardingType string `json:"forwarding_type,omitempty"`
|
|
||||||
Slug string `json:"slug,omitempty"`
|
|
||||||
UserID string `json:"user_id,omitempty"`
|
|
||||||
Active bool `json:"active,omitempty"`
|
|
||||||
StartedAt time.Time `json:"started_at,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
|
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
|
||||||
"Content-Length: 11\r\n" +
|
"Content-Length: 11\r\n" +
|
||||||
"Content-Type: text/plain\r\n\r\n" +
|
"Content-Type: text/plain\r\n\r\n" +
|
||||||
|
|||||||
Reference in New Issue
Block a user