Compare commits
145 Commits
v1.1.0-rc.1
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 8229879db8 | |||
| 7015e7f4de | |||
| 03c6b44fa2 | |||
| 3af3fdbc9c | |||
| 6dc4bb58ea | |||
| bd2b843e5d | |||
| 5b05723e93 | |||
| 22ad935299 | |||
| ebd915e18e | |||
| 728691d119 | |||
| 1344afd1b2 | |||
| 4cbee5079c | |||
| 0b071dfde7 | |||
| 6062c2e11d | |||
| 2a2d484e91 | |||
| 9377233515 | |||
| fab625e13a | |||
| 1ed845bf2d | |||
| 67378aabda | |||
| a26d1672d9 | |||
| 7f44cc7bc0 | |||
| a3f6baa6ae | |||
| 6def82a095 | |||
| 354da27424 | |||
| ee1dc3c3cd | |||
| 65df01fee5 | |||
| 79fd292a77 | |||
| 4041681be6 | |||
| 2ee24c8d51 | |||
| 384bb98f48 | |||
| 9785a97973 | |||
| b8c6359820 | |||
| 8fee8bf92e | |||
| 04c9ddbc13 | |||
| 211745dc26 | |||
| 09aa92a0ae | |||
| 1ed9f3631f | |||
| bd826d6d06 | |||
| 2f5c44ff01 | |||
| d0e052524c | |||
| 24b9872aa4 | |||
| 8b84373036 | |||
| e796ab5328 | |||
| efdfc4ce95 | |||
| 1dc929cc25 | |||
| 14abac6579 | |||
| 21179da4b5 | |||
| 32f8be2891 | |||
| 5af7af3139 | |||
| f4848e9754 | |||
| d2e508c8ef | |||
| 5499b7d08a | |||
| 58f1fdabe1 | |||
| c1fb588cf4 | |||
| 3029996773 | |||
| 3fd179d32b | |||
| a598a10e94 | |||
| 29cabe42d3 | |||
| e534972abc | |||
| a55ff5f6ab | |||
| 50b4127cb3 | |||
| 7e635721fb | |||
| 016df9caee | |||
| d91eecb2a0 | |||
| 961a905542 | |||
| 634c8321ef | |||
| 9f4c24a3f3 | |||
| 1408b80917 | |||
| 2bc20dd991 | |||
| 1e12373359 | |||
| 9a4539cc02 | |||
| e3ead4d52f | |||
| aa1a465178 | |||
| 27f49879af | |||
| adb0264bb5 | |||
| 8fb19af5a6 | |||
| 41fdb5639c | |||
| 44d224f491 | |||
| 9be0328e24 | |||
| 2b9bca65d5 | |||
| 6587dc0f39 | |||
| f421781f44 | |||
| 6969d6823a | |||
| 1a04af8873 | |||
| 19135ceb42 | |||
| edb11dbc51 | |||
| 819f044275 | |||
| a7ebf2c5db | |||
| 64c1038f4b | |||
| aafea49975 | |||
| dbdf8094fa | |||
| ae3ed52d16 | |||
| fb638636bf | |||
| da29df85b7 | |||
| 8b0e08c629 | |||
| f0804d6946 | |||
| 09e526cd1e | |||
| 887ebf78b1 | |||
| bef7a49f88 | |||
| 17633b4e3c | |||
| f25d61d1d1 | |||
| 8782b77b74 | |||
| fc3cd886db | |||
| b0da57db0d | |||
| 0bd6eeadf3 | |||
| 449f546e04 | |||
| 4644420eee | |||
| c9bf9e62bd | |||
| 57d2136377 | |||
| 8a34aaba80 | |||
| ff995a929e | |||
| 32ac9c1749 | |||
| 07d9f3afe6 | |||
| e051a5b742 | |||
| d35228759c | |||
| abd103b5ab | |||
| 560c98b869 | |||
| e1f5d73e03 | |||
| 19fd6d59d2 | |||
| e3988b339f | |||
| 336948a397 | |||
| 50ae422de8 | |||
| 8467ed555e | |||
| 01ddc76f7e | |||
| ffb3565ff5 | |||
| 6d700ef6dd | |||
| b8acb6da4c | |||
| 6b4127f0ef | |||
| 16d48ff906 | |||
| 6213ff8a30 | |||
| 4ffaec9d9a | |||
| 6de0a618ee | |||
| 8cc70fa45e | |||
| d666ae5545 | |||
| 5edb3c8086 | |||
| 5b603d8317 | |||
| 5ceade81db | |||
| 8fd9f8b567 | |||
| 30e84ac3b7 | |||
| fd6ffc2500 | |||
| e1cd4ed981 | |||
| 96d2b88f95 | |||
| 8a456d2cde | |||
| 8841230653 | |||
| 4d0a7deaf2 |
+40
-80
@@ -2,24 +2,38 @@ name: Docker Build and Push
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- staging
|
||||
tags:
|
||||
- 'v*'
|
||||
paths:
|
||||
- '**.go'
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- 'Dockerfile'
|
||||
- 'Dockerfile.*'
|
||||
- '.dockerignore'
|
||||
- '.gitea/workflows/build.yml'
|
||||
|
||||
jobs:
|
||||
build-and-push-branches:
|
||||
test:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: github.ref_type == 'branch'
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 'stable'
|
||||
cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Run tests
|
||||
run: go test -v -p 4 ./...
|
||||
|
||||
|
||||
build-and-push:
|
||||
name: Build and Push Docker Image
|
||||
runs-on: ubuntu-latest
|
||||
needs: test
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -28,64 +42,7 @@ jobs:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: git.fossy.my.id
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Set version variables
|
||||
id: vars
|
||||
run: |
|
||||
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
|
||||
echo "VERSION=dev-main" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "VERSION=dev-staging" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
echo "BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT
|
||||
echo "COMMIT=${{ github.sha }}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Build and push Docker image for main
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: |
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:latest
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
VERSION=${{ steps.vars.outputs.VERSION }}
|
||||
BUILD_DATE=${{ steps.vars.outputs.BUILD_DATE }}
|
||||
COMMIT=${{ steps.vars.outputs.COMMIT }}
|
||||
if: github.ref == 'refs/heads/main'
|
||||
|
||||
- name: Build and push Docker image for staging
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: |
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:staging
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
VERSION=${{ steps.vars.outputs.VERSION }}
|
||||
BUILD_DATE=${{ steps.vars.outputs.BUILD_DATE }}
|
||||
COMMIT=${{ steps.vars.outputs.COMMIT }}
|
||||
if: github.ref == 'refs/heads/staging'
|
||||
|
||||
build-and-push-tags:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.ref_type == 'tag' && startsWith(github.ref, 'refs/tags/v')
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
- name: Log in to Docker Registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: git.fossy.my.id
|
||||
@@ -103,32 +60,35 @@ jobs:
|
||||
if echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$'; then
|
||||
MAJOR=$(echo "$VERSION" | cut -d. -f1)
|
||||
MINOR=$(echo "$VERSION" | cut -d. -f2)
|
||||
|
||||
PATCH=$(echo "$VERSION" | cut -d. -f3 | cut -d- -f1)
|
||||
|
||||
echo "MAJOR=$MAJOR" >> $GITHUB_OUTPUT
|
||||
echo "MINOR=$MINOR" >> $GITHUB_OUTPUT
|
||||
|
||||
echo "PATCH=$PATCH" >> $GITHUB_OUTPUT
|
||||
|
||||
if echo "$VERSION" | grep -q '-'; then
|
||||
PRERELEASE_TAG=$(echo "$VERSION" | cut -d- -f2 | cut -d. -f1)
|
||||
echo "IS_PRERELEASE=true" >> $GITHUB_OUTPUT
|
||||
echo "ADDITIONAL_TAG=staging" >> $GITHUB_OUTPUT
|
||||
echo "PRERELEASE_TAG=$PRERELEASE_TAG" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "IS_PRERELEASE=false" >> $GITHUB_OUTPUT
|
||||
echo "ADDITIONAL_TAG=latest" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
else
|
||||
echo "Invalid version format: $VERSION"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Build and push Docker image for release
|
||||
- name: Build and push Docker image (release)
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: |
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:release
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }}.${{ steps.version.outputs.MINOR }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:${{ steps.version.outputs.ADDITIONAL_TAG }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:latest
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
VERSION=${{ steps.version.outputs.VERSION }}
|
||||
@@ -136,17 +96,17 @@ jobs:
|
||||
COMMIT=${{ steps.version.outputs.COMMIT }}
|
||||
if: steps.version.outputs.IS_PRERELEASE == 'false'
|
||||
|
||||
- name: Build and push Docker image for pre-release
|
||||
- name: Build and push Docker image (pre-release)
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
tags: |
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:${{ steps.version.outputs.ADDITIONAL_TAG }}
|
||||
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:staging
|
||||
platforms: linux/amd64,linux/arm64
|
||||
build-args: |
|
||||
VERSION=${{ steps.version.outputs.VERSION }}
|
||||
BUILD_DATE=${{ steps.version.outputs.BUILD_DATE }}
|
||||
COMMIT=${{ steps.version.outputs.COMMIT }}
|
||||
if: steps.version.outputs.IS_PRERELEASE == 'true'
|
||||
if: steps.version.outputs.IS_PRERELEASE == 'true'
|
||||
@@ -1,21 +0,0 @@
|
||||
name: renovate
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0 * * *"
|
||||
push:
|
||||
branches:
|
||||
- staging
|
||||
|
||||
jobs:
|
||||
renovate:
|
||||
runs-on: ubuntu-latest
|
||||
container: git.fossy.my.id/renovate-clanker/renovate:latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- run: renovate
|
||||
env:
|
||||
RENOVATE_CONFIG_FILE: ${{ gitea.workspace }}/renovate-config.js
|
||||
LOG_LEVEL: "debug"
|
||||
RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }}
|
||||
GITHUB_COM_TOKEN: ${{ secrets.COM_TOKEN }}
|
||||
@@ -0,0 +1,60 @@
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- staging
|
||||
- 'feat/**'
|
||||
|
||||
name: SonarQube Scan
|
||||
jobs:
|
||||
sonarqube:
|
||||
name: SonarQube Trigger
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checking out
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 'stable'
|
||||
cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: go mod tidy
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./... 2>&1 | tee vet-results.txt
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
go test ./... -v -p 4 -coverprofile=coverage
|
||||
|
||||
- name: Run GolangCI-Lint Analysis
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
skip-cache: true
|
||||
version: v2.6
|
||||
args: >
|
||||
--issues-exit-code=0
|
||||
--output.text.path=stdout
|
||||
--output.checkstyle.path=golangci-lint-report.xml
|
||||
|
||||
- name: SonarQube Scan
|
||||
uses: SonarSource/sonarqube-scan-action@v7.0.0
|
||||
env:
|
||||
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
|
||||
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
|
||||
with:
|
||||
args: >
|
||||
-Dsonar.projectKey=tunnel-please
|
||||
-Dsonar.go.coverage.reportPaths=coverage
|
||||
-Dsonar.test.inclusions=**/*_test.go
|
||||
-Dsonar.test.exclusions=**/vendor/**
|
||||
-Dsonar.exclusions=**/*_test.go,**/vendor/**,**/golangci-lint-report.xml
|
||||
-Dsonar.go.govet.reportPaths=vet-results.txt
|
||||
-Dsonar.go.golangci-lint.reportPaths=golangci-lint-report.xml
|
||||
-Dsonar.sources=./
|
||||
-Dsonar.tests=./
|
||||
@@ -0,0 +1,36 @@
|
||||
name: Tests
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event_name == 'pull_request' ||
|
||||
(github.event_name == 'issue_comment' &&
|
||||
github.event.issue.pull_request != null &&
|
||||
contains(github.event.comment.body, '/retest'))
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: 'stable'
|
||||
cache: false
|
||||
|
||||
- name: Install dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Run tests
|
||||
run: go test -v -p 4 ./...
|
||||
Vendored
+3
-1
@@ -4,4 +4,6 @@ id_rsa*
|
||||
.env
|
||||
tmp
|
||||
certs
|
||||
app
|
||||
app
|
||||
coverage
|
||||
test-results.json
|
||||
+5
-2
@@ -1,4 +1,4 @@
|
||||
FROM golang:1.25.5-alpine AS go_builder
|
||||
FROM golang:1.25.7-alpine AS go_builder
|
||||
|
||||
ARG VERSION=dev
|
||||
ARG BUILD_DATE=unknown
|
||||
@@ -22,7 +22,10 @@ RUN --mount=type=cache,target=/go/pkg/mod \
|
||||
--mount=type=cache,target=/root/.cache/go-build \
|
||||
CGO_ENABLED=0 GOOS=linux \
|
||||
go build -trimpath \
|
||||
-ldflags="-w -s -X tunnel_pls/version.Version=${VERSION} -X tunnel_pls/version.BuildDate=${BUILD_DATE} -X tunnel_pls/version.Commit=${COMMIT}" \
|
||||
-ldflags="-w -s \
|
||||
-X tunnel_pls/internal/version.Version=${VERSION} \
|
||||
-X tunnel_pls/internal/version.BuildDate=${BUILD_DATE} \
|
||||
-X tunnel_pls/internal/version.Commit=${COMMIT}" \
|
||||
-o /app/tunnel_pls \
|
||||
.
|
||||
|
||||
|
||||
@@ -1,6 +1,22 @@
|
||||
<div align="center">
|
||||
|
||||
<img alt="gopher" title="gopher" src="./docs/images/gopher.png" width="325" />
|
||||
|
||||
# Tunnel Please
|
||||
|
||||
A lightweight SSH-based tunnel server written in Go that enables secure TCP and HTTP forwarding with an interactive terminal interface for managing connections and custom subdomains.
|
||||
A lightweight SSH-based tunnel server
|
||||
|
||||
<br/><br/>
|
||||
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
[](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
|
||||
|
||||
</div>
|
||||
|
||||
## Features
|
||||
|
||||
@@ -17,104 +33,32 @@ A lightweight SSH-based tunnel server written in Go that enables secure TCP and
|
||||
|
||||
The following environment variables can be configured in the `.env` file:
|
||||
|
||||
| Variable | Description | Default | Required |
|
||||
|----------|-------------|---------|----------|
|
||||
| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
|
||||
| `PORT` | SSH server port | `2200` | No |
|
||||
| `HTTP_PORT` | HTTP server port | `8080` | No |
|
||||
| `HTTPS_PORT` | HTTPS server port | `8443` | No |
|
||||
| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
|
||||
| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
|
||||
| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No |
|
||||
| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | - | Yes (if auto-cert) |
|
||||
| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
|
||||
| `CORS_LIST` | Comma-separated list of allowed CORS origins | - | No |
|
||||
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
|
||||
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
|
||||
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
|
||||
| `PPROF_PORT` | Port for pprof server | `6060` | No |
|
||||
| Variable | Description | Default | Required |
|
||||
|---------------------|-----------------------------------------------------------------------------|-------------------------|---------------------|
|
||||
| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
|
||||
| `PORT` | SSH server port | `2200` | No |
|
||||
| `HTTP_PORT` | HTTP server port | `8080` | No |
|
||||
| `HTTPS_PORT` | HTTPS server port | `8443` | No |
|
||||
| `KEY_LOC` | Path to the private key file | `certs/privkey.pem` | No |
|
||||
| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
|
||||
| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
|
||||
| `TLS_STORAGE_PATH` | Path to store TLS certificates | `certs/tls/` | No |
|
||||
| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No |
|
||||
| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | `-` | Yes (if auto-cert) |
|
||||
| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
|
||||
| `CORS_LIST` | Comma-separated list of allowed CORS origins | `-` | No |
|
||||
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
|
||||
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
|
||||
| `MAX_HEADER_SIZE` | Maximum size of HTTP headers in bytes (4096-131072) | `4096` | No |
|
||||
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
|
||||
| `PPROF_PORT` | Port for pprof server | `6060` | No |
|
||||
| `MODE` | Runtime mode: `standalone` or `node` | `standalone` | No |
|
||||
| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
|
||||
| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
|
||||
| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | `-` | Yes (node mode) |
|
||||
|
||||
**Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality.
|
||||
|
||||
### Automatic TLS Certificate Management
|
||||
|
||||
The server supports automatic TLS certificate generation and renewal using [CertMagic](https://github.com/caddyserver/certmagic) with Cloudflare DNS-01 challenge. This is required for wildcard certificate support (`*.yourdomain.com`).
|
||||
|
||||
**Certificate Storage:**
|
||||
- TLS certificates are stored in `certs/tls/` (relative to application directory)
|
||||
- User-provided certificates: `certs/tls/cert.pem` and `certs/tls/privkey.pem`
|
||||
- CertMagic automatic certificates: `certs/tls/certmagic/`
|
||||
- SSH keys are stored separately in `certs/ssh/`
|
||||
|
||||
**How it works:**
|
||||
1. If user-provided certificates exist at `certs/tls/cert.pem` and `certs/tls/privkey.pem` and cover both `DOMAIN` and `*.DOMAIN`, they will be used
|
||||
2. If certificates are missing, expired, expiring within 30 days, or don't cover the required domains, CertMagic will automatically obtain new certificates from Let's Encrypt
|
||||
3. Certificates are automatically renewed before expiration
|
||||
4. User-provided certificates support hot-reload (changes detected every 30 seconds)
|
||||
|
||||
**Cloudflare API Token Setup:**
|
||||
|
||||
To use automatic certificate generation, you need a Cloudflare API token with the following permissions:
|
||||
|
||||
1. Go to [Cloudflare Dashboard](https://dash.cloudflare.com/profile/api-tokens)
|
||||
2. Click "Create Token"
|
||||
3. Use "Create Custom Token" with these permissions:
|
||||
- **Zone → Zone → Read** (for all zones or specific zone)
|
||||
- **Zone → DNS → Edit** (for all zones or specific zone)
|
||||
4. Copy the token and set it as `CF_API_TOKEN` environment variable
|
||||
|
||||
**Example configuration for automatic certificates:**
|
||||
```env
|
||||
DOMAIN=example.com
|
||||
TLS_ENABLED=true
|
||||
CF_API_TOKEN=your_cloudflare_api_token_here
|
||||
ACME_EMAIL=admin@example.com
|
||||
# ACME_STAGING=true # Uncomment for testing to avoid rate limits
|
||||
```
|
||||
|
||||
### SSH Key Auto-Generation
|
||||
|
||||
The application will automatically generate a new 4096-bit RSA key pair at `certs/ssh/id_rsa` if it doesn't exist. This makes it easier to get started without manually creating SSH keys. SSH keys are stored separately from TLS certificates.
|
||||
|
||||
### Memory Optimization
|
||||
|
||||
The application uses a buffer pool with controlled buffer sizes to prevent excessive memory usage under high concurrent loads. The `BUFFER_SIZE` environment variable controls the size of buffers used for io.Copy operations:
|
||||
|
||||
- **Default:** 32768 bytes (32 KB) - Good balance for most scenarios
|
||||
- **Minimum:** 4096 bytes (4 KB) - Lower memory usage, more CPU overhead
|
||||
- **Maximum:** 1048576 bytes (1 MB) - Higher throughput, more memory usage
|
||||
|
||||
**Recommended settings based on load:**
|
||||
- **Low traffic (<100 concurrent):** `BUFFER_SIZE=32768` (default)
|
||||
- **High traffic (>100 concurrent):** `BUFFER_SIZE=16384` or `BUFFER_SIZE=8192`
|
||||
- **Very high traffic (>1000 concurrent):** `BUFFER_SIZE=8192` or `BUFFER_SIZE=4096`
|
||||
|
||||
The buffer pool reuses buffers across connections, preventing memory fragmentation and reducing garbage collection pressure.
|
||||
|
||||
### Profiling with pprof
|
||||
|
||||
To enable profiling for performance analysis:
|
||||
|
||||
1. Set `PPROF_ENABLED=true` in your `.env` file
|
||||
2. Optionally set `PPROF_PORT` to your desired port (default: 6060)
|
||||
3. Access profiling data at `http://localhost:6060/debug/pprof/`
|
||||
|
||||
Common pprof endpoints:
|
||||
- `/debug/pprof/` - Index page with available profiles
|
||||
- `/debug/pprof/heap` - Memory allocation profile
|
||||
- `/debug/pprof/goroutine` - Stack traces of all current goroutines
|
||||
- `/debug/pprof/profile` - CPU profile (30-second sample by default)
|
||||
- `/debug/pprof/trace` - Execution trace
|
||||
|
||||
Example usage with `go tool pprof`:
|
||||
```bash
|
||||
# Analyze CPU profile
|
||||
go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
|
||||
|
||||
# Analyze memory heap
|
||||
go tool pprof http://localhost:6060/debug/pprof/heap
|
||||
```
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`.
|
||||
@@ -193,22 +137,6 @@ docker-compose -f docker-compose.tcp.yml up -d
|
||||
docker-compose -f docker-compose.root.yml down
|
||||
```
|
||||
|
||||
### Volume Management
|
||||
|
||||
All configurations use a named volume `certs` for persistent storage:
|
||||
- SSH keys: `/app/certs/ssh/`
|
||||
- TLS certificates: `/app/certs/tls/`
|
||||
|
||||
To backup certificates:
|
||||
```bash
|
||||
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar czf /backup/certs-backup.tar.gz -C /data .
|
||||
```
|
||||
|
||||
To restore certificates:
|
||||
```bash
|
||||
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar xzf /backup/certs-backup.tar.gz -C /data
|
||||
```
|
||||
|
||||
### Recommendation
|
||||
|
||||
**Use `docker-compose.root.yml`** for production deployments if you need:
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 2.0 MiB |
@@ -1,48 +1,60 @@
|
||||
module tunnel_pls
|
||||
|
||||
go 1.24.4
|
||||
go 1.25.5
|
||||
|
||||
require (
|
||||
github.com/caddyserver/certmagic v0.25.0
|
||||
github.com/charmbracelet/bubbles v0.21.0
|
||||
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0
|
||||
github.com/caddyserver/certmagic v0.25.1
|
||||
github.com/charmbracelet/bubbles v0.21.1
|
||||
github.com/charmbracelet/bubbletea v1.3.10
|
||||
github.com/charmbracelet/lipgloss v1.1.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/libdns/cloudflare v0.2.2
|
||||
golang.org/x/crypto v0.46.0
|
||||
github.com/muesli/termenv v0.16.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/crypto v0.47.0
|
||||
google.golang.org/grpc v1.78.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/atotto/clipboard v0.1.4 // indirect
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
|
||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||
github.com/charmbracelet/lipgloss v1.1.0 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.10.1 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
|
||||
github.com/charmbracelet/x/term v0.2.1 // indirect
|
||||
github.com/caddyserver/zerossl v0.1.4 // indirect
|
||||
github.com/charmbracelet/colorprofile v0.4.1 // indirect
|
||||
github.com/charmbracelet/x/ansi v0.11.5 // indirect
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
|
||||
github.com/charmbracelet/x/term v0.2.2 // indirect
|
||||
github.com/clipperhouse/displaywidth v0.9.0 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/libdns/libdns v1.1.1 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-localereader v0.0.1 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mholt/acmez/v3 v3.1.3 // indirect
|
||||
github.com/miekg/dns v1.1.68 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/mholt/acmez/v3 v3.1.4 // indirect
|
||||
github.com/miekg/dns v1.1.69 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/muesli/termenv v0.16.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/sahilm/fuzzy v0.1.1 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/zeebo/blake3 v0.2.4 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go.uber.org/zap v1.27.0 // indirect
|
||||
go.uber.org/zap v1.27.1 // indirect
|
||||
go.uber.org/zap/exp v0.3.0 // indirect
|
||||
golang.org/x/mod v0.30.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,35 +1,62 @@
|
||||
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA=
|
||||
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo=
|
||||
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
|
||||
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
|
||||
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
|
||||
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
|
||||
github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic=
|
||||
github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA=
|
||||
github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA=
|
||||
github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
|
||||
github.com/caddyserver/certmagic v0.25.1 h1:4sIKKbOt5pg6+sL7tEwymE1x2bj6CHr80da1CRRIPbY=
|
||||
github.com/caddyserver/certmagic v0.25.1/go.mod h1:VhyvndxtVton/Fo/wKhRoC46Rbw1fmjvQ3GjHYSQTEY=
|
||||
github.com/caddyserver/zerossl v0.1.4 h1:CVJOE3MZeFisCERZjkxIcsqIH4fnFdlYWnPYeFtBHRw=
|
||||
github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
|
||||
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
|
||||
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
|
||||
github.com/charmbracelet/bubbles v0.21.1 h1:nj0decPiixaZeL9diI4uzzQTkkz1kYY8+jgzCZXSmW0=
|
||||
github.com/charmbracelet/bubbles v0.21.1/go.mod h1:HHvIYRCpbkCJw2yo0vNX1O5loCwSr9/mWS8GYSg50Sk=
|
||||
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
|
||||
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
|
||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
|
||||
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
|
||||
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
|
||||
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
|
||||
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
|
||||
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
|
||||
github.com/charmbracelet/x/ansi v0.11.3 h1:6DcVaqWI82BBVM/atTyq6yBoRLZFBsnoDoX9GCu2YOI=
|
||||
github.com/charmbracelet/x/ansi v0.11.3/go.mod h1:yI7Zslym9tCJcedxz5+WBq+eUGMJT0bM06Fqy1/Y4dI=
|
||||
github.com/charmbracelet/x/ansi v0.11.5 h1:NBWeBpj/lJPE3Q5l+Lusa4+mH6v7487OP8K0r1IhRg4=
|
||||
github.com/charmbracelet/x/ansi v0.11.5/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.14 h1:iUEMryGyFTelKW3THW4+FfPgi4fkmKnnaLOXuc+/Kj4=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.14/go.mod h1:P447lJl49ywBbil/KjCk2HexGh4tEY9LH0/1QrZZ9rA=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
|
||||
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ=
|
||||
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
|
||||
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
|
||||
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
|
||||
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
|
||||
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
|
||||
github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo=
|
||||
github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o=
|
||||
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
|
||||
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
|
||||
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
|
||||
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
@@ -40,18 +67,18 @@ github.com/libdns/cloudflare v0.2.2 h1:XWHv+C1dDcApqazlh08Q6pjytYLgR2a+Y3xrXFu0v
|
||||
github.com/libdns/cloudflare v0.2.2/go.mod h1:w9uTmRCDlAoafAsTPnn2nJ0XHK/eaUMh86DUk8BWi60=
|
||||
github.com/libdns/libdns v1.1.1 h1:wPrHrXILoSHKWJKGd0EiAVmiJbFShguILTg9leS/P/U=
|
||||
github.com/libdns/libdns v1.1.1/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
|
||||
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
|
||||
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
|
||||
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
|
||||
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/mholt/acmez/v3 v3.1.3 h1:gUl789rjbJSuM5hYzOFnNaGgWPV1xVfnOs59o0dZEcc=
|
||||
github.com/mholt/acmez/v3 v3.1.3/go.mod h1:L1wOU06KKvq7tswuMDwKdcHeKpFFgkppZy/y0DFxagQ=
|
||||
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
|
||||
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
|
||||
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/mholt/acmez/v3 v3.1.4 h1:DyzZe/RnAzT3rpZj/2Ii5xZpiEvvYk3cQEN/RmqxwFQ=
|
||||
github.com/mholt/acmez/v3 v3.1.4/go.mod h1:L1wOU06KKvq7tswuMDwKdcHeKpFFgkppZy/y0DFxagQ=
|
||||
github.com/miekg/dns v1.1.69 h1:Kb7Y/1Jo+SG+a2GtfoFUfDkG//csdRPwRLkCsxDG9Sc=
|
||||
github.com/miekg/dns v1.1.69/go.mod h1:7OyjD9nEba5OkqQ/hB4fy3PIoxafSZJtducccIelz3g=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
|
||||
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
|
||||
@@ -60,13 +87,22 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
|
||||
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
|
||||
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
|
||||
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
|
||||
@@ -75,33 +111,56 @@ github.com/zeebo/blake3 v0.2.4 h1:KYQPkhpRtcqh0ssGYcKLG1JYvddkEA8QwCM/yBqhaZI=
|
||||
github.com/zeebo/blake3 v0.2.4/go.mod h1:7eeQ6d2iXWRGF6npfaxl2CU+xy2Fjo2gxeyZGCRUjcE=
|
||||
github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
|
||||
github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U=
|
||||
go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
|
||||
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
|
||||
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b h1:Mv8VFug0MP9e5vUxfBcE3vUkV6CImK3cMNMIDFjmzxU=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
|
||||
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
|
||||
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/grpc/client"
|
||||
"tunnel_pls/internal/key"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/internal/transport"
|
||||
"tunnel_pls/internal/version"
|
||||
"tunnel_pls/server"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Bootstrap struct {
|
||||
Randomizer random.Random
|
||||
Config config.Config
|
||||
SessionRegistry registry.Registry
|
||||
Port port.Port
|
||||
GrpcClient client.Client
|
||||
ErrChan chan error
|
||||
SignalChan chan os.Signal
|
||||
}
|
||||
|
||||
func New(config config.Config, port port.Port) (*Bootstrap, error) {
|
||||
randomizer := random.New()
|
||||
sessionRegistry := registry.NewRegistry()
|
||||
|
||||
if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
grpcClient, err := client.New(config, sessionRegistry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errChan := make(chan error, 5)
|
||||
signalChan := make(chan os.Signal, 1)
|
||||
|
||||
return &Bootstrap{
|
||||
Randomizer: randomizer,
|
||||
Config: config,
|
||||
SessionRegistry: sessionRegistry,
|
||||
Port: port,
|
||||
GrpcClient: grpcClient,
|
||||
ErrChan: errChan,
|
||||
SignalChan: signalChan,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) {
|
||||
sshCfg := &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
|
||||
}
|
||||
|
||||
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
|
||||
return nil, fmt.Errorf("generate ssh key: %w", err)
|
||||
}
|
||||
privateBytes, err := os.ReadFile(sshKeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read private key: %w", err)
|
||||
}
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
sshCfg.AddHostKey(private)
|
||||
return sshCfg, nil
|
||||
}
|
||||
|
||||
func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error {
|
||||
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer healthCancel()
|
||||
if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil {
|
||||
return fmt.Errorf("gRPC health check failed: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
|
||||
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
|
||||
httpserver := transport.NewHTTPServer(conf, registry)
|
||||
ln, err := httpserver.Listen()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
||||
return
|
||||
}
|
||||
if err = httpserver.Serve(ln); err != nil {
|
||||
errChan <- fmt.Errorf("error when serving http server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func startHTTPSServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
|
||||
tlsCfg, err := transport.NewTLSConfig(conf)
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
|
||||
return
|
||||
}
|
||||
httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg)
|
||||
ln, err := httpsServer.Listen()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
|
||||
return
|
||||
}
|
||||
if err = httpsServer.Serve(ln); err != nil {
|
||||
errChan <- fmt.Errorf("error when serving https server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) {
|
||||
sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort())
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
sshServer.Start()
|
||||
|
||||
errChan <- sshServer.Close()
|
||||
}
|
||||
|
||||
func startPprof(pprofPort string, errChan chan<- error) {
|
||||
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
|
||||
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||
errChan <- fmt.Errorf("pprof server error: %v", err)
|
||||
}
|
||||
}
|
||||
func (b *Bootstrap) Run() error {
|
||||
sshConfig, err := newSSHConfig(b.Config.KeyLoc())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create SSH config: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
if b.Config.Mode() == types.ServerModeNODE {
|
||||
err = b.startGRPCClient(ctx, b.Config, b.ErrChan)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start gRPC client: %w", err)
|
||||
}
|
||||
defer func(grpcClient client.Client) {
|
||||
err = grpcClient.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close gRPC client")
|
||||
}
|
||||
}(b.GrpcClient)
|
||||
}
|
||||
|
||||
go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan)
|
||||
|
||||
if b.Config.TLSEnabled() {
|
||||
go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan)
|
||||
}
|
||||
|
||||
go func() {
|
||||
startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan)
|
||||
}()
|
||||
|
||||
if b.Config.PprofEnabled() {
|
||||
go startPprof(b.Config.PprofPort(), b.ErrChan)
|
||||
}
|
||||
|
||||
log.Println("All services started successfully")
|
||||
|
||||
select {
|
||||
case err = <-b.ErrChan:
|
||||
return fmt.Errorf("service error: %w", err)
|
||||
case sig := <-b.SignalChan:
|
||||
log.Printf("Received signal %s, initiating graceful shutdown", sig)
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,558 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
args := m.Called(user, oldKey, newKey)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
args := m.Called(key, session)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
args := m.Called(user)
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
type MockRandom struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRandom) String(length int) (string, error) {
|
||||
args := m.Called(length)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
type MockConfig struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConfig) Domain() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return 0
|
||||
}
|
||||
switch v := args.Get(0).(type) {
|
||||
case types.ServerMode:
|
||||
return v
|
||||
case int:
|
||||
return types.ServerMode(v)
|
||||
default:
|
||||
return types.ServerMode(args.Int(0))
|
||||
}
|
||||
}
|
||||
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
||||
|
||||
type MockPort struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
||||
return m.Called(startPort, endPort).Error(0)
|
||||
}
|
||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
||||
args := m.Called()
|
||||
var mPort uint16
|
||||
if args.Get(0) != nil {
|
||||
switch v := args.Get(0).(type) {
|
||||
case int:
|
||||
mPort = uint16(v)
|
||||
case uint16:
|
||||
mPort = v
|
||||
case uint32:
|
||||
mPort = uint16(v)
|
||||
case int32:
|
||||
mPort = uint16(v)
|
||||
case float64:
|
||||
mPort = uint16(v)
|
||||
default:
|
||||
mPort = uint16(args.Int(0))
|
||||
}
|
||||
}
|
||||
return mPort, args.Bool(1)
|
||||
}
|
||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
||||
return m.Called(port, assigned).Error(0)
|
||||
}
|
||||
func (m *MockPort) Claim(port uint16) bool {
|
||||
return m.Called(port).Bool(0)
|
||||
}
|
||||
|
||||
type MockGRPCClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*grpc.ClientConn)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
||||
args := m.Called(ctx, token)
|
||||
return args.Bool(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
|
||||
args := m.Called(ctx, domain, token)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupConfig func() config.Config
|
||||
setupPort func() port.Port
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Success New with default value",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Error when AddRange fails",
|
||||
setupPort: func() port.Port {
|
||||
mockPort := &MockPort{}
|
||||
mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range"))
|
||||
return mockPort
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "invalid port range",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var mockPort port.Port
|
||||
if tt.setupPort != nil {
|
||||
mockPort = tt.setupPort()
|
||||
} else {
|
||||
mockPort = port.New()
|
||||
}
|
||||
|
||||
var mockConfig config.Config
|
||||
if tt.setupConfig != nil {
|
||||
mockConfig = tt.setupConfig()
|
||||
} else {
|
||||
var err error
|
||||
mockConfig, err = config.MustLoad()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
bootstrap, err := New(mockConfig, mockPort)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, bootstrap)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, bootstrap)
|
||||
assert.NotNil(t, bootstrap.Randomizer)
|
||||
assert.NotNil(t, bootstrap.SessionRegistry)
|
||||
assert.NotNil(t, bootstrap.Config)
|
||||
assert.NotNil(t, bootstrap.Port)
|
||||
assert.NotNil(t, bootstrap.ErrChan)
|
||||
assert.NotNil(t, bootstrap.SignalChan)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func randomAvailablePort() (string, error) {
|
||||
listener, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func(listener net.Listener) {
|
||||
_ = listener.Close()
|
||||
}(listener)
|
||||
|
||||
mPort := listener.Addr().(*net.TCPAddr).Port
|
||||
return strconv.Itoa(mPort), nil
|
||||
}
|
||||
|
||||
func TestRun(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockErrChan := make(chan error, 1)
|
||||
mockSignalChan := make(chan os.Signal, 1)
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
keyLoc := filepath.Join(tmpDir, "key.key")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupConfig func() *MockConfig
|
||||
setupGrpcClient func() *MockGRPCClient
|
||||
needCerts bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "successful run and termination",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "error from SSH server invalid port",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("invalid")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "error from HTTP server invalid port",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("invalid")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "error from HTTPS server invalid port",
|
||||
setupConfig: func() *MockConfig {
|
||||
tempDir := os.TempDir()
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("invalid")
|
||||
mockConfig.On("TLSEnabled").Return(true)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("TLSStoragePath").Return(tempDir)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "grpc health check failed",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("invalid")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
setupGrpcClient: func() *MockGRPCClient {
|
||||
mockGRPCClient := &MockGRPCClient{}
|
||||
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed"))
|
||||
return mockGRPCClient
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "successful run with pprof enabled",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
pprofPort, _ := randomAvailablePort()
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(true)
|
||||
mockConfig.On("PprofPort").Return(pprofPort)
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
expectError: false,
|
||||
}, {
|
||||
name: "successful run in NODE mode with signal",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
setupGrpcClient: func() *MockGRPCClient {
|
||||
mockGRPCClient := &MockGRPCClient{}
|
||||
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
|
||||
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
mockGRPCClient.On("Close").Return(nil)
|
||||
return mockGRPCClient
|
||||
},
|
||||
expectError: false,
|
||||
}, {
|
||||
name: "successful run in NODE mode with signal buf error when closing",
|
||||
setupConfig: func() *MockConfig {
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("KeyLoc").Return(keyLoc)
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("SSHPort").Return("0")
|
||||
mockConfig.On("HTTPPort").Return("0")
|
||||
mockConfig.On("HTTPSPort").Return("0")
|
||||
mockConfig.On("TLSEnabled").Return(false)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
mockConfig.On("ACMEEmail").Return("test@example.com")
|
||||
mockConfig.On("CFAPIToken").Return("fake-token")
|
||||
mockConfig.On("ACMEStaging").Return(true)
|
||||
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
|
||||
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
|
||||
mockConfig.On("BufferSize").Return(4096)
|
||||
mockConfig.On("PprofEnabled").Return(false)
|
||||
mockConfig.On("PprofPort").Return("0")
|
||||
mockConfig.On("GRPCAddress").Return("localhost")
|
||||
mockConfig.On("GRPCPort").Return("0")
|
||||
mockConfig.On("NodeToken").Return("fake-node-token")
|
||||
return mockConfig
|
||||
},
|
||||
setupGrpcClient: func() *MockGRPCClient {
|
||||
mockGRPCClient := &MockGRPCClient{}
|
||||
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
|
||||
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
|
||||
mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy"))
|
||||
return mockGRPCClient
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockConfig := tt.setupConfig()
|
||||
mockGRPCClient := &MockGRPCClient{}
|
||||
bootstrap := &Bootstrap{
|
||||
Randomizer: mockRandom,
|
||||
Config: mockConfig,
|
||||
SessionRegistry: mockSessionRegistry,
|
||||
Port: mockPort,
|
||||
ErrChan: mockErrChan,
|
||||
SignalChan: mockSignalChan,
|
||||
GrpcClient: mockGRPCClient,
|
||||
}
|
||||
|
||||
if tt.setupGrpcClient != nil {
|
||||
bootstrap.GrpcClient = tt.setupGrpcClient()
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- bootstrap.Run()
|
||||
}()
|
||||
|
||||
if tt.expectError {
|
||||
err := <-done
|
||||
assert.Error(t, err)
|
||||
} else if tt.name == "successful run with pprof enabled" {
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
fmt.Println(mockConfig.PprofPort())
|
||||
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
err = resp.Body.Close()
|
||||
assert.NoError(t, err)
|
||||
mockSignalChan <- os.Interrupt
|
||||
err = <-done
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
mockSignalChan <- os.Interrupt
|
||||
err := <-done
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+60
-25
@@ -1,35 +1,70 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
import "tunnel_pls/types"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
type Config interface {
|
||||
Domain() string
|
||||
SSHPort() string
|
||||
|
||||
func init() {
|
||||
if _, err := os.Stat(".env"); err == nil {
|
||||
if err := godotenv.Load(".env"); err != nil {
|
||||
log.Printf("Warning: Failed to load .env file: %s", err)
|
||||
}
|
||||
}
|
||||
HTTPPort() string
|
||||
HTTPSPort() string
|
||||
|
||||
KeyLoc() string
|
||||
|
||||
TLSEnabled() bool
|
||||
TLSRedirect() bool
|
||||
TLSStoragePath() string
|
||||
|
||||
ACMEEmail() string
|
||||
CFAPIToken() string
|
||||
ACMEStaging() bool
|
||||
|
||||
AllowedPortsStart() uint16
|
||||
AllowedPortsEnd() uint16
|
||||
|
||||
BufferSize() int
|
||||
HeaderSize() int
|
||||
|
||||
PprofEnabled() bool
|
||||
PprofPort() string
|
||||
|
||||
Mode() types.ServerMode
|
||||
GRPCAddress() string
|
||||
GRPCPort() string
|
||||
NodeToken() string
|
||||
}
|
||||
|
||||
func Getenv(key, defaultValue string) string {
|
||||
val := os.Getenv(key)
|
||||
if val == "" {
|
||||
val = defaultValue
|
||||
func MustLoad() (Config, error) {
|
||||
if err := loadEnvFile(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return val
|
||||
cfg, err := parse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func GetBufferSize() int {
|
||||
sizeStr := Getenv("BUFFER_SIZE", "32768")
|
||||
size, err := strconv.Atoi(sizeStr)
|
||||
if err != nil || size < 4096 || size > 1048576 {
|
||||
return 32768
|
||||
}
|
||||
return size
|
||||
}
|
||||
func (c *config) Domain() string { return c.domain }
|
||||
func (c *config) SSHPort() string { return c.sshPort }
|
||||
func (c *config) HTTPPort() string { return c.httpPort }
|
||||
func (c *config) HTTPSPort() string { return c.httpsPort }
|
||||
func (c *config) KeyLoc() string { return c.keyLoc }
|
||||
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
|
||||
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
|
||||
func (c *config) TLSStoragePath() string { return c.tlsStoragePath }
|
||||
func (c *config) ACMEEmail() string { return c.acmeEmail }
|
||||
func (c *config) CFAPIToken() string { return c.cfAPIToken }
|
||||
func (c *config) ACMEStaging() bool { return c.acmeStaging }
|
||||
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
|
||||
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
|
||||
func (c *config) BufferSize() int { return c.bufferSize }
|
||||
func (c *config) HeaderSize() int { return c.headerSize }
|
||||
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
|
||||
func (c *config) PprofPort() string { return c.pprofPort }
|
||||
func (c *config) Mode() types.ServerMode { return c.mode }
|
||||
func (c *config) GRPCAddress() string { return c.grpcAddress }
|
||||
func (c *config) GRPCPort() string { return c.grpcPort }
|
||||
func (c *config) NodeToken() string { return c.nodeToken }
|
||||
|
||||
@@ -0,0 +1,405 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetenv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
val string
|
||||
def string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "returns existing env",
|
||||
key: "TEST_ENV_EXIST",
|
||||
val: "value",
|
||||
def: "default",
|
||||
expected: "value",
|
||||
},
|
||||
{
|
||||
name: "returns default when env missing",
|
||||
key: "TEST_ENV_MISSING",
|
||||
val: "",
|
||||
def: "default",
|
||||
expected: "default",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv(tt.key, tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv(tt.key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetenvBool(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
key string
|
||||
val string
|
||||
def bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "returns true when env is true",
|
||||
key: "TEST_BOOL_TRUE",
|
||||
val: "true",
|
||||
def: false,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "returns false when env is false",
|
||||
key: "TEST_BOOL_FALSE",
|
||||
val: "false",
|
||||
def: true,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "returns default when env missing",
|
||||
key: "TEST_BOOL_MISSING",
|
||||
val: "",
|
||||
def: true,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "returns false when env is not true",
|
||||
key: "TEST_BOOL_INVALID",
|
||||
val: "yes",
|
||||
def: true,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv(tt.key, tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv(tt.key)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode string
|
||||
expect types.ServerMode
|
||||
expectErr bool
|
||||
}{
|
||||
{"standalone", "standalone", types.ServerModeSTANDALONE, false},
|
||||
{"node", "node", types.ServerModeNODE, false},
|
||||
{"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false},
|
||||
{"invalid", "invalid", 0, true},
|
||||
{"empty (default)", "", types.ServerModeSTANDALONE, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.mode != "" {
|
||||
t.Setenv("MODE", tt.mode)
|
||||
} else {
|
||||
err := os.Unsetenv("MODE")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
mode, err := parseMode()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expect, mode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAllowedPorts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
start uint16
|
||||
end uint16
|
||||
expectErr bool
|
||||
}{
|
||||
{"valid range", "1000-2000", 1000, 2000, false},
|
||||
{"empty", "", 0, 0, false},
|
||||
{"invalid format - no dash", "1000", 0, 0, true},
|
||||
{"invalid format - too many dashes", "1000-2000-3000", 0, 0, true},
|
||||
{"invalid start port", "abc-2000", 0, 0, true},
|
||||
{"invalid end port", "1000-abc", 0, 0, true},
|
||||
{"out of range start", "70000-80000", 0, 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv("ALLOWED_PORTS", tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv("ALLOWED_PORTS")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
start, end, err := parseAllowedPorts()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.start, start)
|
||||
assert.Equal(t, tt.end, end)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBufferSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
expect int
|
||||
}{
|
||||
{"valid size", "8192", 8192},
|
||||
{"default size", "", 32768},
|
||||
{"too small", "1024", 4096},
|
||||
{"too large", "2000000", 4096},
|
||||
{"invalid format", "abc", 4096},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv("BUFFER_SIZE", tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv("BUFFER_SIZE")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
size := parseBufferSize()
|
||||
assert.Equal(t, tt.expect, size)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHeaderSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
val string
|
||||
expect int
|
||||
}{
|
||||
{"valid size", "8192", 8192},
|
||||
{"default size", "", 4096},
|
||||
{"too small", "1024", 4096},
|
||||
{"too large", "2000000", 4096},
|
||||
{"invalid format", "abc", 4096},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.val != "" {
|
||||
t.Setenv("MAX_HEADER_SIZE", tt.val)
|
||||
} else {
|
||||
err := os.Unsetenv("MAX_HEADER_SIZE")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
size := parseHeaderSize()
|
||||
assert.Equal(t, tt.expect, size)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envs map[string]string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "minimal valid config",
|
||||
envs: map[string]string{
|
||||
"DOMAIN": "example.com",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "TLS enabled without token",
|
||||
envs: map[string]string{
|
||||
"TLS_ENABLED": "true",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "TLS enabled with token",
|
||||
envs: map[string]string{
|
||||
"TLS_ENABLED": "true",
|
||||
"CF_API_TOKEN": "secret",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Node mode without token",
|
||||
envs: map[string]string{
|
||||
"MODE": "node",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Node mode with token",
|
||||
envs: map[string]string{
|
||||
"MODE": "node",
|
||||
"NODE_TOKEN": "token",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid mode",
|
||||
envs: map[string]string{
|
||||
"MODE": "invalid",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid allowed ports",
|
||||
envs: map[string]string{
|
||||
"ALLOWED_PORTS": "1000",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
os.Clearenv()
|
||||
for k, v := range tt.envs {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
cfg, err := parse()
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cfg)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetters(t *testing.T) {
|
||||
envs := map[string]string{
|
||||
"DOMAIN": "example.com",
|
||||
"PORT": "2222",
|
||||
"HTTP_PORT": "80",
|
||||
"HTTPS_PORT": "443",
|
||||
"KEY_LOC": "certs/ssh/id_rsa",
|
||||
"TLS_ENABLED": "true",
|
||||
"TLS_REDIRECT": "true",
|
||||
"TLS_STORAGE_PATH": "certs/tls/",
|
||||
"ACME_EMAIL": "test@example.com",
|
||||
"CF_API_TOKEN": "token",
|
||||
"ACME_STAGING": "true",
|
||||
"ALLOWED_PORTS": "1000-2000",
|
||||
"BUFFER_SIZE": "16384",
|
||||
"MAX_HEADER_SIZE": "4096",
|
||||
"PPROF_ENABLED": "true",
|
||||
"PPROF_PORT": "7070",
|
||||
"MODE": "standalone",
|
||||
"GRPC_ADDRESS": "127.0.0.1",
|
||||
"GRPC_PORT": "9090",
|
||||
"NODE_TOKEN": "ntoken",
|
||||
}
|
||||
|
||||
os.Clearenv()
|
||||
for k, v := range envs {
|
||||
t.Setenv(k, v)
|
||||
}
|
||||
|
||||
cfg, err := parse()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "example.com", cfg.Domain())
|
||||
assert.Equal(t, "2222", cfg.SSHPort())
|
||||
assert.Equal(t, "80", cfg.HTTPPort())
|
||||
assert.Equal(t, "443", cfg.HTTPSPort())
|
||||
assert.Equal(t, "certs/ssh/id_rsa", cfg.KeyLoc())
|
||||
assert.Equal(t, true, cfg.TLSEnabled())
|
||||
assert.Equal(t, true, cfg.TLSRedirect())
|
||||
assert.Equal(t, "certs/tls/", cfg.TLSStoragePath())
|
||||
assert.Equal(t, "test@example.com", cfg.ACMEEmail())
|
||||
assert.Equal(t, "token", cfg.CFAPIToken())
|
||||
assert.Equal(t, true, cfg.ACMEStaging())
|
||||
assert.Equal(t, uint16(1000), cfg.AllowedPortsStart())
|
||||
assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd())
|
||||
assert.Equal(t, 16384, cfg.BufferSize())
|
||||
assert.Equal(t, 4096, cfg.HeaderSize())
|
||||
assert.Equal(t, true, cfg.PprofEnabled())
|
||||
assert.Equal(t, "7070", cfg.PprofPort())
|
||||
assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode())
|
||||
assert.Equal(t, "127.0.0.1", cfg.GRPCAddress())
|
||||
assert.Equal(t, "9090", cfg.GRPCPort())
|
||||
assert.Equal(t, "ntoken", cfg.NodeToken())
|
||||
}
|
||||
|
||||
func TestMustLoad(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
os.Clearenv()
|
||||
t.Setenv("DOMAIN", "example.com")
|
||||
cfg, err := MustLoad()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cfg)
|
||||
})
|
||||
|
||||
t.Run("loadEnvFile error", func(t *testing.T) {
|
||||
err := os.Mkdir(".env", 0755)
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
err = os.Remove(".env")
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
cfg, err := MustLoad()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cfg)
|
||||
})
|
||||
|
||||
t.Run("parse error", func(t *testing.T) {
|
||||
os.Clearenv()
|
||||
t.Setenv("MODE", "invalid")
|
||||
cfg, err := MustLoad()
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, cfg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadEnvFile(t *testing.T) {
|
||||
t.Run("file exists", func(t *testing.T) {
|
||||
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
|
||||
assert.NoError(t, err)
|
||||
defer func() {
|
||||
err = os.Remove(".env")
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
err = loadEnvFile()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE"))
|
||||
})
|
||||
|
||||
t.Run("file missing", func(t *testing.T) {
|
||||
_ = os.Remove(".env")
|
||||
err := loadEnvFile()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
domain string
|
||||
sshPort string
|
||||
|
||||
httpPort string
|
||||
httpsPort string
|
||||
|
||||
keyLoc string
|
||||
|
||||
tlsEnabled bool
|
||||
tlsRedirect bool
|
||||
tlsStoragePath string
|
||||
acmeEmail string
|
||||
cfAPIToken string
|
||||
acmeStaging bool
|
||||
|
||||
allowedPortsStart uint16
|
||||
allowedPortsEnd uint16
|
||||
|
||||
bufferSize int
|
||||
headerSize int
|
||||
|
||||
pprofEnabled bool
|
||||
pprofPort string
|
||||
|
||||
mode types.ServerMode
|
||||
grpcAddress string
|
||||
grpcPort string
|
||||
nodeToken string
|
||||
}
|
||||
|
||||
func parse() (*config, error) {
|
||||
mode, err := parseMode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
domain := getenv("DOMAIN", "localhost")
|
||||
sshPort := getenv("PORT", "2200")
|
||||
|
||||
httpPort := getenv("HTTP_PORT", "8080")
|
||||
httpsPort := getenv("HTTPS_PORT", "8443")
|
||||
|
||||
keyLoc := getenv("KEY_LOC", "certs/privkey.pem")
|
||||
|
||||
tlsEnabled := getenvBool("TLS_ENABLED", false)
|
||||
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
|
||||
tlsStoragePath := getenv("TLS_STORAGE_PATH", "certs/tls/")
|
||||
|
||||
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
|
||||
acmeStaging := getenvBool("ACME_STAGING", false)
|
||||
|
||||
cfToken := getenv("CF_API_TOKEN", "")
|
||||
if tlsEnabled && cfToken == "" {
|
||||
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
|
||||
}
|
||||
|
||||
start, end, err := parseAllowedPorts()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bufferSize := parseBufferSize()
|
||||
headerSize := parseHeaderSize()
|
||||
|
||||
pprofEnabled := getenvBool("PPROF_ENABLED", false)
|
||||
pprofPort := getenv("PPROF_PORT", "6060")
|
||||
|
||||
grpcHost := getenv("GRPC_ADDRESS", "localhost")
|
||||
grpcPort := getenv("GRPC_PORT", "8080")
|
||||
|
||||
nodeToken := getenv("NODE_TOKEN", "")
|
||||
if mode == types.ServerModeNODE && nodeToken == "" {
|
||||
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
|
||||
}
|
||||
|
||||
return &config{
|
||||
domain: domain,
|
||||
sshPort: sshPort,
|
||||
httpPort: httpPort,
|
||||
httpsPort: httpsPort,
|
||||
keyLoc: keyLoc,
|
||||
tlsEnabled: tlsEnabled,
|
||||
tlsRedirect: tlsRedirect,
|
||||
tlsStoragePath: tlsStoragePath,
|
||||
acmeEmail: acmeEmail,
|
||||
cfAPIToken: cfToken,
|
||||
acmeStaging: acmeStaging,
|
||||
allowedPortsStart: start,
|
||||
allowedPortsEnd: end,
|
||||
bufferSize: bufferSize,
|
||||
headerSize: headerSize,
|
||||
pprofEnabled: pprofEnabled,
|
||||
pprofPort: pprofPort,
|
||||
mode: mode,
|
||||
grpcAddress: grpcHost,
|
||||
grpcPort: grpcPort,
|
||||
nodeToken: nodeToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func loadEnvFile() error {
|
||||
if _, err := os.Stat(".env"); err == nil {
|
||||
return godotenv.Load(".env")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseMode() (types.ServerMode, error) {
|
||||
switch strings.ToLower(getenv("MODE", "standalone")) {
|
||||
case "standalone":
|
||||
return types.ServerModeSTANDALONE, nil
|
||||
case "node":
|
||||
return types.ServerModeNODE, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid MODE value")
|
||||
}
|
||||
}
|
||||
|
||||
func parseAllowedPorts() (uint16, uint16, error) {
|
||||
raw := getenv("ALLOWED_PORTS", "")
|
||||
if raw == "" {
|
||||
return 0, 0, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(raw, "-")
|
||||
if len(parts) != 2 {
|
||||
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
|
||||
}
|
||||
|
||||
start, err := strconv.ParseUint(parts[0], 10, 16)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
end, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return uint16(start), uint16(end), nil
|
||||
}
|
||||
|
||||
func parseBufferSize() int {
|
||||
raw := getenv("BUFFER_SIZE", "32768")
|
||||
size, err := strconv.Atoi(raw)
|
||||
if err != nil || size < 4096 || size > 1048576 {
|
||||
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
|
||||
return 4096
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func parseHeaderSize() int {
|
||||
raw := getenv("MAX_HEADER_SIZE", "4096")
|
||||
size, err := strconv.Atoi(raw)
|
||||
if err != nil || size < 4096 || size > 131072 {
|
||||
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
|
||||
return 4096
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
func getenv(key, def string) string {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
return v
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func getenvBool(key string, def bool) bool {
|
||||
val := os.Getenv(key)
|
||||
if val == "" {
|
||||
return def
|
||||
}
|
||||
return val == "true"
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/types"
|
||||
|
||||
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
SubscribeEvents(ctx context.Context, identity, authToken string) error
|
||||
ClientConn() *grpc.ClientConn
|
||||
AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error)
|
||||
Close() error
|
||||
CheckServerHealth(ctx context.Context) error
|
||||
}
|
||||
type client struct {
|
||||
config config.Config
|
||||
conn *grpc.ClientConn
|
||||
address string
|
||||
sessionRegistry registry.Registry
|
||||
eventService proto.EventServiceClient
|
||||
authorizeConnectionService proto.UserServiceClient
|
||||
closing bool
|
||||
}
|
||||
|
||||
var (
|
||||
grpcNewClient = grpc.NewClient
|
||||
healthNewHealthClient = grpc_health_v1.NewHealthClient
|
||||
initialBackoff = time.Second
|
||||
)
|
||||
|
||||
func New(config config.Config, sessionRegistry registry.Registry) (Client, error) {
|
||||
address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort())
|
||||
|
||||
var opts []grpc.DialOption
|
||||
|
||||
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
|
||||
kaParams := keepalive.ClientParameters{
|
||||
Time: 2 * time.Minute,
|
||||
Timeout: 10 * time.Second,
|
||||
PermitWithoutStream: false,
|
||||
}
|
||||
|
||||
opts = append(opts, grpc.WithKeepaliveParams(kaParams))
|
||||
|
||||
opts = append(opts,
|
||||
grpc.WithDefaultCallOptions(
|
||||
grpc.MaxCallRecvMsgSize(4*1024*1024),
|
||||
grpc.MaxCallSendMsgSize(4*1024*1024),
|
||||
),
|
||||
)
|
||||
|
||||
conn, err := grpcNewClient(address, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err)
|
||||
}
|
||||
|
||||
eventService := proto.NewEventServiceClient(conn)
|
||||
authorizeConnectionService := proto.NewUserServiceClient(conn)
|
||||
|
||||
return &client{
|
||||
config: config,
|
||||
conn: conn,
|
||||
address: address,
|
||||
sessionRegistry: sessionRegistry,
|
||||
eventService: eventService,
|
||||
authorizeConnectionService: authorizeConnectionService,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error {
|
||||
backoff := initialBackoff
|
||||
|
||||
for {
|
||||
if err := c.subscribeAndProcess(ctx, identity, authToken, &backoff); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) subscribeAndProcess(ctx context.Context, identity, authToken string, backoff *time.Duration) error {
|
||||
subscribe, err := c.eventService.Subscribe(ctx)
|
||||
if err != nil {
|
||||
return c.handleSubscribeError(ctx, err, backoff)
|
||||
}
|
||||
|
||||
err = subscribe.Send(&proto.Node{
|
||||
Type: proto.EventType_AUTHENTICATION,
|
||||
Payload: &proto.Node_AuthEvent{
|
||||
AuthEvent: &proto.Authentication{
|
||||
Identity: identity,
|
||||
AuthToken: authToken,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return c.handleAuthError(ctx, err, backoff)
|
||||
}
|
||||
|
||||
log.Println("Authentication Successfully sent to gRPC server")
|
||||
*backoff = time.Second
|
||||
|
||||
return c.handleStreamError(ctx, c.processEventStream(subscribe), backoff)
|
||||
}
|
||||
|
||||
func (c *client) handleSubscribeError(ctx context.Context, err error, backoff *time.Duration) error {
|
||||
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
||||
return err
|
||||
}
|
||||
if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
|
||||
return err
|
||||
}
|
||||
if err = c.wait(ctx, *backoff); err != nil {
|
||||
return err
|
||||
}
|
||||
c.growBackoff(backoff)
|
||||
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleAuthError(ctx context.Context, err error, backoff *time.Duration) error {
|
||||
log.Println("Authentication failed to send to gRPC server:", err)
|
||||
if !c.isConnectionError(err) {
|
||||
return err
|
||||
}
|
||||
if err := c.wait(ctx, *backoff); err != nil {
|
||||
return err
|
||||
}
|
||||
c.growBackoff(backoff)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) handleStreamError(ctx context.Context, err error, backoff *time.Duration) error {
|
||||
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
|
||||
return err
|
||||
}
|
||||
if !c.isConnectionError(err) {
|
||||
return err
|
||||
}
|
||||
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
|
||||
if err := c.wait(ctx, *backoff); err != nil {
|
||||
return err
|
||||
}
|
||||
c.growBackoff(backoff)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) wait(ctx context.Context, duration time.Duration) error {
|
||||
if duration <= 0 {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) growBackoff(backoff *time.Duration) {
|
||||
const maxBackoff = 30 * time.Second
|
||||
*backoff *= 2
|
||||
if *backoff > maxBackoff {
|
||||
*backoff = maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error {
|
||||
handlers := c.eventHandlers(subscribe)
|
||||
|
||||
for {
|
||||
recv, err := subscribe.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
handler, ok := handlers[recv.GetType()]
|
||||
if !ok {
|
||||
log.Printf("Unknown event type received: %v", recv.GetType())
|
||||
continue
|
||||
}
|
||||
|
||||
if err = handler(recv); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error {
|
||||
return map[proto.EventType]func(*proto.Events) error{
|
||||
proto.EventType_SLUG_CHANGE: func(evt *proto.Events) error { return c.handleSlugChange(subscribe, evt) },
|
||||
proto.EventType_GET_SESSIONS: func(evt *proto.Events) error { return c.handleGetSessions(subscribe, evt) },
|
||||
proto.EventType_TERMINATE_SESSION: func(evt *proto.Events) error { return c.handleTerminateSession(subscribe, evt) },
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
|
||||
slugEvent := evt.GetSlugEvent()
|
||||
user := slugEvent.GetUser()
|
||||
oldKey := types.SessionKey{Id: slugEvent.GetOld(), Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: slugEvent.GetNew(), Type: types.TunnelTypeHTTP}
|
||||
|
||||
userSession, err := c.sessionRegistry.Get(oldKey)
|
||||
if err != nil {
|
||||
return c.sendSlugChangeResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
if err = c.sessionRegistry.Update(user, oldKey, newKey); err != nil {
|
||||
return c.sendSlugChangeResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
userSession.Interaction().Redraw()
|
||||
return c.sendSlugChangeResponse(subscribe, true, "")
|
||||
}
|
||||
|
||||
func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
|
||||
sessions := c.sessionRegistry.GetAllSessionFromUser(evt.GetGetSessionsEvent().GetIdentity())
|
||||
|
||||
var details []*proto.Detail
|
||||
for _, ses := range sessions {
|
||||
detail := ses.Detail()
|
||||
details = append(details, &proto.Detail{
|
||||
Node: c.config.Domain(),
|
||||
ForwardingType: detail.ForwardingType,
|
||||
Slug: detail.Slug,
|
||||
UserId: detail.UserID,
|
||||
Active: detail.Active,
|
||||
StartedAt: timestamppb.New(detail.StartedAt),
|
||||
})
|
||||
}
|
||||
|
||||
return c.sendGetSessionsResponse(subscribe, details)
|
||||
}
|
||||
|
||||
func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
|
||||
terminate := evt.GetTerminateSessionEvent()
|
||||
user := terminate.GetUser()
|
||||
slug := terminate.GetSlug()
|
||||
|
||||
tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType())
|
||||
if err != nil {
|
||||
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType})
|
||||
if err != nil {
|
||||
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
if err = userSession.Lifecycle().Close(); err != nil {
|
||||
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
|
||||
}
|
||||
|
||||
return c.sendTerminateSessionResponse(subscribe, true, "")
|
||||
}
|
||||
|
||||
func (c *client) sendSlugChangeResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
|
||||
Payload: &proto.Node_SlugEventResponse{
|
||||
SlugEventResponse: &proto.SlugChangeEventResponse{Success: success, Message: message},
|
||||
},
|
||||
}, "slug change response")
|
||||
}
|
||||
|
||||
func (c *client) sendGetSessionsResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], details []*proto.Detail) error {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_GET_SESSIONS,
|
||||
Payload: &proto.Node_GetSessionsEvent{
|
||||
GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
|
||||
},
|
||||
}, "send get sessions response")
|
||||
}
|
||||
|
||||
func (c *client) sendTerminateSessionResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
|
||||
return c.sendNode(subscribe, &proto.Node{
|
||||
Type: proto.EventType_TERMINATE_SESSION,
|
||||
Payload: &proto.Node_TerminateSessionEventResponse{
|
||||
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: success, Message: message},
|
||||
},
|
||||
}, "terminate session response")
|
||||
}
|
||||
|
||||
func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error {
|
||||
if err := subscribe.Send(node); err != nil {
|
||||
if c.isConnectionError(err) {
|
||||
return err
|
||||
}
|
||||
log.Printf("%s: %v", context, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
|
||||
switch t {
|
||||
case proto.TunnelType_HTTP:
|
||||
return types.TunnelTypeHTTP, nil
|
||||
case proto.TunnelType_TCP:
|
||||
return types.TunnelTypeTCP, nil
|
||||
default:
|
||||
return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) ClientConn() *grpc.ClientConn {
|
||||
return c.conn
|
||||
}
|
||||
|
||||
func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
||||
check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token})
|
||||
if err != nil {
|
||||
return false, "UNAUTHORIZED", err
|
||||
}
|
||||
|
||||
if check.GetResponse() == proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED {
|
||||
return false, "UNAUTHORIZED", nil
|
||||
}
|
||||
return true, check.GetUser(), nil
|
||||
}
|
||||
|
||||
func (c *client) CheckServerHealth(ctx context.Context) error {
|
||||
healthClient := healthNewHealthClient(c.ClientConn())
|
||||
resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{
|
||||
Service: "",
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("health check failed: %w", err)
|
||||
}
|
||||
if resp.Status != grpc_health_v1.HealthCheckResponse_SERVING {
|
||||
return fmt.Errorf("server not serving: %v", resp.Status)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) Close() error {
|
||||
if c.conn != nil {
|
||||
log.Printf("Closing gRPC connection to %s", c.address)
|
||||
c.closing = true
|
||||
return c.conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) isConnectionError(err error) bool {
|
||||
if c.closing {
|
||||
return false
|
||||
}
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
switch status.Code(err) {
|
||||
case codes.Unavailable, codes.Canceled, codes.DeadlineExceeded:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
package header
|
||||
|
||||
type ResponseHeader interface {
|
||||
Value(key string) string
|
||||
Set(key string, value string)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
}
|
||||
|
||||
type responseHeader struct {
|
||||
startLine []byte
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
type RequestHeader interface {
|
||||
Value(key string) string
|
||||
Set(key string, value string)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
Method() string
|
||||
Path() string
|
||||
Version() string
|
||||
}
|
||||
type requestHeader struct {
|
||||
method string
|
||||
path string
|
||||
version string
|
||||
startLine []byte
|
||||
headers map[string]string
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package header
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expectErr bool
|
||||
errContains string
|
||||
expectMethod string
|
||||
expectPath string
|
||||
expectVersion string
|
||||
expectHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
data: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\nX-Custom: value\r\n\r\n"),
|
||||
expectErr: false,
|
||||
expectMethod: "GET",
|
||||
expectPath: "/path",
|
||||
expectVersion: "HTTP/1.1",
|
||||
expectHeaders: map[string]string{
|
||||
"Host": "example.com",
|
||||
"X-Custom": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no CRLF in start line",
|
||||
data: []byte("GET /path HTTP/1.1"),
|
||||
expectErr: true,
|
||||
errContains: "no CRLF found in start line",
|
||||
},
|
||||
{
|
||||
name: "invalid start line - missing method",
|
||||
data: []byte("INVALID\r\n\r\n"),
|
||||
expectErr: true,
|
||||
errContains: "invalid start line: missing method",
|
||||
},
|
||||
{
|
||||
name: "invalid start line - missing version",
|
||||
data: []byte("GET /path\r\n\r\n"),
|
||||
expectErr: true,
|
||||
errContains: "invalid start line: missing version",
|
||||
},
|
||||
{
|
||||
name: "invalid start line - multiple spaces",
|
||||
data: []byte("GET /path HTTP/1.1\r\n\r\n"),
|
||||
expectErr: false,
|
||||
expectMethod: "GET",
|
||||
expectPath: "",
|
||||
expectVersion: "/path HTTP/1.1",
|
||||
expectHeaders: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "start line with trailing space",
|
||||
data: []byte("GET / HTTP/1.1 \r\n\r\n"),
|
||||
expectErr: false,
|
||||
expectMethod: "GET",
|
||||
expectPath: "/",
|
||||
expectVersion: "HTTP/1.1 ",
|
||||
expectHeaders: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := NewRequest(tt.data)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, req)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, req)
|
||||
assert.Equal(t, tt.expectMethod, req.Method())
|
||||
assert.Equal(t, tt.expectPath, req.Path())
|
||||
assert.Equal(t, tt.expectVersion, req.Version())
|
||||
for k, v := range tt.expectHeaders {
|
||||
assert.Equal(t, v, req.Value(k))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHeaderMethods(t *testing.T) {
|
||||
data := []byte("GET / HTTP/1.1\r\nHost: original\r\n\r\n")
|
||||
req, _ := NewRequest(data)
|
||||
|
||||
req.Set("Host", "updated")
|
||||
req.Set("X-New", "new-value")
|
||||
assert.Equal(t, "updated", req.Value("Host"))
|
||||
assert.Equal(t, "new-value", req.Value("X-New"))
|
||||
|
||||
assert.Equal(t, "", req.Value("Non-Existent"))
|
||||
|
||||
req.Remove("X-New")
|
||||
assert.Equal(t, "", req.Value("X-New"))
|
||||
|
||||
final := req.Finalize()
|
||||
assert.Contains(t, string(final), "GET / HTTP/1.1\r\n")
|
||||
assert.Contains(t, string(final), "Host: updated\r\n")
|
||||
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
|
||||
}
|
||||
|
||||
func TestNewResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expectErr bool
|
||||
errContains string
|
||||
expectHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
data: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"),
|
||||
expectErr: false,
|
||||
expectHeaders: map[string]string{
|
||||
"Content-Length": "0",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid response - no CRLF",
|
||||
data: []byte("HTTP/1.1 200 OK"),
|
||||
expectErr: true,
|
||||
errContains: "no CRLF found in start line",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, err := NewResponse(tt.data)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tt.errContains)
|
||||
}
|
||||
assert.Nil(t, resp)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, resp)
|
||||
for k, v := range tt.expectHeaders {
|
||||
assert.Equal(t, v, resp.Value(k))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseHeaderMethods(t *testing.T) {
|
||||
data := []byte("HTTP/1.1 200 OK\r\nServer: old\r\n\r\n")
|
||||
resp, _ := NewResponse(data)
|
||||
|
||||
resp.Set("Server", "new")
|
||||
resp.Set("X-Res", "val")
|
||||
assert.Equal(t, "new", resp.Value("Server"))
|
||||
assert.Equal(t, "val", resp.Value("X-Res"))
|
||||
|
||||
resp.Remove("X-Res")
|
||||
assert.Equal(t, "", resp.Value("X-Res"))
|
||||
|
||||
final := resp.Finalize()
|
||||
assert.Contains(t, string(final), "HTTP/1.1 200 OK\r\n")
|
||||
assert.Contains(t, string(final), "Server: new\r\n")
|
||||
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
|
||||
}
|
||||
|
||||
func TestSetRemainingHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
initialHeaders map[string]string
|
||||
expectHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "various header formats",
|
||||
data: []byte("K1: V1\r\nK2:V2\r\n K3 : V3 \r\nNoColon\r\n\r\n"),
|
||||
expectHeaders: map[string]string{
|
||||
"K1": "V1",
|
||||
"K2": "V2",
|
||||
"K3": "V3",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no trailing CRLF",
|
||||
data: []byte("K1: V1"),
|
||||
expectHeaders: map[string]string{
|
||||
"K1": "V1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty lines",
|
||||
data: []byte("\r\nK1: V1"),
|
||||
expectHeaders: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "headers with only colon",
|
||||
data: []byte(": value\r\nkey:\r\n"),
|
||||
expectHeaders: map[string]string{
|
||||
"": "value",
|
||||
"key": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := &requestHeader{headers: make(map[string]string)}
|
||||
if tt.initialHeaders != nil {
|
||||
req.headers = tt.initialHeaders
|
||||
}
|
||||
setRemainingHeaders(tt.data, req)
|
||||
assert.Equal(t, len(tt.expectHeaders), len(req.headers))
|
||||
for k, v := range tt.expectHeaders {
|
||||
assert.Equal(t, v, req.headers[k])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package header
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func setRemainingHeaders(remaining []byte, header interface {
|
||||
Set(key string, value string)
|
||||
}) {
|
||||
for len(remaining) > 0 {
|
||||
lineEnd := bytes.Index(remaining, []byte("\r\n"))
|
||||
if lineEnd == -1 {
|
||||
lineEnd = len(remaining)
|
||||
}
|
||||
|
||||
line := remaining[:lineEnd]
|
||||
|
||||
if len(line) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
colonIdx := bytes.IndexByte(line, ':')
|
||||
if colonIdx != -1 {
|
||||
key := bytes.TrimSpace(line[:colonIdx])
|
||||
value := bytes.TrimSpace(line[colonIdx+1:])
|
||||
header.Set(string(key), string(value))
|
||||
}
|
||||
|
||||
if lineEnd == len(remaining) {
|
||||
break
|
||||
}
|
||||
|
||||
remaining = remaining[lineEnd+2:]
|
||||
}
|
||||
}
|
||||
|
||||
func parseStartLine(startLine []byte) (method, path, version string, err error) {
|
||||
firstSpace := bytes.IndexByte(startLine, ' ')
|
||||
if firstSpace == -1 {
|
||||
return "", "", "", fmt.Errorf("invalid start line: missing method")
|
||||
}
|
||||
|
||||
secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ')
|
||||
if secondSpace == -1 {
|
||||
return "", "", "", fmt.Errorf("invalid start line: missing version")
|
||||
}
|
||||
secondSpace += firstSpace + 1
|
||||
|
||||
method = string(startLine[:firstSpace])
|
||||
path = string(startLine[firstSpace+1 : secondSpace])
|
||||
version = string(startLine[secondSpace+1:])
|
||||
|
||||
return method, path, version, nil
|
||||
}
|
||||
|
||||
func finalize(startLine []byte, headers map[string]string) []byte {
|
||||
size := len(startLine) + 2
|
||||
for key, val := range headers {
|
||||
size += len(key) + 2 + len(val) + 2
|
||||
}
|
||||
size += 2
|
||||
|
||||
buf := make([]byte, 0, size)
|
||||
buf = append(buf, startLine...)
|
||||
buf = append(buf, '\r', '\n')
|
||||
|
||||
for key, val := range headers {
|
||||
buf = append(buf, key...)
|
||||
buf = append(buf, ':', ' ')
|
||||
buf = append(buf, val...)
|
||||
buf = append(buf, '\r', '\n')
|
||||
}
|
||||
|
||||
buf = append(buf, '\r', '\n')
|
||||
return buf
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package header
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func NewRequest(headerData []byte) (RequestHeader, error) {
|
||||
header := &requestHeader{
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
||||
if lineEnd == -1 {
|
||||
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
|
||||
}
|
||||
|
||||
startLine := headerData[:lineEnd]
|
||||
header.startLine = startLine
|
||||
var err error
|
||||
header.method, header.path, header.version, err = parseStartLine(startLine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
remaining := headerData[lineEnd+2:]
|
||||
|
||||
setRemainingHeaders(remaining, header)
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func (req *requestHeader) Value(key string) string {
|
||||
val, ok := req.headers[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func (req *requestHeader) Set(key string, value string) {
|
||||
req.headers[key] = value
|
||||
}
|
||||
|
||||
func (req *requestHeader) Remove(key string) {
|
||||
delete(req.headers, key)
|
||||
}
|
||||
|
||||
func (req *requestHeader) Method() string {
|
||||
return req.method
|
||||
}
|
||||
|
||||
func (req *requestHeader) Path() string {
|
||||
return req.path
|
||||
}
|
||||
|
||||
func (req *requestHeader) Version() string {
|
||||
return req.version
|
||||
}
|
||||
|
||||
func (req *requestHeader) Finalize() []byte {
|
||||
return finalize(req.startLine, req.headers)
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package header
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func NewResponse(headerData []byte) (ResponseHeader, error) {
|
||||
header := &responseHeader{
|
||||
startLine: nil,
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
||||
if lineEnd == -1 {
|
||||
return nil, fmt.Errorf("invalid response: no CRLF found in start line")
|
||||
}
|
||||
|
||||
header.startLine = headerData[:lineEnd]
|
||||
remaining := headerData[lineEnd+2:]
|
||||
setRemainingHeaders(remaining, header)
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func (resp *responseHeader) Value(key string) string {
|
||||
return resp.headers[key]
|
||||
}
|
||||
|
||||
func (resp *responseHeader) Set(key string, value string) {
|
||||
resp.headers[key] = value
|
||||
}
|
||||
|
||||
func (resp *responseHeader) Remove(key string) {
|
||||
delete(resp.headers, key)
|
||||
}
|
||||
|
||||
func (resp *responseHeader) Finalize() []byte {
|
||||
return finalize(resp.startLine, resp.headers)
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package stream
|
||||
|
||||
import "bytes"
|
||||
|
||||
func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
|
||||
headerByte := data[:delimiterIdx+len(DELIMITER)]
|
||||
body := data[delimiterIdx+len(DELIMITER):]
|
||||
return headerByte, body
|
||||
}
|
||||
|
||||
func isHTTPHeader(buf []byte) bool {
|
||||
lines := bytes.Split(buf, []byte("\r\n"))
|
||||
|
||||
startLine := string(lines[0])
|
||||
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, line := range lines[1:] {
|
||||
if len(line) == 0 {
|
||||
break
|
||||
}
|
||||
colonIdx := bytes.IndexByte(line, ':')
|
||||
if colonIdx <= 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"tunnel_pls/internal/http/header"
|
||||
)
|
||||
|
||||
func (hs *http) Read(p []byte) (int, error) {
|
||||
tmp := make([]byte, len(p))
|
||||
read, err := hs.reader.Read(tmp)
|
||||
if read == 0 && err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tmp = tmp[:read]
|
||||
|
||||
headerEndIdx := bytes.Index(tmp, DELIMITER)
|
||||
if headerEndIdx == -1 {
|
||||
return handleNoDelimiter(p, tmp, err)
|
||||
}
|
||||
|
||||
headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx)
|
||||
|
||||
if !isHTTPHeader(headerByte) {
|
||||
copy(p, tmp)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
return hs.processHTTPRequest(p, headerByte, bodyByte)
|
||||
}
|
||||
|
||||
func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) {
|
||||
reqhf, err := header.NewRequest(headerByte)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err = hs.ApplyRequestMiddlewares(reqhf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
hs.reqHeader = reqhf
|
||||
combined := append(reqhf.Finalize(), bodyByte...)
|
||||
return copy(p, combined), nil
|
||||
}
|
||||
|
||||
func handleNoDelimiter(p, tmp []byte, err error) (int, error) {
|
||||
copy(p, tmp)
|
||||
return len(tmp), err
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"regexp"
|
||||
"tunnel_pls/internal/http/header"
|
||||
"tunnel_pls/internal/middleware"
|
||||
)
|
||||
|
||||
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
|
||||
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
|
||||
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
|
||||
|
||||
type HTTP interface {
|
||||
io.ReadWriteCloser
|
||||
CloseWrite() error
|
||||
RemoteAddr() net.Addr
|
||||
UseResponseMiddleware(mw middleware.ResponseMiddleware)
|
||||
UseRequestMiddleware(mw middleware.RequestMiddleware)
|
||||
SetRequestHeader(header header.RequestHeader)
|
||||
RequestMiddlewares() []middleware.RequestMiddleware
|
||||
ResponseMiddlewares() []middleware.ResponseMiddleware
|
||||
ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error
|
||||
ApplyRequestMiddlewares(reqhf header.RequestHeader) error
|
||||
}
|
||||
|
||||
type http struct {
|
||||
remoteAddr net.Addr
|
||||
writer io.Writer
|
||||
reader io.Reader
|
||||
buf []byte
|
||||
respHeader header.ResponseHeader
|
||||
reqHeader header.RequestHeader
|
||||
respMW []middleware.ResponseMiddleware
|
||||
reqMW []middleware.RequestMiddleware
|
||||
}
|
||||
|
||||
func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP {
|
||||
return &http{
|
||||
remoteAddr: remoteAddr,
|
||||
writer: writer,
|
||||
reader: reader,
|
||||
buf: make([]byte, 0, 4096),
|
||||
}
|
||||
}
|
||||
|
||||
func (hs *http) RemoteAddr() net.Addr {
|
||||
return hs.remoteAddr
|
||||
}
|
||||
|
||||
func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) {
|
||||
hs.respMW = append(hs.respMW, mw)
|
||||
}
|
||||
|
||||
func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) {
|
||||
hs.reqMW = append(hs.reqMW, mw)
|
||||
}
|
||||
|
||||
func (hs *http) SetRequestHeader(header header.RequestHeader) {
|
||||
hs.reqHeader = header
|
||||
}
|
||||
|
||||
func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware {
|
||||
return hs.reqMW
|
||||
}
|
||||
|
||||
func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
|
||||
return hs.respMW
|
||||
}
|
||||
|
||||
func (hs *http) Close() error {
|
||||
if closer, ok := hs.writer.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *http) CloseWrite() error {
|
||||
if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok {
|
||||
return closer.CloseWrite()
|
||||
}
|
||||
return hs.Close()
|
||||
}
|
||||
|
||||
func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error {
|
||||
for _, m := range hs.RequestMiddlewares() {
|
||||
if err := m.HandleRequest(reqhf); err != nil {
|
||||
log.Printf("Error when applying request middleware: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error {
|
||||
for _, m := range hs.ResponseMiddlewares() {
|
||||
if err := m.HandleResponse(resphf, bodyByte); err != nil {
|
||||
log.Printf("Cannot apply middleware: %s\n", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,765 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tunnel_pls/internal/http/header"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockAddr struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockAddr) String() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
func (m *MockAddr) Network() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
type MockRequestMiddleware struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
|
||||
args := m.Called(h)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockResponseMiddleware struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
|
||||
args := m.Called(h, body)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockReadWriter struct {
|
||||
mock.Mock
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockReadWriter) Read(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockReadWriter) Write(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockReadWriter) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockReadWriter) CloseWrite() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockReadWriterOnlyCloser struct {
|
||||
mock.Mock
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockReadWriterOnlyCloser) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockWriterOnly struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockWriterOnly) Write(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockWriterOnly) Read(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
type MockReader struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockReader) Read(p []byte) (int, error) {
|
||||
args := m.Called(p)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
type MockWriter struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockWriter) Write(p []byte) (int, error) {
|
||||
ret := m.Called(p)
|
||||
|
||||
var n int
|
||||
var err error
|
||||
|
||||
switch v := ret.Get(0).(type) {
|
||||
case func([]byte) int:
|
||||
n = v(p)
|
||||
case int:
|
||||
n = v
|
||||
default:
|
||||
n = len(p)
|
||||
}
|
||||
|
||||
switch v := ret.Get(1).(type) {
|
||||
case func([]byte) error:
|
||||
err = v(p)
|
||||
case error:
|
||||
err = v
|
||||
default:
|
||||
err = nil
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (m *MockWriter) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func TestHTTPMethods(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
rw := new(MockReadWriter)
|
||||
hs := New(rw, rw, addr)
|
||||
|
||||
assert.Equal(t, addr, hs.RemoteAddr())
|
||||
|
||||
reqMW := new(MockRequestMiddleware)
|
||||
hs.UseRequestMiddleware(reqMW)
|
||||
assert.Equal(t, 1, len(hs.RequestMiddlewares()))
|
||||
assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
|
||||
|
||||
respMW := new(MockResponseMiddleware)
|
||||
hs.UseResponseMiddleware(respMW)
|
||||
assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
|
||||
assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
|
||||
|
||||
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
hs.SetRequestHeader(reqH)
|
||||
}
|
||||
|
||||
func TestApplyMiddlewares(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware)
|
||||
apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
|
||||
verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "apply request middleware success",
|
||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
|
||||
h := args.Get(0).(header.RequestHeader)
|
||||
h.Set("X-Middleware", "true")
|
||||
}).Return(nil)
|
||||
hs.UseRequestMiddleware(reqMW)
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyRequestMiddlewares(reqH)
|
||||
},
|
||||
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
|
||||
assert.Equal(t, "true", reqH.Value("X-Middleware"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "apply response middleware success",
|
||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
h := args.Get(0).(header.ResponseHeader)
|
||||
h.Set("X-Resp-Middleware", "true")
|
||||
}).Return(nil)
|
||||
hs.UseResponseMiddleware(respMW)
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
||||
},
|
||||
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
|
||||
assert.Equal(t, "true", respH.Value("X-Resp-Middleware"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "apply request middleware error",
|
||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||
reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError)
|
||||
hs.UseRequestMiddleware(reqMW)
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyRequestMiddlewares(reqH)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "apply response middleware error",
|
||||
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError)
|
||||
hs.UseResponseMiddleware(respMW)
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
||||
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
rw := new(MockReadWriter)
|
||||
hs := New(rw, rw, addr)
|
||||
|
||||
reqMW := new(MockRequestMiddleware)
|
||||
respMW := new(MockResponseMiddleware)
|
||||
tt.setup(hs, reqMW, respMW)
|
||||
|
||||
err := tt.apply(hs, reqH, respH)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, reqH, respH)
|
||||
}
|
||||
}
|
||||
|
||||
reqMW.AssertExpectations(t)
|
||||
respMW.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseMethods(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() (io.Writer, io.Reader)
|
||||
op func(HTTP) error
|
||||
verify func(*testing.T, io.Writer)
|
||||
}{
|
||||
{
|
||||
name: "Close success",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
rw := new(MockReadWriter)
|
||||
rw.On("Close").Return(nil)
|
||||
return rw, rw
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.Close() },
|
||||
verify: func(t *testing.T, w io.Writer) {
|
||||
w.(*MockReadWriter).AssertCalled(t, "Close")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWrite with CloseWrite implementation",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
rw := new(MockReadWriter)
|
||||
rw.On("CloseWrite").Return(nil)
|
||||
return rw, rw
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||
verify: func(t *testing.T, w io.Writer) {
|
||||
w.(*MockReadWriter).AssertCalled(t, "CloseWrite")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWrite fallback to Close",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
rw := new(MockReadWriterOnlyCloser)
|
||||
rw.On("Close").Return(nil)
|
||||
return rw, rw
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||
verify: func(t *testing.T, w io.Writer) {
|
||||
w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Close with No Closer",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
w := new(MockWriterOnly)
|
||||
r := new(MockReader)
|
||||
return w, r
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.Close() },
|
||||
},
|
||||
{
|
||||
name: "CloseWrite with No CloseWrite and No Closer",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
w := new(MockWriterOnly)
|
||||
r := new(MockReader)
|
||||
return w, r
|
||||
},
|
||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
w, r := tt.setup()
|
||||
hs := New(w, r, addr)
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
err := tt.op(hs)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, w)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitHeaderAndBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
delimiterIdx int
|
||||
expectHeader []byte
|
||||
expectBody []byte
|
||||
}{
|
||||
{
|
||||
name: "standard",
|
||||
data: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nBodyContent"),
|
||||
delimiterIdx: 31,
|
||||
expectHeader: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"),
|
||||
expectBody: []byte("BodyContent"),
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
delimiterIdx: 15,
|
||||
expectHeader: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
expectBody: []byte(""),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, b := splitHeaderAndBody(tt.data, tt.delimiterIdx)
|
||||
assert.Equal(t, tt.expectHeader, h)
|
||||
assert.Equal(t, tt.expectBody, b)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsHTTPHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
buf []byte
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
buf: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "valid response",
|
||||
buf: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "invalid start line",
|
||||
buf: []byte("NOT_HTTP /path\r\nHost: example.com\r\n\r\n"),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header line (no colon)",
|
||||
buf: []byte("GET / HTTP/1.1\r\nInvalidHeaderLine\r\n\r\n"),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header line (colon at 0)",
|
||||
buf: []byte("GET / HTTP/1.1\r\n: value\r\n\r\n"),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "empty header section",
|
||||
buf: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "multiple headers",
|
||||
buf: []byte("GET / HTTP/1.1\r\nH1: V1\r\nH2: V2\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isHTTPHeader(tt.buf)
|
||||
assert.Equal(t, tt.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
readLen int
|
||||
expectContent string
|
||||
expectRead int
|
||||
expectErr bool
|
||||
middlewareErr error
|
||||
isHTTP bool
|
||||
}{
|
||||
{
|
||||
name: "valid http request",
|
||||
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\nBody"),
|
||||
readLen: 100,
|
||||
expectContent: "Body",
|
||||
expectRead: 54,
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
name: "non-http data",
|
||||
input: []byte("Some random data\r\n\r\nMore data"),
|
||||
readLen: 100,
|
||||
expectContent: "Some random data\r\n\r\nMore data",
|
||||
expectRead: 29,
|
||||
isHTTP: false,
|
||||
},
|
||||
{
|
||||
name: "no delimiter",
|
||||
input: []byte("Partial data without delimiter"),
|
||||
readLen: 100,
|
||||
expectContent: "Partial data without delimiter",
|
||||
expectRead: 30,
|
||||
isHTTP: false,
|
||||
},
|
||||
{
|
||||
name: "middleware error",
|
||||
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\n"),
|
||||
readLen: 100,
|
||||
middlewareErr: assert.AnError,
|
||||
expectErr: true,
|
||||
isHTTP: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
reader := new(MockReader)
|
||||
writer := new(MockWriterOnly)
|
||||
|
||||
if tt.expectErr || tt.name == "valid http request" {
|
||||
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
||||
p := args.Get(0).([]byte)
|
||||
copy(p, tt.input)
|
||||
}).Return(len(tt.input), io.EOF).Once()
|
||||
} else {
|
||||
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
||||
p := args.Get(0).([]byte)
|
||||
copy(p, tt.input)
|
||||
}).Return(len(tt.input), nil).Once()
|
||||
}
|
||||
|
||||
hs := New(writer, reader, addr)
|
||||
|
||||
reqMW := new(MockRequestMiddleware)
|
||||
if tt.isHTTP {
|
||||
if tt.middlewareErr != nil {
|
||||
reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr)
|
||||
} else {
|
||||
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
|
||||
h := args.Get(0).(header.RequestHeader)
|
||||
h.Set("X-Middleware", "true")
|
||||
}).Return(nil)
|
||||
}
|
||||
}
|
||||
hs.UseRequestMiddleware(reqMW)
|
||||
|
||||
p := make([]byte, tt.readLen)
|
||||
n, err := hs.Read(p)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectRead, n)
|
||||
if tt.name == "valid http request" {
|
||||
content := string(p[:n])
|
||||
assert.Contains(t, content, "GET / HTTP/1.1\r\n")
|
||||
assert.Contains(t, content, "Host: test\r\n")
|
||||
assert.Contains(t, content, "X-Middleware: true\r\n")
|
||||
assert.True(t, bytes.HasSuffix(p[:n], []byte("\r\n\r\nBody")))
|
||||
} else {
|
||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
if tt.isHTTP {
|
||||
reqMW.AssertExpectations(t)
|
||||
}
|
||||
reader.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
writes [][]byte
|
||||
expectWritten string
|
||||
expectErr bool
|
||||
middlewareErr error
|
||||
isHTTP bool
|
||||
}{
|
||||
{
|
||||
name: "valid http response in one write",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
|
||||
},
|
||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
name: "valid http response in multiple writes",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\n"),
|
||||
[]byte("Content-Length: 4\r\n\r\n"),
|
||||
[]byte("Body"),
|
||||
},
|
||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
name: "non-http data",
|
||||
writes: [][]byte{
|
||||
[]byte("Random data with delimiter\r\n\r\nFlush"),
|
||||
},
|
||||
expectWritten: "Random data with delimiter\r\n\r\nFlush",
|
||||
isHTTP: false,
|
||||
},
|
||||
{
|
||||
name: "bypass buffering",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
},
|
||||
expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
|
||||
"HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
|
||||
isHTTP: true,
|
||||
},
|
||||
{
|
||||
name: "middleware error",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
},
|
||||
middlewareErr: assert.AnError,
|
||||
expectErr: true,
|
||||
isHTTP: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
var writtenData bytes.Buffer
|
||||
writer := new(MockWriter)
|
||||
|
||||
writer.On("Write", mock.Anything).Run(func(args mock.Arguments) {
|
||||
p := args.Get(0).([]byte)
|
||||
writtenData.Write(p)
|
||||
}).Return(func(p []byte) int {
|
||||
return len(p)
|
||||
}, nil)
|
||||
|
||||
reader := new(MockReader)
|
||||
hs := New(writer, reader, addr)
|
||||
|
||||
respMW := new(MockResponseMiddleware)
|
||||
if tt.isHTTP {
|
||||
if tt.middlewareErr != nil {
|
||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr)
|
||||
} else {
|
||||
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
h := args.Get(0).(header.ResponseHeader)
|
||||
h.Set("X-Resp-Middleware", "true")
|
||||
}).Return(nil)
|
||||
}
|
||||
}
|
||||
hs.UseResponseMiddleware(respMW)
|
||||
|
||||
var totalN int
|
||||
var err error
|
||||
for _, w := range tt.writes {
|
||||
var n int
|
||||
n, err = hs.Write(w)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
totalN += n
|
||||
}
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
written := writtenData.String()
|
||||
if strings.HasPrefix(tt.expectWritten, "HTTP/") {
|
||||
assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
|
||||
assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
|
||||
if strings.Contains(tt.expectWritten, "Content-Length: 4") {
|
||||
assert.Contains(t, written, "Content-Length: 4\r\n")
|
||||
}
|
||||
assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
|
||||
} else {
|
||||
assert.Equal(t, tt.expectWritten, written)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.isHTTP {
|
||||
respMW.AssertExpectations(t)
|
||||
}
|
||||
if tt.middlewareErr == nil {
|
||||
writer.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() (io.Writer, io.Reader)
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "write error in writeHeaderAndBody",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
writer := new(MockWriter)
|
||||
writer.On("Write", mock.Anything).Return(0, assert.AnError)
|
||||
reader := new(MockReader)
|
||||
return writer, reader
|
||||
},
|
||||
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
},
|
||||
{
|
||||
name: "write error in writeHeaderAndBody second write",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
writer := new(MockWriter)
|
||||
writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once()
|
||||
writer.On("Write", mock.Anything).Return(0, assert.AnError).Once()
|
||||
reader := new(MockReader)
|
||||
return writer, reader
|
||||
},
|
||||
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
|
||||
},
|
||||
{
|
||||
name: "write error in writeRawBuffer",
|
||||
setup: func() (io.Writer, io.Reader) {
|
||||
writer := new(MockWriter)
|
||||
writer.On("Write", mock.Anything).Return(0, assert.AnError)
|
||||
reader := new(MockReader)
|
||||
return writer, reader
|
||||
},
|
||||
data: []byte("Not HTTP\r\n\r\nFlush"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
w, r := tt.setup()
|
||||
hs := New(w, r, addr)
|
||||
|
||||
_, err := hs.Write(tt.data)
|
||||
assert.Error(t, err)
|
||||
|
||||
w.(*MockWriter).AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadEOF(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() io.Reader
|
||||
expectN int
|
||||
expectErr error
|
||||
expectContent string
|
||||
}{
|
||||
{
|
||||
name: "read eof",
|
||||
setup: func() io.Reader {
|
||||
reader := new(MockReader)
|
||||
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
||||
p := args.Get(0).([]byte)
|
||||
copy(p, "data")
|
||||
}).Return(4, io.EOF)
|
||||
return reader
|
||||
},
|
||||
expectN: 4,
|
||||
expectErr: io.EOF,
|
||||
expectContent: "data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := new(MockAddr)
|
||||
addr.On("String").Return("1.2.3.4:1234")
|
||||
|
||||
reader := tt.setup()
|
||||
hs := New(nil, reader, addr)
|
||||
|
||||
p := make([]byte, 100)
|
||||
n, err := hs.Read(p)
|
||||
|
||||
assert.Equal(t, tt.expectN, n)
|
||||
assert.Equal(t, tt.expectErr, err)
|
||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
||||
|
||||
reader.(*MockReader).AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"tunnel_pls/internal/http/header"
|
||||
)
|
||||
|
||||
func (hs *http) Write(p []byte) (int, error) {
|
||||
if hs.shouldBypassBuffering(p) {
|
||||
hs.respHeader = nil
|
||||
}
|
||||
|
||||
if hs.respHeader != nil {
|
||||
return hs.writer.Write(p)
|
||||
}
|
||||
|
||||
hs.buf = append(hs.buf, p...)
|
||||
|
||||
headerEndIdx := bytes.Index(hs.buf, DELIMITER)
|
||||
if headerEndIdx == -1 {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
return hs.processBufferedResponse(p, headerEndIdx)
|
||||
}
|
||||
|
||||
func (hs *http) shouldBypassBuffering(p []byte) bool {
|
||||
return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
|
||||
}
|
||||
|
||||
func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
|
||||
headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx)
|
||||
|
||||
if !isHTTPHeader(headerByte) {
|
||||
return hs.writeRawBuffer()
|
||||
}
|
||||
|
||||
if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
hs.buf = nil
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (hs *http) writeRawBuffer() (int, error) {
|
||||
_, err := hs.writer.Write(hs.buf)
|
||||
length := len(hs.buf)
|
||||
hs.buf = nil
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return length, nil
|
||||
}
|
||||
|
||||
func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error {
|
||||
resphf, err := header.NewResponse(headerByte)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hs.respHeader = resphf
|
||||
finalHeader := resphf.Finalize()
|
||||
|
||||
if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error {
|
||||
if _, err := hs.writer.Write(header); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(bodyByte) > 0 {
|
||||
if _, err := hs.writer.Write(bodyByte); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+28
-9
@@ -5,6 +5,8 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -12,7 +14,20 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
rsaGenerateKey = rsa.GenerateKey
|
||||
pemEncode = pem.Encode
|
||||
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
|
||||
return ssh.NewPublicKey(key)
|
||||
}
|
||||
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
|
||||
return w.Write(data)
|
||||
}
|
||||
osOpenFile = os.OpenFile
|
||||
)
|
||||
|
||||
func GenerateSSHKeyIfNotExist(keyPath string) error {
|
||||
var errGroup = make([]error, 0)
|
||||
if _, err := os.Stat(keyPath); err == nil {
|
||||
log.Printf("SSH key already exists at %s", keyPath)
|
||||
return nil
|
||||
@@ -20,7 +35,7 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
|
||||
|
||||
log.Printf("SSH key not found at %s, generating new key pair...", keyPath)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
privateKey, err := rsaGenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -35,33 +50,37 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
privateKeyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
privateKeyFile, err := osOpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer privateKeyFile.Close()
|
||||
defer func(privateKeyFile *os.File) {
|
||||
errGroup = append(errGroup, privateKeyFile.Close())
|
||||
}(privateKeyFile)
|
||||
|
||||
if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil {
|
||||
if err := pemEncode(privateKeyFile, privateKeyPEM); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
|
||||
publicKey, err := sshNewPublicKey(&privateKey.PublicKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pubKeyPath := keyPath + ".pub"
|
||||
pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
pubKeyFile, err := osOpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer pubKeyFile.Close()
|
||||
defer func(pubKeyFile *os.File) {
|
||||
errGroup = append(errGroup, pubKeyFile.Close())
|
||||
}(pubKeyFile)
|
||||
|
||||
_, err = pubKeyFile.Write(ssh.MarshalAuthorizedKey(publicKey))
|
||||
_, err = pubKeyWrite(pubKeyFile, ssh.MarshalAuthorizedKey(publicKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("SSH key pair generated successfully at %s and %s", keyPath, pubKeyPath)
|
||||
return nil
|
||||
return errors.Join(errGroup...)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestGenerateSSHKeyIfNotExist(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(t *testing.T, tempDir string) string
|
||||
mockSetup func() func()
|
||||
wantErr bool
|
||||
errStr string
|
||||
verify func(t *testing.T, keyPath string)
|
||||
}{
|
||||
{
|
||||
name: "GenerateNewKey",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "id_rsa")
|
||||
},
|
||||
verify: func(t *testing.T, keyPath string) {
|
||||
pubKeyPath := keyPath + ".pub"
|
||||
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
||||
t.Errorf("Private key file not created")
|
||||
}
|
||||
if _, err := os.Stat(pubKeyPath); os.IsNotExist(err) {
|
||||
t.Errorf("Public key file not created")
|
||||
}
|
||||
privateKeyBytes, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read private key: %v", err)
|
||||
}
|
||||
if _, err = ssh.ParseRawPrivateKey(privateKeyBytes); err != nil {
|
||||
t.Errorf("Failed to parse private key: %v", err)
|
||||
}
|
||||
publicKeyBytes, err := os.ReadFile(pubKeyPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read public key: %v", err)
|
||||
}
|
||||
if _, _, _, _, err = ssh.ParseAuthorizedKey(publicKeyBytes); err != nil {
|
||||
t.Errorf("Failed to parse public key: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "DoNotOverwriteExistingKey",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
keyPath := filepath.Join(tempDir, "existing_id_rsa")
|
||||
dummyPrivate := "dummy private"
|
||||
dummyPublic := "dummy public"
|
||||
if err := os.WriteFile(keyPath, []byte(dummyPrivate), 0600); err != nil {
|
||||
t.Fatalf("Failed to create dummy private key: %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath+".pub", []byte(dummyPublic), 0644); err != nil {
|
||||
t.Fatalf("Failed to create dummy public key: %v", err)
|
||||
}
|
||||
return keyPath
|
||||
},
|
||||
verify: func(t *testing.T, keyPath string) {
|
||||
gotPrivate, _ := os.ReadFile(keyPath)
|
||||
if string(gotPrivate) != "dummy private" {
|
||||
t.Errorf("Private key was overwritten")
|
||||
}
|
||||
gotPublic, _ := os.ReadFile(keyPath + ".pub")
|
||||
if string(gotPublic) != "dummy public" {
|
||||
t.Errorf("Public key was overwritten")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CreateNestedDirectories",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "nested", "dir", "id_rsa")
|
||||
},
|
||||
verify: func(t *testing.T, keyPath string) {
|
||||
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
|
||||
t.Errorf("Private key file not created in nested directory")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FailureMkdirAll",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
dirPath := filepath.Join(tempDir, "file_as_dir")
|
||||
if err := os.WriteFile(dirPath, []byte("not a dir"), 0644); err != nil {
|
||||
t.Fatalf("Failed to create file: %v", err)
|
||||
}
|
||||
return filepath.Join(dirPath, "id_rsa")
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "PrivateExistsPublicMissing",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
keyPath := filepath.Join(tempDir, "partial_id_rsa")
|
||||
if err := os.WriteFile(keyPath, []byte("private"), 0600); err != nil {
|
||||
t.Fatalf("Failed to create private key: %v", err)
|
||||
}
|
||||
return keyPath
|
||||
},
|
||||
verify: func(t *testing.T, keyPath string) {
|
||||
if _, err := os.Stat(keyPath + ".pub"); !os.IsNotExist(err) {
|
||||
t.Errorf("Public key should NOT have been created if private key existed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FailureRSAGenerateKey",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_rsa")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := rsaGenerateKey
|
||||
rsaGenerateKey = func(random io.Reader, bits int) (*rsa.PrivateKey, error) {
|
||||
return nil, errors.New("rsa error")
|
||||
}
|
||||
return func() { rsaGenerateKey = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "rsa error",
|
||||
},
|
||||
{
|
||||
name: "FailureOpenFilePrivate",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_open_private")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := osOpenFile
|
||||
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
return nil, errors.New("open error")
|
||||
}
|
||||
return func() { osOpenFile = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "open error",
|
||||
},
|
||||
{
|
||||
name: "FailurePemEncode",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_pem")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := pemEncode
|
||||
pemEncode = func(out io.Writer, b *pem.Block) error {
|
||||
return errors.New("pem error")
|
||||
}
|
||||
return func() { pemEncode = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "pem error",
|
||||
},
|
||||
{
|
||||
name: "FailureSSHNewPublicKey",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_ssh")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := sshNewPublicKey
|
||||
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
|
||||
return nil, errors.New("ssh error")
|
||||
}
|
||||
return func() { sshNewPublicKey = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "ssh error",
|
||||
},
|
||||
{
|
||||
name: "FailureOpenFilePublic",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_open_public")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := osOpenFile
|
||||
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
if filepath.Ext(name) == ".pub" {
|
||||
return nil, errors.New("open pub error")
|
||||
}
|
||||
return os.OpenFile(name, flag, perm)
|
||||
}
|
||||
return func() { osOpenFile = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "open pub error",
|
||||
},
|
||||
{
|
||||
name: "FailurePubKeyWrite",
|
||||
setup: func(t *testing.T, tempDir string) string {
|
||||
return filepath.Join(tempDir, "fail_write")
|
||||
},
|
||||
mockSetup: func() func() {
|
||||
old := pubKeyWrite
|
||||
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
|
||||
return 0, errors.New("write error")
|
||||
}
|
||||
return func() { pubKeyWrite = old }
|
||||
},
|
||||
wantErr: true,
|
||||
errStr: "write error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
keyPath := tt.setup(t, tempDir)
|
||||
if tt.mockSetup != nil {
|
||||
cleanup := tt.mockSetup()
|
||||
defer cleanup()
|
||||
}
|
||||
|
||||
err := GenerateSSHKeyIfNotExist(keyPath)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.wantErr && tt.errStr != "" && err != nil && err.Error() != tt.errStr {
|
||||
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErrStr %v", err, tt.errStr)
|
||||
}
|
||||
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, keyPath)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"tunnel_pls/internal/http/header"
|
||||
)
|
||||
|
||||
type ForwardedFor struct {
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func NewForwardedFor(addr net.Addr) *ForwardedFor {
|
||||
return &ForwardedFor{addr: addr}
|
||||
}
|
||||
|
||||
func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error {
|
||||
host, _, err := net.SplitHostPort(ff.addr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Set("X-Forwarded-For", host)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type mockRequestHeader struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Value(key string) string {
|
||||
return m.Called(key).String(0)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Set(key string, value string) {
|
||||
m.Called(key, value)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Remove(key string) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Finalize() []byte {
|
||||
return m.Called().Get(0).([]byte)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Method() string {
|
||||
return m.Called().String(0)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Path() string {
|
||||
return m.Called().String(0)
|
||||
}
|
||||
|
||||
func (m *mockRequestHeader) Version() string {
|
||||
return m.Called().String(0)
|
||||
}
|
||||
|
||||
func TestForwardedFor_HandleRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
expectedHost string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid IPv4 address",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 8080},
|
||||
expectedHost: "192.168.1.100",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid IPv6 address",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 8080},
|
||||
expectedHost: "2001:db8::ff00:42:8329",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid address format",
|
||||
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
||||
expectedHost: "",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "valid IPv4 address with port",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
|
||||
expectedHost: "127.0.0.1",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ff := NewForwardedFor(tc.addr)
|
||||
reqHeader := new(mockRequestHeader)
|
||||
|
||||
if !tc.expectError {
|
||||
reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return()
|
||||
}
|
||||
|
||||
err := ff.HandleRequest(reqHeader)
|
||||
|
||||
if tc.expectError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
reqHeader.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForwardedFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
expectAddr net.Addr
|
||||
}{
|
||||
{
|
||||
name: "IPv4 address",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
|
||||
expectAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
|
||||
expectAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
|
||||
},
|
||||
{
|
||||
name: "Unix address",
|
||||
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
||||
expectAddr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ff := NewForwardedFor(tc.addr)
|
||||
assert.Equal(t, tc.expectAddr.String(), ff.addr.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"tunnel_pls/internal/http/header"
|
||||
)
|
||||
|
||||
type RequestMiddleware interface {
|
||||
HandleRequest(header header.RequestHeader) error
|
||||
}
|
||||
|
||||
type ResponseMiddleware interface {
|
||||
HandleResponse(header header.ResponseHeader, body []byte) error
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"tunnel_pls/internal/http/header"
|
||||
)
|
||||
|
||||
type TunnelFingerprint struct{}
|
||||
|
||||
func NewTunnelFingerprint() *TunnelFingerprint {
|
||||
return &TunnelFingerprint{}
|
||||
}
|
||||
|
||||
func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error {
|
||||
header.Set("Server", "Tunnel Please")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type mockResponseHeader struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockResponseHeader) Value(key string) string {
|
||||
return m.Called(key).String(0)
|
||||
}
|
||||
|
||||
func (m *mockResponseHeader) Set(key string, value string) {
|
||||
m.Called(key, value)
|
||||
}
|
||||
|
||||
func (m *mockResponseHeader) Remove(key string) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *mockResponseHeader) Finalize() []byte {
|
||||
return m.Called().Get(0).([]byte)
|
||||
}
|
||||
|
||||
func TestTunnelFingerprintHandleResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expected map[string]string
|
||||
body []byte
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "Sets Server Header",
|
||||
expected: map[string]string{"Server": "Tunnel Please"},
|
||||
body: []byte("Sample body"),
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Overwrites Server Header",
|
||||
expected: map[string]string{"Server": "Tunnel Please"},
|
||||
body: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockHeader := new(mockResponseHeader)
|
||||
for k, v := range tt.expected {
|
||||
mockHeader.On("Set", k, v).Return()
|
||||
}
|
||||
|
||||
tunnelFingerprint := NewTunnelFingerprint()
|
||||
|
||||
err := tunnelFingerprint.HandleResponse(mockHeader, tt.body)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
mockHeader.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTunnelFingerprint(t *testing.T) {
|
||||
instance := NewTunnelFingerprint()
|
||||
assert.NotNil(t, instance)
|
||||
}
|
||||
+36
-49
@@ -3,63 +3,40 @@ package port
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"tunnel_pls/internal/config"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
AddPortRange(startPort, endPort uint16) error
|
||||
GetUnassignedPort() (uint16, bool)
|
||||
SetPortStatus(port uint16, assigned bool) error
|
||||
GetPortStatus(port uint16) (bool, bool)
|
||||
type Port interface {
|
||||
AddRange(startPort, endPort uint16) error
|
||||
Unassigned() (uint16, bool)
|
||||
SetStatus(port uint16, assigned bool) error
|
||||
Claim(port uint16) (claimed bool)
|
||||
}
|
||||
|
||||
type manager struct {
|
||||
type port struct {
|
||||
mu sync.RWMutex
|
||||
ports map[uint16]bool
|
||||
sortedPorts []uint16
|
||||
}
|
||||
|
||||
var Default Manager = &manager{
|
||||
ports: make(map[uint16]bool),
|
||||
sortedPorts: []uint16{},
|
||||
func New() Port {
|
||||
return &port{
|
||||
ports: make(map[uint16]bool),
|
||||
sortedPorts: []uint16{},
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
rawRange := config.Getenv("ALLOWED_PORTS", "")
|
||||
if rawRange == "" {
|
||||
return
|
||||
}
|
||||
|
||||
splitRange := strings.Split(rawRange, "-")
|
||||
if len(splitRange) != 2 {
|
||||
return
|
||||
}
|
||||
|
||||
start, err := strconv.ParseUint(splitRange[0], 10, 16)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
end, err := strconv.ParseUint(splitRange[1], 10, 16)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = Default.AddPortRange(uint16(start), uint16(end))
|
||||
}
|
||||
|
||||
func (pm *manager) AddPortRange(startPort, endPort uint16) error {
|
||||
func (pm *port) AddRange(startPort, endPort uint16) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if startPort > endPort {
|
||||
return fmt.Errorf("start port cannot be greater than end port")
|
||||
}
|
||||
for port := startPort; port <= endPort; port++ {
|
||||
if _, exists := pm.ports[port]; !exists {
|
||||
pm.ports[port] = false
|
||||
pm.sortedPorts = append(pm.sortedPorts, port)
|
||||
for index := startPort; index <= endPort; index++ {
|
||||
if _, exists := pm.ports[index]; !exists {
|
||||
pm.ports[index] = false
|
||||
pm.sortedPorts = append(pm.sortedPorts, index)
|
||||
}
|
||||
}
|
||||
sort.Slice(pm.sortedPorts, func(i, j int) bool {
|
||||
@@ -68,20 +45,19 @@ func (pm *manager) AddPortRange(startPort, endPort uint16) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *manager) GetUnassignedPort() (uint16, bool) {
|
||||
func (pm *port) Unassigned() (uint16, bool) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
for _, port := range pm.sortedPorts {
|
||||
if !pm.ports[port] {
|
||||
pm.ports[port] = true
|
||||
return port, true
|
||||
for _, index := range pm.sortedPorts {
|
||||
if !pm.ports[index] {
|
||||
return index, true
|
||||
}
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
|
||||
func (pm *port) SetStatus(port uint16, assigned bool) error {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
@@ -89,10 +65,21 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *manager) GetPortStatus(port uint16) (bool, bool) {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
func (pm *port) Claim(port uint16) (claimed bool) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
status, exists := pm.ports[port]
|
||||
return status, exists
|
||||
|
||||
if exists && status {
|
||||
return false
|
||||
}
|
||||
|
||||
if !exists {
|
||||
pm.ports[port] = true
|
||||
return true
|
||||
}
|
||||
|
||||
pm.ports[port] = true
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
package port
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
startPort uint16
|
||||
endPort uint16
|
||||
wantErr bool
|
||||
}{
|
||||
{"normal range", 1000, 1002, false},
|
||||
{"invalid range", 2000, 1999, true},
|
||||
{"single port range", 3000, 3000, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pm := New()
|
||||
err := pm.AddRange(tt.startPort, tt.endPort)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnassigned(t *testing.T) {
|
||||
pm := New()
|
||||
_ = pm.AddRange(1000, 1002)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
status map[uint16]bool
|
||||
want uint16
|
||||
wantOk bool
|
||||
}{
|
||||
{"all unassigned", map[uint16]bool{1000: false, 1001: false, 1002: false}, 1000, true},
|
||||
{"some assigned", map[uint16]bool{1000: true, 1001: false, 1002: true}, 1001, true},
|
||||
{"all assigned", map[uint16]bool{1000: true, 1001: true, 1002: true}, 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for k, v := range tt.status {
|
||||
_ = pm.SetStatus(k, v)
|
||||
}
|
||||
got, gotOk := pm.Unassigned()
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.wantOk, gotOk)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetStatus(t *testing.T) {
|
||||
pm := New()
|
||||
_ = pm.AddRange(1000, 1002)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port uint16
|
||||
assigned bool
|
||||
}{
|
||||
{"assign port 1000", 1000, true},
|
||||
{"unassign port 1001", 1001, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := pm.SetStatus(tt.port, tt.assigned)
|
||||
assert.NoError(t, err)
|
||||
|
||||
status, ok := pm.(*port).ports[tt.port]
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tt.assigned, status)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaim(t *testing.T) {
|
||||
pm := New()
|
||||
_ = pm.AddRange(1000, 1002)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port uint16
|
||||
status bool
|
||||
want bool
|
||||
}{
|
||||
{"claim unassigned port", 1000, false, true},
|
||||
{"claim already assigned port", 1001, true, false},
|
||||
{"claim non-existent port", 5000, false, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if _, exists := pm.(*port).ports[tt.port]; exists {
|
||||
_ = pm.SetStatus(tt.port, tt.status)
|
||||
}
|
||||
|
||||
got := pm.Claim(tt.port)
|
||||
assert.Equal(t, tt.want, got)
|
||||
|
||||
finalState := pm.(*port).ports[tt.port]
|
||||
assert.True(t, finalState)
|
||||
})
|
||||
}
|
||||
}
|
||||
+35
-12
@@ -1,18 +1,41 @@
|
||||
package random
|
||||
|
||||
import (
|
||||
mathrand "math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
func GenerateRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz"
|
||||
seededRand := mathrand.New(mathrand.NewSource(time.Now().UnixNano() + int64(mathrand.Intn(9999))))
|
||||
var result strings.Builder
|
||||
for i := 0; i < length; i++ {
|
||||
randomIndex := seededRand.Intn(len(charset))
|
||||
result.WriteString(string(charset[randomIndex]))
|
||||
}
|
||||
return result.String()
|
||||
var (
|
||||
ErrInvalidLength = fmt.Errorf("invalid length")
|
||||
)
|
||||
|
||||
type Random interface {
|
||||
String(length int) (string, error)
|
||||
}
|
||||
|
||||
type random struct {
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func New() Random {
|
||||
return &random{reader: rand.Reader}
|
||||
}
|
||||
|
||||
func (ran *random) String(length int) (string, error) {
|
||||
if length < 0 {
|
||||
return "", ErrInvalidLength
|
||||
}
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, length)
|
||||
|
||||
if _, err := ran.reader.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for i := range b {
|
||||
b[i] = charset[int(b[i])%len(charset)]
|
||||
}
|
||||
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
package random
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRandom_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
length int
|
||||
wantErr bool
|
||||
}{
|
||||
{"ValidLengthZero", 0, false},
|
||||
{"ValidPositiveLength", 10, false},
|
||||
{"NegativeLength", -1, true},
|
||||
{"VeryLargeLength", 1_000_000, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
randomizer := New()
|
||||
|
||||
result, err := randomizer.String(tt.length)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, result, tt.length)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandomWithFailingReader_String(t *testing.T) {
|
||||
errBrainrot := assert.AnError
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
reader io.Reader
|
||||
expectErr error
|
||||
}{
|
||||
{
|
||||
name: "failing reader",
|
||||
reader: func() io.Reader {
|
||||
return &failingReader{err: errBrainrot}
|
||||
}(),
|
||||
expectErr: errBrainrot,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
randomizer := &random{reader: tt.reader}
|
||||
result, err := randomizer.String(20)
|
||||
assert.ErrorIs(t, err, tt.expectErr)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type failingReader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *failingReader) Read(p []byte) (int, error) {
|
||||
return 0, f.err
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
)
|
||||
|
||||
type Key = types.SessionKey
|
||||
|
||||
type Session interface {
|
||||
Lifecycle() lifecycle.Lifecycle
|
||||
Interaction() interaction.Interaction
|
||||
Forwarder() forwarder.Forwarder
|
||||
Slug() slug.Slug
|
||||
Detail() *types.Detail
|
||||
}
|
||||
|
||||
type Registry interface {
|
||||
Get(key Key) (session Session, err error)
|
||||
GetWithUser(user string, key Key) (session Session, err error)
|
||||
Update(user string, oldKey, newKey Key) error
|
||||
Register(key Key, session Session) (success bool)
|
||||
Remove(key Key)
|
||||
GetAllSessionFromUser(user string) []Session
|
||||
}
|
||||
type registry struct {
|
||||
mu sync.RWMutex
|
||||
byUser map[string]map[Key]Session
|
||||
slugIndex map[Key]string
|
||||
}
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = fmt.Errorf("session not found")
|
||||
ErrSlugInUse = fmt.Errorf("slug already in use")
|
||||
ErrInvalidSlug = fmt.Errorf("invalid slug")
|
||||
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
|
||||
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
|
||||
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
|
||||
)
|
||||
|
||||
func NewRegistry() Registry {
|
||||
return ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registry) Get(key Key) (session Session, err error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
userID, ok := r.slugIndex[key]
|
||||
if !ok {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
|
||||
client, ok := r.byUser[userID][key]
|
||||
if !ok {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (r *registry) GetWithUser(user string, key Key) (session Session, err error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
client, ok := r.byUser[user][key]
|
||||
if !ok {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (r *registry) Update(user string, oldKey, newKey Key) error {
|
||||
if oldKey.Type != newKey.Type {
|
||||
return ErrSlugUnchanged
|
||||
}
|
||||
|
||||
if newKey.Type != types.TunnelTypeHTTP {
|
||||
return ErrSlugChangeNotAllowed
|
||||
}
|
||||
|
||||
if isForbiddenSlug(newKey.Id) {
|
||||
return ErrForbiddenSlug
|
||||
}
|
||||
|
||||
if !isValidSlug(newKey.Id) {
|
||||
return ErrInvalidSlug
|
||||
}
|
||||
|
||||
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
|
||||
return ErrSlugInUse
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
client, ok := r.byUser[user][oldKey]
|
||||
if !ok {
|
||||
return ErrSessionNotFound
|
||||
}
|
||||
|
||||
delete(r.byUser[user], oldKey)
|
||||
delete(r.slugIndex, oldKey)
|
||||
|
||||
client.Slug().Set(newKey.Id)
|
||||
r.slugIndex[newKey] = user
|
||||
|
||||
r.byUser[user][newKey] = client
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registry) Register(key Key, userSession Session) (success bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.slugIndex[key]; exists {
|
||||
return false
|
||||
}
|
||||
|
||||
userID := userSession.Lifecycle().User()
|
||||
if r.byUser[userID] == nil {
|
||||
r.byUser[userID] = make(map[Key]Session)
|
||||
}
|
||||
|
||||
r.byUser[userID][key] = userSession
|
||||
r.slugIndex[key] = userID
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *registry) GetAllSessionFromUser(user string) []Session {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
m := r.byUser[user]
|
||||
if len(m) == 0 {
|
||||
return []Session{}
|
||||
}
|
||||
|
||||
sessions := make([]Session, 0, len(m))
|
||||
for _, s := range m {
|
||||
sessions = append(sessions, s)
|
||||
}
|
||||
return sessions
|
||||
}
|
||||
|
||||
func (r *registry) Remove(key Key) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
userID, ok := r.slugIndex[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(r.byUser[userID], key)
|
||||
if len(r.byUser[userID]) == 0 {
|
||||
delete(r.byUser, userID)
|
||||
}
|
||||
delete(r.slugIndex, key)
|
||||
}
|
||||
|
||||
func isValidSlug(slug string) bool {
|
||||
if len(slug) < minSlugLength || len(slug) > maxSlugLength {
|
||||
return false
|
||||
}
|
||||
|
||||
if slug[0] == '-' || slug[len(slug)-1] == '-' {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, c := range slug {
|
||||
if !isValidSlugChar(byte(c)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidSlugChar(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-'
|
||||
}
|
||||
|
||||
func isForbiddenSlug(slug string) bool {
|
||||
_, ok := forbiddenSlugs[slug]
|
||||
return ok
|
||||
}
|
||||
|
||||
var forbiddenSlugs = map[string]struct{}{
|
||||
"ping": {},
|
||||
"staging": {},
|
||||
"admin": {},
|
||||
"root": {},
|
||||
"api": {},
|
||||
"www": {},
|
||||
"support": {},
|
||||
"help": {},
|
||||
"status": {},
|
||||
"health": {},
|
||||
"login": {},
|
||||
"logout": {},
|
||||
"signup": {},
|
||||
"register": {},
|
||||
"settings": {},
|
||||
"config": {},
|
||||
"null": {},
|
||||
"undefined": {},
|
||||
"example": {},
|
||||
"test": {},
|
||||
"dev": {},
|
||||
"system": {},
|
||||
"administrator": {},
|
||||
"dashboard": {},
|
||||
"account": {},
|
||||
"profile": {},
|
||||
"user": {},
|
||||
"users": {},
|
||||
"auth": {},
|
||||
"oauth": {},
|
||||
"callback": {},
|
||||
"webhook": {},
|
||||
"webhooks": {},
|
||||
"static": {},
|
||||
"assets": {},
|
||||
"cdn": {},
|
||||
"mail": {},
|
||||
"email": {},
|
||||
"ftp": {},
|
||||
"ssh": {},
|
||||
"git": {},
|
||||
"svn": {},
|
||||
"blog": {},
|
||||
"news": {},
|
||||
"about": {},
|
||||
"contact": {},
|
||||
"terms": {},
|
||||
"privacy": {},
|
||||
"legal": {},
|
||||
"billing": {},
|
||||
"payment": {},
|
||||
"checkout": {},
|
||||
"cart": {},
|
||||
"shop": {},
|
||||
"store": {},
|
||||
"download": {},
|
||||
"uploads": {},
|
||||
"images": {},
|
||||
"img": {},
|
||||
"css": {},
|
||||
"js": {},
|
||||
"fonts": {},
|
||||
"public": {},
|
||||
"private": {},
|
||||
"internal": {},
|
||||
"external": {},
|
||||
"proxy": {},
|
||||
"cache": {},
|
||||
"debug": {},
|
||||
"metrics": {},
|
||||
"monitoring": {},
|
||||
"graphql": {},
|
||||
"rest": {},
|
||||
"rpc": {},
|
||||
"socket": {},
|
||||
"ws": {},
|
||||
"wss": {},
|
||||
"app": {},
|
||||
"apps": {},
|
||||
"mobile": {},
|
||||
"desktop": {},
|
||||
"embed": {},
|
||||
"widget": {},
|
||||
"docs": {},
|
||||
"documentation": {},
|
||||
"wiki": {},
|
||||
"forum": {},
|
||||
"community": {},
|
||||
"feedback": {},
|
||||
"report": {},
|
||||
"abuse": {},
|
||||
"spam": {},
|
||||
"security": {},
|
||||
"verify": {},
|
||||
"confirm": {},
|
||||
"reset": {},
|
||||
"password": {},
|
||||
"recovery": {},
|
||||
"unsubscribe": {},
|
||||
"subscribe": {},
|
||||
"notifications": {},
|
||||
"alerts": {},
|
||||
"messages": {},
|
||||
"inbox": {},
|
||||
"outbox": {},
|
||||
"sent": {},
|
||||
"draft": {},
|
||||
"trash": {},
|
||||
"archive": {},
|
||||
"search": {},
|
||||
"explore": {},
|
||||
"discover": {},
|
||||
"trending": {},
|
||||
"popular": {},
|
||||
"featured": {},
|
||||
"new": {},
|
||||
"latest": {},
|
||||
"top": {},
|
||||
"best": {},
|
||||
"hot": {},
|
||||
"random": {},
|
||||
"all": {},
|
||||
"any": {},
|
||||
"none": {},
|
||||
"true": {},
|
||||
"false": {},
|
||||
}
|
||||
|
||||
var (
|
||||
minSlugLength = 3
|
||||
maxSlugLength = 20
|
||||
)
|
||||
@@ -0,0 +1,695 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type mockSession struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(lifecycle.Lifecycle)
|
||||
}
|
||||
func (m *mockSession) Interaction() interaction.Interaction {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(interaction.Interaction)
|
||||
}
|
||||
func (m *mockSession) Forwarder() forwarder.Forwarder {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(forwarder.Forwarder)
|
||||
}
|
||||
func (m *mockSession) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
func (m *mockSession) Detail() *types.Detail {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*types.Detail)
|
||||
}
|
||||
|
||||
type mockLifecycle struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (ml *mockLifecycle) Channel() ssh.Channel {
|
||||
args := ml.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(ssh.Channel)
|
||||
}
|
||||
|
||||
func (ml *mockLifecycle) Connection() ssh.Conn {
|
||||
args := ml.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(ssh.Conn)
|
||||
}
|
||||
|
||||
func (ml *mockLifecycle) PortRegistry() port.Port {
|
||||
args := ml.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(port.Port)
|
||||
}
|
||||
|
||||
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { ml.Called(channel) }
|
||||
func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { ml.Called(status) }
|
||||
func (ml *mockLifecycle) IsActive() bool { return ml.Called().Bool(0) }
|
||||
func (ml *mockLifecycle) StartedAt() time.Time { return ml.Called().Get(0).(time.Time) }
|
||||
func (ml *mockLifecycle) Close() error { return ml.Called().Error(0) }
|
||||
func (ml *mockLifecycle) User() string { return ml.Called().String(0) }
|
||||
|
||||
type mockSlug struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (ms *mockSlug) Set(slug string) { ms.Called(slug) }
|
||||
func (ms *mockSlug) String() string { return ms.Called().String(0) }
|
||||
|
||||
func createMockSession(user ...string) *mockSession {
|
||||
u := "user1"
|
||||
if len(user) > 0 {
|
||||
u = user[0]
|
||||
}
|
||||
m := new(mockSession)
|
||||
ml := new(mockLifecycle)
|
||||
ml.On("User").Return(u).Maybe()
|
||||
m.On("Lifecycle").Return(ml).Maybe()
|
||||
ms := new(mockSlug)
|
||||
ms.On("Set", mock.Anything).Maybe()
|
||||
m.On("Slug").Return(ms).Maybe()
|
||||
m.On("Interaction").Return(nil).Maybe()
|
||||
m.On("Forwarder").Return(nil).Maybe()
|
||||
m.On("Detail").Return(nil).Maybe()
|
||||
return m
|
||||
}
|
||||
|
||||
func TestNewRegistry(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
require.NotNil(t, r)
|
||||
}
|
||||
|
||||
func TestRegistry_Get(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry)
|
||||
key types.SessionKey
|
||||
wantErr error
|
||||
wantResult bool
|
||||
}{
|
||||
{
|
||||
name: "session found",
|
||||
setupFunc: func(r *registry) {
|
||||
user := "user1"
|
||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession(user)
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser[user] = map[types.SessionKey]Session{
|
||||
key: session,
|
||||
}
|
||||
r.slugIndex[key] = user
|
||||
},
|
||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||
wantErr: nil,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "session not found in slugIndex",
|
||||
setupFunc: func(r *registry) {},
|
||||
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
{
|
||||
name: "session not found in byUser",
|
||||
setupFunc: func(r *registry) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
|
||||
},
|
||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[types.SessionKey]Session),
|
||||
slugIndex: make(map[types.SessionKey]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
tt.setupFunc(r)
|
||||
|
||||
session, err := r.Get(tt.key)
|
||||
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantResult, session != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_GetWithUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry)
|
||||
user string
|
||||
key types.SessionKey
|
||||
wantErr error
|
||||
wantResult bool
|
||||
}{
|
||||
{
|
||||
name: "session found",
|
||||
setupFunc: func(r *registry) {
|
||||
user := "user1"
|
||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser[user] = map[types.SessionKey]Session{
|
||||
key: session,
|
||||
}
|
||||
r.slugIndex[key] = user
|
||||
},
|
||||
user: "user1",
|
||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||
wantErr: nil,
|
||||
wantResult: true,
|
||||
},
|
||||
{
|
||||
name: "session not found in slugIndex",
|
||||
setupFunc: func(r *registry) {},
|
||||
user: "user1",
|
||||
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
{
|
||||
name: "session not found in byUser",
|
||||
setupFunc: func(r *registry) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
|
||||
},
|
||||
user: "user1",
|
||||
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[types.SessionKey]Session),
|
||||
slugIndex: make(map[types.SessionKey]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
tt.setupFunc(r)
|
||||
|
||||
session, err := r.GetWithUser(tt.user, tt.key)
|
||||
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.wantResult, session != nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Update(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user string
|
||||
setupFunc func(r *registry) (oldKey, newKey types.SessionKey)
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "change slug success",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession("user1")
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "change slug to already used slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
newKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
r.slugIndex[newKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSlugInUse,
|
||||
},
|
||||
{
|
||||
name: "change slug to forbidden slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrForbiddenSlug,
|
||||
},
|
||||
{
|
||||
name: "change slug to invalid slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrInvalidSlug,
|
||||
},
|
||||
{
|
||||
name: "change slug but session not found",
|
||||
user: "user2",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
|
||||
}
|
||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
{
|
||||
name: "change slug but session is not in the map",
|
||||
user: "user2",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
|
||||
}
|
||||
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSessionNotFound,
|
||||
},
|
||||
{
|
||||
name: "change slug with same slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSlugUnchanged,
|
||||
},
|
||||
{
|
||||
name: "tcp tunnel cannot change slug",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
|
||||
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
|
||||
newKey := oldKey
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.byUser["user1"] = map[types.SessionKey]Session{
|
||||
oldKey: session,
|
||||
}
|
||||
r.slugIndex[oldKey] = "user1"
|
||||
|
||||
return oldKey, newKey
|
||||
},
|
||||
wantErr: ErrSlugChangeNotAllowed,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[types.SessionKey]Session),
|
||||
slugIndex: make(map[types.SessionKey]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
|
||||
oldKey, newKey := tt.setupFunc(r)
|
||||
|
||||
err := r.Update(tt.user, oldKey, newKey)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
|
||||
if err == nil {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
_, ok := r.byUser[tt.user][newKey]
|
||||
assert.True(t, ok, "newKey not found in registry")
|
||||
_, ok = r.byUser[tt.user][oldKey]
|
||||
assert.False(t, ok, "oldKey still exists in registry")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Register(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user string
|
||||
setupFunc func(r *registry) Key
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "register new key successfully",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) Key {
|
||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
return key
|
||||
},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "register already existing key fails",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) Key {
|
||||
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
|
||||
r.mu.Lock()
|
||||
r.byUser["user1"] = map[Key]Session{key: session}
|
||||
r.slugIndex[key] = "user1"
|
||||
r.mu.Unlock()
|
||||
|
||||
return key
|
||||
},
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "register multiple keys for same user",
|
||||
user: "user1",
|
||||
setupFunc: func(r *registry) Key {
|
||||
firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
r.mu.Lock()
|
||||
r.byUser["user1"] = map[Key]Session{firstKey: session}
|
||||
r.slugIndex[firstKey] = "user1"
|
||||
r.mu.Unlock()
|
||||
|
||||
return types.SessionKey{Id: "second", Type: types.TunnelTypeHTTP}
|
||||
},
|
||||
wantOK: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
|
||||
key := tt.setupFunc(r)
|
||||
session := createMockSession()
|
||||
|
||||
ok := r.Register(key, session)
|
||||
assert.Equal(t, tt.wantOK, ok)
|
||||
|
||||
if ok {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
assert.Equal(t, session, r.byUser[tt.user][key], "session not stored in byUser")
|
||||
assert.Equal(t, tt.user, r.slugIndex[key], "slugIndex not updated")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_GetAllSessionFromUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry) string
|
||||
expectN int
|
||||
}{
|
||||
{
|
||||
name: "user has no sessions",
|
||||
setupFunc: func(r *registry) string {
|
||||
return "user1"
|
||||
},
|
||||
expectN: 0,
|
||||
},
|
||||
{
|
||||
name: "user has multiple sessions",
|
||||
setupFunc: func(r *registry) string {
|
||||
user := "user1"
|
||||
key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
|
||||
r.mu.Lock()
|
||||
r.byUser[user] = map[Key]Session{
|
||||
key1: createMockSession(),
|
||||
key2: createMockSession(),
|
||||
}
|
||||
r.mu.Unlock()
|
||||
return user
|
||||
},
|
||||
expectN: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
user := tt.setupFunc(r)
|
||||
sessions := r.GetAllSessionFromUser(user)
|
||||
assert.Len(t, sessions, tt.expectN)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Remove(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry) (string, types.SessionKey)
|
||||
key types.SessionKey
|
||||
verify func(*testing.T, *registry, string, types.SessionKey)
|
||||
}{
|
||||
{
|
||||
name: "remove existing key",
|
||||
setupFunc: func(r *registry) (string, types.SessionKey) {
|
||||
user := "user1"
|
||||
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||
session := createMockSession()
|
||||
r.mu.Lock()
|
||||
r.byUser[user] = map[Key]Session{key: session}
|
||||
r.slugIndex[key] = user
|
||||
r.mu.Unlock()
|
||||
return user, key
|
||||
},
|
||||
verify: func(t *testing.T, r *registry, user string, key types.SessionKey) {
|
||||
_, ok := r.byUser[user][key]
|
||||
assert.False(t, ok, "expected key to be removed from byUser")
|
||||
_, ok = r.slugIndex[key]
|
||||
assert.False(t, ok, "expected key to be removed from slugIndex")
|
||||
_, ok = r.byUser[user]
|
||||
assert.False(t, ok, "expected user to be removed from byUser map")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove non-existing key",
|
||||
setupFunc: func(r *registry) (string, types.SessionKey) {
|
||||
return "", types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
user, key := tt.setupFunc(r)
|
||||
if user == "" {
|
||||
key = tt.key
|
||||
}
|
||||
r.Remove(key)
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, r, user, key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidSlug(t *testing.T) {
|
||||
tests := []struct {
|
||||
slug string
|
||||
want bool
|
||||
}{
|
||||
{"abc", true},
|
||||
{"abc-123", true},
|
||||
{"a", false},
|
||||
{"verybigdihsixsevenlabubu", false},
|
||||
{"-iamsigma", false},
|
||||
{"ligma-", false},
|
||||
{"invalid$", false},
|
||||
{"valid-slug1", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.slug, func(t *testing.T) {
|
||||
got := isValidSlug(tt.slug)
|
||||
if got != tt.want {
|
||||
t.Errorf("isValidSlug(%q) = %v; want %v", tt.slug, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidSlugChar(t *testing.T) {
|
||||
tests := []struct {
|
||||
char byte
|
||||
want bool
|
||||
}{
|
||||
{'a', true},
|
||||
{'z', true},
|
||||
{'0', true},
|
||||
{'9', true},
|
||||
{'-', true},
|
||||
{'A', false},
|
||||
{'$', false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(string(tt.char), func(t *testing.T) {
|
||||
got := isValidSlugChar(tt.char)
|
||||
if got != tt.want {
|
||||
t.Errorf("isValidSlugChar(%q) = %v; want %v", tt.char, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsForbiddenSlug(t *testing.T) {
|
||||
forbiddenSlugs = map[string]struct{}{
|
||||
"admin": {},
|
||||
"root": {},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
slug string
|
||||
want bool
|
||||
}{
|
||||
{"admin", true},
|
||||
{"root", true},
|
||||
{"user", false},
|
||||
{"guest", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.slug, func(t *testing.T) {
|
||||
got := isForbiddenSlug(tt.slug)
|
||||
if got != tt.want {
|
||||
t.Errorf("isForbiddenSlug(%q) = %v; want %v", tt.slug, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/registry"
|
||||
)
|
||||
|
||||
type httpServer struct {
|
||||
handler *httpHandler
|
||||
config config.Config
|
||||
}
|
||||
|
||||
func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport {
|
||||
return &httpServer{
|
||||
handler: newHTTPHandler(config, sessionRegistry),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (ht *httpServer) Listen() (net.Listener, error) {
|
||||
return net.Listen("tcp", ":"+ht.config.HTTPPort())
|
||||
}
|
||||
|
||||
func (ht *httpServer) Serve(listener net.Listener) error {
|
||||
log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort())
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return err
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go ht.handler.Handler(conn, false)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestNewHTTPServer(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
httpSrv, ok := srv.(*httpServer)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
|
||||
assert.NotNil(t, srv)
|
||||
}
|
||||
|
||||
func TestHTTPServer_Listen(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, listener)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHTTPServer_Serve(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
err = srv.Serve(listener)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
}
|
||||
|
||||
func TestHTTPServer_Serve_AcceptError(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
err := srv.Serve(ml)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
ml.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestHTTPServer_Serve_Success(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
mockConfig.On("HeaderSize").Return(4096)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
listenerport := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
}
|
||||
|
||||
type mockListener struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockListener) Accept() (net.Conn, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(net.Conn), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockListener) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockListener) Addr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/http/header"
|
||||
"tunnel_pls/internal/http/stream"
|
||||
"tunnel_pls/internal/middleware"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type httpHandler struct {
|
||||
config config.Config
|
||||
sessionRegistry registry.Registry
|
||||
}
|
||||
|
||||
func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler {
|
||||
return &httpHandler{
|
||||
config: config,
|
||||
sessionRegistry: sessionRegistry,
|
||||
}
|
||||
}
|
||||
|
||||
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
|
||||
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
|
||||
fmt.Sprintf("Location: %s", location) +
|
||||
"Content-Length: 0\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"\r\n"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) badRequest(conn net.Conn) error {
|
||||
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
|
||||
defer hh.closeConnection(conn)
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
buf := make([]byte, hh.config.HeaderSize())
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 {
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
|
||||
reqhf, err := header.NewRequest(buf[:n])
|
||||
if err != nil {
|
||||
log.Printf("Error creating request header: %v", err)
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
slug, err := hh.extractSlug(reqhf)
|
||||
if err != nil {
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
if hh.shouldRedirectToTLS(isTLS) {
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain()))
|
||||
return
|
||||
}
|
||||
|
||||
if hh.handlePingRequest(slug, conn) {
|
||||
return
|
||||
}
|
||||
|
||||
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
||||
Id: slug,
|
||||
Type: types.TunnelTypeHTTP,
|
||||
})
|
||||
if err != nil {
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
|
||||
return
|
||||
}
|
||||
|
||||
hw := stream.New(conn, conn, conn.RemoteAddr())
|
||||
defer func(hw stream.HTTP) {
|
||||
err = hw.Close()
|
||||
if err != nil {
|
||||
log.Printf("Error closing HTTP stream: %v", err)
|
||||
}
|
||||
}(hw)
|
||||
hh.forwardRequest(hw, reqhf, sshSession)
|
||||
}
|
||||
|
||||
func (hh *httpHandler) closeConnection(conn net.Conn) {
|
||||
err := conn.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error closing connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
|
||||
host := strings.Split(reqhf.Value("Host"), ".")
|
||||
if len(host) <= 1 {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
return host[0], nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
|
||||
return !isTLS && hh.config.TLSRedirect()
|
||||
}
|
||||
|
||||
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
||||
if slug != "ping" {
|
||||
return false
|
||||
}
|
||||
|
||||
_, err := conn.Write([]byte(
|
||||
"HTTP/1.1 200 OK\r\n" +
|
||||
"Content-Length: 0\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"Access-Control-Allow-Origin: *\r\n" +
|
||||
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
|
||||
"Access-Control-Allow-Headers: *\r\n" +
|
||||
"\r\n",
|
||||
))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 200 OK:", err)
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, hw.RemoteAddr())
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
defer func() {
|
||||
err = channel.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error closing forwarded channel: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
hh.setupMiddlewares(hw)
|
||||
|
||||
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
|
||||
log.Printf("Failed to forward initial request: %v", err)
|
||||
return
|
||||
}
|
||||
sshSession.Forwarder().HandleConnection(hw, channel)
|
||||
}
|
||||
|
||||
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
|
||||
fingerprintMiddleware := middleware.NewTunnelFingerprint()
|
||||
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
|
||||
|
||||
hw.UseResponseMiddleware(fingerprintMiddleware)
|
||||
hw.UseRequestMiddleware(forwardedForMiddleware)
|
||||
}
|
||||
|
||||
func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error {
|
||||
hw.SetRequestHeader(initialRequest)
|
||||
|
||||
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
|
||||
return fmt.Errorf("error applying request middlewares: %w", err)
|
||||
}
|
||||
|
||||
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
|
||||
return fmt.Errorf("error writing to channel: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,717 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
args := m.Called(user, oldKey, newKey)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
args := m.Called(key, session)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
args := m.Called(user)
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
|
||||
args := m.Called()
|
||||
return args.Get(0).(lifecycle.Lifecycle)
|
||||
}
|
||||
|
||||
func (m *MockSession) Interaction() interaction.Interaction {
|
||||
args := m.Called()
|
||||
return args.Get(0).(interaction.Interaction)
|
||||
}
|
||||
|
||||
func (m *MockSession) Forwarder() forwarder.Forwarder {
|
||||
args := m.Called()
|
||||
return args.Get(0).(forwarder.Forwarder)
|
||||
}
|
||||
|
||||
func (m *MockSession) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
func (m *MockSession) Detail() *types.Detail {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*types.Detail)
|
||||
}
|
||||
|
||||
type MockSSHChannel struct {
|
||||
ssh.Channel
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Write(data []byte) (int, error) {
|
||||
args := m.Called(data)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockForwarder struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||
m.Called(dst, src)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) TunnelType() types.TunnelType {
|
||||
args := m.Called()
|
||||
return args.Get(0).(types.TunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) ForwardedPort() uint16 {
|
||||
args := m.Called()
|
||||
return uint16(args.Int(0))
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
||||
m.Called(tunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
||||
m.Called(port)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetListener(listener net.Listener) {
|
||||
m.Called(listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Listener() net.Listener {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(ctx, origin)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
|
||||
type MockConn struct {
|
||||
mock.Mock
|
||||
ReadBuffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
if m.ReadBuffer != nil {
|
||||
return m.ReadBuffer.Read(b)
|
||||
}
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
if args.Int(0) == -1 {
|
||||
return len(b), args.Error(1)
|
||||
}
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *wrappedConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func TestNewHTTPHandler(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("Domain").Return("domain")
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
hh := newHTTPHandler(mockConfig, msr)
|
||||
assert.NotNil(t, hh)
|
||||
assert.Equal(t, msr, hh.sessionRegistry)
|
||||
}
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isTLS bool
|
||||
redirectTLS bool
|
||||
request []byte
|
||||
expected []byte
|
||||
setupMocks func(*MockSessionRegistry)
|
||||
setupConn func() (net.Conn, net.Conn)
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "bad request - invalid host",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - missing host",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "isTLS true and redirectTLS true - no redirect",
|
||||
isTLS: true,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redirect to TLS",
|
||||
isTLS: false,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle ping request",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "session not found",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return((registry.Session)(nil), fmt.Errorf("session not found"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - invalid http",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("INVALID\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - header too large",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - no request",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte(""),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - open channel fails",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - send initial request fails",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
go func() {
|
||||
for range reqCh {
|
||||
}
|
||||
}()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - success",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.ReadWriter)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
||||
})
|
||||
|
||||
go func() {
|
||||
for range reqCh {
|
||||
}
|
||||
}()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redirect - write failure",
|
||||
isTLS: false,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - write failure",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read error - connection failure",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte(""),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle ping request - write failure",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "close connection - error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(182, nil)
|
||||
mc.On("Close").Return(fmt.Errorf("close error"))
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - stream close error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return()
|
||||
},
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
|
||||
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
mc.On("RemoteAddr").Return(addr)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - middleware failure",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
|
||||
return k.Id == "test"
|
||||
})).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
},
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Close").Return(nil).Times(2)
|
||||
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - channel close error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.ReadWriter)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - open channel timeout",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
ctx := args.Get(0).(context.Context)
|
||||
<-ctx.Done()
|
||||
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSessionRegistry := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
mockConfig.On("HeaderSize").Return(4096)
|
||||
mockConfig.On("TLSRedirect").Return(true)
|
||||
hh := &httpHandler{
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
config: mockConfig,
|
||||
}
|
||||
|
||||
if tt.setupMocks != nil {
|
||||
tt.setupMocks(mockSessionRegistry)
|
||||
}
|
||||
|
||||
var serverConn, clientConn net.Conn
|
||||
if tt.setupConn != nil {
|
||||
serverConn, clientConn = tt.setupConn()
|
||||
} else {
|
||||
serverConn, clientConn = net.Pipe()
|
||||
}
|
||||
|
||||
if clientConn != nil {
|
||||
defer func(clientConn net.Conn) {
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
}(clientConn)
|
||||
}
|
||||
|
||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
var wrappedServerConn net.Conn
|
||||
if _, ok := serverConn.(*MockConn); ok {
|
||||
wrappedServerConn = serverConn
|
||||
} else {
|
||||
wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr}
|
||||
}
|
||||
|
||||
responseChan := make(chan []byte, 1)
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
if clientConn != nil {
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
var res []byte
|
||||
for {
|
||||
buf := make([]byte, 4096)
|
||||
n, err := clientConn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Logf("Error reading response: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
res = append(res, buf[:n]...)
|
||||
if len(tt.expected) > 0 && len(res) >= len(tt.expected) {
|
||||
break
|
||||
}
|
||||
}
|
||||
responseChan <- res
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := clientConn.Write(tt.request)
|
||||
if err != nil {
|
||||
t.Logf("Error writing request: %v", err)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
close(responseChan)
|
||||
close(doneChan)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
hh.Handler(wrappedServerConn, tt.isTLS)
|
||||
}()
|
||||
|
||||
select {
|
||||
case response := <-responseChan:
|
||||
if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" {
|
||||
resStr := string(response)
|
||||
assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n"))
|
||||
assert.Contains(t, resStr, "Content-Length: 5\r\n")
|
||||
assert.Contains(t, resStr, "Server: Tunnel Please\r\n")
|
||||
assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello"))
|
||||
} else {
|
||||
assert.Equal(t, string(tt.expected), string(response))
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
if clientConn != nil {
|
||||
t.Fatal("Test timeout - no response received")
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if clientConn != nil {
|
||||
<-doneChan
|
||||
}
|
||||
|
||||
mockSessionRegistry.AssertExpectations(t)
|
||||
if mc, ok := serverConn.(*MockConn); ok {
|
||||
mc.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/registry"
|
||||
)
|
||||
|
||||
type https struct {
|
||||
config config.Config
|
||||
tlsConfig *tls.Config
|
||||
httpHandler *httpHandler
|
||||
}
|
||||
|
||||
func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport {
|
||||
return &https{
|
||||
config: config,
|
||||
tlsConfig: tlsConfig,
|
||||
httpHandler: newHTTPHandler(config, sessionRegistry),
|
||||
}
|
||||
}
|
||||
|
||||
func (ht *https) Listen() (net.Listener, error) {
|
||||
return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig)
|
||||
}
|
||||
|
||||
func (ht *https) Serve(listener net.Listener) error {
|
||||
log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort())
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return err
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go ht.httpHandler.Handler(conn, true)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewHTTPSServer(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
tlsConfig := &tls.Config{}
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
httpsSrv, ok := srv.(*https)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
|
||||
assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Listen(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
if err != nil {
|
||||
t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err)
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, listener)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Serve(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
err = srv.Serve(listener)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
err := srv.Serve(ml)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
ml.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Serve_Success(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
mockConfig.On("HeaderSize").Return(4096)
|
||||
|
||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
listenerport := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type tcp struct {
|
||||
port uint16
|
||||
forwarder Forwarder
|
||||
}
|
||||
|
||||
type Forwarder interface {
|
||||
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||
}
|
||||
|
||||
func NewTCPServer(port uint16, forwarder Forwarder) Transport {
|
||||
return &tcp{
|
||||
port: port,
|
||||
forwarder: forwarder,
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *tcp) Listen() (net.Listener, error) {
|
||||
return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port))
|
||||
}
|
||||
|
||||
func (tt *tcp) Serve(listener net.Listener) error {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
go tt.handleTcp(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (tt *tcp) handleTcp(conn net.Conn) {
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, conn.RemoteAddr())
|
||||
if err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
tt.forwarder.HandleConnection(conn, channel)
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestNewTCPServer(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
port := uint16(9000)
|
||||
|
||||
srv := NewTCPServer(port, mf)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
tcpSrv, ok := srv.(*tcp)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, port, tcpSrv.port)
|
||||
assert.Equal(t, mf, tcpSrv.forwarder)
|
||||
}
|
||||
|
||||
func TestTCPServer_Listen(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, listener)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestTCPServer_Serve(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
err = srv.Serve(listener)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestTCPServer_Serve_AcceptError(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
err := srv.Serve(ml)
|
||||
assert.Nil(t, err)
|
||||
ml.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_Serve_Success(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
reqs := make(chan *ssh.Request)
|
||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
|
||||
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err = conn.Close()
|
||||
assert.NoError(t, err)
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
mf.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_handleTcp_Success(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf).(*tcp)
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer func(clientConn net.Conn) {
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
}(clientConn)
|
||||
|
||||
reqs := make(chan *ssh.Request)
|
||||
mockChannel := new(MockSSHChannel)
|
||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
|
||||
|
||||
mf.On("HandleConnection", serverConn, mockChannel).Return()
|
||||
|
||||
srv.handleTcp(serverConn)
|
||||
|
||||
mf.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_handleTcp_CloseError(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf).(*tcp)
|
||||
|
||||
mc := new(MockConn)
|
||||
mc.On("Close").Return(errors.New("close error"))
|
||||
mc.On("RemoteAddr").Return(&net.TCPAddr{})
|
||||
|
||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||
|
||||
srv.handleTcp(mc)
|
||||
mc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf).(*tcp)
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer func(clientConn net.Conn) {
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
}(clientConn)
|
||||
|
||||
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||
|
||||
srv.handleTcp(serverConn)
|
||||
|
||||
mf.AssertExpectations(t)
|
||||
}
|
||||
@@ -0,0 +1,435 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
"github.com/libdns/cloudflare"
|
||||
)
|
||||
|
||||
func NewTLSConfig(config config.Config) (*tls.Config, error) {
|
||||
var initErr error
|
||||
|
||||
tlsManagerOnce.Do(func() {
|
||||
tm := createTLSManager(config)
|
||||
initErr = tm.initialize()
|
||||
if initErr == nil {
|
||||
globalTLSManager = tm
|
||||
}
|
||||
})
|
||||
|
||||
if initErr != nil {
|
||||
return nil, initErr
|
||||
}
|
||||
|
||||
return globalTLSManager.getTLSConfig(), nil
|
||||
}
|
||||
|
||||
type tlsManager struct {
|
||||
config config.Config
|
||||
|
||||
certPath string
|
||||
keyPath string
|
||||
storagePath string
|
||||
|
||||
userCert *tls.Certificate
|
||||
userCertMu sync.RWMutex
|
||||
|
||||
magic *certmagic.Config
|
||||
|
||||
useCertMagic bool
|
||||
}
|
||||
|
||||
var globalTLSManager *tlsManager
|
||||
var tlsManagerOnce sync.Once
|
||||
|
||||
func createTLSManager(cfg config.Config) *tlsManager {
|
||||
storagePath := cfg.TLSStoragePath()
|
||||
cleanBase := filepath.Clean(storagePath)
|
||||
|
||||
return &tlsManager{
|
||||
config: cfg,
|
||||
certPath: filepath.Join(cleanBase, "cert.pem"),
|
||||
keyPath: filepath.Join(cleanBase, "privkey.pem"),
|
||||
storagePath: filepath.Join(cleanBase, "certmagic"),
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initialize() error {
|
||||
if tm.userCertsExistAndValid() {
|
||||
return tm.initializeWithUserCerts()
|
||||
}
|
||||
return tm.initializeWithCertMagic()
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initializeWithUserCerts() error {
|
||||
log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath)
|
||||
|
||||
if err := tm.loadUserCerts(); err != nil {
|
||||
return fmt.Errorf("failed to load user certificates: %w", err)
|
||||
}
|
||||
|
||||
tm.useCertMagic = false
|
||||
tm.startCertWatcher()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initializeWithCertMagic() error {
|
||||
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic",
|
||||
tm.config.Domain(), tm.config.Domain())
|
||||
|
||||
if err := tm.initCertMagic(); err != nil {
|
||||
return fmt.Errorf("failed to initialize CertMagic: %w", err)
|
||||
}
|
||||
|
||||
tm.useCertMagic = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||
if !tm.certFilesExist() {
|
||||
return false
|
||||
}
|
||||
return validateCertDomains(tm.certPath, tm.config.Domain())
|
||||
}
|
||||
|
||||
func (tm *tlsManager) certFilesExist() bool {
|
||||
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
||||
log.Printf("Certificate file not found: %s", tm.certPath)
|
||||
return false
|
||||
}
|
||||
if _, err := os.Stat(tm.keyPath); os.IsNotExist(err) {
|
||||
log.Printf("Key file not found: %s", tm.keyPath)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (tm *tlsManager) loadUserCerts() error {
|
||||
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tm.userCertMu.Lock()
|
||||
tm.userCert = &cert
|
||||
tm.userCertMu.Unlock()
|
||||
|
||||
log.Printf("Loaded user certificates successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) startCertWatcher() {
|
||||
go func() {
|
||||
watcher := newCertWatcher(tm)
|
||||
watcher.watch()
|
||||
}()
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initCertMagic() error {
|
||||
if err := tm.createStorageDirectory(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if tm.config.CFAPIToken() == "" {
|
||||
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
||||
}
|
||||
|
||||
magic := tm.createCertMagicConfig()
|
||||
tm.magic = magic
|
||||
|
||||
return tm.obtainCertificates(magic)
|
||||
}
|
||||
|
||||
func (tm *tlsManager) createStorageDirectory() error {
|
||||
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
|
||||
cfProvider := &cloudflare.Provider{
|
||||
APIToken: tm.config.CFAPIToken(),
|
||||
}
|
||||
|
||||
storage := &certmagic.FileStorage{Path: tm.storagePath}
|
||||
|
||||
cache := certmagic.NewCache(certmagic.CacheOptions{
|
||||
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
|
||||
return tm.magic, nil
|
||||
},
|
||||
})
|
||||
|
||||
magic := certmagic.New(cache, certmagic.Config{
|
||||
Storage: storage,
|
||||
})
|
||||
|
||||
acmeIssuer := tm.createACMEIssuer(magic, cfProvider)
|
||||
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
||||
|
||||
return magic
|
||||
}
|
||||
|
||||
func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer {
|
||||
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
||||
Email: tm.config.ACMEEmail(),
|
||||
Agreed: true,
|
||||
DNS01Solver: &certmagic.DNS01Solver{
|
||||
DNSManager: certmagic.DNSManager{
|
||||
DNSProvider: cfProvider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if tm.config.ACMEStaging() {
|
||||
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
|
||||
log.Printf("Using Let's Encrypt staging server")
|
||||
} else {
|
||||
acmeIssuer.CA = certmagic.LetsEncryptProductionCA
|
||||
log.Printf("Using Let's Encrypt production server")
|
||||
}
|
||||
|
||||
return acmeIssuer
|
||||
}
|
||||
|
||||
func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error {
|
||||
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
|
||||
log.Printf("Requesting certificates for: %v", domains)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := magic.ManageSync(ctx, domains); err != nil {
|
||||
return fmt.Errorf("failed to obtain certificates: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Certificates obtained successfully for %v", domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) getTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: tm.getCertificate,
|
||||
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
|
||||
CurvePreferences: []tls.CurveID{
|
||||
tls.X25519,
|
||||
},
|
||||
|
||||
SessionTicketsDisabled: false,
|
||||
ClientAuth: tls.NoClientCert,
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if tm.useCertMagic {
|
||||
return tm.magic.GetCertificate(hello)
|
||||
}
|
||||
|
||||
tm.userCertMu.RLock()
|
||||
defer tm.userCertMu.RUnlock()
|
||||
|
||||
if tm.userCert == nil {
|
||||
return nil, fmt.Errorf("no certificate available")
|
||||
}
|
||||
|
||||
return tm.userCert, nil
|
||||
}
|
||||
|
||||
func validateCertDomains(certPath, domain string) bool {
|
||||
cert, err := loadAndParseCertificate(certPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !isCertificateValid(cert) {
|
||||
return false
|
||||
}
|
||||
|
||||
return certCoversRequiredDomains(cert, domain)
|
||||
}
|
||||
|
||||
func loadAndParseCertificate(certPath string) (*x509.Certificate, error) {
|
||||
certPEM, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read certificate: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
log.Printf("Failed to decode PEM block from certificate")
|
||||
return nil, fmt.Errorf("failed to decode PEM block")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse certificate: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
func isCertificateValid(cert *x509.Certificate) bool {
|
||||
now := time.Now()
|
||||
|
||||
if now.After(cert.NotAfter) {
|
||||
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
thirtyDaysFromNow := now.Add(30 * 24 * time.Hour)
|
||||
if thirtyDaysFromNow.After(cert.NotAfter) {
|
||||
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool {
|
||||
certDomains := extractCertDomains(cert)
|
||||
hasBase, hasWildcard := checkDomainCoverage(certDomains, domain)
|
||||
|
||||
logDomainCoverage(hasBase, hasWildcard, domain)
|
||||
return hasBase && hasWildcard
|
||||
}
|
||||
|
||||
func extractCertDomains(cert *x509.Certificate) []string {
|
||||
var domains []string
|
||||
if cert.Subject.CommonName != "" {
|
||||
domains = append(domains, cert.Subject.CommonName)
|
||||
}
|
||||
domains = append(domains, cert.DNSNames...)
|
||||
return domains
|
||||
}
|
||||
|
||||
func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) {
|
||||
wildcardDomain := "*." + domain
|
||||
|
||||
for _, d := range certDomains {
|
||||
if d == domain {
|
||||
hasBase = true
|
||||
}
|
||||
if d == wildcardDomain {
|
||||
hasWildcard = true
|
||||
}
|
||||
}
|
||||
|
||||
return hasBase, hasWildcard
|
||||
}
|
||||
|
||||
func logDomainCoverage(hasBase, hasWildcard bool, domain string) {
|
||||
if !hasBase {
|
||||
log.Printf("Certificate does not cover base domain: %s", domain)
|
||||
}
|
||||
if !hasWildcard {
|
||||
log.Printf("Certificate does not cover wildcard domain: *.%s", domain)
|
||||
}
|
||||
}
|
||||
|
||||
type certWatcher struct {
|
||||
tm *tlsManager
|
||||
lastCertMod time.Time
|
||||
lastKeyMod time.Time
|
||||
}
|
||||
|
||||
func newCertWatcher(tm *tlsManager) *certWatcher {
|
||||
watcher := &certWatcher{tm: tm}
|
||||
watcher.initializeModTimes()
|
||||
return watcher
|
||||
}
|
||||
|
||||
func (cw *certWatcher) initializeModTimes() {
|
||||
if info, err := os.Stat(cw.tm.certPath); err == nil {
|
||||
cw.lastCertMod = info.ModTime()
|
||||
}
|
||||
if info, err := os.Stat(cw.tm.keyPath); err == nil {
|
||||
cw.lastKeyMod = info.ModTime()
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *certWatcher) watch() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if cw.checkAndReloadCerts() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cw *certWatcher) checkAndReloadCerts() bool {
|
||||
certInfo, keyInfo, err := cw.getFileInfo()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !cw.filesModified(certInfo, keyInfo) {
|
||||
return false
|
||||
}
|
||||
|
||||
return cw.handleCertificateChange(certInfo, keyInfo)
|
||||
}
|
||||
|
||||
func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) {
|
||||
certInfo, certErr := os.Stat(cw.tm.certPath)
|
||||
keyInfo, keyErr := os.Stat(cw.tm.keyPath)
|
||||
|
||||
if certErr != nil || keyErr != nil {
|
||||
return nil, nil, fmt.Errorf("file stat error")
|
||||
}
|
||||
|
||||
return certInfo, keyInfo, nil
|
||||
}
|
||||
|
||||
func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool {
|
||||
return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod)
|
||||
}
|
||||
|
||||
func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool {
|
||||
log.Printf("Certificate files changed, reloading...")
|
||||
|
||||
if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) {
|
||||
return cw.switchToCertMagic()
|
||||
}
|
||||
|
||||
if err := cw.tm.loadUserCerts(); err != nil {
|
||||
log.Printf("Failed to reload certificates: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
cw.updateModTimes(certInfo, keyInfo)
|
||||
log.Printf("Certificates reloaded successfully")
|
||||
return false
|
||||
}
|
||||
|
||||
func (cw *certWatcher) switchToCertMagic() bool {
|
||||
log.Printf("New certificates don't cover required domains")
|
||||
|
||||
if err := cw.tm.initCertMagic(); err != nil {
|
||||
log.Printf("Failed to initialize CertMagic: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
cw.tm.useCertMagic = true
|
||||
return true
|
||||
}
|
||||
|
||||
func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) {
|
||||
cw.lastCertMod = certInfo.ModTime()
|
||||
cw.lastKeyMod = keyInfo.ModTime()
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,14 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type Transport interface {
|
||||
Listen() (net.Listener, error)
|
||||
Serve(listener net.Listener) error
|
||||
}
|
||||
|
||||
type HTTP interface {
|
||||
Handler(conn net.Conn, isTLS bool)
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVersionFunctions(t *testing.T) {
|
||||
origVersion := Version
|
||||
origBuildDate := BuildDate
|
||||
origCommit := Commit
|
||||
defer func() {
|
||||
Version = origVersion
|
||||
BuildDate = origBuildDate
|
||||
Commit = origCommit
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
version string
|
||||
buildDate string
|
||||
commit string
|
||||
wantFull string
|
||||
wantShort string
|
||||
}{
|
||||
{
|
||||
name: "Default dev version",
|
||||
version: "dev",
|
||||
buildDate: "unknown",
|
||||
commit: "unknown",
|
||||
wantFull: "tunnel_pls dev (commit: unknown, built: unknown)",
|
||||
wantShort: "dev",
|
||||
},
|
||||
{
|
||||
name: "Release version",
|
||||
version: "v1.0.0",
|
||||
buildDate: "2026-01-23",
|
||||
commit: "abcdef123",
|
||||
wantFull: "tunnel_pls v1.0.0 (commit: abcdef123, built: 2026-01-23)",
|
||||
wantShort: "v1.0.0",
|
||||
},
|
||||
{
|
||||
name: "Empty values",
|
||||
version: "",
|
||||
buildDate: "",
|
||||
commit: "",
|
||||
wantFull: "tunnel_pls (commit: , built: )",
|
||||
wantShort: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
Version = tt.version
|
||||
BuildDate = tt.buildDate
|
||||
Commit = tt.commit
|
||||
|
||||
gotFull := GetVersion()
|
||||
if gotFull != tt.wantFull {
|
||||
t.Errorf("GetVersion() = %q, want %q", gotFull, tt.wantFull)
|
||||
}
|
||||
|
||||
gotShort := GetShortVersion()
|
||||
if gotShort != tt.wantShort {
|
||||
t.Errorf("GetShortVersion() = %q, want %q", gotShort, tt.wantShort)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetVersion_Format(t *testing.T) {
|
||||
v := "1.2.3"
|
||||
c := "brainrot"
|
||||
d := "now"
|
||||
|
||||
Version = v
|
||||
Commit = c
|
||||
BuildDate = d
|
||||
|
||||
expected := fmt.Sprintf("tunnel_pls %s (commit: %s, built: %s)", v, c, d)
|
||||
if GetVersion() != expected {
|
||||
t.Errorf("GetVersion() formatting mismatch")
|
||||
}
|
||||
}
|
||||
@@ -3,16 +3,11 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"tunnel_pls/internal/bootstrap"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/key"
|
||||
"tunnel_pls/server"
|
||||
"tunnel_pls/session"
|
||||
"tunnel_pls/version"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/version"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -23,47 +18,19 @@ func main() {
|
||||
|
||||
log.SetOutput(os.Stdout)
|
||||
log.SetFlags(log.LstdFlags | log.Lshortfile)
|
||||
|
||||
log.Printf("Starting %s", version.GetVersion())
|
||||
|
||||
pprofEnabled := config.Getenv("PPROF_ENABLED", "false")
|
||||
if pprofEnabled == "true" {
|
||||
pprofPort := config.Getenv("PPROF_PORT", "6060")
|
||||
go func() {
|
||||
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
|
||||
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
|
||||
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
|
||||
log.Printf("pprof server error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
sshConfig := &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
ServerVersion: fmt.Sprintf("SSH-2.0-TunnlPls-%s", version.GetShortVersion()),
|
||||
}
|
||||
|
||||
sshKeyPath := "certs/ssh/id_rsa"
|
||||
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
|
||||
log.Fatalf("Failed to generate SSH key: %s", err)
|
||||
}
|
||||
|
||||
privateBytes, err := os.ReadFile(sshKeyPath)
|
||||
conf, err := config.MustLoad()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load private key: %s", err)
|
||||
log.Fatalf("Config load error: %v", err)
|
||||
}
|
||||
|
||||
private, err := ssh.ParsePrivateKey(privateBytes)
|
||||
boot, err := bootstrap.New(conf, port.New())
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse private key: %s", err)
|
||||
log.Fatalf("Startup error: %v", err)
|
||||
}
|
||||
|
||||
sshConfig.AddHostKey(private)
|
||||
sessionRegistry := session.NewRegistry()
|
||||
|
||||
app, err := server.NewServer(sshConfig, sessionRegistry)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to start server: %s", err)
|
||||
if err = boot.Run(); err != nil {
|
||||
log.Fatalf("Application error: %v", err)
|
||||
}
|
||||
app.Start()
|
||||
}
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
module.exports = {
|
||||
"endpoint": "https://git.fossy.my.id/api/v1",
|
||||
"gitAuthor": "Renovate-Clanker <renovate-bot@fossy.my.id>",
|
||||
"platform": "gitea",
|
||||
"onboardingConfigFileName": "renovate.json",
|
||||
"autodiscover": true,
|
||||
"optimizeForDisabled": true,
|
||||
};
|
||||
@@ -1,276 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type HeaderManager interface {
|
||||
Get(key string) []byte
|
||||
Set(key string, value []byte)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
}
|
||||
|
||||
type ResponseHeaderManager interface {
|
||||
Get(key string) string
|
||||
Set(key string, value string)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
}
|
||||
|
||||
type RequestHeaderManager interface {
|
||||
Get(key string) string
|
||||
Set(key string, value string)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
GetMethod() string
|
||||
GetPath() string
|
||||
GetVersion() string
|
||||
}
|
||||
|
||||
type responseHeaderFactory struct {
|
||||
startLine []byte
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
type requestHeaderFactory struct {
|
||||
method string
|
||||
path string
|
||||
version string
|
||||
startLine []byte
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) {
|
||||
switch v := r.(type) {
|
||||
case []byte:
|
||||
return parseHeadersFromBytes(v)
|
||||
case *bufio.Reader:
|
||||
return parseHeadersFromReader(v)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type: %T", r)
|
||||
}
|
||||
}
|
||||
|
||||
func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) {
|
||||
header := &requestHeaderFactory{
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
lineEnd := bytes.IndexByte(headerData, '\n')
|
||||
if lineEnd == -1 {
|
||||
return nil, fmt.Errorf("invalid request: no newline found")
|
||||
}
|
||||
|
||||
startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n")
|
||||
header.startLine = make([]byte, len(startLine))
|
||||
copy(header.startLine, startLine)
|
||||
|
||||
parts := bytes.Split(startLine, []byte{' '})
|
||||
if len(parts) < 3 {
|
||||
return nil, fmt.Errorf("invalid request line")
|
||||
}
|
||||
|
||||
header.method = string(parts[0])
|
||||
header.path = string(parts[1])
|
||||
header.version = string(parts[2])
|
||||
|
||||
remaining := headerData[lineEnd+1:]
|
||||
|
||||
for len(remaining) > 0 {
|
||||
lineEnd = bytes.IndexByte(remaining, '\n')
|
||||
if lineEnd == -1 {
|
||||
lineEnd = len(remaining)
|
||||
}
|
||||
|
||||
line := bytes.TrimRight(remaining[:lineEnd], "\r\n")
|
||||
|
||||
if len(line) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
colonIdx := bytes.IndexByte(line, ':')
|
||||
if colonIdx != -1 {
|
||||
key := bytes.TrimSpace(line[:colonIdx])
|
||||
value := bytes.TrimSpace(line[colonIdx+1:])
|
||||
header.headers[string(key)] = string(value)
|
||||
}
|
||||
|
||||
if lineEnd == len(remaining) {
|
||||
break
|
||||
}
|
||||
remaining = remaining[lineEnd+1:]
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) {
|
||||
header := &requestHeaderFactory{
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
startLineBytes, err := br.ReadSlice('\n')
|
||||
if err != nil {
|
||||
if err == bufio.ErrBufferFull {
|
||||
var startLine string
|
||||
startLine, err = br.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
startLineBytes = []byte(startLine)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
|
||||
header.startLine = make([]byte, len(startLineBytes))
|
||||
copy(header.startLine, startLineBytes)
|
||||
|
||||
parts := bytes.Split(startLineBytes, []byte{' '})
|
||||
if len(parts) < 3 {
|
||||
return nil, fmt.Errorf("invalid request line")
|
||||
}
|
||||
|
||||
header.method = string(parts[0])
|
||||
header.path = string(parts[1])
|
||||
header.version = string(parts[2])
|
||||
|
||||
for {
|
||||
lineBytes, err := br.ReadSlice('\n')
|
||||
if err != nil {
|
||||
if err == bufio.ErrBufferFull {
|
||||
var line string
|
||||
line, err = br.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lineBytes = []byte(line)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
|
||||
|
||||
if len(lineBytes) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
colonIdx := bytes.IndexByte(lineBytes, ':')
|
||||
if colonIdx == -1 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := bytes.TrimSpace(lineBytes[:colonIdx])
|
||||
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
|
||||
|
||||
header.headers[string(key)] = string(value)
|
||||
}
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
|
||||
header := &responseHeaderFactory{
|
||||
startLine: nil,
|
||||
headers: make(map[string]string),
|
||||
}
|
||||
lines := bytes.Split(startLine, []byte("\r\n"))
|
||||
if len(lines) == 0 {
|
||||
return header
|
||||
}
|
||||
header.startLine = lines[0]
|
||||
for _, h := range lines[1:] {
|
||||
if len(h) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := bytes.SplitN(h, []byte(":"), 2)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
val := bytes.TrimSpace(parts[1])
|
||||
header.headers[string(key)] = string(val)
|
||||
}
|
||||
return header
|
||||
}
|
||||
|
||||
func (resp *responseHeaderFactory) Get(key string) string {
|
||||
return resp.headers[key]
|
||||
}
|
||||
|
||||
func (resp *responseHeaderFactory) Set(key string, value string) {
|
||||
resp.headers[key] = value
|
||||
}
|
||||
|
||||
func (resp *responseHeaderFactory) Remove(key string) {
|
||||
delete(resp.headers, key)
|
||||
}
|
||||
|
||||
func (resp *responseHeaderFactory) Finalize() []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
buf.Write(resp.startLine)
|
||||
buf.WriteString("\r\n")
|
||||
|
||||
for key, val := range resp.headers {
|
||||
buf.WriteString(key)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(val)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
|
||||
buf.WriteString("\r\n")
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) Get(key string) string {
|
||||
val, ok := req.headers[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) Set(key string, value string) {
|
||||
req.headers[key] = value
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) Remove(key string) {
|
||||
delete(req.headers, key)
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) GetMethod() string {
|
||||
return req.method
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) GetPath() string {
|
||||
return req.path
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) GetVersion() string {
|
||||
return req.version
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) Finalize() []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
buf.Write(req.startLine)
|
||||
buf.WriteString("\r\n")
|
||||
|
||||
for key, val := range req.headers {
|
||||
buf.WriteString(key)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(val)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
|
||||
buf.WriteString("\r\n")
|
||||
return buf.Bytes()
|
||||
}
|
||||
-391
@@ -1,391 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/session"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type HTTPWriter interface {
|
||||
io.Reader
|
||||
io.Writer
|
||||
GetRemoteAddr() net.Addr
|
||||
GetWriter() io.Writer
|
||||
AddResponseMiddleware(mw ResponseMiddleware)
|
||||
AddRequestStartMiddleware(mw RequestMiddleware)
|
||||
SetRequestHeader(header RequestHeaderManager)
|
||||
GetRequestStartMiddleware() []RequestMiddleware
|
||||
}
|
||||
|
||||
type customWriter struct {
|
||||
remoteAddr net.Addr
|
||||
writer io.Writer
|
||||
reader io.Reader
|
||||
headerBuf []byte
|
||||
buf []byte
|
||||
respHeader ResponseHeaderManager
|
||||
reqHeader RequestHeaderManager
|
||||
respMW []ResponseMiddleware
|
||||
reqStartMW []RequestMiddleware
|
||||
reqEndMW []RequestMiddleware
|
||||
}
|
||||
|
||||
func (cw *customWriter) GetRemoteAddr() net.Addr {
|
||||
return cw.remoteAddr
|
||||
}
|
||||
|
||||
func (cw *customWriter) GetWriter() io.Writer {
|
||||
return cw.writer
|
||||
}
|
||||
|
||||
func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) {
|
||||
cw.respMW = append(cw.respMW, mw)
|
||||
}
|
||||
|
||||
func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) {
|
||||
cw.reqStartMW = append(cw.reqStartMW, mw)
|
||||
}
|
||||
|
||||
func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) {
|
||||
cw.reqHeader = header
|
||||
}
|
||||
|
||||
func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware {
|
||||
return cw.reqStartMW
|
||||
}
|
||||
|
||||
func (cw *customWriter) Read(p []byte) (int, error) {
|
||||
tmp := make([]byte, len(p))
|
||||
read, err := cw.reader.Read(tmp)
|
||||
if read == 0 && err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
tmp = tmp[:read]
|
||||
|
||||
idx := bytes.Index(tmp, DELIMITER)
|
||||
if idx == -1 {
|
||||
copy(p, tmp)
|
||||
if err != nil {
|
||||
return read, err
|
||||
}
|
||||
return read, nil
|
||||
}
|
||||
|
||||
header := tmp[:idx+len(DELIMITER)]
|
||||
body := tmp[idx+len(DELIMITER):]
|
||||
|
||||
if !isHTTPHeader(header) {
|
||||
copy(p, tmp)
|
||||
return read, nil
|
||||
}
|
||||
|
||||
for _, m := range cw.reqEndMW {
|
||||
err = m.HandleRequest(cw.reqHeader)
|
||||
if err != nil {
|
||||
log.Printf("Error when applying request middleware: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
reqhf, err := NewRequestHeaderFactory(header)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for _, m := range cw.reqStartMW {
|
||||
if mwErr := m.HandleRequest(reqhf); mwErr != nil {
|
||||
log.Printf("Error when applying request middleware: %v", mwErr)
|
||||
return 0, mwErr
|
||||
}
|
||||
}
|
||||
|
||||
cw.reqHeader = reqhf
|
||||
finalHeader := reqhf.Finalize()
|
||||
|
||||
combined := append(finalHeader, body...)
|
||||
|
||||
n := copy(p, combined)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
|
||||
return &customWriter{
|
||||
remoteAddr: remoteAddr,
|
||||
writer: writer,
|
||||
reader: reader,
|
||||
buf: make([]byte, 0, 4096),
|
||||
}
|
||||
}
|
||||
|
||||
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
|
||||
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
|
||||
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
|
||||
|
||||
func isHTTPHeader(buf []byte) bool {
|
||||
lines := bytes.Split(buf, []byte("\r\n"))
|
||||
|
||||
startLine := string(lines[0])
|
||||
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, line := range lines[1:] {
|
||||
if len(line) == 0 {
|
||||
break
|
||||
}
|
||||
colonIdx := bytes.IndexByte(line, ':')
|
||||
if colonIdx <= 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (cw *customWriter) Write(p []byte) (int, error) {
|
||||
if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" {
|
||||
cw.respHeader = nil
|
||||
}
|
||||
|
||||
if cw.respHeader != nil {
|
||||
n, err := cw.writer.Write(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
cw.buf = append(cw.buf, p...)
|
||||
|
||||
idx := bytes.Index(cw.buf, DELIMITER)
|
||||
if idx == -1 {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
header := cw.buf[:idx+len(DELIMITER)]
|
||||
body := cw.buf[idx+len(DELIMITER):]
|
||||
|
||||
if !isHTTPHeader(header) {
|
||||
_, err := cw.writer.Write(cw.buf)
|
||||
cw.buf = nil
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
resphf := NewResponseHeaderFactory(header)
|
||||
for _, m := range cw.respMW {
|
||||
err := m.HandleResponse(resphf, body)
|
||||
if err != nil {
|
||||
log.Printf("Cannot apply middleware: %s\n", err)
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
header = resphf.Finalize()
|
||||
cw.respHeader = resphf
|
||||
|
||||
_, err := cw.writer.Write(header)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(body) > 0 {
|
||||
_, err = cw.writer.Write(body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
cw.buf = nil
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
var redirectTLS = false
|
||||
|
||||
type HTTPServer interface {
|
||||
ListenAndServe() error
|
||||
ListenAndServeTLS() error
|
||||
handler(conn net.Conn)
|
||||
handlerTLS(conn net.Conn)
|
||||
}
|
||||
type httpServer struct {
|
||||
sessionRegistry session.Registry
|
||||
}
|
||||
|
||||
func NewHTTPServer(sessionRegistry session.Registry) HTTPServer {
|
||||
return &httpServer{sessionRegistry: sessionRegistry}
|
||||
}
|
||||
|
||||
func (hs *httpServer) ListenAndServe() error {
|
||||
httpPort := config.Getenv("HTTP_PORT", "8080")
|
||||
listener, err := net.Listen("tcp", ":"+httpPort)
|
||||
if err != nil {
|
||||
return errors.New("Error listening: " + err.Error())
|
||||
}
|
||||
if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" {
|
||||
redirectTLS = true
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
var conn net.Conn
|
||||
conn, err = listener.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go hs.handler(conn)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *httpServer) handler(conn net.Conn) {
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error closing connection: %v", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}()
|
||||
|
||||
dstReader := bufio.NewReader(conn)
|
||||
reqhf, err := NewRequestHeaderFactory(dstReader)
|
||||
if err != nil {
|
||||
log.Printf("Error creating request header: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
host := strings.Split(reqhf.Get("Host"), ".")
|
||||
if len(host) < 1 {
|
||||
_, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 400 Bad Request:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slug := host[0]
|
||||
|
||||
if redirectTLS {
|
||||
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
|
||||
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) +
|
||||
"Content-Length: 0\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"\r\n"))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 301 Moved Permanently:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if slug == "ping" {
|
||||
_, err = conn.Write([]byte(
|
||||
"HTTP/1.1 200 OK\r\n" +
|
||||
"Content-Length: 0\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"Access-Control-Allow-Origin: *\r\n" +
|
||||
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
|
||||
"Access-Control-Allow-Headers: *\r\n" +
|
||||
"\r\n",
|
||||
))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 200 OK:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
sshSession, exist := hs.sessionRegistry.Get(slug)
|
||||
if !exist {
|
||||
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
|
||||
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
|
||||
"Content-Length: 0\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"\r\n"))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 301 Moved Permanently:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
||||
forwardRequest(cw, reqhf, sshSession)
|
||||
return
|
||||
}
|
||||
|
||||
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) {
|
||||
payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
|
||||
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
}()
|
||||
|
||||
var channel ssh.Channel
|
||||
var reqs <-chan *ssh.Request
|
||||
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
|
||||
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
|
||||
return
|
||||
}
|
||||
channel = result.channel
|
||||
reqs = result.reqs
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("Timeout opening forwarded-tcpip channel")
|
||||
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
|
||||
return
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
fingerprintMiddleware := NewTunnelFingerprint()
|
||||
forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr())
|
||||
|
||||
cw.AddResponseMiddleware(fingerprintMiddleware)
|
||||
cw.AddRequestStartMiddleware(forwardedForMiddleware)
|
||||
cw.SetRequestHeader(initialRequest)
|
||||
|
||||
for _, m := range cw.GetRequestStartMiddleware() {
|
||||
if err := m.HandleRequest(initialRequest); err != nil {
|
||||
log.Printf("Error handling request: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
_, err := channel.Write(initialRequest.Finalize())
|
||||
if err != nil {
|
||||
log.Printf("Failed to forward request: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
|
||||
return
|
||||
}
|
||||
-108
@@ -1,108 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"tunnel_pls/internal/config"
|
||||
)
|
||||
|
||||
func (hs *httpServer) ListenAndServeTLS() error {
|
||||
domain := config.Getenv("DOMAIN", "localhost")
|
||||
httpsPort := config.Getenv("HTTPS_PORT", "8443")
|
||||
|
||||
tlsConfig, err := NewTLSConfig(domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize TLS config: %w", err)
|
||||
}
|
||||
|
||||
ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
var conn net.Conn
|
||||
conn, err = ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
log.Println("https server closed")
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go hs.handlerTLS(conn)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *httpServer) handlerTLS(conn net.Conn) {
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
if err != nil {
|
||||
log.Printf("Error closing connection: %v", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}()
|
||||
|
||||
dstReader := bufio.NewReader(conn)
|
||||
reqhf, err := NewRequestHeaderFactory(dstReader)
|
||||
if err != nil {
|
||||
log.Printf("Error creating request header: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
host := strings.Split(reqhf.Get("Host"), ".")
|
||||
if len(host) < 1 {
|
||||
_, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 400 Bad Request:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
slug := host[0]
|
||||
|
||||
if slug == "ping" {
|
||||
_, err = conn.Write([]byte(
|
||||
"HTTP/1.1 200 OK\r\n" +
|
||||
"Content-Length: 0\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"Access-Control-Allow-Origin: *\r\n" +
|
||||
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
|
||||
"Access-Control-Allow-Headers: *\r\n" +
|
||||
"\r\n",
|
||||
))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 200 OK:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
sshSession, exist := hs.sessionRegistry.Get(slug)
|
||||
if !exist {
|
||||
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
|
||||
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
|
||||
"Content-Length: 0\r\n" +
|
||||
"Connection: close\r\n" +
|
||||
"\r\n"))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 301 Moved Permanently:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
|
||||
forwardRequest(cw, reqhf, sshSession)
|
||||
return
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type RequestMiddleware interface {
|
||||
HandleRequest(header RequestHeaderManager) error
|
||||
}
|
||||
|
||||
type ResponseMiddleware interface {
|
||||
HandleResponse(header ResponseHeaderManager, body []byte) error
|
||||
}
|
||||
|
||||
type TunnelFingerprint struct{}
|
||||
|
||||
func NewTunnelFingerprint() *TunnelFingerprint {
|
||||
return &TunnelFingerprint{}
|
||||
}
|
||||
|
||||
func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
|
||||
header.Set("Server", "Tunnel Please")
|
||||
return nil
|
||||
}
|
||||
|
||||
type ForwardedFor struct {
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
func NewForwardedFor(addr net.Addr) *ForwardedFor {
|
||||
return &ForwardedFor{addr: addr}
|
||||
}
|
||||
|
||||
func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
|
||||
host, _, err := net.SplitHostPort(ff.addr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
header.Set("X-Forwarded-For", host)
|
||||
return nil
|
||||
}
|
||||
+70
-35
@@ -1,55 +1,65 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/grpc/client"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
conn *net.Listener
|
||||
config *ssh.ServerConfig
|
||||
sessionRegistry session.Registry
|
||||
type Server interface {
|
||||
Start()
|
||||
Close() error
|
||||
}
|
||||
type server struct {
|
||||
randomizer random.Random
|
||||
config config.Config
|
||||
sshPort string
|
||||
sshListener net.Listener
|
||||
sshConfig *ssh.ServerConfig
|
||||
grpcClient client.Client
|
||||
sessionRegistry registry.Registry
|
||||
portRegistry port.Port
|
||||
}
|
||||
|
||||
func NewServer(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry) (*Server, error) {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200")))
|
||||
func New(randomizer random.Random, config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
|
||||
if err != nil {
|
||||
log.Fatalf("failed to listen on port 2200: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
HttpServer := NewHTTPServer(sessionRegistry)
|
||||
err = HttpServer.ListenAndServe()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to start http server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.Getenv("TLS_ENABLED", "false") == "true" {
|
||||
err = HttpServer.ListenAndServeTLS()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to start https server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Server{
|
||||
conn: &listener,
|
||||
config: sshConfig,
|
||||
return &server{
|
||||
randomizer: randomizer,
|
||||
config: config,
|
||||
sshPort: sshPort,
|
||||
sshListener: listener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: grpcClient,
|
||||
sessionRegistry: sessionRegistry,
|
||||
portRegistry: portRegistry,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Start() {
|
||||
log.Println("SSH server is starting on port 2200...")
|
||||
func (s *server) Start() {
|
||||
log.Printf("SSH server is starting on port %s", s.sshPort)
|
||||
for {
|
||||
conn, err := (*s.conn).Accept()
|
||||
conn, err := s.sshListener.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
log.Println("listener closed, stopping server")
|
||||
return
|
||||
}
|
||||
log.Printf("failed to accept connection: %v", err)
|
||||
continue
|
||||
}
|
||||
@@ -58,11 +68,15 @@ func (s *Server) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
|
||||
func (s *server) Close() error {
|
||||
return s.sshListener.Close()
|
||||
}
|
||||
|
||||
func (s *server) handleConnection(conn net.Conn) {
|
||||
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig)
|
||||
if err != nil {
|
||||
log.Printf("failed to establish SSH connection: %v", err)
|
||||
err := conn.Close()
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close SSH connection: %v", err)
|
||||
return
|
||||
@@ -70,13 +84,34 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("SSH connection established:", sshConn.User())
|
||||
defer func(sshConn *ssh.ServerConn) {
|
||||
err = sshConn.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
|
||||
log.Printf("failed to close SSH server: %v", err)
|
||||
}
|
||||
}(sshConn)
|
||||
|
||||
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry)
|
||||
user := "UNAUTHORIZED"
|
||||
if s.grpcClient != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
_, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User())
|
||||
user = u
|
||||
cancel()
|
||||
}
|
||||
log.Println("SSH connection established:", sshConn.User())
|
||||
sshSession := session.New(&session.Config{
|
||||
Randomizer: s.randomizer,
|
||||
Config: s.config,
|
||||
Conn: sshConn,
|
||||
InitialReq: forwardingReqs,
|
||||
SshChan: chans,
|
||||
SessionRegistry: s.sessionRegistry,
|
||||
PortRegistry: s.portRegistry,
|
||||
User: user,
|
||||
})
|
||||
err = sshSession.Start()
|
||||
if err != nil {
|
||||
log.Printf("SSH session ended with error: %v", err)
|
||||
log.Printf("SSH session ended with error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,880 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type MockRandom struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRandom) String(length int) (string, error) {
|
||||
args := m.Called(length)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
type MockConfig struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConfig) Domain() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return 0
|
||||
}
|
||||
switch v := args.Get(0).(type) {
|
||||
case types.ServerMode:
|
||||
return v
|
||||
case int:
|
||||
return types.ServerMode(v)
|
||||
default:
|
||||
return types.ServerMode(args.Int(0))
|
||||
}
|
||||
}
|
||||
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
args := m.Called(user, oldKey, newKey)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
args := m.Called(key, session)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
args := m.Called(user)
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
type MockGRPCClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*grpc.ClientConn)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
||||
args := m.Called(ctx, token)
|
||||
return args.Bool(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
||||
args := m.Called(ctx)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
|
||||
args := m.Called(ctx, domain, token)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockGRPCClient) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockPort struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
||||
return m.Called(startPort, endPort).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
||||
args := m.Called()
|
||||
return uint16(args.Int(0)), args.Bool(1)
|
||||
}
|
||||
|
||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
||||
return m.Called(port, assigned).Error(0)
|
||||
}
|
||||
|
||||
func (m *MockPort) Claim(port uint16) bool {
|
||||
return m.Called(port).Bool(0)
|
||||
}
|
||||
|
||||
type MockListener struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockListener) Accept() (net.Conn, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(net.Conn), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockListener) Close() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func (m *MockListener) Addr() net.Addr {
|
||||
return m.Called().Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
signer, _ := ssh.NewSignerFromKey(key)
|
||||
config := &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
}
|
||||
config.AddHostKey(signer)
|
||||
return config, signer
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
port: "0",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
port: "invalid",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, err := New(mr, mc, sc, mreg, mg, mp, tt.port)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
_ = s.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("port already in use", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := l.Addr().(*net.TCPAddr).Port
|
||||
defer func(l net.Listener) {
|
||||
err = l.Close()
|
||||
assert.NoError(t, err)
|
||||
}(l)
|
||||
|
||||
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
t.Run("successful close", func(t *testing.T) {
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
err := s.Close()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("close already closed listener", func(t *testing.T) {
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
_ = s.Close()
|
||||
err := s.Close()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("close with nil listener", func(t *testing.T) {
|
||||
s := &server{
|
||||
sshListener: nil,
|
||||
}
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
assert.NotNil(t, r)
|
||||
}
|
||||
}()
|
||||
_ = s.Close()
|
||||
t.Fatal("expected panic for nil listener")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
t.Run("normal stop", func(t *testing.T) {
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_ = s.Close()
|
||||
}()
|
||||
s.Start()
|
||||
})
|
||||
|
||||
t.Run("accept error - temporary error continues loop", func(t *testing.T) {
|
||||
ml := new(MockListener)
|
||||
s := &server{
|
||||
sshListener: ml,
|
||||
sshPort: "0",
|
||||
}
|
||||
|
||||
ml.On("Accept").Return(nil, errors.New("temporary error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s.Start()
|
||||
ml.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("accept error - immediate close", func(t *testing.T) {
|
||||
ml := new(MockListener)
|
||||
s := &server{
|
||||
sshListener: ml,
|
||||
sshPort: "0",
|
||||
}
|
||||
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s.Start()
|
||||
ml.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("accept success - connection fails SSH handshake", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
mockListener := &MockListener{}
|
||||
mockListener.On("Accept").Return(serverConn, nil).Once()
|
||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshListener: mockListener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
})
|
||||
|
||||
t.Run("accept success - valid SSH connection without auth", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
mockListener := &MockListener{}
|
||||
mockListener.On("Accept").Return(serverConn, nil).Once()
|
||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshListener: mockListener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleConnection(t *testing.T) {
|
||||
t.Run("SSH handshake fails - connection closed", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
s.handleConnection(serverConn)
|
||||
})
|
||||
|
||||
// SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS
|
||||
// GONNA IMPLEMENT THIS UNIT TEST LATER
|
||||
|
||||
//t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) {
|
||||
// mockRandom := &MockRandom{}
|
||||
// mockConfig := &MockConfig{}
|
||||
// mockSessionRegistry := &MockSessionRegistry{}
|
||||
// mockGrpcClient := &MockGRPCClient{}
|
||||
// mockPort := &MockPort{}
|
||||
//
|
||||
// sshConfig, _ := getTestSSHConfig()
|
||||
//
|
||||
// serverConn, clientConn := net.Pipe()
|
||||
//
|
||||
// s := &server{
|
||||
// randomizer: mockRandom,
|
||||
// config: mockConfig,
|
||||
// sshPort: "0",
|
||||
// sshConfig: sshConfig,
|
||||
// grpcClient: mockGrpcClient,
|
||||
// sessionRegistry: mockSessionRegistry,
|
||||
// portRegistry: mockPort,
|
||||
// }
|
||||
//
|
||||
// done := make(chan bool, 1)
|
||||
//
|
||||
// go func() {
|
||||
// s.handleConnection(serverConn)
|
||||
// done <- true
|
||||
// }()
|
||||
//
|
||||
// go func() {
|
||||
// clientConn.Write([]byte("invalid ssh protocol\n"))
|
||||
// clientConn.Close()
|
||||
// }()
|
||||
//
|
||||
// select {
|
||||
// case <-done:
|
||||
// case <-time.After(1 * time.Second):
|
||||
// t.Fatal("handleConnection did not complete in time")
|
||||
// }
|
||||
//})
|
||||
|
||||
t.Run("SSH connection established without gRPC client", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
mockConfig.On("Domain").Return("test.com")
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("SSHPort").Return("2200")
|
||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func(listener net.Listener) {
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}(listener)
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConnection(conn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
||||
if err != nil {
|
||||
t.Logf("Client dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(client *ssh.Client) {
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}(client)
|
||||
|
||||
type forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
payload := ssh.Marshal(forwardPayload{
|
||||
BindAddr: "localhost",
|
||||
BindPort: 80,
|
||||
})
|
||||
|
||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
||||
if err != nil {
|
||||
t.Logf("Forward request failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Log("handleConnection completed")
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SSH connection established with gRPC authorization", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
||||
mockConfig.On("Domain").Return("test.com")
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("SSHPort").Return("2200")
|
||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func(listener net.Listener) {
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}(listener)
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConnection(conn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
||||
if err != nil {
|
||||
t.Logf("Client dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(client *ssh.Client) {
|
||||
err = client.Close()
|
||||
assert.NoError(t, err)
|
||||
}(client)
|
||||
|
||||
type forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
payload := ssh.Marshal(forwardPayload{
|
||||
BindAddr: "localhost",
|
||||
BindPort: 80,
|
||||
})
|
||||
|
||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
||||
if err != nil {
|
||||
t.Logf("Forward request failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
mockGrpcClient.AssertExpectations(t)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("SSH connection with gRPC authorization error", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockGrpcClient := &MockGRPCClient{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
||||
mockConfig.On("Domain").Return("test.com")
|
||||
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
||||
mockConfig.On("SSHPort").Return("2200")
|
||||
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
||||
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func(listener net.Listener) {
|
||||
err = listener.Close()
|
||||
assert.NoError(t, err)
|
||||
}(listener)
|
||||
|
||||
serverAddr := listener.Addr().String()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: mockGrpcClient,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.handleConnection(conn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
clientConfig := &ssh.ClientConfig{
|
||||
User: "testuser",
|
||||
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
||||
if err != nil {
|
||||
t.Logf("Client dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
defer func(client *ssh.Client) {
|
||||
_ = client.Close()
|
||||
}(client)
|
||||
|
||||
type forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
payload := ssh.Marshal(forwardPayload{
|
||||
BindAddr: "localhost",
|
||||
BindPort: 8080,
|
||||
})
|
||||
|
||||
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
||||
if err != nil {
|
||||
t.Logf("Forward request failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
mockGrpcClient.AssertExpectations(t)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("connection cleanup on close", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
serverConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: serverConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
done := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
s.handleConnection(serverConn)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("handleConnection did not complete in time")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
t.Run("full server lifecycle", func(t *testing.T) {
|
||||
mr := new(MockRandom)
|
||||
mc := new(MockConfig)
|
||||
mreg := new(MockSessionRegistry)
|
||||
mg := new(MockGRPCClient)
|
||||
mp := new(MockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
s, err := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err := s.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
s.Start()
|
||||
})
|
||||
|
||||
t.Run("multiple connections", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
conn1Server, conn1Client := net.Pipe()
|
||||
conn2Server, conn2Client := net.Pipe()
|
||||
|
||||
mockListener := &MockListener{}
|
||||
mockListener.On("Accept").Return(conn1Server, nil).Once()
|
||||
mockListener.On("Accept").Return(conn2Server, nil).Once()
|
||||
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshListener: mockListener,
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
go s.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = conn1Client.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_ = conn2Client.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mockListener.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorHandling(t *testing.T) {
|
||||
t.Run("write error during SSH handshake", func(t *testing.T) {
|
||||
mockRandom := &MockRandom{}
|
||||
mockConfig := &MockConfig{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockPort := &MockPort{}
|
||||
|
||||
sshConfig, _ := getTestSSHConfig()
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
err := clientConn.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
s := &server{
|
||||
randomizer: mockRandom,
|
||||
config: mockConfig,
|
||||
sshPort: "0",
|
||||
sshConfig: sshConfig,
|
||||
grpcClient: nil,
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
portRegistry: mockPort,
|
||||
}
|
||||
|
||||
s.handleConnection(serverConn)
|
||||
})
|
||||
}
|
||||
-336
@@ -1,336 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
"github.com/libdns/cloudflare"
|
||||
)
|
||||
|
||||
type TLSManager interface {
|
||||
userCertsExistAndValid() bool
|
||||
loadUserCerts() error
|
||||
startCertWatcher()
|
||||
initCertMagic() error
|
||||
getTLSConfig() *tls.Config
|
||||
getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
|
||||
}
|
||||
|
||||
type tlsManager struct {
|
||||
domain string
|
||||
certPath string
|
||||
keyPath string
|
||||
storagePath string
|
||||
|
||||
userCert *tls.Certificate
|
||||
userCertMu sync.RWMutex
|
||||
|
||||
magic *certmagic.Config
|
||||
|
||||
useCertMagic bool
|
||||
}
|
||||
|
||||
var globalTLSManager TLSManager
|
||||
var tlsManagerOnce sync.Once
|
||||
|
||||
func NewTLSConfig(domain string) (*tls.Config, error) {
|
||||
var initErr error
|
||||
|
||||
tlsManagerOnce.Do(func() {
|
||||
certPath := "certs/tls/cert.pem"
|
||||
keyPath := "certs/tls/privkey.pem"
|
||||
storagePath := "certs/tls/certmagic"
|
||||
|
||||
tm := &tlsManager{
|
||||
domain: domain,
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
storagePath: storagePath,
|
||||
}
|
||||
|
||||
if tm.userCertsExistAndValid() {
|
||||
log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath)
|
||||
if err := tm.loadUserCerts(); err != nil {
|
||||
initErr = fmt.Errorf("failed to load user certificates: %w", err)
|
||||
return
|
||||
}
|
||||
tm.useCertMagic = false
|
||||
tm.startCertWatcher()
|
||||
} else {
|
||||
if !isACMEConfigComplete() {
|
||||
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
|
||||
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
|
||||
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
|
||||
if err := tm.initCertMagic(); err != nil {
|
||||
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
|
||||
return
|
||||
}
|
||||
tm.useCertMagic = true
|
||||
}
|
||||
|
||||
globalTLSManager = tm
|
||||
})
|
||||
|
||||
if initErr != nil {
|
||||
return nil, initErr
|
||||
}
|
||||
|
||||
return globalTLSManager.getTLSConfig(), nil
|
||||
}
|
||||
|
||||
func isACMEConfigComplete() bool {
|
||||
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
|
||||
return cfAPIToken != ""
|
||||
}
|
||||
|
||||
func (tm *tlsManager) userCertsExistAndValid() bool {
|
||||
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
|
||||
log.Printf("Certificate file not found: %s", tm.certPath)
|
||||
return false
|
||||
}
|
||||
if _, err := os.Stat(tm.keyPath); os.IsNotExist(err) {
|
||||
log.Printf("Key file not found: %s", tm.keyPath)
|
||||
return false
|
||||
}
|
||||
|
||||
return ValidateCertDomains(tm.certPath, tm.domain)
|
||||
}
|
||||
|
||||
func ValidateCertDomains(certPath, domain string) bool {
|
||||
certPEM, err := os.ReadFile(certPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to read certificate: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(certPEM)
|
||||
if block == nil {
|
||||
log.Printf("Failed to decode PEM block from certificate")
|
||||
return false
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse certificate: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().After(cert.NotAfter) {
|
||||
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) {
|
||||
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
|
||||
return false
|
||||
}
|
||||
|
||||
var certDomains []string
|
||||
if cert.Subject.CommonName != "" {
|
||||
certDomains = append(certDomains, cert.Subject.CommonName)
|
||||
}
|
||||
certDomains = append(certDomains, cert.DNSNames...)
|
||||
|
||||
hasBase := false
|
||||
hasWildcard := false
|
||||
wildcardDomain := "*." + domain
|
||||
|
||||
for _, d := range certDomains {
|
||||
if d == domain {
|
||||
hasBase = true
|
||||
}
|
||||
if d == wildcardDomain {
|
||||
hasWildcard = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasBase {
|
||||
log.Printf("Certificate does not cover base domain: %s", domain)
|
||||
}
|
||||
if !hasWildcard {
|
||||
log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain)
|
||||
}
|
||||
|
||||
return hasBase && hasWildcard
|
||||
}
|
||||
|
||||
func (tm *tlsManager) loadUserCerts() error {
|
||||
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tm.userCertMu.Lock()
|
||||
tm.userCert = &cert
|
||||
tm.userCertMu.Unlock()
|
||||
|
||||
log.Printf("Loaded user certificates successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) startCertWatcher() {
|
||||
go func() {
|
||||
var lastCertMod, lastKeyMod time.Time
|
||||
|
||||
if info, err := os.Stat(tm.certPath); err == nil {
|
||||
lastCertMod = info.ModTime()
|
||||
}
|
||||
if info, err := os.Stat(tm.keyPath); err == nil {
|
||||
lastKeyMod = info.ModTime()
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
certInfo, certErr := os.Stat(tm.certPath)
|
||||
keyInfo, keyErr := os.Stat(tm.keyPath)
|
||||
|
||||
if certErr != nil || keyErr != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
|
||||
log.Printf("Certificate files changed, reloading...")
|
||||
|
||||
if !ValidateCertDomains(tm.certPath, tm.domain) {
|
||||
log.Printf("New certificates don't cover required domains")
|
||||
|
||||
if !isACMEConfigComplete() {
|
||||
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("Switching to CertMagic for automatic certificate management")
|
||||
if err := tm.initCertMagic(); err != nil {
|
||||
log.Printf("Failed to initialize CertMagic: %v", err)
|
||||
continue
|
||||
}
|
||||
tm.useCertMagic = true
|
||||
return
|
||||
}
|
||||
|
||||
if err := tm.loadUserCerts(); err != nil {
|
||||
log.Printf("Failed to reload certificates: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
lastCertMod = certInfo.ModTime()
|
||||
lastKeyMod = keyInfo.ModTime()
|
||||
log.Printf("Certificates reloaded successfully")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (tm *tlsManager) initCertMagic() error {
|
||||
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create cert storage directory: %w", err)
|
||||
}
|
||||
|
||||
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain)
|
||||
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
|
||||
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
|
||||
|
||||
if cfAPIToken == "" {
|
||||
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
|
||||
}
|
||||
|
||||
cfProvider := &cloudflare.Provider{
|
||||
APIToken: cfAPIToken,
|
||||
}
|
||||
|
||||
storage := &certmagic.FileStorage{Path: tm.storagePath}
|
||||
|
||||
cache := certmagic.NewCache(certmagic.CacheOptions{
|
||||
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
|
||||
return tm.magic, nil
|
||||
},
|
||||
})
|
||||
|
||||
magic := certmagic.New(cache, certmagic.Config{
|
||||
Storage: storage,
|
||||
})
|
||||
|
||||
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
|
||||
Email: acmeEmail,
|
||||
Agreed: true,
|
||||
DNS01Solver: &certmagic.DNS01Solver{
|
||||
DNSManager: certmagic.DNSManager{
|
||||
DNSProvider: cfProvider,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if acmeStaging {
|
||||
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
|
||||
log.Printf("Using Let's Encrypt staging server")
|
||||
} else {
|
||||
acmeIssuer.CA = certmagic.LetsEncryptProductionCA
|
||||
log.Printf("Using Let's Encrypt production server")
|
||||
}
|
||||
|
||||
magic.Issuers = []certmagic.Issuer{acmeIssuer}
|
||||
tm.magic = magic
|
||||
|
||||
domains := []string{tm.domain, "*." + tm.domain}
|
||||
log.Printf("Requesting certificates for: %v", domains)
|
||||
|
||||
ctx := context.Background()
|
||||
if err := magic.ManageSync(ctx, domains); err != nil {
|
||||
return fmt.Errorf("failed to obtain certificates: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Certificates obtained successfully for %v", domains)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *tlsManager) getTLSConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
GetCertificate: tm.getCertificate,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
|
||||
SessionTicketsDisabled: false,
|
||||
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
},
|
||||
|
||||
CurvePreferences: []tls.CurveID{
|
||||
tls.X25519,
|
||||
},
|
||||
|
||||
ClientAuth: tls.NoClientCert,
|
||||
NextProtos: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if tm.useCertMagic {
|
||||
return tm.magic.GetCertificate(hello)
|
||||
}
|
||||
|
||||
tm.userCertMu.RLock()
|
||||
defer tm.userCertMu.RUnlock()
|
||||
|
||||
if tm.userCert == nil {
|
||||
return nil, fmt.Errorf("no certificate available")
|
||||
}
|
||||
|
||||
return tm.userCert, nil
|
||||
}
|
||||
+112
-176
@@ -1,15 +1,14 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
@@ -17,239 +16,176 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var bufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
bufSize := config.GetBufferSize()
|
||||
return make([]byte, bufSize)
|
||||
},
|
||||
type Forwarder interface {
|
||||
SetType(tunnelType types.TunnelType)
|
||||
SetForwardedPort(port uint16)
|
||||
SetListener(listener net.Listener)
|
||||
Listener() net.Listener
|
||||
TunnelType() types.TunnelType
|
||||
ForwardedPort() uint16
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||
buf := bufferPool.Get().([]byte)
|
||||
defer bufferPool.Put(buf)
|
||||
return io.CopyBuffer(dst, src, buf)
|
||||
}
|
||||
|
||||
type Forwarder struct {
|
||||
type forwarder struct {
|
||||
listener net.Listener
|
||||
tunnelType types.TunnelType
|
||||
forwardedPort uint16
|
||||
slugManager slug.Manager
|
||||
lifecycle Lifecycle
|
||||
slug slug.Slug
|
||||
conn ssh.Conn
|
||||
bufferPool sync.Pool
|
||||
}
|
||||
|
||||
func NewForwarder(slugManager slug.Manager) *Forwarder {
|
||||
return &Forwarder{
|
||||
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
|
||||
return &forwarder{
|
||||
listener: nil,
|
||||
tunnelType: "",
|
||||
tunnelType: types.TunnelTypeUNKNOWN,
|
||||
forwardedPort: 0,
|
||||
slugManager: slugManager,
|
||||
lifecycle: nil,
|
||||
slug: slug,
|
||||
conn: conn,
|
||||
bufferPool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
bufSize := config.BufferSize()
|
||||
buf := make([]byte, bufSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type Lifecycle interface {
|
||||
GetConnection() ssh.Conn
|
||||
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||
buf := f.bufferPool.Get().(*[]byte)
|
||||
defer f.bufferPool.Put(buf)
|
||||
return io.CopyBuffer(dst, src, *buf)
|
||||
}
|
||||
|
||||
type ForwardingController interface {
|
||||
AcceptTCPConnections()
|
||||
SetType(tunnelType types.TunnelType)
|
||||
GetTunnelType() types.TunnelType
|
||||
GetForwardedPort() uint16
|
||||
SetForwardedPort(port uint16)
|
||||
SetListener(listener net.Listener)
|
||||
GetListener() net.Listener
|
||||
Close() error
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||
WriteBadGatewayResponse(dst io.Writer)
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
|
||||
f.lifecycle = lifecycle
|
||||
}
|
||||
|
||||
func (f *Forwarder) AcceptTCPConnections() {
|
||||
for {
|
||||
conn, err := f.GetListener().Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
log.Printf("Error accepting connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
log.Printf("Failed to set connection deadline: %v", err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
||||
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
}()
|
||||
func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
payload := createForwardedTCPIPPayload(origin, f.forwardedPort)
|
||||
type channelResult struct {
|
||||
channel ssh.Channel
|
||||
reqs <-chan *ssh.Request
|
||||
err error
|
||||
}
|
||||
resultChan := make(chan channelResult, 1)
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
if result.err != nil {
|
||||
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||
log.Printf("Failed to clear connection deadline: %v", err)
|
||||
}
|
||||
|
||||
go ssh.DiscardRequests(result.reqs)
|
||||
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
|
||||
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Printf("Timeout opening forwarded-tcpip channel")
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Printf("Failed to close connection: %v", closeErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
||||
defer func() {
|
||||
_, err := io.Copy(io.Discard, src)
|
||||
if err != nil {
|
||||
log.Printf("Failed to discard connection: %v", err)
|
||||
}
|
||||
|
||||
err = src.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error closing source channel: %v", err)
|
||||
}
|
||||
|
||||
if closer, ok := dst.(io.Closer); ok {
|
||||
err = closer.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Printf("Error closing destination connection: %v", err)
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
case <-ctx.Done():
|
||||
if channel != nil {
|
||||
_ = channel.Close()
|
||||
go ssh.DiscardRequests(reqs)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("Handling new forwarded connection from %s", remoteAddr)
|
||||
select {
|
||||
case result := <-resultChan:
|
||||
return result.channel, result.reqs, result.err
|
||||
case <-ctx.Done():
|
||||
return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func closeWriter(w io.Writer) error {
|
||||
if cw, ok := w.(interface{ CloseWrite() error }); ok {
|
||||
return cw.CloseWrite()
|
||||
}
|
||||
if closer, ok := w.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
|
||||
var errs []error
|
||||
_, err := f.copyWithBuffer(dst, src)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
|
||||
}
|
||||
|
||||
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
|
||||
errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||
defer func() {
|
||||
_, _ = io.Copy(io.Discard, src)
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := copyWithBuffer(dst, src)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error copying src→dst: %v", err)
|
||||
err := f.copyAndClose(dst, src, "src to dst")
|
||||
if err != nil {
|
||||
log.Println("Error during copy: ", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := copyWithBuffer(src, dst)
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("Error copying dst→src: %v", err)
|
||||
err := f.copyAndClose(src, dst, "dst to src")
|
||||
if err != nil {
|
||||
log.Println("Error during copy: ", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetType(tunnelType types.TunnelType) {
|
||||
func (f *forwarder) SetType(tunnelType types.TunnelType) {
|
||||
f.tunnelType = tunnelType
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetTunnelType() types.TunnelType {
|
||||
func (f *forwarder) TunnelType() types.TunnelType {
|
||||
return f.tunnelType
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetForwardedPort() uint16 {
|
||||
func (f *forwarder) ForwardedPort() uint16 {
|
||||
return f.forwardedPort
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetForwardedPort(port uint16) {
|
||||
func (f *forwarder) SetForwardedPort(port uint16) {
|
||||
f.forwardedPort = port
|
||||
}
|
||||
|
||||
func (f *Forwarder) SetListener(listener net.Listener) {
|
||||
func (f *forwarder) SetListener(listener net.Listener) {
|
||||
f.listener = listener
|
||||
}
|
||||
|
||||
func (f *Forwarder) GetListener() net.Listener {
|
||||
func (f *forwarder) Listener() net.Listener {
|
||||
return f.listener
|
||||
}
|
||||
|
||||
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||
_, err := dst.Write(types.BadGatewayResponse)
|
||||
if err != nil {
|
||||
log.Printf("failed to write Bad Gateway response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Forwarder) Close() error {
|
||||
if f.GetListener() != nil {
|
||||
func (f *forwarder) Close() error {
|
||||
if f.Listener() != nil {
|
||||
return f.listener.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
host, originPort := parseAddr(origin.String())
|
||||
|
||||
writeSSHString(&buf, "localhost")
|
||||
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort()))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
writeSSHString(&buf, host)
|
||||
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func parseAddr(addr string) (string, uint16) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
||||
return "0.0.0.0", uint16(0)
|
||||
}
|
||||
func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte {
|
||||
host, portStr, _ := net.SplitHostPort(origin.String())
|
||||
port, _ := strconv.Atoi(portStr)
|
||||
return host, uint16(port)
|
||||
}
|
||||
|
||||
func writeSSHString(buffer *bytes.Buffer, str string) {
|
||||
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
||||
if err != nil {
|
||||
log.Printf("Failed to write string to buffer: %v", err)
|
||||
return
|
||||
forwardPayload := struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}{
|
||||
DestAddr: "localhost",
|
||||
DestPort: uint32(destPort),
|
||||
OriginAddr: host,
|
||||
OriginPort: uint32(port),
|
||||
}
|
||||
buffer.WriteString(str)
|
||||
|
||||
return ssh.Marshal(forwardPayload)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,291 +0,0 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
|
||||
for req := range GlobalRequest {
|
||||
switch req.Type {
|
||||
case "shell", "pty-req":
|
||||
err := req.Reply(true, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
case "window-change":
|
||||
p := req.Payload
|
||||
if len(p) < 16 {
|
||||
log.Println("invalid window-change payload")
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
cols := binary.BigEndian.Uint32(p[0:4])
|
||||
rows := binary.BigEndian.Uint32(p[4:8])
|
||||
|
||||
s.interaction.SetWH(int(cols), int(rows))
|
||||
|
||||
err := req.Reply(true, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
log.Println("Unknown request type:", req.Type)
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
|
||||
log.Println("Port forwarding request detected")
|
||||
|
||||
reader := bytes.NewReader(req.Payload)
|
||||
|
||||
addr, err := readSSHString(reader)
|
||||
if err != nil {
|
||||
log.Println("Failed to read address from payload:", err)
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
err = s.lifecycle.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var rawPortToBind uint32
|
||||
if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
|
||||
log.Println("Failed to read port from payload:", err)
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
err = s.lifecycle.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if rawPortToBind > 65535 {
|
||||
log.Printf("Port %d is larger than allowed port of 65535", rawPortToBind)
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
err = s.lifecycle.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
portToBind := uint16(rawPortToBind)
|
||||
if isBlockedPort(portToBind) {
|
||||
log.Printf("Port %d is blocked or restricted", portToBind)
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
err = s.lifecycle.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if portToBind == 80 || portToBind == 443 {
|
||||
s.HandleHTTPForward(req, portToBind)
|
||||
return
|
||||
}
|
||||
if portToBind == 0 {
|
||||
unassign, success := portUtil.Default.GetUnassignedPort()
|
||||
portToBind = unassign
|
||||
if !success {
|
||||
log.Println("No available port")
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
err = s.lifecycle.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
|
||||
log.Printf("Port %d is already in use or restricted", portToBind)
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
err = s.lifecycle.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
err = portUtil.Default.SetPortStatus(portToBind, true)
|
||||
if err != nil {
|
||||
log.Println("Failed to set port status:", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.HandleTCPForward(req, addr, portToBind)
|
||||
}
|
||||
|
||||
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
|
||||
slug := random.GenerateRandomString(20)
|
||||
|
||||
if !s.registry.Register(slug, s) {
|
||||
log.Printf("Failed to register client with slug: %s", slug)
|
||||
err := req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
||||
if err != nil {
|
||||
log.Println("Failed to write port to buffer:", err)
|
||||
s.registry.Remove(slug)
|
||||
err = req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
log.Printf("HTTP forwarding approved on port: %d", portToBind)
|
||||
|
||||
err = req.Reply(true, buf.Bytes())
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
s.registry.Remove(slug)
|
||||
err = req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.forwarder.SetType(types.HTTP)
|
||||
s.forwarder.SetForwardedPort(portToBind)
|
||||
s.slugManager.Set(slug)
|
||||
s.lifecycle.SetStatus(types.RUNNING)
|
||||
s.interaction.Start()
|
||||
}
|
||||
|
||||
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
|
||||
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
|
||||
if err != nil {
|
||||
log.Printf("Port %d is already in use or restricted", portToBind)
|
||||
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
|
||||
log.Printf("Failed to reset port status: %v", setErr)
|
||||
}
|
||||
err = req.Reply(false, nil)
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
return
|
||||
}
|
||||
err = s.lifecycle.Close()
|
||||
if err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
|
||||
if err != nil {
|
||||
log.Println("Failed to write port to buffer:", err)
|
||||
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
|
||||
log.Printf("Failed to reset port status: %v", setErr)
|
||||
}
|
||||
err = listener.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close listener: %s", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("TCP forwarding approved on port: %d", portToBind)
|
||||
err = req.Reply(true, buf.Bytes())
|
||||
if err != nil {
|
||||
log.Println("Failed to reply to request:", err)
|
||||
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
|
||||
log.Printf("Failed to reset port status: %v", setErr)
|
||||
}
|
||||
err = listener.Close()
|
||||
if err != nil {
|
||||
log.Printf("Failed to close listener: %s", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.forwarder.SetType(types.TCP)
|
||||
s.forwarder.SetListener(listener)
|
||||
s.forwarder.SetForwardedPort(portToBind)
|
||||
s.lifecycle.SetStatus(types.RUNNING)
|
||||
go s.forwarder.AcceptTCPConnections()
|
||||
s.interaction.Start()
|
||||
}
|
||||
|
||||
func readSSHString(reader *bytes.Reader) (string, error) {
|
||||
var length uint32
|
||||
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
|
||||
return "", err
|
||||
}
|
||||
strBytes := make([]byte, length)
|
||||
if _, err := reader.Read(strBytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(strBytes), nil
|
||||
}
|
||||
|
||||
func isBlockedPort(port uint16) bool {
|
||||
if port == 80 || port == 443 {
|
||||
return false
|
||||
}
|
||||
if port < 1024 && port != 0 {
|
||||
return true
|
||||
}
|
||||
for _, p := range blockedReservedPorts {
|
||||
if p == port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package interaction
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
func (m *model) comingSoonUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
m.showingComingSoon = false
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
|
||||
func (m *model) comingSoonView() string {
|
||||
isCompact := shouldUseCompactLayout(m.width, 60)
|
||||
|
||||
var boxPadding int
|
||||
var boxMargin int
|
||||
if isCompact {
|
||||
boxPadding = 1
|
||||
boxMargin = 1
|
||||
} else {
|
||||
boxPadding = 3
|
||||
boxMargin = 2
|
||||
}
|
||||
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
PaddingTop(1).
|
||||
PaddingBottom(1)
|
||||
|
||||
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||
messageBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
Background(lipgloss.Color("#1A1A2E")).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin).
|
||||
Width(messageBoxWidth).
|
||||
Align(lipgloss.Center)
|
||||
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#666666")).
|
||||
Italic(true).
|
||||
MarginTop(1)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("\n\n")
|
||||
|
||||
var title string
|
||||
if shouldUseCompactLayout(m.width, 40) {
|
||||
title = "Coming Soon"
|
||||
} else {
|
||||
title = "⏳ Coming Soon"
|
||||
}
|
||||
b.WriteString(titleStyle.Render(title))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
var message string
|
||||
if shouldUseCompactLayout(m.width, 50) {
|
||||
message = "Coming soon!\nStay tuned."
|
||||
} else {
|
||||
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
|
||||
}
|
||||
b.WriteString(messageBoxStyle.Render(message))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
var helpText string
|
||||
if shouldUseCompactLayout(m.width, 60) {
|
||||
helpText = "Press any key..."
|
||||
} else {
|
||||
helpText = "This message will disappear in 5 seconds or press any key..."
|
||||
}
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package interaction
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
func (m *model) handleCommandSelection(item commandItem) (tea.Model, tea.Cmd) {
|
||||
switch item.name {
|
||||
case "slug":
|
||||
m.showingCommands = false
|
||||
m.editingSlug = true
|
||||
m.slugInput.SetValue(m.interaction.slug.String())
|
||||
m.slugInput.Focus()
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case "tunnel-type":
|
||||
m.showingCommands = false
|
||||
m.showingComingSoon = true
|
||||
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
|
||||
default:
|
||||
m.showingCommands = false
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
switch {
|
||||
case key.Matches(msg, m.keymap.quit), msg.String() == "esc":
|
||||
m.showingCommands = false
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case msg.String() == "enter":
|
||||
selectedItem := m.commandList.SelectedItem()
|
||||
if selectedItem != nil {
|
||||
item := selectedItem.(commandItem)
|
||||
return m.handleCommandSelection(item)
|
||||
}
|
||||
}
|
||||
m.commandList, cmd = m.commandList.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *model) commandsView() string {
|
||||
isCompact := shouldUseCompactLayout(m.width, 60)
|
||||
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
PaddingTop(1).
|
||||
PaddingBottom(1)
|
||||
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#666666")).
|
||||
Italic(true).
|
||||
MarginTop(1)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
|
||||
var title string
|
||||
if shouldUseCompactLayout(m.width, 40) {
|
||||
title = "Commands"
|
||||
} else {
|
||||
title = "⚡ Commands"
|
||||
}
|
||||
b.WriteString(titleStyle.Render(title))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(m.commandList.View())
|
||||
b.WriteString("\n")
|
||||
|
||||
var helpText string
|
||||
if isCompact {
|
||||
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
|
||||
} else {
|
||||
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
|
||||
}
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
package interaction
|
||||
|
||||
const (
|
||||
backspaceChar = 8
|
||||
deleteChar = 127
|
||||
enterChar = 13
|
||||
escapeChar = 27
|
||||
ctrlC = 3
|
||||
forwardSlash = '/'
|
||||
minPrintableChar = 32
|
||||
maxPrintableChar = 126
|
||||
|
||||
minSlugLength = 3
|
||||
maxSlugLength = 20
|
||||
|
||||
clearScreen = "\033[H\033[2J"
|
||||
clearLine = "\033[K"
|
||||
clearToLineEnd = "\r\033[K"
|
||||
backspaceSeq = "\b \b"
|
||||
|
||||
minBoxWidth = 50
|
||||
paddingRight = 4
|
||||
)
|
||||
|
||||
var forbiddenSlugs = map[string]struct{}{
|
||||
"ping": {},
|
||||
"staging": {},
|
||||
"admin": {},
|
||||
"root": {},
|
||||
"api": {},
|
||||
"www": {},
|
||||
"support": {},
|
||||
"help": {},
|
||||
"status": {},
|
||||
"health": {},
|
||||
"login": {},
|
||||
"logout": {},
|
||||
"signup": {},
|
||||
"register": {},
|
||||
"settings": {},
|
||||
"config": {},
|
||||
"null": {},
|
||||
"undefined": {},
|
||||
"example": {},
|
||||
"test": {},
|
||||
"dev": {},
|
||||
"system": {},
|
||||
"administrator": {},
|
||||
"dashboard": {},
|
||||
"account": {},
|
||||
"profile": {},
|
||||
"user": {},
|
||||
"users": {},
|
||||
"auth": {},
|
||||
"oauth": {},
|
||||
"callback": {},
|
||||
"webhook": {},
|
||||
"webhooks": {},
|
||||
"static": {},
|
||||
"assets": {},
|
||||
"cdn": {},
|
||||
"mail": {},
|
||||
"email": {},
|
||||
"ftp": {},
|
||||
"ssh": {},
|
||||
"git": {},
|
||||
"svn": {},
|
||||
"blog": {},
|
||||
"news": {},
|
||||
"about": {},
|
||||
"contact": {},
|
||||
"terms": {},
|
||||
"privacy": {},
|
||||
"legal": {},
|
||||
"billing": {},
|
||||
"payment": {},
|
||||
"checkout": {},
|
||||
"cart": {},
|
||||
"shop": {},
|
||||
"store": {},
|
||||
"download": {},
|
||||
"uploads": {},
|
||||
"images": {},
|
||||
"img": {},
|
||||
"css": {},
|
||||
"js": {},
|
||||
"fonts": {},
|
||||
"public": {},
|
||||
"private": {},
|
||||
"internal": {},
|
||||
"external": {},
|
||||
"proxy": {},
|
||||
"cache": {},
|
||||
"debug": {},
|
||||
"metrics": {},
|
||||
"monitoring": {},
|
||||
"graphql": {},
|
||||
"rest": {},
|
||||
"rpc": {},
|
||||
"socket": {},
|
||||
"ws": {},
|
||||
"wss": {},
|
||||
"app": {},
|
||||
"apps": {},
|
||||
"mobile": {},
|
||||
"desktop": {},
|
||||
"embed": {},
|
||||
"widget": {},
|
||||
"docs": {},
|
||||
"documentation": {},
|
||||
"wiki": {},
|
||||
"forum": {},
|
||||
"community": {},
|
||||
"feedback": {},
|
||||
"report": {},
|
||||
"abuse": {},
|
||||
"spam": {},
|
||||
"security": {},
|
||||
"verify": {},
|
||||
"confirm": {},
|
||||
"reset": {},
|
||||
"password": {},
|
||||
"recovery": {},
|
||||
"unsubscribe": {},
|
||||
"subscribe": {},
|
||||
"notifications": {},
|
||||
"alerts": {},
|
||||
"messages": {},
|
||||
"inbox": {},
|
||||
"outbox": {},
|
||||
"sent": {},
|
||||
"draft": {},
|
||||
"trash": {},
|
||||
"archive": {},
|
||||
"search": {},
|
||||
"explore": {},
|
||||
"discover": {},
|
||||
"trending": {},
|
||||
"popular": {},
|
||||
"featured": {},
|
||||
"new": {},
|
||||
"latest": {},
|
||||
"top": {},
|
||||
"best": {},
|
||||
"hot": {},
|
||||
"random": {},
|
||||
"all": {},
|
||||
"any": {},
|
||||
"none": {},
|
||||
"true": {},
|
||||
"false": {},
|
||||
}
|
||||
@@ -0,0 +1,216 @@
|
||||
package interaction
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
func (m *model) dashboardUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
switch {
|
||||
case key.Matches(msg, m.keymap.quit):
|
||||
m.quitting = true
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
|
||||
case key.Matches(msg, m.keymap.command):
|
||||
m.showingCommands = true
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *model) dashboardView() string {
|
||||
isCompact := shouldUseCompactLayout(m.width, BreakpointLarge)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(m.renderHeader(isCompact))
|
||||
b.WriteString(m.renderUserInfo(isCompact))
|
||||
b.WriteString(m.renderQuickActions(isCompact))
|
||||
b.WriteString(m.renderFooter(isCompact))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) renderHeader(isCompact bool) string {
|
||||
var b strings.Builder
|
||||
|
||||
asciiArtMargin := getMarginValue(isCompact, 0, 1)
|
||||
asciiArtStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
MarginBottom(asciiArtMargin)
|
||||
|
||||
b.WriteString(asciiArtStyle.Render(m.getASCIIArt()))
|
||||
b.WriteString("\n")
|
||||
|
||||
if !shouldUseCompactLayout(m.width, BreakpointSmall) {
|
||||
b.WriteString(m.renderSubtitle())
|
||||
} else {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) getASCIIArt() string {
|
||||
if shouldUseCompactLayout(m.width, BreakpointTiny) {
|
||||
return "TUNNEL PLS"
|
||||
}
|
||||
|
||||
if shouldUseCompactLayout(m.width, BreakpointLarge) {
|
||||
return `
|
||||
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
|
||||
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
|
||||
}
|
||||
|
||||
return `
|
||||
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
|
||||
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
|
||||
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
|
||||
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
|
||||
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
|
||||
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
|
||||
}
|
||||
|
||||
func (m *model) renderSubtitle() string {
|
||||
subtitleStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorGray)).
|
||||
Italic(true)
|
||||
|
||||
urlStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
Underline(true).
|
||||
Italic(true)
|
||||
|
||||
return subtitleStyle.Render("Secure tunnel service by Bagas • ") +
|
||||
urlStyle.Render("https://fossy.my.id") + "\n\n"
|
||||
}
|
||||
|
||||
func (m *model) renderUserInfo(isCompact bool) string {
|
||||
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
|
||||
boxPadding := getMarginValue(isCompact, 1, 2)
|
||||
boxMargin := getMarginValue(isCompact, 1, 2)
|
||||
|
||||
responsiveInfoBox := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color(ColorPrimary)).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin).
|
||||
Width(boxMaxWidth)
|
||||
|
||||
infoContent := m.getUserInfoContent(isCompact)
|
||||
return responsiveInfoBox.Render(infoContent) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) getUserInfoContent(isCompact bool) string {
|
||||
userInfoStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorWhite)).
|
||||
Bold(true)
|
||||
|
||||
sectionHeaderStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorGray)).
|
||||
Bold(true)
|
||||
|
||||
addressStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorWhite))
|
||||
|
||||
urlBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorSecondary)).
|
||||
Bold(true).
|
||||
Italic(true)
|
||||
|
||||
authenticatedUser := m.interaction.user
|
||||
tunnelURL := urlBoxStyle.Render(m.getTunnelURL())
|
||||
|
||||
if isCompact {
|
||||
return fmt.Sprintf("👤 %s\n\n%s\n%s",
|
||||
userInfoStyle.Render(authenticatedUser),
|
||||
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
|
||||
addressStyle.Render(fmt.Sprintf(" %s", tunnelURL)))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
|
||||
userInfoStyle.Render(authenticatedUser),
|
||||
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
|
||||
addressStyle.Render(tunnelURL))
|
||||
}
|
||||
|
||||
func (m *model) renderQuickActions(isCompact bool) string {
|
||||
var b strings.Builder
|
||||
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
PaddingTop(1)
|
||||
|
||||
b.WriteString(titleStyle.Render(m.getQuickActionsTitle()))
|
||||
b.WriteString("\n")
|
||||
|
||||
featureMargin := getMarginValue(isCompact, 1, 2)
|
||||
featureStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorWhite)).
|
||||
MarginLeft(featureMargin)
|
||||
|
||||
keyHintStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
Bold(true)
|
||||
|
||||
commands := m.getActionCommands(keyHintStyle)
|
||||
b.WriteString(featureStyle.Render(commands.commandsText))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(featureStyle.Render(commands.quitText))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) getQuickActionsTitle() string {
|
||||
if shouldUseCompactLayout(m.width, BreakpointTiny) {
|
||||
return "Actions"
|
||||
}
|
||||
if shouldUseCompactLayout(m.width, BreakpointLarge) {
|
||||
return "Quick Actions"
|
||||
}
|
||||
return "✨ Quick Actions"
|
||||
}
|
||||
|
||||
type actionCommands struct {
|
||||
commandsText string
|
||||
quitText string
|
||||
}
|
||||
|
||||
func (m *model) getActionCommands(keyHintStyle lipgloss.Style) actionCommands {
|
||||
if shouldUseCompactLayout(m.width, BreakpointSmall) {
|
||||
return actionCommands{
|
||||
commandsText: fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]")),
|
||||
quitText: fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]")),
|
||||
}
|
||||
}
|
||||
|
||||
return actionCommands{
|
||||
commandsText: fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]")),
|
||||
quitText: fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]")),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) renderFooter(isCompact bool) string {
|
||||
if isCompact {
|
||||
return ""
|
||||
}
|
||||
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorDarkGray)).
|
||||
Italic(true)
|
||||
|
||||
return "\n\n" + footerStyle.Render("Press 'C' to customize your tunnel settings")
|
||||
}
|
||||
|
||||
func getMarginValue(isCompact bool, compactValue, normalValue int) int {
|
||||
if isCompact {
|
||||
return compactValue
|
||||
}
|
||||
return normalValue
|
||||
}
|
||||
@@ -2,10 +2,8 @@ package interaction
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
"sync"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/session/slug"
|
||||
@@ -21,36 +19,59 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Lifecycle interface {
|
||||
Close() error
|
||||
type Interaction interface {
|
||||
Mode() types.InteractiveMode
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetMode(m types.InteractiveMode)
|
||||
SetWH(w, h int)
|
||||
Start()
|
||||
Redraw()
|
||||
Send(message string) error
|
||||
}
|
||||
|
||||
type Controller interface {
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetLifecycle(lifecycle Lifecycle)
|
||||
SetSlugModificator(func(oldSlug, newSlug string) bool)
|
||||
Start()
|
||||
SetWH(w, h int)
|
||||
type SessionRegistry interface {
|
||||
Update(user string, oldKey, newKey types.SessionKey) error
|
||||
}
|
||||
|
||||
type Forwarder interface {
|
||||
Close() error
|
||||
GetTunnelType() types.TunnelType
|
||||
GetForwardedPort() uint16
|
||||
TunnelType() types.TunnelType
|
||||
ForwardedPort() uint16
|
||||
}
|
||||
|
||||
type Interaction struct {
|
||||
channel ssh.Channel
|
||||
slugManager slug.Manager
|
||||
forwarder Forwarder
|
||||
lifecycle Lifecycle
|
||||
updateClientSlug func(oldSlug, newSlug string) bool
|
||||
program *tea.Program
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
type CloseFunc func() error
|
||||
type interaction struct {
|
||||
randomizer random.Random
|
||||
config config.Config
|
||||
channel ssh.Channel
|
||||
slug slug.Slug
|
||||
forwarder Forwarder
|
||||
closeFunc CloseFunc
|
||||
user string
|
||||
sessionRegistry SessionRegistry
|
||||
program *tea.Program
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
mode types.InteractiveMode
|
||||
programMu sync.Mutex
|
||||
}
|
||||
|
||||
func (i *Interaction) SetWH(w, h int) {
|
||||
func (i *interaction) SetMode(m types.InteractiveMode) {
|
||||
i.mode = m
|
||||
}
|
||||
|
||||
func (i *interaction) Mode() types.InteractiveMode {
|
||||
return i.mode
|
||||
}
|
||||
|
||||
func (i *interaction) Send(message string) error {
|
||||
if i.channel != nil {
|
||||
_, err := i.channel.Write([]byte(message))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (i *interaction) SetWH(w, h int) {
|
||||
if i.program != nil {
|
||||
i.program.Send(tea.WindowSizeMsg{
|
||||
Width: w,
|
||||
@@ -59,116 +80,42 @@ func (i *Interaction) SetWH(w, h int) {
|
||||
}
|
||||
}
|
||||
|
||||
type commandItem struct {
|
||||
name string
|
||||
desc string
|
||||
}
|
||||
|
||||
type model struct {
|
||||
tunnelURL string
|
||||
domain string
|
||||
protocol string
|
||||
tunnelType types.TunnelType
|
||||
port uint16
|
||||
keymap keymap
|
||||
help help.Model
|
||||
quitting bool
|
||||
showingCommands bool
|
||||
editingSlug bool
|
||||
showingComingSoon bool
|
||||
commandList list.Model
|
||||
slugInput textinput.Model
|
||||
slugError string
|
||||
interaction *Interaction
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
type keymap struct {
|
||||
quit key.Binding
|
||||
command key.Binding
|
||||
random key.Binding
|
||||
}
|
||||
|
||||
type tickMsg time.Time
|
||||
|
||||
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
|
||||
func New(randomizer random.Random, config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Interaction{
|
||||
channel: nil,
|
||||
slugManager: slugManager,
|
||||
forwarder: forwarder,
|
||||
lifecycle: nil,
|
||||
updateClientSlug: nil,
|
||||
program: nil,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
return &interaction{
|
||||
randomizer: randomizer,
|
||||
config: config,
|
||||
channel: nil,
|
||||
slug: slug,
|
||||
forwarder: forwarder,
|
||||
closeFunc: closeFunc,
|
||||
user: user,
|
||||
sessionRegistry: sessionRegistry,
|
||||
program: nil,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
|
||||
i.lifecycle = lifecycle
|
||||
}
|
||||
|
||||
func (i *Interaction) SetChannel(channel ssh.Channel) {
|
||||
func (i *interaction) SetChannel(channel ssh.Channel) {
|
||||
i.channel = channel
|
||||
}
|
||||
|
||||
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) (success bool)) {
|
||||
i.updateClientSlug = modificator
|
||||
}
|
||||
|
||||
func (i *Interaction) Stop() {
|
||||
func (i *interaction) Stop() {
|
||||
if i.cancel != nil {
|
||||
i.cancel()
|
||||
}
|
||||
|
||||
i.programMu.Lock()
|
||||
defer i.programMu.Unlock()
|
||||
|
||||
if i.program != nil {
|
||||
i.program.Kill()
|
||||
i.program = nil
|
||||
}
|
||||
}
|
||||
|
||||
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
|
||||
width := screenWidth - padding
|
||||
if width > maxWidth {
|
||||
width = maxWidth
|
||||
}
|
||||
if width < minWidth {
|
||||
width = minWidth
|
||||
}
|
||||
return width
|
||||
}
|
||||
|
||||
func shouldUseCompactLayout(width int, threshold int) bool {
|
||||
return width < threshold
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLength int) string {
|
||||
if len(s) <= maxLength {
|
||||
return s
|
||||
}
|
||||
if maxLength < 4 {
|
||||
return s[:maxLength]
|
||||
}
|
||||
return s[:maxLength-3] + "..."
|
||||
}
|
||||
|
||||
func (i commandItem) FilterValue() string { return i.name }
|
||||
func (i commandItem) Title() string { return i.name }
|
||||
func (i commandItem) Description() string { return i.desc }
|
||||
|
||||
func tickCmd(d time.Duration) tea.Cmd {
|
||||
return tea.Tick(d, func(t time.Time) tea.Msg {
|
||||
return tickMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
func (m model) Init() tea.Cmd {
|
||||
return tea.Batch(textinput.Blink, tea.WindowSize())
|
||||
}
|
||||
|
||||
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tickMsg:
|
||||
@@ -194,555 +141,62 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
|
||||
case tea.KeyMsg:
|
||||
if m.showingComingSoon {
|
||||
m.showingComingSoon = false
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
return m.comingSoonUpdate(msg)
|
||||
}
|
||||
|
||||
if m.editingSlug {
|
||||
if m.tunnelType != types.HTTP {
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case "enter":
|
||||
inputValue := m.slugInput.Value()
|
||||
|
||||
if isForbiddenSlug(inputValue) {
|
||||
m.slugError = "This subdomain is reserved. Please choose a different one."
|
||||
return m, nil
|
||||
} else if !isValidSlug(inputValue) {
|
||||
m.slugError = "Invalid subdomain. Follow the rules."
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if !m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue) {
|
||||
m.slugError = "Someone already uses this subdomain."
|
||||
return m, nil
|
||||
}
|
||||
|
||||
m.tunnelURL = buildURL(m.protocol, inputValue, m.domain)
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case "ctrl+c":
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
default:
|
||||
if key.Matches(msg, m.keymap.random) {
|
||||
newSubdomain := generateRandomSubdomain()
|
||||
m.slugInput.SetValue(newSubdomain)
|
||||
m.slugError = ""
|
||||
m.slugInput, cmd = m.slugInput.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
m.slugError = ""
|
||||
m.slugInput, cmd = m.slugInput.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
return m.slugUpdate(msg)
|
||||
}
|
||||
|
||||
if m.showingCommands {
|
||||
switch {
|
||||
case key.Matches(msg, m.keymap.quit):
|
||||
m.showingCommands = false
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case msg.String() == "enter":
|
||||
selectedItem := m.commandList.SelectedItem()
|
||||
if selectedItem != nil {
|
||||
item := selectedItem.(commandItem)
|
||||
if item.name == "slug" {
|
||||
m.showingCommands = false
|
||||
m.editingSlug = true
|
||||
m.slugInput.SetValue(m.interaction.slugManager.Get())
|
||||
m.slugInput.Focus()
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
} else if item.name == "tunnel-type" {
|
||||
m.showingCommands = false
|
||||
m.showingComingSoon = true
|
||||
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
m.showingCommands = false
|
||||
return m, nil
|
||||
}
|
||||
case msg.String() == "esc":
|
||||
m.showingCommands = false
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
m.commandList, cmd = m.commandList.Update(msg)
|
||||
return m, cmd
|
||||
return m.commandsUpdate(msg)
|
||||
}
|
||||
|
||||
switch {
|
||||
case key.Matches(msg, m.keymap.quit):
|
||||
m.quitting = true
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
|
||||
case key.Matches(msg, m.keymap.command):
|
||||
m.showingCommands = true
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
return m.dashboardUpdate(msg)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m model) helpView() string {
|
||||
return "\n" + m.help.ShortHelpView([]key.Binding{
|
||||
m.keymap.command,
|
||||
m.keymap.quit,
|
||||
})
|
||||
func (i *interaction) Redraw() {
|
||||
if i.program != nil {
|
||||
i.program.Send(tea.ClearScreen())
|
||||
}
|
||||
}
|
||||
|
||||
func (m model) View() string {
|
||||
func (m *model) View() string {
|
||||
if m.quitting {
|
||||
return ""
|
||||
}
|
||||
|
||||
if m.showingComingSoon {
|
||||
isCompact := shouldUseCompactLayout(m.width, 60)
|
||||
|
||||
var boxPadding int
|
||||
var boxMargin int
|
||||
if isCompact {
|
||||
boxPadding = 1
|
||||
boxMargin = 1
|
||||
} else {
|
||||
boxPadding = 3
|
||||
boxMargin = 2
|
||||
}
|
||||
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
PaddingTop(1).
|
||||
PaddingBottom(1)
|
||||
|
||||
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||
messageBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
Background(lipgloss.Color("#1A1A2E")).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin).
|
||||
Width(messageBoxWidth).
|
||||
Align(lipgloss.Center)
|
||||
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#666666")).
|
||||
Italic(true).
|
||||
MarginTop(1)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("\n\n")
|
||||
|
||||
var title string
|
||||
if shouldUseCompactLayout(m.width, 40) {
|
||||
title = "Coming Soon"
|
||||
} else {
|
||||
title = "⏳ Coming Soon"
|
||||
}
|
||||
b.WriteString(titleStyle.Render(title))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
var message string
|
||||
if shouldUseCompactLayout(m.width, 50) {
|
||||
message = "Coming soon!\nStay tuned."
|
||||
} else {
|
||||
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
|
||||
}
|
||||
b.WriteString(messageBoxStyle.Render(message))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
var helpText string
|
||||
if shouldUseCompactLayout(m.width, 60) {
|
||||
helpText = "Press any key..."
|
||||
} else {
|
||||
helpText = "This message will disappear in 5 seconds or press any key..."
|
||||
}
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
|
||||
return b.String()
|
||||
return m.comingSoonView()
|
||||
}
|
||||
|
||||
if m.editingSlug {
|
||||
isCompact := shouldUseCompactLayout(m.width, 70)
|
||||
isVeryCompact := shouldUseCompactLayout(m.width, 50)
|
||||
|
||||
var boxPadding int
|
||||
var boxMargin int
|
||||
if isVeryCompact {
|
||||
boxPadding = 1
|
||||
boxMargin = 1
|
||||
} else if isCompact {
|
||||
boxPadding = 1
|
||||
boxMargin = 1
|
||||
} else {
|
||||
boxPadding = 2
|
||||
boxMargin = 2
|
||||
}
|
||||
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
PaddingTop(1).
|
||||
PaddingBottom(1)
|
||||
|
||||
instructionStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
MarginTop(1)
|
||||
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin)
|
||||
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#666666")).
|
||||
Italic(true).
|
||||
MarginTop(1)
|
||||
|
||||
errorBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FF0000")).
|
||||
Background(lipgloss.Color("#3D0000")).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#FF0000")).
|
||||
Padding(0, boxPadding).
|
||||
MarginTop(1).
|
||||
MarginBottom(1)
|
||||
|
||||
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||
rulesBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||
Padding(0, boxPadding).
|
||||
MarginTop(1).
|
||||
MarginBottom(1).
|
||||
Width(rulesBoxWidth)
|
||||
|
||||
var b strings.Builder
|
||||
var title string
|
||||
if isVeryCompact {
|
||||
title = "Edit Subdomain"
|
||||
} else {
|
||||
title = "🔧 Edit Subdomain"
|
||||
}
|
||||
b.WriteString(titleStyle.Render(title))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
if m.tunnelType != types.HTTP {
|
||||
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||
warningBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FFA500")).
|
||||
Background(lipgloss.Color("#3D2000")).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#FFA500")).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin).
|
||||
Width(warningBoxWidth)
|
||||
|
||||
var warningText string
|
||||
if isVeryCompact {
|
||||
warningText = "⚠️ TCP tunnels don't support custom subdomains."
|
||||
} else {
|
||||
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
|
||||
}
|
||||
b.WriteString(warningBoxStyle.Render(warningText))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
var helpText string
|
||||
if isVeryCompact {
|
||||
helpText = "Press any key to go back"
|
||||
} else {
|
||||
helpText = "Press Enter or Esc to go back"
|
||||
}
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
var rulesContent string
|
||||
if isVeryCompact {
|
||||
rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
|
||||
} else if isCompact {
|
||||
rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
|
||||
} else {
|
||||
rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
|
||||
}
|
||||
b.WriteString(rulesBoxStyle.Render(rulesContent))
|
||||
b.WriteString("\n")
|
||||
|
||||
var instruction string
|
||||
if isVeryCompact {
|
||||
instruction = "Custom subdomain:"
|
||||
} else {
|
||||
instruction = "Enter your custom subdomain:"
|
||||
}
|
||||
b.WriteString(instructionStyle.Render(instruction))
|
||||
b.WriteString("\n")
|
||||
|
||||
if m.slugError != "" {
|
||||
errorInputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#FF0000")).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(1)
|
||||
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
|
||||
b.WriteString("\n")
|
||||
} else {
|
||||
b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
|
||||
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
|
||||
|
||||
if len(previewURL) > previewWidth-10 {
|
||||
previewURL = truncateString(previewURL, previewWidth-10)
|
||||
}
|
||||
|
||||
previewStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#04B575")).
|
||||
Italic(true).
|
||||
Width(previewWidth)
|
||||
b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
|
||||
b.WriteString("\n")
|
||||
|
||||
var helpText string
|
||||
if isVeryCompact {
|
||||
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
|
||||
} else {
|
||||
helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
|
||||
}
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
|
||||
return b.String()
|
||||
return m.slugView()
|
||||
}
|
||||
|
||||
if m.showingCommands {
|
||||
isCompact := shouldUseCompactLayout(m.width, 60)
|
||||
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
PaddingTop(1).
|
||||
PaddingBottom(1)
|
||||
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#666666")).
|
||||
Italic(true).
|
||||
MarginTop(1)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString("\n")
|
||||
|
||||
var title string
|
||||
if shouldUseCompactLayout(m.width, 40) {
|
||||
title = "Commands"
|
||||
} else {
|
||||
title = "⚡ Commands"
|
||||
}
|
||||
b.WriteString(titleStyle.Render(title))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(m.commandList.View())
|
||||
b.WriteString("\n")
|
||||
|
||||
var helpText string
|
||||
if isCompact {
|
||||
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
|
||||
} else {
|
||||
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
|
||||
}
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
|
||||
return b.String()
|
||||
return m.commandsView()
|
||||
}
|
||||
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
PaddingTop(1)
|
||||
|
||||
subtitleStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#888888")).
|
||||
Italic(true)
|
||||
|
||||
urlStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
Underline(true).
|
||||
Italic(true)
|
||||
|
||||
urlBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#04B575")).
|
||||
Bold(true).
|
||||
Italic(true)
|
||||
|
||||
keyHintStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
Bold(true)
|
||||
|
||||
var b strings.Builder
|
||||
|
||||
isCompact := shouldUseCompactLayout(m.width, 85)
|
||||
|
||||
var asciiArtMargin int
|
||||
if isCompact {
|
||||
asciiArtMargin = 0
|
||||
} else {
|
||||
asciiArtMargin = 1
|
||||
}
|
||||
|
||||
asciiArtStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color("#7D56F4")).
|
||||
MarginBottom(asciiArtMargin)
|
||||
|
||||
var asciiArt string
|
||||
if shouldUseCompactLayout(m.width, 50) {
|
||||
asciiArt = "TUNNEL PLS"
|
||||
} else if isCompact {
|
||||
asciiArt = `
|
||||
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
|
||||
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
|
||||
} else {
|
||||
asciiArt = `
|
||||
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
|
||||
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
|
||||
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
|
||||
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
|
||||
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
|
||||
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
|
||||
}
|
||||
|
||||
b.WriteString(asciiArtStyle.Render(asciiArt))
|
||||
b.WriteString("\n")
|
||||
|
||||
if !shouldUseCompactLayout(m.width, 60) {
|
||||
b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
|
||||
b.WriteString(urlStyle.Render("https://fossy.my.id"))
|
||||
b.WriteString("\n\n")
|
||||
} else {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
|
||||
var boxPadding int
|
||||
var boxMargin int
|
||||
if isCompact {
|
||||
boxPadding = 1
|
||||
boxMargin = 1
|
||||
} else {
|
||||
boxPadding = 2
|
||||
boxMargin = 2
|
||||
}
|
||||
|
||||
responsiveInfoBox := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("#7D56F4")).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin).
|
||||
Width(boxMaxWidth)
|
||||
|
||||
urlDisplay := m.tunnelURL
|
||||
if shouldUseCompactLayout(m.width, 80) && len(m.tunnelURL) > m.width-20 {
|
||||
maxLen := m.width - 25
|
||||
if maxLen > 10 {
|
||||
urlDisplay = truncateString(m.tunnelURL, maxLen)
|
||||
}
|
||||
}
|
||||
|
||||
var infoContent string
|
||||
if shouldUseCompactLayout(m.width, 70) {
|
||||
infoContent = fmt.Sprintf("🌐 %s", urlBoxStyle.Render(urlDisplay))
|
||||
} else if isCompact {
|
||||
infoContent = fmt.Sprintf("🌐 Forwarding to:\n\n %s", urlBoxStyle.Render(urlDisplay))
|
||||
} else {
|
||||
infoContent = fmt.Sprintf("🌐 F O R W A R D I N G T O:\n\n %s", urlBoxStyle.Render(urlDisplay))
|
||||
}
|
||||
b.WriteString(responsiveInfoBox.Render(infoContent))
|
||||
b.WriteString("\n")
|
||||
|
||||
var quickActionsTitle string
|
||||
if shouldUseCompactLayout(m.width, 50) {
|
||||
quickActionsTitle = "Actions"
|
||||
} else if isCompact {
|
||||
quickActionsTitle = "Quick Actions"
|
||||
} else {
|
||||
quickActionsTitle = "✨ Quick Actions"
|
||||
}
|
||||
b.WriteString(titleStyle.Render(quickActionsTitle))
|
||||
b.WriteString("\n")
|
||||
|
||||
var featureMargin int
|
||||
if isCompact {
|
||||
featureMargin = 1
|
||||
} else {
|
||||
featureMargin = 2
|
||||
}
|
||||
|
||||
compactFeatureStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#FAFAFA")).
|
||||
MarginLeft(featureMargin)
|
||||
|
||||
var commandsText string
|
||||
var quitText string
|
||||
if shouldUseCompactLayout(m.width, 60) {
|
||||
commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
|
||||
quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
|
||||
} else {
|
||||
commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
|
||||
quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
|
||||
}
|
||||
|
||||
b.WriteString(compactFeatureStyle.Render(commandsText))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(compactFeatureStyle.Render(quitText))
|
||||
|
||||
if !shouldUseCompactLayout(m.width, 70) {
|
||||
b.WriteString("\n\n")
|
||||
footerStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color("#666666")).
|
||||
Italic(true)
|
||||
b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
|
||||
}
|
||||
|
||||
return b.String()
|
||||
return m.dashboardView()
|
||||
}
|
||||
|
||||
func (i *Interaction) Start() {
|
||||
func (i *interaction) Start() {
|
||||
if i.mode == types.InteractiveModeHEADLESS {
|
||||
return
|
||||
}
|
||||
lipgloss.SetColorProfile(termenv.TrueColor)
|
||||
|
||||
domain := config.Getenv("DOMAIN", "localhost")
|
||||
protocol := "http"
|
||||
if config.Getenv("TLS_ENABLED", "false") == "true" {
|
||||
if i.config.TLSEnabled() {
|
||||
protocol = "https"
|
||||
}
|
||||
|
||||
tunnelType := i.forwarder.GetTunnelType()
|
||||
port := i.forwarder.GetForwardedPort()
|
||||
|
||||
var tunnelURL string
|
||||
if tunnelType == types.HTTP {
|
||||
tunnelURL = buildURL(protocol, i.slugManager.Get(), domain)
|
||||
} else {
|
||||
tunnelURL = fmt.Sprintf("tcp://%s:%d", domain, port)
|
||||
}
|
||||
tunnelType := i.forwarder.TunnelType()
|
||||
port := i.forwarder.ForwardedPort()
|
||||
|
||||
items := []list.Item{
|
||||
commandItem{name: "slug", desc: "Set custom subdomain"},
|
||||
@@ -764,9 +218,9 @@ func (i *Interaction) Start() {
|
||||
ti.CharLimit = 20
|
||||
ti.Width = 50
|
||||
|
||||
m := model{
|
||||
tunnelURL: tunnelURL,
|
||||
domain: domain,
|
||||
m := &model{
|
||||
randomizer: i.randomizer,
|
||||
domain: i.config.Domain(),
|
||||
protocol: protocol,
|
||||
tunnelType: tunnelType,
|
||||
port: port,
|
||||
@@ -790,6 +244,7 @@ func (i *Interaction) Start() {
|
||||
help: help.New(),
|
||||
}
|
||||
|
||||
i.programMu.Lock()
|
||||
i.program = tea.NewProgram(
|
||||
m,
|
||||
tea.WithInput(i.channel),
|
||||
@@ -800,49 +255,21 @@ func (i *Interaction) Start() {
|
||||
tea.WithoutSignalHandler(),
|
||||
tea.WithFPS(30),
|
||||
)
|
||||
i.programMu.Unlock()
|
||||
|
||||
_, err := i.program.Run()
|
||||
if err != nil {
|
||||
log.Printf("Cannot close tea: %s \n", err)
|
||||
}
|
||||
i.program.Kill()
|
||||
i.program = nil
|
||||
if err := m.interaction.lifecycle.Close(); err != nil {
|
||||
log.Printf("Cannot close session: %s \n", err)
|
||||
|
||||
i.programMu.Lock()
|
||||
if i.program != nil {
|
||||
i.program.Kill()
|
||||
i.program = nil
|
||||
}
|
||||
i.programMu.Unlock()
|
||||
|
||||
if i.closeFunc != nil {
|
||||
_ = i.closeFunc()
|
||||
}
|
||||
}
|
||||
|
||||
func buildURL(protocol, subdomain, domain string) string {
|
||||
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
|
||||
}
|
||||
|
||||
func generateRandomSubdomain() string {
|
||||
return random.GenerateRandomString(20)
|
||||
}
|
||||
|
||||
func isValidSlug(slug string) bool {
|
||||
if len(slug) < minSlugLength || len(slug) > maxSlugLength {
|
||||
return false
|
||||
}
|
||||
|
||||
if slug[0] == '-' || slug[len(slug)-1] == '-' {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, c := range slug {
|
||||
if !isValidSlugChar(byte(c)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidSlugChar(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-'
|
||||
}
|
||||
|
||||
func isForbiddenSlug(slug string) bool {
|
||||
_, ok := forbiddenSlugs[slug]
|
||||
return ok
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,116 @@
|
||||
package interaction
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/charmbracelet/bubbles/help"
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/list"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
type commandItem struct {
|
||||
name string
|
||||
desc string
|
||||
}
|
||||
|
||||
func (i commandItem) FilterValue() string { return i.name }
|
||||
func (i commandItem) Title() string { return i.name }
|
||||
func (i commandItem) Description() string { return i.desc }
|
||||
|
||||
type model struct {
|
||||
randomizer random.Random
|
||||
domain string
|
||||
protocol string
|
||||
tunnelType types.TunnelType
|
||||
port uint16
|
||||
keymap keymap
|
||||
help help.Model
|
||||
quitting bool
|
||||
showingCommands bool
|
||||
editingSlug bool
|
||||
showingComingSoon bool
|
||||
commandList list.Model
|
||||
slugInput textinput.Model
|
||||
slugError string
|
||||
interaction *interaction
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
const (
|
||||
ColorPrimary = "#7D56F4"
|
||||
ColorSecondary = "#04B575"
|
||||
ColorGray = "#888888"
|
||||
ColorDarkGray = "#666666"
|
||||
ColorWhite = "#FAFAFA"
|
||||
ColorError = "#FF0000"
|
||||
ColorErrorBg = "#3D0000"
|
||||
ColorWarning = "#FFA500"
|
||||
ColorWarningBg = "#3D2000"
|
||||
)
|
||||
|
||||
const (
|
||||
BreakpointTiny = 50
|
||||
BreakpointSmall = 60
|
||||
BreakpointMedium = 70
|
||||
BreakpointLarge = 85
|
||||
)
|
||||
|
||||
func (m *model) getTunnelURL() string {
|
||||
if m.tunnelType == types.TunnelTypeHTTP {
|
||||
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
|
||||
}
|
||||
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
|
||||
}
|
||||
|
||||
type keymap struct {
|
||||
quit key.Binding
|
||||
command key.Binding
|
||||
random key.Binding
|
||||
}
|
||||
|
||||
type tickMsg time.Time
|
||||
|
||||
func (m *model) Init() tea.Cmd {
|
||||
return tea.Batch(textinput.Blink, tea.WindowSize())
|
||||
}
|
||||
|
||||
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
|
||||
width := screenWidth - padding
|
||||
if width > maxWidth {
|
||||
width = maxWidth
|
||||
}
|
||||
if width < minWidth {
|
||||
width = minWidth
|
||||
}
|
||||
return width
|
||||
}
|
||||
|
||||
func shouldUseCompactLayout(width int, threshold int) bool {
|
||||
return width < threshold
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLength int) string {
|
||||
if len(s) <= maxLength {
|
||||
return s
|
||||
}
|
||||
if maxLength < 4 {
|
||||
return s[:maxLength]
|
||||
}
|
||||
return s[:maxLength-3] + "..."
|
||||
}
|
||||
|
||||
func tickCmd(d time.Duration) tea.Cmd {
|
||||
return tea.Tick(d, func(t time.Time) tea.Msg {
|
||||
return tickMsg(t)
|
||||
})
|
||||
}
|
||||
|
||||
func buildURL(protocol, subdomain, domain string) string {
|
||||
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
package interaction
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/charmbracelet/bubbles/key"
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||
var cmd tea.Cmd
|
||||
|
||||
if m.tunnelType != types.TunnelTypeHTTP {
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
}
|
||||
|
||||
switch msg.String() {
|
||||
case "esc", "ctrl+c":
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
case "enter":
|
||||
inputValue := m.slugInput.Value()
|
||||
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
|
||||
Id: m.interaction.slug.String(),
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}, types.SessionKey{
|
||||
Id: inputValue,
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}); err != nil {
|
||||
m.slugError = err.Error()
|
||||
return m, nil
|
||||
}
|
||||
m.editingSlug = false
|
||||
m.slugError = ""
|
||||
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
|
||||
default:
|
||||
if key.Matches(msg, m.keymap.random) {
|
||||
newSubdomain, err := m.randomizer.String(20)
|
||||
if err != nil {
|
||||
return m, cmd
|
||||
}
|
||||
m.slugInput.SetValue(newSubdomain)
|
||||
}
|
||||
m.slugError = ""
|
||||
m.slugInput, cmd = m.slugInput.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) slugView() string {
|
||||
isCompact := shouldUseCompactLayout(m.width, BreakpointMedium)
|
||||
isVeryCompact := shouldUseCompactLayout(m.width, BreakpointTiny)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(m.renderSlugTitle(isVeryCompact))
|
||||
|
||||
if m.tunnelType != types.TunnelTypeHTTP {
|
||||
b.WriteString(m.renderTCPWarning(isVeryCompact, isCompact))
|
||||
return b.String()
|
||||
}
|
||||
|
||||
b.WriteString(m.renderSlugRules(isVeryCompact, isCompact))
|
||||
b.WriteString(m.renderSlugInstruction(isVeryCompact))
|
||||
b.WriteString(m.renderSlugInput(isVeryCompact, isCompact))
|
||||
b.WriteString(m.renderSlugPreview(isVeryCompact))
|
||||
b.WriteString(m.renderSlugHelp(isVeryCompact))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) renderSlugTitle(isVeryCompact bool) string {
|
||||
titleStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(lipgloss.Color(ColorPrimary)).
|
||||
PaddingTop(1).
|
||||
PaddingBottom(1)
|
||||
|
||||
title := "🔧 Edit Subdomain"
|
||||
if isVeryCompact {
|
||||
title = "Edit Subdomain"
|
||||
}
|
||||
|
||||
return titleStyle.Render(title) + "\n\n"
|
||||
}
|
||||
|
||||
func (m *model) renderTCPWarning(isVeryCompact, isCompact bool) string {
|
||||
boxPadding := getPaddingValue(isVeryCompact, isCompact)
|
||||
boxMargin := getMarginValue(isCompact, 1, 2)
|
||||
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||
|
||||
warningBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorWarning)).
|
||||
Background(lipgloss.Color(ColorWarningBg)).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color(ColorWarning)).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin).
|
||||
Width(warningBoxWidth)
|
||||
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorDarkGray)).
|
||||
Italic(true).
|
||||
MarginTop(1)
|
||||
|
||||
warningText := m.getTCPWarningText(isVeryCompact)
|
||||
helpText := m.getTCPHelpText(isVeryCompact)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(warningBoxStyle.Render(warningText))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(helpStyle.Render(helpText))
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) getTCPWarningText(isVeryCompact bool) string {
|
||||
if isVeryCompact {
|
||||
return "⚠️ TCP tunnels don't support custom subdomains."
|
||||
}
|
||||
return "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
|
||||
}
|
||||
|
||||
func (m *model) getTCPHelpText(isVeryCompact bool) string {
|
||||
if isVeryCompact {
|
||||
return "Press any key to go back"
|
||||
}
|
||||
return "Press Enter or Esc to go back"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugRules(isVeryCompact, isCompact bool) string {
|
||||
boxPadding := getPaddingValue(isVeryCompact, isCompact)
|
||||
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
|
||||
|
||||
rulesBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorWhite)).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color(ColorPrimary)).
|
||||
Padding(0, boxPadding).
|
||||
MarginTop(1).
|
||||
MarginBottom(1).
|
||||
Width(rulesBoxWidth)
|
||||
|
||||
rulesContent := m.getRulesContent(isVeryCompact, isCompact)
|
||||
return rulesBoxStyle.Render(rulesContent) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) getRulesContent(isVeryCompact, isCompact bool) string {
|
||||
if isVeryCompact {
|
||||
return "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
|
||||
}
|
||||
|
||||
if isCompact {
|
||||
return "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
|
||||
}
|
||||
|
||||
return "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugInstruction(isVeryCompact bool) string {
|
||||
instructionStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorWhite)).
|
||||
MarginTop(1)
|
||||
|
||||
instruction := "Enter your custom subdomain:"
|
||||
if isVeryCompact {
|
||||
instruction = "Custom subdomain:"
|
||||
}
|
||||
|
||||
return instructionStyle.Render(instruction) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugInput(isVeryCompact, isCompact bool) string {
|
||||
boxPadding := getPaddingValue(isVeryCompact, isCompact)
|
||||
boxMargin := getMarginValue(isCompact, 1, 2)
|
||||
|
||||
if m.slugError != "" {
|
||||
return m.renderErrorInput(boxPadding, boxMargin)
|
||||
}
|
||||
|
||||
return m.renderNormalInput(boxPadding, boxMargin)
|
||||
}
|
||||
|
||||
func (m *model) renderErrorInput(boxPadding, boxMargin int) string {
|
||||
errorInputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color(ColorError)).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(1)
|
||||
|
||||
errorBoxStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorError)).
|
||||
Background(lipgloss.Color(ColorErrorBg)).
|
||||
Bold(true).
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color(ColorError)).
|
||||
Padding(0, boxPadding).
|
||||
MarginTop(1).
|
||||
MarginBottom(1)
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
|
||||
b.WriteString("\n")
|
||||
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
|
||||
b.WriteString("\n")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (m *model) renderNormalInput(boxPadding, boxMargin int) string {
|
||||
inputBoxStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color(ColorPrimary)).
|
||||
Padding(1, boxPadding).
|
||||
MarginTop(boxMargin).
|
||||
MarginBottom(boxMargin)
|
||||
|
||||
return inputBoxStyle.Render(m.slugInput.View()) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugPreview(isVeryCompact bool) string {
|
||||
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
|
||||
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
|
||||
|
||||
if isVeryCompact {
|
||||
previewURL = truncateString(previewURL, previewWidth-10)
|
||||
}
|
||||
|
||||
previewStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorSecondary)).
|
||||
Italic(true).
|
||||
Width(previewWidth)
|
||||
|
||||
return previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)) + "\n"
|
||||
}
|
||||
|
||||
func (m *model) renderSlugHelp(isVeryCompact bool) string {
|
||||
helpStyle := lipgloss.NewStyle().
|
||||
Foreground(lipgloss.Color(ColorDarkGray)).
|
||||
Italic(true).
|
||||
MarginTop(1)
|
||||
|
||||
helpText := "Press Enter to save • CTRL+R for random • Esc to cancel"
|
||||
if isVeryCompact {
|
||||
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
|
||||
}
|
||||
|
||||
return helpStyle.Render(helpText)
|
||||
}
|
||||
|
||||
func getPaddingValue(isVeryCompact, isCompact bool) int {
|
||||
if isVeryCompact || isCompact {
|
||||
return 1
|
||||
}
|
||||
return 2
|
||||
}
|
||||
+101
-54
@@ -4,6 +4,9 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
@@ -13,88 +16,132 @@ import (
|
||||
|
||||
type Forwarder interface {
|
||||
Close() error
|
||||
GetTunnelType() types.TunnelType
|
||||
GetForwardedPort() uint16
|
||||
TunnelType() types.TunnelType
|
||||
ForwardedPort() uint16
|
||||
}
|
||||
|
||||
type Lifecycle struct {
|
||||
status types.Status
|
||||
conn ssh.Conn
|
||||
channel ssh.Channel
|
||||
forwarder Forwarder
|
||||
slugManager slug.Manager
|
||||
unregisterClient func(slug string)
|
||||
type SessionRegistry interface {
|
||||
Remove(key types.SessionKey)
|
||||
}
|
||||
|
||||
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
|
||||
return &Lifecycle{
|
||||
status: "",
|
||||
conn: conn,
|
||||
channel: nil,
|
||||
forwarder: forwarder,
|
||||
slugManager: slugManager,
|
||||
unregisterClient: nil,
|
||||
type lifecycle struct {
|
||||
mu sync.Mutex
|
||||
status types.SessionStatus
|
||||
closeErr error
|
||||
conn ssh.Conn
|
||||
channel ssh.Channel
|
||||
forwarder Forwarder
|
||||
slug slug.Slug
|
||||
startedAt time.Time
|
||||
sessionRegistry SessionRegistry
|
||||
portRegistry portUtil.Port
|
||||
user string
|
||||
}
|
||||
|
||||
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
|
||||
return &lifecycle{
|
||||
status: types.SessionStatusINITIALIZING,
|
||||
conn: conn,
|
||||
channel: nil,
|
||||
forwarder: forwarder,
|
||||
slug: slugManager,
|
||||
startedAt: time.Now(),
|
||||
sessionRegistry: sessionRegistry,
|
||||
portRegistry: port,
|
||||
user: user,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
|
||||
l.unregisterClient = unregisterClient
|
||||
}
|
||||
|
||||
type SessionLifecycle interface {
|
||||
Close() error
|
||||
SetStatus(status types.Status)
|
||||
GetConnection() ssh.Conn
|
||||
GetChannel() ssh.Channel
|
||||
type Lifecycle interface {
|
||||
Connection() ssh.Conn
|
||||
Channel() ssh.Channel
|
||||
PortRegistry() portUtil.Port
|
||||
User() string
|
||||
SetChannel(channel ssh.Channel)
|
||||
SetUnregisterClient(unregisterClient func(slug string))
|
||||
SetStatus(status types.SessionStatus)
|
||||
IsActive() bool
|
||||
StartedAt() time.Time
|
||||
Close() error
|
||||
}
|
||||
|
||||
func (l *Lifecycle) GetChannel() ssh.Channel {
|
||||
func (l *lifecycle) PortRegistry() portUtil.Port {
|
||||
return l.portRegistry
|
||||
}
|
||||
|
||||
func (l *lifecycle) User() string {
|
||||
return l.user
|
||||
}
|
||||
|
||||
func (l *lifecycle) SetChannel(channel ssh.Channel) {
|
||||
l.channel = channel
|
||||
}
|
||||
|
||||
func (l *lifecycle) Channel() ssh.Channel {
|
||||
return l.channel
|
||||
}
|
||||
|
||||
func (l *Lifecycle) SetChannel(channel ssh.Channel) {
|
||||
l.channel = channel
|
||||
}
|
||||
func (l *Lifecycle) GetConnection() ssh.Conn {
|
||||
func (l *lifecycle) Connection() ssh.Conn {
|
||||
return l.conn
|
||||
}
|
||||
func (l *Lifecycle) SetStatus(status types.Status) {
|
||||
|
||||
func (l *lifecycle) SetStatus(status types.SessionStatus) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.status = status
|
||||
}
|
||||
|
||||
func (l *Lifecycle) Close() error {
|
||||
err := l.forwarder.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return err
|
||||
func (l *lifecycle) IsActive() bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.status == types.SessionStatusRUNNING
|
||||
}
|
||||
|
||||
func (l *lifecycle) Close() error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if l.status == types.SessionStatusCLOSED {
|
||||
return l.closeErr
|
||||
}
|
||||
l.status = types.SessionStatusCLOSED
|
||||
|
||||
var errs []error
|
||||
tunnelType := l.forwarder.TunnelType()
|
||||
|
||||
if l.channel != nil {
|
||||
err := l.channel.Close()
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
if err := l.channel.Close(); err != nil && !isClosedError(err) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if l.conn != nil {
|
||||
err := l.conn.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
return err
|
||||
if err := l.conn.Close(); err != nil && !isClosedError(err) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
clientSlug := l.slugManager.Get()
|
||||
if clientSlug != "" {
|
||||
l.unregisterClient(clientSlug)
|
||||
clientSlug := l.slug.String()
|
||||
key := types.SessionKey{
|
||||
Id: clientSlug,
|
||||
Type: tunnelType,
|
||||
}
|
||||
l.sessionRegistry.Remove(key)
|
||||
|
||||
if tunnelType == types.TunnelTypeTCP {
|
||||
errs = append(errs, l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false))
|
||||
errs = append(errs, l.forwarder.Close())
|
||||
}
|
||||
|
||||
if l.forwarder.GetTunnelType() == types.TCP {
|
||||
err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
l.closeErr = errors.Join(errs...)
|
||||
return l.closeErr
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || err.Error() == "EOF"
|
||||
}
|
||||
|
||||
func (l *lifecycle) StartedAt() time.Time {
|
||||
return l.startedAt
|
||||
}
|
||||
|
||||
@@ -0,0 +1,303 @@
|
||||
package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key types.SessionKey) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
type MockForwarder struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||
args := m.Called(origin)
|
||||
return args.Get(0).([]byte)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||
m.Called(dst, src)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) TunnelType() types.TunnelType {
|
||||
args := m.Called()
|
||||
return args.Get(0).(types.TunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) ForwardedPort() uint16 {
|
||||
args := m.Called()
|
||||
return args.Get(0).(uint16)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
||||
m.Called(tunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
||||
m.Called(port)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetListener(listener net.Listener) {
|
||||
m.Called(listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Listener() net.Listener {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(ctx, origin)
|
||||
if args.Get(0) == nil {
|
||||
return nil, nil, args.Error(2)
|
||||
}
|
||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
|
||||
type MockPort struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
||||
return m.Called(startPort, endPort).Error(0)
|
||||
}
|
||||
func (m *MockPort) Unassigned() (uint16, bool) {
|
||||
args := m.Called()
|
||||
var port uint16
|
||||
if args.Get(0) != nil {
|
||||
switch v := args.Get(0).(type) {
|
||||
case int:
|
||||
port = uint16(v)
|
||||
case uint16:
|
||||
port = v
|
||||
case uint32:
|
||||
port = uint16(v)
|
||||
case int32:
|
||||
port = uint16(v)
|
||||
case float64:
|
||||
port = uint16(v)
|
||||
default:
|
||||
port = uint16(args.Int(0))
|
||||
}
|
||||
}
|
||||
return port, args.Bool(1)
|
||||
}
|
||||
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
||||
return m.Called(port, assigned).Error(0)
|
||||
}
|
||||
func (m *MockPort) Claim(port uint16) bool {
|
||||
return m.Called(port).Bool(0)
|
||||
}
|
||||
|
||||
type MockSlug struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (ms *MockSlug) Set(slug string) {
|
||||
ms.Called(slug)
|
||||
}
|
||||
func (ms *MockSlug) String() string {
|
||||
return ms.Called().String(0)
|
||||
}
|
||||
|
||||
type MockSSHConn struct {
|
||||
ssh.Conn
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockSSHChannel struct {
|
||||
ssh.Channel
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Close() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
assert.NotNil(t, mockLifecycle.Connection())
|
||||
assert.NotNil(t, mockLifecycle.User())
|
||||
assert.NotNil(t, mockLifecycle.PortRegistry())
|
||||
assert.NotNil(t, mockLifecycle.StartedAt())
|
||||
}
|
||||
|
||||
func TestLifecycle_User(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
user := "mas-fuad"
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user)
|
||||
assert.Equal(t, user, mockLifecycle.User())
|
||||
}
|
||||
|
||||
func TestLifecycle_SetChannel(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
mockSSHChannel := &MockSSHChannel{}
|
||||
|
||||
mockLifecycle.SetChannel(mockSSHChannel)
|
||||
|
||||
assert.Equal(t, mockSSHChannel, mockLifecycle.Channel())
|
||||
}
|
||||
|
||||
func TestLifecycle_SetStatus(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||
assert.True(t, mockLifecycle.IsActive())
|
||||
}
|
||||
|
||||
func TestLifecycle_IsActive(t *testing.T) {
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockSlug := &MockSlug{}
|
||||
mockPort := &MockPort{}
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||
assert.True(t, mockLifecycle.IsActive())
|
||||
}
|
||||
|
||||
func TestLifecycle_Close(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tunnelType types.TunnelType
|
||||
connCloseErr error
|
||||
channelCloseErr error
|
||||
expectErr bool
|
||||
alreadyClosed bool
|
||||
}{
|
||||
{
|
||||
name: "Close HTTP forwarding success",
|
||||
tunnelType: types.TunnelTypeHTTP,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Close TCP forwarding success",
|
||||
tunnelType: types.TunnelTypeTCP,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "Close with conn close error",
|
||||
tunnelType: types.TunnelTypeHTTP,
|
||||
connCloseErr: errors.New("conn close error"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Close with channel close error",
|
||||
tunnelType: types.TunnelTypeHTTP,
|
||||
channelCloseErr: errors.New("channel close error"),
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "Close when already closed",
|
||||
tunnelType: types.TunnelTypeHTTP,
|
||||
alreadyClosed: true,
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSSHConn := &MockSSHConn{}
|
||||
mockSSHConn.On("Close").Return(tt.connCloseErr)
|
||||
|
||||
mockForwarder := &MockForwarder{}
|
||||
mockForwarder.On("TunnelType").Return(tt.tunnelType)
|
||||
if tt.tunnelType == types.TunnelTypeTCP {
|
||||
mockForwarder.On("ForwardedPort").Return(uint16(8080))
|
||||
mockForwarder.On("Close").Return(nil)
|
||||
}
|
||||
|
||||
mockSlug := &MockSlug{}
|
||||
mockSlug.On("String").Return("test-slug")
|
||||
|
||||
mockPort := &MockPort{}
|
||||
if tt.tunnelType == types.TunnelTypeTCP {
|
||||
mockPort.On("SetStatus", uint16(8080), false).Return(nil)
|
||||
}
|
||||
|
||||
mockSessionRegistry := &MockSessionRegistry{}
|
||||
mockSessionRegistry.On("Remove", mock.Anything).Return()
|
||||
|
||||
mockSSHChannel := &MockSSHChannel{}
|
||||
mockSSHChannel.On("Close").Return(tt.channelCloseErr)
|
||||
|
||||
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
|
||||
|
||||
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||
mockLifecycle.SetChannel(mockSSHChannel)
|
||||
|
||||
if tt.alreadyClosed {
|
||||
err := mockLifecycle.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
err := mockLifecycle.Close()
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.False(t, mockLifecycle.IsActive())
|
||||
|
||||
mockSSHConn.AssertExpectations(t)
|
||||
mockForwarder.AssertExpectations(t)
|
||||
mockSlug.AssertExpectations(t)
|
||||
mockPort.AssertExpectations(t)
|
||||
mockSessionRegistry.AssertExpectations(t)
|
||||
mockSSHChannel.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package session
|
||||
|
||||
import "sync"
|
||||
|
||||
type Registry interface {
|
||||
Get(slug string) (session *SSHSession, exist bool)
|
||||
Update(oldSlug, newSlug string) (success bool)
|
||||
Register(slug string, session *SSHSession) (success bool)
|
||||
Remove(slug string)
|
||||
}
|
||||
type registry struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*SSHSession
|
||||
}
|
||||
|
||||
func NewRegistry() Registry {
|
||||
return ®istry{
|
||||
clients: make(map[string]*SSHSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registry) Get(slug string) (session *SSHSession, exist bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
session, exist = r.clients[slug]
|
||||
return
|
||||
}
|
||||
|
||||
func (r *registry) Update(oldSlug, newSlug string) (success bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.clients[newSlug]; exists && newSlug != oldSlug {
|
||||
return false
|
||||
}
|
||||
|
||||
client, ok := r.clients[oldSlug]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(r.clients, oldSlug)
|
||||
client.slugManager.Set(newSlug)
|
||||
r.clients[newSlug] = client
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *registry) Register(slug string, session *SSHSession) (success bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if _, exists := r.clients[slug]; exists {
|
||||
return false
|
||||
}
|
||||
|
||||
r.clients[slug] = session
|
||||
return true
|
||||
}
|
||||
|
||||
func (r *registry) Remove(slug string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
delete(r.clients, slug)
|
||||
}
|
||||
+337
-58
@@ -1,107 +1,203 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/random"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/internal/transport"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Session interface {
|
||||
HandleGlobalRequest(ch <-chan *ssh.Request)
|
||||
HandleTCPIPForward(req *ssh.Request)
|
||||
HandleHTTPForward(req *ssh.Request, port uint16)
|
||||
HandleTCPForward(req *ssh.Request, addr string, port uint16)
|
||||
HandleGlobalRequest(ch <-chan *ssh.Request) error
|
||||
HandleTCPIPForward(req *ssh.Request) error
|
||||
HandleHTTPForward(req *ssh.Request, port uint16) error
|
||||
HandleTCPForward(req *ssh.Request, addr string, port uint16) error
|
||||
Lifecycle() lifecycle.Lifecycle
|
||||
Interaction() interaction.Interaction
|
||||
Forwarder() forwarder.Forwarder
|
||||
Slug() slug.Slug
|
||||
Detail() *types.Detail
|
||||
Start() error
|
||||
}
|
||||
|
||||
type SSHSession struct {
|
||||
initialReq <-chan *ssh.Request
|
||||
sshReqChannel <-chan ssh.NewChannel
|
||||
lifecycle lifecycle.SessionLifecycle
|
||||
interaction interaction.Controller
|
||||
forwarder forwarder.ForwardingController
|
||||
slugManager slug.Manager
|
||||
registry Registry
|
||||
type session struct {
|
||||
randomizer random.Random
|
||||
config config.Config
|
||||
initialReq <-chan *ssh.Request
|
||||
sshChan <-chan ssh.NewChannel
|
||||
lifecycle lifecycle.Lifecycle
|
||||
interaction interaction.Interaction
|
||||
forwarder forwarder.Forwarder
|
||||
slug slug.Slug
|
||||
registry registry.Registry
|
||||
}
|
||||
|
||||
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
|
||||
return s.lifecycle
|
||||
type Config struct {
|
||||
Randomizer random.Random
|
||||
Config config.Config
|
||||
Conn *ssh.ServerConn
|
||||
InitialReq <-chan *ssh.Request
|
||||
SshChan <-chan ssh.NewChannel
|
||||
SessionRegistry registry.Registry
|
||||
PortRegistry portUtil.Port
|
||||
User string
|
||||
}
|
||||
|
||||
func (s *SSHSession) GetInteraction() interaction.Controller {
|
||||
return s.interaction
|
||||
}
|
||||
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
|
||||
|
||||
func (s *SSHSession) GetForwarder() forwarder.ForwardingController {
|
||||
return s.forwarder
|
||||
}
|
||||
func New(conf *Config) Session {
|
||||
slugManager := slug.New()
|
||||
forwarderManager := forwarder.New(conf.Config, slugManager, conf.Conn)
|
||||
lifecycleManager := lifecycle.New(conf.Conn, forwarderManager, slugManager, conf.PortRegistry, conf.SessionRegistry, conf.User)
|
||||
interactionManager := interaction.New(conf.Randomizer, conf.Config, slugManager, forwarderManager, conf.SessionRegistry, conf.User, lifecycleManager.Close)
|
||||
|
||||
func (s *SSHSession) GetSlugManager() slug.Manager {
|
||||
return s.slugManager
|
||||
}
|
||||
|
||||
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry) *SSHSession {
|
||||
slugManager := slug.NewManager()
|
||||
forwarderManager := forwarder.NewForwarder(slugManager)
|
||||
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
|
||||
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager)
|
||||
|
||||
interactionManager.SetLifecycle(lifecycleManager)
|
||||
interactionManager.SetSlugModificator(sessionRegistry.Update)
|
||||
forwarderManager.SetLifecycle(lifecycleManager)
|
||||
lifecycleManager.SetUnregisterClient(sessionRegistry.Remove)
|
||||
|
||||
return &SSHSession{
|
||||
initialReq: forwardingReq,
|
||||
sshReqChannel: sshChan,
|
||||
lifecycle: lifecycleManager,
|
||||
interaction: interactionManager,
|
||||
forwarder: forwarderManager,
|
||||
slugManager: slugManager,
|
||||
registry: sessionRegistry,
|
||||
return &session{
|
||||
randomizer: conf.Randomizer,
|
||||
config: conf.Config,
|
||||
initialReq: conf.InitialReq,
|
||||
sshChan: conf.SshChan,
|
||||
lifecycle: lifecycleManager,
|
||||
interaction: interactionManager,
|
||||
forwarder: forwarderManager,
|
||||
slug: slugManager,
|
||||
registry: conf.SessionRegistry,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHSession) Start() error {
|
||||
channel := <-s.sshReqChannel
|
||||
func (s *session) Lifecycle() lifecycle.Lifecycle {
|
||||
return s.lifecycle
|
||||
}
|
||||
|
||||
func (s *session) Interaction() interaction.Interaction {
|
||||
return s.interaction
|
||||
}
|
||||
|
||||
func (s *session) Forwarder() forwarder.Forwarder {
|
||||
return s.forwarder
|
||||
}
|
||||
|
||||
func (s *session) Slug() slug.Slug {
|
||||
return s.slug
|
||||
}
|
||||
|
||||
func (s *session) Detail() *types.Detail {
|
||||
tunnelTypeMap := map[types.TunnelType]string{
|
||||
types.TunnelTypeHTTP: "HTTP",
|
||||
types.TunnelTypeTCP: "TCP",
|
||||
}
|
||||
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
|
||||
if !ok {
|
||||
tunnelType = "UNKNOWN"
|
||||
}
|
||||
|
||||
return &types.Detail{
|
||||
ForwardingType: tunnelType,
|
||||
Slug: s.slug.String(),
|
||||
UserID: s.lifecycle.User(),
|
||||
Active: s.lifecycle.IsActive(),
|
||||
StartedAt: s.lifecycle.StartedAt(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) Start() error {
|
||||
if err := s.setupSessionMode(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tcpipReq := s.waitForTCPIPForward()
|
||||
if tcpipReq == nil {
|
||||
return s.handleMissingForwardRequest()
|
||||
}
|
||||
|
||||
if s.shouldRejectUnauthorized() {
|
||||
return s.denyForwardingRequest(tcpipReq, nil, nil, "headless forwarding only allowed on node mode")
|
||||
}
|
||||
|
||||
if err := s.HandleTCPIPForward(tcpipReq); err != nil {
|
||||
return err
|
||||
}
|
||||
s.interaction.Start()
|
||||
|
||||
return s.waitForSessionEnd()
|
||||
}
|
||||
|
||||
func (s *session) setupSessionMode() error {
|
||||
select {
|
||||
case channel, ok := <-s.sshChan:
|
||||
if !ok {
|
||||
log.Println("Forwarding request channel closed")
|
||||
return nil
|
||||
}
|
||||
return s.setupInteractiveMode(channel)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
s.interaction.SetMode(types.InteractiveModeHEADLESS)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
|
||||
ch, reqs, err := channel.Accept()
|
||||
if err != nil {
|
||||
log.Printf("failed to accept channel: %v", err)
|
||||
return err
|
||||
}
|
||||
go s.HandleGlobalRequest(reqs)
|
||||
|
||||
tcpipReq := s.waitForTCPIPForward()
|
||||
if tcpipReq == nil {
|
||||
_, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))))
|
||||
go func() {
|
||||
err = s.HandleGlobalRequest(reqs)
|
||||
if err != nil {
|
||||
return err
|
||||
log.Printf("global request handler error: %v", err)
|
||||
}
|
||||
if err := s.lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
}
|
||||
return fmt.Errorf("No forwarding Request")
|
||||
}
|
||||
}()
|
||||
|
||||
s.lifecycle.SetChannel(ch)
|
||||
s.interaction.SetChannel(ch)
|
||||
s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
|
||||
|
||||
s.HandleTCPIPForward(tcpipReq)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) handleMissingForwardRequest() error {
|
||||
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("no forwarding Request")
|
||||
}
|
||||
|
||||
func (s *session) shouldRejectUnauthorized() bool {
|
||||
return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
|
||||
s.config.Mode() == types.ServerModeSTANDALONE &&
|
||||
s.lifecycle.User() == "UNAUTHORIZED"
|
||||
}
|
||||
|
||||
func (s *session) waitForSessionEnd() error {
|
||||
if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
||||
log.Printf("ssh connection closed with error: %v", err)
|
||||
}
|
||||
|
||||
if err := s.lifecycle.Close(); err != nil {
|
||||
log.Printf("failed to close session: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSHSession) waitForTCPIPForward() *ssh.Request {
|
||||
func (s *session) waitForTCPIPForward() *ssh.Request {
|
||||
select {
|
||||
case req, ok := <-s.initialReq:
|
||||
if !ok {
|
||||
@@ -121,3 +217,186 @@ func (s *SSHSession) waitForTCPIPForward() *ssh.Request {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) handleWindowChange(req *ssh.Request) error {
|
||||
p := req.Payload
|
||||
if len(p) < 16 {
|
||||
log.Println("invalid window-change payload")
|
||||
return req.Reply(false, nil)
|
||||
}
|
||||
|
||||
cols := binary.BigEndian.Uint32(p[0:4])
|
||||
rows := binary.BigEndian.Uint32(p[4:8])
|
||||
|
||||
s.interaction.SetWH(int(cols), int(rows))
|
||||
return req.Reply(true, nil)
|
||||
}
|
||||
|
||||
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
|
||||
for req := range GlobalRequest {
|
||||
switch req.Type {
|
||||
case "shell", "pty-req":
|
||||
if err := req.Reply(true, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
case "window-change":
|
||||
if err := s.handleWindowChange(req); err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
log.Println("Unknown request type:", req.Type)
|
||||
if err := req.Reply(false, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) {
|
||||
var forwardPayload struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
if err = ssh.Unmarshal(payload, &forwardPayload); err != nil {
|
||||
return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err)
|
||||
}
|
||||
|
||||
if forwardPayload.BindPort > 65535 {
|
||||
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
|
||||
}
|
||||
|
||||
port = uint16(forwardPayload.BindPort)
|
||||
|
||||
if isBlockedPort(port) {
|
||||
return "", 0, fmt.Errorf("port is blocked")
|
||||
}
|
||||
|
||||
if port == 0 {
|
||||
unassigned, ok := s.lifecycle.PortRegistry().Unassigned()
|
||||
if !ok {
|
||||
return "", 0, fmt.Errorf("no available port")
|
||||
}
|
||||
return forwardPayload.BindAddr, unassigned, nil
|
||||
}
|
||||
|
||||
return forwardPayload.BindAddr, port, nil
|
||||
}
|
||||
|
||||
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
|
||||
var errs []error
|
||||
if key != nil {
|
||||
s.registry.Remove(*key)
|
||||
}
|
||||
|
||||
if listener != nil {
|
||||
errs = append(errs, listener.Close())
|
||||
}
|
||||
|
||||
errs = append(errs, req.Reply(false, nil))
|
||||
errs = append(errs, s.lifecycle.Close())
|
||||
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
|
||||
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
|
||||
replyPayload := struct {
|
||||
BoundPort uint32
|
||||
}{
|
||||
BoundPort: uint32(portToBind),
|
||||
}
|
||||
err := req.Reply(true, ssh.Marshal(replyPayload))
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.forwarder.SetType(tunnelType)
|
||||
s.forwarder.SetForwardedPort(portToBind)
|
||||
s.slug.Set(slug)
|
||||
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
|
||||
|
||||
if listener != nil {
|
||||
s.forwarder.SetListener(listener)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
|
||||
address, port, err := s.parseForwardPayload(req.Payload)
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
|
||||
}
|
||||
|
||||
switch port {
|
||||
case 80, 443:
|
||||
return s.HandleHTTPForward(req, port)
|
||||
default:
|
||||
return s.HandleTCPForward(req, address, port)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
|
||||
randomString, err := s.randomizer.String(20)
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
|
||||
}
|
||||
key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
|
||||
if !s.registry.Register(key, s) {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
|
||||
}
|
||||
|
||||
err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
|
||||
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
|
||||
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||
}
|
||||
|
||||
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
|
||||
listener, err := tcpServer.Listen()
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
|
||||
}
|
||||
|
||||
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
|
||||
if !s.registry.Register(key, s) {
|
||||
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
|
||||
}
|
||||
|
||||
err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
|
||||
if err != nil {
|
||||
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
|
||||
}
|
||||
|
||||
go func() {
|
||||
err = tcpServer.Serve(listener)
|
||||
if err != nil {
|
||||
log.Printf("Failed serving tcp server: %s\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isBlockedPort(port uint16) bool {
|
||||
if port == 80 || port == 443 {
|
||||
return false
|
||||
}
|
||||
if port < 1024 && port != 0 {
|
||||
return true
|
||||
}
|
||||
for _, p := range blockedReservedPorts {
|
||||
if p == port {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,24 +1,24 @@
|
||||
package slug
|
||||
|
||||
type Manager interface {
|
||||
Get() string
|
||||
type Slug interface {
|
||||
String() string
|
||||
Set(slug string)
|
||||
}
|
||||
|
||||
type manager struct {
|
||||
type slug struct {
|
||||
slug string
|
||||
}
|
||||
|
||||
func NewManager() Manager {
|
||||
return &manager{
|
||||
func New() Slug {
|
||||
return &slug{
|
||||
slug: "",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *manager) Get() string {
|
||||
func (s *slug) String() string {
|
||||
return s.slug
|
||||
}
|
||||
|
||||
func (s *manager) Set(slug string) {
|
||||
func (s *slug) Set(slug string) {
|
||||
s.slug = slug
|
||||
}
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
package slug
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type SlugTestSuite struct {
|
||||
suite.Suite
|
||||
slug Slug
|
||||
}
|
||||
|
||||
func (suite *SlugTestSuite) SetupTest() {
|
||||
suite.slug = New()
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
s := New()
|
||||
|
||||
assert.NotNil(t, s, "New() should return a non-nil Slug")
|
||||
assert.Implements(t, (*Slug)(nil), s, "New() should return a type that implements Slug interface")
|
||||
assert.Equal(t, "", s.String(), "New() should initialize with empty string")
|
||||
}
|
||||
|
||||
func (suite *SlugTestSuite) TestString() {
|
||||
assert.Equal(suite.T(), "", suite.slug.String(), "String() should return empty string initially")
|
||||
|
||||
suite.slug.Set("test-slug")
|
||||
assert.Equal(suite.T(), "test-slug", suite.slug.String(), "String() should return the set value")
|
||||
}
|
||||
|
||||
func (suite *SlugTestSuite) TestSet() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple slug",
|
||||
input: "hello-world",
|
||||
expected: "hello-world",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "slug with numbers",
|
||||
input: "test-123",
|
||||
expected: "test-123",
|
||||
},
|
||||
{
|
||||
name: "slug with special characters",
|
||||
input: "hello_world-123",
|
||||
expected: "hello_world-123",
|
||||
},
|
||||
{
|
||||
name: "overwrite existing slug",
|
||||
input: "new-slug",
|
||||
expected: "new-slug",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
suite.Run(tc.name, func() {
|
||||
suite.slug.Set(tc.input)
|
||||
assert.Equal(suite.T(), tc.expected, suite.slug.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (suite *SlugTestSuite) TestMultipleSet() {
|
||||
suite.slug.Set("first-slug")
|
||||
assert.Equal(suite.T(), "first-slug", suite.slug.String())
|
||||
|
||||
suite.slug.Set("second-slug")
|
||||
assert.Equal(suite.T(), "second-slug", suite.slug.String())
|
||||
|
||||
suite.slug.Set("")
|
||||
assert.Equal(suite.T(), "", suite.slug.String())
|
||||
}
|
||||
|
||||
func TestSlugIsolation(t *testing.T) {
|
||||
slug1 := New()
|
||||
slug2 := New()
|
||||
|
||||
slug1.Set("slug-one")
|
||||
slug2.Set("slug-two")
|
||||
|
||||
assert.Equal(t, "slug-one", slug1.String(), "First slug should maintain its value")
|
||||
assert.Equal(t, "slug-two", slug2.String(), "Second slug should maintain its value")
|
||||
}
|
||||
|
||||
func TestSlugTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(SlugTestSuite))
|
||||
}
|
||||
+37
-7
@@ -1,20 +1,50 @@
|
||||
package types
|
||||
|
||||
type Status string
|
||||
import "time"
|
||||
|
||||
type SessionStatus int
|
||||
|
||||
const (
|
||||
INITIALIZING Status = "INITIALIZING"
|
||||
RUNNING Status = "RUNNING"
|
||||
SETUP Status = "SETUP"
|
||||
SessionStatusINITIALIZING SessionStatus = iota
|
||||
SessionStatusRUNNING
|
||||
SessionStatusCLOSED
|
||||
)
|
||||
|
||||
type TunnelType string
|
||||
type InteractiveMode int
|
||||
|
||||
const (
|
||||
HTTP TunnelType = "HTTP"
|
||||
TCP TunnelType = "TCP"
|
||||
InteractiveModeINTERACTIVE InteractiveMode = iota + 1
|
||||
InteractiveModeHEADLESS
|
||||
)
|
||||
|
||||
type TunnelType int
|
||||
|
||||
const (
|
||||
TunnelTypeUNKNOWN TunnelType = iota
|
||||
TunnelTypeHTTP
|
||||
TunnelTypeTCP
|
||||
)
|
||||
|
||||
type ServerMode int
|
||||
|
||||
const (
|
||||
ServerModeSTANDALONE = iota + 1
|
||||
ServerModeNODE
|
||||
)
|
||||
|
||||
type SessionKey struct {
|
||||
Id string
|
||||
Type TunnelType
|
||||
}
|
||||
|
||||
type Detail struct {
|
||||
ForwardingType string `json:"forwarding_type,omitempty"`
|
||||
Slug string `json:"slug,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Active bool `json:"active,omitempty"`
|
||||
StartedAt time.Time `json:"started_at,omitempty"`
|
||||
}
|
||||
|
||||
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
|
||||
"Content-Length: 11\r\n" +
|
||||
"Content-Type: text/plain\r\n\r\n" +
|
||||
|
||||
Reference in New Issue
Block a user