145 Commits

Author SHA1 Message Date
8229879db8 Merge pull request 'chore(deps): update golang docker tag to v1.25.7' (#82) from renovate/golang-1.x into main
SonarQube Scan / SonarQube Trigger (push) Successful in 2m13s
2026-02-05 02:03:49 +07:00
7015e7f4de chore(deps): update golang docker tag to v1.25.7
Tests / Run Tests (pull_request) Successful in 1m11s
2026-02-04 19:03:47 +00:00
03c6b44fa2 Merge pull request 'fix(deps): update module github.com/charmbracelet/bubbles to v0.21.1' (#81) from renovate/github.com-charmbracelet-bubbles-0.x into main
SonarQube Scan / SonarQube Trigger (push) Successful in 3m15s
2026-02-03 15:05:34 +07:00
3af3fdbc9c fix(deps): update module github.com/charmbracelet/bubbles to v0.21.1
Tests / Run Tests (pull_request) Successful in 1m18s
2026-02-03 08:05:29 +00:00
6dc4bb58ea Merge pull request 'chore(deps): update actions/checkout action to v6' (#80) from renovate/actions-checkout-6.x into main
SonarQube Scan / SonarQube Trigger (push) Successful in 4m25s
Reviewed-on: #80
2026-01-28 01:16:08 +07:00
bd2b843e5d chore(deps): update actions/checkout action to v6
Tests / Run Tests (pull_request) Successful in 1m9s
2026-01-27 18:11:54 +00:00
5b05723e93 ci: refactor workflows for SonarQube, tag-only Docker builds, and global testing
SonarQube Scan / SonarQube Trigger (push) Successful in 4m41s
Docker Build and Push / Run Tests (push) Successful in 1m59s
Docker Build and Push / Build and Push Docker Image (push) Successful in 8m22s
- Run SonarQube scans only on main, staging, and feat/* branches
- Build and push Docker images only on semantic version tags
- Add test job that runs on all events
2026-01-28 01:06:29 +07:00
22ad935299 Merge pull request 'chore(deps): update actions/checkout action to v6' (#75) from renovate/actions-checkout-6.x into main
SonarQube Scan / SonarQube Trigger (push) Successful in 6m25s
Reviewed-on: #75
2026-01-27 18:36:31 +07:00
ebd915e18e chore(deps): update actions/checkout action to v6
SonarQube Scan / SonarQube Trigger (pull_request) Has been cancelled
SonarQube Scan / SonarQube Trigger (push) Has been cancelled
2026-01-27 11:35:15 +00:00
728691d119 Update .gitea/workflows/sonarqube.yml
SonarQube Scan / SonarQube Trigger (push) Has been cancelled
2026-01-27 18:31:10 +07:00
1344afd1b2 Merge pull request 'fix(deps): update module github.com/stretchr/testify to v1.11.1' (#79) from renovate/github.com-stretchr-testify-1.x into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
SonarQube Scan / SonarQube Trigger (push) Has been cancelled
Docker Build and Push / build-and-push-branches (push) Has been cancelled
2026-01-27 18:19:53 +07:00
4cbee5079c fix(deps): update module github.com/stretchr/testify to v1.11.1
SonarQube Scan / SonarQube Trigger (pull_request) Has been cancelled
SonarQube Scan / SonarQube Trigger (push) Has been cancelled
2026-01-27 11:19:47 +00:00
0b071dfde7 Merge pull request 'chore(deps): update dependency go to v1.25.6' (#78) from renovate/go-1.x into main
SonarQube Scan / SonarQube Trigger (push) Has been cancelled
2026-01-27 18:19:40 +07:00
6062c2e11d chore(deps): update dependency go to v1.25.6
SonarQube Scan / SonarQube Trigger (pull_request) Has been cancelled
SonarQube Scan / SonarQube Trigger (push) Has been cancelled
2026-01-27 11:19:34 +00:00
2a2d484e91 Merge pull request 'staging' (#77) from staging into main
SonarQube Scan / SonarQube Trigger (push) Successful in 6m4s
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 23m12s
Reviewed-on: #77
2026-01-27 18:08:36 +07:00
9377233515 feat(testing): comprehensive test coverage and quality improvements (#76)
SonarQube Scan / SonarQube Trigger (push) Successful in 3m32s
Docker Build and Push / build-and-push-branches (push) Successful in 48m34s
Docker Build and Push / build-and-push-tags (push) Has been skipped
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 6m12s
- Added unit tests for all core components (interaction, forwarder, stream, lifecycle, session, config, transport, middleware, etc.)
- Migrated to Testify framework for testing
- Integrated SonarQube for code quality monitoring
- Reduced cognitive complexity across multiple modules
- Fixed buffer handling, serialization, and error handling issues
- Set up automated CI/CD pipeline with coverage reporting

Reviewed-on: #76
2026-01-27 16:36:40 +07:00
fab625e13a docs: show CI/CD status badge and mascot in README
SonarQube Scan / SonarQube Trigger (push) Successful in 3m32s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 3m26s
2026-01-27 16:28:20 +07:00
1ed845bf2d test(interaction): add unit tests for interaction behavior 2026-01-27 16:28:20 +07:00
67378aabda refactor(dockerfile): split long ldflags line 2026-01-27 16:28:20 +07:00
a26d1672d9 refactor(interaction): reduce cognitive complexity and centralize color constants 2026-01-27 16:28:20 +07:00
7f44cc7bc0 fix: ensure proper buffer reuse with pointer handling in sync.Pool 2026-01-27 16:28:20 +07:00
a3f6baa6ae test: check and handle error for testing 2026-01-27 16:28:20 +07:00
6def82a095 ci: add project source and test path for sonarqube 2026-01-27 16:28:20 +07:00
354da27424 test(forwarder): add unit tests for forwarder behavior 2026-01-27 16:28:20 +07:00
ee1dc3c3cd chore(tests): migrate to Testify for mocking and assertions 2026-01-27 16:28:20 +07:00
65df01fee5 refactor(forwarder): remove CreateForwardedTCPIPPayload method
- OpenForwardedChannel now privately calls CreateForwardedTCPIPPayload
- Removed an unused function
2026-01-27 16:28:20 +07:00
79fd292a77 feat(http): add http header size limit for initial request 2026-01-27 16:28:20 +07:00
4041681be6 refactor(header): NewRequest to accept only []byte 2026-01-27 16:28:20 +07:00
2ee24c8d51 test(config): add test for keyloc and header size 2026-01-27 16:28:20 +07:00
384bb98f48 test(stream): migrate mocking to testify 2026-01-27 16:28:20 +07:00
9785a97973 refactor: remove duplicate channel management helpers from HTTP handler 2026-01-27 16:28:20 +07:00
b8c6359820 refactor: remove custom parsing functions and use ssh.Marshal/ssh.Unmarshal for serialization 2026-01-27 16:28:20 +07:00
8fee8bf92e test(server): add unit test for handleConnection 2026-01-27 16:28:20 +07:00
04c9ddbc13 test(lifecycle): add unit tests for lifecycle behavior 2026-01-27 16:28:20 +07:00
211745dc26 test(slug): add unit tests for slug behavior 2026-01-27 16:28:20 +07:00
09aa92a0ae fix: properly initialize tlsStoragePath in config load 2026-01-27 16:28:20 +07:00
1ed9f3631f fix: correct buffer pool usage to avoid type assertion error 2026-01-27 16:28:20 +07:00
bd826d6d06 refactor(transport): reduce cognitive complexity and clean up public API 2026-01-27 16:28:20 +07:00
2f5c44ff01 test(bootstrap): add unit tests for initial bootstrap behavior 2026-01-27 16:28:20 +07:00
d0e052524c refactor: decouple application startup logic from main 2026-01-27 16:28:20 +07:00
24b9872aa4 fix: corrected defer usage to pass buffer pointer 2026-01-27 16:28:20 +07:00
8b84373036 fix: remove unnecessary use of fmt.Sprintf 2026-01-27 16:28:20 +07:00
e796ab5328 fix: handle error return values for privateKeyFile.Close and pubKeyFile.Close 2026-01-27 16:28:20 +07:00
efdfc4ce95 chore: remove unused headerBuf variable 2026-01-27 16:28:20 +07:00
1dc929cc25 ci: sonarqube add linting 2026-01-27 16:28:20 +07:00
14abac6579 test(session): add unit tests for session behavior 2026-01-27 16:28:20 +07:00
21179da4b5 refactor(session): reduce function parameters 2026-01-27 16:28:20 +07:00
32f8be2891 test(version): add unit tests for version behavior 2026-01-27 16:28:20 +07:00
5af7af3139 test(client): add unit tests for grpc client behavior 2026-01-27 16:28:20 +07:00
f4848e9754 fix(client): reduce cognitive complexity and fix typo (go:S3776) 2026-01-27 16:28:20 +07:00
d2e508c8ef test(key): add unit tests for key behavior 2026-01-27 16:28:20 +07:00
5499b7d08a ci: update SonarQube action configuration 2026-01-27 16:28:20 +07:00
58f1fdabe1 test(server): add unit tests for server startup behavior 2026-01-27 16:28:20 +07:00
c1fb588cf4 test(config): add unit tests for config behavior 2026-01-27 16:28:20 +07:00
3029996773 test(stream): add unit tests for stream behavior
- Fix duplicating EOF error when closing SSH connection
- Add new SessionStatusCLOSED type
2026-01-27 16:28:20 +07:00
3fd179d32b test(header): add unit tests for header behavior 2026-01-27 16:28:20 +07:00
a598a10e94 update: exclude local test coverage 2026-01-27 16:28:20 +07:00
29cabe42d3 test(transport): add unit tests for transport behavior using Testify 2026-01-27 16:28:20 +07:00
e534972abc test(random): add unit tests for random behavior
- Added unit tests to cover random string generation and error handling.
- Introduced Random interface and random struct for better abstraction.
- Updated server, session, and interaction packages to require Random interface for dependency injection.
2026-01-27 16:28:20 +07:00
a55ff5f6ab test(port): add unit tests for port behavior 2026-01-27 16:28:20 +07:00
50b4127cb3 test(middleware): add unit tests for middleware behavior
- remove redundant check on registry.Update and check if slug exist before locking the mutex
- Update SonarQube action to not use Go cache when setting up Go
2026-01-27 16:28:20 +07:00
7e635721fb ci: automate Go tests and Sonar coverage reporting 2026-01-27 16:28:20 +07:00
016df9caee test(registry): add unit tests for registry behavior 2026-01-27 16:28:20 +07:00
d91eecb2a0 chore: Refactor and optimize project architecture
Docker Build and Push / build-and-push-tags (push) Has been skipped
SonarQube Scan / SonarQube Trigger (push) Successful in 54s
Docker Build and Push / build-and-push-branches (push) Successful in 12m17s
- Fix: Resolve goroutine deadlock on early connection close
- Refactor: Simplify Start() method, unify forwarding logic, and enhance HTTP handler modularity
- Improve: Connection handling, header parsing, and resource management
- Refactor: Centralize environment loading, enforce typed access, and cleanup config structure
- Enhance: SonarQube scan integration for CI
- Chore: Reorganize project layout and simplify lifecycle management
- Define reusable constants for registry errors

Reviewed-on: #74
2026-01-22 22:16:33 +07:00
961a905542 chore(restructure): refactor architecture, config, and lifecycle management
Docker Build and Push / build-and-push-tags (push) Has been skipped
SonarQube Scan / SonarQube Trigger (push) Successful in 44s
Docker Build and Push / build-and-push-branches (push) Successful in 11m16s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 46s
- Reorganized internal packages and overall project structure
- Moved HTTP/HTTPS/TCP servers into the transport layer
- Decoupled server initialization from HTTP/HTTPS/TCP startup logic
- Separated HTTP parsing, streaming, middleware, and session registry concerns
- Refactored session and forwarder responsibilities for clearer ownership
- Centralized environment loading with validated, typed config access
- Made config immutable after initialization and normalized enum naming
- Improved resource lifecycle handling and error aggregation on shutdown
- Introduced reusable, package-level registry errors
- Added SonarQube scanning to CI pipeline

Reviewed-on: #73
2026-01-22 00:48:40 +07:00
634c8321ef refactor(registry): define reusable constant errors
SonarQube Scan / SonarQube Trigger (push) Successful in 52s
SonarQube Scan / SonarQube Trigger (pull_request) Successful in 46s
- Introduced package-level error variables in registry to replace repeated fmt.Errorf calls
- Added errors like ErrSessionNotFound, ErrSlugInUse, ErrInvalidSlug, ErrForbiddenSlug, ErrSlugChangeNotAllowed, and ErrSlugUnchanged
2026-01-22 00:39:28 +07:00
9f4c24a3f3 refactor(lifecycle): reorder resource closing and simplify Close()
SonarQube Scan / SonarQube Trigger (push) Successful in 53s
- Close channel and connection first, then remove session
- Close forwarded port and forwarder at the end for TCP tunnels
- Aggregate all errors using errors.Join instead of failing early
2026-01-21 21:59:59 +07:00
1408b80917 ci: add sonarqube scan
SonarQube Scan / SonarQube Trigger (push) Successful in 48s
2026-01-21 21:24:57 +07:00
2bc20dd991 refactor(config): centralize env loading and enforce typed access
- Centralize environment variable loading in config.MustLoad
- Parse and validate all env vars once at initialization
- Make config fields private and read-only
- Remove public Getenv usage in favor of typed accessors
- Improve validation and initialization order
- Normalize enum naming to be idiomatic and avoid constant collisions
2026-01-21 19:43:19 +07:00
1e12373359 chore(restructure): reorganize project layout
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 13m1s
- Reorganize internal packages and overall project structure
- Update imports and wiring to match the new layout
- Separate HTTP parsing and streaming from the server package
- Separate middleware from the server package
- Separate session registry from the session package
- Move HTTP, HTTPS, and TCP servers to the transport package
- Session package no longer starts the TCP server directly
- Server package no longer starts HTTP/HTTPS servers on initialization
- Forwarder no longer handles accepting TCP requests
- Move session details to the types package
- HTTP/HTTPS initialization is now the responsibility of main
2026-01-21 14:06:46 +07:00
9a4539cc02 refactor(httpheader): extract header parsing into dedicated package
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m19s
Moved HTTP header parsing and building logic from server package to internal/httpheader
2026-01-20 21:15:34 +07:00
e3ead4d52f refactor: optimize header parsing and remove factory naming
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m20s
- Remove factory naming
- Use direct byte indexing instead of bytes.TrimRight
- Extract parseStartLine and setRemainingHeaders helpers
2026-01-20 20:56:08 +07:00
aa1a465178 refactor(forwarder): improve connection handling and cleanup
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Has been cancelled
- Extract copyAndClose method for bidirectional data transfe
- Add closeWriter helper for graceful connection shutdown
- Add handleIncomingConnection helper
- Add openForwardedChannel helper
2026-01-20 19:01:15 +07:00
27f49879af refactor(server): enhance HTTP handler modularity and fix resource leak
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m43s
- Rename customWriter struct to httpWriter for clarity
- Add closeWriter field to properly close write side of connections
- Update all cw variable references to hw
- Merge handlerTLS into handler function to reduce code duplication
- Extract handler into smaller, focused methods
- Split Read/Write/forwardRequest into composable functions

Fixes resource leak where connections weren't properly closed on the
write side, matching the forwarder's CloseWrite() pattern.
2026-01-19 22:41:04 +07:00
adb0264bb5 refactor(session): simplify Start() and unify forwarding logic
- Extract helper functions from Start() for better code organization
- Eliminate duplication with finalizeForwarding() method
- Consolidate denial logic into denyForwardingRequest()
- Update all handler methods to return errors instead of logging internally
- Improve error handling consistency across all operations
2026-01-19 15:53:16 +07:00
8fb19af5a6 fix: resolve copy goroutine deadlock on early connection close
- Add proper CloseWrite handling to signal EOF to other goroutine
- Ensure both copy goroutines terminate when either side closes
- Prevent goroutine leaks for SSH forwarded-tcpip channels:
    - Use select with default when sending result to resultChan
    - Close unused SSH channels and discard requests if main goroutine has already timed out
2026-01-19 00:20:28 +07:00
41fdb5639c Merge pull request 'refactor: explicit initialization and dependency injection' (#70) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 9m49s
Reviewed-on: #70
2026-01-18 21:46:59 +07:00
44d224f491 refactor: explicit initialization and dependency injection
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 10m10s
- Replace init() with config.Load() function when loading env variables
- Inject portRegistry into session, server, and lifecycle structs
- Inject sessionRegistry directly into interaction and lifecycle
- Remove SetSessionRegistry function and global port variables
- Pass ssh.Conn directly to forwarder constructor instead of lifecycle interface
- Pass user and closeFunc callback to interaction constructor instead of lifecycle interface
- Eliminate circular dependencies between lifecycle, forwarder, and interaction
- Remove setter methods (SetLifecycle) from forwarder and interaction interfaces
2026-01-18 21:20:05 +07:00
9be0328e24 Merge pull request 'staging' (#69) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 9m30s
Reviewed-on: #69
2026-01-17 19:15:40 +07:00
2b9bca65d5 refactor(interaction): separate view and update logic into modular files
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 11m44s
- Extract slug editing logic to slug.go (slugView/slugUpdate)
- Extract commands menu logic to commands.go (commandsView/commandsUpdate)
- Extract coming soon modal to coming_soon.go (comingSoonView/comingSoonUpdate)
- Extract main dashboard logic to dashboard.go (dashboardView/dashboardUpdate)
- Create model.go for shared model struct and helper functions
- Replace math/rand with crypto/rand for random subdomain generation
- Remove legacy TLS cipher suite configuration
2026-01-17 17:33:10 +07:00
6587dc0f39 refactor(interaction): separate view and update logic into modular files
- Extract slug editing logic to slug.go (slugView/slugUpdate)
- Extract commands menu logic to commands.go (commandsView/commandsUpdate)
- Extract coming soon modal to coming_soon.go (comingSoonView/comingSoonUpdate)
- Extract main dashboard logic to dashboard.go (dashboardView/dashboardUpdate)
- Create model.go for shared model struct and helper functions
- Replace math/rand with crypto/rand for random subdomain generation
- Remove legacy TLS cipher suite configuration
2026-01-17 17:30:21 +07:00
f421781f44 Merge pull request 'refactor: convert structs to interfaces and rename accessors' (#68) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 10m34s
Reviewed-on: #68
2026-01-16 16:41:22 +07:00
6969d6823a Merge branch 'main' into staging
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 11m38s
2026-01-16 16:35:36 +07:00
1a04af8873 Merge branch 'main' into staging
Docker Build and Push / build-and-push-branches (push) Successful in 11m35s
Docker Build and Push / build-and-push-tags (push) Has been skipped
2026-01-16 16:28:39 +07:00
19135ceb42 refactor: convert structs to interfaces and rename accessors
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Has been cancelled
- Convert struct types to interfaces
- Rename getter and setter methods
- Add Close method to server interface
- Merge handler functionality into session file
- Handle lifecycle.Connection().Wait()
- fix panic on nil connection in SSH server
2026-01-16 15:25:31 +07:00
edb11dbc51 Merge pull request 'chore(deps): update golang docker tag to v1.25.6' (#67) from renovate/golang-1.x into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m34s
2026-01-16 05:01:06 +07:00
819f044275 chore(deps): update golang docker tag to v1.25.6 2026-01-15 22:01:02 +00:00
a7ebf2c5db Merge pull request 'fix(deps): update module golang.org/x/crypto to v0.47.0' (#66) from renovate/golang.org-x-crypto-0.x into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 10m34s
Reviewed-on: #66
2026-01-14 10:42:52 +00:00
64c1038f4b fix(deps): update module golang.org/x/crypto to v0.47.0 2026-01-14 10:41:47 +00:00
aafea49975 feat: integrate gRPC, session refactor, SSH headless support, and bug fixes
Docker Build and Push / build-and-push-tags (push) Successful in 11m34s
Docker Build and Push / build-and-push-branches (push) Has been skipped
- gRPC integration: slug edit handling, get sessions by user, and session requests from gRPC server
- Refactor gRPC client: simplify processEventStream and handle authenticated user info
- Session management improvements: use session key for registry, forwarder session termination, inject SessionRegistry interface
- SSH enhancements: add headless mode support for SSH -N connections
- Bug fixes:
  - prevent subdomain changes to already-in-use subdomains
  - fix startup order and environment variable keys
  - atomic ClaimPort() to prevent race conditions
- Refactors:
  - consolidate error handling
  - replace Get/Set patterns with idiomatic Go interfaces
  - change enums from string to int
- CI cleanup: remove renovate bot

Reviewed-on: #65
2026-01-14 10:16:43 +00:00
dbdf8094fa refactor: replace Get/Set patterns with idiomatic Go interfaces
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 13m4s
- rename constructors to New
- remove Get/Set-style accessors
- replace string-based enums with iota-backed types
2026-01-14 16:54:10 +07:00
ae3ed52d16 fix(port): add atomic ClaimPort() to prevent race condition
- Replace GetPortStatus/SetPortStatus calls with atomic ClaimPort() operation.
- Fixed a logic error when handling headless tunneling.
2026-01-14 16:51:50 +07:00
fb638636bf refactor: consolidate error handling with fail() function in session handlers
- Replace repetitive error handling code with fail() function in HandleGlobalRequest
- Standardize error response pattern across all handler methods
- Improve code maintainability and reduce duplication
2026-01-14 16:51:50 +07:00
da29df85b7 feat: add headless mode support for SSH -N connections
- use s.lifecycle.GetConnection().Wait() to block until SSH connection closes
- Prevent premature session closure in headless mode

In headless mode (ssh -N), there's no channel interaction to block on,
so the session would immediately return and close. Now blocking on
conn.Wait() keeps the session alive until the client disconnects.
2026-01-14 16:51:50 +07:00
8b0e08c629 fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 2026-01-14 16:51:50 +07:00
f0804d6946 ci: remove renovate 2026-01-14 16:51:50 +07:00
09e526cd1e feat: add authenticated user info and restructure handleConnection
- Display authenticated username in welcome page information box
- Refactor handleConnection function for better structure and clarity
2026-01-14 16:51:50 +07:00
887ebf78b1 refactor(grpc/client): simplify processEventStream with per-event handlers
- Extract eventHandlers dispatch table
- Add per-event handlers: handleSlugChange, handleGetSessions, handleTerminateSession
- Introduce sendNode helper to centralize send/error handling and preserve connection-error propagation
- Add protoToTunnelType for tunnel-type validation
- Map unknown proto.TunnelType to types.UNKNOWN in protoToTunnelType and return a descriptive error
- Reduce boilerplate and improve readability of processEventStream
2026-01-14 16:51:50 +07:00
bef7a49f88 feat: implement forwarder session termination 2026-01-14 16:51:50 +07:00
17633b4e3c refactor: inject SessionRegistry interface instead of individual functions 2026-01-14 16:51:50 +07:00
f25d61d1d1 update: proto file to v1.3.0 2026-01-14 16:51:50 +07:00
8782b77b74 feat(session): use session key for registry 2026-01-14 16:51:50 +07:00
fc3cd886db fix: use correct environment variable key 2026-01-14 16:51:50 +07:00
b0da57db0d fix: startup order 2026-01-14 16:51:50 +07:00
0bd6eeadf3 feat: implement sessions request from grpc server 2026-01-14 16:51:50 +07:00
449f546e04 feat: implement sessions request from grpc server 2026-01-14 16:51:50 +07:00
4644420eee feat: implement get sessions by user 2026-01-14 16:51:50 +07:00
c9bf9e62bd feat(grpc): integrate slug edit handling 2026-01-14 16:51:50 +07:00
57d2136377 WIP: gRPC integration, initial implementation 2026-01-14 16:51:47 +07:00
8a34aaba80 WIP: gRPC integration, initial implementation 2026-01-14 16:51:35 +07:00
ff995a929e revert 01ddc76f7e
revert Merge pull request 'fix(deps): update module github.com/caddyserver/certmagic to v0.25.1' (#58) from renovate/github.com-caddyserver-certmagic-0.x into main
2026-01-14 16:51:35 +07:00
32ac9c1749 fix(deps): update module github.com/caddyserver/certmagic to v0.25.1
# Conflicts:
#	go.mod
2026-01-14 16:51:30 +07:00
07d9f3afe6 refactor: replace Get/Set patterns with idiomatic Go interfaces
Docker Build and Push / build-and-push-tags (push) Successful in 10m59s
Docker Build and Push / build-and-push-branches (push) Has been skipped
- rename constructors to New
- remove Get/Set-style accessors
- replace string-based enums with iota-backed types
2026-01-14 15:28:17 +07:00
e051a5b742 Merge pull request 'fix(deps): update module golang.org/x/crypto to v0.47.0' (#64) from renovate/golang.org-x-crypto-0.x into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 9m51s
renovate / renovate (push) Successful in 55s
2026-01-12 18:20:57 +00:00
d35228759c fix(deps): update module golang.org/x/crypto to v0.47.0 2026-01-12 18:20:53 +00:00
abd103b5ab fix(port): add atomic ClaimPort() to prevent race condition
Docker Build and Push / build-and-push-tags (push) Successful in 3m23s
Docker Build and Push / build-and-push-branches (push) Has been skipped
- Replace GetPortStatus/SetPortStatus calls with atomic ClaimPort() operation.
- Fixed a logic error when handling headless tunneling.
2026-01-12 18:25:35 +07:00
560c98b869 refactor: consolidate error handling with fail() function in session handlers
Docker Build and Push / build-and-push-tags (push) Successful in 3m21s
Docker Build and Push / build-and-push-branches (push) Has been skipped
- Replace repetitive error handling code with fail() function in HandleGlobalRequest
- Standardize error response pattern across all handler methods
- Improve code maintainability and reduce duplication
2026-01-12 14:42:42 +07:00
e1f5d73e03 feat: add headless mode support for SSH -N connections
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 3m3s
- use s.lifecycle.GetConnection().Wait() to block until SSH connection closes
- Prevent premature session closure in headless mode

In headless mode (ssh -N), there's no channel interaction to block on,
so the session would immediately return and close. Now blocking on
conn.Wait() keeps the session alive until the client disconnects.
2026-01-11 15:21:11 +07:00
19fd6d59d2 Merge pull request 'main' (#62) from main into staging
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 3m32s
Reviewed-on: #62
2026-01-09 12:15:30 +00:00
e3988b339f Merge pull request 'fix(deps): update module github.com/caddyserver/certmagic to v0.25.1' (#61) from renovate/github.com-caddyserver-certmagic-0.x into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 3m21s
Reviewed-on: #61
2026-01-09 12:15:05 +00:00
336948a397 fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 2026-01-09 10:00:35 +00:00
50ae422de8 Merge pull request 'staging' (#60) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 3m20s
Reviewed-on: #60
2026-01-09 09:33:28 +00:00
8467ed555e revert 01ddc76f7e
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Has been cancelled
revert Merge pull request 'fix(deps): update module github.com/caddyserver/certmagic to v0.25.1' (#58) from renovate/github.com-caddyserver-certmagic-0.x into main
2026-01-09 09:33:04 +00:00
01ddc76f7e Merge pull request 'fix(deps): update module github.com/caddyserver/certmagic to v0.25.1' (#58) from renovate/github.com-caddyserver-certmagic-0.x into main
Docker Build and Push / build-and-push-branches (push) Waiting to run
Docker Build and Push / build-and-push-tags (push) Has been skipped
2026-01-09 09:30:23 +00:00
ffb3565ff5 fix(deps): update module github.com/caddyserver/certmagic to v0.25.1 2026-01-09 09:30:18 +00:00
6d700ef6dd Merge pull request 'feat/grpc-integration' (#59) from feat/grpc-integration into staging
Docker Build and Push / build-and-push-branches (push) Successful in 5m25s
Docker Build and Push / build-and-push-tags (push) Has been skipped
Reviewed-on: #59
2026-01-09 09:24:20 +00:00
b8acb6da4c ci: remove renovate
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Has been cancelled
2026-01-08 13:03:02 +07:00
6b4127f0ef feat: add authenticated user info and restructure handleConnection
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 3m21s
- Display authenticated username in welcome page information box
- Refactor handleConnection function for better structure and clarity
2026-01-07 23:07:02 +07:00
16d48ff906 refactor(grpc/client): simplify processEventStream with per-event handlers
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 3m20s
- Extract eventHandlers dispatch table
- Add per-event handlers: handleSlugChange, handleGetSessions, handleTerminateSession
- Introduce sendNode helper to centralize send/error handling and preserve connection-error propagation
- Add protoToTunnelType for tunnel-type validation
- Map unknown proto.TunnelType to types.UNKNOWN in protoToTunnelType and return a descriptive error
- Reduce boilerplate and improve readability of processEventStream
2026-01-06 20:14:56 +07:00
6213ff8a30 feat: implement forwarder session termination
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 3m36s
2026-01-06 18:32:48 +07:00
4ffaec9d9a refactor: inject SessionRegistry interface instead of individual functions
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 4m16s
2026-01-05 16:49:17 +07:00
6de0a618ee update: proto file to v1.3.0
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 4m0s
2026-01-05 00:55:51 +07:00
8cc70fa45e feat(session): use session key for registry 2026-01-05 00:50:42 +07:00
d666ae5545 fix: use correct environment variable key
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 4m1s
2026-01-04 18:21:34 +07:00
5edb3c8086 fix: startup order
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 3m51s
2026-01-04 15:19:03 +07:00
5b603d8317 feat: implement sessions request from grpc server
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 4m7s
2026-01-03 21:17:01 +07:00
5ceade81db Merge pull request 'staging' (#57) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 3m57s
renovate / renovate (push) Failing after 34s
Reviewed-on: #57
2026-01-03 13:07:49 +00:00
8fd9f8b567 feat: implement sessions request from grpc server
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Has been cancelled
2026-01-03 20:06:14 +07:00
30e84ac3b7 feat: implement get sessions by user 2026-01-02 22:58:54 +07:00
fd6ffc2500 feat(grpc): integrate slug edit handling 2026-01-02 18:27:48 +07:00
e1cd4ed981 WIP: gRPC integration, initial implementation 2026-01-01 21:03:17 +07:00
96d2b88f95 WIP: gRPC integration, initial implementation 2026-01-01 21:01:15 +07:00
8a456d2cde Merge pull request 'staging' (#55) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 5m50s
renovate / renovate (push) Successful in 35s
Reviewed-on: #55
2025-12-31 08:51:25 +00:00
8841230653 Merge pull request 'fix: prevent subdomain change to already-in-use subdomains' (#54) from staging into main
Docker Build and Push / build-and-push (push) Successful in 5m20s
renovate / renovate (push) Successful in 38s
Reviewed-on: #54
2025-12-30 12:42:05 +00:00
4d0a7deaf2 Merge pull request 'staging' (#53) from staging into main
Docker Build and Push / build-and-push (push) Successful in 3m33s
renovate / renovate (push) Successful in 22s
Reviewed-on: #53
2025-12-29 17:18:25 +00:00
81 changed files with 17990 additions and 3093 deletions
+40 -80
View File
@@ -2,24 +2,38 @@ name: Docker Build and Push
on:
push:
branches:
- main
- staging
tags:
- 'v*'
paths:
- '**.go'
- 'go.mod'
- 'go.sum'
- 'Dockerfile'
- 'Dockerfile.*'
- '.dockerignore'
- '.gitea/workflows/build.yml'
jobs:
build-and-push-branches:
test:
name: Run Tests
runs-on: ubuntu-latest
if: github.ref_type == 'branch'
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: 'stable'
cache: false
- name: Install dependencies
run: go mod download
- name: Run go vet
run: go vet ./...
- name: Run tests
run: go test -v -p 4 ./...
build-and-push:
name: Build and Push Docker Image
runs-on: ubuntu-latest
needs: test
steps:
- name: Checkout repository
@@ -28,64 +42,7 @@ jobs:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
registry: git.fossy.my.id
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Set version variables
id: vars
run: |
if [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "VERSION=dev-main" >> $GITHUB_OUTPUT
else
echo "VERSION=dev-staging" >> $GITHUB_OUTPUT
fi
echo "BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT
echo "COMMIT=${{ github.sha }}" >> $GITHUB_OUTPUT
- name: Build and push Docker image for main
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: |
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:latest
platforms: linux/amd64,linux/arm64
build-args: |
VERSION=${{ steps.vars.outputs.VERSION }}
BUILD_DATE=${{ steps.vars.outputs.BUILD_DATE }}
COMMIT=${{ steps.vars.outputs.COMMIT }}
if: github.ref == 'refs/heads/main'
- name: Build and push Docker image for staging
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: |
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:staging
platforms: linux/amd64,linux/arm64
build-args: |
VERSION=${{ steps.vars.outputs.VERSION }}
BUILD_DATE=${{ steps.vars.outputs.BUILD_DATE }}
COMMIT=${{ steps.vars.outputs.COMMIT }}
if: github.ref == 'refs/heads/staging'
build-and-push-tags:
runs-on: ubuntu-latest
if: github.ref_type == 'tag' && startsWith(github.ref, 'refs/tags/v')
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
- name: Log in to Docker Registry
uses: docker/login-action@v3
with:
registry: git.fossy.my.id
@@ -103,32 +60,35 @@ jobs:
if echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$'; then
MAJOR=$(echo "$VERSION" | cut -d. -f1)
MINOR=$(echo "$VERSION" | cut -d. -f2)
PATCH=$(echo "$VERSION" | cut -d. -f3 | cut -d- -f1)
echo "MAJOR=$MAJOR" >> $GITHUB_OUTPUT
echo "MINOR=$MINOR" >> $GITHUB_OUTPUT
echo "PATCH=$PATCH" >> $GITHUB_OUTPUT
if echo "$VERSION" | grep -q '-'; then
PRERELEASE_TAG=$(echo "$VERSION" | cut -d- -f2 | cut -d. -f1)
echo "IS_PRERELEASE=true" >> $GITHUB_OUTPUT
echo "ADDITIONAL_TAG=staging" >> $GITHUB_OUTPUT
echo "PRERELEASE_TAG=$PRERELEASE_TAG" >> $GITHUB_OUTPUT
else
echo "IS_PRERELEASE=false" >> $GITHUB_OUTPUT
echo "ADDITIONAL_TAG=latest" >> $GITHUB_OUTPUT
fi
else
echo "Invalid version format: $VERSION"
exit 1
fi
- name: Build and push Docker image for release
- name: Build and push Docker image (release)
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: |
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:release
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }}.${{ steps.version.outputs.MINOR }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.MAJOR }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:${{ steps.version.outputs.ADDITIONAL_TAG }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:latest
platforms: linux/amd64,linux/arm64
build-args: |
VERSION=${{ steps.version.outputs.VERSION }}
@@ -136,17 +96,17 @@ jobs:
COMMIT=${{ steps.version.outputs.COMMIT }}
if: steps.version.outputs.IS_PRERELEASE == 'false'
- name: Build and push Docker image for pre-release
- name: Build and push Docker image (pre-release)
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: |
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:v${{ steps.version.outputs.VERSION }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:${{ steps.version.outputs.ADDITIONAL_TAG }}
git.fossy.my.id/${{ secrets.DOCKER_USERNAME }}/tunnel-please:staging
platforms: linux/amd64,linux/arm64
build-args: |
VERSION=${{ steps.version.outputs.VERSION }}
BUILD_DATE=${{ steps.version.outputs.BUILD_DATE }}
COMMIT=${{ steps.version.outputs.COMMIT }}
if: steps.version.outputs.IS_PRERELEASE == 'true'
if: steps.version.outputs.IS_PRERELEASE == 'true'
-21
View File
@@ -1,21 +0,0 @@
name: renovate
on:
schedule:
- cron: "0 0 * * *"
push:
branches:
- staging
jobs:
renovate:
runs-on: ubuntu-latest
container: git.fossy.my.id/renovate-clanker/renovate:latest
steps:
- uses: actions/checkout@v6
- run: renovate
env:
RENOVATE_CONFIG_FILE: ${{ gitea.workspace }}/renovate-config.js
LOG_LEVEL: "debug"
RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }}
GITHUB_COM_TOKEN: ${{ secrets.COM_TOKEN }}
+60
View File
@@ -0,0 +1,60 @@
on:
push:
branches:
- main
- staging
- 'feat/**'
name: SonarQube Scan
jobs:
sonarqube:
name: SonarQube Trigger
runs-on: ubuntu-latest
steps:
- name: Checking out
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: 'stable'
cache: false
- name: Install dependencies
run: go mod tidy
- name: Run go vet
run: go vet ./... 2>&1 | tee vet-results.txt
- name: Run tests with coverage
run: |
go test ./... -v -p 4 -coverprofile=coverage
- name: Run GolangCI-Lint Analysis
uses: golangci/golangci-lint-action@v9
with:
skip-cache: true
version: v2.6
args: >
--issues-exit-code=0
--output.text.path=stdout
--output.checkstyle.path=golangci-lint-report.xml
- name: SonarQube Scan
uses: SonarSource/sonarqube-scan-action@v7.0.0
env:
SONAR_HOST_URL: ${{ secrets.SONARQUBE_HOST }}
SONAR_TOKEN: ${{ secrets.SONARQUBE_TOKEN }}
with:
args: >
-Dsonar.projectKey=tunnel-please
-Dsonar.go.coverage.reportPaths=coverage
-Dsonar.test.inclusions=**/*_test.go
-Dsonar.test.exclusions=**/vendor/**
-Dsonar.exclusions=**/*_test.go,**/vendor/**,**/golangci-lint-report.xml
-Dsonar.go.govet.reportPaths=vet-results.txt
-Dsonar.go.golangci-lint.reportPaths=golangci-lint-report.xml
-Dsonar.sources=./
-Dsonar.tests=./
+36
View File
@@ -0,0 +1,36 @@
name: Tests
on:
pull_request:
types: [opened, synchronize, reopened]
issue_comment:
types: [created]
jobs:
test:
name: Run Tests
runs-on: ubuntu-latest
if: |
github.event_name == 'pull_request' ||
(github.event_name == 'issue_comment' &&
github.event.issue.pull_request != null &&
contains(github.event.comment.body, '/retest'))
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: 'stable'
cache: false
- name: Install dependencies
run: go mod download
- name: Run go vet
run: go vet ./...
- name: Run tests
run: go test -v -p 4 ./...
Vendored
+3 -1
View File
@@ -4,4 +4,6 @@ id_rsa*
.env
tmp
certs
app
app
coverage
test-results.json
+5 -2
View File
@@ -1,4 +1,4 @@
FROM golang:1.25.5-alpine AS go_builder
FROM golang:1.25.7-alpine AS go_builder
ARG VERSION=dev
ARG BUILD_DATE=unknown
@@ -22,7 +22,10 @@ RUN --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \
CGO_ENABLED=0 GOOS=linux \
go build -trimpath \
-ldflags="-w -s -X tunnel_pls/version.Version=${VERSION} -X tunnel_pls/version.BuildDate=${BUILD_DATE} -X tunnel_pls/version.Commit=${COMMIT}" \
-ldflags="-w -s \
-X tunnel_pls/internal/version.Version=${VERSION} \
-X tunnel_pls/internal/version.BuildDate=${BUILD_DATE} \
-X tunnel_pls/internal/version.Commit=${COMMIT}" \
-o /app/tunnel_pls \
.
+40 -112
View File
@@ -1,6 +1,22 @@
<div align="center">
<img alt="gopher" title="gopher" src="./docs/images/gopher.png" width="325" />
# Tunnel Please
A lightweight SSH-based tunnel server written in Go that enables secure TCP and HTTP forwarding with an interactive terminal interface for managing connections and custom subdomains.
A lightweight SSH-based tunnel server
<br/><br/>
[![Coverage](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=coverage&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Lines of Code](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=ncloc&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Quality Gate Status](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=alert_status&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Security Issues](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_security_issues&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Maintainability Rating](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_maintainability_rating&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Reliability Rating](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_reliability_rating&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
[![Security Rating](https://sonar.fossy.my.id/api/project_badges/measure?project=tunnel-please&metric=software_quality_security_rating&token=sqb_0feaed756b943aa75a79499d45d6610c1620830d)](https://sonar.fossy.my.id/dashboard?id=tunnel-please)
</div>
## Features
@@ -17,104 +33,32 @@ A lightweight SSH-based tunnel server written in Go that enables secure TCP and
The following environment variables can be configured in the `.env` file:
| Variable | Description | Default | Required |
|----------|-------------|---------|----------|
| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
| `PORT` | SSH server port | `2200` | No |
| `HTTP_PORT` | HTTP server port | `8080` | No |
| `HTTPS_PORT` | HTTPS server port | `8443` | No |
| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No |
| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | - | Yes (if auto-cert) |
| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
| `CORS_LIST` | Comma-separated list of allowed CORS origins | - | No |
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
| `PPROF_PORT` | Port for pprof server | `6060` | No |
| Variable | Description | Default | Required |
|---------------------|-----------------------------------------------------------------------------|-------------------------|---------------------|
| `DOMAIN` | Domain name for subdomain routing | `localhost` | No |
| `PORT` | SSH server port | `2200` | No |
| `HTTP_PORT` | HTTP server port | `8080` | No |
| `HTTPS_PORT` | HTTPS server port | `8443` | No |
| `KEY_LOC` | Path to the private key file | `certs/privkey.pem` | No |
| `TLS_ENABLED` | Enable TLS/HTTPS | `false` | No |
| `TLS_REDIRECT` | Redirect HTTP to HTTPS | `false` | No |
| `TLS_STORAGE_PATH` | Path to store TLS certificates | `certs/tls/` | No |
| `ACME_EMAIL` | Email for Let's Encrypt registration | `admin@<DOMAIN>` | No |
| `CF_API_TOKEN` | Cloudflare API token for DNS-01 challenge | `-` | Yes (if auto-cert) |
| `ACME_STAGING` | Use Let's Encrypt staging server | `false` | No |
| `CORS_LIST` | Comma-separated list of allowed CORS origins | `-` | No |
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
| `MAX_HEADER_SIZE` | Maximum size of HTTP headers in bytes (4096-131072) | `4096` | No |
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
| `PPROF_PORT` | Port for pprof server | `6060` | No |
| `MODE` | Runtime mode: `standalone` or `node` | `standalone` | No |
| `GRPC_ADDRESS` | gRPC server address/host used in `node` mode | `localhost` | No |
| `GRPC_PORT` | gRPC server port used in `node` mode | `8080` | No |
| `NODE_TOKEN` | Authentication token sent to controller in `node` mode | `-` | Yes (node mode) |
**Note:** All environment variables now use UPPERCASE naming. The application includes sensible defaults for all variables, so you can run it without a `.env` file for basic functionality.
### Automatic TLS Certificate Management
The server supports automatic TLS certificate generation and renewal using [CertMagic](https://github.com/caddyserver/certmagic) with Cloudflare DNS-01 challenge. This is required for wildcard certificate support (`*.yourdomain.com`).
**Certificate Storage:**
- TLS certificates are stored in `certs/tls/` (relative to application directory)
- User-provided certificates: `certs/tls/cert.pem` and `certs/tls/privkey.pem`
- CertMagic automatic certificates: `certs/tls/certmagic/`
- SSH keys are stored separately in `certs/ssh/`
**How it works:**
1. If user-provided certificates exist at `certs/tls/cert.pem` and `certs/tls/privkey.pem` and cover both `DOMAIN` and `*.DOMAIN`, they will be used
2. If certificates are missing, expired, expiring within 30 days, or don't cover the required domains, CertMagic will automatically obtain new certificates from Let's Encrypt
3. Certificates are automatically renewed before expiration
4. User-provided certificates support hot-reload (changes detected every 30 seconds)
**Cloudflare API Token Setup:**
To use automatic certificate generation, you need a Cloudflare API token with the following permissions:
1. Go to [Cloudflare Dashboard](https://dash.cloudflare.com/profile/api-tokens)
2. Click "Create Token"
3. Use "Create Custom Token" with these permissions:
- **Zone → Zone → Read** (for all zones or specific zone)
- **Zone → DNS → Edit** (for all zones or specific zone)
4. Copy the token and set it as `CF_API_TOKEN` environment variable
**Example configuration for automatic certificates:**
```env
DOMAIN=example.com
TLS_ENABLED=true
CF_API_TOKEN=your_cloudflare_api_token_here
ACME_EMAIL=admin@example.com
# ACME_STAGING=true # Uncomment for testing to avoid rate limits
```
### SSH Key Auto-Generation
The application will automatically generate a new 4096-bit RSA key pair at `certs/ssh/id_rsa` if it doesn't exist. This makes it easier to get started without manually creating SSH keys. SSH keys are stored separately from TLS certificates.
### Memory Optimization
The application uses a buffer pool with controlled buffer sizes to prevent excessive memory usage under high concurrent loads. The `BUFFER_SIZE` environment variable controls the size of buffers used for io.Copy operations:
- **Default:** 32768 bytes (32 KB) - Good balance for most scenarios
- **Minimum:** 4096 bytes (4 KB) - Lower memory usage, more CPU overhead
- **Maximum:** 1048576 bytes (1 MB) - Higher throughput, more memory usage
**Recommended settings based on load:**
- **Low traffic (<100 concurrent):** `BUFFER_SIZE=32768` (default)
- **High traffic (>100 concurrent):** `BUFFER_SIZE=16384` or `BUFFER_SIZE=8192`
- **Very high traffic (>1000 concurrent):** `BUFFER_SIZE=8192` or `BUFFER_SIZE=4096`
The buffer pool reuses buffers across connections, preventing memory fragmentation and reducing garbage collection pressure.
### Profiling with pprof
To enable profiling for performance analysis:
1. Set `PPROF_ENABLED=true` in your `.env` file
2. Optionally set `PPROF_PORT` to your desired port (default: 6060)
3. Access profiling data at `http://localhost:6060/debug/pprof/`
Common pprof endpoints:
- `/debug/pprof/` - Index page with available profiles
- `/debug/pprof/heap` - Memory allocation profile
- `/debug/pprof/goroutine` - Stack traces of all current goroutines
- `/debug/pprof/profile` - CPU profile (30-second sample by default)
- `/debug/pprof/trace` - Execution trace
Example usage with `go tool pprof`:
```bash
# Analyze CPU profile
go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
# Analyze memory heap
go tool pprof http://localhost:6060/debug/pprof/heap
```
## Docker Deployment
Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`.
@@ -193,22 +137,6 @@ docker-compose -f docker-compose.tcp.yml up -d
docker-compose -f docker-compose.root.yml down
```
### Volume Management
All configurations use a named volume `certs` for persistent storage:
- SSH keys: `/app/certs/ssh/`
- TLS certificates: `/app/certs/tls/`
To backup certificates:
```bash
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar czf /backup/certs-backup.tar.gz -C /data .
```
To restore certificates:
```bash
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar xzf /backup/certs-backup.tar.gz -C /data
```
### Recommendation
**Use `docker-compose.root.yml`** for production deployments if you need:
Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB

+33 -21
View File
@@ -1,48 +1,60 @@
module tunnel_pls
go 1.24.4
go 1.25.5
require (
github.com/caddyserver/certmagic v0.25.0
github.com/charmbracelet/bubbles v0.21.0
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0
github.com/caddyserver/certmagic v0.25.1
github.com/charmbracelet/bubbles v0.21.1
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/joho/godotenv v1.5.1
github.com/libdns/cloudflare v0.2.2
golang.org/x/crypto v0.46.0
github.com/muesli/termenv v0.16.0
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.47.0
google.golang.org/grpc v1.78.0
google.golang.org/protobuf v1.36.11
)
require (
github.com/atotto/clipboard v0.1.4 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
github.com/charmbracelet/lipgloss v1.1.0 // indirect
github.com/charmbracelet/x/ansi v0.10.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/caddyserver/zerossl v0.1.4 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.5 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/libdns/libdns v1.1.1 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/mholt/acmez/v3 v3.1.3 // indirect
github.com/miekg/dns v1.1.68 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/mholt/acmez/v3 v3.1.4 // indirect
github.com/miekg/dns v1.1.69 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/sahilm/fuzzy v0.1.1 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/zeebo/blake3 v0.2.4 // indirect
go.uber.org/multierr v1.11.0 // indirect
go.uber.org/zap v1.27.0 // indirect
go.uber.org/zap v1.27.1 // indirect
go.uber.org/zap/exp v0.3.0 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/tools v0.39.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/text v0.33.0 // indirect
golang.org/x/tools v0.40.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+100 -41
View File
@@ -1,35 +1,62 @@
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA=
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic=
github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA=
github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA=
github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/caddyserver/certmagic v0.25.1 h1:4sIKKbOt5pg6+sL7tEwymE1x2bj6CHr80da1CRRIPbY=
github.com/caddyserver/certmagic v0.25.1/go.mod h1:VhyvndxtVton/Fo/wKhRoC46Rbw1fmjvQ3GjHYSQTEY=
github.com/caddyserver/zerossl v0.1.4 h1:CVJOE3MZeFisCERZjkxIcsqIH4fnFdlYWnPYeFtBHRw=
github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
github.com/charmbracelet/bubbles v0.21.1 h1:nj0decPiixaZeL9diI4uzzQTkkz1kYY8+jgzCZXSmW0=
github.com/charmbracelet/bubbles v0.21.1/go.mod h1:HHvIYRCpbkCJw2yo0vNX1O5loCwSr9/mWS8GYSg50Sk=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
github.com/charmbracelet/x/ansi v0.11.3 h1:6DcVaqWI82BBVM/atTyq6yBoRLZFBsnoDoX9GCu2YOI=
github.com/charmbracelet/x/ansi v0.11.3/go.mod h1:yI7Zslym9tCJcedxz5+WBq+eUGMJT0bM06Fqy1/Y4dI=
github.com/charmbracelet/x/ansi v0.11.5 h1:NBWeBpj/lJPE3Q5l+Lusa4+mH6v7487OP8K0r1IhRg4=
github.com/charmbracelet/x/ansi v0.11.5/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.14 h1:iUEMryGyFTelKW3THW4+FfPgi4fkmKnnaLOXuc+/Kj4=
github.com/charmbracelet/x/cellbuf v0.0.14/go.mod h1:P447lJl49ywBbil/KjCk2HexGh4tEY9LH0/1QrZZ9rA=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ=
github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo=
github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
@@ -40,18 +67,18 @@ github.com/libdns/cloudflare v0.2.2 h1:XWHv+C1dDcApqazlh08Q6pjytYLgR2a+Y3xrXFu0v
github.com/libdns/cloudflare v0.2.2/go.mod h1:w9uTmRCDlAoafAsTPnn2nJ0XHK/eaUMh86DUk8BWi60=
github.com/libdns/libdns v1.1.1 h1:wPrHrXILoSHKWJKGd0EiAVmiJbFShguILTg9leS/P/U=
github.com/libdns/libdns v1.1.1/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mholt/acmez/v3 v3.1.3 h1:gUl789rjbJSuM5hYzOFnNaGgWPV1xVfnOs59o0dZEcc=
github.com/mholt/acmez/v3 v3.1.3/go.mod h1:L1wOU06KKvq7tswuMDwKdcHeKpFFgkppZy/y0DFxagQ=
github.com/miekg/dns v1.1.68 h1:jsSRkNozw7G/mnmXULynzMNIsgY2dHC8LO6U6Ij2JEA=
github.com/miekg/dns v1.1.68/go.mod h1:fujopn7TB3Pu3JM69XaawiU0wqjpL9/8xGop5UrTPps=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mholt/acmez/v3 v3.1.4 h1:DyzZe/RnAzT3rpZj/2Ii5xZpiEvvYk3cQEN/RmqxwFQ=
github.com/mholt/acmez/v3 v3.1.4/go.mod h1:L1wOU06KKvq7tswuMDwKdcHeKpFFgkppZy/y0DFxagQ=
github.com/miekg/dns v1.1.69 h1:Kb7Y/1Jo+SG+a2GtfoFUfDkG//csdRPwRLkCsxDG9Sc=
github.com/miekg/dns v1.1.69/go.mod h1:7OyjD9nEba5OkqQ/hB4fy3PIoxafSZJtducccIelz3g=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
@@ -60,13 +87,22 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
@@ -75,33 +111,56 @@ github.com/zeebo/blake3 v0.2.4 h1:KYQPkhpRtcqh0ssGYcKLG1JYvddkEA8QwCM/yBqhaZI=
github.com/zeebo/blake3 v0.2.4/go.mod h1:7eeQ6d2iXWRGF6npfaxl2CU+xy2Fjo2gxeyZGCRUjcE=
github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U=
go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b h1:Mv8VFug0MP9e5vUxfBcE3vUkV6CImK3cMNMIDFjmzxU=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ=
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+196
View File
@@ -0,0 +1,196 @@
package bootstrap
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/key"
"tunnel_pls/internal/port"
"tunnel_pls/internal/random"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/internal/version"
"tunnel_pls/server"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type Bootstrap struct {
Randomizer random.Random
Config config.Config
SessionRegistry registry.Registry
Port port.Port
GrpcClient client.Client
ErrChan chan error
SignalChan chan os.Signal
}
func New(config config.Config, port port.Port) (*Bootstrap, error) {
randomizer := random.New()
sessionRegistry := registry.NewRegistry()
if err := port.AddRange(config.AllowedPortsStart(), config.AllowedPortsEnd()); err != nil {
return nil, err
}
grpcClient, err := client.New(config, sessionRegistry)
if err != nil {
return nil, err
}
errChan := make(chan error, 5)
signalChan := make(chan os.Signal, 1)
return &Bootstrap{
Randomizer: randomizer,
Config: config,
SessionRegistry: sessionRegistry,
Port: port,
GrpcClient: grpcClient,
ErrChan: errChan,
SignalChan: signalChan,
}, nil
}
func newSSHConfig(sshKeyPath string) (*ssh.ServerConfig, error) {
sshCfg := &ssh.ServerConfig{
NoClientAuth: true,
ServerVersion: fmt.Sprintf("SSH-2.0-TunnelPlease-%s", version.GetShortVersion()),
}
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
return nil, fmt.Errorf("generate ssh key: %w", err)
}
privateBytes, err := os.ReadFile(sshKeyPath)
if err != nil {
return nil, fmt.Errorf("read private key: %w", err)
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
sshCfg.AddHostKey(private)
return sshCfg, nil
}
func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, errChan chan<- error) error {
healthCtx, healthCancel := context.WithTimeout(ctx, 5*time.Second)
defer healthCancel()
if err := b.GrpcClient.CheckServerHealth(healthCtx); err != nil {
return fmt.Errorf("gRPC health check failed: %w", err)
}
go func() {
if err := b.GrpcClient.SubscribeEvents(ctx, conf.Domain(), conf.NodeToken()); err != nil {
errChan <- fmt.Errorf("failed to subscribe to events: %w", err)
}
}()
return nil
}
func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
httpserver := transport.NewHTTPServer(conf, registry)
ln, err := httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
if err = httpserver.Serve(ln); err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
}
}
func startHTTPSServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
tlsCfg, err := transport.NewTLSConfig(conf)
if err != nil {
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
return
}
httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg)
ln, err := httpsServer.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
return
}
if err = httpsServer.Serve(ln); err != nil {
errChan <- fmt.Errorf("error when serving https server: %w", err)
}
}
func startSSHServer(rand random.Random, conf config.Config, sshCfg *ssh.ServerConfig, registry registry.Registry, grpcClient client.Client, portManager port.Port, errChan chan<- error) {
sshServer, err := server.New(rand, conf, sshCfg, registry, grpcClient, portManager, conf.SSHPort())
if err != nil {
errChan <- err
return
}
sshServer.Start()
errChan <- sshServer.Close()
}
func startPprof(pprofPort string, errChan chan<- error) {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
errChan <- fmt.Errorf("pprof server error: %v", err)
}
}
func (b *Bootstrap) Run() error {
sshConfig, err := newSSHConfig(b.Config.KeyLoc())
if err != nil {
return fmt.Errorf("failed to create SSH config: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
signal.Notify(b.SignalChan, os.Interrupt, syscall.SIGTERM)
if b.Config.Mode() == types.ServerModeNODE {
err = b.startGRPCClient(ctx, b.Config, b.ErrChan)
if err != nil {
return fmt.Errorf("failed to start gRPC client: %w", err)
}
defer func(grpcClient client.Client) {
err = grpcClient.Close()
if err != nil {
log.Printf("failed to close gRPC client")
}
}(b.GrpcClient)
}
go startHTTPServer(b.Config, b.SessionRegistry, b.ErrChan)
if b.Config.TLSEnabled() {
go startHTTPSServer(b.Config, b.SessionRegistry, b.ErrChan)
}
go func() {
startSSHServer(b.Randomizer, b.Config, sshConfig, b.SessionRegistry, b.GrpcClient, b.Port, b.ErrChan)
}()
if b.Config.PprofEnabled() {
go startPprof(b.Config.PprofPort(), b.ErrChan)
}
log.Println("All services started successfully")
select {
case err = <-b.ErrChan:
return fmt.Errorf("service error: %w", err)
case sig := <-b.SignalChan:
log.Printf("Received signal %s, initiating graceful shutdown", sig)
cancel()
return nil
}
}
+558
View File
@@ -0,0 +1,558 @@
package bootstrap
import (
"context"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
"os"
"path/filepath"
"strconv"
"testing"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
)
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockRandom struct {
mock.Mock
}
func (m *MockRandom) String(length int) (string, error) {
args := m.Called(length)
return args.String(0), args.Error(1)
}
type MockConfig struct {
mock.Mock
}
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode {
args := m.Called()
if args.Get(0) == nil {
return 0
}
switch v := args.Get(0).(type) {
case types.ServerMode:
return v
case int:
return types.ServerMode(v)
default:
return types.ServerMode(args.Int(0))
}
}
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type MockPort struct {
mock.Mock
}
func (m *MockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
var mPort uint16
if args.Get(0) != nil {
switch v := args.Get(0).(type) {
case int:
mPort = uint16(v)
case uint16:
mPort = v
case uint32:
mPort = uint16(v)
case int32:
mPort = uint16(v)
case float64:
mPort = uint16(v)
default:
mPort = uint16(args.Int(0))
}
}
return mPort, args.Bool(1)
}
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *MockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type MockGRPCClient struct {
mock.Mock
}
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
args := m.Called()
return args.Get(0).(*grpc.ClientConn)
}
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
args := m.Called(ctx, token)
return args.Bool(0), args.String(1), args.Error(2)
}
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
args := m.Called(ctx, domain, token)
return args.Error(0)
}
func (m *MockGRPCClient) Close() error {
args := m.Called()
return args.Error(0)
}
func TestNew(t *testing.T) {
tests := []struct {
name string
setupConfig func() config.Config
setupPort func() port.Port
wantErr bool
errContains string
}{
{
name: "Success New with default value",
wantErr: false,
},
{
name: "Error when AddRange fails",
setupPort: func() port.Port {
mockPort := &MockPort{}
mockPort.On("AddRange", mock.Anything, mock.Anything).Return(fmt.Errorf("invalid port range"))
return mockPort
},
wantErr: true,
errContains: "invalid port range",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var mockPort port.Port
if tt.setupPort != nil {
mockPort = tt.setupPort()
} else {
mockPort = port.New()
}
var mockConfig config.Config
if tt.setupConfig != nil {
mockConfig = tt.setupConfig()
} else {
var err error
mockConfig, err = config.MustLoad()
assert.NoError(t, err)
}
bootstrap, err := New(mockConfig, mockPort)
if tt.wantErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, bootstrap)
} else {
assert.NoError(t, err)
assert.NotNil(t, bootstrap)
assert.NotNil(t, bootstrap.Randomizer)
assert.NotNil(t, bootstrap.SessionRegistry)
assert.NotNil(t, bootstrap.Config)
assert.NotNil(t, bootstrap.Port)
assert.NotNil(t, bootstrap.ErrChan)
assert.NotNil(t, bootstrap.SignalChan)
}
})
}
}
func randomAvailablePort() (string, error) {
listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", err
}
defer func(listener net.Listener) {
_ = listener.Close()
}(listener)
mPort := listener.Addr().(*net.TCPAddr).Port
return strconv.Itoa(mPort), nil
}
func TestRun(t *testing.T) {
mockRandom := &MockRandom{}
mockErrChan := make(chan error, 1)
mockSignalChan := make(chan os.Signal, 1)
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
tmpDir := t.TempDir()
keyLoc := filepath.Join(tmpDir, "key.key")
tests := []struct {
name string
setupConfig func() *MockConfig
setupGrpcClient func() *MockGRPCClient
needCerts bool
expectError bool
}{
{
name: "successful run and termination",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: false,
},
{
name: "error from SSH server invalid port",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("invalid")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: true,
},
{
name: "error from HTTP server invalid port",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("invalid")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: true,
},
{
name: "error from HTTPS server invalid port",
setupConfig: func() *MockConfig {
tempDir := os.TempDir()
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("invalid")
mockConfig.On("TLSEnabled").Return(true)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("TLSStoragePath").Return(tempDir)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: true,
},
{
name: "grpc health check failed",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("invalid")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(fmt.Errorf("health check failed"))
return mockGRPCClient
},
expectError: true,
},
{
name: "successful run with pprof enabled",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
pprofPort, _ := randomAvailablePort()
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeSTANDALONE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(true)
mockConfig.On("PprofPort").Return(pprofPort)
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
expectError: false,
}, {
name: "successful run in NODE mode with signal",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockGRPCClient.On("Close").Return(nil)
return mockGRPCClient
},
expectError: false,
}, {
name: "successful run in NODE mode with signal buf error when closing",
setupConfig: func() *MockConfig {
mockConfig := &MockConfig{}
mockConfig.On("KeyLoc").Return(keyLoc)
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("Domain").Return("example.com")
mockConfig.On("SSHPort").Return("0")
mockConfig.On("HTTPPort").Return("0")
mockConfig.On("HTTPSPort").Return("0")
mockConfig.On("TLSEnabled").Return(false)
mockConfig.On("TLSRedirect").Return(false)
mockConfig.On("ACMEEmail").Return("test@example.com")
mockConfig.On("CFAPIToken").Return("fake-token")
mockConfig.On("ACMEStaging").Return(true)
mockConfig.On("AllowedPortsStart").Return(uint16(1024))
mockConfig.On("AllowedPortsEnd").Return(uint16(65535))
mockConfig.On("BufferSize").Return(4096)
mockConfig.On("PprofEnabled").Return(false)
mockConfig.On("PprofPort").Return("0")
mockConfig.On("GRPCAddress").Return("localhost")
mockConfig.On("GRPCPort").Return("0")
mockConfig.On("NodeToken").Return("fake-node-token")
return mockConfig
},
setupGrpcClient: func() *MockGRPCClient {
mockGRPCClient := &MockGRPCClient{}
mockGRPCClient.On("CheckServerHealth", mock.Anything).Return(nil)
mockGRPCClient.On("SubscribeEvents", mock.Anything, mock.Anything, mock.Anything).Return(nil)
mockGRPCClient.On("Close").Return(fmt.Errorf("you fucked up, buddy"))
return mockGRPCClient
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockConfig := tt.setupConfig()
mockGRPCClient := &MockGRPCClient{}
bootstrap := &Bootstrap{
Randomizer: mockRandom,
Config: mockConfig,
SessionRegistry: mockSessionRegistry,
Port: mockPort,
ErrChan: mockErrChan,
SignalChan: mockSignalChan,
GrpcClient: mockGRPCClient,
}
if tt.setupGrpcClient != nil {
bootstrap.GrpcClient = tt.setupGrpcClient()
}
done := make(chan error, 1)
go func() {
done <- bootstrap.Run()
}()
if tt.expectError {
err := <-done
assert.Error(t, err)
} else if tt.name == "successful run with pprof enabled" {
time.Sleep(200 * time.Millisecond)
fmt.Println(mockConfig.PprofPort())
resp, err := http.Get(fmt.Sprintf("http://localhost:%s/debug/pprof/", mockConfig.PprofPort()))
assert.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
err = resp.Body.Close()
assert.NoError(t, err)
mockSignalChan <- os.Interrupt
err = <-done
assert.NoError(t, err)
} else {
time.Sleep(time.Second)
mockSignalChan <- os.Interrupt
err := <-done
assert.NoError(t, err)
}
})
}
}
+60 -25
View File
@@ -1,35 +1,70 @@
package config
import (
"log"
"os"
"strconv"
import "tunnel_pls/types"
"github.com/joho/godotenv"
)
type Config interface {
Domain() string
SSHPort() string
func init() {
if _, err := os.Stat(".env"); err == nil {
if err := godotenv.Load(".env"); err != nil {
log.Printf("Warning: Failed to load .env file: %s", err)
}
}
HTTPPort() string
HTTPSPort() string
KeyLoc() string
TLSEnabled() bool
TLSRedirect() bool
TLSStoragePath() string
ACMEEmail() string
CFAPIToken() string
ACMEStaging() bool
AllowedPortsStart() uint16
AllowedPortsEnd() uint16
BufferSize() int
HeaderSize() int
PprofEnabled() bool
PprofPort() string
Mode() types.ServerMode
GRPCAddress() string
GRPCPort() string
NodeToken() string
}
func Getenv(key, defaultValue string) string {
val := os.Getenv(key)
if val == "" {
val = defaultValue
func MustLoad() (Config, error) {
if err := loadEnvFile(); err != nil {
return nil, err
}
return val
cfg, err := parse()
if err != nil {
return nil, err
}
return cfg, nil
}
func GetBufferSize() int {
sizeStr := Getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(sizeStr)
if err != nil || size < 4096 || size > 1048576 {
return 32768
}
return size
}
func (c *config) Domain() string { return c.domain }
func (c *config) SSHPort() string { return c.sshPort }
func (c *config) HTTPPort() string { return c.httpPort }
func (c *config) HTTPSPort() string { return c.httpsPort }
func (c *config) KeyLoc() string { return c.keyLoc }
func (c *config) TLSEnabled() bool { return c.tlsEnabled }
func (c *config) TLSRedirect() bool { return c.tlsRedirect }
func (c *config) TLSStoragePath() string { return c.tlsStoragePath }
func (c *config) ACMEEmail() string { return c.acmeEmail }
func (c *config) CFAPIToken() string { return c.cfAPIToken }
func (c *config) ACMEStaging() bool { return c.acmeStaging }
func (c *config) AllowedPortsStart() uint16 { return c.allowedPortsStart }
func (c *config) AllowedPortsEnd() uint16 { return c.allowedPortsEnd }
func (c *config) BufferSize() int { return c.bufferSize }
func (c *config) HeaderSize() int { return c.headerSize }
func (c *config) PprofEnabled() bool { return c.pprofEnabled }
func (c *config) PprofPort() string { return c.pprofPort }
func (c *config) Mode() types.ServerMode { return c.mode }
func (c *config) GRPCAddress() string { return c.grpcAddress }
func (c *config) GRPCPort() string { return c.grpcPort }
func (c *config) NodeToken() string { return c.nodeToken }
+405
View File
@@ -0,0 +1,405 @@
package config
import (
"os"
"testing"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
)
func TestGetenv(t *testing.T) {
tests := []struct {
name string
key string
val string
def string
expected string
}{
{
name: "returns existing env",
key: "TEST_ENV_EXIST",
val: "value",
def: "default",
expected: "value",
},
{
name: "returns default when env missing",
key: "TEST_ENV_MISSING",
val: "",
def: "default",
expected: "default",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv(tt.key, tt.val)
} else {
err := os.Unsetenv(tt.key)
assert.NoError(t, err)
}
assert.Equal(t, tt.expected, getenv(tt.key, tt.def))
})
}
}
func TestGetenvBool(t *testing.T) {
tests := []struct {
name string
key string
val string
def bool
expected bool
}{
{
name: "returns true when env is true",
key: "TEST_BOOL_TRUE",
val: "true",
def: false,
expected: true,
},
{
name: "returns false when env is false",
key: "TEST_BOOL_FALSE",
val: "false",
def: true,
expected: false,
},
{
name: "returns default when env missing",
key: "TEST_BOOL_MISSING",
val: "",
def: true,
expected: true,
},
{
name: "returns false when env is not true",
key: "TEST_BOOL_INVALID",
val: "yes",
def: true,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv(tt.key, tt.val)
} else {
err := os.Unsetenv(tt.key)
assert.NoError(t, err)
}
assert.Equal(t, tt.expected, getenvBool(tt.key, tt.def))
})
}
}
func TestParseMode(t *testing.T) {
tests := []struct {
name string
mode string
expect types.ServerMode
expectErr bool
}{
{"standalone", "standalone", types.ServerModeSTANDALONE, false},
{"node", "node", types.ServerModeNODE, false},
{"uppercase", "STANDALONE", types.ServerModeSTANDALONE, false},
{"invalid", "invalid", 0, true},
{"empty (default)", "", types.ServerModeSTANDALONE, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.mode != "" {
t.Setenv("MODE", tt.mode)
} else {
err := os.Unsetenv("MODE")
assert.NoError(t, err)
}
mode, err := parseMode()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expect, mode)
}
})
}
}
func TestParseAllowedPorts(t *testing.T) {
tests := []struct {
name string
val string
start uint16
end uint16
expectErr bool
}{
{"valid range", "1000-2000", 1000, 2000, false},
{"empty", "", 0, 0, false},
{"invalid format - no dash", "1000", 0, 0, true},
{"invalid format - too many dashes", "1000-2000-3000", 0, 0, true},
{"invalid start port", "abc-2000", 0, 0, true},
{"invalid end port", "1000-abc", 0, 0, true},
{"out of range start", "70000-80000", 0, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv("ALLOWED_PORTS", tt.val)
} else {
err := os.Unsetenv("ALLOWED_PORTS")
assert.NoError(t, err)
}
start, end, err := parseAllowedPorts()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.start, start)
assert.Equal(t, tt.end, end)
}
})
}
}
func TestParseBufferSize(t *testing.T) {
tests := []struct {
name string
val string
expect int
}{
{"valid size", "8192", 8192},
{"default size", "", 32768},
{"too small", "1024", 4096},
{"too large", "2000000", 4096},
{"invalid format", "abc", 4096},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv("BUFFER_SIZE", tt.val)
} else {
err := os.Unsetenv("BUFFER_SIZE")
assert.NoError(t, err)
}
size := parseBufferSize()
assert.Equal(t, tt.expect, size)
})
}
}
func TestParseHeaderSize(t *testing.T) {
tests := []struct {
name string
val string
expect int
}{
{"valid size", "8192", 8192},
{"default size", "", 4096},
{"too small", "1024", 4096},
{"too large", "2000000", 4096},
{"invalid format", "abc", 4096},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.val != "" {
t.Setenv("MAX_HEADER_SIZE", tt.val)
} else {
err := os.Unsetenv("MAX_HEADER_SIZE")
assert.NoError(t, err)
}
size := parseHeaderSize()
assert.Equal(t, tt.expect, size)
})
}
}
func TestParse(t *testing.T) {
tests := []struct {
name string
envs map[string]string
expectErr bool
}{
{
name: "minimal valid config",
envs: map[string]string{
"DOMAIN": "example.com",
},
expectErr: false,
},
{
name: "TLS enabled without token",
envs: map[string]string{
"TLS_ENABLED": "true",
},
expectErr: true,
},
{
name: "TLS enabled with token",
envs: map[string]string{
"TLS_ENABLED": "true",
"CF_API_TOKEN": "secret",
},
expectErr: false,
},
{
name: "Node mode without token",
envs: map[string]string{
"MODE": "node",
},
expectErr: true,
},
{
name: "Node mode with token",
envs: map[string]string{
"MODE": "node",
"NODE_TOKEN": "token",
},
expectErr: false,
},
{
name: "invalid mode",
envs: map[string]string{
"MODE": "invalid",
},
expectErr: true,
},
{
name: "invalid allowed ports",
envs: map[string]string{
"ALLOWED_PORTS": "1000",
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
os.Clearenv()
for k, v := range tt.envs {
t.Setenv(k, v)
}
cfg, err := parse()
if tt.expectErr {
assert.Error(t, err)
assert.Nil(t, cfg)
} else {
assert.NoError(t, err)
assert.NotNil(t, cfg)
}
})
}
}
func TestGetters(t *testing.T) {
envs := map[string]string{
"DOMAIN": "example.com",
"PORT": "2222",
"HTTP_PORT": "80",
"HTTPS_PORT": "443",
"KEY_LOC": "certs/ssh/id_rsa",
"TLS_ENABLED": "true",
"TLS_REDIRECT": "true",
"TLS_STORAGE_PATH": "certs/tls/",
"ACME_EMAIL": "test@example.com",
"CF_API_TOKEN": "token",
"ACME_STAGING": "true",
"ALLOWED_PORTS": "1000-2000",
"BUFFER_SIZE": "16384",
"MAX_HEADER_SIZE": "4096",
"PPROF_ENABLED": "true",
"PPROF_PORT": "7070",
"MODE": "standalone",
"GRPC_ADDRESS": "127.0.0.1",
"GRPC_PORT": "9090",
"NODE_TOKEN": "ntoken",
}
os.Clearenv()
for k, v := range envs {
t.Setenv(k, v)
}
cfg, err := parse()
assert.NoError(t, err)
assert.Equal(t, "example.com", cfg.Domain())
assert.Equal(t, "2222", cfg.SSHPort())
assert.Equal(t, "80", cfg.HTTPPort())
assert.Equal(t, "443", cfg.HTTPSPort())
assert.Equal(t, "certs/ssh/id_rsa", cfg.KeyLoc())
assert.Equal(t, true, cfg.TLSEnabled())
assert.Equal(t, true, cfg.TLSRedirect())
assert.Equal(t, "certs/tls/", cfg.TLSStoragePath())
assert.Equal(t, "test@example.com", cfg.ACMEEmail())
assert.Equal(t, "token", cfg.CFAPIToken())
assert.Equal(t, true, cfg.ACMEStaging())
assert.Equal(t, uint16(1000), cfg.AllowedPortsStart())
assert.Equal(t, uint16(2000), cfg.AllowedPortsEnd())
assert.Equal(t, 16384, cfg.BufferSize())
assert.Equal(t, 4096, cfg.HeaderSize())
assert.Equal(t, true, cfg.PprofEnabled())
assert.Equal(t, "7070", cfg.PprofPort())
assert.Equal(t, types.ServerMode(types.ServerModeSTANDALONE), cfg.Mode())
assert.Equal(t, "127.0.0.1", cfg.GRPCAddress())
assert.Equal(t, "9090", cfg.GRPCPort())
assert.Equal(t, "ntoken", cfg.NodeToken())
}
func TestMustLoad(t *testing.T) {
t.Run("success", func(t *testing.T) {
os.Clearenv()
t.Setenv("DOMAIN", "example.com")
cfg, err := MustLoad()
assert.NoError(t, err)
assert.NotNil(t, cfg)
})
t.Run("loadEnvFile error", func(t *testing.T) {
err := os.Mkdir(".env", 0755)
assert.NoError(t, err)
defer func() {
err = os.Remove(".env")
assert.NoError(t, err)
}()
cfg, err := MustLoad()
assert.Error(t, err)
assert.Nil(t, cfg)
})
t.Run("parse error", func(t *testing.T) {
os.Clearenv()
t.Setenv("MODE", "invalid")
cfg, err := MustLoad()
assert.Error(t, err)
assert.Nil(t, cfg)
})
}
func TestLoadEnvFile(t *testing.T) {
t.Run("file exists", func(t *testing.T) {
err := os.WriteFile(".env", []byte("TEST_ENV_FILE=true"), 0644)
assert.NoError(t, err)
defer func() {
err = os.Remove(".env")
assert.NoError(t, err)
}()
err = loadEnvFile()
assert.NoError(t, err)
assert.Equal(t, "true", os.Getenv("TEST_ENV_FILE"))
})
t.Run("file missing", func(t *testing.T) {
_ = os.Remove(".env")
err := loadEnvFile()
assert.NoError(t, err)
})
}
+190
View File
@@ -0,0 +1,190 @@
package config
import (
"fmt"
"log"
"os"
"strconv"
"strings"
"tunnel_pls/types"
"github.com/joho/godotenv"
)
type config struct {
domain string
sshPort string
httpPort string
httpsPort string
keyLoc string
tlsEnabled bool
tlsRedirect bool
tlsStoragePath string
acmeEmail string
cfAPIToken string
acmeStaging bool
allowedPortsStart uint16
allowedPortsEnd uint16
bufferSize int
headerSize int
pprofEnabled bool
pprofPort string
mode types.ServerMode
grpcAddress string
grpcPort string
nodeToken string
}
func parse() (*config, error) {
mode, err := parseMode()
if err != nil {
return nil, err
}
domain := getenv("DOMAIN", "localhost")
sshPort := getenv("PORT", "2200")
httpPort := getenv("HTTP_PORT", "8080")
httpsPort := getenv("HTTPS_PORT", "8443")
keyLoc := getenv("KEY_LOC", "certs/privkey.pem")
tlsEnabled := getenvBool("TLS_ENABLED", false)
tlsRedirect := tlsEnabled && getenvBool("TLS_REDIRECT", false)
tlsStoragePath := getenv("TLS_STORAGE_PATH", "certs/tls/")
acmeEmail := getenv("ACME_EMAIL", "admin@"+domain)
acmeStaging := getenvBool("ACME_STAGING", false)
cfToken := getenv("CF_API_TOKEN", "")
if tlsEnabled && cfToken == "" {
return nil, fmt.Errorf("CF_API_TOKEN is required when TLS is enabled")
}
start, end, err := parseAllowedPorts()
if err != nil {
return nil, err
}
bufferSize := parseBufferSize()
headerSize := parseHeaderSize()
pprofEnabled := getenvBool("PPROF_ENABLED", false)
pprofPort := getenv("PPROF_PORT", "6060")
grpcHost := getenv("GRPC_ADDRESS", "localhost")
grpcPort := getenv("GRPC_PORT", "8080")
nodeToken := getenv("NODE_TOKEN", "")
if mode == types.ServerModeNODE && nodeToken == "" {
return nil, fmt.Errorf("NODE_TOKEN is required in node mode")
}
return &config{
domain: domain,
sshPort: sshPort,
httpPort: httpPort,
httpsPort: httpsPort,
keyLoc: keyLoc,
tlsEnabled: tlsEnabled,
tlsRedirect: tlsRedirect,
tlsStoragePath: tlsStoragePath,
acmeEmail: acmeEmail,
cfAPIToken: cfToken,
acmeStaging: acmeStaging,
allowedPortsStart: start,
allowedPortsEnd: end,
bufferSize: bufferSize,
headerSize: headerSize,
pprofEnabled: pprofEnabled,
pprofPort: pprofPort,
mode: mode,
grpcAddress: grpcHost,
grpcPort: grpcPort,
nodeToken: nodeToken,
}, nil
}
func loadEnvFile() error {
if _, err := os.Stat(".env"); err == nil {
return godotenv.Load(".env")
}
return nil
}
func parseMode() (types.ServerMode, error) {
switch strings.ToLower(getenv("MODE", "standalone")) {
case "standalone":
return types.ServerModeSTANDALONE, nil
case "node":
return types.ServerModeNODE, nil
default:
return 0, fmt.Errorf("invalid MODE value")
}
}
func parseAllowedPorts() (uint16, uint16, error) {
raw := getenv("ALLOWED_PORTS", "")
if raw == "" {
return 0, 0, nil
}
parts := strings.Split(raw, "-")
if len(parts) != 2 {
return 0, 0, fmt.Errorf("invalid ALLOWED_PORTS format")
}
start, err := strconv.ParseUint(parts[0], 10, 16)
if err != nil {
return 0, 0, err
}
end, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return 0, 0, err
}
return uint16(start), uint16(end), nil
}
func parseBufferSize() int {
raw := getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(raw)
if err != nil || size < 4096 || size > 1048576 {
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
return 4096
}
return size
}
func parseHeaderSize() int {
raw := getenv("MAX_HEADER_SIZE", "4096")
size, err := strconv.Atoi(raw)
if err != nil || size < 4096 || size > 131072 {
log.Println("Invalid BUFFER_SIZE, falling back to 4096")
return 4096
}
return size
}
func getenv(key, def string) string {
if v := os.Getenv(key); v != "" {
return v
}
return def
}
func getenvBool(key string, def bool) bool {
val := os.Getenv(key)
if val == "" {
return def
}
return val == "true"
}
+377
View File
@@ -0,0 +1,377 @@
package client
import (
"context"
"errors"
"fmt"
"io"
"log"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
"tunnel_pls/types"
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
)
type Client interface {
SubscribeEvents(ctx context.Context, identity, authToken string) error
ClientConn() *grpc.ClientConn
AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error)
Close() error
CheckServerHealth(ctx context.Context) error
}
type client struct {
config config.Config
conn *grpc.ClientConn
address string
sessionRegistry registry.Registry
eventService proto.EventServiceClient
authorizeConnectionService proto.UserServiceClient
closing bool
}
var (
grpcNewClient = grpc.NewClient
healthNewHealthClient = grpc_health_v1.NewHealthClient
initialBackoff = time.Second
)
func New(config config.Config, sessionRegistry registry.Registry) (Client, error) {
address := fmt.Sprintf("%s:%s", config.GRPCAddress(), config.GRPCPort())
var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
kaParams := keepalive.ClientParameters{
Time: 2 * time.Minute,
Timeout: 10 * time.Second,
PermitWithoutStream: false,
}
opts = append(opts, grpc.WithKeepaliveParams(kaParams))
opts = append(opts,
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(4*1024*1024),
grpc.MaxCallSendMsgSize(4*1024*1024),
),
)
conn, err := grpcNewClient(address, opts...)
if err != nil {
return nil, fmt.Errorf("failed to connect to gRPC server at %s: %w", address, err)
}
eventService := proto.NewEventServiceClient(conn)
authorizeConnectionService := proto.NewUserServiceClient(conn)
return &client{
config: config,
conn: conn,
address: address,
sessionRegistry: sessionRegistry,
eventService: eventService,
authorizeConnectionService: authorizeConnectionService,
}, nil
}
func (c *client) SubscribeEvents(ctx context.Context, identity, authToken string) error {
backoff := initialBackoff
for {
if err := c.subscribeAndProcess(ctx, identity, authToken, &backoff); err != nil {
return err
}
}
}
func (c *client) subscribeAndProcess(ctx context.Context, identity, authToken string, backoff *time.Duration) error {
subscribe, err := c.eventService.Subscribe(ctx)
if err != nil {
return c.handleSubscribeError(ctx, err, backoff)
}
err = subscribe.Send(&proto.Node{
Type: proto.EventType_AUTHENTICATION,
Payload: &proto.Node_AuthEvent{
AuthEvent: &proto.Authentication{
Identity: identity,
AuthToken: authToken,
},
},
})
if err != nil {
return c.handleAuthError(ctx, err, backoff)
}
log.Println("Authentication Successfully sent to gRPC server")
*backoff = time.Second
return c.handleStreamError(ctx, c.processEventStream(subscribe), backoff)
}
func (c *client) handleSubscribeError(ctx context.Context, err error, backoff *time.Duration) error {
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
return err
}
if !c.isConnectionError(err) || status.Code(err) == codes.Unauthenticated {
return err
}
if err = c.wait(ctx, *backoff); err != nil {
return err
}
c.growBackoff(backoff)
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
return nil
}
func (c *client) handleAuthError(ctx context.Context, err error, backoff *time.Duration) error {
log.Println("Authentication failed to send to gRPC server:", err)
if !c.isConnectionError(err) {
return err
}
if err := c.wait(ctx, *backoff); err != nil {
return err
}
c.growBackoff(backoff)
return nil
}
func (c *client) handleStreamError(ctx context.Context, err error, backoff *time.Duration) error {
if errors.Is(err, context.Canceled) || status.Code(err) == codes.Canceled || ctx.Err() != nil {
return err
}
if !c.isConnectionError(err) {
return err
}
log.Printf("Reconnect to controller within %v sec", backoff.Seconds())
if err := c.wait(ctx, *backoff); err != nil {
return err
}
c.growBackoff(backoff)
return nil
}
func (c *client) wait(ctx context.Context, duration time.Duration) error {
if duration <= 0 {
return nil
}
select {
case <-time.After(duration):
return nil
case <-ctx.Done():
return ctx.Err()
}
}
func (c *client) growBackoff(backoff *time.Duration) {
const maxBackoff = 30 * time.Second
*backoff *= 2
if *backoff > maxBackoff {
*backoff = maxBackoff
}
}
func (c *client) processEventStream(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) error {
handlers := c.eventHandlers(subscribe)
for {
recv, err := subscribe.Recv()
if err != nil {
return err
}
handler, ok := handlers[recv.GetType()]
if !ok {
log.Printf("Unknown event type received: %v", recv.GetType())
continue
}
if err = handler(recv); err != nil {
return err
}
}
}
func (c *client) eventHandlers(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events]) map[proto.EventType]func(*proto.Events) error {
return map[proto.EventType]func(*proto.Events) error{
proto.EventType_SLUG_CHANGE: func(evt *proto.Events) error { return c.handleSlugChange(subscribe, evt) },
proto.EventType_GET_SESSIONS: func(evt *proto.Events) error { return c.handleGetSessions(subscribe, evt) },
proto.EventType_TERMINATE_SESSION: func(evt *proto.Events) error { return c.handleTerminateSession(subscribe, evt) },
}
}
func (c *client) handleSlugChange(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
slugEvent := evt.GetSlugEvent()
user := slugEvent.GetUser()
oldKey := types.SessionKey{Id: slugEvent.GetOld(), Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: slugEvent.GetNew(), Type: types.TunnelTypeHTTP}
userSession, err := c.sessionRegistry.Get(oldKey)
if err != nil {
return c.sendSlugChangeResponse(subscribe, false, err.Error())
}
if err = c.sessionRegistry.Update(user, oldKey, newKey); err != nil {
return c.sendSlugChangeResponse(subscribe, false, err.Error())
}
userSession.Interaction().Redraw()
return c.sendSlugChangeResponse(subscribe, true, "")
}
func (c *client) handleGetSessions(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
sessions := c.sessionRegistry.GetAllSessionFromUser(evt.GetGetSessionsEvent().GetIdentity())
var details []*proto.Detail
for _, ses := range sessions {
detail := ses.Detail()
details = append(details, &proto.Detail{
Node: c.config.Domain(),
ForwardingType: detail.ForwardingType,
Slug: detail.Slug,
UserId: detail.UserID,
Active: detail.Active,
StartedAt: timestamppb.New(detail.StartedAt),
})
}
return c.sendGetSessionsResponse(subscribe, details)
}
func (c *client) handleTerminateSession(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], evt *proto.Events) error {
terminate := evt.GetTerminateSessionEvent()
user := terminate.GetUser()
slug := terminate.GetSlug()
tunnelType, err := c.protoToTunnelType(terminate.GetTunnelType())
if err != nil {
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
}
userSession, err := c.sessionRegistry.GetWithUser(user, types.SessionKey{Id: slug, Type: tunnelType})
if err != nil {
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
}
if err = userSession.Lifecycle().Close(); err != nil {
return c.sendTerminateSessionResponse(subscribe, false, err.Error())
}
return c.sendTerminateSessionResponse(subscribe, true, "")
}
func (c *client) sendSlugChangeResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_SLUG_CHANGE_RESPONSE,
Payload: &proto.Node_SlugEventResponse{
SlugEventResponse: &proto.SlugChangeEventResponse{Success: success, Message: message},
},
}, "slug change response")
}
func (c *client) sendGetSessionsResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], details []*proto.Detail) error {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_GET_SESSIONS,
Payload: &proto.Node_GetSessionsEvent{
GetSessionsEvent: &proto.GetSessionsResponse{Details: details},
},
}, "send get sessions response")
}
func (c *client) sendTerminateSessionResponse(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], success bool, message string) error {
return c.sendNode(subscribe, &proto.Node{
Type: proto.EventType_TERMINATE_SESSION,
Payload: &proto.Node_TerminateSessionEventResponse{
TerminateSessionEventResponse: &proto.TerminateSessionEventResponse{Success: success, Message: message},
},
}, "terminate session response")
}
func (c *client) sendNode(subscribe grpc.BidiStreamingClient[proto.Node, proto.Events], node *proto.Node, context string) error {
if err := subscribe.Send(node); err != nil {
if c.isConnectionError(err) {
return err
}
log.Printf("%s: %v", context, err)
}
return nil
}
func (c *client) protoToTunnelType(t proto.TunnelType) (types.TunnelType, error) {
switch t {
case proto.TunnelType_HTTP:
return types.TunnelTypeHTTP, nil
case proto.TunnelType_TCP:
return types.TunnelTypeTCP, nil
default:
return types.TunnelTypeUNKNOWN, fmt.Errorf("unknown tunnel type received")
}
}
func (c *client) ClientConn() *grpc.ClientConn {
return c.conn
}
func (c *client) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
check, err := c.authorizeConnectionService.Check(ctx, &proto.CheckRequest{AuthToken: token})
if err != nil {
return false, "UNAUTHORIZED", err
}
if check.GetResponse() == proto.AuthorizationResponse_MESSAGE_TYPE_UNAUTHORIZED {
return false, "UNAUTHORIZED", nil
}
return true, check.GetUser(), nil
}
func (c *client) CheckServerHealth(ctx context.Context) error {
healthClient := healthNewHealthClient(c.ClientConn())
resp, err := healthClient.Check(ctx, &grpc_health_v1.HealthCheckRequest{
Service: "",
})
if err != nil {
return fmt.Errorf("health check failed: %w", err)
}
if resp.Status != grpc_health_v1.HealthCheckResponse_SERVING {
return fmt.Errorf("server not serving: %v", resp.Status)
}
return nil
}
func (c *client) Close() error {
if c.conn != nil {
log.Printf("Closing gRPC connection to %s", c.address)
c.closing = true
return c.conn.Close()
}
return nil
}
func (c *client) isConnectionError(err error) bool {
if c.closing {
return false
}
if err == nil {
return false
}
if errors.Is(err, io.EOF) {
return true
}
switch status.Code(err) {
case codes.Unavailable, codes.Canceled, codes.DeadlineExceeded:
return true
default:
return false
}
}
File diff suppressed because it is too large Load Diff
+30
View File
@@ -0,0 +1,30 @@
package header
type ResponseHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type responseHeader struct {
startLine []byte
headers map[string]string
}
type RequestHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
Method() string
Path() string
Version() string
}
type requestHeader struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
+227
View File
@@ -0,0 +1,227 @@
package header
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewRequest(t *testing.T) {
tests := []struct {
name string
data []byte
expectErr bool
errContains string
expectMethod string
expectPath string
expectVersion string
expectHeaders map[string]string
}{
{
name: "success",
data: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\nX-Custom: value\r\n\r\n"),
expectErr: false,
expectMethod: "GET",
expectPath: "/path",
expectVersion: "HTTP/1.1",
expectHeaders: map[string]string{
"Host": "example.com",
"X-Custom": "value",
},
},
{
name: "no CRLF in start line",
data: []byte("GET /path HTTP/1.1"),
expectErr: true,
errContains: "no CRLF found in start line",
},
{
name: "invalid start line - missing method",
data: []byte("INVALID\r\n\r\n"),
expectErr: true,
errContains: "invalid start line: missing method",
},
{
name: "invalid start line - missing version",
data: []byte("GET /path\r\n\r\n"),
expectErr: true,
errContains: "invalid start line: missing version",
},
{
name: "invalid start line - multiple spaces",
data: []byte("GET /path HTTP/1.1\r\n\r\n"),
expectErr: false,
expectMethod: "GET",
expectPath: "",
expectVersion: "/path HTTP/1.1",
expectHeaders: map[string]string{},
},
{
name: "start line with trailing space",
data: []byte("GET / HTTP/1.1 \r\n\r\n"),
expectErr: false,
expectMethod: "GET",
expectPath: "/",
expectVersion: "HTTP/1.1 ",
expectHeaders: map[string]string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := NewRequest(tt.data)
if tt.expectErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, req)
} else {
assert.NoError(t, err)
assert.NotNil(t, req)
assert.Equal(t, tt.expectMethod, req.Method())
assert.Equal(t, tt.expectPath, req.Path())
assert.Equal(t, tt.expectVersion, req.Version())
for k, v := range tt.expectHeaders {
assert.Equal(t, v, req.Value(k))
}
}
})
}
}
func TestRequestHeaderMethods(t *testing.T) {
data := []byte("GET / HTTP/1.1\r\nHost: original\r\n\r\n")
req, _ := NewRequest(data)
req.Set("Host", "updated")
req.Set("X-New", "new-value")
assert.Equal(t, "updated", req.Value("Host"))
assert.Equal(t, "new-value", req.Value("X-New"))
assert.Equal(t, "", req.Value("Non-Existent"))
req.Remove("X-New")
assert.Equal(t, "", req.Value("X-New"))
final := req.Finalize()
assert.Contains(t, string(final), "GET / HTTP/1.1\r\n")
assert.Contains(t, string(final), "Host: updated\r\n")
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
}
func TestNewResponse(t *testing.T) {
tests := []struct {
name string
data []byte
expectErr bool
errContains string
expectHeaders map[string]string
}{
{
name: "success",
data: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n"),
expectErr: false,
expectHeaders: map[string]string{
"Content-Length": "0",
},
},
{
name: "invalid response - no CRLF",
data: []byte("HTTP/1.1 200 OK"),
expectErr: true,
errContains: "no CRLF found in start line",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := NewResponse(tt.data)
if tt.expectErr {
assert.Error(t, err)
if tt.errContains != "" {
assert.Contains(t, err.Error(), tt.errContains)
}
assert.Nil(t, resp)
} else {
assert.NoError(t, err)
assert.NotNil(t, resp)
for k, v := range tt.expectHeaders {
assert.Equal(t, v, resp.Value(k))
}
}
})
}
}
func TestResponseHeaderMethods(t *testing.T) {
data := []byte("HTTP/1.1 200 OK\r\nServer: old\r\n\r\n")
resp, _ := NewResponse(data)
resp.Set("Server", "new")
resp.Set("X-Res", "val")
assert.Equal(t, "new", resp.Value("Server"))
assert.Equal(t, "val", resp.Value("X-Res"))
resp.Remove("X-Res")
assert.Equal(t, "", resp.Value("X-Res"))
final := resp.Finalize()
assert.Contains(t, string(final), "HTTP/1.1 200 OK\r\n")
assert.Contains(t, string(final), "Server: new\r\n")
assert.True(t, bytes.HasSuffix(final, []byte("\r\n\r\n")))
}
func TestSetRemainingHeaders(t *testing.T) {
tests := []struct {
name string
data []byte
initialHeaders map[string]string
expectHeaders map[string]string
}{
{
name: "various header formats",
data: []byte("K1: V1\r\nK2:V2\r\n K3 : V3 \r\nNoColon\r\n\r\n"),
expectHeaders: map[string]string{
"K1": "V1",
"K2": "V2",
"K3": "V3",
},
},
{
name: "no trailing CRLF",
data: []byte("K1: V1"),
expectHeaders: map[string]string{
"K1": "V1",
},
},
{
name: "empty lines",
data: []byte("\r\nK1: V1"),
expectHeaders: map[string]string{},
},
{
name: "headers with only colon",
data: []byte(": value\r\nkey:\r\n"),
expectHeaders: map[string]string{
"": "value",
"key": "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &requestHeader{headers: make(map[string]string)}
if tt.initialHeaders != nil {
req.headers = tt.initialHeaders
}
setRemainingHeaders(tt.data, req)
assert.Equal(t, len(tt.expectHeaders), len(req.headers))
for k, v := range tt.expectHeaders {
assert.Equal(t, v, req.headers[k])
}
})
}
}
+77
View File
@@ -0,0 +1,77 @@
package header
import (
"bytes"
"fmt"
)
func setRemainingHeaders(remaining []byte, header interface {
Set(key string, value string)
}) {
for len(remaining) > 0 {
lineEnd := bytes.Index(remaining, []byte("\r\n"))
if lineEnd == -1 {
lineEnd = len(remaining)
}
line := remaining[:lineEnd]
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx != -1 {
key := bytes.TrimSpace(line[:colonIdx])
value := bytes.TrimSpace(line[colonIdx+1:])
header.Set(string(key), string(value))
}
if lineEnd == len(remaining) {
break
}
remaining = remaining[lineEnd+2:]
}
}
func parseStartLine(startLine []byte) (method, path, version string, err error) {
firstSpace := bytes.IndexByte(startLine, ' ')
if firstSpace == -1 {
return "", "", "", fmt.Errorf("invalid start line: missing method")
}
secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ')
if secondSpace == -1 {
return "", "", "", fmt.Errorf("invalid start line: missing version")
}
secondSpace += firstSpace + 1
method = string(startLine[:firstSpace])
path = string(startLine[firstSpace+1 : secondSpace])
version = string(startLine[secondSpace+1:])
return method, path, version, nil
}
func finalize(startLine []byte, headers map[string]string) []byte {
size := len(startLine) + 2
for key, val := range headers {
size += len(key) + 2 + len(val) + 2
}
size += 2
buf := make([]byte, 0, size)
buf = append(buf, startLine...)
buf = append(buf, '\r', '\n')
for key, val := range headers {
buf = append(buf, key...)
buf = append(buf, ':', ' ')
buf = append(buf, val...)
buf = append(buf, '\r', '\n')
}
buf = append(buf, '\r', '\n')
return buf
}
+63
View File
@@ -0,0 +1,63 @@
package header
import (
"bytes"
"fmt"
)
func NewRequest(headerData []byte) (RequestHeader, error) {
header := &requestHeader{
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
}
startLine := headerData[:lineEnd]
header.startLine = startLine
var err error
header.method, header.path, header.version, err = parseStartLine(startLine)
if err != nil {
return nil, err
}
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func (req *requestHeader) Value(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeader) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeader) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeader) Method() string {
return req.method
}
func (req *requestHeader) Path() string {
return req.path
}
func (req *requestHeader) Version() string {
return req.version
}
func (req *requestHeader) Finalize() []byte {
return finalize(req.startLine, req.headers)
}
+40
View File
@@ -0,0 +1,40 @@
package header
import (
"bytes"
"fmt"
)
func NewResponse(headerData []byte) (ResponseHeader, error) {
header := &responseHeader{
startLine: nil,
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid response: no CRLF found in start line")
}
header.startLine = headerData[:lineEnd]
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func (resp *responseHeader) Value(key string) string {
return resp.headers[key]
}
func (resp *responseHeader) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeader) Remove(key string) {
delete(resp.headers, key)
}
func (resp *responseHeader) Finalize() []byte {
return finalize(resp.startLine, resp.headers)
}
+29
View File
@@ -0,0 +1,29 @@
package stream
import "bytes"
func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
headerByte := data[:delimiterIdx+len(DELIMITER)]
body := data[delimiterIdx+len(DELIMITER):]
return headerByte, body
}
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
+50
View File
@@ -0,0 +1,50 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := hs.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
headerEndIdx := bytes.Index(tmp, DELIMITER)
if headerEndIdx == -1 {
return handleNoDelimiter(p, tmp, err)
}
headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx)
if !isHTTPHeader(headerByte) {
copy(p, tmp)
return read, nil
}
return hs.processHTTPRequest(p, headerByte, bodyByte)
}
func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) {
reqhf, err := header.NewRequest(headerByte)
if err != nil {
return 0, err
}
if err = hs.ApplyRequestMiddlewares(reqhf); err != nil {
return 0, err
}
hs.reqHeader = reqhf
combined := append(reqhf.Finalize(), bodyByte...)
return copy(p, combined), nil
}
func handleNoDelimiter(p, tmp []byte, err error) (int, error) {
copy(p, tmp)
return len(tmp), err
}
+105
View File
@@ -0,0 +1,105 @@
package stream
import (
"io"
"log"
"net"
"regexp"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/middleware"
)
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
type HTTP interface {
io.ReadWriteCloser
CloseWrite() error
RemoteAddr() net.Addr
UseResponseMiddleware(mw middleware.ResponseMiddleware)
UseRequestMiddleware(mw middleware.RequestMiddleware)
SetRequestHeader(header header.RequestHeader)
RequestMiddlewares() []middleware.RequestMiddleware
ResponseMiddlewares() []middleware.ResponseMiddleware
ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error
ApplyRequestMiddlewares(reqhf header.RequestHeader) error
}
type http struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
buf []byte
respHeader header.ResponseHeader
reqHeader header.RequestHeader
respMW []middleware.ResponseMiddleware
reqMW []middleware.RequestMiddleware
}
func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP {
return &http{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
func (hs *http) RemoteAddr() net.Addr {
return hs.remoteAddr
}
func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) {
hs.respMW = append(hs.respMW, mw)
}
func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) {
hs.reqMW = append(hs.reqMW, mw)
}
func (hs *http) SetRequestHeader(header header.RequestHeader) {
hs.reqHeader = header
}
func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware {
return hs.reqMW
}
func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
return hs.respMW
}
func (hs *http) Close() error {
if closer, ok := hs.writer.(io.Closer); ok {
return closer.Close()
}
return nil
}
func (hs *http) CloseWrite() error {
if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok {
return closer.CloseWrite()
}
return hs.Close()
}
func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error {
for _, m := range hs.RequestMiddlewares() {
if err := m.HandleRequest(reqhf); err != nil {
log.Printf("Error when applying request middleware: %v", err)
return err
}
}
return nil
}
func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error {
for _, m := range hs.ResponseMiddlewares() {
if err := m.HandleResponse(resphf, bodyByte); err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return err
}
}
return nil
}
+765
View File
@@ -0,0 +1,765 @@
package stream
import (
"bytes"
"io"
"strings"
"testing"
"tunnel_pls/internal/http/header"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockAddr struct {
mock.Mock
}
func (m *MockAddr) String() string {
args := m.Called()
return args.String(0)
}
func (m *MockAddr) Network() string {
args := m.Called()
return args.String(0)
}
type MockRequestMiddleware struct {
mock.Mock
}
func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
args := m.Called(h)
return args.Error(0)
}
type MockResponseMiddleware struct {
mock.Mock
}
func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
args := m.Called(h, body)
return args.Error(0)
}
type MockReadWriter struct {
mock.Mock
bytes.Buffer
}
func (m *MockReadWriter) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriter) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriter) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockReadWriter) CloseWrite() error {
args := m.Called()
return args.Error(0)
}
type MockReadWriterOnlyCloser struct {
mock.Mock
bytes.Buffer
}
func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriterOnlyCloser) Close() error {
args := m.Called()
return args.Error(0)
}
type MockWriterOnly struct {
mock.Mock
}
func (m *MockWriterOnly) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockWriterOnly) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
type MockReader struct {
mock.Mock
}
func (m *MockReader) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
type MockWriter struct {
mock.Mock
}
func (m *MockWriter) Write(p []byte) (int, error) {
ret := m.Called(p)
var n int
var err error
switch v := ret.Get(0).(type) {
case func([]byte) int:
n = v(p)
case int:
n = v
default:
n = len(p)
}
switch v := ret.Get(1).(type) {
case func([]byte) error:
err = v(p)
case error:
err = v
default:
err = nil
}
return n, err
}
func (m *MockWriter) Close() error {
args := m.Called()
return args.Error(0)
}
func TestHTTPMethods(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
rw := new(MockReadWriter)
hs := New(rw, rw, addr)
assert.Equal(t, addr, hs.RemoteAddr())
reqMW := new(MockRequestMiddleware)
hs.UseRequestMiddleware(reqMW)
assert.Equal(t, 1, len(hs.RequestMiddlewares()))
assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
respMW := new(MockResponseMiddleware)
hs.UseResponseMiddleware(respMW)
assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
hs.SetRequestHeader(reqH)
}
func TestApplyMiddlewares(t *testing.T) {
tests := []struct {
name string
setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware)
apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
expectErr bool
}{
{
name: "apply request middleware success",
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.RequestHeader)
h.Set("X-Middleware", "true")
}).Return(nil)
hs.UseRequestMiddleware(reqMW)
},
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyRequestMiddlewares(reqH)
},
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
assert.Equal(t, "true", reqH.Value("X-Middleware"))
},
},
{
name: "apply response middleware success",
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.ResponseHeader)
h.Set("X-Resp-Middleware", "true")
}).Return(nil)
hs.UseResponseMiddleware(respMW)
},
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
},
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
assert.Equal(t, "true", respH.Value("X-Resp-Middleware"))
},
},
{
name: "apply request middleware error",
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError)
hs.UseRequestMiddleware(reqMW)
},
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyRequestMiddlewares(reqH)
},
expectErr: true,
},
{
name: "apply response middleware error",
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError)
hs.UseResponseMiddleware(respMW)
},
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
},
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
rw := new(MockReadWriter)
hs := New(rw, rw, addr)
reqMW := new(MockRequestMiddleware)
respMW := new(MockResponseMiddleware)
tt.setup(hs, reqMW, respMW)
err := tt.apply(hs, reqH, respH)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.verify != nil {
tt.verify(t, reqH, respH)
}
}
reqMW.AssertExpectations(t)
respMW.AssertExpectations(t)
})
}
}
func TestCloseMethods(t *testing.T) {
tests := []struct {
name string
setup func() (io.Writer, io.Reader)
op func(HTTP) error
verify func(*testing.T, io.Writer)
}{
{
name: "Close success",
setup: func() (io.Writer, io.Reader) {
rw := new(MockReadWriter)
rw.On("Close").Return(nil)
return rw, rw
},
op: func(hs HTTP) error { return hs.Close() },
verify: func(t *testing.T, w io.Writer) {
w.(*MockReadWriter).AssertCalled(t, "Close")
},
},
{
name: "CloseWrite with CloseWrite implementation",
setup: func() (io.Writer, io.Reader) {
rw := new(MockReadWriter)
rw.On("CloseWrite").Return(nil)
return rw, rw
},
op: func(hs HTTP) error { return hs.CloseWrite() },
verify: func(t *testing.T, w io.Writer) {
w.(*MockReadWriter).AssertCalled(t, "CloseWrite")
},
},
{
name: "CloseWrite fallback to Close",
setup: func() (io.Writer, io.Reader) {
rw := new(MockReadWriterOnlyCloser)
rw.On("Close").Return(nil)
return rw, rw
},
op: func(hs HTTP) error { return hs.CloseWrite() },
verify: func(t *testing.T, w io.Writer) {
w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close")
},
},
{
name: "Close with No Closer",
setup: func() (io.Writer, io.Reader) {
w := new(MockWriterOnly)
r := new(MockReader)
return w, r
},
op: func(hs HTTP) error { return hs.Close() },
},
{
name: "CloseWrite with No CloseWrite and No Closer",
setup: func() (io.Writer, io.Reader) {
w := new(MockWriterOnly)
r := new(MockReader)
return w, r
},
op: func(hs HTTP) error { return hs.CloseWrite() },
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
w, r := tt.setup()
hs := New(w, r, addr)
assert.NotPanics(t, func() {
err := tt.op(hs)
assert.NoError(t, err)
})
if tt.verify != nil {
tt.verify(t, w)
}
})
}
}
func TestSplitHeaderAndBody(t *testing.T) {
tests := []struct {
name string
data []byte
delimiterIdx int
expectHeader []byte
expectBody []byte
}{
{
name: "standard",
data: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nBodyContent"),
delimiterIdx: 31,
expectHeader: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"),
expectBody: []byte("BodyContent"),
},
{
name: "empty body",
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
delimiterIdx: 15,
expectHeader: []byte("HTTP/1.1 200 OK\r\n\r\n"),
expectBody: []byte(""),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h, b := splitHeaderAndBody(tt.data, tt.delimiterIdx)
assert.Equal(t, tt.expectHeader, h)
assert.Equal(t, tt.expectBody, b)
})
}
}
func TestIsHTTPHeader(t *testing.T) {
tests := []struct {
name string
buf []byte
expect bool
}{
{
name: "valid request",
buf: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"),
expect: true,
},
{
name: "valid response",
buf: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"),
expect: true,
},
{
name: "invalid start line",
buf: []byte("NOT_HTTP /path\r\nHost: example.com\r\n\r\n"),
expect: false,
},
{
name: "invalid header line (no colon)",
buf: []byte("GET / HTTP/1.1\r\nInvalidHeaderLine\r\n\r\n"),
expect: false,
},
{
name: "invalid header line (colon at 0)",
buf: []byte("GET / HTTP/1.1\r\n: value\r\n\r\n"),
expect: false,
},
{
name: "empty header section",
buf: []byte("GET / HTTP/1.1\r\n\r\n"),
expect: true,
},
{
name: "multiple headers",
buf: []byte("GET / HTTP/1.1\r\nH1: V1\r\nH2: V2\r\n\r\n"),
expect: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isHTTPHeader(tt.buf)
assert.Equal(t, tt.expect, result)
})
}
}
func TestRead(t *testing.T) {
tests := []struct {
name string
input []byte
readLen int
expectContent string
expectRead int
expectErr bool
middlewareErr error
isHTTP bool
}{
{
name: "valid http request",
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\nBody"),
readLen: 100,
expectContent: "Body",
expectRead: 54,
isHTTP: true,
},
{
name: "non-http data",
input: []byte("Some random data\r\n\r\nMore data"),
readLen: 100,
expectContent: "Some random data\r\n\r\nMore data",
expectRead: 29,
isHTTP: false,
},
{
name: "no delimiter",
input: []byte("Partial data without delimiter"),
readLen: 100,
expectContent: "Partial data without delimiter",
expectRead: 30,
isHTTP: false,
},
{
name: "middleware error",
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\n"),
readLen: 100,
middlewareErr: assert.AnError,
expectErr: true,
isHTTP: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
reader := new(MockReader)
writer := new(MockWriterOnly)
if tt.expectErr || tt.name == "valid http request" {
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
copy(p, tt.input)
}).Return(len(tt.input), io.EOF).Once()
} else {
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
copy(p, tt.input)
}).Return(len(tt.input), nil).Once()
}
hs := New(writer, reader, addr)
reqMW := new(MockRequestMiddleware)
if tt.isHTTP {
if tt.middlewareErr != nil {
reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr)
} else {
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.RequestHeader)
h.Set("X-Middleware", "true")
}).Return(nil)
}
}
hs.UseRequestMiddleware(reqMW)
p := make([]byte, tt.readLen)
n, err := hs.Read(p)
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectRead, n)
if tt.name == "valid http request" {
content := string(p[:n])
assert.Contains(t, content, "GET / HTTP/1.1\r\n")
assert.Contains(t, content, "Host: test\r\n")
assert.Contains(t, content, "X-Middleware: true\r\n")
assert.True(t, bytes.HasSuffix(p[:n], []byte("\r\n\r\nBody")))
} else {
assert.Equal(t, tt.expectContent, string(p[:n]))
}
}
if tt.isHTTP {
reqMW.AssertExpectations(t)
}
reader.AssertExpectations(t)
})
}
}
func TestWrite(t *testing.T) {
tests := []struct {
name string
writes [][]byte
expectWritten string
expectErr bool
middlewareErr error
isHTTP bool
}{
{
name: "valid http response in one write",
writes: [][]byte{
[]byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
},
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
isHTTP: true,
},
{
name: "valid http response in multiple writes",
writes: [][]byte{
[]byte("HTTP/1.1 200 OK\r\n"),
[]byte("Content-Length: 4\r\n\r\n"),
[]byte("Body"),
},
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
isHTTP: true,
},
{
name: "non-http data",
writes: [][]byte{
[]byte("Random data with delimiter\r\n\r\nFlush"),
},
expectWritten: "Random data with delimiter\r\n\r\nFlush",
isHTTP: false,
},
{
name: "bypass buffering",
writes: [][]byte{
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
},
expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
"HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
isHTTP: true,
},
{
name: "middleware error",
writes: [][]byte{
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
},
middlewareErr: assert.AnError,
expectErr: true,
isHTTP: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
var writtenData bytes.Buffer
writer := new(MockWriter)
writer.On("Write", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
writtenData.Write(p)
}).Return(func(p []byte) int {
return len(p)
}, nil)
reader := new(MockReader)
hs := New(writer, reader, addr)
respMW := new(MockResponseMiddleware)
if tt.isHTTP {
if tt.middlewareErr != nil {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr)
} else {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.ResponseHeader)
h.Set("X-Resp-Middleware", "true")
}).Return(nil)
}
}
hs.UseResponseMiddleware(respMW)
var totalN int
var err error
for _, w := range tt.writes {
var n int
n, err = hs.Write(w)
if err != nil {
break
}
totalN += n
}
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
written := writtenData.String()
if strings.HasPrefix(tt.expectWritten, "HTTP/") {
assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
if strings.Contains(tt.expectWritten, "Content-Length: 4") {
assert.Contains(t, written, "Content-Length: 4\r\n")
}
assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
} else {
assert.Equal(t, tt.expectWritten, written)
}
}
if tt.isHTTP {
respMW.AssertExpectations(t)
}
if tt.middlewareErr == nil {
writer.AssertExpectations(t)
}
})
}
}
func TestWriteErrors(t *testing.T) {
tests := []struct {
name string
setup func() (io.Writer, io.Reader)
data []byte
}{
{
name: "write error in writeHeaderAndBody",
setup: func() (io.Writer, io.Reader) {
writer := new(MockWriter)
writer.On("Write", mock.Anything).Return(0, assert.AnError)
reader := new(MockReader)
return writer, reader
},
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
},
{
name: "write error in writeHeaderAndBody second write",
setup: func() (io.Writer, io.Reader) {
writer := new(MockWriter)
writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once()
writer.On("Write", mock.Anything).Return(0, assert.AnError).Once()
reader := new(MockReader)
return writer, reader
},
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
},
{
name: "write error in writeRawBuffer",
setup: func() (io.Writer, io.Reader) {
writer := new(MockWriter)
writer.On("Write", mock.Anything).Return(0, assert.AnError)
reader := new(MockReader)
return writer, reader
},
data: []byte("Not HTTP\r\n\r\nFlush"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
w, r := tt.setup()
hs := New(w, r, addr)
_, err := hs.Write(tt.data)
assert.Error(t, err)
w.(*MockWriter).AssertExpectations(t)
})
}
}
func TestReadEOF(t *testing.T) {
tests := []struct {
name string
setup func() io.Reader
expectN int
expectErr error
expectContent string
}{
{
name: "read eof",
setup: func() io.Reader {
reader := new(MockReader)
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
copy(p, "data")
}).Return(4, io.EOF)
return reader
},
expectN: 4,
expectErr: io.EOF,
expectContent: "data",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
reader := tt.setup()
hs := New(nil, reader, addr)
p := make([]byte, 100)
n, err := hs.Read(p)
assert.Equal(t, tt.expectN, n)
assert.Equal(t, tt.expectErr, err)
assert.Equal(t, tt.expectContent, string(p[:n]))
reader.(*MockReader).AssertExpectations(t)
})
}
}
+88
View File
@@ -0,0 +1,88 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Write(p []byte) (int, error) {
if hs.shouldBypassBuffering(p) {
hs.respHeader = nil
}
if hs.respHeader != nil {
return hs.writer.Write(p)
}
hs.buf = append(hs.buf, p...)
headerEndIdx := bytes.Index(hs.buf, DELIMITER)
if headerEndIdx == -1 {
return len(p), nil
}
return hs.processBufferedResponse(p, headerEndIdx)
}
func (hs *http) shouldBypassBuffering(p []byte) bool {
return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
}
func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx)
if !isHTTPHeader(headerByte) {
return hs.writeRawBuffer()
}
if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil {
return 0, err
}
hs.buf = nil
return len(p), nil
}
func (hs *http) writeRawBuffer() (int, error) {
_, err := hs.writer.Write(hs.buf)
length := len(hs.buf)
hs.buf = nil
if err != nil {
return 0, err
}
return length, nil
}
func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error {
resphf, err := header.NewResponse(headerByte)
if err != nil {
return err
}
if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil {
return err
}
hs.respHeader = resphf
finalHeader := resphf.Finalize()
if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil {
return err
}
return nil
}
func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error {
if _, err := hs.writer.Write(header); err != nil {
return err
}
if len(bodyByte) > 0 {
if _, err := hs.writer.Write(bodyByte); err != nil {
return err
}
}
return nil
}
+28 -9
View File
@@ -5,6 +5,8 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"io"
"log"
"os"
"path/filepath"
@@ -12,7 +14,20 @@ import (
"golang.org/x/crypto/ssh"
)
var (
rsaGenerateKey = rsa.GenerateKey
pemEncode = pem.Encode
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
return ssh.NewPublicKey(key)
}
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
return w.Write(data)
}
osOpenFile = os.OpenFile
)
func GenerateSSHKeyIfNotExist(keyPath string) error {
var errGroup = make([]error, 0)
if _, err := os.Stat(keyPath); err == nil {
log.Printf("SSH key already exists at %s", keyPath)
return nil
@@ -20,7 +35,7 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
log.Printf("SSH key not found at %s, generating new key pair...", keyPath)
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
privateKey, err := rsaGenerateKey(rand.Reader, 4096)
if err != nil {
return err
}
@@ -35,33 +50,37 @@ func GenerateSSHKeyIfNotExist(keyPath string) error {
return err
}
privateKeyFile, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
privateKeyFile, err := osOpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
return err
}
defer privateKeyFile.Close()
defer func(privateKeyFile *os.File) {
errGroup = append(errGroup, privateKeyFile.Close())
}(privateKeyFile)
if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil {
if err := pemEncode(privateKeyFile, privateKeyPEM); err != nil {
return err
}
publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey)
publicKey, err := sshNewPublicKey(&privateKey.PublicKey)
if err != nil {
return err
}
pubKeyPath := keyPath + ".pub"
pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
pubKeyFile, err := osOpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return err
}
defer pubKeyFile.Close()
defer func(pubKeyFile *os.File) {
errGroup = append(errGroup, pubKeyFile.Close())
}(pubKeyFile)
_, err = pubKeyFile.Write(ssh.MarshalAuthorizedKey(publicKey))
_, err = pubKeyWrite(pubKeyFile, ssh.MarshalAuthorizedKey(publicKey))
if err != nil {
return err
}
log.Printf("SSH key pair generated successfully at %s and %s", keyPath, pubKeyPath)
return nil
return errors.Join(errGroup...)
}
+235
View File
@@ -0,0 +1,235 @@
package key
import (
"crypto/rsa"
"encoding/pem"
"errors"
"io"
"os"
"path/filepath"
"testing"
"golang.org/x/crypto/ssh"
)
func TestGenerateSSHKeyIfNotExist(t *testing.T) {
tempDir := t.TempDir()
tests := []struct {
name string
setup func(t *testing.T, tempDir string) string
mockSetup func() func()
wantErr bool
errStr string
verify func(t *testing.T, keyPath string)
}{
{
name: "GenerateNewKey",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "id_rsa")
},
verify: func(t *testing.T, keyPath string) {
pubKeyPath := keyPath + ".pub"
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Errorf("Private key file not created")
}
if _, err := os.Stat(pubKeyPath); os.IsNotExist(err) {
t.Errorf("Public key file not created")
}
privateKeyBytes, err := os.ReadFile(keyPath)
if err != nil {
t.Fatalf("Failed to read private key: %v", err)
}
if _, err = ssh.ParseRawPrivateKey(privateKeyBytes); err != nil {
t.Errorf("Failed to parse private key: %v", err)
}
publicKeyBytes, err := os.ReadFile(pubKeyPath)
if err != nil {
t.Fatalf("Failed to read public key: %v", err)
}
if _, _, _, _, err = ssh.ParseAuthorizedKey(publicKeyBytes); err != nil {
t.Errorf("Failed to parse public key: %v", err)
}
},
},
{
name: "DoNotOverwriteExistingKey",
setup: func(t *testing.T, tempDir string) string {
keyPath := filepath.Join(tempDir, "existing_id_rsa")
dummyPrivate := "dummy private"
dummyPublic := "dummy public"
if err := os.WriteFile(keyPath, []byte(dummyPrivate), 0600); err != nil {
t.Fatalf("Failed to create dummy private key: %v", err)
}
if err := os.WriteFile(keyPath+".pub", []byte(dummyPublic), 0644); err != nil {
t.Fatalf("Failed to create dummy public key: %v", err)
}
return keyPath
},
verify: func(t *testing.T, keyPath string) {
gotPrivate, _ := os.ReadFile(keyPath)
if string(gotPrivate) != "dummy private" {
t.Errorf("Private key was overwritten")
}
gotPublic, _ := os.ReadFile(keyPath + ".pub")
if string(gotPublic) != "dummy public" {
t.Errorf("Public key was overwritten")
}
},
},
{
name: "CreateNestedDirectories",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "nested", "dir", "id_rsa")
},
verify: func(t *testing.T, keyPath string) {
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
t.Errorf("Private key file not created in nested directory")
}
},
},
{
name: "FailureMkdirAll",
setup: func(t *testing.T, tempDir string) string {
dirPath := filepath.Join(tempDir, "file_as_dir")
if err := os.WriteFile(dirPath, []byte("not a dir"), 0644); err != nil {
t.Fatalf("Failed to create file: %v", err)
}
return filepath.Join(dirPath, "id_rsa")
},
wantErr: true,
},
{
name: "PrivateExistsPublicMissing",
setup: func(t *testing.T, tempDir string) string {
keyPath := filepath.Join(tempDir, "partial_id_rsa")
if err := os.WriteFile(keyPath, []byte("private"), 0600); err != nil {
t.Fatalf("Failed to create private key: %v", err)
}
return keyPath
},
verify: func(t *testing.T, keyPath string) {
if _, err := os.Stat(keyPath + ".pub"); !os.IsNotExist(err) {
t.Errorf("Public key should NOT have been created if private key existed")
}
},
},
{
name: "FailureRSAGenerateKey",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_rsa")
},
mockSetup: func() func() {
old := rsaGenerateKey
rsaGenerateKey = func(random io.Reader, bits int) (*rsa.PrivateKey, error) {
return nil, errors.New("rsa error")
}
return func() { rsaGenerateKey = old }
},
wantErr: true,
errStr: "rsa error",
},
{
name: "FailureOpenFilePrivate",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_open_private")
},
mockSetup: func() func() {
old := osOpenFile
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
return nil, errors.New("open error")
}
return func() { osOpenFile = old }
},
wantErr: true,
errStr: "open error",
},
{
name: "FailurePemEncode",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_pem")
},
mockSetup: func() func() {
old := pemEncode
pemEncode = func(out io.Writer, b *pem.Block) error {
return errors.New("pem error")
}
return func() { pemEncode = old }
},
wantErr: true,
errStr: "pem error",
},
{
name: "FailureSSHNewPublicKey",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_ssh")
},
mockSetup: func() func() {
old := sshNewPublicKey
sshNewPublicKey = func(key interface{}) (ssh.PublicKey, error) {
return nil, errors.New("ssh error")
}
return func() { sshNewPublicKey = old }
},
wantErr: true,
errStr: "ssh error",
},
{
name: "FailureOpenFilePublic",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_open_public")
},
mockSetup: func() func() {
old := osOpenFile
osOpenFile = func(name string, flag int, perm os.FileMode) (*os.File, error) {
if filepath.Ext(name) == ".pub" {
return nil, errors.New("open pub error")
}
return os.OpenFile(name, flag, perm)
}
return func() { osOpenFile = old }
},
wantErr: true,
errStr: "open pub error",
},
{
name: "FailurePubKeyWrite",
setup: func(t *testing.T, tempDir string) string {
return filepath.Join(tempDir, "fail_write")
},
mockSetup: func() func() {
old := pubKeyWrite
pubKeyWrite = func(w io.Writer, data []byte) (int, error) {
return 0, errors.New("write error")
}
return func() { pubKeyWrite = old }
},
wantErr: true,
errStr: "write error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
keyPath := tt.setup(t, tempDir)
if tt.mockSetup != nil {
cleanup := tt.mockSetup()
defer cleanup()
}
err := GenerateSSHKeyIfNotExist(keyPath)
if (err != nil) != tt.wantErr {
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr && tt.errStr != "" && err != nil && err.Error() != tt.errStr {
t.Errorf("GenerateSSHKeyIfNotExist() error = %v, wantErrStr %v", err, tt.errStr)
}
if tt.verify != nil {
tt.verify(t, keyPath)
}
})
}
}
+23
View File
@@ -0,0 +1,23 @@
package middleware
import (
"net"
"tunnel_pls/internal/http/header"
)
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+126
View File
@@ -0,0 +1,126 @@
package middleware
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type mockRequestHeader struct {
mock.Mock
}
func (m *mockRequestHeader) Value(key string) string {
return m.Called(key).String(0)
}
func (m *mockRequestHeader) Set(key string, value string) {
m.Called(key, value)
}
func (m *mockRequestHeader) Remove(key string) {
m.Called(key)
}
func (m *mockRequestHeader) Finalize() []byte {
return m.Called().Get(0).([]byte)
}
func (m *mockRequestHeader) Method() string {
return m.Called().String(0)
}
func (m *mockRequestHeader) Path() string {
return m.Called().String(0)
}
func (m *mockRequestHeader) Version() string {
return m.Called().String(0)
}
func TestForwardedFor_HandleRequest(t *testing.T) {
tests := []struct {
name string
addr net.Addr
expectedHost string
expectError bool
}{
{
name: "valid IPv4 address",
addr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 8080},
expectedHost: "192.168.1.100",
expectError: false,
},
{
name: "valid IPv6 address",
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 8080},
expectedHost: "2001:db8::ff00:42:8329",
expectError: false,
},
{
name: "invalid address format",
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
expectedHost: "",
expectError: true,
},
{
name: "valid IPv4 address with port",
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
expectedHost: "127.0.0.1",
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ff := NewForwardedFor(tc.addr)
reqHeader := new(mockRequestHeader)
if !tc.expectError {
reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return()
}
err := ff.HandleRequest(reqHeader)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
reqHeader.AssertExpectations(t)
}
})
}
}
func TestNewForwardedFor(t *testing.T) {
tests := []struct {
name string
addr net.Addr
expectAddr net.Addr
}{
{
name: "IPv4 address",
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
expectAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
},
{
name: "IPv6 address",
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
expectAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
},
{
name: "Unix address",
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
expectAddr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ff := NewForwardedFor(tc.addr)
assert.Equal(t, tc.expectAddr.String(), ff.addr.String())
})
}
}
+13
View File
@@ -0,0 +1,13 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type RequestMiddleware interface {
HandleRequest(header header.RequestHeader) error
}
type ResponseMiddleware interface {
HandleResponse(header header.ResponseHeader, body []byte) error
}
+16
View File
@@ -0,0 +1,16 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
@@ -0,0 +1,70 @@
package middleware
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type mockResponseHeader struct {
mock.Mock
}
func (m *mockResponseHeader) Value(key string) string {
return m.Called(key).String(0)
}
func (m *mockResponseHeader) Set(key string, value string) {
m.Called(key, value)
}
func (m *mockResponseHeader) Remove(key string) {
m.Called(key)
}
func (m *mockResponseHeader) Finalize() []byte {
return m.Called().Get(0).([]byte)
}
func TestTunnelFingerprintHandleResponse(t *testing.T) {
tests := []struct {
name string
expected map[string]string
body []byte
wantErr error
}{
{
name: "Sets Server Header",
expected: map[string]string{"Server": "Tunnel Please"},
body: []byte("Sample body"),
wantErr: nil,
},
{
name: "Overwrites Server Header",
expected: map[string]string{"Server": "Tunnel Please"},
body: nil,
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockHeader := new(mockResponseHeader)
for k, v := range tt.expected {
mockHeader.On("Set", k, v).Return()
}
tunnelFingerprint := NewTunnelFingerprint()
err := tunnelFingerprint.HandleResponse(mockHeader, tt.body)
assert.ErrorIs(t, err, tt.wantErr)
mockHeader.AssertExpectations(t)
})
}
}
func TestNewTunnelFingerprint(t *testing.T) {
instance := NewTunnelFingerprint()
assert.NotNil(t, instance)
}
+36 -49
View File
@@ -3,63 +3,40 @@ package port
import (
"fmt"
"sort"
"strconv"
"strings"
"sync"
"tunnel_pls/internal/config"
)
type Manager interface {
AddPortRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error
GetPortStatus(port uint16) (bool, bool)
type Port interface {
AddRange(startPort, endPort uint16) error
Unassigned() (uint16, bool)
SetStatus(port uint16, assigned bool) error
Claim(port uint16) (claimed bool)
}
type manager struct {
type port struct {
mu sync.RWMutex
ports map[uint16]bool
sortedPorts []uint16
}
var Default Manager = &manager{
ports: make(map[uint16]bool),
sortedPorts: []uint16{},
func New() Port {
return &port{
ports: make(map[uint16]bool),
sortedPorts: []uint16{},
}
}
func init() {
rawRange := config.Getenv("ALLOWED_PORTS", "")
if rawRange == "" {
return
}
splitRange := strings.Split(rawRange, "-")
if len(splitRange) != 2 {
return
}
start, err := strconv.ParseUint(splitRange[0], 10, 16)
if err != nil {
return
}
end, err := strconv.ParseUint(splitRange[1], 10, 16)
if err != nil {
return
}
_ = Default.AddPortRange(uint16(start), uint16(end))
}
func (pm *manager) AddPortRange(startPort, endPort uint16) error {
func (pm *port) AddRange(startPort, endPort uint16) error {
pm.mu.Lock()
defer pm.mu.Unlock()
if startPort > endPort {
return fmt.Errorf("start port cannot be greater than end port")
}
for port := startPort; port <= endPort; port++ {
if _, exists := pm.ports[port]; !exists {
pm.ports[port] = false
pm.sortedPorts = append(pm.sortedPorts, port)
for index := startPort; index <= endPort; index++ {
if _, exists := pm.ports[index]; !exists {
pm.ports[index] = false
pm.sortedPorts = append(pm.sortedPorts, index)
}
}
sort.Slice(pm.sortedPorts, func(i, j int) bool {
@@ -68,20 +45,19 @@ func (pm *manager) AddPortRange(startPort, endPort uint16) error {
return nil
}
func (pm *manager) GetUnassignedPort() (uint16, bool) {
func (pm *port) Unassigned() (uint16, bool) {
pm.mu.Lock()
defer pm.mu.Unlock()
for _, port := range pm.sortedPorts {
if !pm.ports[port] {
pm.ports[port] = true
return port, true
for _, index := range pm.sortedPorts {
if !pm.ports[index] {
return index, true
}
}
return 0, false
}
func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
func (pm *port) SetStatus(port uint16, assigned bool) error {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -89,10 +65,21 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
return nil
}
func (pm *manager) GetPortStatus(port uint16) (bool, bool) {
pm.mu.RLock()
defer pm.mu.RUnlock()
func (pm *port) Claim(port uint16) (claimed bool) {
pm.mu.Lock()
defer pm.mu.Unlock()
status, exists := pm.ports[port]
return status, exists
if exists && status {
return false
}
if !exists {
pm.ports[port] = true
return true
}
pm.ports[port] = true
return true
}
+114
View File
@@ -0,0 +1,114 @@
package port
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAddRange(t *testing.T) {
tests := []struct {
name string
startPort uint16
endPort uint16
wantErr bool
}{
{"normal range", 1000, 1002, false},
{"invalid range", 2000, 1999, true},
{"single port range", 3000, 3000, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm := New()
err := pm.AddRange(tt.startPort, tt.endPort)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestUnassigned(t *testing.T) {
pm := New()
_ = pm.AddRange(1000, 1002)
tests := []struct {
name string
status map[uint16]bool
want uint16
wantOk bool
}{
{"all unassigned", map[uint16]bool{1000: false, 1001: false, 1002: false}, 1000, true},
{"some assigned", map[uint16]bool{1000: true, 1001: false, 1002: true}, 1001, true},
{"all assigned", map[uint16]bool{1000: true, 1001: true, 1002: true}, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for k, v := range tt.status {
_ = pm.SetStatus(k, v)
}
got, gotOk := pm.Unassigned()
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantOk, gotOk)
})
}
}
func TestSetStatus(t *testing.T) {
pm := New()
_ = pm.AddRange(1000, 1002)
tests := []struct {
name string
port uint16
assigned bool
}{
{"assign port 1000", 1000, true},
{"unassign port 1001", 1001, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := pm.SetStatus(tt.port, tt.assigned)
assert.NoError(t, err)
status, ok := pm.(*port).ports[tt.port]
assert.True(t, ok)
assert.Equal(t, tt.assigned, status)
})
}
}
func TestClaim(t *testing.T) {
pm := New()
_ = pm.AddRange(1000, 1002)
tests := []struct {
name string
port uint16
status bool
want bool
}{
{"claim unassigned port", 1000, false, true},
{"claim already assigned port", 1001, true, false},
{"claim non-existent port", 5000, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, exists := pm.(*port).ports[tt.port]; exists {
_ = pm.SetStatus(tt.port, tt.status)
}
got := pm.Claim(tt.port)
assert.Equal(t, tt.want, got)
finalState := pm.(*port).ports[tt.port]
assert.True(t, finalState)
})
}
}
+35 -12
View File
@@ -1,18 +1,41 @@
package random
import (
mathrand "math/rand"
"strings"
"time"
"crypto/rand"
"fmt"
"io"
)
func GenerateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz"
seededRand := mathrand.New(mathrand.NewSource(time.Now().UnixNano() + int64(mathrand.Intn(9999))))
var result strings.Builder
for i := 0; i < length; i++ {
randomIndex := seededRand.Intn(len(charset))
result.WriteString(string(charset[randomIndex]))
}
return result.String()
var (
ErrInvalidLength = fmt.Errorf("invalid length")
)
type Random interface {
String(length int) (string, error)
}
type random struct {
reader io.Reader
}
func New() Random {
return &random{reader: rand.Reader}
}
func (ran *random) String(length int) (string, error) {
if length < 0 {
return "", ErrInvalidLength
}
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length)
if _, err := ran.reader.Read(b); err != nil {
return "", err
}
for i := range b {
b[i] = charset[int(b[i])%len(charset)]
}
return string(b), nil
}
+70
View File
@@ -0,0 +1,70 @@
package random
import (
"io"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRandom_String(t *testing.T) {
tests := []struct {
name string
length int
wantErr bool
}{
{"ValidLengthZero", 0, false},
{"ValidPositiveLength", 10, false},
{"NegativeLength", -1, true},
{"VeryLargeLength", 1_000_000, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randomizer := New()
result, err := randomizer.String(tt.length)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Len(t, result, tt.length)
}
})
}
}
func TestRandomWithFailingReader_String(t *testing.T) {
errBrainrot := assert.AnError
tests := []struct {
name string
reader io.Reader
expectErr error
}{
{
name: "failing reader",
reader: func() io.Reader {
return &failingReader{err: errBrainrot}
}(),
expectErr: errBrainrot,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
randomizer := &random{reader: tt.reader}
result, err := randomizer.String(20)
assert.ErrorIs(t, err, tt.expectErr)
assert.Empty(t, result)
})
}
}
type failingReader struct {
err error
}
func (f *failingReader) Read(p []byte) (int, error) {
return 0, f.err
}
+328
View File
@@ -0,0 +1,328 @@
package registry
import (
"fmt"
"sync"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
)
type Key = types.SessionKey
type Session interface {
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *types.Detail
}
type Registry interface {
Get(key Key) (session Session, err error)
GetWithUser(user string, key Key) (session Session, err error)
Update(user string, oldKey, newKey Key) error
Register(key Key, session Session) (success bool)
Remove(key Key)
GetAllSessionFromUser(user string) []Session
}
type registry struct {
mu sync.RWMutex
byUser map[string]map[Key]Session
slugIndex map[Key]string
}
var (
ErrSessionNotFound = fmt.Errorf("session not found")
ErrSlugInUse = fmt.Errorf("slug already in use")
ErrInvalidSlug = fmt.Errorf("invalid slug")
ErrForbiddenSlug = fmt.Errorf("forbidden slug")
ErrSlugChangeNotAllowed = fmt.Errorf("slug change not allowed for this tunnel type")
ErrSlugUnchanged = fmt.Errorf("slug is unchanged")
)
func NewRegistry() Registry {
return &registry{
byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string),
}
}
func (r *registry) Get(key Key) (session Session, err error) {
r.mu.RLock()
defer r.mu.RUnlock()
userID, ok := r.slugIndex[key]
if !ok {
return nil, ErrSessionNotFound
}
client, ok := r.byUser[userID][key]
if !ok {
return nil, ErrSessionNotFound
}
return client, nil
}
func (r *registry) GetWithUser(user string, key Key) (session Session, err error) {
r.mu.RLock()
defer r.mu.RUnlock()
client, ok := r.byUser[user][key]
if !ok {
return nil, ErrSessionNotFound
}
return client, nil
}
func (r *registry) Update(user string, oldKey, newKey Key) error {
if oldKey.Type != newKey.Type {
return ErrSlugUnchanged
}
if newKey.Type != types.TunnelTypeHTTP {
return ErrSlugChangeNotAllowed
}
if isForbiddenSlug(newKey.Id) {
return ErrForbiddenSlug
}
if !isValidSlug(newKey.Id) {
return ErrInvalidSlug
}
if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey {
return ErrSlugInUse
}
r.mu.Lock()
defer r.mu.Unlock()
client, ok := r.byUser[user][oldKey]
if !ok {
return ErrSessionNotFound
}
delete(r.byUser[user], oldKey)
delete(r.slugIndex, oldKey)
client.Slug().Set(newKey.Id)
r.slugIndex[newKey] = user
r.byUser[user][newKey] = client
return nil
}
func (r *registry) Register(key Key, userSession Session) (success bool) {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.slugIndex[key]; exists {
return false
}
userID := userSession.Lifecycle().User()
if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]Session)
}
r.byUser[userID][key] = userSession
r.slugIndex[key] = userID
return true
}
func (r *registry) GetAllSessionFromUser(user string) []Session {
r.mu.RLock()
defer r.mu.RUnlock()
m := r.byUser[user]
if len(m) == 0 {
return []Session{}
}
sessions := make([]Session, 0, len(m))
for _, s := range m {
sessions = append(sessions, s)
}
return sessions
}
func (r *registry) Remove(key Key) {
r.mu.Lock()
defer r.mu.Unlock()
userID, ok := r.slugIndex[key]
if !ok {
return
}
delete(r.byUser[userID], key)
if len(r.byUser[userID]) == 0 {
delete(r.byUser, userID)
}
delete(r.slugIndex, key)
}
func isValidSlug(slug string) bool {
if len(slug) < minSlugLength || len(slug) > maxSlugLength {
return false
}
if slug[0] == '-' || slug[len(slug)-1] == '-' {
return false
}
for _, c := range slug {
if !isValidSlugChar(byte(c)) {
return false
}
}
return true
}
func isValidSlugChar(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-'
}
func isForbiddenSlug(slug string) bool {
_, ok := forbiddenSlugs[slug]
return ok
}
var forbiddenSlugs = map[string]struct{}{
"ping": {},
"staging": {},
"admin": {},
"root": {},
"api": {},
"www": {},
"support": {},
"help": {},
"status": {},
"health": {},
"login": {},
"logout": {},
"signup": {},
"register": {},
"settings": {},
"config": {},
"null": {},
"undefined": {},
"example": {},
"test": {},
"dev": {},
"system": {},
"administrator": {},
"dashboard": {},
"account": {},
"profile": {},
"user": {},
"users": {},
"auth": {},
"oauth": {},
"callback": {},
"webhook": {},
"webhooks": {},
"static": {},
"assets": {},
"cdn": {},
"mail": {},
"email": {},
"ftp": {},
"ssh": {},
"git": {},
"svn": {},
"blog": {},
"news": {},
"about": {},
"contact": {},
"terms": {},
"privacy": {},
"legal": {},
"billing": {},
"payment": {},
"checkout": {},
"cart": {},
"shop": {},
"store": {},
"download": {},
"uploads": {},
"images": {},
"img": {},
"css": {},
"js": {},
"fonts": {},
"public": {},
"private": {},
"internal": {},
"external": {},
"proxy": {},
"cache": {},
"debug": {},
"metrics": {},
"monitoring": {},
"graphql": {},
"rest": {},
"rpc": {},
"socket": {},
"ws": {},
"wss": {},
"app": {},
"apps": {},
"mobile": {},
"desktop": {},
"embed": {},
"widget": {},
"docs": {},
"documentation": {},
"wiki": {},
"forum": {},
"community": {},
"feedback": {},
"report": {},
"abuse": {},
"spam": {},
"security": {},
"verify": {},
"confirm": {},
"reset": {},
"password": {},
"recovery": {},
"unsubscribe": {},
"subscribe": {},
"notifications": {},
"alerts": {},
"messages": {},
"inbox": {},
"outbox": {},
"sent": {},
"draft": {},
"trash": {},
"archive": {},
"search": {},
"explore": {},
"discover": {},
"trending": {},
"popular": {},
"featured": {},
"new": {},
"latest": {},
"top": {},
"best": {},
"hot": {},
"random": {},
"all": {},
"any": {},
"none": {},
"true": {},
"false": {},
}
var (
minSlugLength = 3
maxSlugLength = 20
)
+695
View File
@@ -0,0 +1,695 @@
package registry
import (
"sync"
"testing"
"time"
"tunnel_pls/internal/port"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
type mockSession struct {
mock.Mock
}
func (m *mockSession) Lifecycle() lifecycle.Lifecycle {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(lifecycle.Lifecycle)
}
func (m *mockSession) Interaction() interaction.Interaction {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(interaction.Interaction)
}
func (m *mockSession) Forwarder() forwarder.Forwarder {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(forwarder.Forwarder)
}
func (m *mockSession) Slug() slug.Slug {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(slug.Slug)
}
func (m *mockSession) Detail() *types.Detail {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(*types.Detail)
}
type mockLifecycle struct {
mock.Mock
}
func (ml *mockLifecycle) Channel() ssh.Channel {
args := ml.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(ssh.Channel)
}
func (ml *mockLifecycle) Connection() ssh.Conn {
args := ml.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(ssh.Conn)
}
func (ml *mockLifecycle) PortRegistry() port.Port {
args := ml.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(port.Port)
}
func (ml *mockLifecycle) SetChannel(channel ssh.Channel) { ml.Called(channel) }
func (ml *mockLifecycle) SetStatus(status types.SessionStatus) { ml.Called(status) }
func (ml *mockLifecycle) IsActive() bool { return ml.Called().Bool(0) }
func (ml *mockLifecycle) StartedAt() time.Time { return ml.Called().Get(0).(time.Time) }
func (ml *mockLifecycle) Close() error { return ml.Called().Error(0) }
func (ml *mockLifecycle) User() string { return ml.Called().String(0) }
type mockSlug struct {
mock.Mock
}
func (ms *mockSlug) Set(slug string) { ms.Called(slug) }
func (ms *mockSlug) String() string { return ms.Called().String(0) }
func createMockSession(user ...string) *mockSession {
u := "user1"
if len(user) > 0 {
u = user[0]
}
m := new(mockSession)
ml := new(mockLifecycle)
ml.On("User").Return(u).Maybe()
m.On("Lifecycle").Return(ml).Maybe()
ms := new(mockSlug)
ms.On("Set", mock.Anything).Maybe()
m.On("Slug").Return(ms).Maybe()
m.On("Interaction").Return(nil).Maybe()
m.On("Forwarder").Return(nil).Maybe()
m.On("Detail").Return(nil).Maybe()
return m
}
func TestNewRegistry(t *testing.T) {
r := NewRegistry()
require.NotNil(t, r)
}
func TestRegistry_Get(t *testing.T) {
tests := []struct {
name string
setupFunc func(r *registry)
key types.SessionKey
wantErr error
wantResult bool
}{
{
name: "session found",
setupFunc: func(r *registry) {
user := "user1"
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
session := createMockSession(user)
r.mu.Lock()
defer r.mu.Unlock()
r.byUser[user] = map[types.SessionKey]Session{
key: session,
}
r.slugIndex[key] = user
},
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
wantErr: nil,
wantResult: true,
},
{
name: "session not found in slugIndex",
setupFunc: func(r *registry) {},
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
wantErr: ErrSessionNotFound,
},
{
name: "session not found in byUser",
setupFunc: func(r *registry) {
r.mu.Lock()
defer r.mu.Unlock()
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
},
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
wantErr: ErrSessionNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
byUser: make(map[string]map[types.SessionKey]Session),
slugIndex: make(map[types.SessionKey]string),
mu: sync.RWMutex{},
}
tt.setupFunc(r)
session, err := r.Get(tt.key)
assert.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantResult, session != nil)
})
}
}
func TestRegistry_GetWithUser(t *testing.T) {
tests := []struct {
name string
setupFunc func(r *registry)
user string
key types.SessionKey
wantErr error
wantResult bool
}{
{
name: "session found",
setupFunc: func(r *registry) {
user := "user1"
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser[user] = map[types.SessionKey]Session{
key: session,
}
r.slugIndex[key] = user
},
user: "user1",
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
wantErr: nil,
wantResult: true,
},
{
name: "session not found in slugIndex",
setupFunc: func(r *registry) {},
user: "user1",
key: types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP},
wantErr: ErrSessionNotFound,
},
{
name: "session not found in byUser",
setupFunc: func(r *registry) {
r.mu.Lock()
defer r.mu.Unlock()
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "invalid_user"
},
user: "user1",
key: types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP},
wantErr: ErrSessionNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
byUser: make(map[string]map[types.SessionKey]Session),
slugIndex: make(map[types.SessionKey]string),
mu: sync.RWMutex{},
}
tt.setupFunc(r)
session, err := r.GetWithUser(tt.user, tt.key)
assert.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantResult, session != nil)
})
}
}
func TestRegistry_Update(t *testing.T) {
tests := []struct {
name string
user string
setupFunc func(r *registry) (oldKey, newKey types.SessionKey)
wantErr error
}{
{
name: "change slug success",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
session := createMockSession("user1")
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: nil,
},
{
name: "change slug to already used slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
newKey: session,
}
r.slugIndex[oldKey] = "user1"
r.slugIndex[newKey] = "user1"
return oldKey, newKey
},
wantErr: ErrSlugInUse,
},
{
name: "change slug to forbidden slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "ping", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: ErrForbiddenSlug,
},
{
name: "change slug to invalid slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test2-", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: ErrInvalidSlug,
},
{
name: "change slug but session not found",
user: "user2",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
}
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
return oldKey, newKey
},
wantErr: ErrSessionNotFound,
},
{
name: "change slug but session is not in the map",
user: "user2",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test4", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}: session,
}
r.slugIndex[types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}] = "user1"
return oldKey, newKey
},
wantErr: ErrSessionNotFound,
},
{
name: "change slug with same slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
newKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: ErrSlugUnchanged,
},
{
name: "tcp tunnel cannot change slug",
user: "user1",
setupFunc: func(r *registry) (types.SessionKey, types.SessionKey) {
oldKey := types.SessionKey{Id: "test2", Type: types.TunnelTypeTCP}
newKey := oldKey
session := createMockSession()
r.mu.Lock()
defer r.mu.Unlock()
r.byUser["user1"] = map[types.SessionKey]Session{
oldKey: session,
}
r.slugIndex[oldKey] = "user1"
return oldKey, newKey
},
wantErr: ErrSlugChangeNotAllowed,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
r := &registry{
byUser: make(map[string]map[types.SessionKey]Session),
slugIndex: make(map[types.SessionKey]string),
mu: sync.RWMutex{},
}
oldKey, newKey := tt.setupFunc(r)
err := r.Update(tt.user, oldKey, newKey)
assert.ErrorIs(t, err, tt.wantErr)
if err == nil {
r.mu.RLock()
defer r.mu.RUnlock()
_, ok := r.byUser[tt.user][newKey]
assert.True(t, ok, "newKey not found in registry")
_, ok = r.byUser[tt.user][oldKey]
assert.False(t, ok, "oldKey still exists in registry")
}
})
}
}
func TestRegistry_Register(t *testing.T) {
tests := []struct {
name string
user string
setupFunc func(r *registry) Key
wantOK bool
}{
{
name: "register new key successfully",
user: "user1",
setupFunc: func(r *registry) Key {
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
return key
},
wantOK: true,
},
{
name: "register already existing key fails",
user: "user1",
setupFunc: func(r *registry) Key {
key := types.SessionKey{Id: "test1", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
r.byUser["user1"] = map[Key]Session{key: session}
r.slugIndex[key] = "user1"
r.mu.Unlock()
return key
},
wantOK: false,
},
{
name: "register multiple keys for same user",
user: "user1",
setupFunc: func(r *registry) Key {
firstKey := types.SessionKey{Id: "first", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
r.byUser["user1"] = map[Key]Session{firstKey: session}
r.slugIndex[firstKey] = "user1"
r.mu.Unlock()
return types.SessionKey{Id: "second", Type: types.TunnelTypeHTTP}
},
wantOK: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
r := &registry{
byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string),
mu: sync.RWMutex{},
}
key := tt.setupFunc(r)
session := createMockSession()
ok := r.Register(key, session)
assert.Equal(t, tt.wantOK, ok)
if ok {
r.mu.RLock()
defer r.mu.RUnlock()
assert.Equal(t, session, r.byUser[tt.user][key], "session not stored in byUser")
assert.Equal(t, tt.user, r.slugIndex[key], "slugIndex not updated")
}
})
}
}
func TestRegistry_GetAllSessionFromUser(t *testing.T) {
tests := []struct {
name string
setupFunc func(r *registry) string
expectN int
}{
{
name: "user has no sessions",
setupFunc: func(r *registry) string {
return "user1"
},
expectN: 0,
},
{
name: "user has multiple sessions",
setupFunc: func(r *registry) string {
user := "user1"
key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
r.mu.Lock()
r.byUser[user] = map[Key]Session{
key1: createMockSession(),
key2: createMockSession(),
}
r.mu.Unlock()
return user
},
expectN: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string),
mu: sync.RWMutex{},
}
user := tt.setupFunc(r)
sessions := r.GetAllSessionFromUser(user)
assert.Len(t, sessions, tt.expectN)
})
}
}
func TestRegistry_Remove(t *testing.T) {
tests := []struct {
name string
setupFunc func(r *registry) (string, types.SessionKey)
key types.SessionKey
verify func(*testing.T, *registry, string, types.SessionKey)
}{
{
name: "remove existing key",
setupFunc: func(r *registry) (string, types.SessionKey) {
user := "user1"
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
session := createMockSession()
r.mu.Lock()
r.byUser[user] = map[Key]Session{key: session}
r.slugIndex[key] = user
r.mu.Unlock()
return user, key
},
verify: func(t *testing.T, r *registry, user string, key types.SessionKey) {
_, ok := r.byUser[user][key]
assert.False(t, ok, "expected key to be removed from byUser")
_, ok = r.slugIndex[key]
assert.False(t, ok, "expected key to be removed from slugIndex")
_, ok = r.byUser[user]
assert.False(t, ok, "expected user to be removed from byUser map")
},
},
{
name: "remove non-existing key",
setupFunc: func(r *registry) (string, types.SessionKey) {
return "", types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &registry{
byUser: make(map[string]map[Key]Session),
slugIndex: make(map[Key]string),
mu: sync.RWMutex{},
}
user, key := tt.setupFunc(r)
if user == "" {
key = tt.key
}
r.Remove(key)
if tt.verify != nil {
tt.verify(t, r, user, key)
}
})
}
}
func TestIsValidSlug(t *testing.T) {
tests := []struct {
slug string
want bool
}{
{"abc", true},
{"abc-123", true},
{"a", false},
{"verybigdihsixsevenlabubu", false},
{"-iamsigma", false},
{"ligma-", false},
{"invalid$", false},
{"valid-slug1", true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.slug, func(t *testing.T) {
got := isValidSlug(tt.slug)
if got != tt.want {
t.Errorf("isValidSlug(%q) = %v; want %v", tt.slug, got, tt.want)
}
})
}
}
func TestIsValidSlugChar(t *testing.T) {
tests := []struct {
char byte
want bool
}{
{'a', true},
{'z', true},
{'0', true},
{'9', true},
{'-', true},
{'A', false},
{'$', false},
}
for _, tt := range tests {
tt := tt
t.Run(string(tt.char), func(t *testing.T) {
got := isValidSlugChar(tt.char)
if got != tt.want {
t.Errorf("isValidSlugChar(%q) = %v; want %v", tt.char, got, tt.want)
}
})
}
}
func TestIsForbiddenSlug(t *testing.T) {
forbiddenSlugs = map[string]struct{}{
"admin": {},
"root": {},
}
tests := []struct {
slug string
want bool
}{
{"admin", true},
{"root", true},
{"user", false},
{"guest", false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.slug, func(t *testing.T) {
got := isForbiddenSlug(tt.slug)
if got != tt.want {
t.Errorf("isForbiddenSlug(%q) = %v; want %v", tt.slug, got, tt.want)
}
})
}
}
+41
View File
@@ -0,0 +1,41 @@
package transport
import (
"errors"
"log"
"net"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
)
type httpServer struct {
handler *httpHandler
config config.Config
}
func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport {
return &httpServer{
handler: newHTTPHandler(config, sessionRegistry),
config: config,
}
}
func (ht *httpServer) Listen() (net.Listener, error) {
return net.Listen("tcp", ":"+ht.config.HTTPPort())
}
func (ht *httpServer) Serve(listener net.Listener) error {
log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort())
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.handler.Handler(conn, false)
}
}
+135
View File
@@ -0,0 +1,135 @@
package transport
import (
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestNewHTTPServer(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
assert.NotNil(t, srv)
httpSrv, ok := srv.(*httpServer)
assert.True(t, ok)
assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
assert.NotNil(t, srv)
}
func TestHTTPServer_Listen(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
listener, err := srv.Listen()
assert.NoError(t, err)
assert.NotNil(t, listener)
err = listener.Close()
assert.NoError(t, err)
}
func TestHTTPServer_Serve(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
time.Sleep(100 * time.Millisecond)
err = listener.Close()
assert.NoError(t, err)
}()
err = srv.Serve(listener)
assert.True(t, errors.Is(err, net.ErrClosed))
}
func TestHTTPServer_Serve_AcceptError(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
err := srv.Serve(ml)
assert.True(t, errors.Is(err, net.ErrClosed))
ml.AssertExpectations(t)
}
func TestHTTPServer_Serve_Success(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
mockConfig.On("HeaderSize").Return(4096)
mockConfig.On("TLSRedirect").Return(false)
srv := NewHTTPServer(mockConfig, msr)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
listenerport := listener.Addr().(*net.TCPAddr).Port
go func() {
_ = srv.Serve(listener)
}()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
assert.NoError(t, err)
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
time.Sleep(100 * time.Millisecond)
err = conn.Close()
assert.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
}
type mockListener struct {
mock.Mock
}
func (m *mockListener) Accept() (net.Conn, error) {
args := m.Called()
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(net.Conn), args.Error(1)
}
func (m *mockListener) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *mockListener) Addr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
+201
View File
@@ -0,0 +1,201 @@
package transport
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware"
"tunnel_pls/internal/registry"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type httpHandler struct {
config config.Config
sessionRegistry registry.Registry
}
func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler {
return &httpHandler{
config: config,
sessionRegistry: sessionRegistry,
}
}
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
fmt.Sprintf("Location: %s", location) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
return err
}
return nil
}
func (hh *httpHandler) badRequest(conn net.Conn) error {
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
return err
}
return nil
}
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
defer hh.closeConnection(conn)
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
buf := make([]byte, hh.config.HeaderSize())
n, err := conn.Read(buf)
if err != nil {
_ = hh.badRequest(conn)
return
}
if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 {
_ = hh.badRequest(conn)
return
}
_ = conn.SetReadDeadline(time.Time{})
reqhf, err := header.NewRequest(buf[:n])
if err != nil {
log.Printf("Error creating request header: %v", err)
_ = hh.badRequest(conn)
return
}
slug, err := hh.extractSlug(reqhf)
if err != nil {
_ = hh.badRequest(conn)
return
}
if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain()))
return
}
if hh.handlePingRequest(slug, conn) {
return
}
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.TunnelTypeHTTP,
})
if err != nil {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return
}
hw := stream.New(conn, conn, conn.RemoteAddr())
defer func(hw stream.HTTP) {
err = hw.Close()
if err != nil {
log.Printf("Error closing HTTP stream: %v", err)
}
}(hw)
hh.forwardRequest(hw, reqhf, sshSession)
}
func (hh *httpHandler) closeConnection(conn net.Conn) {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
}
}
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".")
if len(host) <= 1 {
return "", errors.New("invalid host")
}
return host[0], nil
}
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
return !isTLS && hh.config.TLSRedirect()
}
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
if slug != "ping" {
return false
}
_, err := conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
return true
}
return true
}
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
channel, reqs, err := sshSession.Forwarder().OpenForwardedChannel(ctx, hw.RemoteAddr())
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
defer func() {
err = channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
hh.setupMiddlewares(hw)
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
log.Printf("Failed to forward initial request: %v", err)
return
}
sshSession.Forwarder().HandleConnection(hw, channel)
}
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
fingerprintMiddleware := middleware.NewTunnelFingerprint()
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
hw.UseResponseMiddleware(fingerprintMiddleware)
hw.UseRequestMiddleware(forwardedForMiddleware)
}
func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error {
hw.SetRequestHeader(initialRequest)
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
return fmt.Errorf("error applying request middlewares: %w", err)
}
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
return fmt.Errorf("error writing to channel: %w", err)
}
return nil
}
+717
View File
@@ -0,0 +1,717 @@
package transport
import (
"bytes"
"context"
"fmt"
"io"
"net"
"strings"
"sync"
"testing"
"time"
"tunnel_pls/internal/registry"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockSession struct {
mock.Mock
}
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
args := m.Called()
return args.Get(0).(lifecycle.Lifecycle)
}
func (m *MockSession) Interaction() interaction.Interaction {
args := m.Called()
return args.Get(0).(interaction.Interaction)
}
func (m *MockSession) Forwarder() forwarder.Forwarder {
args := m.Called()
return args.Get(0).(forwarder.Forwarder)
}
func (m *MockSession) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
func (m *MockSession) Detail() *types.Detail {
args := m.Called()
return args.Get(0).(*types.Detail)
}
type MockSSHChannel struct {
ssh.Channel
mock.Mock
}
func (m *MockSSHChannel) Write(data []byte) (int, error) {
args := m.Called(data)
return args.Int(0), args.Error(1)
}
func (m *MockSSHChannel) Close() error {
args := m.Called()
return args.Error(0)
}
type MockForwarder struct {
mock.Mock
}
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
m.Called(dst, src)
}
func (m *MockForwarder) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockForwarder) TunnelType() types.TunnelType {
args := m.Called()
return args.Get(0).(types.TunnelType)
}
func (m *MockForwarder) ForwardedPort() uint16 {
args := m.Called()
return uint16(args.Int(0))
}
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
m.Called(tunnelType)
}
func (m *MockForwarder) SetForwardedPort(port uint16) {
m.Called(port)
}
func (m *MockForwarder) SetListener(listener net.Listener) {
m.Called(listener)
}
func (m *MockForwarder) Listener() net.Listener {
args := m.Called()
return args.Get(0).(net.Listener)
}
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
args := m.Called(ctx, origin)
if args.Get(0) == nil {
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
type MockConn struct {
mock.Mock
ReadBuffer *bytes.Buffer
}
func (m *MockConn) LocalAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockConn) SetDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) Read(b []byte) (n int, err error) {
if m.ReadBuffer != nil {
return m.ReadBuffer.Read(b)
}
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Write(b []byte) (n int, err error) {
args := m.Called(b)
if args.Int(0) == -1 {
return len(b), args.Error(1)
}
return args.Int(0), args.Error(1)
}
func (m *MockConn) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockConn) RemoteAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
type wrappedConn struct {
net.Conn
remoteAddr net.Addr
}
func (c *wrappedConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func TestNewHTTPHandler(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
mockConfig.On("Domain").Return("domain")
mockConfig.On("TLSRedirect").Return(false)
hh := newHTTPHandler(mockConfig, msr)
assert.NotNil(t, hh)
assert.Equal(t, msr, hh.sessionRegistry)
}
func TestHandler(t *testing.T) {
tests := []struct {
name string
isTLS bool
redirectTLS bool
request []byte
expected []byte
setupMocks func(*MockSessionRegistry)
setupConn func() (net.Conn, net.Conn)
expectError bool
}{
{
name: "bad request - invalid host",
isTLS: false,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "bad request - missing host",
isTLS: false,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\n\r\n"),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "isTLS true and redirectTLS true - no redirect",
isTLS: true,
redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "redirect to TLS",
isTLS: false,
redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"),
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "handle ping request",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "session not found",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return((registry.Session)(nil), fmt.Errorf("session not found"))
},
},
{
name: "bad request - invalid http",
isTLS: false,
redirectTLS: false,
request: []byte("INVALID\r\n\r\n"),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "bad request - header too large",
isTLS: false,
redirectTLS: false,
request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "bad request - no request",
isTLS: false,
redirectTLS: false,
request: []byte(""),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "forwarding - open channel fails",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
},
},
{
name: "forwarding - send initial request fails",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mockSSHChannel.On("Close").Return(nil)
go func() {
for range reqCh {
}
}()
},
},
{
name: "forwarding - success",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil)
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
w := args.Get(0).(io.ReadWriter)
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
})
go func() {
for range reqCh {
}
}()
},
},
{
name: "redirect - write failure",
isTLS: false,
redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "bad request - write failure",
isTLS: false,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "read error - connection failure",
isTLS: false,
redirectTLS: false,
request: []byte(""),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "handle ping request - write failure",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "close connection - error",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(182, nil)
mc.On("Close").Return(fmt.Errorf("close error"))
return mc, nil
},
},
{
name: "forwarding - stream close error",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil)
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return()
},
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
mc.On("RemoteAddr").Return(addr)
return mc, nil
},
},
{
name: "forwarding - middleware failure",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
return k.Id == "test"
})).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Close").Return(nil)
},
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Close").Return(nil).Times(2)
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
return mc, nil
},
},
{
name: "forwarding - channel close error",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
w := args.Get(0).(io.ReadWriter)
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
})
},
},
{
name: "forwarding - open channel timeout",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockForwarder.On("OpenForwardedChannel", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
ctx := args.Get(0).(context.Context)
<-ctx.Done()
}).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), context.DeadlineExceeded)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSessionRegistry := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
mockConfig.On("HeaderSize").Return(4096)
mockConfig.On("TLSRedirect").Return(true)
hh := &httpHandler{
sessionRegistry: mockSessionRegistry,
config: mockConfig,
}
if tt.setupMocks != nil {
tt.setupMocks(mockSessionRegistry)
}
var serverConn, clientConn net.Conn
if tt.setupConn != nil {
serverConn, clientConn = tt.setupConn()
} else {
serverConn, clientConn = net.Pipe()
}
if clientConn != nil {
defer func(clientConn net.Conn) {
err := clientConn.Close()
assert.NoError(t, err)
}(clientConn)
}
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
var wrappedServerConn net.Conn
if _, ok := serverConn.(*MockConn); ok {
wrappedServerConn = serverConn
} else {
wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr}
}
responseChan := make(chan []byte, 1)
doneChan := make(chan struct{})
if clientConn != nil {
go func() {
defer close(doneChan)
var res []byte
for {
buf := make([]byte, 4096)
n, err := clientConn.Read(buf)
if err != nil {
if err != io.EOF {
t.Logf("Error reading response: %v", err)
}
break
}
res = append(res, buf[:n]...)
if len(tt.expected) > 0 && len(res) >= len(tt.expected) {
break
}
}
responseChan <- res
}()
go func() {
_, err := clientConn.Write(tt.request)
if err != nil {
t.Logf("Error writing request: %v", err)
}
}()
} else {
close(responseChan)
close(doneChan)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
hh.Handler(wrappedServerConn, tt.isTLS)
}()
select {
case response := <-responseChan:
if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" {
resStr := string(response)
assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n"))
assert.Contains(t, resStr, "Content-Length: 5\r\n")
assert.Contains(t, resStr, "Server: Tunnel Please\r\n")
assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello"))
} else {
assert.Equal(t, string(tt.expected), string(response))
}
case <-time.After(10 * time.Second):
if clientConn != nil {
t.Fatal("Test timeout - no response received")
}
}
wg.Wait()
if clientConn != nil {
<-doneChan
}
mockSessionRegistry.AssertExpectations(t)
if mc, ok := serverConn.(*MockConn); ok {
mc.AssertExpectations(t)
}
})
}
}
+44
View File
@@ -0,0 +1,44 @@
package transport
import (
"crypto/tls"
"errors"
"log"
"net"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
)
type https struct {
config config.Config
tlsConfig *tls.Config
httpHandler *httpHandler
}
func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport {
return &https{
config: config,
tlsConfig: tlsConfig,
httpHandler: newHTTPHandler(config, sessionRegistry),
}
}
func (ht *https) Listen() (net.Listener, error) {
return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort())
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.httpHandler.Handler(conn, true)
}
}
+120
View File
@@ -0,0 +1,120 @@
package transport
import (
"crypto/tls"
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewHTTPSServer(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
tlsConfig := &tls.Config{}
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
assert.NotNil(t, srv)
httpsSrv, ok := srv.(*https)
assert.True(t, ok)
assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
}
func TestHTTPSServer_Listen(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
tlsConfig := &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, nil
},
}
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
listener, err := srv.Listen()
if err != nil {
t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err)
return
}
assert.NotNil(t, listener)
err = listener.Close()
assert.NoError(t, err)
}
func TestHTTPSServer_Serve(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
time.Sleep(100 * time.Millisecond)
err = listener.Close()
assert.NoError(t, err)
}()
err = srv.Serve(listener)
assert.True(t, errors.Is(err, net.ErrClosed))
}
func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
err := srv.Serve(ml)
assert.True(t, errors.Is(err, net.ErrClosed))
ml.AssertExpectations(t)
}
func TestHTTPSServer_Serve_Success(t *testing.T) {
msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
mockConfig.On("HeaderSize").Return(4096)
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
listenerport := listener.Addr().(*net.TCPAddr).Port
go func() {
_ = srv.Serve(listener)
}()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
assert.NoError(t, err)
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
time.Sleep(100 * time.Millisecond)
err = conn.Close()
assert.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
}
+67
View File
@@ -0,0 +1,67 @@
package transport
import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"time"
"golang.org/x/crypto/ssh"
)
type tcp struct {
port uint16
forwarder Forwarder
}
type Forwarder interface {
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
HandleConnection(dst io.ReadWriter, src ssh.Channel)
}
func NewTCPServer(port uint16, forwarder Forwarder) Transport {
return &tcp{
port: port,
forwarder: forwarder,
}
}
func (tt *tcp) Listen() (net.Listener, error) {
return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port))
}
func (tt *tcp) Serve(listener net.Listener) error {
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
log.Printf("Error accepting connection: %v", err)
continue
}
go tt.handleTcp(conn)
}
}
func (tt *tcp) handleTcp(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
log.Printf("Failed to close connection: %v", err)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
channel, reqs, err := tt.forwarder.OpenForwardedChannel(ctx, conn.RemoteAddr())
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
tt.forwarder.HandleConnection(conn, channel)
}
+146
View File
@@ -0,0 +1,146 @@
package transport
import (
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/crypto/ssh"
)
func TestNewTCPServer(t *testing.T) {
mf := new(MockForwarder)
port := uint16(9000)
srv := NewTCPServer(port, mf)
assert.NotNil(t, srv)
tcpSrv, ok := srv.(*tcp)
assert.True(t, ok)
assert.Equal(t, port, tcpSrv.port)
assert.Equal(t, mf, tcpSrv.forwarder)
}
func TestTCPServer_Listen(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
listener, err := srv.Listen()
assert.NoError(t, err)
assert.NotNil(t, listener)
err = listener.Close()
assert.NoError(t, err)
}
func TestTCPServer_Serve(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
time.Sleep(100 * time.Millisecond)
err = listener.Close()
assert.NoError(t, err)
}()
err = srv.Serve(listener)
assert.Nil(t, err)
}
func TestTCPServer_Serve_AcceptError(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
err := srv.Serve(ml)
assert.Nil(t, err)
ml.AssertExpectations(t)
}
func TestTCPServer_Serve_Success(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
reqs := make(chan *ssh.Request)
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
go func() {
_ = srv.Serve(listener)
}()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = conn.Close()
assert.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
mf.AssertExpectations(t)
}
func TestTCPServer_handleTcp_Success(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf).(*tcp)
serverConn, clientConn := net.Pipe()
defer func(clientConn net.Conn) {
err := clientConn.Close()
assert.NoError(t, err)
}(clientConn)
reqs := make(chan *ssh.Request)
mockChannel := new(MockSSHChannel)
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
mf.On("HandleConnection", serverConn, mockChannel).Return()
srv.handleTcp(serverConn)
mf.AssertExpectations(t)
}
func TestTCPServer_handleTcp_CloseError(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf).(*tcp)
mc := new(MockConn)
mc.On("Close").Return(errors.New("close error"))
mc.On("RemoteAddr").Return(&net.TCPAddr{})
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
srv.handleTcp(mc)
mc.AssertExpectations(t)
}
func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf).(*tcp)
serverConn, clientConn := net.Pipe()
defer func(clientConn net.Conn) {
err := clientConn.Close()
assert.NoError(t, err)
}(clientConn)
mf.On("OpenForwardedChannel", mock.Anything, mock.Anything).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
srv.handleTcp(serverConn)
mf.AssertExpectations(t)
}
+435
View File
@@ -0,0 +1,435 @@
package transport
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"log"
"os"
"path/filepath"
"sync"
"time"
"tunnel_pls/internal/config"
"github.com/caddyserver/certmagic"
"github.com/libdns/cloudflare"
)
func NewTLSConfig(config config.Config) (*tls.Config, error) {
var initErr error
tlsManagerOnce.Do(func() {
tm := createTLSManager(config)
initErr = tm.initialize()
if initErr == nil {
globalTLSManager = tm
}
})
if initErr != nil {
return nil, initErr
}
return globalTLSManager.getTLSConfig(), nil
}
type tlsManager struct {
config config.Config
certPath string
keyPath string
storagePath string
userCert *tls.Certificate
userCertMu sync.RWMutex
magic *certmagic.Config
useCertMagic bool
}
var globalTLSManager *tlsManager
var tlsManagerOnce sync.Once
func createTLSManager(cfg config.Config) *tlsManager {
storagePath := cfg.TLSStoragePath()
cleanBase := filepath.Clean(storagePath)
return &tlsManager{
config: cfg,
certPath: filepath.Join(cleanBase, "cert.pem"),
keyPath: filepath.Join(cleanBase, "privkey.pem"),
storagePath: filepath.Join(cleanBase, "certmagic"),
}
}
func (tm *tlsManager) initialize() error {
if tm.userCertsExistAndValid() {
return tm.initializeWithUserCerts()
}
return tm.initializeWithCertMagic()
}
func (tm *tlsManager) initializeWithUserCerts() error {
log.Printf("Using user-provided certificates from %s and %s", tm.certPath, tm.keyPath)
if err := tm.loadUserCerts(); err != nil {
return fmt.Errorf("failed to load user certificates: %w", err)
}
tm.useCertMagic = false
tm.startCertWatcher()
return nil
}
func (tm *tlsManager) initializeWithCertMagic() error {
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic",
tm.config.Domain(), tm.config.Domain())
if err := tm.initCertMagic(); err != nil {
return fmt.Errorf("failed to initialize CertMagic: %w", err)
}
tm.useCertMagic = true
return nil
}
func (tm *tlsManager) userCertsExistAndValid() bool {
if !tm.certFilesExist() {
return false
}
return validateCertDomains(tm.certPath, tm.config.Domain())
}
func (tm *tlsManager) certFilesExist() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath)
return false
}
if _, err := os.Stat(tm.keyPath); os.IsNotExist(err) {
log.Printf("Key file not found: %s", tm.keyPath)
return false
}
return true
}
func (tm *tlsManager) loadUserCerts() error {
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
if err != nil {
return err
}
tm.userCertMu.Lock()
tm.userCert = &cert
tm.userCertMu.Unlock()
log.Printf("Loaded user certificates successfully")
return nil
}
func (tm *tlsManager) startCertWatcher() {
go func() {
watcher := newCertWatcher(tm)
watcher.watch()
}()
}
func (tm *tlsManager) initCertMagic() error {
if err := tm.createStorageDirectory(); err != nil {
return err
}
if tm.config.CFAPIToken() == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
}
magic := tm.createCertMagicConfig()
tm.magic = magic
return tm.obtainCertificates(magic)
}
func (tm *tlsManager) createStorageDirectory() error {
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
return fmt.Errorf("failed to create cert storage directory: %w", err)
}
return nil
}
func (tm *tlsManager) createCertMagicConfig() *certmagic.Config {
cfProvider := &cloudflare.Provider{
APIToken: tm.config.CFAPIToken(),
}
storage := &certmagic.FileStorage{Path: tm.storagePath}
cache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
return tm.magic, nil
},
})
magic := certmagic.New(cache, certmagic.Config{
Storage: storage,
})
acmeIssuer := tm.createACMEIssuer(magic, cfProvider)
magic.Issuers = []certmagic.Issuer{acmeIssuer}
return magic
}
func (tm *tlsManager) createACMEIssuer(magic *certmagic.Config, cfProvider *cloudflare.Provider) *certmagic.ACMEIssuer {
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: tm.config.ACMEEmail(),
Agreed: true,
DNS01Solver: &certmagic.DNS01Solver{
DNSManager: certmagic.DNSManager{
DNSProvider: cfProvider,
},
},
})
if tm.config.ACMEStaging() {
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
log.Printf("Using Let's Encrypt staging server")
} else {
acmeIssuer.CA = certmagic.LetsEncryptProductionCA
log.Printf("Using Let's Encrypt production server")
}
return acmeIssuer
}
func (tm *tlsManager) obtainCertificates(magic *certmagic.Config) error {
domains := []string{tm.config.Domain(), "*." + tm.config.Domain()}
log.Printf("Requesting certificates for: %v", domains)
ctx := context.Background()
if err := magic.ManageSync(ctx, domains); err != nil {
return fmt.Errorf("failed to obtain certificates: %w", err)
}
log.Printf("Certificates obtained successfully for %v", domains)
return nil
}
func (tm *tlsManager) getTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: tm.getCertificate,
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
CurvePreferences: []tls.CurveID{
tls.X25519,
},
SessionTicketsDisabled: false,
ClientAuth: tls.NoClientCert,
}
}
func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if tm.useCertMagic {
return tm.magic.GetCertificate(hello)
}
tm.userCertMu.RLock()
defer tm.userCertMu.RUnlock()
if tm.userCert == nil {
return nil, fmt.Errorf("no certificate available")
}
return tm.userCert, nil
}
func validateCertDomains(certPath, domain string) bool {
cert, err := loadAndParseCertificate(certPath)
if err != nil {
return false
}
if !isCertificateValid(cert) {
return false
}
return certCoversRequiredDomains(cert, domain)
}
func loadAndParseCertificate(certPath string) (*x509.Certificate, error) {
certPEM, err := os.ReadFile(certPath)
if err != nil {
log.Printf("Failed to read certificate: %v", err)
return nil, err
}
block, _ := pem.Decode(certPEM)
if block == nil {
log.Printf("Failed to decode PEM block from certificate")
return nil, fmt.Errorf("failed to decode PEM block")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Printf("Failed to parse certificate: %v", err)
return nil, err
}
return cert, nil
}
func isCertificateValid(cert *x509.Certificate) bool {
now := time.Now()
if now.After(cert.NotAfter) {
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
return false
}
thirtyDaysFromNow := now.Add(30 * 24 * time.Hour)
if thirtyDaysFromNow.After(cert.NotAfter) {
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
return false
}
return true
}
func certCoversRequiredDomains(cert *x509.Certificate, domain string) bool {
certDomains := extractCertDomains(cert)
hasBase, hasWildcard := checkDomainCoverage(certDomains, domain)
logDomainCoverage(hasBase, hasWildcard, domain)
return hasBase && hasWildcard
}
func extractCertDomains(cert *x509.Certificate) []string {
var domains []string
if cert.Subject.CommonName != "" {
domains = append(domains, cert.Subject.CommonName)
}
domains = append(domains, cert.DNSNames...)
return domains
}
func checkDomainCoverage(certDomains []string, domain string) (hasBase, hasWildcard bool) {
wildcardDomain := "*." + domain
for _, d := range certDomains {
if d == domain {
hasBase = true
}
if d == wildcardDomain {
hasWildcard = true
}
}
return hasBase, hasWildcard
}
func logDomainCoverage(hasBase, hasWildcard bool, domain string) {
if !hasBase {
log.Printf("Certificate does not cover base domain: %s", domain)
}
if !hasWildcard {
log.Printf("Certificate does not cover wildcard domain: *.%s", domain)
}
}
type certWatcher struct {
tm *tlsManager
lastCertMod time.Time
lastKeyMod time.Time
}
func newCertWatcher(tm *tlsManager) *certWatcher {
watcher := &certWatcher{tm: tm}
watcher.initializeModTimes()
return watcher
}
func (cw *certWatcher) initializeModTimes() {
if info, err := os.Stat(cw.tm.certPath); err == nil {
cw.lastCertMod = info.ModTime()
}
if info, err := os.Stat(cw.tm.keyPath); err == nil {
cw.lastKeyMod = info.ModTime()
}
}
func (cw *certWatcher) watch() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
if cw.checkAndReloadCerts() {
return
}
}
}
func (cw *certWatcher) checkAndReloadCerts() bool {
certInfo, keyInfo, err := cw.getFileInfo()
if err != nil {
return false
}
if !cw.filesModified(certInfo, keyInfo) {
return false
}
return cw.handleCertificateChange(certInfo, keyInfo)
}
func (cw *certWatcher) getFileInfo() (os.FileInfo, os.FileInfo, error) {
certInfo, certErr := os.Stat(cw.tm.certPath)
keyInfo, keyErr := os.Stat(cw.tm.keyPath)
if certErr != nil || keyErr != nil {
return nil, nil, fmt.Errorf("file stat error")
}
return certInfo, keyInfo, nil
}
func (cw *certWatcher) filesModified(certInfo, keyInfo os.FileInfo) bool {
return certInfo.ModTime().After(cw.lastCertMod) || keyInfo.ModTime().After(cw.lastKeyMod)
}
func (cw *certWatcher) handleCertificateChange(certInfo, keyInfo os.FileInfo) bool {
log.Printf("Certificate files changed, reloading...")
if !validateCertDomains(cw.tm.certPath, cw.tm.config.Domain()) {
return cw.switchToCertMagic()
}
if err := cw.tm.loadUserCerts(); err != nil {
log.Printf("Failed to reload certificates: %v", err)
return false
}
cw.updateModTimes(certInfo, keyInfo)
log.Printf("Certificates reloaded successfully")
return false
}
func (cw *certWatcher) switchToCertMagic() bool {
log.Printf("New certificates don't cover required domains")
if err := cw.tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err)
return false
}
cw.tm.useCertMagic = true
return true
}
func (cw *certWatcher) updateModTimes(certInfo, keyInfo os.FileInfo) {
cw.lastCertMod = certInfo.ModTime()
cw.lastKeyMod = keyInfo.ModTime()
}
File diff suppressed because it is too large Load Diff
+14
View File
@@ -0,0 +1,14 @@
package transport
import (
"net"
)
type Transport interface {
Listen() (net.Listener, error)
Serve(listener net.Listener) error
}
type HTTP interface {
Handler(conn net.Conn, isTLS bool)
}
+84
View File
@@ -0,0 +1,84 @@
package version
import (
"fmt"
"testing"
)
func TestVersionFunctions(t *testing.T) {
origVersion := Version
origBuildDate := BuildDate
origCommit := Commit
defer func() {
Version = origVersion
BuildDate = origBuildDate
Commit = origCommit
}()
tests := []struct {
name string
version string
buildDate string
commit string
wantFull string
wantShort string
}{
{
name: "Default dev version",
version: "dev",
buildDate: "unknown",
commit: "unknown",
wantFull: "tunnel_pls dev (commit: unknown, built: unknown)",
wantShort: "dev",
},
{
name: "Release version",
version: "v1.0.0",
buildDate: "2026-01-23",
commit: "abcdef123",
wantFull: "tunnel_pls v1.0.0 (commit: abcdef123, built: 2026-01-23)",
wantShort: "v1.0.0",
},
{
name: "Empty values",
version: "",
buildDate: "",
commit: "",
wantFull: "tunnel_pls (commit: , built: )",
wantShort: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Version = tt.version
BuildDate = tt.buildDate
Commit = tt.commit
gotFull := GetVersion()
if gotFull != tt.wantFull {
t.Errorf("GetVersion() = %q, want %q", gotFull, tt.wantFull)
}
gotShort := GetShortVersion()
if gotShort != tt.wantShort {
t.Errorf("GetShortVersion() = %q, want %q", gotShort, tt.wantShort)
}
})
}
}
func TestGetVersion_Format(t *testing.T) {
v := "1.2.3"
c := "brainrot"
d := "now"
Version = v
Commit = c
BuildDate = d
expected := fmt.Sprintf("tunnel_pls %s (commit: %s, built: %s)", v, c, d)
if GetVersion() != expected {
t.Errorf("GetVersion() formatting mismatch")
}
}
+9 -42
View File
@@ -3,16 +3,11 @@ package main
import (
"fmt"
"log"
"net/http"
_ "net/http/pprof"
"os"
"tunnel_pls/internal/bootstrap"
"tunnel_pls/internal/config"
"tunnel_pls/internal/key"
"tunnel_pls/server"
"tunnel_pls/session"
"tunnel_pls/version"
"golang.org/x/crypto/ssh"
"tunnel_pls/internal/port"
"tunnel_pls/internal/version"
)
func main() {
@@ -23,47 +18,19 @@ func main() {
log.SetOutput(os.Stdout)
log.SetFlags(log.LstdFlags | log.Lshortfile)
log.Printf("Starting %s", version.GetVersion())
pprofEnabled := config.Getenv("PPROF_ENABLED", "false")
if pprofEnabled == "true" {
pprofPort := config.Getenv("PPROF_PORT", "6060")
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
}
sshConfig := &ssh.ServerConfig{
NoClientAuth: true,
ServerVersion: fmt.Sprintf("SSH-2.0-TunnlPls-%s", version.GetShortVersion()),
}
sshKeyPath := "certs/ssh/id_rsa"
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
log.Fatalf("Failed to generate SSH key: %s", err)
}
privateBytes, err := os.ReadFile(sshKeyPath)
conf, err := config.MustLoad()
if err != nil {
log.Fatalf("Failed to load private key: %s", err)
log.Fatalf("Config load error: %v", err)
}
private, err := ssh.ParsePrivateKey(privateBytes)
boot, err := bootstrap.New(conf, port.New())
if err != nil {
log.Fatalf("Failed to parse private key: %s", err)
log.Fatalf("Startup error: %v", err)
}
sshConfig.AddHostKey(private)
sessionRegistry := session.NewRegistry()
app, err := server.NewServer(sshConfig, sessionRegistry)
if err != nil {
log.Fatalf("Failed to start server: %s", err)
if err = boot.Run(); err != nil {
log.Fatalf("Application error: %v", err)
}
app.Start()
}
-8
View File
@@ -1,8 +0,0 @@
module.exports = {
"endpoint": "https://git.fossy.my.id/api/v1",
"gitAuthor": "Renovate-Clanker <renovate-bot@fossy.my.id>",
"platform": "gitea",
"onboardingConfigFileName": "renovate.json",
"autodiscover": true,
"optimizeForDisabled": true,
};
-276
View File
@@ -1,276 +0,0 @@
package server
import (
"bufio"
"bytes"
"fmt"
)
type HeaderManager interface {
Get(key string) []byte
Set(key string, value []byte)
Remove(key string)
Finalize() []byte
}
type ResponseHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type RequestHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
}
type responseHeaderFactory struct {
startLine []byte
headers map[string]string
}
type requestHeaderFactory struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) {
switch v := r.(type) {
case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
}
}
func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) {
header := &requestHeaderFactory{
headers: make(map[string]string, 16),
}
lineEnd := bytes.IndexByte(headerData, '\n')
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no newline found")
}
startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n")
header.startLine = make([]byte, len(startLine))
copy(header.startLine, startLine)
parts := bytes.Split(startLine, []byte{' '})
if len(parts) < 3 {
return nil, fmt.Errorf("invalid request line")
}
header.method = string(parts[0])
header.path = string(parts[1])
header.version = string(parts[2])
remaining := headerData[lineEnd+1:]
for len(remaining) > 0 {
lineEnd = bytes.IndexByte(remaining, '\n')
if lineEnd == -1 {
lineEnd = len(remaining)
}
line := bytes.TrimRight(remaining[:lineEnd], "\r\n")
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx != -1 {
key := bytes.TrimSpace(line[:colonIdx])
value := bytes.TrimSpace(line[colonIdx+1:])
header.headers[string(key)] = string(value)
}
if lineEnd == len(remaining) {
break
}
remaining = remaining[lineEnd+1:]
}
return header, nil
}
func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) {
header := &requestHeaderFactory{
headers: make(map[string]string, 16),
}
startLineBytes, err := br.ReadSlice('\n')
if err != nil {
if err == bufio.ErrBufferFull {
var startLine string
startLine, err = br.ReadString('\n')
if err != nil {
return nil, err
}
startLineBytes = []byte(startLine)
} else {
return nil, err
}
}
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
header.startLine = make([]byte, len(startLineBytes))
copy(header.startLine, startLineBytes)
parts := bytes.Split(startLineBytes, []byte{' '})
if len(parts) < 3 {
return nil, fmt.Errorf("invalid request line")
}
header.method = string(parts[0])
header.path = string(parts[1])
header.version = string(parts[2])
for {
lineBytes, err := br.ReadSlice('\n')
if err != nil {
if err == bufio.ErrBufferFull {
var line string
line, err = br.ReadString('\n')
if err != nil {
return nil, err
}
lineBytes = []byte(line)
} else {
return nil, err
}
}
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
if len(lineBytes) == 0 {
break
}
colonIdx := bytes.IndexByte(lineBytes, ':')
if colonIdx == -1 {
continue
}
key := bytes.TrimSpace(lineBytes[:colonIdx])
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
header.headers[string(key)] = string(value)
}
return header, nil
}
func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
header := &responseHeaderFactory{
startLine: nil,
headers: make(map[string]string),
}
lines := bytes.Split(startLine, []byte("\r\n"))
if len(lines) == 0 {
return header
}
header.startLine = lines[0]
for _, h := range lines[1:] {
if len(h) == 0 {
continue
}
parts := bytes.SplitN(h, []byte(":"), 2)
if len(parts) < 2 {
continue
}
key := parts[0]
val := bytes.TrimSpace(parts[1])
header.headers[string(key)] = string(val)
}
return header
}
func (resp *responseHeaderFactory) Get(key string) string {
return resp.headers[key]
}
func (resp *responseHeaderFactory) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeaderFactory) Remove(key string) {
delete(resp.headers, key)
}
func (resp *responseHeaderFactory) Finalize() []byte {
var buf bytes.Buffer
buf.Write(resp.startLine)
buf.WriteString("\r\n")
for key, val := range resp.headers {
buf.WriteString(key)
buf.WriteString(": ")
buf.WriteString(val)
buf.WriteString("\r\n")
}
buf.WriteString("\r\n")
return buf.Bytes()
}
func (req *requestHeaderFactory) Get(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeaderFactory) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeaderFactory) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeaderFactory) GetMethod() string {
return req.method
}
func (req *requestHeaderFactory) GetPath() string {
return req.path
}
func (req *requestHeaderFactory) GetVersion() string {
return req.version
}
func (req *requestHeaderFactory) Finalize() []byte {
var buf bytes.Buffer
buf.Write(req.startLine)
buf.WriteString("\r\n")
for key, val := range req.headers {
buf.WriteString(key)
buf.WriteString(": ")
buf.WriteString(val)
buf.WriteString("\r\n")
}
buf.WriteString("\r\n")
return buf.Bytes()
}
-391
View File
@@ -1,391 +0,0 @@
package server
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"net"
"regexp"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/session"
"golang.org/x/crypto/ssh"
)
type HTTPWriter interface {
io.Reader
io.Writer
GetRemoteAddr() net.Addr
GetWriter() io.Writer
AddResponseMiddleware(mw ResponseMiddleware)
AddRequestStartMiddleware(mw RequestMiddleware)
SetRequestHeader(header RequestHeaderManager)
GetRequestStartMiddleware() []RequestMiddleware
}
type customWriter struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader ResponseHeaderManager
reqHeader RequestHeaderManager
respMW []ResponseMiddleware
reqStartMW []RequestMiddleware
reqEndMW []RequestMiddleware
}
func (cw *customWriter) GetRemoteAddr() net.Addr {
return cw.remoteAddr
}
func (cw *customWriter) GetWriter() io.Writer {
return cw.writer
}
func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) {
cw.respMW = append(cw.respMW, mw)
}
func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) {
cw.reqStartMW = append(cw.reqStartMW, mw)
}
func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) {
cw.reqHeader = header
}
func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware {
return cw.reqStartMW
}
func (cw *customWriter) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := cw.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
idx := bytes.Index(tmp, DELIMITER)
if idx == -1 {
copy(p, tmp)
if err != nil {
return read, err
}
return read, nil
}
header := tmp[:idx+len(DELIMITER)]
body := tmp[idx+len(DELIMITER):]
if !isHTTPHeader(header) {
copy(p, tmp)
return read, nil
}
for _, m := range cw.reqEndMW {
err = m.HandleRequest(cw.reqHeader)
if err != nil {
log.Printf("Error when applying request middleware: %v", err)
return 0, err
}
}
reqhf, err := NewRequestHeaderFactory(header)
if err != nil {
return 0, err
}
for _, m := range cw.reqStartMW {
if mwErr := m.HandleRequest(reqhf); mwErr != nil {
log.Printf("Error when applying request middleware: %v", mwErr)
return 0, mwErr
}
}
cw.reqHeader = reqhf
finalHeader := reqhf.Finalize()
combined := append(finalHeader, body...)
n := copy(p, combined)
return n, nil
}
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
return &customWriter{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
func (cw *customWriter) Write(p []byte) (int, error) {
if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" {
cw.respHeader = nil
}
if cw.respHeader != nil {
n, err := cw.writer.Write(p)
if err != nil {
return n, err
}
return n, nil
}
cw.buf = append(cw.buf, p...)
idx := bytes.Index(cw.buf, DELIMITER)
if idx == -1 {
return len(p), nil
}
header := cw.buf[:idx+len(DELIMITER)]
body := cw.buf[idx+len(DELIMITER):]
if !isHTTPHeader(header) {
_, err := cw.writer.Write(cw.buf)
cw.buf = nil
if err != nil {
return 0, err
}
return len(p), nil
}
resphf := NewResponseHeaderFactory(header)
for _, m := range cw.respMW {
err := m.HandleResponse(resphf, body)
if err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return 0, err
}
}
header = resphf.Finalize()
cw.respHeader = resphf
_, err := cw.writer.Write(header)
if err != nil {
return 0, err
}
if len(body) > 0 {
_, err = cw.writer.Write(body)
if err != nil {
return 0, err
}
}
cw.buf = nil
return len(p), nil
}
var redirectTLS = false
type HTTPServer interface {
ListenAndServe() error
ListenAndServeTLS() error
handler(conn net.Conn)
handlerTLS(conn net.Conn)
}
type httpServer struct {
sessionRegistry session.Registry
}
func NewHTTPServer(sessionRegistry session.Registry) HTTPServer {
return &httpServer{sessionRegistry: sessionRegistry}
}
func (hs *httpServer) ListenAndServe() error {
httpPort := config.Getenv("HTTP_PORT", "8080")
listener, err := net.Listen("tcp", ":"+httpPort)
if err != nil {
return errors.New("Error listening: " + err.Error())
}
if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" {
redirectTLS = true
}
go func() {
for {
var conn net.Conn
conn, err = listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handler(conn)
}
}()
return nil
}
func (hs *httpServer) handler(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
return
}
return
}()
dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeaderFactory(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
host := strings.Split(reqhf.Get("Host"), ".")
if len(host) < 1 {
_, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
if err != nil {
log.Println("Failed to write 400 Bad Request:", err)
return
}
return
}
slug := host[0]
if redirectTLS {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
if slug == "ping" {
_, err = conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
return
}
return
}
sshSession, exist := hs.sessionRegistry.Get(slug)
if !exist {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
forwardRequest(cw, reqhf, sshSession)
return
}
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) {
payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
var channel ssh.Channel
var reqs <-chan *ssh.Request
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
channel = result.channel
reqs = result.reqs
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
go ssh.DiscardRequests(reqs)
fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr())
cw.AddResponseMiddleware(fingerprintMiddleware)
cw.AddRequestStartMiddleware(forwardedForMiddleware)
cw.SetRequestHeader(initialRequest)
for _, m := range cw.GetRequestStartMiddleware() {
if err := m.HandleRequest(initialRequest); err != nil {
log.Printf("Error handling request: %v", err)
return
}
}
_, err := channel.Write(initialRequest.Finalize())
if err != nil {
log.Printf("Failed to forward request: %v", err)
return
}
sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return
}
-108
View File
@@ -1,108 +0,0 @@
package server
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"strings"
"tunnel_pls/internal/config"
)
func (hs *httpServer) ListenAndServeTLS() error {
domain := config.Getenv("DOMAIN", "localhost")
httpsPort := config.Getenv("HTTPS_PORT", "8443")
tlsConfig, err := NewTLSConfig(domain)
if err != nil {
return fmt.Errorf("failed to initialize TLS config: %w", err)
}
ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig)
if err != nil {
return err
}
go func() {
for {
var conn net.Conn
conn, err = ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Println("https server closed")
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handlerTLS(conn)
}
}()
return nil
}
func (hs *httpServer) handlerTLS(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
log.Printf("Error closing connection: %v", err)
return
}
return
}()
dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeaderFactory(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
host := strings.Split(reqhf.Get("Host"), ".")
if len(host) < 1 {
_, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
if err != nil {
log.Println("Failed to write 400 Bad Request:", err)
return
}
return
}
slug := host[0]
if slug == "ping" {
_, err = conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
return
}
return
}
sshSession, exist := hs.sessionRegistry.Get(slug)
if !exist {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
forwardRequest(cw, reqhf, sshSession)
return
}
-41
View File
@@ -1,41 +0,0 @@
package server
import (
"net"
)
type RequestMiddleware interface {
HandleRequest(header RequestHeaderManager) error
}
type ResponseMiddleware interface {
HandleResponse(header ResponseHeaderManager, body []byte) error
}
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+70 -35
View File
@@ -1,55 +1,65 @@
package server
import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port"
"tunnel_pls/internal/random"
"tunnel_pls/internal/registry"
"tunnel_pls/session"
"golang.org/x/crypto/ssh"
)
type Server struct {
conn *net.Listener
config *ssh.ServerConfig
sessionRegistry session.Registry
type Server interface {
Start()
Close() error
}
type server struct {
randomizer random.Random
config config.Config
sshPort string
sshListener net.Listener
sshConfig *ssh.ServerConfig
grpcClient client.Client
sessionRegistry registry.Registry
portRegistry port.Port
}
func NewServer(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry) (*Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200")))
func New(randomizer random.Random, config config.Config, sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
if err != nil {
log.Fatalf("failed to listen on port 2200: %v", err)
return nil, err
}
HttpServer := NewHTTPServer(sessionRegistry)
err = HttpServer.ListenAndServe()
if err != nil {
log.Fatalf("failed to start http server: %v", err)
return nil, err
}
if config.Getenv("TLS_ENABLED", "false") == "true" {
err = HttpServer.ListenAndServeTLS()
if err != nil {
log.Fatalf("failed to start https server: %v", err)
return nil, err
}
}
return &Server{
conn: &listener,
config: sshConfig,
return &server{
randomizer: randomizer,
config: config,
sshPort: sshPort,
sshListener: listener,
sshConfig: sshConfig,
grpcClient: grpcClient,
sessionRegistry: sessionRegistry,
portRegistry: portRegistry,
}, nil
}
func (s *Server) Start() {
log.Println("SSH server is starting on port 2200...")
func (s *server) Start() {
log.Printf("SSH server is starting on port %s", s.sshPort)
for {
conn, err := (*s.conn).Accept()
conn, err := s.sshListener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Println("listener closed, stopping server")
return
}
log.Printf("failed to accept connection: %v", err)
continue
}
@@ -58,11 +68,15 @@ func (s *Server) Start() {
}
}
func (s *Server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
func (s *server) Close() error {
return s.sshListener.Close()
}
func (s *server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.sshConfig)
if err != nil {
log.Printf("failed to establish SSH connection: %v", err)
err := conn.Close()
err = conn.Close()
if err != nil {
log.Printf("failed to close SSH connection: %v", err)
return
@@ -70,13 +84,34 @@ func (s *Server) handleConnection(conn net.Conn) {
return
}
log.Println("SSH connection established:", sshConn.User())
defer func(sshConn *ssh.ServerConn) {
err = sshConn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
log.Printf("failed to close SSH server: %v", err)
}
}(sshConn)
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry)
user := "UNAUTHORIZED"
if s.grpcClient != nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
_, u, _ := s.grpcClient.AuthorizeConn(ctx, sshConn.User())
user = u
cancel()
}
log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(&session.Config{
Randomizer: s.randomizer,
Config: s.config,
Conn: sshConn,
InitialReq: forwardingReqs,
SshChan: chans,
SessionRegistry: s.sessionRegistry,
PortRegistry: s.portRegistry,
User: user,
})
err = sshSession.Start()
if err != nil {
log.Printf("SSH session ended with error: %v", err)
log.Printf("SSH session ended with error: %s", err.Error())
return
}
return
}
+880
View File
@@ -0,0 +1,880 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"net"
"testing"
"time"
"tunnel_pls/internal/registry"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
)
type MockRandom struct {
mock.Mock
}
func (m *MockRandom) String(length int) (string, error) {
args := m.Called(length)
return args.String(0), args.Error(1)
}
type MockConfig struct {
mock.Mock
}
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode {
args := m.Called()
if args.Get(0) == nil {
return 0
}
switch v := args.Get(0).(type) {
case types.ServerMode:
return v
case int:
return types.ServerMode(v)
default:
return types.ServerMode(args.Int(0))
}
}
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockGRPCClient struct {
mock.Mock
}
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
args := m.Called()
return args.Get(0).(*grpc.ClientConn)
}
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
args := m.Called(ctx, token)
return args.Bool(0), args.String(1), args.Error(2)
}
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
args := m.Called(ctx)
return args.Error(0)
}
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
args := m.Called(ctx, domain, token)
return args.Error(0)
}
func (m *MockGRPCClient) Close() error {
args := m.Called()
return args.Error(0)
}
type MockPort struct {
mock.Mock
}
func (m *MockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
return uint16(args.Int(0)), args.Bool(1)
}
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *MockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type MockListener struct {
mock.Mock
}
func (m *MockListener) Accept() (net.Conn, error) {
args := m.Called()
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(net.Conn), args.Error(1)
}
func (m *MockListener) Close() error {
return m.Called().Error(0)
}
func (m *MockListener) Addr() net.Addr {
return m.Called().Get(0).(net.Addr)
}
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
key, _ := rsa.GenerateKey(rand.Reader, 2048)
signer, _ := ssh.NewSignerFromKey(key)
config := &ssh.ServerConfig{
NoClientAuth: true,
}
config.AddHostKey(signer)
return config, signer
}
func TestNew(t *testing.T) {
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
tests := []struct {
name string
port string
wantErr bool
}{
{
name: "success",
port: "0",
wantErr: false,
},
{
name: "invalid port",
port: "invalid",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := New(mr, mc, sc, mreg, mg, mp, tt.port)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, s)
} else {
assert.NoError(t, err)
assert.NotNil(t, s)
_ = s.Close()
}
})
}
t.Run("port already in use", func(t *testing.T) {
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
port := l.Addr().(*net.TCPAddr).Port
defer func(l net.Listener) {
err = l.Close()
assert.NoError(t, err)
}(l)
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
assert.Error(t, err)
assert.Nil(t, s)
})
}
func TestClose(t *testing.T) {
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
t.Run("successful close", func(t *testing.T) {
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
err := s.Close()
assert.NoError(t, err)
})
t.Run("close already closed listener", func(t *testing.T) {
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
_ = s.Close()
err := s.Close()
assert.Error(t, err)
})
t.Run("close with nil listener", func(t *testing.T) {
s := &server{
sshListener: nil,
}
defer func() {
if r := recover(); r != nil {
assert.NotNil(t, r)
}
}()
_ = s.Close()
t.Fatal("expected panic for nil listener")
})
}
func TestStart(t *testing.T) {
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
t.Run("normal stop", func(t *testing.T) {
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
go func() {
time.Sleep(100 * time.Millisecond)
_ = s.Close()
}()
s.Start()
})
t.Run("accept error - temporary error continues loop", func(t *testing.T) {
ml := new(MockListener)
s := &server{
sshListener: ml,
sshPort: "0",
}
ml.On("Accept").Return(nil, errors.New("temporary error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
s.Start()
ml.AssertExpectations(t)
})
t.Run("accept error - immediate close", func(t *testing.T) {
ml := new(MockListener)
s := &server{
sshListener: ml,
sshPort: "0",
}
ml.On("Accept").Return(nil, net.ErrClosed).Once()
s.Start()
ml.AssertExpectations(t)
})
t.Run("accept success - connection fails SSH handshake", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
mockListener := &MockListener{}
mockListener.On("Accept").Return(serverConn, nil).Once()
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshListener: mockListener,
sshConfig: sshConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
go s.Start()
time.Sleep(50 * time.Millisecond)
err := clientConn.Close()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
})
t.Run("accept success - valid SSH connection without auth", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
mockListener := &MockListener{}
mockListener.On("Accept").Return(serverConn, nil).Once()
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshListener: mockListener,
sshConfig: sshConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
go s.Start()
time.Sleep(50 * time.Millisecond)
err := clientConn.Close()
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
})
}
func TestHandleConnection(t *testing.T) {
t.Run("SSH handshake fails - connection closed", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: sshConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
err := clientConn.Close()
assert.NoError(t, err)
s.handleConnection(serverConn)
})
// SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS
// GONNA IMPLEMENT THIS UNIT TEST LATER
//t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) {
// mockRandom := &MockRandom{}
// mockConfig := &MockConfig{}
// mockSessionRegistry := &MockSessionRegistry{}
// mockGrpcClient := &MockGRPCClient{}
// mockPort := &MockPort{}
//
// sshConfig, _ := getTestSSHConfig()
//
// serverConn, clientConn := net.Pipe()
//
// s := &server{
// randomizer: mockRandom,
// config: mockConfig,
// sshPort: "0",
// sshConfig: sshConfig,
// grpcClient: mockGrpcClient,
// sessionRegistry: mockSessionRegistry,
// portRegistry: mockPort,
// }
//
// done := make(chan bool, 1)
//
// go func() {
// s.handleConnection(serverConn)
// done <- true
// }()
//
// go func() {
// clientConn.Write([]byte("invalid ssh protocol\n"))
// clientConn.Close()
// }()
//
// select {
// case <-done:
// case <-time.After(1 * time.Second):
// t.Fatal("handleConnection did not complete in time")
// }
//})
t.Run("SSH connection established without gRPC client", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
mockConfig.On("Domain").Return("test.com")
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("SSHPort").Return("2200")
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer func(listener net.Listener) {
err = listener.Close()
assert.NoError(t, err)
}(listener)
serverAddr := listener.Addr().String()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
s.handleConnection(conn)
done <- true
}()
time.Sleep(50 * time.Millisecond)
clientConfig := &ssh.ClientConfig{
User: "testuser",
Auth: []ssh.AuthMethod{ssh.Password("password")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
go func() {
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
if err != nil {
t.Logf("Client dial failed: %v", err)
return
}
defer func(client *ssh.Client) {
err = client.Close()
assert.NoError(t, err)
}(client)
type forwardPayload struct {
BindAddr string
BindPort uint32
}
payload := ssh.Marshal(forwardPayload{
BindAddr: "localhost",
BindPort: 80,
})
_, _, err = client.SendRequest("tcpip-forward", true, payload)
if err != nil {
t.Logf("Forward request failed: %v", err)
}
time.Sleep(500 * time.Millisecond)
}()
select {
case <-done:
t.Log("handleConnection completed")
case <-time.After(5 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
t.Run("SSH connection established with gRPC authorization", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
mockConfig.On("Domain").Return("test.com")
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("SSHPort").Return("2200")
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer func(listener net.Listener) {
err = listener.Close()
assert.NoError(t, err)
}(listener)
serverAddr := listener.Addr().String()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
s.handleConnection(conn)
done <- true
}()
time.Sleep(50 * time.Millisecond)
clientConfig := &ssh.ClientConfig{
User: "testuser",
Auth: []ssh.AuthMethod{ssh.Password("password")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
go func() {
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
if err != nil {
t.Logf("Client dial failed: %v", err)
return
}
defer func(client *ssh.Client) {
err = client.Close()
assert.NoError(t, err)
}(client)
type forwardPayload struct {
BindAddr string
BindPort uint32
}
payload := ssh.Marshal(forwardPayload{
BindAddr: "localhost",
BindPort: 80,
})
_, _, err = client.SendRequest("tcpip-forward", true, payload)
if err != nil {
t.Logf("Forward request failed: %v", err)
}
time.Sleep(500 * time.Millisecond)
}()
select {
case <-done:
mockGrpcClient.AssertExpectations(t)
case <-time.After(5 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
t.Run("SSH connection with gRPC authorization error", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockGrpcClient := &MockGRPCClient{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
mockConfig.On("Domain").Return("test.com")
mockConfig.On("Mode").Return(types.ServerModeNODE)
mockConfig.On("SSHPort").Return("2200")
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer func(listener net.Listener) {
err = listener.Close()
assert.NoError(t, err)
}(listener)
serverAddr := listener.Addr().String()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: mockGrpcClient,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
return
}
s.handleConnection(conn)
done <- true
}()
time.Sleep(50 * time.Millisecond)
clientConfig := &ssh.ClientConfig{
User: "testuser",
Auth: []ssh.AuthMethod{ssh.Password("password")},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
go func() {
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
if err != nil {
t.Logf("Client dial failed: %v", err)
return
}
defer func(client *ssh.Client) {
_ = client.Close()
}(client)
type forwardPayload struct {
BindAddr string
BindPort uint32
}
payload := ssh.Marshal(forwardPayload{
BindAddr: "localhost",
BindPort: 8080,
})
_, _, err = client.SendRequest("tcpip-forward", true, payload)
if err != nil {
t.Logf("Forward request failed: %v", err)
}
time.Sleep(500 * time.Millisecond)
}()
select {
case <-done:
mockGrpcClient.AssertExpectations(t)
case <-time.After(5 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
t.Run("connection cleanup on close", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
serverConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: serverConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
done := make(chan bool, 1)
go func() {
s.handleConnection(serverConn)
done <- true
}()
err := clientConn.Close()
assert.NoError(t, err)
select {
case <-done:
case <-time.After(1 * time.Second):
t.Fatal("handleConnection did not complete in time")
}
})
}
func TestIntegration(t *testing.T) {
t.Run("full server lifecycle", func(t *testing.T) {
mr := new(MockRandom)
mc := new(MockConfig)
mreg := new(MockSessionRegistry)
mg := new(MockGRPCClient)
mp := new(MockPort)
sc, _ := getTestSSHConfig()
s, err := New(mr, mc, sc, mreg, mg, mp, "0")
assert.NoError(t, err)
assert.NotNil(t, s)
go func() {
time.Sleep(100 * time.Millisecond)
err := s.Close()
assert.NoError(t, err)
}()
s.Start()
})
t.Run("multiple connections", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
conn1Server, conn1Client := net.Pipe()
conn2Server, conn2Client := net.Pipe()
mockListener := &MockListener{}
mockListener.On("Accept").Return(conn1Server, nil).Once()
mockListener.On("Accept").Return(conn2Server, nil).Once()
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshListener: mockListener,
sshConfig: sshConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
go s.Start()
time.Sleep(50 * time.Millisecond)
_ = conn1Client.Close()
time.Sleep(50 * time.Millisecond)
_ = conn2Client.Close()
time.Sleep(100 * time.Millisecond)
mockListener.AssertExpectations(t)
})
}
func TestErrorHandling(t *testing.T) {
t.Run("write error during SSH handshake", func(t *testing.T) {
mockRandom := &MockRandom{}
mockConfig := &MockConfig{}
mockSessionRegistry := &MockSessionRegistry{}
mockPort := &MockPort{}
sshConfig, _ := getTestSSHConfig()
serverConn, clientConn := net.Pipe()
err := clientConn.Close()
assert.NoError(t, err)
s := &server{
randomizer: mockRandom,
config: mockConfig,
sshPort: "0",
sshConfig: sshConfig,
grpcClient: nil,
sessionRegistry: mockSessionRegistry,
portRegistry: mockPort,
}
s.handleConnection(serverConn)
})
}
-336
View File
@@ -1,336 +0,0 @@
package server
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"log"
"os"
"sync"
"time"
"tunnel_pls/internal/config"
"github.com/caddyserver/certmagic"
"github.com/libdns/cloudflare"
)
type TLSManager interface {
userCertsExistAndValid() bool
loadUserCerts() error
startCertWatcher()
initCertMagic() error
getTLSConfig() *tls.Config
getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
}
type tlsManager struct {
domain string
certPath string
keyPath string
storagePath string
userCert *tls.Certificate
userCertMu sync.RWMutex
magic *certmagic.Config
useCertMagic bool
}
var globalTLSManager TLSManager
var tlsManagerOnce sync.Once
func NewTLSConfig(domain string) (*tls.Config, error) {
var initErr error
tlsManagerOnce.Do(func() {
certPath := "certs/tls/cert.pem"
keyPath := "certs/tls/privkey.pem"
storagePath := "certs/tls/certmagic"
tm := &tlsManager{
domain: domain,
certPath: certPath,
keyPath: keyPath,
storagePath: storagePath,
}
if tm.userCertsExistAndValid() {
log.Printf("Using user-provided certificates from %s and %s", certPath, keyPath)
if err := tm.loadUserCerts(); err != nil {
initErr = fmt.Errorf("failed to load user certificates: %w", err)
return
}
tm.useCertMagic = false
tm.startCertWatcher()
} else {
if !isACMEConfigComplete() {
log.Printf("User certificates missing or invalid, and ACME configuration is incomplete")
log.Printf("To enable automatic certificate generation, set CF_API_TOKEN environment variable")
initErr = fmt.Errorf("no valid certificates found and ACME configuration is incomplete (CF_API_TOKEN is required)")
return
}
log.Printf("User certificates missing or don't cover %s and *.%s, using CertMagic", domain, domain)
if err := tm.initCertMagic(); err != nil {
initErr = fmt.Errorf("failed to initialize CertMagic: %w", err)
return
}
tm.useCertMagic = true
}
globalTLSManager = tm
})
if initErr != nil {
return nil, initErr
}
return globalTLSManager.getTLSConfig(), nil
}
func isACMEConfigComplete() bool {
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
return cfAPIToken != ""
}
func (tm *tlsManager) userCertsExistAndValid() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath)
return false
}
if _, err := os.Stat(tm.keyPath); os.IsNotExist(err) {
log.Printf("Key file not found: %s", tm.keyPath)
return false
}
return ValidateCertDomains(tm.certPath, tm.domain)
}
func ValidateCertDomains(certPath, domain string) bool {
certPEM, err := os.ReadFile(certPath)
if err != nil {
log.Printf("Failed to read certificate: %v", err)
return false
}
block, _ := pem.Decode(certPEM)
if block == nil {
log.Printf("Failed to decode PEM block from certificate")
return false
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
log.Printf("Failed to parse certificate: %v", err)
return false
}
if time.Now().After(cert.NotAfter) {
log.Printf("Certificate has expired (NotAfter: %v)", cert.NotAfter)
return false
}
if time.Now().Add(30 * 24 * time.Hour).After(cert.NotAfter) {
log.Printf("Certificate expiring soon (NotAfter: %v), will use CertMagic for renewal", cert.NotAfter)
return false
}
var certDomains []string
if cert.Subject.CommonName != "" {
certDomains = append(certDomains, cert.Subject.CommonName)
}
certDomains = append(certDomains, cert.DNSNames...)
hasBase := false
hasWildcard := false
wildcardDomain := "*." + domain
for _, d := range certDomains {
if d == domain {
hasBase = true
}
if d == wildcardDomain {
hasWildcard = true
}
}
if !hasBase {
log.Printf("Certificate does not cover base domain: %s", domain)
}
if !hasWildcard {
log.Printf("Certificate does not cover wildcard domain: %s", wildcardDomain)
}
return hasBase && hasWildcard
}
func (tm *tlsManager) loadUserCerts() error {
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
if err != nil {
return err
}
tm.userCertMu.Lock()
tm.userCert = &cert
tm.userCertMu.Unlock()
log.Printf("Loaded user certificates successfully")
return nil
}
func (tm *tlsManager) startCertWatcher() {
go func() {
var lastCertMod, lastKeyMod time.Time
if info, err := os.Stat(tm.certPath); err == nil {
lastCertMod = info.ModTime()
}
if info, err := os.Stat(tm.keyPath); err == nil {
lastKeyMod = info.ModTime()
}
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for range ticker.C {
certInfo, certErr := os.Stat(tm.certPath)
keyInfo, keyErr := os.Stat(tm.keyPath)
if certErr != nil || keyErr != nil {
continue
}
if certInfo.ModTime().After(lastCertMod) || keyInfo.ModTime().After(lastKeyMod) {
log.Printf("Certificate files changed, reloading...")
if !ValidateCertDomains(tm.certPath, tm.domain) {
log.Printf("New certificates don't cover required domains")
if !isACMEConfigComplete() {
log.Printf("Cannot switch to CertMagic: ACME configuration is incomplete (CF_API_TOKEN is required)")
continue
}
log.Printf("Switching to CertMagic for automatic certificate management")
if err := tm.initCertMagic(); err != nil {
log.Printf("Failed to initialize CertMagic: %v", err)
continue
}
tm.useCertMagic = true
return
}
if err := tm.loadUserCerts(); err != nil {
log.Printf("Failed to reload certificates: %v", err)
continue
}
lastCertMod = certInfo.ModTime()
lastKeyMod = keyInfo.ModTime()
log.Printf("Certificates reloaded successfully")
}
}
}()
}
func (tm *tlsManager) initCertMagic() error {
if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
return fmt.Errorf("failed to create cert storage directory: %w", err)
}
acmeEmail := config.Getenv("ACME_EMAIL", "admin@"+tm.domain)
cfAPIToken := config.Getenv("CF_API_TOKEN", "")
acmeStaging := config.Getenv("ACME_STAGING", "false") == "true"
if cfAPIToken == "" {
return fmt.Errorf("CF_API_TOKEN environment variable is required for automatic certificate generation")
}
cfProvider := &cloudflare.Provider{
APIToken: cfAPIToken,
}
storage := &certmagic.FileStorage{Path: tm.storagePath}
cache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
return tm.magic, nil
},
})
magic := certmagic.New(cache, certmagic.Config{
Storage: storage,
})
acmeIssuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{
Email: acmeEmail,
Agreed: true,
DNS01Solver: &certmagic.DNS01Solver{
DNSManager: certmagic.DNSManager{
DNSProvider: cfProvider,
},
},
})
if acmeStaging {
acmeIssuer.CA = certmagic.LetsEncryptStagingCA
log.Printf("Using Let's Encrypt staging server")
} else {
acmeIssuer.CA = certmagic.LetsEncryptProductionCA
log.Printf("Using Let's Encrypt production server")
}
magic.Issuers = []certmagic.Issuer{acmeIssuer}
tm.magic = magic
domains := []string{tm.domain, "*." + tm.domain}
log.Printf("Requesting certificates for: %v", domains)
ctx := context.Background()
if err := magic.ManageSync(ctx, domains); err != nil {
return fmt.Errorf("failed to obtain certificates: %w", err)
}
log.Printf("Certificates obtained successfully for %v", domains)
return nil
}
func (tm *tlsManager) getTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: tm.getCertificate,
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
SessionTicketsDisabled: false,
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_CHACHA20_POLY1305_SHA256,
},
CurvePreferences: []tls.CurveID{
tls.X25519,
},
ClientAuth: tls.NoClientCert,
NextProtos: nil,
}
}
func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if tm.useCertMagic {
return tm.magic.GetCertificate(hello)
}
tm.userCertMu.RLock()
defer tm.userCertMu.RUnlock()
if tm.userCert == nil {
return nil, fmt.Errorf("no certificate available")
}
return tm.userCert, nil
}
+112 -176
View File
@@ -1,15 +1,14 @@
package forwarder
import (
"bytes"
"encoding/binary"
"context"
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"sync"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/session/slug"
"tunnel_pls/types"
@@ -17,239 +16,176 @@ import (
"golang.org/x/crypto/ssh"
)
var bufferPool = sync.Pool{
New: func() interface{} {
bufSize := config.GetBufferSize()
return make([]byte, bufSize)
},
type Forwarder interface {
SetType(tunnelType types.TunnelType)
SetForwardedPort(port uint16)
SetListener(listener net.Listener)
Listener() net.Listener
TunnelType() types.TunnelType
ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel)
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
Close() error
}
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
type Forwarder struct {
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slugManager slug.Manager
lifecycle Lifecycle
slug slug.Slug
conn ssh.Conn
bufferPool sync.Pool
}
func NewForwarder(slugManager slug.Manager) *Forwarder {
return &Forwarder{
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: "",
tunnelType: types.TunnelTypeUNKNOWN,
forwardedPort: 0,
slugManager: slugManager,
lifecycle: nil,
slug: slug,
conn: conn,
bufferPool: sync.Pool{
New: func() interface{} {
bufSize := config.BufferSize()
buf := make([]byte, bufSize)
return &buf
},
},
}
}
type Lifecycle interface {
GetConnection() ssh.Conn
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := f.bufferPool.Get().(*[]byte)
defer f.bufferPool.Put(buf)
return io.CopyBuffer(dst, src, *buf)
}
type ForwardingController interface {
AcceptTCPConnections()
SetType(tunnelType types.TunnelType)
GetTunnelType() types.TunnelType
GetForwardedPort() uint16
SetForwardedPort(port uint16)
SetListener(listener net.Listener)
GetListener() net.Listener
Close() error
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
SetLifecycle(lifecycle Lifecycle)
CreateForwardedTCPIPPayload(origin net.Addr) []byte
WriteBadGatewayResponse(dst io.Writer)
}
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
f.lifecycle = lifecycle
}
func (f *Forwarder) AcceptTCPConnections() {
for {
conn, err := f.GetListener().Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Printf("Failed to set connection deadline: %v", err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
payload := createForwardedTCPIPPayload(origin, f.forwardedPort)
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
if err := conn.SetDeadline(time.Time{}); err != nil {
log.Printf("Failed to clear connection deadline: %v", err)
}
go ssh.DiscardRequests(result.reqs)
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
}
}
}
func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
defer func() {
_, err := io.Copy(io.Discard, src)
if err != nil {
log.Printf("Failed to discard connection: %v", err)
}
err = src.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing source channel: %v", err)
}
if closer, ok := dst.(io.Closer); ok {
err = closer.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing destination connection: %v", err)
case resultChan <- channelResult{channel, reqs, err}:
case <-ctx.Done():
if channel != nil {
_ = channel.Close()
go ssh.DiscardRequests(reqs)
}
}
}()
log.Printf("Handling new forwarded connection from %s", remoteAddr)
select {
case result := <-resultChan:
return result.channel, result.reqs, result.err
case <-ctx.Done():
return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err())
}
}
func closeWriter(w io.Writer) error {
if cw, ok := w.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
if closer, ok := w.(io.Closer); ok {
return closer.Close()
}
return nil
}
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
var errs []error
_, err := f.copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
}
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
}
return errors.Join(errs...)
}
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
defer func() {
_, _ = io.Copy(io.Discard, src)
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, err := copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying src→dst: %v", err)
err := f.copyAndClose(dst, src, "src to dst")
if err != nil {
log.Println("Error during copy: ", err)
return
}
}()
go func() {
defer wg.Done()
_, err := copyWithBuffer(src, dst)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying dst→src: %v", err)
err := f.copyAndClose(src, dst, "dst to src")
if err != nil {
log.Println("Error during copy: ", err)
return
}
}()
wg.Wait()
}
func (f *Forwarder) SetType(tunnelType types.TunnelType) {
func (f *forwarder) SetType(tunnelType types.TunnelType) {
f.tunnelType = tunnelType
}
func (f *Forwarder) GetTunnelType() types.TunnelType {
func (f *forwarder) TunnelType() types.TunnelType {
return f.tunnelType
}
func (f *Forwarder) GetForwardedPort() uint16 {
func (f *forwarder) ForwardedPort() uint16 {
return f.forwardedPort
}
func (f *Forwarder) SetForwardedPort(port uint16) {
func (f *forwarder) SetForwardedPort(port uint16) {
f.forwardedPort = port
}
func (f *Forwarder) SetListener(listener net.Listener) {
func (f *forwarder) SetListener(listener net.Listener) {
f.listener = listener
}
func (f *Forwarder) GetListener() net.Listener {
func (f *forwarder) Listener() net.Listener {
return f.listener
}
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
_, err := dst.Write(types.BadGatewayResponse)
if err != nil {
log.Printf("failed to write Bad Gateway response: %v", err)
return
}
}
func (f *Forwarder) Close() error {
if f.GetListener() != nil {
func (f *forwarder) Close() error {
if f.Listener() != nil {
return f.listener.Close()
}
return nil
}
func (f *Forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
var buf bytes.Buffer
host, originPort := parseAddr(origin.String())
writeSSHString(&buf, "localhost")
err := binary.Write(&buf, binary.BigEndian, uint32(f.GetForwardedPort()))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
}
writeSSHString(&buf, host)
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return nil
}
return buf.Bytes()
}
func parseAddr(addr string) (string, uint16) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
return "0.0.0.0", uint16(0)
}
func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte {
host, portStr, _ := net.SplitHostPort(origin.String())
port, _ := strconv.Atoi(portStr)
return host, uint16(port)
}
func writeSSHString(buffer *bytes.Buffer, str string) {
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
if err != nil {
log.Printf("Failed to write string to buffer: %v", err)
return
forwardPayload := struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}{
DestAddr: "localhost",
DestPort: uint32(destPort),
OriginAddr: host,
OriginPort: uint32(port),
}
buffer.WriteString(str)
return ssh.Marshal(forwardPayload)
}
File diff suppressed because it is too large Load Diff
-291
View File
@@ -1,291 +0,0 @@
package session
import (
"bytes"
"encoding/binary"
"fmt"
"log"
"net"
portUtil "tunnel_pls/internal/port"
"tunnel_pls/internal/random"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest {
switch req.Type {
case "shell", "pty-req":
err := req.Reply(true, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
case "window-change":
p := req.Payload
if len(p) < 16 {
log.Println("invalid window-change payload")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
return
}
cols := binary.BigEndian.Uint32(p[0:4])
rows := binary.BigEndian.Uint32(p[4:8])
s.interaction.SetWH(int(cols), int(rows))
err := req.Reply(true, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
default:
log.Println("Unknown request type:", req.Type)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
}
}
}
func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected")
reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader)
if err != nil {
log.Println("Failed to read address from payload:", err)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
var rawPortToBind uint32
if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
log.Println("Failed to read port from payload:", err)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
if rawPortToBind > 65535 {
log.Printf("Port %d is larger than allowed port of 65535", rawPortToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) {
log.Printf("Port %d is blocked or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
if portToBind == 80 || portToBind == 443 {
s.HandleHTTPForward(req, portToBind)
return
}
if portToBind == 0 {
unassign, success := portUtil.Default.GetUnassignedPort()
portToBind = unassign
if !success {
log.Println("No available port")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
log.Printf("Port %d is already in use or restricted", portToBind)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
err = portUtil.Default.SetPortStatus(portToBind, true)
if err != nil {
log.Println("Failed to set port status:", err)
return
}
s.HandleTCPForward(req, addr, portToBind)
}
func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
slug := random.GenerateRandomString(20)
if !s.registry.Register(slug, s) {
log.Printf("Failed to register client with slug: %s", slug)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return
}
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil {
log.Println("Failed to write port to buffer:", err)
s.registry.Remove(slug)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return
}
log.Printf("HTTP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes())
if err != nil {
log.Println("Failed to reply to request:", err)
s.registry.Remove(slug)
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
}
return
}
s.forwarder.SetType(types.HTTP)
s.forwarder.SetForwardedPort(portToBind)
s.slugManager.Set(slug)
s.lifecycle.SetStatus(types.RUNNING)
s.interaction.Start()
}
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil {
log.Printf("Port %d is already in use or restricted", portToBind)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
err = s.lifecycle.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
if err != nil {
log.Println("Failed to write port to buffer:", err)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return
}
log.Printf("TCP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes())
if err != nil {
log.Println("Failed to reply to request:", err)
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
err = listener.Close()
if err != nil {
log.Printf("Failed to close listener: %s", err)
return
}
return
}
s.forwarder.SetType(types.TCP)
s.forwarder.SetListener(listener)
s.forwarder.SetForwardedPort(portToBind)
s.lifecycle.SetStatus(types.RUNNING)
go s.forwarder.AcceptTCPConnections()
s.interaction.Start()
}
func readSSHString(reader *bytes.Reader) (string, error) {
var length uint32
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
return "", err
}
strBytes := make([]byte, length)
if _, err := reader.Read(strBytes); err != nil {
return "", err
}
return string(strBytes), nil
}
func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 {
return false
}
if port < 1024 && port != 0 {
return true
}
for _, p := range blockedReservedPorts {
if p == port {
return true
}
}
return false
}
+83
View File
@@ -0,0 +1,83 @@
package interaction
import (
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) comingSoonUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.showingComingSoon = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
func (m *model) comingSoonView() string {
isCompact := shouldUseCompactLayout(m.width, 60)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 3
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
messageBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Background(lipgloss.Color("#1A1A2E")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(messageBoxWidth).
Align(lipgloss.Center)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Coming Soon"
} else {
title = "⏳ Coming Soon"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
var message string
if shouldUseCompactLayout(m.width, 50) {
message = "Coming soon!\nStay tuned."
} else {
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
}
b.WriteString(messageBoxStyle.Render(message))
b.WriteString("\n\n")
var helpText string
if shouldUseCompactLayout(m.width, 60) {
helpText = "Press any key..."
} else {
helpText = "This message will disappear in 5 seconds or press any key..."
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
+86
View File
@@ -0,0 +1,86 @@
package interaction
import (
"strings"
"time"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) handleCommandSelection(item commandItem) (tea.Model, tea.Cmd) {
switch item.name {
case "slug":
m.showingCommands = false
m.editingSlug = true
m.slugInput.SetValue(m.interaction.slug.String())
m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "tunnel-type":
m.showingCommands = false
m.showingComingSoon = true
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
default:
m.showingCommands = false
return m, nil
}
}
func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch {
case key.Matches(msg, m.keymap.quit), msg.String() == "esc":
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case msg.String() == "enter":
selectedItem := m.commandList.SelectedItem()
if selectedItem != nil {
item := selectedItem.(commandItem)
return m.handleCommandSelection(item)
}
}
m.commandList, cmd = m.commandList.Update(msg)
return m, cmd
}
func (m *model) commandsView() string {
isCompact := shouldUseCompactLayout(m.width, 60)
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Commands"
} else {
title = "⚡ Commands"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
b.WriteString(m.commandList.View())
b.WriteString("\n")
var helpText string
if isCompact {
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
} else {
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
-152
View File
@@ -1,152 +0,0 @@
package interaction
const (
backspaceChar = 8
deleteChar = 127
enterChar = 13
escapeChar = 27
ctrlC = 3
forwardSlash = '/'
minPrintableChar = 32
maxPrintableChar = 126
minSlugLength = 3
maxSlugLength = 20
clearScreen = "\033[H\033[2J"
clearLine = "\033[K"
clearToLineEnd = "\r\033[K"
backspaceSeq = "\b \b"
minBoxWidth = 50
paddingRight = 4
)
var forbiddenSlugs = map[string]struct{}{
"ping": {},
"staging": {},
"admin": {},
"root": {},
"api": {},
"www": {},
"support": {},
"help": {},
"status": {},
"health": {},
"login": {},
"logout": {},
"signup": {},
"register": {},
"settings": {},
"config": {},
"null": {},
"undefined": {},
"example": {},
"test": {},
"dev": {},
"system": {},
"administrator": {},
"dashboard": {},
"account": {},
"profile": {},
"user": {},
"users": {},
"auth": {},
"oauth": {},
"callback": {},
"webhook": {},
"webhooks": {},
"static": {},
"assets": {},
"cdn": {},
"mail": {},
"email": {},
"ftp": {},
"ssh": {},
"git": {},
"svn": {},
"blog": {},
"news": {},
"about": {},
"contact": {},
"terms": {},
"privacy": {},
"legal": {},
"billing": {},
"payment": {},
"checkout": {},
"cart": {},
"shop": {},
"store": {},
"download": {},
"uploads": {},
"images": {},
"img": {},
"css": {},
"js": {},
"fonts": {},
"public": {},
"private": {},
"internal": {},
"external": {},
"proxy": {},
"cache": {},
"debug": {},
"metrics": {},
"monitoring": {},
"graphql": {},
"rest": {},
"rpc": {},
"socket": {},
"ws": {},
"wss": {},
"app": {},
"apps": {},
"mobile": {},
"desktop": {},
"embed": {},
"widget": {},
"docs": {},
"documentation": {},
"wiki": {},
"forum": {},
"community": {},
"feedback": {},
"report": {},
"abuse": {},
"spam": {},
"security": {},
"verify": {},
"confirm": {},
"reset": {},
"password": {},
"recovery": {},
"unsubscribe": {},
"subscribe": {},
"notifications": {},
"alerts": {},
"messages": {},
"inbox": {},
"outbox": {},
"sent": {},
"draft": {},
"trash": {},
"archive": {},
"search": {},
"explore": {},
"discover": {},
"trending": {},
"popular": {},
"featured": {},
"new": {},
"latest": {},
"top": {},
"best": {},
"hot": {},
"random": {},
"all": {},
"any": {},
"none": {},
"true": {},
"false": {},
}
+216
View File
@@ -0,0 +1,216 @@
package interaction
import (
"fmt"
"strings"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) dashboardUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch {
case key.Matches(msg, m.keymap.quit):
m.quitting = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
case key.Matches(msg, m.keymap.command):
m.showingCommands = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
return m, nil
}
func (m *model) dashboardView() string {
isCompact := shouldUseCompactLayout(m.width, BreakpointLarge)
var b strings.Builder
b.WriteString(m.renderHeader(isCompact))
b.WriteString(m.renderUserInfo(isCompact))
b.WriteString(m.renderQuickActions(isCompact))
b.WriteString(m.renderFooter(isCompact))
return b.String()
}
func (m *model) renderHeader(isCompact bool) string {
var b strings.Builder
asciiArtMargin := getMarginValue(isCompact, 0, 1)
asciiArtStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color(ColorPrimary)).
MarginBottom(asciiArtMargin)
b.WriteString(asciiArtStyle.Render(m.getASCIIArt()))
b.WriteString("\n")
if !shouldUseCompactLayout(m.width, BreakpointSmall) {
b.WriteString(m.renderSubtitle())
} else {
b.WriteString("\n")
}
return b.String()
}
func (m *model) getASCIIArt() string {
if shouldUseCompactLayout(m.width, BreakpointTiny) {
return "TUNNEL PLS"
}
if shouldUseCompactLayout(m.width, BreakpointLarge) {
return `
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
}
return `
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
}
func (m *model) renderSubtitle() string {
subtitleStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorGray)).
Italic(true)
urlStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorPrimary)).
Underline(true).
Italic(true)
return subtitleStyle.Render("Secure tunnel service by Bagas • ") +
urlStyle.Render("https://fossy.my.id") + "\n\n"
}
func (m *model) renderUserInfo(isCompact bool) string {
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
boxPadding := getMarginValue(isCompact, 1, 2)
boxMargin := getMarginValue(isCompact, 1, 2)
responsiveInfoBox := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorPrimary)).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(boxMaxWidth)
infoContent := m.getUserInfoContent(isCompact)
return responsiveInfoBox.Render(infoContent) + "\n"
}
func (m *model) getUserInfoContent(isCompact bool) string {
userInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWhite)).
Bold(true)
sectionHeaderStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorGray)).
Bold(true)
addressStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWhite))
urlBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorSecondary)).
Bold(true).
Italic(true)
authenticatedUser := m.interaction.user
tunnelURL := urlBoxStyle.Render(m.getTunnelURL())
if isCompact {
return fmt.Sprintf("👤 %s\n\n%s\n%s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(fmt.Sprintf(" %s", tunnelURL)))
}
return fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(tunnelURL))
}
func (m *model) renderQuickActions(isCompact bool) string {
var b strings.Builder
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color(ColorPrimary)).
PaddingTop(1)
b.WriteString(titleStyle.Render(m.getQuickActionsTitle()))
b.WriteString("\n")
featureMargin := getMarginValue(isCompact, 1, 2)
featureStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWhite)).
MarginLeft(featureMargin)
keyHintStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorPrimary)).
Bold(true)
commands := m.getActionCommands(keyHintStyle)
b.WriteString(featureStyle.Render(commands.commandsText))
b.WriteString("\n")
b.WriteString(featureStyle.Render(commands.quitText))
return b.String()
}
func (m *model) getQuickActionsTitle() string {
if shouldUseCompactLayout(m.width, BreakpointTiny) {
return "Actions"
}
if shouldUseCompactLayout(m.width, BreakpointLarge) {
return "Quick Actions"
}
return "✨ Quick Actions"
}
type actionCommands struct {
commandsText string
quitText string
}
func (m *model) getActionCommands(keyHintStyle lipgloss.Style) actionCommands {
if shouldUseCompactLayout(m.width, BreakpointSmall) {
return actionCommands{
commandsText: fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]")),
quitText: fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]")),
}
}
return actionCommands{
commandsText: fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]")),
quitText: fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]")),
}
}
func (m *model) renderFooter(isCompact bool) string {
if isCompact {
return ""
}
footerStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorDarkGray)).
Italic(true)
return "\n\n" + footerStyle.Render("Press 'C' to customize your tunnel settings")
}
func getMarginValue(isCompact bool, compactValue, normalValue int) int {
if isCompact {
return compactValue
}
return normalValue
}
+99 -672
View File
@@ -2,10 +2,8 @@ package interaction
import (
"context"
"fmt"
"log"
"strings"
"time"
"sync"
"tunnel_pls/internal/config"
"tunnel_pls/internal/random"
"tunnel_pls/session/slug"
@@ -21,36 +19,59 @@ import (
"golang.org/x/crypto/ssh"
)
type Lifecycle interface {
Close() error
type Interaction interface {
Mode() types.InteractiveMode
SetChannel(channel ssh.Channel)
SetMode(m types.InteractiveMode)
SetWH(w, h int)
Start()
Redraw()
Send(message string) error
}
type Controller interface {
SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle)
SetSlugModificator(func(oldSlug, newSlug string) bool)
Start()
SetWH(w, h int)
type SessionRegistry interface {
Update(user string, oldKey, newKey types.SessionKey) error
}
type Forwarder interface {
Close() error
GetTunnelType() types.TunnelType
GetForwardedPort() uint16
TunnelType() types.TunnelType
ForwardedPort() uint16
}
type Interaction struct {
channel ssh.Channel
slugManager slug.Manager
forwarder Forwarder
lifecycle Lifecycle
updateClientSlug func(oldSlug, newSlug string) bool
program *tea.Program
ctx context.Context
cancel context.CancelFunc
type CloseFunc func() error
type interaction struct {
randomizer random.Random
config config.Config
channel ssh.Channel
slug slug.Slug
forwarder Forwarder
closeFunc CloseFunc
user string
sessionRegistry SessionRegistry
program *tea.Program
ctx context.Context
cancel context.CancelFunc
mode types.InteractiveMode
programMu sync.Mutex
}
func (i *Interaction) SetWH(w, h int) {
func (i *interaction) SetMode(m types.InteractiveMode) {
i.mode = m
}
func (i *interaction) Mode() types.InteractiveMode {
return i.mode
}
func (i *interaction) Send(message string) error {
if i.channel != nil {
_, err := i.channel.Write([]byte(message))
return err
}
return nil
}
func (i *interaction) SetWH(w, h int) {
if i.program != nil {
i.program.Send(tea.WindowSizeMsg{
Width: w,
@@ -59,116 +80,42 @@ func (i *Interaction) SetWH(w, h int) {
}
}
type commandItem struct {
name string
desc string
}
type model struct {
tunnelURL string
domain string
protocol string
tunnelType types.TunnelType
port uint16
keymap keymap
help help.Model
quitting bool
showingCommands bool
editingSlug bool
showingComingSoon bool
commandList list.Model
slugInput textinput.Model
slugError string
interaction *Interaction
width int
height int
}
type keymap struct {
quit key.Binding
command key.Binding
random key.Binding
}
type tickMsg time.Time
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
func New(randomizer random.Random, config config.Config, slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
ctx, cancel := context.WithCancel(context.Background())
return &Interaction{
channel: nil,
slugManager: slugManager,
forwarder: forwarder,
lifecycle: nil,
updateClientSlug: nil,
program: nil,
ctx: ctx,
cancel: cancel,
return &interaction{
randomizer: randomizer,
config: config,
channel: nil,
slug: slug,
forwarder: forwarder,
closeFunc: closeFunc,
user: user,
sessionRegistry: sessionRegistry,
program: nil,
ctx: ctx,
cancel: cancel,
}
}
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle
}
func (i *Interaction) SetChannel(channel ssh.Channel) {
func (i *interaction) SetChannel(channel ssh.Channel) {
i.channel = channel
}
func (i *Interaction) SetSlugModificator(modificator func(oldSlug, newSlug string) (success bool)) {
i.updateClientSlug = modificator
}
func (i *Interaction) Stop() {
func (i *interaction) Stop() {
if i.cancel != nil {
i.cancel()
}
i.programMu.Lock()
defer i.programMu.Unlock()
if i.program != nil {
i.program.Kill()
i.program = nil
}
}
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
width := screenWidth - padding
if width > maxWidth {
width = maxWidth
}
if width < minWidth {
width = minWidth
}
return width
}
func shouldUseCompactLayout(width int, threshold int) bool {
return width < threshold
}
func truncateString(s string, maxLength int) string {
if len(s) <= maxLength {
return s
}
if maxLength < 4 {
return s[:maxLength]
}
return s[:maxLength-3] + "..."
}
func (i commandItem) FilterValue() string { return i.name }
func (i commandItem) Title() string { return i.name }
func (i commandItem) Description() string { return i.desc }
func tickCmd(d time.Duration) tea.Cmd {
return tea.Tick(d, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func (m model) Init() tea.Cmd {
return tea.Batch(textinput.Blink, tea.WindowSize())
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tickMsg:
@@ -194,555 +141,62 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.KeyMsg:
if m.showingComingSoon {
m.showingComingSoon = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
return m.comingSoonUpdate(msg)
}
if m.editingSlug {
if m.tunnelType != types.HTTP {
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
switch msg.String() {
case "esc":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter":
inputValue := m.slugInput.Value()
if isForbiddenSlug(inputValue) {
m.slugError = "This subdomain is reserved. Please choose a different one."
return m, nil
} else if !isValidSlug(inputValue) {
m.slugError = "Invalid subdomain. Follow the rules."
return m, nil
}
if !m.interaction.updateClientSlug(m.interaction.slugManager.Get(), inputValue) {
m.slugError = "Someone already uses this subdomain."
return m, nil
}
m.tunnelURL = buildURL(m.protocol, inputValue, m.domain)
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "ctrl+c":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
default:
if key.Matches(msg, m.keymap.random) {
newSubdomain := generateRandomSubdomain()
m.slugInput.SetValue(newSubdomain)
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
return m, cmd
}
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
return m, cmd
}
return m.slugUpdate(msg)
}
if m.showingCommands {
switch {
case key.Matches(msg, m.keymap.quit):
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case msg.String() == "enter":
selectedItem := m.commandList.SelectedItem()
if selectedItem != nil {
item := selectedItem.(commandItem)
if item.name == "slug" {
m.showingCommands = false
m.editingSlug = true
m.slugInput.SetValue(m.interaction.slugManager.Get())
m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" {
m.showingCommands = false
m.showingComingSoon = true
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
}
m.showingCommands = false
return m, nil
}
case msg.String() == "esc":
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
m.commandList, cmd = m.commandList.Update(msg)
return m, cmd
return m.commandsUpdate(msg)
}
switch {
case key.Matches(msg, m.keymap.quit):
m.quitting = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
case key.Matches(msg, m.keymap.command):
m.showingCommands = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
return m.dashboardUpdate(msg)
}
return m, nil
}
func (m model) helpView() string {
return "\n" + m.help.ShortHelpView([]key.Binding{
m.keymap.command,
m.keymap.quit,
})
func (i *interaction) Redraw() {
if i.program != nil {
i.program.Send(tea.ClearScreen())
}
}
func (m model) View() string {
func (m *model) View() string {
if m.quitting {
return ""
}
if m.showingComingSoon {
isCompact := shouldUseCompactLayout(m.width, 60)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 3
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
messageBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Background(lipgloss.Color("#1A1A2E")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(messageBoxWidth).
Align(lipgloss.Center)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Coming Soon"
} else {
title = "⏳ Coming Soon"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
var message string
if shouldUseCompactLayout(m.width, 50) {
message = "Coming soon!\nStay tuned."
} else {
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
}
b.WriteString(messageBoxStyle.Render(message))
b.WriteString("\n\n")
var helpText string
if shouldUseCompactLayout(m.width, 60) {
helpText = "Press any key..."
} else {
helpText = "This message will disappear in 5 seconds or press any key..."
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
return m.comingSoonView()
}
if m.editingSlug {
isCompact := shouldUseCompactLayout(m.width, 70)
isVeryCompact := shouldUseCompactLayout(m.width, 50)
var boxPadding int
var boxMargin int
if isVeryCompact {
boxPadding = 1
boxMargin = 1
} else if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
instructionStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginTop(1)
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
errorBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FF0000")).
Background(lipgloss.Color("#3D0000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1)
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
rulesBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1).
Width(rulesBoxWidth)
var b strings.Builder
var title string
if isVeryCompact {
title = "Edit Subdomain"
} else {
title = "🔧 Edit Subdomain"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
if m.tunnelType != types.HTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")).
Background(lipgloss.Color("#3D2000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FFA500")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(warningBoxWidth)
var warningText string
if isVeryCompact {
warningText = "⚠️ TCP tunnels don't support custom subdomains."
} else {
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
}
b.WriteString(warningBoxStyle.Render(warningText))
b.WriteString("\n\n")
var helpText string
if isVeryCompact {
helpText = "Press any key to go back"
} else {
helpText = "Press Enter or Esc to go back"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
var rulesContent string
if isVeryCompact {
rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
} else if isCompact {
rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
} else {
rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
}
b.WriteString(rulesBoxStyle.Render(rulesContent))
b.WriteString("\n")
var instruction string
if isVeryCompact {
instruction = "Custom subdomain:"
} else {
instruction = "Enter your custom subdomain:"
}
b.WriteString(instructionStyle.Render(instruction))
b.WriteString("\n")
if m.slugError != "" {
errorInputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(1)
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
b.WriteString("\n")
} else {
b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
}
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
if len(previewURL) > previewWidth-10 {
previewURL = truncateString(previewURL, previewWidth-10)
}
previewStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Italic(true).
Width(previewWidth)
b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
b.WriteString("\n")
var helpText string
if isVeryCompact {
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
} else {
helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
return m.slugView()
}
if m.showingCommands {
isCompact := shouldUseCompactLayout(m.width, 60)
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Commands"
} else {
title = "⚡ Commands"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
b.WriteString(m.commandList.View())
b.WriteString("\n")
var helpText string
if isCompact {
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
} else {
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
return m.commandsView()
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1)
subtitleStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Italic(true)
urlStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Underline(true).
Italic(true)
urlBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Bold(true).
Italic(true)
keyHintStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Bold(true)
var b strings.Builder
isCompact := shouldUseCompactLayout(m.width, 85)
var asciiArtMargin int
if isCompact {
asciiArtMargin = 0
} else {
asciiArtMargin = 1
}
asciiArtStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
MarginBottom(asciiArtMargin)
var asciiArt string
if shouldUseCompactLayout(m.width, 50) {
asciiArt = "TUNNEL PLS"
} else if isCompact {
asciiArt = `
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
} else {
asciiArt = `
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
}
b.WriteString(asciiArtStyle.Render(asciiArt))
b.WriteString("\n")
if !shouldUseCompactLayout(m.width, 60) {
b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
b.WriteString(urlStyle.Render("https://fossy.my.id"))
b.WriteString("\n\n")
} else {
b.WriteString("\n")
}
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
responsiveInfoBox := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(boxMaxWidth)
urlDisplay := m.tunnelURL
if shouldUseCompactLayout(m.width, 80) && len(m.tunnelURL) > m.width-20 {
maxLen := m.width - 25
if maxLen > 10 {
urlDisplay = truncateString(m.tunnelURL, maxLen)
}
}
var infoContent string
if shouldUseCompactLayout(m.width, 70) {
infoContent = fmt.Sprintf("🌐 %s", urlBoxStyle.Render(urlDisplay))
} else if isCompact {
infoContent = fmt.Sprintf("🌐 Forwarding to:\n\n %s", urlBoxStyle.Render(urlDisplay))
} else {
infoContent = fmt.Sprintf("🌐 F O R W A R D I N G T O:\n\n %s", urlBoxStyle.Render(urlDisplay))
}
b.WriteString(responsiveInfoBox.Render(infoContent))
b.WriteString("\n")
var quickActionsTitle string
if shouldUseCompactLayout(m.width, 50) {
quickActionsTitle = "Actions"
} else if isCompact {
quickActionsTitle = "Quick Actions"
} else {
quickActionsTitle = "✨ Quick Actions"
}
b.WriteString(titleStyle.Render(quickActionsTitle))
b.WriteString("\n")
var featureMargin int
if isCompact {
featureMargin = 1
} else {
featureMargin = 2
}
compactFeatureStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginLeft(featureMargin)
var commandsText string
var quitText string
if shouldUseCompactLayout(m.width, 60) {
commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
} else {
commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
}
b.WriteString(compactFeatureStyle.Render(commandsText))
b.WriteString("\n")
b.WriteString(compactFeatureStyle.Render(quitText))
if !shouldUseCompactLayout(m.width, 70) {
b.WriteString("\n\n")
footerStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true)
b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
}
return b.String()
return m.dashboardView()
}
func (i *Interaction) Start() {
func (i *interaction) Start() {
if i.mode == types.InteractiveModeHEADLESS {
return
}
lipgloss.SetColorProfile(termenv.TrueColor)
domain := config.Getenv("DOMAIN", "localhost")
protocol := "http"
if config.Getenv("TLS_ENABLED", "false") == "true" {
if i.config.TLSEnabled() {
protocol = "https"
}
tunnelType := i.forwarder.GetTunnelType()
port := i.forwarder.GetForwardedPort()
var tunnelURL string
if tunnelType == types.HTTP {
tunnelURL = buildURL(protocol, i.slugManager.Get(), domain)
} else {
tunnelURL = fmt.Sprintf("tcp://%s:%d", domain, port)
}
tunnelType := i.forwarder.TunnelType()
port := i.forwarder.ForwardedPort()
items := []list.Item{
commandItem{name: "slug", desc: "Set custom subdomain"},
@@ -764,9 +218,9 @@ func (i *Interaction) Start() {
ti.CharLimit = 20
ti.Width = 50
m := model{
tunnelURL: tunnelURL,
domain: domain,
m := &model{
randomizer: i.randomizer,
domain: i.config.Domain(),
protocol: protocol,
tunnelType: tunnelType,
port: port,
@@ -790,6 +244,7 @@ func (i *Interaction) Start() {
help: help.New(),
}
i.programMu.Lock()
i.program = tea.NewProgram(
m,
tea.WithInput(i.channel),
@@ -800,49 +255,21 @@ func (i *Interaction) Start() {
tea.WithoutSignalHandler(),
tea.WithFPS(30),
)
i.programMu.Unlock()
_, err := i.program.Run()
if err != nil {
log.Printf("Cannot close tea: %s \n", err)
}
i.program.Kill()
i.program = nil
if err := m.interaction.lifecycle.Close(); err != nil {
log.Printf("Cannot close session: %s \n", err)
i.programMu.Lock()
if i.program != nil {
i.program.Kill()
i.program = nil
}
i.programMu.Unlock()
if i.closeFunc != nil {
_ = i.closeFunc()
}
}
func buildURL(protocol, subdomain, domain string) string {
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
}
func generateRandomSubdomain() string {
return random.GenerateRandomString(20)
}
func isValidSlug(slug string) bool {
if len(slug) < minSlugLength || len(slug) > maxSlugLength {
return false
}
if slug[0] == '-' || slug[len(slug)-1] == '-' {
return false
}
for _, c := range slug {
if !isValidSlugChar(byte(c)) {
return false
}
}
return true
}
func isValidSlugChar(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-'
}
func isForbiddenSlug(slug string) bool {
_, ok := forbiddenSlugs[slug]
return ok
}
File diff suppressed because it is too large Load Diff
+116
View File
@@ -0,0 +1,116 @@
package interaction
import (
"fmt"
"time"
"tunnel_pls/internal/random"
"tunnel_pls/types"
"github.com/charmbracelet/bubbles/help"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/list"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
)
type commandItem struct {
name string
desc string
}
func (i commandItem) FilterValue() string { return i.name }
func (i commandItem) Title() string { return i.name }
func (i commandItem) Description() string { return i.desc }
type model struct {
randomizer random.Random
domain string
protocol string
tunnelType types.TunnelType
port uint16
keymap keymap
help help.Model
quitting bool
showingCommands bool
editingSlug bool
showingComingSoon bool
commandList list.Model
slugInput textinput.Model
slugError string
interaction *interaction
width int
height int
}
const (
ColorPrimary = "#7D56F4"
ColorSecondary = "#04B575"
ColorGray = "#888888"
ColorDarkGray = "#666666"
ColorWhite = "#FAFAFA"
ColorError = "#FF0000"
ColorErrorBg = "#3D0000"
ColorWarning = "#FFA500"
ColorWarningBg = "#3D2000"
)
const (
BreakpointTiny = 50
BreakpointSmall = 60
BreakpointMedium = 70
BreakpointLarge = 85
)
func (m *model) getTunnelURL() string {
if m.tunnelType == types.TunnelTypeHTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
}
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
}
type keymap struct {
quit key.Binding
command key.Binding
random key.Binding
}
type tickMsg time.Time
func (m *model) Init() tea.Cmd {
return tea.Batch(textinput.Blink, tea.WindowSize())
}
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
width := screenWidth - padding
if width > maxWidth {
width = maxWidth
}
if width < minWidth {
width = minWidth
}
return width
}
func shouldUseCompactLayout(width int, threshold int) bool {
return width < threshold
}
func truncateString(s string, maxLength int) string {
if len(s) <= maxLength {
return s
}
if maxLength < 4 {
return s[:maxLength]
}
return s[:maxLength-3] + "..."
}
func tickCmd(d time.Duration) tea.Cmd {
return tea.Tick(d, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func buildURL(protocol, subdomain, domain string) string {
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
}
+265
View File
@@ -0,0 +1,265 @@
package interaction
import (
"fmt"
"strings"
"tunnel_pls/types"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
if m.tunnelType != types.TunnelTypeHTTP {
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
switch msg.String() {
case "esc", "ctrl+c":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter":
inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
Id: m.interaction.slug.String(),
Type: types.TunnelTypeHTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.TunnelTypeHTTP,
}); err != nil {
m.slugError = err.Error()
return m, nil
}
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
default:
if key.Matches(msg, m.keymap.random) {
newSubdomain, err := m.randomizer.String(20)
if err != nil {
return m, cmd
}
m.slugInput.SetValue(newSubdomain)
}
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
return m, cmd
}
}
func (m *model) slugView() string {
isCompact := shouldUseCompactLayout(m.width, BreakpointMedium)
isVeryCompact := shouldUseCompactLayout(m.width, BreakpointTiny)
var b strings.Builder
b.WriteString(m.renderSlugTitle(isVeryCompact))
if m.tunnelType != types.TunnelTypeHTTP {
b.WriteString(m.renderTCPWarning(isVeryCompact, isCompact))
return b.String()
}
b.WriteString(m.renderSlugRules(isVeryCompact, isCompact))
b.WriteString(m.renderSlugInstruction(isVeryCompact))
b.WriteString(m.renderSlugInput(isVeryCompact, isCompact))
b.WriteString(m.renderSlugPreview(isVeryCompact))
b.WriteString(m.renderSlugHelp(isVeryCompact))
return b.String()
}
func (m *model) renderSlugTitle(isVeryCompact bool) string {
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color(ColorPrimary)).
PaddingTop(1).
PaddingBottom(1)
title := "🔧 Edit Subdomain"
if isVeryCompact {
title = "Edit Subdomain"
}
return titleStyle.Render(title) + "\n\n"
}
func (m *model) renderTCPWarning(isVeryCompact, isCompact bool) string {
boxPadding := getPaddingValue(isVeryCompact, isCompact)
boxMargin := getMarginValue(isCompact, 1, 2)
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWarning)).
Background(lipgloss.Color(ColorWarningBg)).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorWarning)).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(warningBoxWidth)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorDarkGray)).
Italic(true).
MarginTop(1)
warningText := m.getTCPWarningText(isVeryCompact)
helpText := m.getTCPHelpText(isVeryCompact)
var b strings.Builder
b.WriteString(warningBoxStyle.Render(warningText))
b.WriteString("\n\n")
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
func (m *model) getTCPWarningText(isVeryCompact bool) string {
if isVeryCompact {
return "⚠️ TCP tunnels don't support custom subdomains."
}
return "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
}
func (m *model) getTCPHelpText(isVeryCompact bool) string {
if isVeryCompact {
return "Press any key to go back"
}
return "Press Enter or Esc to go back"
}
func (m *model) renderSlugRules(isVeryCompact, isCompact bool) string {
boxPadding := getPaddingValue(isVeryCompact, isCompact)
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
rulesBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWhite)).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorPrimary)).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1).
Width(rulesBoxWidth)
rulesContent := m.getRulesContent(isVeryCompact, isCompact)
return rulesBoxStyle.Render(rulesContent) + "\n"
}
func (m *model) getRulesContent(isVeryCompact, isCompact bool) string {
if isVeryCompact {
return "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
}
if isCompact {
return "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
}
return "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
}
func (m *model) renderSlugInstruction(isVeryCompact bool) string {
instructionStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorWhite)).
MarginTop(1)
instruction := "Enter your custom subdomain:"
if isVeryCompact {
instruction = "Custom subdomain:"
}
return instructionStyle.Render(instruction) + "\n"
}
func (m *model) renderSlugInput(isVeryCompact, isCompact bool) string {
boxPadding := getPaddingValue(isVeryCompact, isCompact)
boxMargin := getMarginValue(isCompact, 1, 2)
if m.slugError != "" {
return m.renderErrorInput(boxPadding, boxMargin)
}
return m.renderNormalInput(boxPadding, boxMargin)
}
func (m *model) renderErrorInput(boxPadding, boxMargin int) string {
errorInputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorError)).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(1)
errorBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorError)).
Background(lipgloss.Color(ColorErrorBg)).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorError)).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1)
var b strings.Builder
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
b.WriteString("\n")
return b.String()
}
func (m *model) renderNormalInput(boxPadding, boxMargin int) string {
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color(ColorPrimary)).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin)
return inputBoxStyle.Render(m.slugInput.View()) + "\n"
}
func (m *model) renderSlugPreview(isVeryCompact bool) string {
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
if isVeryCompact {
previewURL = truncateString(previewURL, previewWidth-10)
}
previewStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorSecondary)).
Italic(true).
Width(previewWidth)
return previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)) + "\n"
}
func (m *model) renderSlugHelp(isVeryCompact bool) string {
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color(ColorDarkGray)).
Italic(true).
MarginTop(1)
helpText := "Press Enter to save • CTRL+R for random • Esc to cancel"
if isVeryCompact {
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
}
return helpStyle.Render(helpText)
}
func getPaddingValue(isVeryCompact, isCompact bool) int {
if isVeryCompact || isCompact {
return 1
}
return 2
}
+101 -54
View File
@@ -4,6 +4,9 @@ import (
"errors"
"io"
"net"
"sync"
"time"
portUtil "tunnel_pls/internal/port"
"tunnel_pls/session/slug"
"tunnel_pls/types"
@@ -13,88 +16,132 @@ import (
type Forwarder interface {
Close() error
GetTunnelType() types.TunnelType
GetForwardedPort() uint16
TunnelType() types.TunnelType
ForwardedPort() uint16
}
type Lifecycle struct {
status types.Status
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
slugManager slug.Manager
unregisterClient func(slug string)
type SessionRegistry interface {
Remove(key types.SessionKey)
}
func NewLifecycle(conn ssh.Conn, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
return &Lifecycle{
status: "",
conn: conn,
channel: nil,
forwarder: forwarder,
slugManager: slugManager,
unregisterClient: nil,
type lifecycle struct {
mu sync.Mutex
status types.SessionStatus
closeErr error
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
slug slug.Slug
startedAt time.Time
sessionRegistry SessionRegistry
portRegistry portUtil.Port
user string
}
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
return &lifecycle{
status: types.SessionStatusINITIALIZING,
conn: conn,
channel: nil,
forwarder: forwarder,
slug: slugManager,
startedAt: time.Now(),
sessionRegistry: sessionRegistry,
portRegistry: port,
user: user,
}
}
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
l.unregisterClient = unregisterClient
}
type SessionLifecycle interface {
Close() error
SetStatus(status types.Status)
GetConnection() ssh.Conn
GetChannel() ssh.Channel
type Lifecycle interface {
Connection() ssh.Conn
Channel() ssh.Channel
PortRegistry() portUtil.Port
User() string
SetChannel(channel ssh.Channel)
SetUnregisterClient(unregisterClient func(slug string))
SetStatus(status types.SessionStatus)
IsActive() bool
StartedAt() time.Time
Close() error
}
func (l *Lifecycle) GetChannel() ssh.Channel {
func (l *lifecycle) PortRegistry() portUtil.Port {
return l.portRegistry
}
func (l *lifecycle) User() string {
return l.user
}
func (l *lifecycle) SetChannel(channel ssh.Channel) {
l.channel = channel
}
func (l *lifecycle) Channel() ssh.Channel {
return l.channel
}
func (l *Lifecycle) SetChannel(channel ssh.Channel) {
l.channel = channel
}
func (l *Lifecycle) GetConnection() ssh.Conn {
func (l *lifecycle) Connection() ssh.Conn {
return l.conn
}
func (l *Lifecycle) SetStatus(status types.Status) {
func (l *lifecycle) SetStatus(status types.SessionStatus) {
l.mu.Lock()
defer l.mu.Unlock()
l.status = status
}
func (l *Lifecycle) Close() error {
err := l.forwarder.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
func (l *lifecycle) IsActive() bool {
l.mu.Lock()
defer l.mu.Unlock()
return l.status == types.SessionStatusRUNNING
}
func (l *lifecycle) Close() error {
l.mu.Lock()
defer l.mu.Unlock()
if l.status == types.SessionStatusCLOSED {
return l.closeErr
}
l.status = types.SessionStatusCLOSED
var errs []error
tunnelType := l.forwarder.TunnelType()
if l.channel != nil {
err := l.channel.Close()
if err != nil && !errors.Is(err, io.EOF) {
return err
if err := l.channel.Close(); err != nil && !isClosedError(err) {
errs = append(errs, err)
}
}
if l.conn != nil {
err := l.conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
if err := l.conn.Close(); err != nil && !isClosedError(err) {
errs = append(errs, err)
}
}
clientSlug := l.slugManager.Get()
if clientSlug != "" {
l.unregisterClient(clientSlug)
clientSlug := l.slug.String()
key := types.SessionKey{
Id: clientSlug,
Type: tunnelType,
}
l.sessionRegistry.Remove(key)
if tunnelType == types.TunnelTypeTCP {
errs = append(errs, l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false))
errs = append(errs, l.forwarder.Close())
}
if l.forwarder.GetTunnelType() == types.TCP {
err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
if err != nil {
return err
}
}
return nil
l.closeErr = errors.Join(errs...)
return l.closeErr
}
func isClosedError(err error) bool {
if err == nil {
return false
}
return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || err.Error() == "EOF"
}
func (l *lifecycle) StartedAt() time.Time {
return l.startedAt
}
+303
View File
@@ -0,0 +1,303 @@
package lifecycle
import (
"context"
"errors"
"io"
"net"
"testing"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/crypto/ssh"
)
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Remove(key types.SessionKey) {
m.Called(key)
}
type MockForwarder struct {
mock.Mock
}
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
args := m.Called(origin)
return args.Get(0).([]byte)
}
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
m.Called(dst, src)
}
func (m *MockForwarder) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockForwarder) TunnelType() types.TunnelType {
args := m.Called()
return args.Get(0).(types.TunnelType)
}
func (m *MockForwarder) ForwardedPort() uint16 {
args := m.Called()
return args.Get(0).(uint16)
}
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
m.Called(tunnelType)
}
func (m *MockForwarder) SetForwardedPort(port uint16) {
m.Called(port)
}
func (m *MockForwarder) SetListener(listener net.Listener) {
m.Called(listener)
}
func (m *MockForwarder) Listener() net.Listener {
args := m.Called()
return args.Get(0).(net.Listener)
}
func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
args := m.Called(ctx, origin)
if args.Get(0) == nil {
return nil, nil, args.Error(2)
}
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
type MockPort struct {
mock.Mock
}
func (m *MockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *MockPort) Unassigned() (uint16, bool) {
args := m.Called()
var port uint16
if args.Get(0) != nil {
switch v := args.Get(0).(type) {
case int:
port = uint16(v)
case uint16:
port = v
case uint32:
port = uint16(v)
case int32:
port = uint16(v)
case float64:
port = uint16(v)
default:
port = uint16(args.Int(0))
}
}
return port, args.Bool(1)
}
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *MockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type MockSlug struct {
mock.Mock
}
func (ms *MockSlug) Set(slug string) {
ms.Called(slug)
}
func (ms *MockSlug) String() string {
return ms.Called().String(0)
}
type MockSSHConn struct {
ssh.Conn
mock.Mock
}
func (m *MockSSHConn) Close() error {
args := m.Called()
return args.Error(0)
}
type MockSSHChannel struct {
ssh.Channel
mock.Mock
}
func (m *MockSSHChannel) Close() error {
return m.Called().Error(0)
}
func TestNew(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
assert.NotNil(t, mockLifecycle.Connection())
assert.NotNil(t, mockLifecycle.User())
assert.NotNil(t, mockLifecycle.PortRegistry())
assert.NotNil(t, mockLifecycle.StartedAt())
}
func TestLifecycle_User(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
user := "mas-fuad"
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, user)
assert.Equal(t, user, mockLifecycle.User())
}
func TestLifecycle_SetChannel(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
mockSSHChannel := &MockSSHChannel{}
mockLifecycle.SetChannel(mockSSHChannel)
assert.Equal(t, mockSSHChannel, mockLifecycle.Channel())
}
func TestLifecycle_SetStatus(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
assert.True(t, mockLifecycle.IsActive())
}
func TestLifecycle_IsActive(t *testing.T) {
mockSSHConn := new(MockSSHConn)
mockForwarder := &MockForwarder{}
mockSlug := &MockSlug{}
mockPort := &MockPort{}
mockSessionRegistry := &MockSessionRegistry{}
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
assert.True(t, mockLifecycle.IsActive())
}
func TestLifecycle_Close(t *testing.T) {
tests := []struct {
name string
tunnelType types.TunnelType
connCloseErr error
channelCloseErr error
expectErr bool
alreadyClosed bool
}{
{
name: "Close HTTP forwarding success",
tunnelType: types.TunnelTypeHTTP,
expectErr: false,
},
{
name: "Close TCP forwarding success",
tunnelType: types.TunnelTypeTCP,
expectErr: false,
},
{
name: "Close with conn close error",
tunnelType: types.TunnelTypeHTTP,
connCloseErr: errors.New("conn close error"),
expectErr: true,
},
{
name: "Close with channel close error",
tunnelType: types.TunnelTypeHTTP,
channelCloseErr: errors.New("channel close error"),
expectErr: true,
},
{
name: "Close when already closed",
tunnelType: types.TunnelTypeHTTP,
alreadyClosed: true,
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSSHConn := &MockSSHConn{}
mockSSHConn.On("Close").Return(tt.connCloseErr)
mockForwarder := &MockForwarder{}
mockForwarder.On("TunnelType").Return(tt.tunnelType)
if tt.tunnelType == types.TunnelTypeTCP {
mockForwarder.On("ForwardedPort").Return(uint16(8080))
mockForwarder.On("Close").Return(nil)
}
mockSlug := &MockSlug{}
mockSlug.On("String").Return("test-slug")
mockPort := &MockPort{}
if tt.tunnelType == types.TunnelTypeTCP {
mockPort.On("SetStatus", uint16(8080), false).Return(nil)
}
mockSessionRegistry := &MockSessionRegistry{}
mockSessionRegistry.On("Remove", mock.Anything).Return()
mockSSHChannel := &MockSSHChannel{}
mockSSHChannel.On("Close").Return(tt.channelCloseErr)
mockLifecycle := New(mockSSHConn, mockForwarder, mockSlug, mockPort, mockSessionRegistry, "mas-fuad")
mockLifecycle.SetStatus(types.SessionStatusRUNNING)
mockLifecycle.SetChannel(mockSSHChannel)
if tt.alreadyClosed {
err := mockLifecycle.Close()
assert.NoError(t, err)
}
err := mockLifecycle.Close()
if tt.expectErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.False(t, mockLifecycle.IsActive())
mockSSHConn.AssertExpectations(t)
mockForwarder.AssertExpectations(t)
mockSlug.AssertExpectations(t)
mockPort.AssertExpectations(t)
mockSessionRegistry.AssertExpectations(t)
mockSSHChannel.AssertExpectations(t)
})
}
}
-66
View File
@@ -1,66 +0,0 @@
package session
import "sync"
type Registry interface {
Get(slug string) (session *SSHSession, exist bool)
Update(oldSlug, newSlug string) (success bool)
Register(slug string, session *SSHSession) (success bool)
Remove(slug string)
}
type registry struct {
mu sync.RWMutex
clients map[string]*SSHSession
}
func NewRegistry() Registry {
return &registry{
clients: make(map[string]*SSHSession),
}
}
func (r *registry) Get(slug string) (session *SSHSession, exist bool) {
r.mu.RLock()
defer r.mu.RUnlock()
session, exist = r.clients[slug]
return
}
func (r *registry) Update(oldSlug, newSlug string) (success bool) {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.clients[newSlug]; exists && newSlug != oldSlug {
return false
}
client, ok := r.clients[oldSlug]
if !ok {
return false
}
delete(r.clients, oldSlug)
client.slugManager.Set(newSlug)
r.clients[newSlug] = client
return true
}
func (r *registry) Register(slug string, session *SSHSession) (success bool) {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.clients[slug]; exists {
return false
}
r.clients[slug] = session
return true
}
func (r *registry) Remove(slug string) {
r.mu.Lock()
defer r.mu.Unlock()
delete(r.clients, slug)
}
+337 -58
View File
@@ -1,107 +1,203 @@
package session
import (
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net"
"time"
"tunnel_pls/internal/config"
portUtil "tunnel_pls/internal/port"
"tunnel_pls/internal/random"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16)
HandleTCPForward(req *ssh.Request, addr string, port uint16)
HandleGlobalRequest(ch <-chan *ssh.Request) error
HandleTCPIPForward(req *ssh.Request) error
HandleHTTPForward(req *ssh.Request, port uint16) error
HandleTCPForward(req *ssh.Request, addr string, port uint16) error
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *types.Detail
Start() error
}
type SSHSession struct {
initialReq <-chan *ssh.Request
sshReqChannel <-chan ssh.NewChannel
lifecycle lifecycle.SessionLifecycle
interaction interaction.Controller
forwarder forwarder.ForwardingController
slugManager slug.Manager
registry Registry
type session struct {
randomizer random.Random
config config.Config
initialReq <-chan *ssh.Request
sshChan <-chan ssh.NewChannel
lifecycle lifecycle.Lifecycle
interaction interaction.Interaction
forwarder forwarder.Forwarder
slug slug.Slug
registry registry.Registry
}
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
return s.lifecycle
type Config struct {
Randomizer random.Random
Config config.Config
Conn *ssh.ServerConn
InitialReq <-chan *ssh.Request
SshChan <-chan ssh.NewChannel
SessionRegistry registry.Registry
PortRegistry portUtil.Port
User string
}
func (s *SSHSession) GetInteraction() interaction.Controller {
return s.interaction
}
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func (s *SSHSession) GetForwarder() forwarder.ForwardingController {
return s.forwarder
}
func New(conf *Config) Session {
slugManager := slug.New()
forwarderManager := forwarder.New(conf.Config, slugManager, conf.Conn)
lifecycleManager := lifecycle.New(conf.Conn, forwarderManager, slugManager, conf.PortRegistry, conf.SessionRegistry, conf.User)
interactionManager := interaction.New(conf.Randomizer, conf.Config, slugManager, forwarderManager, conf.SessionRegistry, conf.User, lifecycleManager.Close)
func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager
}
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry) *SSHSession {
slugManager := slug.NewManager()
forwarderManager := forwarder.NewForwarder(slugManager)
interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
lifecycleManager := lifecycle.NewLifecycle(conn, forwarderManager, slugManager)
interactionManager.SetLifecycle(lifecycleManager)
interactionManager.SetSlugModificator(sessionRegistry.Update)
forwarderManager.SetLifecycle(lifecycleManager)
lifecycleManager.SetUnregisterClient(sessionRegistry.Remove)
return &SSHSession{
initialReq: forwardingReq,
sshReqChannel: sshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slugManager: slugManager,
registry: sessionRegistry,
return &session{
randomizer: conf.Randomizer,
config: conf.Config,
initialReq: conf.InitialReq,
sshChan: conf.SshChan,
lifecycle: lifecycleManager,
interaction: interactionManager,
forwarder: forwarderManager,
slug: slugManager,
registry: conf.SessionRegistry,
}
}
func (s *SSHSession) Start() error {
channel := <-s.sshReqChannel
func (s *session) Lifecycle() lifecycle.Lifecycle {
return s.lifecycle
}
func (s *session) Interaction() interaction.Interaction {
return s.interaction
}
func (s *session) Forwarder() forwarder.Forwarder {
return s.forwarder
}
func (s *session) Slug() slug.Slug {
return s.slug
}
func (s *session) Detail() *types.Detail {
tunnelTypeMap := map[types.TunnelType]string{
types.TunnelTypeHTTP: "HTTP",
types.TunnelTypeTCP: "TCP",
}
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
if !ok {
tunnelType = "UNKNOWN"
}
return &types.Detail{
ForwardingType: tunnelType,
Slug: s.slug.String(),
UserID: s.lifecycle.User(),
Active: s.lifecycle.IsActive(),
StartedAt: s.lifecycle.StartedAt(),
}
}
func (s *session) Start() error {
if err := s.setupSessionMode(); err != nil {
return err
}
tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil {
return s.handleMissingForwardRequest()
}
if s.shouldRejectUnauthorized() {
return s.denyForwardingRequest(tcpipReq, nil, nil, "headless forwarding only allowed on node mode")
}
if err := s.HandleTCPIPForward(tcpipReq); err != nil {
return err
}
s.interaction.Start()
return s.waitForSessionEnd()
}
func (s *session) setupSessionMode() error {
select {
case channel, ok := <-s.sshChan:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
return s.setupInteractiveMode(channel)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.InteractiveModeHEADLESS)
return nil
}
}
func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
return err
}
go s.HandleGlobalRequest(reqs)
tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil {
_, err := ch.Write([]byte(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200"))))
go func() {
err = s.HandleGlobalRequest(reqs)
if err != nil {
return err
log.Printf("global request handler error: %v", err)
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("No forwarding Request")
}
}()
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.InteractiveModeINTERACTIVE)
s.HandleTCPIPForward(tcpipReq)
return nil
}
func (s *session) handleMissingForwardRequest() error {
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", s.config.Domain(), s.config.SSHPort()))
if err != nil {
return err
}
return fmt.Errorf("no forwarding Request")
}
func (s *session) shouldRejectUnauthorized() bool {
return s.interaction.Mode() == types.InteractiveModeHEADLESS &&
s.config.Mode() == types.ServerModeSTANDALONE &&
s.lifecycle.User() == "UNAUTHORIZED"
}
func (s *session) waitForSessionEnd() error {
if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("ssh connection closed with error: %v", err)
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err
}
return nil
}
func (s *SSHSession) waitForTCPIPForward() *ssh.Request {
func (s *session) waitForTCPIPForward() *ssh.Request {
select {
case req, ok := <-s.initialReq:
if !ok {
@@ -121,3 +217,186 @@ func (s *SSHSession) waitForTCPIPForward() *ssh.Request {
return nil
}
}
func (s *session) handleWindowChange(req *ssh.Request) error {
p := req.Payload
if len(p) < 16 {
log.Println("invalid window-change payload")
return req.Reply(false, nil)
}
cols := binary.BigEndian.Uint32(p[0:4])
rows := binary.BigEndian.Uint32(p[4:8])
s.interaction.SetWH(int(cols), int(rows))
return req.Reply(true, nil)
}
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
for req := range GlobalRequest {
switch req.Type {
case "shell", "pty-req":
if err := req.Reply(true, nil); err != nil {
return err
}
case "window-change":
if err := s.handleWindowChange(req); err != nil {
return err
}
default:
log.Println("Unknown request type:", req.Type)
if err := req.Reply(false, nil); err != nil {
return err
}
}
}
return nil
}
func (s *session) parseForwardPayload(payload []byte) (address string, port uint16, err error) {
var forwardPayload struct {
BindAddr string
BindPort uint32
}
if err = ssh.Unmarshal(payload, &forwardPayload); err != nil {
return "", 0, fmt.Errorf("failed to unmarshal forward payload: %w", err)
}
if forwardPayload.BindPort > 65535 {
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
}
port = uint16(forwardPayload.BindPort)
if isBlockedPort(port) {
return "", 0, fmt.Errorf("port is blocked")
}
if port == 0 {
unassigned, ok := s.lifecycle.PortRegistry().Unassigned()
if !ok {
return "", 0, fmt.Errorf("no available port")
}
return forwardPayload.BindAddr, unassigned, nil
}
return forwardPayload.BindAddr, port, nil
}
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
var errs []error
if key != nil {
s.registry.Remove(*key)
}
if listener != nil {
errs = append(errs, listener.Close())
}
errs = append(errs, req.Reply(false, nil))
errs = append(errs, s.lifecycle.Close())
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
return errors.Join(errs...)
}
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
replyPayload := struct {
BoundPort uint32
}{
BoundPort: uint32(portToBind),
}
err := req.Reply(true, ssh.Marshal(replyPayload))
if err != nil {
return err
}
s.forwarder.SetType(tunnelType)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(slug)
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
if listener != nil {
s.forwarder.SetListener(listener)
}
return nil
}
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
address, port, err := s.parseForwardPayload(req.Payload)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
}
switch port {
case 80, 443:
return s.HandleHTTPForward(req, port)
default:
return s.HandleTCPForward(req, address, port)
}
}
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
randomString, err := s.randomizer.String(20)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
}
key := types.SessionKey{Id: randomString, Type: types.TunnelTypeHTTP}
if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
}
err = s.finalizeForwarding(req, portToBind, nil, types.TunnelTypeHTTP, key.Id)
if err != nil {
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}
return nil
}
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
}
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
listener, err := tcpServer.Listen()
if err != nil {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Port %d is already in use or restricted", portToBind))
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TunnelTypeTCP}
if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TunnelTypeTCP client with id: %s", key.Id))
}
err = s.finalizeForwarding(req, portToBind, listener, types.TunnelTypeTCP, key.Id)
if err != nil {
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}
go func() {
err = tcpServer.Serve(listener)
if err != nil {
log.Printf("Failed serving tcp server: %s\n", err)
}
}()
return nil
}
func isBlockedPort(port uint16) bool {
if port == 80 || port == 443 {
return false
}
if port < 1024 && port != 0 {
return true
}
for _, p := range blockedReservedPorts {
if p == port {
return true
}
}
return false
}
File diff suppressed because it is too large Load Diff
+7 -7
View File
@@ -1,24 +1,24 @@
package slug
type Manager interface {
Get() string
type Slug interface {
String() string
Set(slug string)
}
type manager struct {
type slug struct {
slug string
}
func NewManager() Manager {
return &manager{
func New() Slug {
return &slug{
slug: "",
}
}
func (s *manager) Get() string {
func (s *slug) String() string {
return s.slug
}
func (s *manager) Set(slug string) {
func (s *slug) Set(slug string) {
s.slug = slug
}
+99
View File
@@ -0,0 +1,99 @@
package slug
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type SlugTestSuite struct {
suite.Suite
slug Slug
}
func (suite *SlugTestSuite) SetupTest() {
suite.slug = New()
}
func TestNew(t *testing.T) {
s := New()
assert.NotNil(t, s, "New() should return a non-nil Slug")
assert.Implements(t, (*Slug)(nil), s, "New() should return a type that implements Slug interface")
assert.Equal(t, "", s.String(), "New() should initialize with empty string")
}
func (suite *SlugTestSuite) TestString() {
assert.Equal(suite.T(), "", suite.slug.String(), "String() should return empty string initially")
suite.slug.Set("test-slug")
assert.Equal(suite.T(), "test-slug", suite.slug.String(), "String() should return the set value")
}
func (suite *SlugTestSuite) TestSet() {
testCases := []struct {
name string
input string
expected string
}{
{
name: "simple slug",
input: "hello-world",
expected: "hello-world",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "slug with numbers",
input: "test-123",
expected: "test-123",
},
{
name: "slug with special characters",
input: "hello_world-123",
expected: "hello_world-123",
},
{
name: "overwrite existing slug",
input: "new-slug",
expected: "new-slug",
},
}
for _, tc := range testCases {
suite.Run(tc.name, func() {
suite.slug.Set(tc.input)
assert.Equal(suite.T(), tc.expected, suite.slug.String())
})
}
}
func (suite *SlugTestSuite) TestMultipleSet() {
suite.slug.Set("first-slug")
assert.Equal(suite.T(), "first-slug", suite.slug.String())
suite.slug.Set("second-slug")
assert.Equal(suite.T(), "second-slug", suite.slug.String())
suite.slug.Set("")
assert.Equal(suite.T(), "", suite.slug.String())
}
func TestSlugIsolation(t *testing.T) {
slug1 := New()
slug2 := New()
slug1.Set("slug-one")
slug2.Set("slug-two")
assert.Equal(t, "slug-one", slug1.String(), "First slug should maintain its value")
assert.Equal(t, "slug-two", slug2.String(), "Second slug should maintain its value")
}
func TestSlugTestSuite(t *testing.T) {
suite.Run(t, new(SlugTestSuite))
}
+37 -7
View File
@@ -1,20 +1,50 @@
package types
type Status string
import "time"
type SessionStatus int
const (
INITIALIZING Status = "INITIALIZING"
RUNNING Status = "RUNNING"
SETUP Status = "SETUP"
SessionStatusINITIALIZING SessionStatus = iota
SessionStatusRUNNING
SessionStatusCLOSED
)
type TunnelType string
type InteractiveMode int
const (
HTTP TunnelType = "HTTP"
TCP TunnelType = "TCP"
InteractiveModeINTERACTIVE InteractiveMode = iota + 1
InteractiveModeHEADLESS
)
type TunnelType int
const (
TunnelTypeUNKNOWN TunnelType = iota
TunnelTypeHTTP
TunnelTypeTCP
)
type ServerMode int
const (
ServerModeSTANDALONE = iota + 1
ServerModeNODE
)
type SessionKey struct {
Id string
Type TunnelType
}
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
}
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" +