14 Commits

Author SHA1 Message Date
1e12373359 chore(restructure): reorganize project layout
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 13m1s
- Reorganize internal packages and overall project structure
- Update imports and wiring to match the new layout
- Separate HTTP parsing and streaming from the server package
- Separate middleware from the server package
- Separate session registry from the session package
- Move HTTP, HTTPS, and TCP servers to the transport package
- Session package no longer starts the TCP server directly
- Server package no longer starts HTTP/HTTPS servers on initialization
- Forwarder no longer handles accepting TCP requests
- Move session details to the types package
- HTTP/HTTPS initialization is now the responsibility of main
2026-01-21 14:06:46 +07:00
9a4539cc02 refactor(httpheader): extract header parsing into dedicated package
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m19s
Moved HTTP header parsing and building logic from server package to internal/httpheader
2026-01-20 21:15:34 +07:00
e3ead4d52f refactor: optimize header parsing and remove factory naming
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m20s
- Remove factory naming
- Use direct byte indexing instead of bytes.TrimRight
- Extract parseStartLine and setRemainingHeaders helpers
2026-01-20 20:56:08 +07:00
aa1a465178 refactor(forwarder): improve connection handling and cleanup
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Has been cancelled
- Extract copyAndClose method for bidirectional data transfe
- Add closeWriter helper for graceful connection shutdown
- Add handleIncomingConnection helper
- Add openForwardedChannel helper
2026-01-20 19:01:15 +07:00
27f49879af refactor(server): enhance HTTP handler modularity and fix resource leak
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m43s
- Rename customWriter struct to httpWriter for clarity
- Add closeWriter field to properly close write side of connections
- Update all cw variable references to hw
- Merge handlerTLS into handler function to reduce code duplication
- Extract handler into smaller, focused methods
- Split Read/Write/forwardRequest into composable functions

Fixes resource leak where connections weren't properly closed on the
write side, matching the forwarder's CloseWrite() pattern.
2026-01-19 22:41:04 +07:00
adb0264bb5 refactor(session): simplify Start() and unify forwarding logic
- Extract helper functions from Start() for better code organization
- Eliminate duplication with finalizeForwarding() method
- Consolidate denial logic into denyForwardingRequest()
- Update all handler methods to return errors instead of logging internally
- Improve error handling consistency across all operations
2026-01-19 15:53:16 +07:00
8fb19af5a6 fix: resolve copy goroutine deadlock on early connection close
- Add proper CloseWrite handling to signal EOF to other goroutine
- Ensure both copy goroutines terminate when either side closes
- Prevent goroutine leaks for SSH forwarded-tcpip channels:
    - Use select with default when sending result to resultChan
    - Close unused SSH channels and discard requests if main goroutine has already timed out
2026-01-19 00:20:28 +07:00
41fdb5639c Merge pull request 'refactor: explicit initialization and dependency injection' (#70) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 9m49s
Reviewed-on: #70
2026-01-18 21:46:59 +07:00
44d224f491 refactor: explicit initialization and dependency injection
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 10m10s
- Replace init() with config.Load() function when loading env variables
- Inject portRegistry into session, server, and lifecycle structs
- Inject sessionRegistry directly into interaction and lifecycle
- Remove SetSessionRegistry function and global port variables
- Pass ssh.Conn directly to forwarder constructor instead of lifecycle interface
- Pass user and closeFunc callback to interaction constructor instead of lifecycle interface
- Eliminate circular dependencies between lifecycle, forwarder, and interaction
- Remove setter methods (SetLifecycle) from forwarder and interaction interfaces
2026-01-18 21:20:05 +07:00
9be0328e24 Merge pull request 'staging' (#69) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 9m30s
Reviewed-on: #69
2026-01-17 19:15:40 +07:00
2b9bca65d5 refactor(interaction): separate view and update logic into modular files
Docker Build and Push / build-and-push-branches (push) Has been skipped
Docker Build and Push / build-and-push-tags (push) Successful in 11m44s
- Extract slug editing logic to slug.go (slugView/slugUpdate)
- Extract commands menu logic to commands.go (commandsView/commandsUpdate)
- Extract coming soon modal to coming_soon.go (comingSoonView/comingSoonUpdate)
- Extract main dashboard logic to dashboard.go (dashboardView/dashboardUpdate)
- Create model.go for shared model struct and helper functions
- Replace math/rand with crypto/rand for random subdomain generation
- Remove legacy TLS cipher suite configuration
2026-01-17 17:33:10 +07:00
6587dc0f39 refactor(interaction): separate view and update logic into modular files
- Extract slug editing logic to slug.go (slugView/slugUpdate)
- Extract commands menu logic to commands.go (commandsView/commandsUpdate)
- Extract coming soon modal to coming_soon.go (comingSoonView/comingSoonUpdate)
- Extract main dashboard logic to dashboard.go (dashboardView/dashboardUpdate)
- Create model.go for shared model struct and helper functions
- Replace math/rand with crypto/rand for random subdomain generation
- Remove legacy TLS cipher suite configuration
2026-01-17 17:30:21 +07:00
f421781f44 Merge pull request 'refactor: convert structs to interfaces and rename accessors' (#68) from staging into main
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 10m34s
Reviewed-on: #68
2026-01-16 16:41:22 +07:00
1a04af8873 Merge branch 'main' into staging
Docker Build and Push / build-and-push-branches (push) Successful in 11m35s
Docker Build and Push / build-and-push-tags (push) Has been skipped
2026-01-16 16:28:39 +07:00
40 changed files with 2136 additions and 1905 deletions
+2 -16
View File
@@ -1,7 +1,3 @@
git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0 h1:RhcBKUG41/om4jgN+iF/vlY/RojTeX1QhBa4p4428ec=
git.fossy.my.id/bagas/tunnel-please-grpc v1.3.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY=
git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0 h1:tpJSKjaSmV+vxxbVx6qnStjxFVXjj2M0rygWXxLb99o=
git.fossy.my.id/bagas/tunnel-please-grpc v1.4.0/go.mod h1:fG+VkArdkceGB0bNA7IFQus9GetLAwdF5Oi4jdMlXtY=
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0 h1:3xszIhck4wo9CoeRq9vnkar4PhY7kz9QrR30qj2XszA=
git.fossy.my.id/bagas/tunnel-please-grpc v1.5.0/go.mod h1:Weh6ZujgWmT8XxD3Qba7sJ6r5eyUMB9XSWynqdyOoLo=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
@@ -10,12 +6,8 @@ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiE
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/aymanbagabas/go-udiff v0.2.0 h1:TK0fH4MteXUDspT88n8CKzvK0X9O2xu9yQjWpi6yML8=
github.com/aymanbagabas/go-udiff v0.2.0/go.mod h1:RE4Ex0qsGkTAJoQdQQCA0uG+nAzJO/pI/QwceO5fgrA=
github.com/caddyserver/certmagic v0.25.0 h1:VMleO/XA48gEWes5l+Fh6tRWo9bHkhwAEhx63i+F5ic=
github.com/caddyserver/certmagic v0.25.0/go.mod h1:m9yB7Mud24OQbPHOiipAoyKPn9pKHhpSJxXR1jydBxA=
github.com/caddyserver/certmagic v0.25.1 h1:4sIKKbOt5pg6+sL7tEwymE1x2bj6CHr80da1CRRIPbY=
github.com/caddyserver/certmagic v0.25.1/go.mod h1:VhyvndxtVton/Fo/wKhRoC46Rbw1fmjvQ3GjHYSQTEY=
github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA=
github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/caddyserver/zerossl v0.1.4 h1:CVJOE3MZeFisCERZjkxIcsqIH4fnFdlYWnPYeFtBHRw=
github.com/caddyserver/zerossl v0.1.4/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
@@ -118,8 +110,6 @@ go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.uber.org/zap/exp v0.3.0 h1:6JYzdifzYkGmTdRR59oYH+Ng7k49H9qVpWwNSsGJj3U=
go.uber.org/zap/exp v0.3.0/go.mod h1:5I384qq7XGxYyByIhHm6jg5CHkGY0nsTfbDLgDDlgJQ=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
@@ -132,14 +122,10 @@ golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
+3 -5
View File
@@ -1,19 +1,17 @@
package config
import (
"log"
"os"
"strconv"
"github.com/joho/godotenv"
)
func init() {
func Load() error {
if _, err := os.Stat(".env"); err == nil {
if err := godotenv.Load(".env"); err != nil {
log.Printf("Warning: Failed to load .env file: %s", err)
}
return godotenv.Load(".env")
}
return nil
}
func Getenv(key, defaultValue string) string {
+3 -4
View File
@@ -8,10 +8,9 @@ import (
"log"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry"
"tunnel_pls/types"
"tunnel_pls/session"
proto "git.fossy.my.id/bagas/tunnel-please-grpc/gen"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
@@ -32,13 +31,13 @@ type Client interface {
type client struct {
conn *grpc.ClientConn
address string
sessionRegistry session.Registry
sessionRegistry registry.Registry
eventService proto.EventServiceClient
authorizeConnectionService proto.UserServiceClient
closing bool
}
func New(address string, sessionRegistry session.Registry) (Client, error) {
func New(address string, sessionRegistry registry.Registry) (Client, error) {
var opts []grpc.DialOption
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
+30
View File
@@ -0,0 +1,30 @@
package header
type ResponseHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type responseHeader struct {
startLine []byte
headers map[string]string
}
type RequestHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
}
type requestHeader struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
+148
View File
@@ -0,0 +1,148 @@
package header
import (
"bufio"
"bytes"
"fmt"
)
func setRemainingHeaders(remaining []byte, header interface {
Set(key string, value string)
}) {
for len(remaining) > 0 {
lineEnd := bytes.Index(remaining, []byte("\r\n"))
if lineEnd == -1 {
lineEnd = len(remaining)
}
line := remaining[:lineEnd]
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx != -1 {
key := bytes.TrimSpace(line[:colonIdx])
value := bytes.TrimSpace(line[colonIdx+1:])
header.Set(string(key), string(value))
}
if lineEnd == len(remaining) {
break
}
remaining = remaining[lineEnd+2:]
}
}
func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) {
header := &requestHeader{
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
}
startLine := headerData[:lineEnd]
header.startLine = startLine
var err error
header.method, header.path, header.version, err = parseStartLine(startLine)
if err != nil {
return nil, err
}
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func parseStartLine(startLine []byte) (method, path, version string, err error) {
firstSpace := bytes.IndexByte(startLine, ' ')
if firstSpace == -1 {
return "", "", "", fmt.Errorf("invalid start line: missing method")
}
secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ')
if secondSpace == -1 {
return "", "", "", fmt.Errorf("invalid start line: missing version")
}
secondSpace += firstSpace + 1
method = string(startLine[:firstSpace])
path = string(startLine[firstSpace+1 : secondSpace])
version = string(startLine[secondSpace+1:])
return method, path, version, nil
}
func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) {
header := &requestHeader{
headers: make(map[string]string, 16),
}
startLineBytes, err := br.ReadSlice('\n')
if err != nil {
return nil, err
}
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
header.startLine = make([]byte, len(startLineBytes))
copy(header.startLine, startLineBytes)
header.method, header.path, header.version, err = parseStartLine(header.startLine)
if err != nil {
return nil, err
}
for {
lineBytes, err := br.ReadSlice('\n')
if err != nil {
return nil, err
}
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
if len(lineBytes) == 0 {
break
}
colonIdx := bytes.IndexByte(lineBytes, ':')
if colonIdx == -1 {
continue
}
key := bytes.TrimSpace(lineBytes[:colonIdx])
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
header.headers[string(key)] = string(value)
}
return header, nil
}
func finalize(startLine []byte, headers map[string]string) []byte {
size := len(startLine) + 2
for key, val := range headers {
size += len(key) + 2 + len(val) + 2
}
size += 2
buf := make([]byte, 0, size)
buf = append(buf, startLine...)
buf = append(buf, '\r', '\n')
for key, val := range headers {
buf = append(buf, key...)
buf = append(buf, ':', ' ')
buf = append(buf, val...)
buf = append(buf, '\r', '\n')
}
buf = append(buf, '\r', '\n')
return buf
}
+49
View File
@@ -0,0 +1,49 @@
package header
import (
"bufio"
"fmt"
)
func NewRequest(r interface{}) (RequestHeader, error) {
switch v := r.(type) {
case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
}
}
func (req *requestHeader) Value(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeader) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeader) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeader) GetMethod() string {
return req.method
}
func (req *requestHeader) GetPath() string {
return req.path
}
func (req *requestHeader) GetVersion() string {
return req.version
}
func (req *requestHeader) Finalize() []byte {
return finalize(req.startLine, req.headers)
}
+40
View File
@@ -0,0 +1,40 @@
package header
import (
"bytes"
"fmt"
)
func NewResponse(headerData []byte) (ResponseHeader, error) {
header := &responseHeader{
startLine: nil,
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid response: no CRLF found in start line")
}
header.startLine = headerData[:lineEnd]
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func (resp *responseHeader) Value(key string) string {
return resp.headers[key]
}
func (resp *responseHeader) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeader) Remove(key string) {
delete(resp.headers, key)
}
func (resp *responseHeader) Finalize() []byte {
return finalize(resp.startLine, resp.headers)
}
+29
View File
@@ -0,0 +1,29 @@
package stream
import "bytes"
func splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
headerByte := data[:delimiterIdx+len(DELIMITER)]
body := data[delimiterIdx+len(DELIMITER):]
return headerByte, body
}
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
+50
View File
@@ -0,0 +1,50 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := hs.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
headerEndIdx := bytes.Index(tmp, DELIMITER)
if headerEndIdx == -1 {
return handleNoDelimiter(p, tmp, err)
}
headerByte, bodyByte := splitHeaderAndBody(tmp, headerEndIdx)
if !isHTTPHeader(headerByte) {
copy(p, tmp)
return read, nil
}
return hs.processHTTPRequest(p, headerByte, bodyByte)
}
func (hs *http) processHTTPRequest(p, headerByte, bodyByte []byte) (int, error) {
reqhf, err := header.NewRequest(headerByte)
if err != nil {
return 0, err
}
if err = hs.ApplyRequestMiddlewares(reqhf); err != nil {
return 0, err
}
hs.reqHeader = reqhf
combined := append(reqhf.Finalize(), bodyByte...)
return copy(p, combined), nil
}
func handleNoDelimiter(p, tmp []byte, err error) (int, error) {
copy(p, tmp)
return len(tmp), err
}
+103
View File
@@ -0,0 +1,103 @@
package stream
import (
"io"
"log"
"net"
"regexp"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/middleware"
)
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
type HTTP interface {
io.ReadWriteCloser
CloseWrite() error
RemoteAddr() net.Addr
UseResponseMiddleware(mw middleware.ResponseMiddleware)
UseRequestMiddleware(mw middleware.RequestMiddleware)
SetRequestHeader(header header.RequestHeader)
RequestMiddlewares() []middleware.RequestMiddleware
ResponseMiddlewares() []middleware.ResponseMiddleware
ApplyResponseMiddlewares(resphf header.ResponseHeader, body []byte) error
ApplyRequestMiddlewares(reqhf header.RequestHeader) error
}
type http struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader header.ResponseHeader
reqHeader header.RequestHeader
respMW []middleware.ResponseMiddleware
reqMW []middleware.RequestMiddleware
}
func New(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTP {
return &http{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
func (hs *http) RemoteAddr() net.Addr {
return hs.remoteAddr
}
func (hs *http) UseResponseMiddleware(mw middleware.ResponseMiddleware) {
hs.respMW = append(hs.respMW, mw)
}
func (hs *http) UseRequestMiddleware(mw middleware.RequestMiddleware) {
hs.reqMW = append(hs.reqMW, mw)
}
func (hs *http) SetRequestHeader(header header.RequestHeader) {
hs.reqHeader = header
}
func (hs *http) RequestMiddlewares() []middleware.RequestMiddleware {
return hs.reqMW
}
func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
return hs.respMW
}
func (hs *http) Close() error {
return hs.writer.(io.Closer).Close()
}
func (hs *http) CloseWrite() error {
if closer, ok := hs.writer.(interface{ CloseWrite() error }); ok {
return closer.CloseWrite()
}
return hs.Close()
}
func (hs *http) ApplyRequestMiddlewares(reqhf header.RequestHeader) error {
for _, m := range hs.RequestMiddlewares() {
if err := m.HandleRequest(reqhf); err != nil {
log.Printf("Error when applying request middleware: %v", err)
return err
}
}
return nil
}
func (hs *http) ApplyResponseMiddlewares(resphf header.ResponseHeader, bodyByte []byte) error {
for _, m := range hs.ResponseMiddlewares() {
if err := m.HandleResponse(resphf, bodyByte); err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return err
}
}
return nil
}
+88
View File
@@ -0,0 +1,88 @@
package stream
import (
"bytes"
"tunnel_pls/internal/http/header"
)
func (hs *http) Write(p []byte) (int, error) {
if hs.shouldBypassBuffering(p) {
hs.respHeader = nil
}
if hs.respHeader != nil {
return hs.writer.Write(p)
}
hs.buf = append(hs.buf, p...)
headerEndIdx := bytes.Index(hs.buf, DELIMITER)
if headerEndIdx == -1 {
return len(p), nil
}
return hs.processBufferedResponse(p, headerEndIdx)
}
func (hs *http) shouldBypassBuffering(p []byte) bool {
return hs.respHeader != nil && len(hs.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
}
func (hs *http) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
headerByte, bodyByte := splitHeaderAndBody(hs.buf, delimiterIdx)
if !isHTTPHeader(headerByte) {
return hs.writeRawBuffer()
}
if err := hs.processHTTPResponse(headerByte, bodyByte); err != nil {
return 0, err
}
hs.buf = nil
return len(p), nil
}
func (hs *http) writeRawBuffer() (int, error) {
_, err := hs.writer.Write(hs.buf)
length := len(hs.buf)
hs.buf = nil
if err != nil {
return 0, err
}
return length, nil
}
func (hs *http) processHTTPResponse(headerByte, bodyByte []byte) error {
resphf, err := header.NewResponse(headerByte)
if err != nil {
return err
}
if err = hs.ApplyResponseMiddlewares(resphf, bodyByte); err != nil {
return err
}
hs.respHeader = resphf
finalHeader := resphf.Finalize()
if err = hs.writeHeaderAndBody(finalHeader, bodyByte); err != nil {
return err
}
return nil
}
func (hs *http) writeHeaderAndBody(header, bodyByte []byte) error {
if _, err := hs.writer.Write(header); err != nil {
return err
}
if len(bodyByte) > 0 {
if _, err := hs.writer.Write(bodyByte); err != nil {
return err
}
}
return nil
}
+23
View File
@@ -0,0 +1,23 @@
package middleware
import (
"net"
"tunnel_pls/internal/http/header"
)
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header header.RequestHeader) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+13
View File
@@ -0,0 +1,13 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type RequestMiddleware interface {
HandleRequest(header header.RequestHeader) error
}
type ResponseMiddleware interface {
HandleResponse(header header.ResponseHeader, body []byte) error
}
+16
View File
@@ -0,0 +1,16 @@
package middleware
import (
"tunnel_pls/internal/http/header"
)
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header header.ResponseHeader, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
+22 -45
View File
@@ -3,63 +3,40 @@ package port
import (
"fmt"
"sort"
"strconv"
"strings"
"sync"
"tunnel_pls/internal/config"
)
type Manager interface {
AddPortRange(startPort, endPort uint16) error
GetUnassignedPort() (uint16, bool)
SetPortStatus(port uint16, assigned bool) error
ClaimPort(port uint16) (claimed bool)
type Port interface {
AddRange(startPort, endPort uint16) error
Unassigned() (uint16, bool)
SetStatus(port uint16, assigned bool) error
Claim(port uint16) (claimed bool)
}
type manager struct {
type port struct {
mu sync.RWMutex
ports map[uint16]bool
sortedPorts []uint16
}
var Default Manager = &manager{
ports: make(map[uint16]bool),
sortedPorts: []uint16{},
func New() Port {
return &port{
ports: make(map[uint16]bool),
sortedPorts: []uint16{},
}
}
func init() {
rawRange := config.Getenv("ALLOWED_PORTS", "")
if rawRange == "" {
return
}
splitRange := strings.Split(rawRange, "-")
if len(splitRange) != 2 {
return
}
start, err := strconv.ParseUint(splitRange[0], 10, 16)
if err != nil {
return
}
end, err := strconv.ParseUint(splitRange[1], 10, 16)
if err != nil {
return
}
_ = Default.AddPortRange(uint16(start), uint16(end))
}
func (pm *manager) AddPortRange(startPort, endPort uint16) error {
func (pm *port) AddRange(startPort, endPort uint16) error {
pm.mu.Lock()
defer pm.mu.Unlock()
if startPort > endPort {
return fmt.Errorf("start port cannot be greater than end port")
}
for port := startPort; port <= endPort; port++ {
if _, exists := pm.ports[port]; !exists {
pm.ports[port] = false
pm.sortedPorts = append(pm.sortedPorts, port)
for index := startPort; index <= endPort; index++ {
if _, exists := pm.ports[index]; !exists {
pm.ports[index] = false
pm.sortedPorts = append(pm.sortedPorts, index)
}
}
sort.Slice(pm.sortedPorts, func(i, j int) bool {
@@ -68,19 +45,19 @@ func (pm *manager) AddPortRange(startPort, endPort uint16) error {
return nil
}
func (pm *manager) GetUnassignedPort() (uint16, bool) {
func (pm *port) Unassigned() (uint16, bool) {
pm.mu.Lock()
defer pm.mu.Unlock()
for _, port := range pm.sortedPorts {
if !pm.ports[port] {
return port, true
for _, index := range pm.sortedPorts {
if !pm.ports[index] {
return index, true
}
}
return 0, false
}
func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
func (pm *port) SetStatus(port uint16, assigned bool) error {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -88,7 +65,7 @@ func (pm *manager) SetPortStatus(port uint16, assigned bool) error {
return nil
}
func (pm *manager) ClaimPort(port uint16) (claimed bool) {
func (pm *port) Claim(port uint16) (claimed bool) {
pm.mu.Lock()
defer pm.mu.Unlock()
+13 -13
View File
@@ -1,18 +1,18 @@
package random
import (
mathrand "math/rand"
"strings"
"time"
)
import "crypto/rand"
func GenerateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyz"
seededRand := mathrand.New(mathrand.NewSource(time.Now().UnixNano() + int64(mathrand.Intn(9999))))
var result strings.Builder
for i := 0; i < length; i++ {
randomIndex := seededRand.Intn(len(charset))
result.WriteString(string(charset[randomIndex]))
func GenerateRandomString(length int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, length)
if _, err := rand.Read(b); err != nil {
return "", err
}
return result.String()
for i := range b {
b[i] = charset[int(b[i])%len(charset)]
}
return string(b), nil
}
@@ -1,13 +1,25 @@
package session
package registry
import (
"fmt"
"sync"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
)
type Key = types.SessionKey
type Session interface {
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *types.Detail
}
type Registry interface {
Get(key Key) (session Session, err error)
GetWithUser(user string, key Key) (session Session, err error)
@@ -35,12 +47,12 @@ func (r *registry) Get(key Key) (session Session, err error) {
userID, ok := r.slugIndex[key]
if !ok {
return nil, fmt.Errorf("session not found")
return nil, fmt.Errorf("Session not found")
}
client, ok := r.byUser[userID][key]
if !ok {
return nil, fmt.Errorf("session not found")
return nil, fmt.Errorf("Session not found")
}
return client, nil
}
@@ -51,7 +63,7 @@ func (r *registry) GetWithUser(user string, key Key) (session Session, err error
client, ok := r.byUser[user][key]
if !ok {
return nil, fmt.Errorf("session not found")
return nil, fmt.Errorf("Session not found")
}
return client, nil
}
@@ -81,7 +93,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
}
client, ok := r.byUser[user][oldKey]
if !ok {
return fmt.Errorf("session not found")
return fmt.Errorf("Session not found")
}
delete(r.byUser[user], oldKey)
@@ -97,7 +109,7 @@ func (r *registry) Update(user string, oldKey, newKey Key) error {
return nil
}
func (r *registry) Register(key Key, session Session) (success bool) {
func (r *registry) Register(key Key, userSession Session) (success bool) {
r.mu.Lock()
defer r.mu.Unlock()
@@ -105,12 +117,12 @@ func (r *registry) Register(key Key, session Session) (success bool) {
return false
}
userID := session.Lifecycle().User()
userID := userSession.Lifecycle().User()
if r.byUser[userID] == nil {
r.byUser[userID] = make(map[Key]Session)
}
r.byUser[userID][key] = session
r.byUser[userID][key] = userSession
r.slugIndex[key] = userID
return true
}
+40
View File
@@ -0,0 +1,40 @@
package transport
import (
"errors"
"log"
"net"
"tunnel_pls/internal/registry"
)
type httpServer struct {
handler *httpHandler
port string
}
func NewHTTPServer(port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &httpServer{
handler: newHTTPHandler(sessionRegistry, redirectTLS),
port: port,
}
}
func (ht *httpServer) Listen() (net.Listener, error) {
return net.Listen("tcp", ":"+ht.port)
}
func (ht *httpServer) Serve(listener net.Listener) error {
log.Printf("HTTP server is starting on port %s", ht.port)
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.handler.handler(conn, false)
}
}
+227
View File
@@ -0,0 +1,227 @@
package transport
import (
"bufio"
"errors"
"fmt"
"log"
"net"
"net/http"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware"
"tunnel_pls/internal/registry"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type httpHandler struct {
sessionRegistry registry.Registry
redirectTLS bool
}
func newHTTPHandler(sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
return &httpHandler{
sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
}
}
func (hh *httpHandler) redirect(conn net.Conn, status int, location string) error {
_, err := conn.Write([]byte(fmt.Sprintf("HTTP/1.1 %d Moved Permanently\r\n", status) +
fmt.Sprintf("Location: %s", location) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
return err
}
return nil
}
func (hh *httpHandler) badRequest(conn net.Conn) error {
if _, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n")); err != nil {
return err
}
return nil
}
func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
defer hh.closeConnection(conn)
dstReader := bufio.NewReader(conn)
reqhf, err := header.NewRequest(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
slug, err := hh.extractSlug(reqhf)
if err != nil {
_ = hh.badRequest(conn)
return
}
if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")))
return
}
if hh.handlePingRequest(slug, conn) {
return
}
sshSession, err := hh.getSession(slug)
if err != nil {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return
}
hw := stream.New(conn, dstReader, conn.RemoteAddr())
defer func(hw stream.HTTP) {
err = hw.Close()
if err != nil {
log.Printf("Error closing HTTP stream: %v", err)
}
}(hw)
hh.forwardRequest(hw, reqhf, sshSession)
}
func (hh *httpHandler) closeConnection(conn net.Conn) {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
}
}
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".")
if len(host) < 1 {
return "", errors.New("invalid host")
}
return host[0], nil
}
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
return !isTLS && hh.redirectTLS
}
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
if slug != "ping" {
return false
}
_, err := conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
}
return true
}
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
return nil, err
}
return sshSession, nil
}
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
channel, err := hh.openForwardedChannel(hw, sshSession)
defer func() {
err = channel.Close()
if err != nil {
log.Printf("Error closing forwarded channel: %v", err)
}
}()
if err != nil {
log.Printf("Failed to establish channel: %v", err)
sshSession.Forwarder().WriteBadGatewayResponse(hw)
return
}
hh.setupMiddlewares(hw)
if err = hh.sendInitialRequest(hw, initialRequest, channel); err != nil {
log.Printf("Failed to forward initial request: %v", err)
return
}
sshSession.Forwarder().HandleConnection(hw, channel)
}
func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.Session) (ssh.Channel, error) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(hw.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
hh.cleanupUnusedChannel(channel, reqs)
}
}()
select {
case result := <-resultChan:
if result.err != nil {
return nil, result.err
}
go ssh.DiscardRequests(result.reqs)
return result.channel, nil
case <-time.After(5 * time.Second):
return nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func (hh *httpHandler) cleanupUnusedChannel(channel ssh.Channel, reqs <-chan *ssh.Request) {
if channel != nil {
if err := channel.Close(); err != nil {
log.Printf("Failed to close unused channel: %v", err)
}
go ssh.DiscardRequests(reqs)
}
}
func (hh *httpHandler) setupMiddlewares(hw stream.HTTP) {
fingerprintMiddleware := middleware.NewTunnelFingerprint()
forwardedForMiddleware := middleware.NewForwardedFor(hw.RemoteAddr())
hw.UseResponseMiddleware(fingerprintMiddleware)
hw.UseRequestMiddleware(forwardedForMiddleware)
}
func (hh *httpHandler) sendInitialRequest(hw stream.HTTP, initialRequest header.RequestHeader, channel ssh.Channel) error {
hw.SetRequestHeader(initialRequest)
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
return fmt.Errorf("error applying request middlewares: %w", err)
}
if _, err := channel.Write(initialRequest.Finalize()); err != nil {
return fmt.Errorf("error writing to channel: %w", err)
}
return nil
}
+48
View File
@@ -0,0 +1,48 @@
package transport
import (
"crypto/tls"
"errors"
"log"
"net"
"tunnel_pls/internal/registry"
)
type https struct {
httpHandler *httpHandler
domain string
port string
}
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
return &https{
httpHandler: newHTTPHandler(sessionRegistry, redirectTLS),
domain: domain,
port: port,
}
}
func (ht *https) Listen() (net.Listener, error) {
tlsConfig, err := NewTLSConfig(ht.domain)
if err != nil {
return nil, err
}
return tls.Listen("tcp", ":"+ht.port, tlsConfig)
}
func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.port)
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return err
}
log.Printf("Error accepting connection: %v", err)
continue
}
go ht.httpHandler.handler(conn, true)
}
}
+66
View File
@@ -0,0 +1,66 @@
package transport
import (
"errors"
"fmt"
"io"
"log"
"net"
"golang.org/x/crypto/ssh"
)
type tcp struct {
port uint16
forwarder forwarder
}
type forwarder interface {
CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
HandleConnection(dst io.ReadWriter, src ssh.Channel)
}
func NewTCPServer(port uint16, forwarder forwarder) Transport {
return &tcp{
port: port,
forwarder: forwarder,
}
}
func (tt *tcp) Listen() (net.Listener, error) {
return net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", tt.port))
}
func (tt *tcp) Serve(listener net.Listener) error {
for {
conn, err := listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
log.Printf("Error accepting connection: %v", err)
continue
}
go tt.handleTcp(conn)
}
}
func (tt *tcp) handleTcp(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
log.Printf("Failed to close connection: %v", err)
}
}()
payload := tt.forwarder.CreateForwardedTCPIPPayload(conn.RemoteAddr())
channel, reqs, err := tt.forwarder.OpenForwardedChannel(payload)
if err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
tt.forwarder.HandleConnection(conn, channel)
}
+5 -11
View File
@@ -1,4 +1,4 @@
package server
package transport
import (
"context"
@@ -301,22 +301,16 @@ func (tm *tlsManager) initCertMagic() error {
func (tm *tlsManager) getTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: tm.getCertificate,
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
SessionTicketsDisabled: false,
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
tls.TLS_CHACHA20_POLY1305_SHA256,
},
MinVersion: tls.VersionTLS13,
MaxVersion: tls.VersionTLS13,
CurvePreferences: []tls.CurveID{
tls.X25519,
},
ClientAuth: tls.NoClientCert,
NextProtos: nil,
SessionTicketsDisabled: false,
ClientAuth: tls.NoClientCert,
}
}
+10
View File
@@ -0,0 +1,10 @@
package transport
import (
"net"
)
type Transport interface {
Listen() (net.Listener, error)
Serve(listener net.Listener) error
}
+82 -6
View File
@@ -4,19 +4,23 @@ import (
"context"
"fmt"
"log"
"net"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/key"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/internal/version"
"tunnel_pls/server"
"tunnel_pls/session"
"tunnel_pls/version"
"golang.org/x/crypto/ssh"
)
@@ -32,6 +36,12 @@ func main() {
log.Printf("Starting %s", version.GetVersion())
err := config.Load()
if err != nil {
log.Fatalf("Failed to load configuration: %s", err)
return
}
mode := strings.ToLower(config.Getenv("MODE", "standalone"))
isNodeMode := mode == "node"
@@ -41,7 +51,7 @@ func main() {
go func() {
pprofAddr := fmt.Sprintf("localhost:%s", pprofPort)
log.Printf("Starting pprof server on http://%s/debug/pprof/", pprofAddr)
if err := http.ListenAndServe(pprofAddr, nil); err != nil {
if err = http.ListenAndServe(pprofAddr, nil); err != nil {
log.Printf("pprof server error: %v", err)
}
}()
@@ -53,7 +63,7 @@ func main() {
}
sshKeyPath := "certs/ssh/id_rsa"
if err := key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
if err = key.GenerateSSHKeyIfNotExist(sshKeyPath); err != nil {
log.Fatalf("Failed to generate SSH key: %s", err)
}
@@ -68,7 +78,7 @@ func main() {
}
sshConfig.AddHostKey(private)
sessionRegistry := session.NewRegistry()
sessionRegistry := registry.NewRegistry()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -107,9 +117,75 @@ func main() {
}()
}
portManager := port.New()
rawRange := config.Getenv("ALLOWED_PORTS", "")
if rawRange != "" {
splitRange := strings.Split(rawRange, "-")
if len(splitRange) == 2 {
var start, end uint64
start, err = strconv.ParseUint(splitRange[0], 10, 16)
if err != nil {
log.Fatalf("Failed to parse start port: %s", err)
}
end, err = strconv.ParseUint(splitRange[1], 10, 16)
if err != nil {
log.Fatalf("Failed to parse end port: %s", err)
}
if err = portManager.AddRange(uint16(start), uint16(end)); err != nil {
log.Fatalf("Failed to add port range: %s", err)
}
log.Printf("PortRegistry range configured: %d-%d", start, end)
} else {
log.Printf("Invalid ALLOWED_PORTS format, expected 'start-end', got: %s", rawRange)
}
}
tlsEnabled := config.Getenv("TLS_ENABLED", "false") == "true"
redirectTLS := config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true"
go func() {
httpPort := config.Getenv("HTTP_PORT", "8080")
var httpListener net.Listener
httpserver := transport.NewHTTPServer(httpPort, sessionRegistry, redirectTLS)
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
if tlsEnabled {
go func() {
httpsPort := config.Getenv("HTTPS_PORT", "8443")
domain := config.Getenv("DOMAIN", "localhost")
var httpListener net.Listener
httpserver := transport.NewHTTPSServer(domain, httpsPort, sessionRegistry, redirectTLS)
httpListener, err = httpserver.Listen()
if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err)
return
}
err = httpserver.Serve(httpListener)
if err != nil {
errChan <- fmt.Errorf("error when serving http server: %w", err)
return
}
}()
}
var app server.Server
go func() {
app, err = server.New(sshConfig, sessionRegistry, grpcClient)
sshPort := config.Getenv("PORT", "2200")
app, err = server.New(sshConfig, sessionRegistry, grpcClient, portManager, sshPort)
if err != nil {
errChan <- fmt.Errorf("failed to start server: %s", err)
return
-276
View File
@@ -1,276 +0,0 @@
package server
import (
"bufio"
"bytes"
"fmt"
)
type HeaderManager interface {
Get(key string) []byte
Set(key string, value []byte)
Remove(key string)
Finalize() []byte
}
type ResponseHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type RequestHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
}
type responseHeaderFactory struct {
startLine []byte
headers map[string]string
}
type requestHeaderFactory struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) {
switch v := r.(type) {
case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
}
}
func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) {
header := &requestHeaderFactory{
headers: make(map[string]string, 16),
}
lineEnd := bytes.IndexByte(headerData, '\n')
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no newline found")
}
startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n")
header.startLine = make([]byte, len(startLine))
copy(header.startLine, startLine)
parts := bytes.Split(startLine, []byte{' '})
if len(parts) < 3 {
return nil, fmt.Errorf("invalid request line")
}
header.method = string(parts[0])
header.path = string(parts[1])
header.version = string(parts[2])
remaining := headerData[lineEnd+1:]
for len(remaining) > 0 {
lineEnd = bytes.IndexByte(remaining, '\n')
if lineEnd == -1 {
lineEnd = len(remaining)
}
line := bytes.TrimRight(remaining[:lineEnd], "\r\n")
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx != -1 {
key := bytes.TrimSpace(line[:colonIdx])
value := bytes.TrimSpace(line[colonIdx+1:])
header.headers[string(key)] = string(value)
}
if lineEnd == len(remaining) {
break
}
remaining = remaining[lineEnd+1:]
}
return header, nil
}
func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) {
header := &requestHeaderFactory{
headers: make(map[string]string, 16),
}
startLineBytes, err := br.ReadSlice('\n')
if err != nil {
if err == bufio.ErrBufferFull {
var startLine string
startLine, err = br.ReadString('\n')
if err != nil {
return nil, err
}
startLineBytes = []byte(startLine)
} else {
return nil, err
}
}
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
header.startLine = make([]byte, len(startLineBytes))
copy(header.startLine, startLineBytes)
parts := bytes.Split(startLineBytes, []byte{' '})
if len(parts) < 3 {
return nil, fmt.Errorf("invalid request line")
}
header.method = string(parts[0])
header.path = string(parts[1])
header.version = string(parts[2])
for {
lineBytes, err := br.ReadSlice('\n')
if err != nil {
if err == bufio.ErrBufferFull {
var line string
line, err = br.ReadString('\n')
if err != nil {
return nil, err
}
lineBytes = []byte(line)
} else {
return nil, err
}
}
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
if len(lineBytes) == 0 {
break
}
colonIdx := bytes.IndexByte(lineBytes, ':')
if colonIdx == -1 {
continue
}
key := bytes.TrimSpace(lineBytes[:colonIdx])
value := bytes.TrimSpace(lineBytes[colonIdx+1:])
header.headers[string(key)] = string(value)
}
return header, nil
}
func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
header := &responseHeaderFactory{
startLine: nil,
headers: make(map[string]string),
}
lines := bytes.Split(startLine, []byte("\r\n"))
if len(lines) == 0 {
return header
}
header.startLine = lines[0]
for _, h := range lines[1:] {
if len(h) == 0 {
continue
}
parts := bytes.SplitN(h, []byte(":"), 2)
if len(parts) < 2 {
continue
}
key := parts[0]
val := bytes.TrimSpace(parts[1])
header.headers[string(key)] = string(val)
}
return header
}
func (resp *responseHeaderFactory) Get(key string) string {
return resp.headers[key]
}
func (resp *responseHeaderFactory) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeaderFactory) Remove(key string) {
delete(resp.headers, key)
}
func (resp *responseHeaderFactory) Finalize() []byte {
var buf bytes.Buffer
buf.Write(resp.startLine)
buf.WriteString("\r\n")
for key, val := range resp.headers {
buf.WriteString(key)
buf.WriteString(": ")
buf.WriteString(val)
buf.WriteString("\r\n")
}
buf.WriteString("\r\n")
return buf.Bytes()
}
func (req *requestHeaderFactory) Get(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeaderFactory) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeaderFactory) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeaderFactory) GetMethod() string {
return req.method
}
func (req *requestHeaderFactory) GetPath() string {
return req.path
}
func (req *requestHeaderFactory) GetVersion() string {
return req.version
}
func (req *requestHeaderFactory) Finalize() []byte {
var buf bytes.Buffer
buf.Write(req.startLine)
buf.WriteString("\r\n")
for key, val := range req.headers {
buf.WriteString(key)
buf.WriteString(": ")
buf.WriteString(val)
buf.WriteString("\r\n")
}
buf.WriteString("\r\n")
return buf.Bytes()
}
-395
View File
@@ -1,395 +0,0 @@
package server
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"log"
"net"
"regexp"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/session"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type HTTPWriter interface {
io.Reader
io.Writer
GetRemoteAddr() net.Addr
GetWriter() io.Writer
AddResponseMiddleware(mw ResponseMiddleware)
AddRequestStartMiddleware(mw RequestMiddleware)
SetRequestHeader(header RequestHeaderManager)
GetRequestStartMiddleware() []RequestMiddleware
}
type customWriter struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader ResponseHeaderManager
reqHeader RequestHeaderManager
respMW []ResponseMiddleware
reqStartMW []RequestMiddleware
reqEndMW []RequestMiddleware
}
func (cw *customWriter) GetRemoteAddr() net.Addr {
return cw.remoteAddr
}
func (cw *customWriter) GetWriter() io.Writer {
return cw.writer
}
func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) {
cw.respMW = append(cw.respMW, mw)
}
func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) {
cw.reqStartMW = append(cw.reqStartMW, mw)
}
func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) {
cw.reqHeader = header
}
func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware {
return cw.reqStartMW
}
func (cw *customWriter) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := cw.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
idx := bytes.Index(tmp, DELIMITER)
if idx == -1 {
copy(p, tmp)
if err != nil {
return read, err
}
return read, nil
}
header := tmp[:idx+len(DELIMITER)]
body := tmp[idx+len(DELIMITER):]
if !isHTTPHeader(header) {
copy(p, tmp)
return read, nil
}
for _, m := range cw.reqEndMW {
err = m.HandleRequest(cw.reqHeader)
if err != nil {
log.Printf("Error when applying request middleware: %v", err)
return 0, err
}
}
reqhf, err := NewRequestHeaderFactory(header)
if err != nil {
return 0, err
}
for _, m := range cw.reqStartMW {
if mwErr := m.HandleRequest(reqhf); mwErr != nil {
log.Printf("Error when applying request middleware: %v", mwErr)
return 0, mwErr
}
}
cw.reqHeader = reqhf
finalHeader := reqhf.Finalize()
combined := append(finalHeader, body...)
n := copy(p, combined)
return n, nil
}
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
return &customWriter{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}
func (cw *customWriter) Write(p []byte) (int, error) {
if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" {
cw.respHeader = nil
}
if cw.respHeader != nil {
n, err := cw.writer.Write(p)
if err != nil {
return n, err
}
return n, nil
}
cw.buf = append(cw.buf, p...)
idx := bytes.Index(cw.buf, DELIMITER)
if idx == -1 {
return len(p), nil
}
header := cw.buf[:idx+len(DELIMITER)]
body := cw.buf[idx+len(DELIMITER):]
if !isHTTPHeader(header) {
_, err := cw.writer.Write(cw.buf)
cw.buf = nil
if err != nil {
return 0, err
}
return len(p), nil
}
resphf := NewResponseHeaderFactory(header)
for _, m := range cw.respMW {
err := m.HandleResponse(resphf, body)
if err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return 0, err
}
}
header = resphf.Finalize()
cw.respHeader = resphf
_, err := cw.writer.Write(header)
if err != nil {
return 0, err
}
if len(body) > 0 {
_, err = cw.writer.Write(body)
if err != nil {
return 0, err
}
}
cw.buf = nil
return len(p), nil
}
var redirectTLS = false
type HTTPServer interface {
ListenAndServe() error
ListenAndServeTLS() error
handler(conn net.Conn)
handlerTLS(conn net.Conn)
}
type httpServer struct {
sessionRegistry session.Registry
}
func NewHTTPServer(sessionRegistry session.Registry) HTTPServer {
return &httpServer{sessionRegistry: sessionRegistry}
}
func (hs *httpServer) ListenAndServe() error {
httpPort := config.Getenv("HTTP_PORT", "8080")
listener, err := net.Listen("tcp", ":"+httpPort)
if err != nil {
return errors.New("Error listening: " + err.Error())
}
if config.Getenv("TLS_ENABLED", "false") == "true" && config.Getenv("TLS_REDIRECT", "false") == "true" {
redirectTLS = true
}
go func() {
for {
var conn net.Conn
conn, err = listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handler(conn)
}
}()
return nil
}
func (hs *httpServer) handler(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("Error closing connection: %v", err)
return
}
return
}()
dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeaderFactory(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
host := strings.Split(reqhf.Get("Host"), ".")
if len(host) < 1 {
_, err := conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
if err != nil {
log.Println("Failed to write 400 Bad Request:", err)
return
}
return
}
slug := host[0]
if redirectTLS {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://%s.%s/\r\n", slug, config.Getenv("DOMAIN", "localhost")) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
if slug == "ping" {
_, err = conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
return
}
return
}
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
forwardRequest(cw, reqhf, sshSession)
return
}
func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) {
payload := sshSession.Forwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
var channel ssh.Channel
var reqs <-chan *ssh.Request
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
channel = result.channel
reqs = result.reqs
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.Forwarder().WriteBadGatewayResponse(cw.GetWriter())
return
}
go ssh.DiscardRequests(reqs)
fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr())
cw.AddResponseMiddleware(fingerprintMiddleware)
cw.AddRequestStartMiddleware(forwardedForMiddleware)
cw.SetRequestHeader(initialRequest)
for _, m := range cw.GetRequestStartMiddleware() {
if err := m.HandleRequest(initialRequest); err != nil {
log.Printf("Error handling request: %v", err)
return
}
}
_, err := channel.Write(initialRequest.Finalize())
if err != nil {
log.Printf("Failed to forward request: %v", err)
return
}
sshSession.Forwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return
}
-112
View File
@@ -1,112 +0,0 @@
package server
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"log"
"net"
"strings"
"tunnel_pls/internal/config"
"tunnel_pls/types"
)
func (hs *httpServer) ListenAndServeTLS() error {
domain := config.Getenv("DOMAIN", "localhost")
httpsPort := config.Getenv("HTTPS_PORT", "8443")
tlsConfig, err := NewTLSConfig(domain)
if err != nil {
return fmt.Errorf("failed to initialize TLS config: %w", err)
}
ln, err := tls.Listen("tcp", ":"+httpsPort, tlsConfig)
if err != nil {
return err
}
go func() {
for {
var conn net.Conn
conn, err = ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Println("https server closed")
}
log.Printf("Error accepting connection: %v", err)
continue
}
go hs.handlerTLS(conn)
}
}()
return nil
}
func (hs *httpServer) handlerTLS(conn net.Conn) {
defer func() {
err := conn.Close()
if err != nil {
log.Printf("Error closing connection: %v", err)
return
}
return
}()
dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeaderFactory(dstReader)
if err != nil {
log.Printf("Error creating request header: %v", err)
return
}
host := strings.Split(reqhf.Get("Host"), ".")
if len(host) < 1 {
_, err = conn.Write([]byte("HTTP/1.1 400 Bad Request\r\n\r\n"))
if err != nil {
log.Println("Failed to write 400 Bad Request:", err)
return
}
return
}
slug := host[0]
if slug == "ping" {
_, err = conn.Write([]byte(
"HTTP/1.1 200 OK\r\n" +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"Access-Control-Allow-Origin: *\r\n" +
"Access-Control-Allow-Methods: GET, HEAD, OPTIONS\r\n" +
"Access-Control-Allow-Headers: *\r\n" +
"\r\n",
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
return
}
return
}
sshSession, err := hs.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.HTTP,
})
if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 301 Moved Permanently\r\n" +
fmt.Sprintf("Location: https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug) +
"Content-Length: 0\r\n" +
"Connection: close\r\n" +
"\r\n"))
if err != nil {
log.Println("Failed to write 301 Moved Permanently:", err)
return
}
return
}
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
forwardRequest(cw, reqhf, sshSession)
return
}
-41
View File
@@ -1,41 +0,0 @@
package server
import (
"net"
)
type RequestMiddleware interface {
HandleRequest(header RequestHeaderManager) error
}
type ResponseMiddleware interface {
HandleResponse(header ResponseHeaderManager, body []byte) error
}
type TunnelFingerprint struct{}
func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{}
}
func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
header.Set("Server", "Tunnel Please")
return nil
}
type ForwardedFor struct {
addr net.Addr
}
func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr}
}
func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil {
return err
}
header.Set("X-Forwarded-For", host)
return nil
}
+16 -27
View File
@@ -7,8 +7,9 @@ import (
"log"
"net"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/grpc/client"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session"
"golang.org/x/crypto/ssh"
@@ -19,46 +20,34 @@ type Server interface {
Close() error
}
type server struct {
listener net.Listener
sshPort string
sshListener net.Listener
config *ssh.ServerConfig
sessionRegistry session.Registry
grpcClient client.Client
sessionRegistry registry.Registry
portRegistry port.Port
}
func New(sshConfig *ssh.ServerConfig, sessionRegistry session.Registry, grpcClient client.Client) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", config.Getenv("PORT", "2200")))
func New(sshConfig *ssh.ServerConfig, sessionRegistry registry.Registry, grpcClient client.Client, portRegistry port.Port, sshPort string) (Server, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", sshPort))
if err != nil {
log.Fatalf("failed to listen on port 2200: %v", err)
return nil, err
}
HttpServer := NewHTTPServer(sessionRegistry)
err = HttpServer.ListenAndServe()
if err != nil {
log.Fatalf("failed to start http server: %v", err)
return nil, err
}
if config.Getenv("TLS_ENABLED", "false") == "true" {
err = HttpServer.ListenAndServeTLS()
if err != nil {
log.Fatalf("failed to start https server: %v", err)
return nil, err
}
}
return &server{
listener: listener,
sshPort: sshPort,
sshListener: listener,
config: sshConfig,
sessionRegistry: sessionRegistry,
grpcClient: grpcClient,
sessionRegistry: sessionRegistry,
portRegistry: portRegistry,
}, nil
}
func (s *server) Start() {
log.Println("SSH server is starting on port 2200...")
log.Printf("SSH server is starting on port %s", s.sshPort)
for {
conn, err := s.listener.Accept()
conn, err := s.sshListener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
log.Println("listener closed, stopping server")
@@ -73,7 +62,7 @@ func (s *server) Start() {
}
func (s *server) Close() error {
return s.listener.Close()
return s.sshListener.Close()
}
func (s *server) handleConnection(conn net.Conn) {
@@ -103,7 +92,7 @@ func (s *server) handleConnection(conn net.Conn) {
cancel()
}
log.Println("SSH connection established:", sshConn.User())
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, user)
sshSession := session.New(sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
err = sshSession.Start()
if err != nil {
log.Printf("SSH session ended with error: %v", err)
+62 -86
View File
@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net"
@@ -35,139 +36,114 @@ type forwarder struct {
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
lifecycle Lifecycle
conn ssh.Conn
}
func New(slug slug.Slug) Forwarder {
func New(slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: types.UNKNOWN,
forwardedPort: 0,
slug: slug,
lifecycle: nil,
conn: conn,
}
}
type Lifecycle interface {
Connection() ssh.Conn
}
type Forwarder interface {
SetType(tunnelType types.TunnelType)
SetLifecycle(lifecycle Lifecycle)
SetForwardedPort(port uint16)
SetListener(listener net.Listener)
Listener() net.Listener
TunnelType() types.TunnelType
ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
HandleConnection(dst io.ReadWriter, src ssh.Channel)
CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
WriteBadGatewayResponse(dst io.Writer)
AcceptTCPConnections()
Close() error
}
func (f *forwarder) SetLifecycle(lifecycle Lifecycle) {
f.lifecycle = lifecycle
}
func (f *forwarder) AcceptTCPConnections() {
for {
conn, err := f.Listener().Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
log.Printf("Error accepting connection: %v", err)
continue
}
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
log.Printf("Failed to set connection deadline: %v", err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := f.lifecycle.Connection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err}
}()
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
case resultChan <- channelResult{channel, reqs, err}:
default:
if channel != nil {
err = channel.Close()
if err != nil {
log.Printf("Failed to close unused channel: %v", err)
return
}
continue
}
if err := conn.SetDeadline(time.Time{}); err != nil {
log.Printf("Failed to clear connection deadline: %v", err)
}
go ssh.DiscardRequests(result.reqs)
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
go ssh.DiscardRequests(reqs)
}
}
}()
select {
case result := <-resultChan:
return result.channel, result.reqs, result.err
case <-time.After(5 * time.Second):
return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
func closeWriter(w io.Writer) error {
if cw, ok := w.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
if closer, ok := w.(io.Closer); ok {
return closer.Close()
}
return nil
}
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
var errs []error
_, err := copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
}
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
}
return errors.Join(errs...)
}
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
defer func() {
_, err := io.Copy(io.Discard, src)
if err != nil {
log.Printf("Failed to discard connection: %v", err)
}
err = src.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing source channel: %v", err)
}
if closer, ok := dst.(io.Closer); ok {
err = closer.Close()
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error closing destination connection: %v", err)
}
}
}()
log.Printf("Handling new forwarded connection from %s", remoteAddr)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, err := copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying src→dst: %v", err)
err := f.copyAndClose(dst, src, "src to dst")
if err != nil {
log.Println("Error during copy: ", err)
return
}
}()
go func() {
defer wg.Done()
_, err := copyWithBuffer(src, dst)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying dst→src: %v", err)
err := f.copyAndClose(src, dst, "dst to src")
if err != nil {
log.Println("Error during copy: ", err)
return
}
}()
+83
View File
@@ -0,0 +1,83 @@
package interaction
import (
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) comingSoonUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.showingComingSoon = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
func (m *model) comingSoonView() string {
isCompact := shouldUseCompactLayout(m.width, 60)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 3
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
messageBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Background(lipgloss.Color("#1A1A2E")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(messageBoxWidth).
Align(lipgloss.Center)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Coming Soon"
} else {
title = "⏳ Coming Soon"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
var message string
if shouldUseCompactLayout(m.width, 50) {
message = "Coming soon!\nStay tuned."
} else {
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
}
b.WriteString(messageBoxStyle.Render(message))
b.WriteString("\n\n")
var helpText string
if shouldUseCompactLayout(m.width, 60) {
helpText = "Press any key..."
} else {
helpText = "This message will disappear in 5 seconds or press any key..."
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
+83
View File
@@ -0,0 +1,83 @@
package interaction
import (
"strings"
"time"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) commandsUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch {
case key.Matches(msg, m.keymap.quit):
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case msg.String() == "enter":
selectedItem := m.commandList.SelectedItem()
if selectedItem != nil {
item := selectedItem.(commandItem)
if item.name == "slug" {
m.showingCommands = false
m.editingSlug = true
m.slugInput.SetValue(m.interaction.slug.String())
m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" {
m.showingCommands = false
m.showingComingSoon = true
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
}
m.showingCommands = false
return m, nil
}
case msg.String() == "esc":
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
m.commandList, cmd = m.commandList.Update(msg)
return m, cmd
}
func (m *model) commandsView() string {
isCompact := shouldUseCompactLayout(m.width, 60)
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Commands"
} else {
title = "⚡ Commands"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
b.WriteString(m.commandList.View())
b.WriteString("\n")
var helpText string
if isCompact {
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
} else {
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
+186
View File
@@ -0,0 +1,186 @@
package interaction
import (
"fmt"
"strings"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) dashboardUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch {
case key.Matches(msg, m.keymap.quit):
m.quitting = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
case key.Matches(msg, m.keymap.command):
m.showingCommands = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
return m, nil
}
func (m *model) dashboardView() string {
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1)
subtitleStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Italic(true)
urlStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Underline(true).
Italic(true)
urlBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Bold(true).
Italic(true)
keyHintStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Bold(true)
var b strings.Builder
isCompact := shouldUseCompactLayout(m.width, 85)
var asciiArtMargin int
if isCompact {
asciiArtMargin = 0
} else {
asciiArtMargin = 1
}
asciiArtStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
MarginBottom(asciiArtMargin)
var asciiArt string
if shouldUseCompactLayout(m.width, 50) {
asciiArt = "TUNNEL PLS"
} else if isCompact {
asciiArt = `
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
} else {
asciiArt = `
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
}
b.WriteString(asciiArtStyle.Render(asciiArt))
b.WriteString("\n")
if !shouldUseCompactLayout(m.width, 60) {
b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
b.WriteString(urlStyle.Render("https://fossy.my.id"))
b.WriteString("\n\n")
} else {
b.WriteString("\n")
}
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
responsiveInfoBox := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(boxMaxWidth)
authenticatedUser := m.interaction.user
userInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Bold(true)
sectionHeaderStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Bold(true)
addressStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA"))
var infoContent string
if shouldUseCompactLayout(m.width, 70) {
infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL()))))
} else {
infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL())))
}
b.WriteString(responsiveInfoBox.Render(infoContent))
b.WriteString("\n")
var quickActionsTitle string
if shouldUseCompactLayout(m.width, 50) {
quickActionsTitle = "Actions"
} else if isCompact {
quickActionsTitle = "Quick Actions"
} else {
quickActionsTitle = "✨ Quick Actions"
}
b.WriteString(titleStyle.Render(quickActionsTitle))
b.WriteString("\n")
var featureMargin int
if isCompact {
featureMargin = 1
} else {
featureMargin = 2
}
compactFeatureStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginLeft(featureMargin)
var commandsText string
var quitText string
if shouldUseCompactLayout(m.width, 60) {
commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
} else {
commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
}
b.WriteString(compactFeatureStyle.Render(commandsText))
b.WriteString("\n")
b.WriteString(compactFeatureStyle.Render(quitText))
if !shouldUseCompactLayout(m.width, 70) {
b.WriteString("\n\n")
footerStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true)
b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
}
return b.String()
}
+23 -623
View File
@@ -2,12 +2,8 @@ package interaction
import (
"context"
"fmt"
"log"
"strings"
"time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/random"
"tunnel_pls/session/slug"
"tunnel_pls/types"
@@ -21,20 +17,9 @@ import (
"golang.org/x/crypto/ssh"
)
type Lifecycle interface {
Close() error
User() string
}
type SessionRegistry interface {
Update(user string, oldKey, newKey types.SessionKey) error
}
type Interaction interface {
Mode() types.Mode
SetChannel(channel ssh.Channel)
SetLifecycle(lifecycle Lifecycle)
SetSessionRegistry(registry SessionRegistry)
SetMode(m types.Mode)
SetWH(w, h int)
Start()
@@ -42,17 +27,23 @@ type Interaction interface {
Send(message string) error
}
type SessionRegistry interface {
Update(user string, oldKey, newKey types.SessionKey) error
}
type Forwarder interface {
Close() error
TunnelType() types.TunnelType
ForwardedPort() uint16
}
type CloseFunc func() error
type interaction struct {
channel ssh.Channel
slug slug.Slug
forwarder Forwarder
lifecycle Lifecycle
closeFunc CloseFunc
user string
sessionRegistry SessionRegistry
program *tea.Program
ctx context.Context
@@ -84,67 +75,21 @@ func (i *interaction) SetWH(w, h int) {
}
}
type commandItem struct {
name string
desc string
}
type model struct {
domain string
protocol string
tunnelType types.TunnelType
port uint16
keymap keymap
help help.Model
quitting bool
showingCommands bool
editingSlug bool
showingComingSoon bool
commandList list.Model
slugInput textinput.Model
slugError string
interaction *interaction
width int
height int
}
func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
}
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
}
type keymap struct {
quit key.Binding
command key.Binding
random key.Binding
}
type tickMsg time.Time
func New(slug slug.Slug, forwarder Forwarder) Interaction {
func New(slug slug.Slug, forwarder Forwarder, sessionRegistry SessionRegistry, user string, closeFunc CloseFunc) Interaction {
ctx, cancel := context.WithCancel(context.Background())
return &interaction{
channel: nil,
slug: slug,
forwarder: forwarder,
lifecycle: nil,
sessionRegistry: nil,
closeFunc: closeFunc,
user: user,
sessionRegistry: sessionRegistry,
program: nil,
ctx: ctx,
cancel: cancel,
}
}
func (i *interaction) SetSessionRegistry(registry SessionRegistry) {
i.sessionRegistry = registry
}
func (i *interaction) SetLifecycle(lifecycle Lifecycle) {
i.lifecycle = lifecycle
}
func (i *interaction) SetChannel(channel ssh.Channel) {
i.channel = channel
}
@@ -159,47 +104,7 @@ func (i *interaction) Stop() {
}
}
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
width := screenWidth - padding
if width > maxWidth {
width = maxWidth
}
if width < minWidth {
width = minWidth
}
return width
}
func shouldUseCompactLayout(width int, threshold int) bool {
return width < threshold
}
func truncateString(s string, maxLength int) string {
if len(s) <= maxLength {
return s
}
if maxLength < 4 {
return s[:maxLength]
}
return s[:maxLength-3] + "..."
}
func (i commandItem) FilterValue() string { return i.name }
func (i commandItem) Title() string { return i.name }
func (i commandItem) Description() string { return i.desc }
func tickCmd(d time.Duration) tea.Cmd {
return tea.Tick(d, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func (m *model) Init() tea.Cmd {
return tea.Batch(textinput.Blink, tea.WindowSize())
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
switch msg := msg.(type) {
case tickMsg:
@@ -225,93 +130,18 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case tea.KeyMsg:
if m.showingComingSoon {
m.showingComingSoon = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
return m.comingSoonUpdate(msg)
}
if m.editingSlug {
if m.tunnelType != types.HTTP {
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
switch msg.String() {
case "esc":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter":
inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.lifecycle.User(), types.SessionKey{
Id: m.interaction.slug.String(),
Type: types.HTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.HTTP,
}); err != nil {
m.slugError = err.Error()
return m, nil
}
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "ctrl+c":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
default:
if key.Matches(msg, m.keymap.random) {
newSubdomain := generateRandomSubdomain()
m.slugInput.SetValue(newSubdomain)
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
return m, cmd
}
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
return m, cmd
}
return m.slugUpdate(msg)
}
if m.showingCommands {
switch {
case key.Matches(msg, m.keymap.quit):
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case msg.String() == "enter":
selectedItem := m.commandList.SelectedItem()
if selectedItem != nil {
item := selectedItem.(commandItem)
if item.name == "slug" {
m.showingCommands = false
m.editingSlug = true
m.slugInput.SetValue(m.interaction.slug.String())
m.slugInput.Focus()
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
} else if item.name == "tunnel-type" {
m.showingCommands = false
m.showingComingSoon = true
return m, tea.Batch(tickCmd(5*time.Second), tea.ClearScreen, textinput.Blink)
}
m.showingCommands = false
return m, nil
}
case msg.String() == "esc":
m.showingCommands = false
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
m.commandList, cmd = m.commandList.Update(msg)
return m, cmd
return m.commandsUpdate(msg)
}
switch {
case key.Matches(msg, m.keymap.quit):
m.quitting = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink, tea.Quit)
case key.Matches(msg, m.keymap.command):
m.showingCommands = true
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
return m.dashboardUpdate(msg)
}
return m, nil
@@ -323,448 +153,24 @@ func (i *interaction) Redraw() {
}
}
func (m *model) helpView() string {
return "\n" + m.help.ShortHelpView([]key.Binding{
m.keymap.command,
m.keymap.quit,
})
}
func (m *model) View() string {
if m.quitting {
return ""
}
if m.showingComingSoon {
isCompact := shouldUseCompactLayout(m.width, 60)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 3
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
messageBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
messageBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Background(lipgloss.Color("#1A1A2E")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(messageBoxWidth).
Align(lipgloss.Center)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Coming Soon"
} else {
title = "⏳ Coming Soon"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
var message string
if shouldUseCompactLayout(m.width, 50) {
message = "Coming soon!\nStay tuned."
} else {
message = "🚀 This feature is coming very soon!\n Stay tuned for updates."
}
b.WriteString(messageBoxStyle.Render(message))
b.WriteString("\n\n")
var helpText string
if shouldUseCompactLayout(m.width, 60) {
helpText = "Press any key..."
} else {
helpText = "This message will disappear in 5 seconds or press any key..."
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
return m.comingSoonView()
}
if m.editingSlug {
isCompact := shouldUseCompactLayout(m.width, 70)
isVeryCompact := shouldUseCompactLayout(m.width, 50)
var boxPadding int
var boxMargin int
if isVeryCompact {
boxPadding = 1
boxMargin = 1
} else if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
instructionStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginTop(1)
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
errorBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FF0000")).
Background(lipgloss.Color("#3D0000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1)
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
rulesBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1).
Width(rulesBoxWidth)
var b strings.Builder
var title string
if isVeryCompact {
title = "Edit Subdomain"
} else {
title = "🔧 Edit Subdomain"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
if m.tunnelType != types.HTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")).
Background(lipgloss.Color("#3D2000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FFA500")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(warningBoxWidth)
var warningText string
if isVeryCompact {
warningText = "⚠️ TCP tunnels don't support custom subdomains."
} else {
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
}
b.WriteString(warningBoxStyle.Render(warningText))
b.WriteString("\n\n")
var helpText string
if isVeryCompact {
helpText = "Press any key to go back"
} else {
helpText = "Press Enter or Esc to go back"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
var rulesContent string
if isVeryCompact {
rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
} else if isCompact {
rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
} else {
rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
}
b.WriteString(rulesBoxStyle.Render(rulesContent))
b.WriteString("\n")
var instruction string
if isVeryCompact {
instruction = "Custom subdomain:"
} else {
instruction = "Enter your custom subdomain:"
}
b.WriteString(instructionStyle.Render(instruction))
b.WriteString("\n")
if m.slugError != "" {
errorInputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(1)
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
b.WriteString("\n")
} else {
b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
}
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
if len(previewURL) > previewWidth-10 {
previewURL = truncateString(previewURL, previewWidth-10)
}
previewStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Italic(true).
Width(previewWidth)
b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
b.WriteString("\n")
var helpText string
if isVeryCompact {
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
} else {
helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
return m.slugView()
}
if m.showingCommands {
isCompact := shouldUseCompactLayout(m.width, 60)
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
var b strings.Builder
b.WriteString("\n")
var title string
if shouldUseCompactLayout(m.width, 40) {
title = "Commands"
} else {
title = "⚡ Commands"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
b.WriteString(m.commandList.View())
b.WriteString("\n")
var helpText string
if isCompact {
helpText = "↑/↓ Nav • Enter Select • Esc Cancel"
} else {
helpText = "↑/↓ Navigate • Enter Select • Esc Cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
return m.commandsView()
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1)
subtitleStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Italic(true)
urlStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Underline(true).
Italic(true)
urlBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Bold(true).
Italic(true)
keyHintStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#7D56F4")).
Bold(true)
var b strings.Builder
isCompact := shouldUseCompactLayout(m.width, 85)
var asciiArtMargin int
if isCompact {
asciiArtMargin = 0
} else {
asciiArtMargin = 1
}
asciiArtStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
MarginBottom(asciiArtMargin)
var asciiArt string
if shouldUseCompactLayout(m.width, 50) {
asciiArt = "TUNNEL PLS"
} else if isCompact {
asciiArt = `
▀█▀ █ █ █▄ █ █▄ █ ██▀ █ ▄▀▀ █ ▄▀▀
█ ▀▄█ █ ▀█ █ ▀█ █▄▄ █▄▄ ▄█▀ █▄▄ ▄█▀`
} else {
asciiArt = `
████████╗██╗ ██╗███╗ ██╗███╗ ██╗███████╗██╗ ██████╗ ██╗ ███████╗
╚══██╔══╝██║ ██║████╗ ██║████╗ ██║██╔════╝██║ ██╔══██╗██║ ██╔════╝
██║ ██║ ██║██╔██╗ ██║██╔██╗ ██║█████╗ ██║ ██████╔╝██║ ███████╗
██║ ██║ ██║██║╚██╗██║██║╚██╗██║██╔══╝ ██║ ██╔═══╝ ██║ ╚════██║
██║ ╚██████╔╝██║ ╚████║██║ ╚████║███████╗███████╗ ██║ ███████╗███████║
╚═╝ ╚═════╝ ╚═╝ ╚═══╝╚═╝ ╚═══╝╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝`
}
b.WriteString(asciiArtStyle.Render(asciiArt))
b.WriteString("\n")
if !shouldUseCompactLayout(m.width, 60) {
b.WriteString(subtitleStyle.Render("Secure tunnel service by Bagas • "))
b.WriteString(urlStyle.Render("https://fossy.my.id"))
b.WriteString("\n\n")
} else {
b.WriteString("\n")
}
boxMaxWidth := getResponsiveWidth(m.width, 10, 40, 80)
var boxPadding int
var boxMargin int
if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
responsiveInfoBox := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(boxMaxWidth)
authenticatedUser := m.interaction.lifecycle.User()
userInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Bold(true)
sectionHeaderStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#888888")).
Bold(true)
addressStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA"))
var infoContent string
if shouldUseCompactLayout(m.width, 70) {
infoContent = fmt.Sprintf("👤 %s\n\n%s\n%s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(fmt.Sprintf(" %s", urlBoxStyle.Render(m.getTunnelURL()))))
} else {
infoContent = fmt.Sprintf("👤 Authenticated as: %s\n\n%s\n %s",
userInfoStyle.Render(authenticatedUser),
sectionHeaderStyle.Render("🌐 FORWARDING ADDRESS:"),
addressStyle.Render(urlBoxStyle.Render(m.getTunnelURL())))
}
b.WriteString(responsiveInfoBox.Render(infoContent))
b.WriteString("\n")
var quickActionsTitle string
if shouldUseCompactLayout(m.width, 50) {
quickActionsTitle = "Actions"
} else if isCompact {
quickActionsTitle = "Quick Actions"
} else {
quickActionsTitle = "✨ Quick Actions"
}
b.WriteString(titleStyle.Render(quickActionsTitle))
b.WriteString("\n")
var featureMargin int
if isCompact {
featureMargin = 1
} else {
featureMargin = 2
}
compactFeatureStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginLeft(featureMargin)
var commandsText string
var quitText string
if shouldUseCompactLayout(m.width, 60) {
commandsText = fmt.Sprintf(" %s Commands", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit", keyHintStyle.Render("[Q]"))
} else {
commandsText = fmt.Sprintf(" %s Open commands menu", keyHintStyle.Render("[C]"))
quitText = fmt.Sprintf(" %s Quit application", keyHintStyle.Render("[Q]"))
}
b.WriteString(compactFeatureStyle.Render(commandsText))
b.WriteString("\n")
b.WriteString(compactFeatureStyle.Render(quitText))
if !shouldUseCompactLayout(m.width, 70) {
b.WriteString("\n\n")
footerStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true)
b.WriteString(footerStyle.Render("Press 'C' to customize your tunnel settings"))
}
return b.String()
return m.dashboardView()
}
func (i *interaction) Start() {
@@ -844,15 +250,9 @@ func (i *interaction) Start() {
}
i.program.Kill()
i.program = nil
if err := m.interaction.lifecycle.Close(); err != nil {
log.Printf("Cannot close session: %s \n", err)
if i.closeFunc != nil {
if err := i.closeFunc(); err != nil {
log.Printf("Cannot close session: %s \n", err)
}
}
}
func buildURL(protocol, subdomain, domain string) string {
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
}
func generateRandomSubdomain() string {
return random.GenerateRandomString(20)
}
+95
View File
@@ -0,0 +1,95 @@
package interaction
import (
"fmt"
"time"
"tunnel_pls/types"
"github.com/charmbracelet/bubbles/help"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/list"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
)
type commandItem struct {
name string
desc string
}
func (i commandItem) FilterValue() string { return i.name }
func (i commandItem) Title() string { return i.name }
func (i commandItem) Description() string { return i.desc }
type model struct {
domain string
protocol string
tunnelType types.TunnelType
port uint16
keymap keymap
help help.Model
quitting bool
showingCommands bool
editingSlug bool
showingComingSoon bool
commandList list.Model
slugInput textinput.Model
slugError string
interaction *interaction
width int
height int
}
func (m *model) getTunnelURL() string {
if m.tunnelType == types.HTTP {
return buildURL(m.protocol, m.interaction.slug.String(), m.domain)
}
return fmt.Sprintf("tcp://%s:%d", m.domain, m.port)
}
type keymap struct {
quit key.Binding
command key.Binding
random key.Binding
}
type tickMsg time.Time
func (m *model) Init() tea.Cmd {
return tea.Batch(textinput.Blink, tea.WindowSize())
}
func getResponsiveWidth(screenWidth, padding, minWidth, maxWidth int) int {
width := screenWidth - padding
if width > maxWidth {
width = maxWidth
}
if width < minWidth {
width = minWidth
}
return width
}
func shouldUseCompactLayout(width int, threshold int) bool {
return width < threshold
}
func truncateString(s string, maxLength int) string {
if len(s) <= maxLength {
return s
}
if maxLength < 4 {
return s[:maxLength]
}
return s[:maxLength-3] + "..."
}
func tickCmd(d time.Duration) tea.Cmd {
return tea.Tick(d, func(t time.Time) tea.Msg {
return tickMsg(t)
})
}
func buildURL(protocol, subdomain, domain string) string {
return fmt.Sprintf("%s://%s.%s", protocol, subdomain, domain)
}
+224
View File
@@ -0,0 +1,224 @@
package interaction
import (
"fmt"
"strings"
"tunnel_pls/internal/random"
"tunnel_pls/types"
"github.com/charmbracelet/bubbles/key"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
func (m *model) slugUpdate(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
var cmd tea.Cmd
if m.tunnelType != types.HTTP {
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
}
switch msg.String() {
case "esc":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "enter":
inputValue := m.slugInput.Value()
if err := m.interaction.sessionRegistry.Update(m.interaction.user, types.SessionKey{
Id: m.interaction.slug.String(),
Type: types.HTTP,
}, types.SessionKey{
Id: inputValue,
Type: types.HTTP,
}); err != nil {
m.slugError = err.Error()
return m, nil
}
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
case "ctrl+c":
m.editingSlug = false
m.slugError = ""
return m, tea.Batch(tea.ClearScreen, textinput.Blink)
default:
if key.Matches(msg, m.keymap.random) {
newSubdomain, err := random.GenerateRandomString(20)
if err != nil {
return m, cmd
}
m.slugInput.SetValue(newSubdomain)
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
}
m.slugError = ""
m.slugInput, cmd = m.slugInput.Update(msg)
return m, cmd
}
}
func (m *model) slugView() string {
isCompact := shouldUseCompactLayout(m.width, 70)
isVeryCompact := shouldUseCompactLayout(m.width, 50)
var boxPadding int
var boxMargin int
if isVeryCompact {
boxPadding = 1
boxMargin = 1
} else if isCompact {
boxPadding = 1
boxMargin = 1
} else {
boxPadding = 2
boxMargin = 2
}
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#7D56F4")).
PaddingTop(1).
PaddingBottom(1)
instructionStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
MarginTop(1)
inputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin)
helpStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#666666")).
Italic(true).
MarginTop(1)
errorBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FF0000")).
Background(lipgloss.Color("#3D0000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1)
rulesBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
rulesBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FAFAFA")).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#7D56F4")).
Padding(0, boxPadding).
MarginTop(1).
MarginBottom(1).
Width(rulesBoxWidth)
var b strings.Builder
var title string
if isVeryCompact {
title = "Edit Subdomain"
} else {
title = "🔧 Edit Subdomain"
}
b.WriteString(titleStyle.Render(title))
b.WriteString("\n\n")
if m.tunnelType != types.HTTP {
warningBoxWidth := getResponsiveWidth(m.width, 10, 30, 60)
warningBoxStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFA500")).
Background(lipgloss.Color("#3D2000")).
Bold(true).
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FFA500")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(boxMargin).
Width(warningBoxWidth)
var warningText string
if isVeryCompact {
warningText = "⚠️ TCP tunnels don't support custom subdomains."
} else {
warningText = "⚠️ TCP tunnels cannot have custom subdomains. Only HTTP/HTTPS tunnels support subdomain customization."
}
b.WriteString(warningBoxStyle.Render(warningText))
b.WriteString("\n\n")
var helpText string
if isVeryCompact {
helpText = "Press any key to go back"
} else {
helpText = "Press Enter or Esc to go back"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
var rulesContent string
if isVeryCompact {
rulesContent = "Rules:\n3-20 chars\na-z, 0-9, -\nNo leading/trailing -"
} else if isCompact {
rulesContent = "📋 Rules:\n • 3-20 chars\n • a-z, 0-9, -\n • No leading/trailing -"
} else {
rulesContent = "📋 Rules: \n\t• 3-20 chars \n\t• a-z, 0-9, - \n\t• No leading/trailing -"
}
b.WriteString(rulesBoxStyle.Render(rulesContent))
b.WriteString("\n")
var instruction string
if isVeryCompact {
instruction = "Custom subdomain:"
} else {
instruction = "Enter your custom subdomain:"
}
b.WriteString(instructionStyle.Render(instruction))
b.WriteString("\n")
if m.slugError != "" {
errorInputBoxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("#FF0000")).
Padding(1, boxPadding).
MarginTop(boxMargin).
MarginBottom(1)
b.WriteString(errorInputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
b.WriteString(errorBoxStyle.Render("❌ " + m.slugError))
b.WriteString("\n")
} else {
b.WriteString(inputBoxStyle.Render(m.slugInput.View()))
b.WriteString("\n")
}
previewURL := buildURL(m.protocol, m.slugInput.Value(), m.domain)
previewWidth := getResponsiveWidth(m.width, 10, 30, 80)
if len(previewURL) > previewWidth-10 {
previewURL = truncateString(previewURL, previewWidth-10)
}
previewStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("#04B575")).
Italic(true).
Width(previewWidth)
b.WriteString(previewStyle.Render(fmt.Sprintf("Preview: %s", previewURL)))
b.WriteString("\n")
var helpText string
if isVeryCompact {
helpText = "Enter: save • CTRL+R: random • Esc: cancel"
} else {
helpText = "Press Enter to save • CTRL+R for random • Esc to cancel"
}
b.WriteString(helpStyle.Render(helpText))
return b.String()
}
+11 -14
View File
@@ -28,47 +28,44 @@ type lifecycle struct {
conn ssh.Conn
channel ssh.Channel
forwarder Forwarder
sessionRegistry SessionRegistry
slug slug.Slug
startedAt time.Time
sessionRegistry SessionRegistry
portRegistry portUtil.Port
user string
}
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, user string) Lifecycle {
func New(conn ssh.Conn, forwarder Forwarder, slugManager slug.Slug, port portUtil.Port, sessionRegistry SessionRegistry, user string) Lifecycle {
return &lifecycle{
status: types.INITIALIZING,
conn: conn,
channel: nil,
forwarder: forwarder,
slug: slugManager,
sessionRegistry: nil,
startedAt: time.Now(),
sessionRegistry: sessionRegistry,
portRegistry: port,
user: user,
}
}
func (l *lifecycle) SetSessionRegistry(registry SessionRegistry) {
l.sessionRegistry = registry
}
type Lifecycle interface {
Connection() ssh.Conn
Channel() ssh.Channel
PortRegistry() portUtil.Port
User() string
SetChannel(channel ssh.Channel)
SetSessionRegistry(registry SessionRegistry)
SetStatus(status types.Status)
IsActive() bool
StartedAt() time.Time
Close() error
}
func (l *lifecycle) User() string {
return l.user
func (l *lifecycle) PortRegistry() portUtil.Port {
return l.portRegistry
}
func (l *lifecycle) Channel() ssh.Channel {
return l.channel
func (l *lifecycle) User() string {
return l.user
}
func (l *lifecycle) SetChannel(channel ssh.Channel) {
@@ -116,7 +113,7 @@ func (l *lifecycle) Close() error {
l.sessionRegistry.Remove(key)
if tunnelType == types.TCP {
if err := portUtil.Default.SetPortStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
if err := l.PortRegistry().SetStatus(l.forwarder.ForwardedPort(), false); err != nil && firstErr == nil {
firstErr = err
}
}
+213 -223
View File
@@ -12,6 +12,8 @@ import (
"tunnel_pls/internal/config"
portUtil "tunnel_pls/internal/port"
"tunnel_pls/internal/random"
"tunnel_pls/internal/registry"
"tunnel_pls/internal/transport"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
@@ -21,24 +23,16 @@ import (
"golang.org/x/crypto/ssh"
)
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
}
type Session interface {
HandleGlobalRequest(ch <-chan *ssh.Request)
HandleTCPIPForward(req *ssh.Request)
HandleHTTPForward(req *ssh.Request, port uint16)
HandleTCPForward(req *ssh.Request, addr string, port uint16)
HandleGlobalRequest(ch <-chan *ssh.Request) error
HandleTCPIPForward(req *ssh.Request) error
HandleHTTPForward(req *ssh.Request, port uint16) error
HandleTCPForward(req *ssh.Request, addr string, port uint16) error
Lifecycle() lifecycle.Lifecycle
Interaction() interaction.Interaction
Forwarder() forwarder.Forwarder
Slug() slug.Slug
Detail() *Detail
Detail() *types.Detail
Start() error
}
@@ -49,21 +43,16 @@ type session struct {
interaction interaction.Interaction
forwarder forwarder.Forwarder
slug slug.Slug
registry Registry
registry registry.Registry
}
var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 5432, 5900, 6379, 8080, 8443, 9000, 9200, 27017}
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry Registry, user string) Session {
func New(conn *ssh.ServerConn, initialReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel, sessionRegistry registry.Registry, portRegistry portUtil.Port, user string) Session {
slugManager := slug.New()
forwarderManager := forwarder.New(slugManager)
interactionManager := interaction.New(slugManager, forwarderManager)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, user)
interactionManager.SetLifecycle(lifecycleManager)
forwarderManager.SetLifecycle(lifecycleManager)
interactionManager.SetSessionRegistry(sessionRegistry)
lifecycleManager.SetSessionRegistry(sessionRegistry)
forwarderManager := forwarder.New(slugManager, conn)
lifecycleManager := lifecycle.New(conn, forwarderManager, slugManager, portRegistry, sessionRegistry, user)
interactionManager := interaction.New(slugManager, forwarderManager, sessionRegistry, user, lifecycleManager.Close)
return &session{
initialReq: initialReq,
@@ -92,16 +81,17 @@ func (s *session) Slug() slug.Slug {
return s.slug
}
func (s *session) Detail() *Detail {
var tunnelType string
if s.forwarder.TunnelType() == types.HTTP {
tunnelType = "HTTP"
} else if s.forwarder.TunnelType() == types.TCP {
tunnelType = "TCP"
} else {
func (s *session) Detail() *types.Detail {
tunnelTypeMap := map[types.TunnelType]string{
types.HTTP: "HTTP",
types.TCP: "TCP",
}
tunnelType, ok := tunnelTypeMap[s.forwarder.TunnelType()]
if !ok {
tunnelType = "UNKNOWN"
}
return &Detail{
return &types.Detail{
ForwardingType: tunnelType,
Slug: s.slug.String(),
UserID: s.lifecycle.User(),
@@ -111,55 +101,80 @@ func (s *session) Detail() *Detail {
}
func (s *session) Start() error {
var channel ssh.NewChannel
var ok bool
select {
case channel, ok = <-s.sshChan:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
return err
}
go s.HandleGlobalRequest(reqs)
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.INTERACTIVE)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.HEADLESS)
if err := s.setupSessionMode(); err != nil {
return err
}
tcpipReq := s.waitForTCPIPForward()
if tcpipReq == nil {
err := s.interaction.Send(fmt.Sprintf("Port forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
if err != nil {
return err
}
if err = s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("no forwarding Request")
return s.handleMissingForwardRequest()
}
if (s.interaction.Mode() == types.HEADLESS && config.Getenv("MODE", "standalone") == "standalone") && s.lifecycle.User() == "UNAUTHORIZED" {
if err := tcpipReq.Reply(false, nil); err != nil {
log.Printf("cannot reply to tcpip req: %s\n", err)
return err
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
return err
}
return nil
if s.shouldRejectUnauthorized() {
return s.denyForwardingRequest(tcpipReq, nil, nil, fmt.Sprintf("headless forwarding only allowed on node mode"))
}
s.HandleTCPIPForward(tcpipReq)
if err := s.HandleTCPIPForward(tcpipReq); err != nil {
return err
}
s.interaction.Start()
return s.waitForSessionEnd()
}
func (s *session) setupSessionMode() error {
select {
case channel, ok := <-s.sshChan:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
return s.setupInteractiveMode(channel)
case <-time.After(500 * time.Millisecond):
s.interaction.SetMode(types.HEADLESS)
return nil
}
}
func (s *session) setupInteractiveMode(channel ssh.NewChannel) error {
ch, reqs, err := channel.Accept()
if err != nil {
log.Printf("failed to accept channel: %v", err)
return err
}
go func() {
err = s.HandleGlobalRequest(reqs)
if err != nil {
log.Printf("global request handler error: %v", err)
}
}()
s.lifecycle.SetChannel(ch)
s.interaction.SetChannel(ch)
s.interaction.SetMode(types.INTERACTIVE)
return nil
}
func (s *session) handleMissingForwardRequest() error {
err := s.interaction.Send(fmt.Sprintf("PortRegistry forwarding request not received. Ensure you ran the correct command with -R flag. Example: ssh %s -p %s -R 80:localhost:3000", config.Getenv("DOMAIN", "localhost"), config.Getenv("PORT", "2200")))
if err != nil {
return err
}
if err = s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return fmt.Errorf("no forwarding Request")
}
func (s *session) shouldRejectUnauthorized() bool {
return s.interaction.Mode() == types.HEADLESS &&
config.Getenv("MODE", "standalone") == "standalone" &&
s.lifecycle.User() == "UNAUTHORIZED"
}
func (s *session) waitForSessionEnd() error {
if err := s.lifecycle.Connection().Wait(); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("ssh connection closed with error: %v", err)
}
@@ -192,216 +207,191 @@ func (s *session) waitForTCPIPForward() *ssh.Request {
}
}
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
func (s *session) handleWindowChange(req *ssh.Request) error {
p := req.Payload
if len(p) < 16 {
log.Println("invalid window-change payload")
return req.Reply(false, nil)
}
cols := binary.BigEndian.Uint32(p[0:4])
rows := binary.BigEndian.Uint32(p[4:8])
s.interaction.SetWH(int(cols), int(rows))
return req.Reply(true, nil)
}
func (s *session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) error {
for req := range GlobalRequest {
switch req.Type {
case "shell", "pty-req":
err := req.Reply(true, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
return err
}
case "window-change":
p := req.Payload
if len(p) < 16 {
log.Println("invalid window-change payload")
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
}
return
}
cols := binary.BigEndian.Uint32(p[0:4])
rows := binary.BigEndian.Uint32(p[4:8])
s.interaction.SetWH(int(cols), int(rows))
err := req.Reply(true, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
if err := s.handleWindowChange(req); err != nil {
return err
}
default:
log.Println("Unknown request type:", req.Type)
err := req.Reply(false, nil)
if err != nil {
log.Println("Failed to reply to request:", err)
return
return err
}
}
}
return nil
}
func (s *session) HandleTCPIPForward(req *ssh.Request) {
log.Println("Port forwarding request detected")
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
}
reader := bytes.NewReader(req.Payload)
addr, err := readSSHString(reader)
func (s *session) parseForwardPayload(payloadReader io.Reader) (address string, port uint16, err error) {
address, err = readSSHString(payloadReader)
if err != nil {
fail(fmt.Sprintf("Failed to read address from payload: %v", err))
return
return "", 0, err
}
var rawPortToBind uint32
if err = binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
fail(fmt.Sprintf("Failed to read port from payload: %v", err))
return
if err = binary.Read(payloadReader, binary.BigEndian, &rawPortToBind); err != nil {
return "", 0, err
}
if rawPortToBind > 65535 {
fail(fmt.Sprintf("Port %d is larger than allowed port of 65535", rawPortToBind))
return
return "", 0, fmt.Errorf("port is larger than allowed port of 65535")
}
portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) {
fail(fmt.Sprintf("Port %d is blocked or restricted", portToBind))
return
port = uint16(rawPortToBind)
if isBlockedPort(port) {
return "", 0, fmt.Errorf("port is block")
}
switch portToBind {
case 80, 443:
s.HandleHTTPForward(req, portToBind)
default:
s.HandleTCPForward(req, addr, portToBind)
if port == 0 {
unassigned, ok := s.lifecycle.PortRegistry().Unassigned()
if !ok {
return "", 0, fmt.Errorf("no available port")
}
return address, unassigned, err
}
return address, port, err
}
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
fail := func(msg string, key *types.SessionKey) {
log.Println(msg)
if key != nil {
s.registry.Remove(*key)
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
func (s *session) denyForwardingRequest(req *ssh.Request, key *types.SessionKey, listener io.Closer, msg string) error {
var errs []error
if key != nil {
s.registry.Remove(*key)
}
if listener != nil {
if err := listener.Close(); err != nil {
errs = append(errs, fmt.Errorf("close listener: %w", err))
}
}
randomString := random.GenerateRandomString(20)
key := types.SessionKey{Id: randomString, Type: types.HTTP}
if !s.registry.Register(key, s) {
fail(fmt.Sprintf("Failed to register client with slug: %s", randomString), nil)
return
if err := req.Reply(false, nil); err != nil {
errs = append(errs, fmt.Errorf("reply request: %w", err))
}
if err := s.lifecycle.Close(); err != nil {
errs = append(errs, fmt.Errorf("close session: %w", err))
}
errs = append(errs, fmt.Errorf("deny forwarding request: %s", msg))
return errors.Join(errs...)
}
func (s *session) approveForwardingRequest(req *ssh.Request, port uint16) (err error) {
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, uint32(portToBind))
err = binary.Write(buf, binary.BigEndian, uint32(port))
if err != nil {
fail(fmt.Sprintf("Failed to write port to buffer: %v", err), &key)
return
return err
}
log.Printf("HTTP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes())
if err != nil {
fail(fmt.Sprintf("Failed to reply to request: %v", err), &key)
return
return err
}
s.forwarder.SetType(types.HTTP)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(randomString)
s.lifecycle.SetStatus(types.RUNNING)
return nil
}
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
fail := func(msg string) {
log.Println(msg)
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
return
}
if err := s.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
}
cleanup := func(msg string, port uint16, listener net.Listener, key *types.SessionKey) {
log.Println(msg)
if key != nil {
s.registry.Remove(*key)
}
if port != 0 {
if setErr := portUtil.Default.SetPortStatus(port, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr)
}
}
if listener != nil {
if closeErr := listener.Close(); closeErr != nil {
log.Printf("Failed to close listener: %v", closeErr)
}
}
if err := req.Reply(false, nil); err != nil {
log.Println("Failed to reply to request:", err)
}
_ = s.lifecycle.Close()
}
if portToBind == 0 {
unassigned, ok := portUtil.Default.GetUnassignedPort()
if !ok {
fail("No available port")
return
}
portToBind = unassigned
}
if claimed := portUtil.Default.ClaimPort(portToBind); !claimed {
fail(fmt.Sprintf("Port %d is already in use or restricted", portToBind))
return
}
log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
func (s *session) finalizeForwarding(req *ssh.Request, portToBind uint16, listener net.Listener, tunnelType types.TunnelType, slug string) error {
err := s.approveForwardingRequest(req, portToBind)
if err != nil {
cleanup(fmt.Sprintf("Port %d is already in use or restricted", portToBind), portToBind, nil, nil)
return
return err
}
s.forwarder.SetType(tunnelType)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(slug)
s.lifecycle.SetStatus(types.RUNNING)
if listener != nil {
s.forwarder.SetListener(listener)
}
return nil
}
func (s *session) HandleTCPIPForward(req *ssh.Request) error {
reader := bytes.NewReader(req.Payload)
address, port, err := s.parseForwardPayload(reader)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("cannot parse forwarded payload: %s", err.Error()))
}
switch port {
case 80, 443:
return s.HandleHTTPForward(req, port)
default:
return s.HandleTCPForward(req, address, port)
}
}
func (s *session) HandleHTTPForward(req *ssh.Request, portToBind uint16) error {
randomString, err := random.GenerateRandomString(20)
if err != nil {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to create slug: %s", err))
}
key := types.SessionKey{Id: randomString, Type: types.HTTP}
if !s.registry.Register(key, s) {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("Failed to register client with slug: %s", randomString))
}
err = s.finalizeForwarding(req, portToBind, nil, types.HTTP, key.Id)
if err != nil {
return s.denyForwardingRequest(req, &key, nil, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}
return nil
}
func (s *session) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) error {
if claimed := s.lifecycle.PortRegistry().Claim(portToBind); !claimed {
return s.denyForwardingRequest(req, nil, nil, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
}
tcpServer := transport.NewTCPServer(portToBind, s.forwarder)
listener, err := tcpServer.Listen()
if err != nil {
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("PortRegistry %d is already in use or restricted", portToBind))
}
key := types.SessionKey{Id: fmt.Sprintf("%d", portToBind), Type: types.TCP}
if !s.registry.Register(key, s) {
cleanup(fmt.Sprintf("Failed to register TCP client with id: %s", key.Id), portToBind, listener, nil)
return
return s.denyForwardingRequest(req, nil, listener, fmt.Sprintf("Failed to register TCP client with id: %s", key.Id))
}
buf := new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, uint32(portToBind))
err = s.finalizeForwarding(req, portToBind, listener, types.TCP, key.Id)
if err != nil {
cleanup(fmt.Sprintf("Failed to write port to buffer: %v", err), portToBind, listener, &key)
return
return s.denyForwardingRequest(req, &key, listener, fmt.Sprintf("Failed to finalize forwarding: %s", err))
}
log.Printf("TCP forwarding approved on port: %d", portToBind)
err = req.Reply(true, buf.Bytes())
if err != nil {
cleanup(fmt.Sprintf("Failed to reply to request: %v", err), portToBind, listener, &key)
return
}
go func() {
err = tcpServer.Serve(listener)
if err != nil {
log.Printf("Failed serving tcp server: %s\n", err)
}
}()
s.forwarder.SetType(types.TCP)
s.forwarder.SetListener(listener)
s.forwarder.SetForwardedPort(portToBind)
s.slug.Set(key.Id)
s.lifecycle.SetStatus(types.RUNNING)
go s.forwarder.AcceptTCPConnections()
return nil
}
func readSSHString(reader *bytes.Reader) (string, error) {
func readSSHString(reader io.Reader) (string, error) {
var length uint32
if err := binary.Read(reader, binary.BigEndian, &length); err != nil {
return "", err
+10
View File
@@ -1,5 +1,7 @@
package types
import "time"
type Status int
const (
@@ -27,6 +29,14 @@ type SessionKey struct {
Type TunnelType
}
type Detail struct {
ForwardingType string `json:"forwarding_type,omitempty"`
Slug string `json:"slug,omitempty"`
UserID string `json:"user_id,omitempty"`
Active bool `json:"active,omitempty"`
StartedAt time.Time `json:"started_at,omitempty"`
}
var BadGatewayResponse = []byte("HTTP/1.1 502 Bad Gateway\r\n" +
"Content-Length: 11\r\n" +
"Content-Type: text/plain\r\n\r\n" +