52 Commits

Author SHA1 Message Date
78b7b894d9 chore(deps): update actions/checkout action to v6
SonarQube Scan / SonarQube Trigger (push) Successful in 49s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 48s
2026-01-22 16:01:02 +00:00
d91eecb2a0 chore: Refactor and optimize project architecture
Docker Build and Push / build-and-push-tags (push) Has been skipped
SonarQube Scan / SonarQube Trigger (push) Successful in 54s
Docker Build and Push / build-and-push-branches (push) Successful in 12m17s
- Fix: Resolve goroutine deadlock on early connection close
- Refactor: Simplify Start() method, unify forwarding logic, and enhance HTTP handler modularity
- Improve: Connection handling, header parsing, and resource management
- Refactor: Centralize environment loading, enforce typed access, and cleanup config structure
- Enhance: SonarQube scan integration for CI
- Chore: Reorganize project layout and simplify lifecycle management
- Define reusable constants for registry errors

Reviewed-on: #74
2026-01-22 22:16:33 +07:00
961a905542 chore(restructure): refactor architecture, config, and lifecycle management
Docker Build and Push / build-and-push-tags (push) Has been skipped
SonarQube Scan / SonarQube Trigger (push) Successful in 44s
Docker Build and Push / build-and-push-branches (push) Successful in 11m16s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 46s
- Reorganized internal packages and overall project structure
- Moved HTTP/HTTPS/TCP servers into the transport layer
- Decoupled server initialization from HTTP/HTTPS/TCP startup logic
- Separated HTTP parsing, streaming, middleware, and session registry concerns
- Refactored session and forwarder responsibilities for clearer ownership
- Centralized environment loading with validated, typed config access
- Made config immutable after initialization and normalized enum naming
- Improved resource lifecycle handling and error aggregation on shutdown
- Introduced reusable, package-level registry errors
- Added SonarQube scanning to CI pipeline

Reviewed-on: #73
2026-01-22 00:48:40 +07:00
634c8321ef refactor(registry): define reusable constant errors
SonarQube Scan / SonarQube Trigger (push) Successful in 52s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 46s
- Introduced package-level error variables in registry to replace repeated fmt.Errorf calls
- Added errors like ErrSessionNotFound, ErrSlugInUse, ErrInvalidSlug, ErrForbiddenSlug, ErrSlugChangeNotAllowed, and ErrSlugUnchanged
2026-01-22 00:39:28 +07:00
9f4c24a3f3 refactor(lifecycle): reorder resource closing and simplify Close()
SonarQube Scan / SonarQube Trigger (push) Successful in 53s
- Close channel and connection first, then remove session
- Close forwarded port and forwarder at the end for TCP tunnels
- Aggregate all errors using errors.Join instead of failing early
2026-01-21 21:59:59 +07:00
1408b80917 ci: add sonarqube scan
SonarQube Scan / SonarQube Trigger (push) Successful in 48s
2026-01-21 21:24:57 +07:00
2bc20dd991 refactor(config): centralize env loading and enforce typed access
- Centralize environment variable loading in config.MustLoad
- Parse and validate all env vars once at initialization
- Make config fields private and read-only
- Remove public Getenv usage in favor of typed accessors
- Improve validation and initialization order
- Normalize enum naming to be idiomatic and avoid constant collisions
2026-01-21 19:43:19 +07:00
1e12373359 chore(restructure): reorganize project layout
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 13m1s
- Reorganize internal packages and overall project structure
- Update imports and wiring to match the new layout
- Separate HTTP parsing and streaming from the server package
- Separate middleware from the server package
- Separate session registry from the session package
- Move HTTP, HTTPS, and TCP servers to the transport package
- Session package no longer starts the TCP server directly
- Server package no longer starts HTTP/HTTPS servers on initialization
- Forwarder no longer handles accepting TCP requests
- Move session details to the types package
- HTTP/HTTPS initialization is now the responsibility of main
2026-01-21 14:06:46 +07:00
9a4539cc02 refactor(httpheader): extract header parsing into dedicated package
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m19s
Moved HTTP header parsing and building logic from server package to internal/httpheader
2026-01-20 21:15:34 +07:00
e3ead4d52f refactor: optimize header parsing and remove factory naming
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m20s
- Remove factory naming
- Use direct byte indexing instead of bytes.TrimRight
- Extract parseStartLine and setRemainingHeaders helpers
2026-01-20 20:56:08 +07:00
aa1a465178 refactor(forwarder): improve connection handling and cleanup
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Has been cancelled
- Extract copyAndClose method for bidirectional data transfe
- Add closeWriter helper for graceful connection shutdown
- Add handleIncomingConnection helper
- Add openForwardedChannel helper
2026-01-20 19:01:15 +07:00
27f49879af refactor(server): enhance HTTP handler modularity and fix resource leak
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m43s
- Rename customWriter struct to httpWriter for clarity
- Add closeWriter field to properly close write side of connections
- Update all cw variable references to hw
- Merge handlerTLS into handler function to reduce code duplication
- Extract handler into smaller, focused methods
- Split Read/Write/forwardRequest into composable functions

Fixes resource leak where connections weren't properly closed on the
write side, matching the forwarder's CloseWrite() pattern.
2026-01-19 22:41:04 +07:00
adb0264bb5 refactor(session): simplify Start() and unify forwarding logic
- Extract helper functions from Start() for better code organization
- Eliminate duplication with finalizeForwarding() method
- Consolidate denial logic into denyForwardingRequest()
- Update all handler methods to return errors instead of logging internally
- Improve error handling consistency across all operations
2026-01-19 15:53:16 +07:00
8fb19af5a6 fix: resolve copy goroutine deadlock on early connection close
- Add proper CloseWrite handling to signal EOF to other goroutine
- Ensure both copy goroutines terminate when either side closes
- Prevent goroutine leaks for SSH forwarded-tcpip channels:
    - Use select with default when sending result to resultChan
    - Close unused SSH channels and discard requests if main goroutine has already timed out
2026-01-19 00:20:28 +07:00
41fdb5639c Merge pull request 'refactor: explicit initialization and dependency injection' (#70) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 9m49s
Reviewed-on: #70
2026-01-18 21:46:59 +07:00
44d224f491 refactor: explicit initialization and dependency injection
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 10m10s
- Replace init() with config.Load() function when loading env variables
- Inject portRegistry into session, server, and lifecycle structs
- Inject sessionRegistry directly into interaction and lifecycle
- Remove SetSessionRegistry function and global port variables
- Pass ssh.Conn directly to forwarder constructor instead of lifecycle interface
- Pass user and closeFunc callback to interaction constructor instead of lifecycle interface
- Eliminate circular dependencies between lifecycle, forwarder, and interaction
- Remove setter methods (SetLifecycle) from forwarder and interaction interfaces
2026-01-18 21:20:05 +07:00
9be0328e24 Merge pull request 'staging' (#69) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 9m30s
Reviewed-on: #69
2026-01-17 19:15:40 +07:00
2b9bca65d5 refactor(interaction): separate view and update logic into modular files
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 11m44s
- Extract slug editing logic to slug.go (slugView/slugUpdate)
- Extract commands menu logic to commands.go (commandsView/commandsUpdate)
- Extract coming soon modal to coming_soon.go (comingSoonView/comingSoonUpdate)
- Extract main dashboard logic to dashboard.go (dashboardView/dashboardUpdate)
- Create model.go for shared model struct and helper functions
- Replace math/rand with crypto/rand for random subdomain generation
- Remove legacy TLS cipher suite configuration
2026-01-17 17:33:10 +07:00
6587dc0f39 refactor(interaction): separate view and update logic into modular files
- Extract slug editing logic to slug.go (slugView/slugUpdate)
- Extract commands menu logic to commands.go (commandsView/commandsUpdate)
- Extract coming soon modal to coming_soon.go (comingSoonView/comingSoonUpdate)
- Extract main dashboard logic to dashboard.go (dashboardView/dashboardUpdate)
- Create model.go for shared model struct and helper functions
- Replace math/rand with crypto/rand for random subdomain generation
- Remove legacy TLS cipher suite configuration
2026-01-17 17:30:21 +07:00
f421781f44 Merge pull request 'refactor: convert structs to interfaces and rename accessors' (#68) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 10m34s
Reviewed-on: #68
2026-01-16 16:41:22 +07:00
6969d6823a Merge branch 'main' into staging
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 11m38s
2026-01-16 16:35:36 +07:00
1a04af8873 Merge branch 'main' into staging
Docker Build and Push / build-and-push-branches (push) Successful in 11m35s
Docker Build and Push / build-and-push-tags (push) Has been skipped
2026-01-16 16:28:39 +07:00
19135ceb42 refactor: convert structs to interfaces and rename accessors
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Has been cancelled
- Convert struct types to interfaces
- Rename getter and setter methods
- Add Close method to server interface
- Merge handler functionality into session file
- Handle lifecycle.Connection().Wait()
- fix panic on nil connection in SSH server
2026-01-16 15:25:31 +07:00
edb11dbc51 Merge pull request 'chore(deps): update golang docker tag to v1.25.6' (#67) from renovate/golang-1.x into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m34s
2026-01-16 05:01:06 +07:00
819f044275 chore(deps): update golang docker tag to v1.25.6 2026-01-15 22:01:02 +00:00
a7ebf2c5db Merge pull request 'fix(deps): update module golang.org/x/crypto to v0.47.0' (#66) from renovate/golang.org-x-crypto-0.x into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 10m34s
Reviewed-on: #66
2026-01-14 10:42:52 +00:00
64c1038f4b fix(deps): update module golang.org/x/crypto to v0.47.0 2026-01-14 10:41:47 +00:00
aafea49975 feat: integrate gRPC, session refactor, SSH headless support, and bug fixes
Docker Build and Push / build-and-push-tags (push) Successful in 11m34s
Docker Build and Push / build-and-push-branches (push) Has been skipped
- gRPC integration: slug edit handling, get sessions by user, and session requests from gRPC server
- Refactor gRPC client: simplify processEventStream and handle authenticated user info
- Session management improvements: use session key for registry, forwarder session termination, inject SessionRegistry interface
- SSH enhancements: add headless mode support for SSH -N connections
- Bug fixes:
  - prevent subdomain changes to already-in-use subdomains
  - fix startup order and environment variable keys
  - atomic ClaimPort() to prevent race conditions
- Refactors:
  - consolidate error handling
  - replace Get/Set patterns with idiomatic Go interfaces
  - change enums from string to int
- CI cleanup: remove renovate bot

Reviewed-on: #65
2026-01-14 10:16:43 +00:00
dbdf8094fa refactor: replace Get/Set patterns with idiomatic Go interfaces
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 13m4s
- rename constructors to New
- remove Get/Set-style accessors
- replace string-based enums with iota-backed types
2026-01-14 16:54:10 +07:00
ae3ed52d16 fix(port): add atomic ClaimPort() to prevent race condition
- Replace GetPortStatus/SetPortStatus calls with atomic ClaimPort() operation.
- Fixed a logic error when handling headless tunneling.
2026-01-14 16:51:50 +07:00
fb638636bf refactor: consolidate error handling with fail() function in session handlers
- Replace repetitive error handling code with fail() function in HandleGlobalRequest
- Standardize error response pattern across all handler methods
- Improve code maintainability and reduce duplication
2026-01-14 16:51:50 +07:00
da29df85b7 feat: add headless mode support for SSH -N connections
- use s.lifecycle.GetConnection().Wait() to block until SSH connection closes
- Prevent premature session closure in headless mode

In headless mode (ssh -N), there's no channel interaction to block on,
so the session would immediately return and close. Now blocking on
conn.Wait() keeps the session alive until the client disconnects.
2026-01-14 16:51:50 +07:00
8b0e08c629 fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 2026-01-14 16:51:50 +07:00
f0804d6946 ci: remove renovate 2026-01-14 16:51:50 +07:00
09e526cd1e feat: add authenticated user info and restructure handleConnection
- Display authenticated username in welcome page information box
- Refactor handleConnection function for better structure and clarity
2026-01-14 16:51:50 +07:00
887ebf78b1 refactor(grpc/client): simplify processEventStream with per-event handlers
- Extract eventHandlers dispatch table
- Add per-event handlers: handleSlugChange, handleGetSessions, handleTerminateSession
- Introduce sendNode helper to centralize send/error handling and preserve connection-error propagation
- Add protoToTunnelType for tunnel-type validation
- Map unknown proto.TunnelType to types.UNKNOWN in protoToTunnelType and return a descriptive error
- Reduce boilerplate and improve readability of processEventStream
2026-01-14 16:51:50 +07:00
bef7a49f88 feat: implement forwarder session termination 2026-01-14 16:51:50 +07:00
17633b4e3c refactor: inject SessionRegistry interface instead of individual functions 2026-01-14 16:51:50 +07:00
f25d61d1d1 update: proto file to v1.3.0 2026-01-14 16:51:50 +07:00
8782b77b74 feat(session): use session key for registry 2026-01-14 16:51:50 +07:00
fc3cd886db fix: use correct environment variable key 2026-01-14 16:51:50 +07:00
b0da57db0d fix: startup order 2026-01-14 16:51:50 +07:00
0bd6eeadf3 feat: implement sessions request from grpc server 2026-01-14 16:51:50 +07:00
449f546e04 feat: implement sessions request from grpc server 2026-01-14 16:51:50 +07:00
4644420eee feat: implement get sessions by user 2026-01-14 16:51:50 +07:00
c9bf9e62bd feat(grpc): integrate slug edit handling 2026-01-14 16:51:50 +07:00
57d2136377 WIP: gRPC integration, initial implementation 2026-01-14 16:51:47 +07:00
8a34aaba80 WIP: gRPC integration, initial implementation 2026-01-14 16:51:35 +07:00
ff995a929e revert 01ddc76f7e
revert Merge pull request 'fix(deps): update module github.com/caddyserver/certmagic to v0.25.1' (#58) from renovate/github.com-caddyserver-certmagic-0.x into main
2026-01-14 16:51:35 +07:00
32ac9c1749 fix(deps): update module github.com/caddyserver/certmagic to v0.25.1
# Conflicts:
#	go.mod
2026-01-14 16:51:30 +07:00
e051a5b742 Merge pull request 'fix(deps): update module golang.org/x/crypto to v0.47.0' (#64) from renovate/golang.org-x-crypto-0.x into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 9m51s
renovate / renovate (push) Successful in 55s
2026-01-12 18:20:57 +00:00
d35228759c fix(deps): update module golang.org/x/crypto to v0.47.0 2026-01-12 18:20:53 +00:00
49 changed files with 2815 additions and 2448 deletions
-21
View File
@@ -1,21 +0,0 @@
name: renovate
on:
schedule:
- cron: "0 0 * * *"
push:
branches:
- staging
jobs:
renovate:
runs-on: ubuntu-latest
container: git.fossy.my.id/renovate-clanker/renovate:latest
steps:
- uses: actions/checkout@v6
- run: renovate
env:
RENOVATE_CONFIG_FILE: ${{ gitea.workspace }}/renovate-config.js
LOG_LEVEL: "debug"
RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }}
GITHUB_COM_TOKEN: ${{ secrets.COM_TOKEN }}
+20
View File
@@ -0,0 +1,20 @@
on:
push:
pull_request:
types: [opened, synchronize, reopened]
name: SonarQube Scan
jobs:
sonarqube:
name: SonarQube Trigger
runs-on: ubuntu-latest
steps:
- name: Checking out
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: SonarQube Scan
uses: SonarSource/sonarqube-scan-action@v7.0.0
env:
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
+1 -1
View File
@@ -1,4 +1,4 @@
FROM golang:1.25.5-alpine AS go_builder FROM golang:1.25.6-alpine AS go_builder
ARG VERSION=dev ARG VERSION=dev
ARG BUILD_DATE=unknown ARG BUILD_DATE=unknown
+6 -6
View File
@@ -4,14 +4,14 @@ go 1.25.5
require ( require (
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0
github.com/caddyserver/certmagic v0.25.0 github.com/caddyserver/certmagic v0.25.1
github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbles v0.21.0
github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0 github.com/charmbracelet/lipgloss v1.1.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/libdns/cloudflare v0.2.2 github.com/libdns/cloudflare v0.2.2
github.com/muesli/termenv v0.16.0 github.com/muesli/termenv v0.16.0
golang.org/x/crypto v0.46.0 golang.org/x/crypto v0.47.0
google.golang.org/grpc v1.78.0 google.golang.org/grpc v1.78.0
google.golang.org/protobuf v1.36.11 google.golang.org/protobuf v1.36.11
) )
@@ -19,7 +19,7 @@ require (
require ( require (
github.com/atotto/clipboard v0.1.4 // indirect github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect github.com/caddyserver/zerossl v0.1.4 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.3 // indirect github.com/charmbracelet/x/ansi v0.11.3 // indirect
github.com/charmbracelet/x/cellbuf v0.0.14 // indirect github.com/charmbracelet/x/cellbuf v0.0.14 // indirect
@@ -48,8 +48,8 @@ require (
golang.org/x/mod v0.31.0 // indirect golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.48.0 // indirect golang.org/x/net v0.48.0 // indirect
golang.org/x/sync v0.19.0 // indirect golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.32.0 // indirect golang.org/x/text v0.33.0 // indirect
golang.org/x/tools v0.40.0 // indirect golang.org/x/tools v0.40.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
) )
+12 -16
View File
@@ -1,7 +1,3 @@
git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0 h1:RhcBKUG41/om4jgN+iF/vlY/RojTeX1QhBa4p4428ec=
git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY=
git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0 h1:tpJSKjaSmV+vxxbVx6qnStjxFVXjj2M0rygWXxLb99o=
git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY=
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA= git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA=
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo= git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
@@ -10,10 +6,10 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8= github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA= github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic= github.com/caddyserver/certmagic v0.25.1 h1:4sIKKbOt5pg6+sL7tEwymE1x2bj6CHr80da1CRRIPbY=
github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA= github.com/caddyserver/certmagic v0.25.1/go.mod h1:VhyvndxtVton/Fo/wKhRoC46Rbw1fmjvQ3GjHYSQTEY=
github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA= github.com/caddyserver/zerossl v0.1.4 h1:CVJOE3MZeFisCERZjkxIcsqIH4fnFdlYWnPYeFtBHRw=
github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
@@ -114,8 +110,8 @@ go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U= go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U=
go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ= go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
@@ -126,12 +122,12 @@ golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
+53 -25
View File
@@ -1,35 +1,63 @@
package config package config
import ( import "tunnel_pls/types"
"log"
"os"
"strconv"
"github.com/joho/godotenv" type Config interface {
) Domain() string
SSHPort() string
func init() { HTTPPort() string
if _, err := os.Stat(".env"); err == nil { HTTPSPort() string
if err := godotenv.Load(".env"); err != nil {
log.Printf("Warning: Failed to load .env file: %s", err) TLSEnabled() bool
} TLSRedirect() bool
}
ACMEEmail() string
CFAPIToken() string
ACMEStaging() bool
AllowedPortsStart() uint16
AllowedPortsEnd() uint16
BufferSize() int
PprofEnabled() bool
PprofPort() string
Mode() types.ServerMode
GRPCAddress() string
GRPCPort() string
NodeToken() string
} }
func Getenv(key, defaultValue string) string { func MustLoad() (Config, error) {
val := os.Getenv(key) if err := loadEnvFile(); err != nil {
if val == "" { return nil, err
val = defaultValue
} }
return val cfg, err := parse()
if err != nil {
return nil, err
}
return cfg, nil
} }
func GetBufferSize() int { func (c *config) Domain() string { return c.domain }
sizeStr := Getenv("BUFFER_SIZE", "32768") func (c *config) SSHPort() string { return c.sshPort }
size, err := strconv.Atoi(sizeStr) func (c *config) HTTPPort() string { return c.httpPort }
if err != nil || size < 4096 || size > 1048576 { func (c *config) HTTPSPort() string { return c.httpsPort }
return 32768 func (c *config) TLSEnabled() bool { return c.tlsEnabled }
} func (c *config) TLSRedirect() bool { return c.tlsRedirect }
return size func (c *config) ACMEEmail() string { return c.acmeEmail }
} func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging }
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
func (c *config) BufferSize() int { return c.bufferSize }
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
func (c *config) PprofPort() string { return c.pprofPort }
func (c *config) Mode() types.ServerMode { return c.mode }
func (c *config) GRPCAddress() string { return c.grpcAddress }
func (c *config) GRPCPort() string { return c.grpcPort }
func (c *config) NodeToken() string { return c.nodeToken }
+170
View File
@@ -0,0 +1,170 @@
package config
import (
"fmt"
"log"
"os"
"strconv"
"strings"
"tunnel_pls/types"
"github.com/joho/godotenv"
)
type config struct {
domain string
sshPort string
httpPort string
httpsPort string
tlsEnabled bool
tlsRedirect bool
acmeEmail string
cfAPIToken string
acmeStaging bool
allowedPortsStart uint16
allowedPortsEnd uint16
bufferSize int
pprofEnabled bool
pprofPort string
mode types.ServerMode
grpcAddress string
grpcPort string
nodeToken string
}
func parse() (*config, error) {
mode, err := parseMode()
if err != nil {
return nil, err
}
domain := getenv("DOMAIN", "localhost")
sshPort := getenv("PORT", "2200")
httpPort := getenv("HTTP_PORT", "8080")
httpsPort := getenv("HTTPS_PORT", "8443")
tlsEnabled := getenvBool("TLS_ENABLED", false)
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
acmeStaging := getenvBool("ACME_STAGING", false)
cfToken := getenv("CF_API_TOKEN", "")
if tlsEnabled && cfToken == "" {
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
}
start, end, err := parseAllowedPorts()
if err != nil {
return nil, err
}
bufferSize := parseBufferSize()
pprofEnabled := getenvBool("PPROF_ENABLED", false)
pprofPort := getenv("PPROF_PORT", "6060")
grpcHost := getenv("GRPC_ADDRESS", "localhost")
grpcPort := getenv("GRPC_PORT", "8080")
nodeToken := getenv("NODE_TOKEN", "")
if mode == types.ServerModeNODE && nodeToken == "" {
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
}
return &config{
domain: domain,
sshPort: sshPort,
httpPort: httpPort,
httpsPort: httpsPort,
tlsEnabled: tlsEnabled,
tlsRedirect: tlsRedirect,
acmeEmail: acmeEmail,
cfAPIToken: cfToken,
acmeStaging: acmeStaging,
allowedPortsStart: start,
allowedPortsEnd: end,
bufferSize: bufferSize,
pprofEnabled: pprofEnabled,
pprofPort: pprofPort,
mode: mode,
grpcAddress: grpcHost,
grpcPort: grpcPort,
nodeToken: nodeToken,
}, nil
}
func loadEnvFile() error {
if _, err := os.Stat(".env"); err == nil {
return godotenv.Load(".env")
}
return nil
}
func parseMode() (types.ServerMode, error) {
switch strings.ToLower(getenv("MODE", "standalone")) {
case "standalone":
return types.ServerModeSTANDALONE, nil
case "node":
return types.ServerModeNODE, nil
default:
return 0, fmt.Errorf("invalid MODE value")
}
}
func parseAllowedPorts() (uint16, uint16, error) {
raw := getenv("ALLOWED_PORTS", "")
if raw == "" {
return 0, 0, nil
}
parts := strings.Split(raw, "-")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
}
start, err := strconv.ParseUint(parts[0], 10, 16)
if err != nil {
return 0, 0, err
}
end, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return 0, 0, err
}
return uint16(start), uint16(end), nil
}
func parseBufferSize() int {
raw := getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(raw)
if err != nil || size < 4096 || size > 1048576 {
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
return 4096
}
return size
}
func getenv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}
func getenvBool(key string, def bool) bool {
val := os.Getenv(key)
if val == "" {
return def
}
return val == "true"
}
+52 -106
View File
@@ -2,21 +2,18 @@ package client
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
"tunnel_pls/types" "tunnel_pls/types"
"tunnel_pls/session"
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen" proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@@ -24,83 +21,35 @@ import (
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
) )
type GrpcConfig struct { type Client interface {
Address string SubscribeEvents(ctx context.Context, identity, authToken string) error
UseTLS bool ClientConn() *grpc.ClientConn
InsecureSkipVerify bool AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error)
Timeout time.Duration Close() error
KeepAlive bool CheckServerHealth(ctx context.Context) error
MaxRetries int
KeepAliveTime time.Duration
KeepAliveTimeout time.Duration
PermitWithoutStream bool
} }
type client struct {
type Client struct { config config.Config
conn *grpc.ClientConn conn *grpc.ClientConn
config *GrpcConfig address string
sessionRegistry session.Registry sessionRegistry registry.Registry
eventService proto.EventServiceClient eventService proto.EventServiceClient
authorizeConnectionService proto.UserServiceClient authorizeConnectionService proto.UserServiceClient
closing bool closing bool
} }
func DefaultConfig() *GrpcConfig { func New(config config.Config, address string, sessionRegistry registry.Registry) (Client, error) {
return &GrpcConfig{
Address: "localhost:50051",
UseTLS: false,
InsecureSkipVerify: false,
Timeout: 10 * time.Second,
KeepAlive: true,
MaxRetries: 3,
KeepAliveTime: 2 * time.Minute,
KeepAliveTimeout: 10 * time.Second,
PermitWithoutStream: false,
}
}
func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error) {
if config == nil {
config = DefaultConfig()
} else {
defaults := DefaultConfig()
if config.Address == "" {
config.Address = defaults.Address
}
if config.Timeout == 0 {
config.Timeout = defaults.Timeout
}
if config.MaxRetries == 0 {
config.MaxRetries = defaults.MaxRetries
}
if config.KeepAliveTime == 0 {
config.KeepAliveTime = defaults.KeepAliveTime
}
if config.KeepAliveTimeout == 0 {
config.KeepAliveTimeout = defaults.KeepAliveTimeout
}
}
var opts []grpc.DialOption var opts []grpc.DialOption
if config.UseTLS { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
tlsConfig := &tls.Config{
InsecureSkipVerify: config.InsecureSkipVerify, kaParams := keepalive.ClientParameters{
} Time: 2 * time.Minute,
creds := credentials.NewTLS(tlsConfig) Timeout: 10 * time.Second,
opts = append(opts, grpc.WithTransportCredentials(creds)) PermitWithoutStream: false,
} else {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} }
if config.KeepAlive { opts = append(opts, grpc.WithKeepaliveParams(kaParams))
kaParams := keepalive.ClientParameters{
Time: config.KeepAliveTime,
Timeout: config.KeepAliveTimeout,
PermitWithoutStream: config.PermitWithoutStream,
}
opts = append(opts, grpc.WithKeepaliveParams(kaParams))
}
opts = append(opts, opts = append(opts,
grpc.WithDefaultCallOptions( grpc.WithDefaultCallOptions(
@@ -109,24 +58,25 @@ func New(config *GrpcConfig, sessionRegistry session.Registry) (*Client, error)
), ),
) )
conn, err := grpc.NewClient(config.Address, opts...) conn, err := grpc.NewClient(address, opts...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", config.Address, err) return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err)
} }
eventService := proto.NewEventServiceClient(conn) eventService := proto.NewEventServiceClient(conn)
authorizeConnectionService := proto.NewUserServiceClient(conn) authorizeConnectionService := proto.NewUserServiceClient(conn)
return &Client{ return &client{
conn: conn,
config: config, config: config,
conn: conn,
address: address,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
eventService: eventService, eventService: eventService,
authorizeConnectionService: authorizeConnectionService, authorizeConnectionService: authorizeConnectionService,
}, nil }, nil
} }
func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string) error { func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error {
const ( const (
baseBackoff = time.Second baseBackoff = time.Second
maxBackoff = 30 * time.Second maxBackoff = 30 * time.Second
@@ -209,7 +159,7 @@ func (c *Client) SubscribeEvents(ctx context.Context, identity, authToken string
} }
} }
func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error { func (c *client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error {
handlers := c.eventHandlers(subscribe) handlers := c.eventHandlers(subscribe)
for { for {
@@ -230,7 +180,7 @@ func (c *Client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Nod
} }
} }
func (c *Client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error { func (c *client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error {
return map[proto.EventType]func(*proto.Events) error{ return map[proto.EventType]func(*proto.Events) error{
proto.EventType_SLUG_CHANGE: func(evt *proto.Events) error { return c.handleSlugChange(subscribe, evt) }, proto.EventType_SLUG_CHANGE: func(evt *proto.Events) error { return c.handleSlugChange(subscribe, evt) },
proto.EventType_GET_SESSIONS: func(evt *proto.Events) error { return c.handleGetSessions(subscribe, evt) }, proto.EventType_GET_SESSIONS: func(evt *proto.Events) error { return c.handleGetSessions(subscribe, evt) },
@@ -238,13 +188,13 @@ func (c *Client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, pr
} }
} }
func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
slugEvent := evt.GetSlugEvent() slugEvent := evt.GetSlugEvent()
user := slugEvent.GetUser() user := slugEvent.GetUser()
oldSlug := slugEvent.GetOld() oldSlug := slugEvent.GetOld()
newSlug := slugEvent.GetNew() newSlug := slugEvent.GetNew()
userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.HTTP}) userSession, err := c.sessionRegistry.Get(types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP})
if err != nil { if err != nil {
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
@@ -254,7 +204,7 @@ func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
}, "slug change failure response") }, "slug change failure response")
} }
if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.HTTP}, types.SessionKey{Id: newSlug, Type: types.HTTP}); err != nil { if err = c.sessionRegistry.Update(user, types.SessionKey{Id: oldSlug, Type: types.TunnelTypeHTTP}, types.SessionKey{Id: newSlug, Type: types.TunnelTypeHTTP}); err != nil {
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{ Payload: &proto.Node_SlugEventResponse{
@@ -263,7 +213,7 @@ func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
}, "slug change failure response") }, "slug change failure response")
} }
userSession.GetInteraction().Redraw() userSession.Interaction().Redraw()
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE, Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{ Payload: &proto.Node_SlugEventResponse{
@@ -272,14 +222,14 @@ func (c *Client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node,
}, "slug change success response") }, "slug change success response")
} }
func (c *Client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
sessions := c.sessionRegistry.GetAllSessionFromUser(evt.GetGetSessionsEvent().GetIdentity()) sessions := c.sessionRegistry.GetAllSessionFromUser(evt.GetGetSessionsEvent().GetIdentity())
var details []*proto.Detail var details []*proto.Detail
for _, ses := range sessions { for _, ses := range sessions {
detail := ses.Detail() detail := ses.Detail()
details = append(details, &proto.Detail{ details = append(details, &proto.Detail{
Node: config.Getenv("DOMAIN", "localhost"), Node: c.config.Domain(),
ForwardingType: detail.ForwardingType, ForwardingType: detail.ForwardingType,
Slug: detail.Slug, Slug: detail.Slug,
UserId: detail.UserID, UserId: detail.UserID,
@@ -296,7 +246,7 @@ func (c *Client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node
}, "send get sessions response") }, "send get sessions response")
} }
func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error { func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
terminate := evt.GetTerminateSessionEvent() terminate := evt.GetTerminateSessionEvent()
user := terminate.GetUser() user := terminate.GetUser()
slug := terminate.GetSlug() slug := terminate.GetSlug()
@@ -321,7 +271,7 @@ func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto
}, "terminate session fetch failed") }, "terminate session fetch failed")
} }
if err = userSession.GetLifecycle().Close(); err != nil { if err = userSession.Lifecycle().Close(); err != nil {
return c.sendNode(subscribe, &proto.Node{ return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION, Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{ Payload: &proto.Node_TerminateSessionEventResponse{
@@ -338,7 +288,7 @@ func (c *Client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto
}, "terminate session success response") }, "terminate session success response")
} }
func (c *Client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error { func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error {
if err := subscribe.Send(node); err != nil { if err := subscribe.Send(node); err != nil {
if c.isConnectionError(err) { if c.isConnectionError(err) {
return err return err
@@ -348,22 +298,22 @@ func (c *Client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.E
return nil return nil
} }
func (c *Client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) { func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
switch t { switch t {
case proto.TunnelType_HTTP: case proto.TunnelType_HTTP:
return types.HTTP, nil return types.TunnelTypeHTTP, nil
case proto.TunnelType_TCP: case proto.TunnelType_TCP:
return types.TCP, nil return types.TunnelTypeTCP, nil
default: default:
return types.UNKNOWN, fmt.Errorf("unknown tunnel type received") return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
} }
} }
func (c *Client) GetConnection() *grpc.ClientConn { func (c *client) ClientConn() *grpc.ClientConn {
return c.conn return c.conn
} }
func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) { func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token}) check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token})
if err != nil { if err != nil {
return false, "UNAUTHORIZED", err return false, "UNAUTHORIZED", err
@@ -375,17 +325,8 @@ func (c *Client) AuthorizeConn(ctx context.Context, token string) (authorized bo
return true, check.GetUser(), nil return true, check.GetUser(), nil
} }
func (c *Client) Close() error { func (c *client) CheckServerHealth(ctx context.Context) error {
if c.conn != nil { healthClient := grpc_health_v1.NewHealthClient(c.ClientConn())
log.Printf("Closing gRPC connection to %s", c.config.Address)
c.closing = true
return c.conn.Close()
}
return nil
}
func (c *Client) CheckServerHealth(ctx context.Context) error {
healthClient := grpc_health_v1.NewHealthClient(c.GetConnection())
resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{ resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{
Service: "", Service: "",
}) })
@@ -398,11 +339,16 @@ func (c *Client) CheckServerHealth(ctx context.Context) error {
return nil return nil
} }
func (c *Client) GetConfig() *GrpcConfig { func (c *client) Close() error {
return c.config if c.conn != nil {
log.Printf("Closing gRPC connection to %s", c.address)
c.closing = true
return c.conn.Close()
}
return nil
} }
func (c *Client) isConnectionError(err error) bool { func (c *client) isConnectionError(err error) bool {
if c.closing { if c.closing {
return false return false
} }
+30
View File
@@ -0,0 +1,30 @@
package header
type ResponseHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type responseHeader struct {
startLine []byte
headers map[string]string
}
type RequestHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
Method() string
Path() string
Version() string
}
type requestHeader struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
+148
View File
@@ -0,0 +1,148 @@
package header
import (
"bufio"
"bytes"
"fmt"
)
func setRemainingHeaders(remaining []byte, header interface {
Set(key string, value string)
}) {
for len(remaining) > 0 {
lineEnd := bytes.Index(remaining, []byte("\r\n"))
if lineEnd == -1 {
lineEnd = len(remaining)
}
line := remaining[:lineEnd]
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx != -1 {
key := bytes.TrimSpace(line[:colonIdx])
value := bytes.TrimSpace(line[colonIdx+1:])
header.Set(string(key), string(value))
}
if lineEnd == len(remaining) {
break
}
remaining = remaining[lineEnd+2:]
}
}
func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) {
header := &requestHeader{
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
}
startLine := headerData[:lineEnd]
header.startLine = startLine
var err error
header.method, header.path, header.version, err = parseStartLine(startLine)
if err != nil {
return nil, err
}
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func parseStartLine(startLine []byte) (method, path, version string, err error) {
firstSpace := bytes.IndexByte(startLine, ' ')
if firstSpace == -1 {
return "", "", "", fmt.Errorf("invalid start line: missing method")
}
secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ')
if secondSpace == -1 {
return "", "", "", fmt.Errorf("invalid start line: missing version")
}
secondSpace += firstSpace + 1
method = string(startLine[:firstSpace])
path = string(startLine[firstSpace+1 : secondSpace])
version = string(startLine[secondSpace+1:])
return method, path, version, nil
}
func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) {
header := &requestHeader{
headers: make(map[string]string, 16),
}
startLineBytes, err := br.ReadSlice('\n')
if err != nil {
return nil, err
}
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
header.startLine = make([]byte, len(startLineBytes))
copy(header.startLine, startLineBytes)
header.method, header.path, header.version, err = parseStartLine(header.startLine)
if err != nil {
return nil, err
}
for {
lineBytes, err := br.ReadSlice('\n')
if err != nil {
return nil, err
}
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
if len(lineBytes) == 0 {
break
}
colonIdx := bytes.IndexByte(lineBytes, ':')
if colonIdx == -1 {
continue
}
key := bytes.TrimSpace(lineBytes[:colonIdx])
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
header.headers[string(key)] = string(value)
}
return header, nil
}
func finalize(startLine []byte, headers map[string]string) []byte {
size := len(startLine) + 2
for key, val := range headers {
size += len(key) + 2 + len(val) + 2
}
size += 2
buf := make([]byte, 0, size)
buf = append(buf, startLine...)
buf = append(buf, '\r', '\n')
for key, val := range headers {
buf = append(buf, key...)
buf = append(buf, ':', ' ')
buf = append(buf, val...)
buf = append(buf, '\r', '\n')
}
buf = append(buf, '\r', '\n')
return buf
}
+49
View File
@@ -0,0 +1,49 @@
package header
import (
"bufio"
"fmt"
)
func NewRequest(r interface{}) (RequestHeader, error) {
switch v := r.(type) {
case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
}
}
func (req *requestHeader) Value(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeader) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeader) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeader) Method() string {
return req.method
}
func (req *requestHeader) Path() string {
return req.path
}
func (req *requestHeader) Version() string {
return req.version
}
func (req *requestHeader) Finalize() []byte {
return finalize(req.startLine, req.headers)
}
+40
View File
@@ -0,0 +1,40 @@
package header
import (
"bytes"
"fmt"
)
func NewResponse(headerData []byte) (ResponseHeader, error) {
header := &responseHeader{
startLine: nil,
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid response: no CRLF found in start line")
}
header.startLine = headerData[:lineEnd]
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func (resp *responseHeader) Value(key string) string {
return resp.headers[key]
}
func (resp *responseHeader) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeader) Remove(key string) {
delete(resp.headers, key)
}
func (resp *responseHeader) Finalize() []byte {
return finalize(resp.startLine, resp.headers)
}
+29
View File
@@ -0,0 +1,29 @@
package stream
import "bytes"
func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
headerByte := data[:delimiterIdx+len(DELIMITER)]
body := data[delimiterIdx+len(DELIMITER):]
return headerByte, body
}
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
+50
View File
@@ -0,0 +1,50 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := hs.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
headerEndIdx := bytes.Index(tmp, DELIMITER)
if headerEndIdx == -1 {
return handleNoDelimiter(p, tmp, err)
}
headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx)
if !isHTTPHeader(headerByte) {
copy(p, tmp)
return read, nil
}
return hs.processHTTPRequest(p, headerByte, bodyByte)
}
func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) {
reqhf, err := header.NewRequest(headerByte)
if err != nil {
return 0, err
}
if err = hs.ApplyRequestMiddlewares(reqhf); err != nil {
return 0, err
}
hs.reqHeader = reqhf
combined := append(reqhf.Finalize(), bodyByte...)
return copy(p, combined), nil
}
func handleNoDelimiter(p, tmp []byte, err error) (int, error) {
copy(p, tmp)
return len(tmp), err
}
+103
View File
@@ -0,0 +1,103 @@
package stream
import (
"io"
"log"
"net"
"regexp"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/middleware"
)
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
type HTTP interface {
io.ReadWriteCloser
CloseWrite() error
RemoteAddr() net.Addr
UseResponseMiddleware(mw middleware.ResponseMiddleware)
UseRequestMiddleware(mw middleware.RequestMiddleware)
SetRequestHeader(header header.RequestHeader)
RequestMiddlewares() []middleware.RequestMiddleware
ResponseMiddlewares() []middleware.ResponseMiddleware
ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error
ApplyRequestMiddlewares(reqhf header.RequestHeader) error
}
type http struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader header.ResponseHeader
reqHeader header.RequestHeader
respMW []middleware.ResponseMiddleware
reqMW []middleware.RequestMiddleware
}
func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP {
return &http{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
func (hs *http) RemoteAddr() net.Addr {
return hs.remoteAddr
}
func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) {
hs.respMW = append(hs.respMW, mw)
}
func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) {
hs.reqMW = append(hs.reqMW, mw)
}
func (hs *http) SetRequestHeader(header header.RequestHeader) {
hs.reqHeader = header
}
func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware {
return hs.reqMW
}
func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
return hs.respMW
}
func (hs *http) Close() error {
return hs.writer.(io.Closer).Close()
}
func (hs *http) CloseWrite() error {
if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok {
return closer.CloseWrite()
}
return hs.Close()
}
func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error {
for _, m := range hs.RequestMiddlewares() {
if err := m.HandleRequest(reqhf); err != nil {
log.Printf("Error when applying request middleware: %v", err)
return err
}
}
return nil
}
func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error {
for _, m := range hs.ResponseMiddlewares() {
if err := m.HandleResponse(resphf, bodyByte); err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return err
}
}
return nil
}
+88
View File
@@ -0,0 +1,88 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Write(p []byte) (int, error) {
if hs.shouldBypassBuffering(p) {
hs.respHeader = nil
}
if hs.respHeader != nil {
return hs.writer.Write(p)
}
hs.buf = append(hs.buf, p...)
headerEndIdx := bytes.Index(hs.buf, DELIMITER)
if headerEndIdx == -1 {
return len(p), nil
}
return hs.processBufferedResponse(p, headerEndIdx)
}
func (hs *http) shouldBypassBuffering(p []byte) bool {
return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
}
func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx)
if !isHTTPHeader(headerByte) {
return hs.writeRawBuffer()
}
if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil {
return 0, err
}
hs.buf = nil
return len(p), nil
}
func (hs *http) writeRawBuffer() (int, error) {
_, err := hs.writer.Write(hs.buf)
length := len(hs.buf)
hs.buf = nil
if err != nil {
return 0, err
}
return length, nil
}
func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error {
resphf, err := header.NewResponse(headerByte)
if err != nil {
return err
}
if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil {
return err
}
hs.respHeader = resphf
finalHeader := resphf.Finalize()
if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil {
return err
}
return nil
}
func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error {
if _, err := hs.writer.Write(header); err != nil {
return err
}
if len(bodyByte) > 0 {
if _, err := hs.writer.Write(bodyByte); err != nil {
return err
}
}
return nil
}
+23
View File
@@ -0,0 +1,23 @@
package middleware
import (
"net"
"tunnel_pls/internal/http/header"
)
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+13
View File
@@ -0,0 +1,13 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type RequestMiddleware interface {
HandleRequest(header header.RequestHeader) error
}
type ResponseMiddleware interface {
HandleResponse(header header.ResponseHeader, body []byte) error
}
+16
View File
@@ -0,0 +1,16 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
+36 -49
View File
@@ -3,63 +3,40 @@ package port
import ( import (
"fmt" "fmt"
"sort" "sort"
"strconv"
"strings"
"sync" "sync"
"tunnel_pls/internal/config"
) )
type Manager interface { type Port interface {
AddPortRange(startPort, endPort uint16) error AddRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool) Unassigned() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error SetStatus(port uint16, assigned bool) error
GetPortStatus(port uint16) (bool, bool) Claim(port uint16) (claimed bool)
} }
type manager struct { type port struct {
mu sync.RWMutex mu sync.RWMutex
ports map[uint16]bool ports map[uint16]bool
sortedPorts []uint16 sortedPorts []uint16
} }
var Default Manager = &manager{ func New() Port {
ports: make(map[uint16]bool), return &port{
sortedPorts: []uint16{}, ports: make(map[uint16]bool),
sortedPorts: []uint16{},
}
} }
func init() { func (pm *port) AddRange(startPort, endPort uint16) error {
rawRange := config.Getenv("ALLOWED_PORTS", "")
if rawRange == "" {
return
}
splitRange := strings.Split(rawRange, "-")
if len(splitRange) != 2 {
return
}
start, err := strconv.ParseUint(splitRange[0], 10, 16)
if err != nil {
return
}
end, err := strconv.ParseUint(splitRange[1], 10, 16)
if err != nil {
return
}
_ = Default.AddPortRange(uint16(start), uint16(end))
}
func (pm *manager) AddPortRange(startPort, endPort uint16) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
if startPort > endPort { if startPort > endPort {
return fmt.Errorf("start port cannot be greater than end port") return fmt.Errorf("start port cannot be greater than end port")
} }
for port := startPort; port <= endPort; port++ { for index := startPort; index <= endPort; index++ {
if _, exists := pm.ports[port]; !exists { if _, exists := pm.ports[index]; !exists {
pm.ports[port] = false pm.ports[index] = false
pm.sortedPorts = append(pm.sortedPorts, port) pm.sortedPorts = append(pm.sortedPorts, index)
} }
} }
sort.Slice(pm.sortedPorts, func(i, j int) bool { sort.Slice(pm.sortedPorts, func(i, j int) bool {
@@ -68,20 +45,19 @@ func (pm *manager) AddPortRange(startPort, endPort uint16) error {
return nil return nil
} }
func (pm *manager) GetUnassignedPort() (uint16, bool) { func (pm *port) Unassigned() (uint16, bool) {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
for _, port := range pm.sortedPorts { for _, index := range pm.sortedPorts {
if !pm.ports[port] { if !pm.ports[index] {
pm.ports[port] = true return index, true
return port, true
} }
} }
return 0, false return 0, false
} }
func (pm *manager) SetPortStatus(port uint16, assigned bool) error { func (pm *port) SetStatus(port uint16, assigned bool) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -89,10 +65,21 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
return nil return nil
} }
func (pm *manager) GetPortStatus(port uint16) (bool, bool) { func (pm *port) Claim(port uint16) (claimed bool) {
pm.mu.RLock() pm.mu.Lock()
defer pm.mu.RUnlock() defer pm.mu.Unlock()
status, exists := pm.ports[port] status, exists := pm.ports[port]
return status, exists
if exists && status {
return false
}
if !exists {
pm.ports[port] = true
return true
}
pm.ports[port] = true
return true
} }
+13 -13
View File
@@ -1,18 +1,18 @@
package random package random
import ( import "crypto/rand"
mathrand "math/rand"
"strings"
"time"
)
func GenerateRandomString(length int) string { func GenerateRandomString(length int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyz" const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
seededRand := mathrand.New(mathrand.NewSource(time.Now().UnixNano() + int64(mathrand.Intn(9999)))) b := make([]byte, length)
var result strings.Builder
for i := 0; i < length; i++ { if _, err := rand.Read(b); err != nil {
randomIndex := seededRand.Intn(len(charset)) return "", err
result.WriteString(string(charset[randomIndex]))
} }
return result.String()
for i := range b {
b[i] = charset[int(b[i])%len(charset)]
}
return string(b), nil
} }
@@ -1,103 +1,124 @@
package session package registry
import ( import (
"fmt" "fmt"
"sync" "sync"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
) )
type Key = types.SessionKey type Key = types.SessionKey
type Session interface {
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *types.Detail
}
type Registry interface { type Registry interface {
Get(key Key) (session *SSHSession, err error) Get(key Key) (session Session, err error)
GetWithUser(user string, key Key) (session *SSHSession, err error) GetWithUser(user string, key Key) (session Session, err error)
Update(user string, oldKey, newKey Key) error Update(user string, oldKey, newKey Key) error
Register(key Key, session *SSHSession) (success bool) Register(key Key, session Session) (success bool)
Remove(key Key) Remove(key Key)
GetAllSessionFromUser(user string) []*SSHSession GetAllSessionFromUser(user string) []Session
} }
type registry struct { type registry struct {
mu sync.RWMutex mu sync.RWMutex
byUser map[string]map[Key]*SSHSession byUser map[string]map[Key]Session
slugIndex map[Key]string slugIndex map[Key]string
} }
var (
ErrSessionNotFound = fmt.Errorf("session not found")
ErrSlugInUse = fmt.Errorf("slug already in use")
ErrInvalidSlug = fmt.Errorf("invalid slug")
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
)
func NewRegistry() Registry { func NewRegistry() Registry {
return &registry{ return &registry{
byUser: make(map[string]map[Key]*SSHSession), byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string), slugIndex: make(map[Key]string),
} }
} }
func (r *registry) Get(key Key) (session *SSHSession, err error) { func (r *registry) Get(key Key) (session Session, err error) {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
userID, ok := r.slugIndex[key] userID, ok := r.slugIndex[key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, ErrSessionNotFound
} }
client, ok := r.byUser[userID][key] client, ok := r.byUser[userID][key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, ErrSessionNotFound
} }
return client, nil return client, nil
} }
func (r *registry) GetWithUser(user string, key Key) (session *SSHSession, err error) { func (r *registry) GetWithUser(user string, key Key) (session Session, err error) {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
client, ok := r.byUser[user][key] client, ok := r.byUser[user][key]
if !ok { if !ok {
return nil, fmt.Errorf("session not found") return nil, ErrSessionNotFound
} }
return client, nil return client, nil
} }
func (r *registry) Update(user string, oldKey, newKey Key) error { func (r *registry) Update(user string, oldKey, newKey Key) error {
if oldKey.Type != newKey.Type { if oldKey.Type != newKey.Type {
return fmt.Errorf("tunnel type cannot change") return ErrSlugUnchanged
} }
if newKey.Type != types.HTTP { if newKey.Type != types.TunnelTypeHTTP {
return fmt.Errorf("non http tunnel cannot change slug") return ErrSlugChangeNotAllowed
} }
if isForbiddenSlug(newKey.Id) { if isForbiddenSlug(newKey.Id) {
return fmt.Errorf("this subdomain is reserved. Please choose a different one") return ErrForbiddenSlug
} }
if !isValidSlug(newKey.Id) { if !isValidSlug(newKey.Id) {
return fmt.Errorf("invalid subdomain. Follow the rules") return ErrInvalidSlug
} }
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return fmt.Errorf("someone already uses this subdomain") return ErrSlugInUse
} }
client, ok := r.byUser[user][oldKey] client, ok := r.byUser[user][oldKey]
if !ok { if !ok {
return fmt.Errorf("session not found") return ErrSessionNotFound
} }
delete(r.byUser[user], oldKey) delete(r.byUser[user], oldKey)
delete(r.slugIndex, oldKey) delete(r.slugIndex, oldKey)
client.slugManager.Set(newKey.Id) client.Slug().Set(newKey.Id)
r.slugIndex[newKey] = user r.slugIndex[newKey] = user
if r.byUser[user] == nil { if r.byUser[user] == nil {
r.byUser[user] = make(map[Key]*SSHSession) r.byUser[user] = make(map[Key]Session)
} }
r.byUser[user][newKey] = client r.byUser[user][newKey] = client
return nil return nil
} }
func (r *registry) Register(key Key, session *SSHSession) (success bool) { func (r *registry) Register(key Key, userSession Session) (success bool) {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
@@ -105,26 +126,26 @@ func (r *registry) Register(key Key, session *SSHSession) (success bool) {
return false return false
} }
userID := session.lifecycle.GetUser() userID := userSession.Lifecycle().User()
if r.byUser[userID] == nil { if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]*SSHSession) r.byUser[userID] = make(map[Key]Session)
} }
r.byUser[userID][key] = session r.byUser[userID][key] = userSession
r.slugIndex[key] = userID r.slugIndex[key] = userID
return true return true
} }
func (r *registry) GetAllSessionFromUser(user string) []*SSHSession { func (r *registry) GetAllSessionFromUser(user string) []Session {
r.mu.RLock() r.mu.RLock()
defer r.mu.RUnlock() defer r.mu.RUnlock()
m := r.byUser[user] m := r.byUser[user]
if len(m) == 0 { if len(m) == 0 {
return []*SSHSession{} return []Session{}
} }
sessions := make([]*SSHSession, 0, len(m)) sessions := make([]Session, 0, len(m))
for _, s := range m { for _, s := range m {
sessions = append(sessions, s) sessions = append(sessions, s)
} }
+40
View File
@@ -0,0 +1,40 @@
package transport
import (
"errors"
"log"
"net"
"tunnel_pls/internal/registry"
)
type httpServer struct {
handler *httpHandler
port string
}
func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &httpServer{
handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
port: port,
}
}
func (ht *httpServer) Listen() (net.Listener, error) {
return net.Listen("tcp", ":"+ht.port)
}
func (ht *httpServer) Serve(listener net.Listener) error {
log.Printf("HTTP server is starting on port %s", ht.port)
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.handler.handler(conn, false)
}
}
+231
View File
@@ -0,0 +1,231 @@
package transport
import (
"bufio"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware"
"tunnel_pls/internal/registry"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type httpHandler struct {
domain string
sessionRegistry registry.Registry
redirectTLS bool
}
func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
return &httpHandler{
domain: domain,
sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
}
}
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
fmt.Sprintf("Location: %s", location) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
return err
}
return nil
}
func (hh *httpHandler) badRequest(conn net.Conn) error {
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
return err
}
return nil
}
func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
defer hh.closeConnection(conn)
dstReader := bufio.NewReader(conn)
reqhf, err := header.NewRequest(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
slug, err := hh.extractSlug(reqhf)
if err != nil {
_ = hh.badRequest(conn)
return
}
if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
return
}
if hh.handlePingRequest(slug, conn) {
return
}
sshSession, err := hh.getSession(slug)
if err != nil {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return
}
hw := stream.New(conn, dstReader, conn.RemoteAddr())
defer func(hw stream.HTTP) {
err = hw.Close()
if err != nil {
log.Printf("Error closing HTTP stream: %v", err)
}
}(hw)
hh.forwardRequest(hw, reqhf, sshSession)
}
func (hh *httpHandler) closeConnection(conn net.Conn) {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
}
}
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".")
if len(host) < 1 {
return "", errors.New("invalid host")
}
return host[0], nil
}
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
return !isTLS && hh.redirectTLS
}
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
if slug != "ping" {
return false
}
_, err := conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
}
return true
}
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.TunnelTypeHTTP,
})
if err != nil {
return nil, err
}
return sshSession, nil
}
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession)
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return
}
defer func() {
err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
hh.setupMiddlewares(hw)
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
log.Printf("Failed to forward initial request: %v", err)
return
}
sshSession.Forwarder().HandleConnection(hw, channel)
}
func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
hh.cleanupUnusedChannel(channel, reqs)
}
}()
select {
case result := <-resultChan:
if result.err != nil {
return nil, result.err
}
go ssh.DiscardRequests(result.reqs)
return result.channel, nil
case <-time.After(5 * time.Second):
return nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
if channel != nil {
if err := channel.Close(); err != nil {
log.Printf("Failed to close unused channel: %v", err)
}
go ssh.DiscardRequests(reqs)
}
}
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
fingerprintMiddleware := middleware.NewTunnelFingerprint()
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
hw.UseResponseMiddleware(fingerprintMiddleware)
hw.UseRequestMiddleware(forwardedForMiddleware)
}
func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error {
hw.SetRequestHeader(initialRequest)
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
return fmt.Errorf("error applying request middlewares: %w", err)
}
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
return fmt.Errorf("error writing to channel: %w", err)
}
return nil
}
+45
View File
@@ -0,0 +1,45 @@
package transport
import (
"crypto/tls"
"errors"
"log"
"net"
"tunnel_pls/internal/registry"
)
type https struct {
tlsConfig *tls.Config
httpHandler *httpHandler
domain string
port string
}
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
return &https{
tlsConfig: tlsConfig,
httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
domain: domain,
port: port,
}
}
func (ht *https) Listen() (net.Listener, error) {
return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.port)
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.httpHandler.handler(conn, true)
}
}
+66
View File
@@ -0,0 +1,66 @@
package transport
import (
"errors"
"fmt"
"io"
"log"
"net"
"golang.org/x/crypto/ssh"
)
type tcp struct {
port uint16
forwarder forwarder
}
type forwarder interface {
CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
HandleConnection(dst io.ReadWriter, src ssh.Channel)
}
func NewTCPServer(port uint16, forwarder forwarder) Transport {
return &tcp{
port: port,
forwarder: forwarder,
}
}
func (tt *tcp) Listen() (net.Listener, error) {
return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port))
}
func (tt *tcp) Serve(listener net.Listener) error {
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
log.Printf("Error accepting connection: %v", err)
continue
}
go tt.handleTcp(conn)
}
}
func (tt *tcp) handleTcp(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
log.Printf("Failed to close connection: %v", err)
}
}()
payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
tt.forwarder.HandleConnection(conn, channel)
}
+17 -44
View File
@@ -1,4 +1,4 @@
package server package transport
import ( import (
"context" "context"
@@ -26,7 +26,8 @@ type TLSManager interface {
} }
type tlsManager struct { type tlsManager struct {
domain string config config.Config
certPath string certPath string
keyPath string keyPath string
storagePath string storagePath string
@@ -42,7 +43,7 @@ type tlsManager struct {
var globalTLSManager TLSManager var globalTLSManager TLSManager
var tlsManagerOnce sync.Once var tlsManagerOnce sync.Once
func NewTLSConfig(domain string) (*tls.Config, error) { func NewTLSConfig(config config.Config) (*tls.Config, error) {
var initErr error var initErr error
tlsManagerOnce.Do(func() { tlsManagerOnce.Do(func() {
@@ -51,7 +52,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
storagePath := "certs/tls/certmagic" storagePath := "certs/tls/certmagic"
tm := &tlsManager{ tm := &tlsManager{
domain: domain, config: config,
certPath: certPath, certPath: certPath,
keyPath: keyPath, keyPath: keyPath,
storagePath: storagePath, storagePath: storagePath,
@@ -66,14 +67,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
tm.useCertMagic = false tm.useCertMagic = false
tm.startCertWatcher() tm.startCertWatcher()
} else { } else {
if !isACMEConfigComplete() { log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", config.Domain(), config.Domain())
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
return
}
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
if err := tm.initCertMagic(); err != nil { if err := tm.initCertMagic(); err != nil {
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err) initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
return return
@@ -91,11 +85,6 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
return globalTLSManager.getTLSConfig(), nil return globalTLSManager.getTLSConfig(), nil
} }
func isACMEConfigComplete() bool {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
return cfAPIToken != ""
}
func (tm *tlsManager) userCertsExistAndValid() bool { func (tm *tlsManager) userCertsExistAndValid() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath) log.Printf("Certificate file not found: %s", tm.certPath)
@@ -106,7 +95,7 @@ func (tm *tlsManager) userCertsExistAndValid() bool {
return false return false
} }
return ValidateCertDomains(tm.certPath, tm.domain) return ValidateCertDomains(tm.certPath, tm.config.Domain())
} }
func ValidateCertDomains(certPath, domain string) bool { func ValidateCertDomains(certPath, domain string) bool {
@@ -206,15 +195,9 @@ func (tm *tlsManager) startCertWatcher() {
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) { if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
log.Printf("Certificate files changed, reloading...") log.Printf("Certificate files changed, reloading...")
if !ValidateCertDomains(tm.certPath, tm.domain) { if !ValidateCertDomains(tm.certPath, tm.config.Domain()) {
log.Printf("New certificates don't cover required domains") log.Printf("New certificates don't cover required domains")
if !isACMEConfigComplete() {
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
continue
}
log.Printf("Switching to CertMagic for automatic certificate management")
if err := tm.initCertMagic(); err != nil { if err := tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err) log.Printf("Failed to initialize CertMagic: %v", err)
continue continue
@@ -241,16 +224,12 @@ func (tm *tlsManager) initCertMagic() error {
return fmt.Errorf("failed to create cert storage directory: %w", err) return fmt.Errorf("failed to create cert storage directory: %w", err)
} }
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain) if tm.config.CFAPIToken() == "" {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
if cfAPIToken == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation") return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
} }
cfProvider := &cloudflare.Provider{ cfProvider := &cloudflare.Provider{
APIToken: cfAPIToken, APIToken: tm.config.CFAPIToken(),
} }
storage := &certmagic.FileStorage{Path: tm.storagePath} storage := &certmagic.FileStorage{Path: tm.storagePath}
@@ -266,7 +245,7 @@ func (tm *tlsManager) initCertMagic() error {
}) })
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: acmeEmail, Email: tm.config.ACMEEmail(),
Agreed: true, Agreed: true,
DNS01Solver: &certmagic.DNS01Solver{ DNS01Solver: &certmagic.DNS01Solver{
DNSManager: certmagic.DNSManager{ DNSManager: certmagic.DNSManager{
@@ -275,7 +254,7 @@ func (tm *tlsManager) initCertMagic() error {
}, },
}) })
if acmeStaging { if tm.config.ACMEStaging() {
acmeIssuer.CA = certmagic.LetsEncryptStagingCA acmeIssuer.CA = certmagic.LetsEncryptStagingCA
log.Printf("Using Let's Encrypt staging server") log.Printf("Using Let's Encrypt staging server")
} else { } else {
@@ -286,7 +265,7 @@ func (tm *tlsManager) initCertMagic() error {
magic.Issuers = []certmagic.Issuer{acmeIssuer} magic.Issuers = []certmagic.Issuer{acmeIssuer}
tm.magic = magic tm.magic = magic
domains := []string{tm.domain, "*." + tm.domain} domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
log.Printf("Requesting certificates for: %v", domains) log.Printf("Requesting certificates for: %v", domains)
ctx := context.Background() ctx := context.Background()
@@ -301,22 +280,16 @@ func (tm *tlsManager) initCertMagic() error {
func (tm *tlsManager) getTLSConfig() *tls.Config { func (tm *tlsManager) getTLSConfig() *tls.Config {
return &tls.Config{ return &tls.Config{
GetCertificate: tm.getCertificate, GetCertificate: tm.getCertificate,
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
SessionTicketsDisabled: false, MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_CHACHA20_POLY1305_SHA256,
},
CurvePreferences: []tls.CurveID{ CurvePreferences: []tls.CurveID{
tls.X25519, tls.X25519,
}, },
ClientAuth: tls.NoClientCert, SessionTicketsDisabled: false,
NextProtos: nil, ClientAuth: tls.NoClientCert,
} }
} }
+10
View File
@@ -0,0 +1,10 @@
package transport
import (
"net"
)
type Transport interface {
Listen() (net.Listener, error)
Serve(listener net.Listener) error
}
+81 -43
View File
@@ -4,19 +4,22 @@ import (
"context" "context"
"fmt" "fmt"
"log" "log"
"net"
"net/http" "net/http"
_ "net/http/pprof" _ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
"strings"
"syscall" "syscall"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/key" "tunnel_pls/internal/key"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/internal/version"
"tunnel_pls/server" "tunnel_pls/server"
"tunnel_pls/session" "tunnel_pls/types"
"tunnel_pls/version"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -32,28 +35,19 @@ func main() {
log.Printf("Starting %s", version.GetVersion()) log.Printf("Starting %s", version.GetVersion())
mode := strings.ToLower(config.Getenv("MODE", "standalone")) conf, err := config.MustLoad()
isNodeMode := mode == "node" if err != nil {
log.Fatalf("Failed to load configuration: %s", err)
pprofEnabled := config.Getenv("PPROF_ENABLED", "false") return
if pprofEnabled == "true" {
pprofPort := config.Getenv("PPROF_PORT", "6060")
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
} }
sshConfig := &ssh.ServerConfig{ sshConfig := &ssh.ServerConfig{
NoClientAuth: true, NoClientAuth: true,
ServerVersion: fmt.Sprintf("SSH-2.0-TunnlPls-%s", version.GetShortVersion()), ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
} }
sshKeyPath := "certs/ssh/id_rsa" sshKeyPath := "certs/ssh/id_rsa"
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil { if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
log.Fatalf("Failed to generate SSH key: %s", err) log.Fatalf("Failed to generate SSH key: %s", err)
} }
@@ -68,7 +62,7 @@ func main() {
} }
sshConfig.AddHostKey(private) sshConfig.AddHostKey(private)
sessionRegistry := session.NewRegistry() sessionRegistry := registry.NewRegistry()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -77,55 +71,93 @@ func main() {
shutdownChan := make(chan os.Signal, 1) shutdownChan := make(chan os.Signal, 1)
signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM) signal.Notify(shutdownChan, os.Interrupt, syscall.SIGTERM)
var grpcClient *client.Client var grpcClient client.Client
if isNodeMode {
grpcHost := config.Getenv("GRPC_ADDRESS", "localhost")
grpcPort := config.Getenv("GRPC_PORT", "8080")
grpcAddr := fmt.Sprintf("%s:%s", grpcHost, grpcPort)
nodeToken := config.Getenv("NODE_TOKEN", "")
if nodeToken == "" {
log.Fatalf("NODE_TOKEN is required in node mode")
}
c, err := client.New(&client.GrpcConfig{ if conf.Mode() == types.ServerModeNODE {
Address: grpcAddr, grpcAddr := fmt.Sprintf("%s:%s", conf.GRPCAddress(), conf.GRPCPort())
UseTLS: false,
InsecureSkipVerify: false, grpcClient, err = client.New(conf, grpcAddr, sessionRegistry)
Timeout: 10 * time.Second,
KeepAlive: true,
MaxRetries: 3,
}, sessionRegistry)
if err != nil { if err != nil {
log.Fatalf("failed to create grpc client: %v", err) log.Fatalf("failed to create grpc client: %v", err)
} }
grpcClient = c
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second) healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
if err := grpcClient.CheckServerHealth(healthCtx); err != nil { if err = grpcClient.CheckServerHealth(healthCtx); err != nil {
healthCancel() healthCancel()
log.Fatalf("gRPC health check failed: %v", err) log.Fatalf("gRPC health check failed: %v", err)
} }
healthCancel() healthCancel()
go func() { go func() {
identity := config.Getenv("DOMAIN", "localhost") if err = grpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
if err := grpcClient.SubscribeEvents(ctx, identity, nodeToken); err != nil {
errChan <- fmt.Errorf("failed to subscribe to events: %w", err) errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
} }
}() }()
} }
go func() { go func() {
app, err := server.NewServer(sshConfig, sessionRegistry, grpcClient) var httpListener net.Listener
httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), sessionRegistry, conf.TLSRedirect())
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
if conf.TLSEnabled() {
go func() {
var httpsListener net.Listener
tlsConfig, _ := transport.NewTLSConfig(conf)
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), sessionRegistry, conf.TLSRedirect(), tlsConfig)
httpsListener, err = httpsServer.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpsServer.Serve(httpsListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
}
portManager := port.New()
err = portManager.AddRange(conf.AllowedPortsStart(), conf.AllowedPortsEnd())
if err != nil {
log.Fatalf("Failed to initialize port manager: %s", err)
return
}
var app server.Server
go func() {
app, err = server.New(conf, sshConfig, sessionRegistry, grpcClient, portManager, conf.SSHPort())
if err != nil { if err != nil {
errChan <- fmt.Errorf("failed to start server: %s", err) errChan <- fmt.Errorf("failed to start server: %s", err)
return return
} }
app.Start() app.Start()
}() }()
if conf.PprofEnabled() {
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", conf.PprofPort())
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
}
select { select {
case err := <-errChan: case err = <-errChan:
log.Printf("error happen : %s", err) log.Printf("error happen : %s", err)
case sig := <-shutdownChan: case sig := <-shutdownChan:
log.Printf("received signal %s, shutting down", sig) log.Printf("received signal %s, shutting down", sig)
@@ -133,8 +165,14 @@ func main() {
cancel() cancel()
if app != nil {
if err = app.Close(); err != nil {
log.Printf("failed to close server : %s", err)
}
}
if grpcClient != nil { if grpcClient != nil {
if err := grpcClient.Close(); err != nil { if err = grpcClient.Close(); err != nil {
log.Printf("failed to close grpc conn : %s", err) log.Printf("failed to close grpc conn : %s", err)
} }
} }
-8
View File
@@ -1,8 +0,0 @@
module.exports = {
"endpoint": "https://git.fossy.my.id/api/v1",
"gitAuthor": "Renovate-Clanker <renovate-bot@fossy.my.id>",
"platform": "gitea",
"onboardingConfigFileName": "renovate.json",
"autodiscover": true,
"optimizeForDisabled": true,
};
-276
View File
@@ -1,276 +0,0 @@
package server
import (
"bufio"
"bytes"
"fmt"
)
type HeaderManager interface {
Get(key string) []byte
Set(key string, value []byte)
Remove(key string)
Finalize() []byte
}
type ResponseHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type RequestHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
}
type responseHeaderFactory struct {
startLine []byte
headers map[string]string
}
type requestHeaderFactory struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) {
switch v := r.(type) {
case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
}
}
func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) {
header := &requestHeaderFactory{
headers: make(map[string]string, 16),
}
lineEnd := bytes.IndexByte(headerData, '\n')
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no newline found")
}
startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n")
header.startLine = make([]byte, len(startLine))
copy(header.startLine, startLine)
parts := bytes.Split(startLine, []byte{' '})
if len(parts) < 3 {
return nil, fmt.Errorf("invalid request line")
}
header.method = string(parts[0])
header.path = string(parts[1])
header.version = string(parts[2])
remaining := headerData[lineEnd+1:]
for len(remaining) > 0 {
lineEnd = bytes.IndexByte(remaining, '\n')
if lineEnd == -1 {
lineEnd = len(remaining)
}
line := bytes.TrimRight(remaining[:lineEnd], "\r\n")
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx != -1 {
key := bytes.TrimSpace(line[:colonIdx])
value := bytes.TrimSpace(line[colonIdx+1:])
header.headers[string(key)] = string(value)
}
if lineEnd == len(remaining) {
break
}
remaining = remaining[lineEnd+1:]
}
return header, nil
}
func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) {
header := &requestHeaderFactory{
headers: make(map[string]string, 16),
}
startLineBytes, err := br.ReadSlice('\n')
if err != nil {
if err == bufio.ErrBufferFull {
var startLine string
startLine, err = br.ReadString('\n')
if err != nil {
return nil, err
}
startLineBytes = []byte(startLine)
} else {
return nil, err
}
}
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
header.startLine = make([]byte, len(startLineBytes))
copy(header.startLine, startLineBytes)
parts := bytes.Split(startLineBytes, []byte{' '})
if len(parts) < 3 {
return nil, fmt.Errorf("invalid request line")
}
header.method = string(parts[0])
header.path = string(parts[1])
header.version = string(parts[2])
for {
lineBytes, err := br.ReadSlice('\n')
if err != nil {
if err == bufio.ErrBufferFull {
var line string
line, err = br.ReadString('\n')
if err != nil {
return nil, err
}
lineBytes = []byte(line)
} else {
return nil, err
}
}
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
if len(lineBytes) == 0 {
break
}
colonIdx := bytes.IndexByte(lineBytes, ':')
if colonIdx == -1 {
continue
}
key := bytes.TrimSpace(lineBytes[:colonIdx])
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
header.headers[string(key)] = string(value)
}
return header, nil
}
func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
header := &responseHeaderFactory{
startLine: nil,
headers: make(map[string]string),
}
lines := bytes.Split(startLine, []byte("\r\n"))
if len(lines) == 0 {
return header
}
header.startLine = lines[0]
for _, h := range lines[1:] {
if len(h) == 0 {
continue
}
parts := bytes.SplitN(h, []byte(":"), 2)
if len(parts) < 2 {
continue
}
key := parts[0]
val := bytes.TrimSpace(parts[1])
header.headers[string(key)] = string(val)
}
return header
}
func (resp *responseHeaderFactory) Get(key string) string {
return resp.headers[key]
}
func (resp *responseHeaderFactory) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeaderFactory) Remove(key string) {
delete(resp.headers, key)
}
func (resp *responseHeaderFactory) Finalize() []byte {
var buf bytes.Buffer
buf.Write(resp.startLine)
buf.WriteString("\r\n")
for key, val := range resp.headers {
buf.WriteString(key)
buf.WriteString(": ")
buf.WriteString(val)
buf.WriteString("\r\n")
}
buf.WriteString("\r\n")
return buf.Bytes()
}
func (req *requestHeaderFactory) Get(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeaderFactory) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeaderFactory) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeaderFactory) GetMethod() string {
return req.method
}
func (req *requestHeaderFactory) GetPath() string {
return req.path
}
func (req *requestHeaderFactory) GetVersion() string {
return req.version
}
func (req *requestHeaderFactory) Finalize() []byte {
var buf bytes.Buffer
buf.Write(req.startLine)
buf.WriteString("\r\n")
for key, val := range req.headers {
buf.WriteString(key)
buf.WriteString(": ")
buf.WriteString(val)
buf.WriteString("\r\n")
}
buf.WriteString("\r\n")
return buf.Bytes()
}
-395
View File
@@ -1,395 +0,0 @@
package server
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"net"
"regexp"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/session"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type HTTPWriter interface {
io.Reader
io.Writer
GetRemoteAddr() net.Addr
GetWriter() io.Writer
AddResponseMiddleware(mw ResponseMiddleware)
AddRequestStartMiddleware(mw RequestMiddleware)
SetRequestHeader(header RequestHeaderManager)
GetRequestStartMiddleware() []RequestMiddleware
}
type customWriter struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader ResponseHeaderManager
reqHeader RequestHeaderManager
respMW []ResponseMiddleware
reqStartMW []RequestMiddleware
reqEndMW []RequestMiddleware
}
func (cw *customWriter) GetRemoteAddr() net.Addr {
return cw.remoteAddr
}
func (cw *customWriter) GetWriter() io.Writer {
return cw.writer
}
func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) {
cw.respMW = append(cw.respMW, mw)
}
func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) {
cw.reqStartMW = append(cw.reqStartMW, mw)
}
func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) {
cw.reqHeader = header
}
func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware {
return cw.reqStartMW
}
func (cw *customWriter) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := cw.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
idx := bytes.Index(tmp, DELIMITER)
if idx == -1 {
copy(p, tmp)
if err != nil {
return read, err
}
return read, nil
}
header := tmp[:idx+len(DELIMITER)]
body := tmp[idx+len(DELIMITER):]
if !isHTTPHeader(header) {
copy(p, tmp)
return read, nil
}
for _, m := range cw.reqEndMW {
err = m.HandleRequest(cw.reqHeader)
if err != nil {
log.Printf("Error when applying request middleware: %v", err)
return 0, err
}
}
reqhf, err := NewRequestHeaderFactory(header)
if err != nil {
return 0, err
}
for _, m := range cw.reqStartMW {
if mwErr := m.HandleRequest(reqhf); mwErr != nil {
log.Printf("Error when applying request middleware: %v", mwErr)
return 0, mwErr
}
}
cw.reqHeader = reqhf
finalHeader := reqhf.Finalize()
combined := append(finalHeader, body...)
n := copy(p, combined)
return n, nil
}
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
return &customWriter{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
func (cw *customWriter) Write(p []byte) (int, error) {
if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" {
cw.respHeader = nil
}
if cw.respHeader != nil {
n, err := cw.writer.Write(p)
if err != nil {
return n, err
}
return n, nil
}
cw.buf = append(cw.buf, p...)
idx := bytes.Index(cw.buf, DELIMITER)
if idx == -1 {
return len(p), nil
}
header := cw.buf[:idx+len(DELIMITER)]
body := cw.buf[idx+len(DELIMITER):]
if !isHTTPHeader(header) {
_, err := cw.writer.Write(cw.buf)
cw.buf = nil
if err != nil {
return 0, err
}
return len(p), nil
}
resphf := NewResponseHeaderFactory(header)
for _, m := range cw.respMW {
err := m.HandleResponse(resphf, body)
if err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return 0, err
}
}
header = resphf.Finalize()
cw.respHeader = resphf
_, err := cw.writer.Write(header)
if err != nil {
return 0, err
}
if len(body) > 0 {
_, err = cw.writer.Write(body)
if err != nil {
return 0, err
}
}
cw.buf = nil
return len(p), nil
}
var redirectTLS = false
type HTTPServer interface {
ListenAndServe() error
ListenAndServeTLS() error
handler(conn net.Conn)
handlerTLS(conn net.Conn)
}
type httpServer struct {
sessionRegistry session.Registry
}
func NewHTTPServer(sessionRegistry session.Registry) HTTPServer {
return &httpServer{sessionRegistry: sessionRegistry}
}
func (hs *httpServer) ListenAndServe() error {
httpPort := config.Getenv("HTTP_PORT", "8080")
listener, err := net.Listen("tcp", ":"+httpPort)
if err != nil {
return errors.New("Error listening: " + err.Error())
}
if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" {
redirectTLS = true
}
go func() {
for {
var conn net.Conn
conn, err = listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handler(conn)
}
}()
return nil
}
func (hs *httpServer) handler(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
return
}
return
}()
dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeaderFactory(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
host := strings.Split(reqhf.Get("Host"), ".")
if len(host) < 1 {
_, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
if err != nil {
log.Println("Failed to write 400 Bad Request:", err)
return
}
return
}
slug := host[0]
if redirectTLS {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
if slug == "ping" {
_, err = conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
return
}
return
}
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
forwardRequest(cw, reqhf, sshSession)
return
}
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) {
payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
var channel ssh.Channel
var reqs <-chan *ssh.Request
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
channel = result.channel
reqs = result.reqs
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
go ssh.DiscardRequests(reqs)
fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr())
cw.AddResponseMiddleware(fingerprintMiddleware)
cw.AddRequestStartMiddleware(forwardedForMiddleware)
cw.SetRequestHeader(initialRequest)
for _, m := range cw.GetRequestStartMiddleware() {
if err := m.HandleRequest(initialRequest); err != nil {
log.Printf("Error handling request: %v", err)
return
}
}
_, err := channel.Write(initialRequest.Finalize())
if err != nil {
log.Printf("Failed to forward request: %v", err)
return
}
sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return
}
-112
View File
@@ -1,112 +0,0 @@
package server
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"strings"
"tunnel_pls/internal/config"
"tunnel_pls/types"
)
func (hs *httpServer) ListenAndServeTLS() error {
domain := config.Getenv("DOMAIN", "localhost")
httpsPort := config.Getenv("HTTPS_PORT", "8443")
tlsConfig, err := NewTLSConfig(domain)
if err != nil {
return fmt.Errorf("failed to initialize TLS config: %w", err)
}
ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig)
if err != nil {
return err
}
go func() {
for {
var conn net.Conn
conn, err = ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Println("https server closed")
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handlerTLS(conn)
}
}()
return nil
}
func (hs *httpServer) handlerTLS(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
log.Printf("Error closing connection: %v", err)
return
}
return
}()
dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeaderFactory(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
host := strings.Split(reqhf.Get("Host"), ".")
if len(host) < 1 {
_, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
if err != nil {
log.Println("Failed to write 400 Bad Request:", err)
return
}
return
}
slug := host[0]
if slug == "ping" {
_, err = conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
return
}
return
}
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
forwardRequest(cw, reqhf, sshSession)
return
}
-41
View File
@@ -1,41 +0,0 @@
package server
import (
"net"
)
type RequestMiddleware interface {
HandleRequest(header RequestHeaderManager) error
}
type ResponseMiddleware interface {
HandleResponse(header ResponseHeaderManager, body []byte) error
}
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+37 -34
View File
@@ -9,53 +9,53 @@ import (
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client" "tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session" "tunnel_pls/session"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Server struct { type Server interface {
conn *net.Listener Start()
config *ssh.ServerConfig Close() error
sessionRegistry session.Registry }
grpcClient *client.Client type server struct {
config config.Config
sshPort string
sshListener net.Listener
sshConfig *ssh.ServerConfig
grpcClient client.Client
sessionRegistry registry.Registry
portRegistry port.Port
} }
func NewServer(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient *client.Client) (*Server, error) { func New(config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200"))) listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
if err != nil { if err != nil {
log.Fatalf("failed to listen on port 2200: %v", err)
return nil, err return nil, err
} }
HttpServer := NewHTTPServer(sessionRegistry) return &server{
err = HttpServer.ListenAndServe() config: config,
if err != nil { sshPort: sshPort,
log.Fatalf("failed to start http server: %v", err) sshListener: listener,
return nil, err sshConfig: sshConfig,
}
if config.Getenv("TLS_ENABLED", "false") == "true" {
err = HttpServer.ListenAndServeTLS()
if err != nil {
log.Fatalf("failed to start https server: %v", err)
return nil, err
}
}
return &Server{
conn: &listener,
config: sshConfig,
sessionRegistry: sessionRegistry,
grpcClient: grpcClient, grpcClient: grpcClient,
sessionRegistry: sessionRegistry,
portRegistry: portRegistry,
}, nil }, nil
} }
func (s *Server) Start() { func (s *server) Start() {
log.Println("SSH server is starting on port 2200...") log.Printf("SSH server is starting on port %s", s.sshPort)
for { for {
conn, err := (*s.conn).Accept() conn, err := s.sshListener.Accept()
if err != nil { if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Println("listener closed, stopping server")
return
}
log.Printf("failed to accept connection: %v", err) log.Printf("failed to accept connection: %v", err)
continue continue
} }
@@ -64,8 +64,12 @@ func (s *Server) Start() {
} }
} }
func (s *Server) handleConnection(conn net.Conn) { func (s *server) Close() error {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config) return s.sshListener.Close()
}
func (s *server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil { if err != nil {
log.Printf("failed to establish SSH connection: %v", err) log.Printf("failed to establish SSH connection: %v", err)
err = conn.Close() err = conn.Close()
@@ -90,9 +94,8 @@ func (s *Server) handleConnection(conn net.Conn) {
user = u user = u
cancel() cancel()
} }
log.Println("SSH connection established:", sshConn.User()) log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user) sshSession := session.New(s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
err = sshSession.Start() err = sshSession.Start()
if err != nil { if err != nil {
log.Printf("SSH session ended with error: %v", err) log.Printf("SSH session ended with error: %v", err)
+99 -124
View File
@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"log" "log"
"net" "net"
@@ -17,188 +18,162 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
var bufferPool = sync.Pool{ type Forwarder interface {
New: func() interface{} { SetType(tunnelType types.TunnelType)
bufSize := config.GetBufferSize() SetForwardedPort(port uint16)
return make([]byte, bufSize) SetListener(listener net.Listener)
}, Listener() net.Listener
TunnelType() types.TunnelType
ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel)
CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
WriteBadGatewayResponse(dst io.Writer)
Close() error
} }
type forwarder struct {
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
type Forwarder struct {
listener net.Listener listener net.Listener
tunnelType types.TunnelType tunnelType types.TunnelType
forwardedPort uint16 forwardedPort uint16
slugManager slug.Manager slug slug.Slug
lifecycle Lifecycle conn ssh.Conn
bufferPool sync.Pool
} }
func NewForwarder(slugManager slug.Manager) *Forwarder { func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
return &Forwarder{ return &forwarder{
listener: nil, listener: nil,
tunnelType: "", tunnelType: types.TunnelTypeUNKNOWN,
forwardedPort: 0, forwardedPort: 0,
slugManager: slugManager, slug: slug,
lifecycle: nil, conn: conn,
bufferPool: sync.Pool{
New: func() interface{} {
bufSize := config.BufferSize()
return make([]byte, bufSize)
},
},
} }
} }
type Lifecycle interface { func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
GetConnection() ssh.Conn buf := f.bufferPool.Get().([]byte)
defer f.bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
} }
type ForwardingController interface { func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
AcceptTCPConnections() type channelResult struct {
SetType(tunnelType types.TunnelType) channel ssh.Channel
GetTunnelType() types.TunnelType reqs <-chan *ssh.Request
GetForwardedPort() uint16 err error
SetForwardedPort(port uint16) }
SetListener(listener net.Listener) resultChan := make(chan channelResult, 1)
GetListener() net.Listener
Close() error
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle)
CreateForwardedTCPIPPayload(origin net.Addr) []byte
WriteBadGatewayResponse(dst io.Writer)
}
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
f.lifecycle = lifecycle
}
func (f *Forwarder) AcceptTCPConnections() {
for {
conn, err := f.GetListener().Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Printf("Failed to set connection deadline: %v", err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
go func() {
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
select { select {
case result := <-resultChan: case resultChan <- channelResult{channel, reqs, err}:
if result.err != nil { default:
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) if channel != nil {
if closeErr := conn.Close(); closeErr != nil { err = channel.Close()
log.Printf("Failed to close connection: %v", closeErr) if err != nil {
log.Printf("Failed to close unused channel: %v", err)
return
} }
continue go ssh.DiscardRequests(reqs)
}
if err := conn.SetDeadline(time.Time{}); err != nil {
log.Printf("Failed to clear connection deadline: %v", err)
}
go ssh.DiscardRequests(result.reqs)
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
} }
} }
}()
select {
case result := <-resultChan:
return result.channel, result.reqs, result.err
case <-time.After(5 * time.Second):
return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
} }
} }
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) { func closeWriter(w io.Writer) error {
if cw, ok := w.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
if closer, ok := w.(io.Closer); ok {
return closer.Close()
}
return nil
}
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
var errs []error
_, err := f.copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
}
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
}
return errors.Join(errs...)
}
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
defer func() { defer func() {
_, err := io.Copy(io.Discard, src) _, err := io.Copy(io.Discard, src)
if err != nil { if err != nil {
log.Printf("Failed to discard connection: %v", err) log.Printf("Failed to discard connection: %v", err)
} }
err = src.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing source channel: %v", err)
}
if closer, ok := dst.(io.Closer); ok {
err = closer.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing destination connection: %v", err)
}
}
}() }()
log.Printf("Handling new forwarded connection from %s", remoteAddr)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err := copyWithBuffer(dst, src) err := f.copyAndClose(dst, src, "src to dst")
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { if err != nil {
log.Printf("Error copying src→dst: %v", err) log.Println("Error during copy: ", err)
return
} }
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err := copyWithBuffer(src, dst) err := f.copyAndClose(src, dst, "dst to src")
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { if err != nil {
log.Printf("Error copying dst→src: %v", err) log.Println("Error during copy: ", err)
return
} }
}() }()
wg.Wait() wg.Wait()
} }
func (f *Forwarder) SetType(tunnelType types.TunnelType) { func (f *forwarder) SetType(tunnelType types.TunnelType) {
f.tunnelType = tunnelType f.tunnelType = tunnelType
} }
func (f *Forwarder) GetTunnelType() types.TunnelType { func (f *forwarder) TunnelType() types.TunnelType {
return f.tunnelType return f.tunnelType
} }
func (f *Forwarder) GetForwardedPort() uint16 { func (f *forwarder) ForwardedPort() uint16 {
return f.forwardedPort return f.forwardedPort
} }
func (f *Forwarder) SetForwardedPort(port uint16) { func (f *forwarder) SetForwardedPort(port uint16) {
f.forwardedPort = port f.forwardedPort = port
} }
func (f *Forwarder) SetListener(listener net.Listener) { func (f *forwarder) SetListener(listener net.Listener) {
f.listener = listener f.listener = listener
} }
func (f *Forwarder) GetListener() net.Listener { func (f *forwarder) Listener() net.Listener {
return f.listener return f.listener
} }
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
_, err := dst.Write(types.BadGatewayResponse) _, err := dst.Write(types.BadGatewayResponse)
if err != nil { if err != nil {
log.Printf("failed to write Bad Gateway response: %v", err) log.Printf("failed to write Bad Gateway response: %v", err)
@@ -206,20 +181,20 @@ func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
} }
} }
func (f *Forwarder) Close() error { func (f *forwarder) Close() error {
if f.GetListener() != nil { if f.Listener() != nil {
return f.listener.Close() return f.listener.Close()
} }
return nil return nil
} }
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
var buf bytes.Buffer var buf bytes.Buffer
host, originPort := parseAddr(origin.String()) host, originPort := parseAddr(origin.String())
writeSSHString(&buf, "localhost") writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort())) err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
if err != nil { if err != nil {
log.Printf("Failed to write string to buffer: %v", err) log.Printf("Failed to write string to buffer: %v", err)
return nil return nil
-313
View File
@@ -1,313 +0,0 @@
package session
import (
"bytes"
"encoding/binary"
"fmt"
"log"
"net"
portUtil "tunnel_pls/internal/port"
"tunnel_pls/internal/random"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest {
switch req.Type {
case "shell", "pty-req":
err := req.Reply(true, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
case "window-change":
p := req.Payload
if len(p) < 16 {
log.Println("invalid window-change payload")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
return
}
cols := binary.BigEndian.Uint32(p[0:4])
rows := binary.BigEndian.Uint32(p[4:8])
s.interaction.SetWH(int(cols), int(rows))
err := req.Reply(true, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
default:
log.Println("Unknown request type:", req.Type)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
}
}
}
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected")
reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader)
if err != nil {
log.Println("Failed to read address from payload:", err)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
var rawPortToBind uint32
if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
log.Println("Failed to read port from payload:", err)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
if rawPortToBind > 65535 {
log.Printf("Port %d is larger than allowed port of 65535", rawPortToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) {
log.Printf("Port %d is blocked or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
if portToBind == 80 || portToBind == 443 {
s.HandleHTTPForward(req, portToBind)
return
}
if portToBind == 0 {
unassign, success := portUtil.Default.GetUnassignedPort()
portToBind = unassign
if !success {
log.Println("No available port")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
log.Printf("Port %d is already in use or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
err = portUtil.Default.SetPortStatus(portToBind, true)
if err != nil {
log.Println("Failed to set port status:", err)
return
}
s.HandleTCPForward(req, addr, portToBind)
}
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
slug := random.GenerateRandomString(20)
key := types.SessionKey{Id: slug, Type: types.HTTP}
if !s.registry.Register(key, s) {
log.Printf("Failed to register client with slug: %s", slug)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return
}
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil {
log.Println("Failed to write port to buffer:", err)
s.registry.Remove(key)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return
}
log.Printf("HTTP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes())
if err != nil {
log.Println("Failed to reply to request:", err)
s.registry.Remove(key)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return
}
s.forwarder.SetType(types.HTTP)
s.forwarder.SetForwardedPort(portToBind)
s.slugManager.Set(slug)
s.lifecycle.SetStatus(types.RUNNING)
s.interaction.Start()
}
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil {
log.Printf("Port %d is already in use or restricted", portToBind)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP}
if !s.registry.Register(key, s) {
log.Printf("Failed to register TCP client with id: %s", key.Id)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
if closeErr := listener.Close(); closeErr != nil {
log.Printf("Failed to close listener: %s", closeErr)
}
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
_ = s.lifecycle.Close()
return
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil {
log.Println("Failed to write port to buffer:", err)
s.registry.Remove(key)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return
}
log.Printf("TCP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes())
if err != nil {
log.Println("Failed to reply to request:", err)
s.registry.Remove(key)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return
}
s.forwarder.SetType(types.TCP)
s.forwarder.SetListener(listener)
s.forwarder.SetForwardedPort(portToBind)
s.slugManager.Set(key.Id)
s.lifecycle.SetStatus(types.RUNNING)
go s.forwarder.AcceptTCPConnections()
s.interaction.Start()
}
func readSSHString(reader *bytes.Reader) (string, error) {
var length uint32
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
return "", err
}
strBytes := make([]byte, length)
if _, err := reader.Read(strBytes); err != nil {
return "", err
}
return string(strBytes), nil
}
func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 {
return false
}
if port < 1024 && port != 0 {
return true
}
for _, p := range blockedReservedPorts {
if p == port {
return true
}
}
return false
}
+83
View File
@@ -0,0 +1,83 @@
package interaction
import (
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) comingSoonUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.showingComingSoon = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
func (m *model) comingSoonView() string {
isCompact := shouldUseCompactLayout(m.width, 60)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 3
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
messageBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Background(lipgloss.Color("#1A1A2E")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(messageBoxWidth).
Align(lipgloss.Center)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Coming Soon"
} else {
title = "⏳ Coming Soon"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
var message string
if shouldUseCompactLayout(m.width, 50) {
message = "Coming soon!\nStay tuned."
} else {
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
}
b.WriteString(messageBoxStyle.Render(message))
b.WriteString("\n\n")
var helpText string
if shouldUseCompactLayout(m.width, 60) {
helpText = "Press any key..."
} else {
helpText = "This message will disappear in 5 seconds or press any key..."
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
+83
View File
@@ -0,0 +1,83 @@
package interaction
import (
"strings"
"time"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch {
case key.Matches(msg, m.keymap.quit):
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case msg.String() == "enter":
selectedItem := m.commandList.SelectedItem()
if selectedItem != nil {
item := selectedItem.(commandItem)
if item.name == "slug" {
m.showingCommands = false
m.editingSlug = true
m.slugInput.SetValue(m.interaction.slug.String())
m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" {
m.showingCommands = false
m.showingComingSoon = true
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
}
m.showingCommands = false
return m, nil
}
case msg.String() == "esc":
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
m.commandList, cmd = m.commandList.Update(msg)
return m, cmd
}
func (m *model) commandsView() string {
isCompact := shouldUseCompactLayout(m.width, 60)
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Commands"
} else {
title = "⚡ Commands"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
b.WriteString(m.commandList.View())
b.WriteString("\n")
var helpText string
if isCompact {
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
} else {
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
+186
View File
@@ -0,0 +1,186 @@
package interaction
import (
"fmt"
"strings"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) dashboardUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch {
case key.Matches(msg, m.keymap.quit):
m.quitting = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
case key.Matches(msg, m.keymap.command):
m.showingCommands = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
return m, nil
}
func (m *model) dashboardView() string {
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1)
subtitleStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Italic(true)
urlStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Underline(true).
Italic(true)
urlBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Bold(true).
Italic(true)
keyHintStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Bold(true)
var b strings.Builder
isCompact := shouldUseCompactLayout(m.width, 85)
var asciiArtMargin int
if isCompact {
asciiArtMargin = 0
} else {
asciiArtMargin = 1
}
asciiArtStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
MarginBottom(asciiArtMargin)
var asciiArt string
if shouldUseCompactLayout(m.width, 50) {
asciiArt = "TUNNEL PLS"
} else if isCompact {
asciiArt = `
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
} else {
asciiArt = `
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
}
b.WriteString(asciiArtStyle.Render(asciiArt))
b.WriteString("\n")
if !shouldUseCompactLayout(m.width, 60) {
b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
b.WriteString(urlStyle.Render("https://fossy.my.id"))
b.WriteString("\n\n")
} else {
b.WriteString("\n")
}
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
responsiveInfoBox := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(boxMaxWidth)
authenticatedUser := m.interaction.user
userInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Bold(true)
sectionHeaderStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Bold(true)
addressStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA"))
var infoContent string
if shouldUseCompactLayout(m.width, 70) {
infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL()))))
} else {
infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL())))
}
b.WriteString(responsiveInfoBox.Render(infoContent))
b.WriteString("\n")
var quickActionsTitle string
if shouldUseCompactLayout(m.width, 50) {
quickActionsTitle = "Actions"
} else if isCompact {
quickActionsTitle = "Quick Actions"
} else {
quickActionsTitle = "✨ Quick Actions"
}
b.WriteString(titleStyle.Render(quickActionsTitle))
b.WriteString("\n")
var featureMargin int
if isCompact {
featureMargin = 1
} else {
featureMargin = 2
}
compactFeatureStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginLeft(featureMargin)
var commandsText string
var quitText string
if shouldUseCompactLayout(m.width, 60) {
commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
} else {
commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
}
b.WriteString(compactFeatureStyle.Render(commandsText))
b.WriteString("\n")
b.WriteString(compactFeatureStyle.Render(quitText))
if !shouldUseCompactLayout(m.width, 70) {
b.WriteString("\n\n")
footerStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true)
b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
}
return b.String()
}
+63 -640
View File
@@ -2,12 +2,8 @@ package interaction
import ( import (
"context" "context"
"fmt"
"log" "log"
"strings"
"time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/random"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
@@ -21,42 +17,57 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Lifecycle interface { type Interaction interface {
Close() error Mode() types.InteractiveMode
GetUser() string SetChannel(channel ssh.Channel)
SetMode(m types.InteractiveMode)
SetWH(w, h int)
Start()
Redraw()
Send(message string) error
} }
type SessionRegistry interface { type SessionRegistry interface {
Update(user string, oldKey, newKey types.SessionKey) error Update(user string, oldKey, newKey types.SessionKey) error
} }
type Controller interface {
SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle)
Start()
SetWH(w, h int)
Redraw()
SetSessionRegistry(registry SessionRegistry)
}
type Forwarder interface { type Forwarder interface {
Close() error Close() error
GetTunnelType() types.TunnelType TunnelType() types.TunnelType
GetForwardedPort() uint16 ForwardedPort() uint16
} }
type Interaction struct { type CloseFunc func() error
type interaction struct {
config config.Config
channel ssh.Channel channel ssh.Channel
slugManager slug.Manager slug slug.Slug
forwarder Forwarder forwarder Forwarder
lifecycle Lifecycle closeFunc CloseFunc
user string
sessionRegistry SessionRegistry sessionRegistry SessionRegistry
program *tea.Program program *tea.Program
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
mode types.InteractiveMode
} }
func (i *Interaction) SetWH(w, h int) { func (i *interaction) SetMode(m types.InteractiveMode) {
i.mode = m
}
func (i *interaction) Mode() types.InteractiveMode {
return i.mode
}
func (i *interaction) Send(message string) error {
if i.channel != nil {
_, err := i.channel.Write([]byte(message))
return err
}
return nil
}
func (i *interaction) SetWH(w, h int) {
if i.program != nil { if i.program != nil {
i.program.Send(tea.WindowSizeMsg{ i.program.Send(tea.WindowSizeMsg{
Width: w, Width: w,
@@ -65,72 +76,27 @@ func (i *Interaction) SetWH(w, h int) {
} }
} }
type commandItem struct { func New(config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
name string
desc string
}
type model struct {
domain string
protocol string
tunnelType types.TunnelType
port uint16
keymap keymap
help help.Model
quitting bool
showingCommands bool
editingSlug bool
showingComingSoon bool
commandList list.Model
slugInput textinput.Model
slugError string
interaction *Interaction
width int
height int
}
func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP {
return buildURL(m.protocol, m.interaction.slugManager.Get(), m.domain)
}
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
}
type keymap struct {
quit key.Binding
command key.Binding
random key.Binding
}
type tickMsg time.Time
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Interaction{ return &interaction{
config: config,
channel: nil, channel: nil,
slugManager: slugManager, slug: slug,
forwarder: forwarder, forwarder: forwarder,
lifecycle: nil, closeFunc: closeFunc,
sessionRegistry: nil, user: user,
sessionRegistry: sessionRegistry,
program: nil, program: nil,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
} }
} }
func (i *Interaction) SetSessionRegistry(registry SessionRegistry) { func (i *interaction) SetChannel(channel ssh.Channel) {
i.sessionRegistry = registry
}
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle
}
func (i *Interaction) SetChannel(channel ssh.Channel) {
i.channel = channel i.channel = channel
} }
func (i *Interaction) Stop() { func (i *interaction) Stop() {
if i.cancel != nil { if i.cancel != nil {
i.cancel() i.cancel()
} }
@@ -140,47 +106,7 @@ func (i *Interaction) Stop() {
} }
} }
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
width := screenWidth - padding
if width > maxWidth {
width = maxWidth
}
if width < minWidth {
width = minWidth
}
return width
}
func shouldUseCompactLayout(width int, threshold int) bool {
return width < threshold
}
func truncateString(s string, maxLength int) string {
if len(s) <= maxLength {
return s
}
if maxLength < 4 {
return s[:maxLength]
}
return s[:maxLength-3] + "..."
}
func (i commandItem) FilterValue() string { return i.name }
func (i commandItem) Title() string { return i.name }
func (i commandItem) Description() string { return i.desc }
func tickCmd(d time.Duration) tea.Cmd {
return tea.Tick(d, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func (m *model) Init() tea.Cmd {
return tea.Batch(textinput.Blink, tea.WindowSize())
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) { switch msg := msg.(type) {
case tickMsg: case tickMsg:
@@ -206,559 +132,62 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.KeyMsg: case tea.KeyMsg:
if m.showingComingSoon { if m.showingComingSoon {
m.showingComingSoon = false return m.comingSoonUpdate(msg)
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} }
if m.editingSlug { if m.editingSlug {
if m.tunnelType != types.HTTP { return m.slugUpdate(msg)
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
switch msg.String() {
case "esc":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter":
inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.GetUser(), types.SessionKey{
Id: m.interaction.slugManager.Get(),
Type: types.HTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.HTTP,
}); err != nil {
m.slugError = err.Error()
return m, nil
}
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "ctrl+c":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
default:
if key.Matches(msg, m.keymap.random) {
newSubdomain := generateRandomSubdomain()
m.slugInput.SetValue(newSubdomain)
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
return m, cmd
}
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
return m, cmd
}
} }
if m.showingCommands { if m.showingCommands {
switch { return m.commandsUpdate(msg)
case key.Matches(msg, m.keymap.quit):
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case msg.String() == "enter":
selectedItem := m.commandList.SelectedItem()
if selectedItem != nil {
item := selectedItem.(commandItem)
if item.name == "slug" {
m.showingCommands = false
m.editingSlug = true
m.slugInput.SetValue(m.interaction.slugManager.Get())
m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" {
m.showingCommands = false
m.showingComingSoon = true
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
}
m.showingCommands = false
return m, nil
}
case msg.String() == "esc":
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
m.commandList, cmd = m.commandList.Update(msg)
return m, cmd
} }
switch { return m.dashboardUpdate(msg)
case key.Matches(msg, m.keymap.quit):
m.quitting = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
case key.Matches(msg, m.keymap.command):
m.showingCommands = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
} }
return m, nil return m, nil
} }
func (i *Interaction) Redraw() { func (i *interaction) Redraw() {
if i.program != nil { if i.program != nil {
i.program.Send(tea.ClearScreen()) i.program.Send(tea.ClearScreen())
} }
} }
func (m *model) helpView() string {
return "\n" + m.help.ShortHelpView([]key.Binding{
m.keymap.command,
m.keymap.quit,
})
}
func (m *model) View() string { func (m *model) View() string {
if m.quitting { if m.quitting {
return "" return ""
} }
if m.showingComingSoon { if m.showingComingSoon {
isCompact := shouldUseCompactLayout(m.width, 60) return m.comingSoonView()
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 3
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
messageBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Background(lipgloss.Color("#1A1A2E")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(messageBoxWidth).
Align(lipgloss.Center)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Coming Soon"
} else {
title = "⏳ Coming Soon"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
var message string
if shouldUseCompactLayout(m.width, 50) {
message = "Coming soon!\nStay tuned."
} else {
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
}
b.WriteString(messageBoxStyle.Render(message))
b.WriteString("\n\n")
var helpText string
if shouldUseCompactLayout(m.width, 60) {
helpText = "Press any key..."
} else {
helpText = "This message will disappear in 5 seconds or press any key..."
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
} }
if m.editingSlug { if m.editingSlug {
isCompact := shouldUseCompactLayout(m.width, 70) return m.slugView()
isVeryCompact := shouldUseCompactLayout(m.width, 50)
var boxPadding int
var boxMargin int
if isVeryCompact {
boxPadding = 1
boxMargin = 1
} else if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
instructionStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginTop(1)
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
errorBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FF0000")).
Background(lipgloss.Color("#3D0000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1)
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
rulesBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1).
Width(rulesBoxWidth)
var b strings.Builder
var title string
if isVeryCompact {
title = "Edit Subdomain"
} else {
title = "🔧 Edit Subdomain"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
if m.tunnelType != types.HTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")).
Background(lipgloss.Color("#3D2000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FFA500")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(warningBoxWidth)
var warningText string
if isVeryCompact {
warningText = "⚠️ TCP tunnels don't support custom subdomains."
} else {
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
}
b.WriteString(warningBoxStyle.Render(warningText))
b.WriteString("\n\n")
var helpText string
if isVeryCompact {
helpText = "Press any key to go back"
} else {
helpText = "Press Enter or Esc to go back"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
var rulesContent string
if isVeryCompact {
rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
} else if isCompact {
rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
} else {
rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
}
b.WriteString(rulesBoxStyle.Render(rulesContent))
b.WriteString("\n")
var instruction string
if isVeryCompact {
instruction = "Custom subdomain:"
} else {
instruction = "Enter your custom subdomain:"
}
b.WriteString(instructionStyle.Render(instruction))
b.WriteString("\n")
if m.slugError != "" {
errorInputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(1)
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
b.WriteString("\n")
} else {
b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
}
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
if len(previewURL) > previewWidth-10 {
previewURL = truncateString(previewURL, previewWidth-10)
}
previewStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Italic(true).
Width(previewWidth)
b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
b.WriteString("\n")
var helpText string
if isVeryCompact {
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
} else {
helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
} }
if m.showingCommands { if m.showingCommands {
isCompact := shouldUseCompactLayout(m.width, 60) return m.commandsView()
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Commands"
} else {
title = "⚡ Commands"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
b.WriteString(m.commandList.View())
b.WriteString("\n")
var helpText string
if isCompact {
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
} else {
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
} }
titleStyle := lipgloss.NewStyle(). return m.dashboardView()
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1)
subtitleStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Italic(true)
urlStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Underline(true).
Italic(true)
urlBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Bold(true).
Italic(true)
keyHintStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Bold(true)
var b strings.Builder
isCompact := shouldUseCompactLayout(m.width, 85)
var asciiArtMargin int
if isCompact {
asciiArtMargin = 0
} else {
asciiArtMargin = 1
}
asciiArtStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
MarginBottom(asciiArtMargin)
var asciiArt string
if shouldUseCompactLayout(m.width, 50) {
asciiArt = "TUNNEL PLS"
} else if isCompact {
asciiArt = `
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
} else {
asciiArt = `
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
}
b.WriteString(asciiArtStyle.Render(asciiArt))
b.WriteString("\n")
if !shouldUseCompactLayout(m.width, 60) {
b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
b.WriteString(urlStyle.Render("https://fossy.my.id"))
b.WriteString("\n\n")
} else {
b.WriteString("\n")
}
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
responsiveInfoBox := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(boxMaxWidth)
authenticatedUser := m.interaction.lifecycle.GetUser()
userInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Bold(true)
sectionHeaderStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Bold(true)
addressStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA"))
var infoContent string
if shouldUseCompactLayout(m.width, 70) {
infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL()))))
} else {
infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL())))
}
b.WriteString(responsiveInfoBox.Render(infoContent))
b.WriteString("\n")
var quickActionsTitle string
if shouldUseCompactLayout(m.width, 50) {
quickActionsTitle = "Actions"
} else if isCompact {
quickActionsTitle = "Quick Actions"
} else {
quickActionsTitle = "✨ Quick Actions"
}
b.WriteString(titleStyle.Render(quickActionsTitle))
b.WriteString("\n")
var featureMargin int
if isCompact {
featureMargin = 1
} else {
featureMargin = 2
}
compactFeatureStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginLeft(featureMargin)
var commandsText string
var quitText string
if shouldUseCompactLayout(m.width, 60) {
commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
} else {
commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
}
b.WriteString(compactFeatureStyle.Render(commandsText))
b.WriteString("\n")
b.WriteString(compactFeatureStyle.Render(quitText))
if !shouldUseCompactLayout(m.width, 70) {
b.WriteString("\n\n")
footerStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true)
b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
}
return b.String()
} }
func (i *Interaction) Start() { func (i *interaction) Start() {
if i.mode == types.InteractiveModeHEADLESS {
return
}
lipgloss.SetColorProfile(termenv.TrueColor) lipgloss.SetColorProfile(termenv.TrueColor)
domain := config.Getenv("DOMAIN", "localhost")
protocol := "http" protocol := "http"
if config.Getenv("TLS_ENABLED", "false") == "true" { if i.config.TLSEnabled() {
protocol = "https" protocol = "https"
} }
tunnelType := i.forwarder.GetTunnelType() tunnelType := i.forwarder.TunnelType()
port := i.forwarder.GetForwardedPort() port := i.forwarder.ForwardedPort()
items := []list.Item{ items := []list.Item{
commandItem{name: "slug", desc: "Set custom subdomain"}, commandItem{name: "slug", desc: "Set custom subdomain"},
@@ -781,7 +210,7 @@ func (i *Interaction) Start() {
ti.Width = 50 ti.Width = 50
m := &model{ m := &model{
domain: domain, domain: i.config.Domain(),
protocol: protocol, protocol: protocol,
tunnelType: tunnelType, tunnelType: tunnelType,
port: port, port: port,
@@ -822,15 +251,9 @@ func (i *Interaction) Start() {
} }
i.program.Kill() i.program.Kill()
i.program = nil i.program = nil
if err := m.interaction.lifecycle.Close(); err != nil { if i.closeFunc != nil {
log.Printf("Cannot close session: %s \n", err) if err := i.closeFunc(); err != nil {
log.Printf("Cannot close session: %s \n", err)
}
} }
} }
func buildURL(protocol, subdomain, domain string) string {
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
}
func generateRandomSubdomain() string {
return random.GenerateRandomString(20)
}
+95
View File
@@ -0,0 +1,95 @@
package interaction
import (
"fmt"
"time"
"tunnel_pls/types"
"github.com/charmbracelet/bubbles/help"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/list"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
)
type commandItem struct {
name string
desc string
}
func (i commandItem) FilterValue() string { return i.name }
func (i commandItem) Title() string { return i.name }
func (i commandItem) Description() string { return i.desc }
type model struct {
domain string
protocol string
tunnelType types.TunnelType
port uint16
keymap keymap
help help.Model
quitting bool
showingCommands bool
editingSlug bool
showingComingSoon bool
commandList list.Model
slugInput textinput.Model
slugError string
interaction *interaction
width int
height int
}
func (m *model) getTunnelURL() string {
if m.tunnelType == types.TunnelTypeHTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
}
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
}
type keymap struct {
quit key.Binding
command key.Binding
random key.Binding
}
type tickMsg time.Time
func (m *model) Init() tea.Cmd {
return tea.Batch(textinput.Blink, tea.WindowSize())
}
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
width := screenWidth - padding
if width > maxWidth {
width = maxWidth
}
if width < minWidth {
width = minWidth
}
return width
}
func shouldUseCompactLayout(width int, threshold int) bool {
return width < threshold
}
func truncateString(s string, maxLength int) string {
if len(s) <= maxLength {
return s
}
if maxLength < 4 {
return s[:maxLength]
}
return s[:maxLength-3] + "..."
}
func tickCmd(d time.Duration) tea.Cmd {
return tea.Tick(d, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func buildURL(protocol, subdomain, domain string) string {
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
}
+224
View File
@@ -0,0 +1,224 @@
package interaction
import (
"fmt"
"strings"
"tunnel_pls/internal/random"
"tunnel_pls/types"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
if m.tunnelType != types.TunnelTypeHTTP {
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
switch msg.String() {
case "esc":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter":
inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
Id: m.interaction.slug.String(),
Type: types.TunnelTypeHTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.TunnelTypeHTTP,
}); err != nil {
m.slugError = err.Error()
return m, nil
}
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "ctrl+c":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
default:
if key.Matches(msg, m.keymap.random) {
newSubdomain, err := random.GenerateRandomString(20)
if err != nil {
return m, cmd
}
m.slugInput.SetValue(newSubdomain)
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
}
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
return m, cmd
}
}
func (m *model) slugView() string {
isCompact := shouldUseCompactLayout(m.width, 70)
isVeryCompact := shouldUseCompactLayout(m.width, 50)
var boxPadding int
var boxMargin int
if isVeryCompact {
boxPadding = 1
boxMargin = 1
} else if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
instructionStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginTop(1)
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
errorBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FF0000")).
Background(lipgloss.Color("#3D0000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1)
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
rulesBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1).
Width(rulesBoxWidth)
var b strings.Builder
var title string
if isVeryCompact {
title = "Edit Subdomain"
} else {
title = "🔧 Edit Subdomain"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
if m.tunnelType != types.TunnelTypeHTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")).
Background(lipgloss.Color("#3D2000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FFA500")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(warningBoxWidth)
var warningText string
if isVeryCompact {
warningText = "⚠️ TCP tunnels don't support custom subdomains."
} else {
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
}
b.WriteString(warningBoxStyle.Render(warningText))
b.WriteString("\n\n")
var helpText string
if isVeryCompact {
helpText = "Press any key to go back"
} else {
helpText = "Press Enter or Esc to go back"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
var rulesContent string
if isVeryCompact {
rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
} else if isCompact {
rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
} else {
rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
}
b.WriteString(rulesBoxStyle.Render(rulesContent))
b.WriteString("\n")
var instruction string
if isVeryCompact {
instruction = "Custom subdomain:"
} else {
instruction = "Enter your custom subdomain:"
}
b.WriteString(instructionStyle.Render(instruction))
b.WriteString("\n")
if m.slugError != "" {
errorInputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(1)
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
b.WriteString("\n")
} else {
b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
}
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
if len(previewURL) > previewWidth-10 {
previewURL = truncateString(previewURL, previewWidth-10)
}
previewStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Italic(true).
Width(previewWidth)
b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
b.WriteString("\n")
var helpText string
if isVeryCompact {
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
} else {
helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
+65 -68
View File
@@ -2,8 +2,6 @@ package lifecycle
import ( import (
"errors" "errors"
"io"
"net"
"time" "time"
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
@@ -15,115 +13,114 @@ import (
type Forwarder interface { type Forwarder interface {
Close() error Close() error
GetTunnelType() types.TunnelType TunnelType() types.TunnelType
GetForwardedPort() uint16 ForwardedPort() uint16
} }
type SessionRegistry interface { type SessionRegistry interface {
Remove(key types.SessionKey) Remove(key types.SessionKey)
} }
type Lifecycle struct { type lifecycle struct {
status types.Status status types.SessionStatus
conn ssh.Conn conn ssh.Conn
channel ssh.Channel channel ssh.Channel
forwarder Forwarder forwarder Forwarder
sessionRegistry SessionRegistry slug slug.Slug
slugManager slug.Manager
startedAt time.Time startedAt time.Time
sessionRegistry SessionRegistry
portRegistry portUtil.Port
user string user string
} }
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager, user string) *Lifecycle { func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
return &Lifecycle{ return &lifecycle{
status: types.INITIALIZING, status: types.SessionStatusINITIALIZING,
conn: conn, conn: conn,
channel: nil, channel: nil,
forwarder: forwarder, forwarder: forwarder,
slugManager: slugManager, slug: slugManager,
sessionRegistry: nil,
startedAt: time.Now(), startedAt: time.Now(),
sessionRegistry: sessionRegistry,
portRegistry: port,
user: user, user: user,
} }
} }
func (l *Lifecycle) SetSessionRegistry(registry SessionRegistry) { type Lifecycle interface {
l.sessionRegistry = registry Connection() ssh.Conn
} PortRegistry() portUtil.Port
User() string
type SessionLifecycle interface {
Close() error
SetStatus(status types.Status)
GetConnection() ssh.Conn
GetChannel() ssh.Channel
GetUser() string
SetChannel(channel ssh.Channel) SetChannel(channel ssh.Channel)
SetSessionRegistry(registry SessionRegistry) SetStatus(status types.SessionStatus)
IsActive() bool IsActive() bool
StartedAt() time.Time StartedAt() time.Time
Close() error
} }
func (l *Lifecycle) GetUser() string { func (l *lifecycle) PortRegistry() portUtil.Port {
return l.portRegistry
}
func (l *lifecycle) User() string {
return l.user return l.user
} }
func (l *Lifecycle) GetChannel() ssh.Channel { func (l *lifecycle) SetChannel(channel ssh.Channel) {
return l.channel
}
func (l *Lifecycle) SetChannel(channel ssh.Channel) {
l.channel = channel l.channel = channel
} }
func (l *Lifecycle) GetConnection() ssh.Conn { func (l *lifecycle) Connection() ssh.Conn {
return l.conn return l.conn
} }
func (l *Lifecycle) SetStatus(status types.Status) { func (l *lifecycle) SetStatus(status types.SessionStatus) {
l.status = status l.status = status
if status == types.RUNNING && l.startedAt.IsZero() { if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
l.startedAt = time.Now() l.startedAt = time.Now()
} }
} }
func (l *Lifecycle) Close() error { func closeIfNotNil(c interface{ Close() error }) error {
err := l.forwarder.Close() if c != nil {
if err != nil && !errors.Is(err, net.ErrClosed) { return c.Close()
return err
} }
if l.channel != nil {
err := l.channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
return err
}
}
if l.conn != nil {
err := l.conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
}
clientSlug := l.slugManager.Get()
if clientSlug != "" && l.sessionRegistry.Remove != nil {
key := types.SessionKey{Id: clientSlug, Type: l.forwarder.GetTunnelType()}
l.sessionRegistry.Remove(key)
}
if l.forwarder.GetTunnelType() == types.TCP {
err = portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
if err != nil {
return err
}
}
return nil return nil
} }
func (l *Lifecycle) IsActive() bool { func (l *lifecycle) Close() error {
return l.status == types.RUNNING var errs []error
tunnelType := l.forwarder.TunnelType()
if err := closeIfNotNil(l.channel); err != nil {
errs = append(errs, err)
}
if err := closeIfNotNil(l.conn); err != nil {
errs = append(errs, err)
}
clientSlug := l.slug.String()
key := types.SessionKey{
Id: clientSlug,
Type: tunnelType,
}
l.sessionRegistry.Remove(key)
if tunnelType == types.TunnelTypeTCP {
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil {
errs = append(errs, err)
}
if err := l.forwarder.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
} }
func (l *Lifecycle) StartedAt() time.Time { func (l *lifecycle) IsActive() bool {
return l.status == types.SessionStatusRUNNING
}
func (l *lifecycle) StartedAt() time.Time {
return l.startedAt return l.startedAt
} }
+350 -70
View File
@@ -1,116 +1,185 @@
package session package session
import ( import (
"bytes"
"encoding/binary"
"errors"
"fmt" "fmt"
"io"
"log" "log"
"net"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
portUtil "tunnel_pls/internal/port"
"tunnel_pls/internal/random"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/session/forwarder" "tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction" "tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle" "tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
type Session interface { type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request) HandleGlobalRequest(ch <-chan *ssh.Request) error
HandleTCPIPForward(req *ssh.Request) HandleTCPIPForward(req *ssh.Request) error
HandleHTTPForward(req *ssh.Request, port uint16) HandleHTTPForward(req *ssh.Request, port uint16) error
HandleTCPForward(req *ssh.Request, addr string, port uint16) HandleTCPForward(req *ssh.Request, addr string, port uint16) error
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *types.Detail
Start() error
} }
type SSHSession struct { type session struct {
initialReq <-chan *ssh.Request config config.Config
sshReqChannel <-chan ssh.NewChannel initialReq <-chan *ssh.Request
lifecycle lifecycle.SessionLifecycle sshChan <-chan ssh.NewChannel
interaction interaction.Controller lifecycle lifecycle.Lifecycle
forwarder forwarder.ForwardingController interaction interaction.Interaction
slugManager slug.Manager forwarder forwarder.Forwarder
registry Registry slug slug.Slug
registry registry.Registry
} }
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle { var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
return s.lifecycle
}
func (s *SSHSession) GetInteraction() interaction.Controller { func New(config config.Config, conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
return s.interaction slugManager := slug.New()
} forwarderManager := forwarder.New(config, slugManager, conn)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
interactionManager := interaction.New(config, slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
func (s *SSHSession) GetForwarder() forwarder.ForwardingController { return &session{
return s.forwarder config: config,
} initialReq: initialReq,
sshChan: sshChan,
func (s *SSHSession) GetSlugManager() slug.Manager { lifecycle: lifecycleManager,
return s.slugManager interaction: interactionManager,
} forwarder: forwarderManager,
slug: slugManager,
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) *SSHSession { registry: sessionRegistry,
slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager)
forwarderManager.SetLifecycle(lifecycleManager)
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
return &SSHSession{
initialReq: forwardingReq,
sshReqChannel: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slugManager: slugManager,
registry: sessionRegistry,
} }
} }
type Detail struct { func (s *session) Lifecycle() lifecycle.Lifecycle {
ForwardingType string `json:"forwarding_type,omitempty"` return s.lifecycle
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
} }
func (s *SSHSession) Detail() Detail { func (s *session) Interaction() interaction.Interaction {
return Detail{ return s.interaction
ForwardingType: string(s.forwarder.GetTunnelType()), }
Slug: s.slugManager.Get(),
UserID: s.lifecycle.GetUser(), func (s *session) Forwarder() forwarder.Forwarder {
return s.forwarder
}
func (s *session) Slug() slug.Slug {
return s.slug
}
func (s *session) Detail() *types.Detail {
tunnelTypeMap := map[types.TunnelType]string{
types.TunnelTypeHTTP: "TunnelTypeHTTP",
types.TunnelTypeTCP: "TunnelTypeTCP",
}
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
if !ok {
tunnelType = "TunnelTypeUNKNOWN"
}
return &types.Detail{
ForwardingType: tunnelType,
Slug: s.slug.String(),
UserID: s.lifecycle.User(),
Active: s.lifecycle.IsActive(), Active: s.lifecycle.IsActive(),
StartedAt: s.lifecycle.StartedAt(), StartedAt: s.lifecycle.StartedAt(),
} }
} }
func (s *SSHSession) Start() error { func (s *session) Start() error {
channel := <-s.sshReqChannel if err := s.setupSessionMode(); err != nil {
return err
}
tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil {
return s.handleMissingForwardRequest()
}
if s.shouldRejectUnauthorized() {
return s.denyForwardingRequest(tcpipReq, nil, nil, fmt.Sprintf("headless forwarding only allowed on node mode"))
}
if err := s.HandleTCPIPForward(tcpipReq); err != nil {
return err
}
s.interaction.Start()
return s.waitForSessionEnd()
}
func (s *session) setupSessionMode() error {
select {
case channel, ok := <-s.sshChan:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
return s.setupInteractiveMode(channel)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.InteractiveModeHEADLESS)
return nil
}
}
func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
ch, reqs, err := channel.Accept() ch, reqs, err := channel.Accept()
if err != nil { if err != nil {
log.Printf("failed to accept channel: %v", err) log.Printf("failed to accept channel: %v", err)
return err return err
} }
go s.HandleGlobalRequest(reqs)
tcpipReq := s.waitForTCPIPForward() go func() {
if tcpipReq == nil { err = s.HandleGlobalRequest(reqs)
_, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))))
if err != nil { if err != nil {
return err log.Printf("global request handler error: %v", err)
} }
if err := s.lifecycle.Close(); err != nil { }()
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("no forwarding Request")
}
s.lifecycle.SetChannel(ch) s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch) s.interaction.SetChannel(ch)
s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
s.HandleTCPIPForward(tcpipReq) return nil
}
func (s *session) handleMissingForwardRequest() error {
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
if err != nil {
return err
}
if err = s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("no forwarding Request")
}
func (s *session) shouldRejectUnauthorized() bool {
return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
s.config.Mode() == types.ServerModeSTANDALONE &&
s.lifecycle.User() == "UNAUTHORIZED"
}
func (s *session) waitForSessionEnd() error {
if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("ssh connection closed with error: %v", err)
}
if err := s.lifecycle.Close(); err != nil { if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
@@ -119,7 +188,7 @@ func (s *SSHSession) Start() error {
return nil return nil
} }
func (s *SSHSession) waitForTCPIPForward() *ssh.Request { func (s *session) waitForTCPIPForward() *ssh.Request {
select { select {
case req, ok := <-s.initialReq: case req, ok := <-s.initialReq:
if !ok { if !ok {
@@ -139,3 +208,214 @@ func (s *SSHSession) waitForTCPIPForward() *ssh.Request {
return nil return nil
} }
} }
func (s *session) handleWindowChange(req *ssh.Request) error {
p := req.Payload
if len(p) < 16 {
log.Println("invalid window-change payload")
return req.Reply(false, nil)
}
cols := binary.BigEndian.Uint32(p[0:4])
rows := binary.BigEndian.Uint32(p[4:8])
s.interaction.SetWH(int(cols), int(rows))
return req.Reply(true, nil)
}
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
for req := range GlobalRequest {
switch req.Type {
case "shell", "pty-req":
err := req.Reply(true, nil)
if err != nil {
return err
}
case "window-change":
if err := s.handleWindowChange(req); err != nil {
return err
}
default:
log.Println("Unknown request type:", req.Type)
err := req.Reply(false, nil)
if err != nil {
return err
}
}
}
return nil
}
func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) {
address, err = readSSHString(payloadReader)
if err != nil {
return "", 0, err
}
var rawPortToBind uint32
if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil {
return "", 0, err
}
if rawPortToBind > 65535 {
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
}
port = uint16(rawPortToBind)
if isBlockedPort(port) {
return "", 0, fmt.Errorf("port is block")
}
if port == 0 {
unassigned, ok := s.lifecycle.PortRegistry().Unassigned()
if !ok {
return "", 0, fmt.Errorf("no available port")
}
return address, unassigned, err
}
return address, port, err
}
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
var errs []error
if key != nil {
s.registry.Remove(*key)
}
if listener != nil {
if err := listener.Close(); err != nil {
errs = append(errs, fmt.Errorf("close listener: %w", err))
}
}
if err := req.Reply(false, nil); err != nil {
errs = append(errs, fmt.Errorf("reply request: %w", err))
}
if err := s.lifecycle.Close(); err != nil {
errs = append(errs, fmt.Errorf("close session: %w", err))
}
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
return errors.Join(errs...)
}
func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) {
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(port))
if err != nil {
return err
}
err = req.Reply(true, buf.Bytes())
if err != nil {
return err
}
return nil
}
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
err := s.approveForwardingRequest(req, portToBind)
if err != nil {
return err
}
s.forwarder.SetType(tunnelType)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(slug)
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
if listener != nil {
s.forwarder.SetListener(listener)
}
return nil
}
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
reader := bytes.NewReader(req.Payload)
address, port, err := s.parseForwardPayload(reader)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
}
switch port {
case 80, 443:
return s.HandleHTTPForward(req, port)
default:
return s.HandleTCPForward(req, address, port)
}
}
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
randomString, err := random.GenerateRandomString(20)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
}
key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
}
err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
if err != nil {
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}
return nil
}
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
}
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
listener, err := tcpServer.Listen()
if err != nil {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
}
err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
if err != nil {
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}
go func() {
err = tcpServer.Serve(listener)
if err != nil {
log.Printf("Failed serving tcp server: %s\n", err)
}
}()
return nil
}
func readSSHString(reader io.Reader) (string, error) {
var length uint32
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
return "", err
}
strBytes := make([]byte, length)
if _, err := reader.Read(strBytes); err != nil {
return "", err
}
return string(strBytes), nil
}
func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 {
return false
}
if port < 1024 && port != 0 {
return true
}
for _, p := range blockedReservedPorts {
if p == port {
return true
}
}
return false
}
+7 -7
View File
@@ -1,24 +1,24 @@
package slug package slug
type Manager interface { type Slug interface {
Get() string String() string
Set(slug string) Set(slug string)
} }
type manager struct { type slug struct {
slug string slug string
} }
func NewManager() Manager { func New() Slug {
return &manager{ return &slug{
slug: "", slug: "",
} }
} }
func (s *manager) Get() string { func (s *slug) String() string {
return s.slug return s.slug
} }
func (s *manager) Set(slug string) { func (s *slug) Set(slug string) {
s.slug = slug s.slug = slug
} }
+1
View File
@@ -0,0 +1 @@
sonar.projectKey=tunnel-please
+31 -8
View File
@@ -1,19 +1,34 @@
package types package types
type Status string import "time"
type SessionStatus int
const ( const (
INITIALIZING Status = "INITIALIZING" SessionStatusINITIALIZING SessionStatus = iota
RUNNING Status = "RUNNING" SessionStatusRUNNING
SETUP Status = "SETUP"
) )
type TunnelType string type InteractiveMode int
const ( const (
UNKNOWN TunnelType = "UNKNOWN" InteractiveModeINTERACTIVE InteractiveMode = iota + 1
HTTP TunnelType = "HTTP" InteractiveModeHEADLESS
TCP TunnelType = "TCP" )
type TunnelType int
const (
TunnelTypeUNKNOWN TunnelType = iota
TunnelTypeHTTP
TunnelTypeTCP
)
type ServerMode int
const (
ServerModeSTANDALONE = iota + 1
ServerModeNODE
) )
type SessionKey struct { type SessionKey struct {
@@ -21,6 +36,14 @@ type SessionKey struct {
Type TunnelType Type TunnelType
} }
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
}
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" + var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" + "Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" + "Content-Type: text/plain\r\n\r\n" +