main #52

Merged
bagas merged 3 commits from main into staging 2025-12-29 14:58:10 +00:00
21 changed files with 573 additions and 227 deletions
Showing only changes of commit 08565d845f - Show all commits

View File

@@ -5,6 +5,14 @@ on:
branches: branches:
- main - main
- staging - staging
paths:
- '**.go'
- 'go.mod'
- 'go.sum'
- 'Dockerfile'
- 'Dockerfile.*'
- '.dockerignore'
- '.gitea/workflows/build.yml'
jobs: jobs:
build-and-push: build-and-push:

View File

@@ -17,4 +17,5 @@ jobs:
env: env:
RENOVATE_CONFIG_FILE: ${{ gitea.workspace }}/renovate-config.js RENOVATE_CONFIG_FILE: ${{ gitea.workspace }}/renovate-config.js
LOG_LEVEL: "debug" LOG_LEVEL: "debug"
RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }} RENOVATE_TOKEN: ${{ secrets.RENOVATE_TOKEN }}
GITHUB_COM_TOKEN: ${{ secrets.COM_TOKEN }}

View File

@@ -16,7 +16,7 @@ COPY . .
RUN --mount=type=cache,target=/go/pkg/mod \ RUN --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 \ CGO_ENABLED=0 GOOS=linux \
go build -trimpath \ go build -trimpath \
-ldflags="-w -s" \ -ldflags="-w -s" \
-o /app/tunnel_pls \ -o /app/tunnel_pls \

105
README.md
View File

@@ -6,7 +6,6 @@ A lightweight SSH-based tunnel server written in Go that enables secure TCP and
- SSH interactive session with real-time command handling - SSH interactive session with real-time command handling
- Custom subdomain management for HTTP tunnels - Custom subdomain management for HTTP tunnels
- Active connection control with drop functionality
- Dual protocol support: HTTP and TCP tunnels - Dual protocol support: HTTP and TCP tunnels
- Real-time connection monitoring - Real-time connection monitoring
## Requirements ## Requirements
@@ -116,6 +115,110 @@ go tool pprof http://localhost:6060/debug/pprof/profile?seconds=30
go tool pprof http://localhost:6060/debug/pprof/heap go tool pprof http://localhost:6060/debug/pprof/heap
``` ```
## Docker Deployment
Three Docker Compose configurations are available for different deployment scenarios. Each configuration uses the image `git.fossy.my.id/bagas/tunnel-please:latest`.
### Configuration Options
#### 1. Root with Host Networking (RECOMMENDED)
**File:** `docker-compose.root.yml`
**Advantages:**
- Full TCP port forwarding support (ports 40000-41000)
- Direct binding to privileged ports (80, 443, 2200)
- Best performance with no NAT overhead
- Maximum flexibility for all tunnel types
- No port mapping limitations
**Use Case:** Production deployments where you need unrestricted TCP forwarding and maximum performance.
**Deploy:**
```bash
docker-compose -f docker-compose.root.yml up -d
```
#### 2. Standard (HTTP/HTTPS Only)
**File:** `docker-compose.standard.yml`
**Advantages:**
- Runs with unprivileged user (more secure)
- Standard port mappings (2200, 80, 443)
- Simple and predictable networking
- TCP port forwarding disabled (`ALLOWED_PORTS=none`)
**Use Case:** Deployments where you only need HTTP/HTTPS tunneling without custom TCP port forwarding.
**Deploy:**
```bash
docker-compose -f docker-compose.standard.yml up -d
```
#### 3. Limited TCP Forwarding
**File:** `docker-compose.tcp.yml`
**Advantages:**
- Runs with unprivileged user (more secure)
- Standard port mappings (2200, 80, 443)
- Limited TCP forwarding (ports 30000-31000)
- Controlled port range exposure
**Use Case:** Deployments where you need both HTTP/HTTPS tunneling and limited TCP forwarding within a specific port range.
**Deploy:**
```bash
docker-compose -f docker-compose.tcp.yml up -d
```
### Quick Start
1. **Choose your configuration** based on your requirements
2. **Edit the environment variables** in the chosen compose file:
- `DOMAIN`: Your domain name (e.g., `example.com`)
- `ACME_EMAIL`: Your email for Let's Encrypt
- `CF_API_TOKEN`: Your Cloudflare API token (if using automatic TLS)
3. **Deploy:**
```bash
docker-compose -f docker-compose.root.yml up -d
```
4. **Check logs:**
```bash
docker-compose -f docker-compose.root.yml logs -f
```
5. **Stop the service:**
```bash
docker-compose -f docker-compose.root.yml down
```
### Volume Management
All configurations use a named volume `certs` for persistent storage:
- SSH keys: `/app/certs/ssh/`
- TLS certificates: `/app/certs/tls/`
To backup certificates:
```bash
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar czf /backup/certs-backup.tar.gz -C /data .
```
To restore certificates:
```bash
docker run --rm -v tunnel_pls_certs:/data -v $(pwd):/backup alpine tar xzf /backup/certs-backup.tar.gz -C /data
```
### Recommendation
**Use `docker-compose.root.yml`** for production deployments if you need:
- Full TCP port forwarding capabilities
- Any port range configuration
- Direct port binding without mapping overhead
- Maximum performance and flexibility
This is the recommended configuration for most use cases as it provides the complete feature set without limitations.
## Contributing ## Contributing
Contributions are welcome! Contributions are welcome!

37
docker-compose.root.yml Normal file
View File

@@ -0,0 +1,37 @@
version: '3.8'
services:
tunnel-please:
image: git.fossy.my.id/bagas/tunnel-please:latest
container_name: tunnel-please-root
user: root
network_mode: host
restart: unless-stopped
volumes:
- certs:/app/certs
environment:
DOMAIN: example.com
PORT: 2200
HTTP_PORT: 8080
HTTPS_PORT: 8443
TLS_ENABLED: "true"
TLS_REDIRECT: "true"
ACME_EMAIL: admin@example.com
CF_API_TOKEN: your_cloudflare_api_token_here
ACME_STAGING: "false"
CORS_LIST: http://localhost:3000,https://example.com
ALLOWED_PORTS: 40000-41000
BUFFER_SIZE: 32768
PPROF_ENABLED: "false"
PPROF_PORT: 6060
healthcheck:
test: ["CMD", "/bin/sh", "-c", "netstat -tln | grep -q :2200"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
volumes:
certs:
driver: local

View File

@@ -0,0 +1,39 @@
version: '3.8'
services:
tunnel-please:
image: git.fossy.my.id/bagas/tunnel-please:latest
container_name: tunnel-please-standard
restart: unless-stopped
ports:
- "2200:2200"
- "80:8080"
- "443:8443"
volumes:
- certs:/app/certs
environment:
DOMAIN: example.com
PORT: 2200
HTTP_PORT: 8080
HTTPS_PORT: 8443
TLS_ENABLED: "true"
TLS_REDIRECT: "true"
ACME_EMAIL: admin@example.com
CF_API_TOKEN: your_cloudflare_api_token_here
ACME_STAGING: "false"
CORS_LIST: http://localhost:3000,https://example.com
ALLOWED_PORTS: none
BUFFER_SIZE: 32768
PPROF_ENABLED: "false"
PPROF_PORT: 6060
healthcheck:
test: ["CMD", "/bin/sh", "-c", "netstat -tln | grep -q :2200"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
volumes:
certs:
driver: local

40
docker-compose.tcp.yml Normal file
View File

@@ -0,0 +1,40 @@
version: '3.8'
services:
tunnel-please:
image: git.fossy.my.id/bagas/tunnel-please:latest
container_name: tunnel-please-tcp
restart: unless-stopped
ports:
- "2200:2200"
- "80:8080"
- "443:8443"
- "30000-31000:30000-31000"
volumes:
- certs:/app/certs
environment:
DOMAIN: example.com
PORT: 2200
HTTP_PORT: 8080
HTTPS_PORT: 8443
TLS_ENABLED: "true"
TLS_REDIRECT: "true"
ACME_EMAIL: admin@example.com
CF_API_TOKEN: your_cloudflare_api_token_here
ACME_STAGING: "false"
CORS_LIST: http://localhost:3000,https://example.com
ALLOWED_PORTS: 30000-31000
BUFFER_SIZE: 32768
PPROF_ENABLED: "false"
PPROF_PORT: 6060
healthcheck:
test: ["CMD", "/bin/sh", "-c", "netstat -tln | grep -q :2200"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
volumes:
certs:
driver: local

View File

@@ -1,6 +1,6 @@
module.exports = { module.exports = {
"endpoint": "https://git.fossy.my.id/api/v1", "endpoint": "https://git.fossy.my.id/api/v1",
"gitAuthor": "Renovate Bot <renovate-bot@fossy.my.id>", "gitAuthor": "Renovate-Clanker <renovate-bot@fossy.my.id>",
"platform": "gitea", "platform": "gitea",
"onboardingConfigFileName": "renovate.json", "onboardingConfigFileName": "renovate.json",
"autodiscover": true, "autodiscover": true,

View File

@@ -10,7 +10,10 @@
"pin", "pin",
"digest" "digest"
], ],
"automerge": true "automerge": true,
"baseBranches": [
"staging"
]
} }
] ]
} }

View File

@@ -9,7 +9,7 @@ import (
) )
func (s *Server) handleConnection(conn net.Conn) { func (s *Server) handleConnection(conn net.Conn) {
sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.Config) sshConn, chans, forwardingReqs, err := ssh.NewServerConn(conn, s.config)
if err != nil { if err != nil {
log.Printf("failed to establish SSH connection: %v", err) log.Printf("failed to establish SSH connection: %v", err)
err := conn.Close() err := conn.Close()

View File

@@ -14,21 +14,38 @@ type HeaderManager interface {
Finalize() []byte Finalize() []byte
} }
type ResponseHeaderFactory struct { type ResponseHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type RequestHeaderManager interface {
Get(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
}
type responseHeaderFactory struct {
startLine []byte startLine []byte
headers map[string]string headers map[string]string
} }
type RequestHeaderFactory struct { type requestHeaderFactory struct {
Method string method string
Path string path string
Version string version string
startLine []byte startLine []byte
headers map[string]string headers map[string]string
} }
func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) { func NewRequestHeaderFactory(br *bufio.Reader) (RequestHeaderManager, error) {
header := &RequestHeaderFactory{ header := &requestHeaderFactory{
headers: make(map[string]string), headers: make(map[string]string),
} }
@@ -44,9 +61,9 @@ func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) {
return nil, fmt.Errorf("invalid request line") return nil, fmt.Errorf("invalid request line")
} }
header.Method = parts[0] header.method = parts[0]
header.Path = parts[1] header.path = parts[1]
header.Version = parts[2] header.version = parts[2]
for { for {
line, err := br.ReadString('\n') line, err := br.ReadString('\n')
@@ -69,8 +86,8 @@ func NewRequestHeaderFactory(br *bufio.Reader) (*RequestHeaderFactory, error) {
return header, nil return header, nil
} }
func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory { func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
header := &ResponseHeaderFactory{ header := &responseHeaderFactory{
startLine: nil, startLine: nil,
headers: make(map[string]string), headers: make(map[string]string),
} }
@@ -96,19 +113,19 @@ func NewResponseHeaderFactory(startLine []byte) *ResponseHeaderFactory {
return header return header
} }
func (resp *ResponseHeaderFactory) Get(key string) string { func (resp *responseHeaderFactory) Get(key string) string {
return resp.headers[key] return resp.headers[key]
} }
func (resp *ResponseHeaderFactory) Set(key string, value string) { func (resp *responseHeaderFactory) Set(key string, value string) {
resp.headers[key] = value resp.headers[key] = value
} }
func (resp *ResponseHeaderFactory) Remove(key string) { func (resp *responseHeaderFactory) Remove(key string) {
delete(resp.headers, key) delete(resp.headers, key)
} }
func (resp *ResponseHeaderFactory) Finalize() []byte { func (resp *responseHeaderFactory) Finalize() []byte {
var buf bytes.Buffer var buf bytes.Buffer
buf.Write(resp.startLine) buf.Write(resp.startLine)
@@ -125,7 +142,7 @@ func (resp *ResponseHeaderFactory) Finalize() []byte {
return buf.Bytes() return buf.Bytes()
} }
func (req *RequestHeaderFactory) Get(key string) string { func (req *requestHeaderFactory) Get(key string) string {
val, ok := req.headers[key] val, ok := req.headers[key]
if !ok { if !ok {
return "" return ""
@@ -133,15 +150,27 @@ func (req *RequestHeaderFactory) Get(key string) string {
return val return val
} }
func (req *RequestHeaderFactory) Set(key string, value string) { func (req *requestHeaderFactory) Set(key string, value string) {
req.headers[key] = value req.headers[key] = value
} }
func (req *RequestHeaderFactory) Remove(key string) { func (req *requestHeaderFactory) Remove(key string) {
delete(req.headers, key) delete(req.headers, key)
} }
func (req *RequestHeaderFactory) Finalize() []byte { func (req *requestHeaderFactory) GetMethod() string {
return req.method
}
func (req *requestHeaderFactory) GetPath() string {
return req.path
}
func (req *requestHeaderFactory) GetVersion() string {
return req.version
}
func (req *requestHeaderFactory) Finalize() []byte {
var buf bytes.Buffer var buf bytes.Buffer
buf.Write(req.startLine) buf.Write(req.startLine)

View File

@@ -20,25 +20,63 @@ import (
type Interaction interface { type Interaction interface {
SendMessage(message string) SendMessage(message string)
} }
type CustomWriter struct {
RemoteAddr net.Addr type HTTPWriter interface {
io.Reader
io.Writer
SetInteraction(interaction Interaction)
AddInteraction(interaction Interaction)
GetRemoteAddr() net.Addr
GetWriter() io.Writer
AddResponseMiddleware(mw ResponseMiddleware)
AddRequestStartMiddleware(mw RequestMiddleware)
SetRequestHeader(header RequestHeaderManager)
GetRequestStartMiddleware() []RequestMiddleware
}
type customWriter struct {
remoteAddr net.Addr
writer io.Writer writer io.Writer
reader io.Reader reader io.Reader
headerBuf []byte headerBuf []byte
buf []byte buf []byte
respHeader *ResponseHeaderFactory respHeader ResponseHeaderManager
reqHeader *RequestHeaderFactory reqHeader RequestHeaderManager
interaction Interaction interaction Interaction
respMW []ResponseMiddleware respMW []ResponseMiddleware
reqStartMW []RequestMiddleware reqStartMW []RequestMiddleware
reqEndMW []RequestMiddleware reqEndMW []RequestMiddleware
} }
func (cw *CustomWriter) SetInteraction(interaction Interaction) { func (cw *customWriter) SetInteraction(interaction Interaction) {
cw.interaction = interaction cw.interaction = interaction
} }
func (cw *CustomWriter) Read(p []byte) (int, error) { func (cw *customWriter) GetRemoteAddr() net.Addr {
return cw.remoteAddr
}
func (cw *customWriter) GetWriter() io.Writer {
return cw.writer
}
func (cw *customWriter) AddResponseMiddleware(mw ResponseMiddleware) {
cw.respMW = append(cw.respMW, mw)
}
func (cw *customWriter) AddRequestStartMiddleware(mw RequestMiddleware) {
cw.reqStartMW = append(cw.reqStartMW, mw)
}
func (cw *customWriter) SetRequestHeader(header RequestHeaderManager) {
cw.reqHeader = header
}
func (cw *customWriter) GetRequestStartMiddleware() []RequestMiddleware {
return cw.reqStartMW
}
func (cw *customWriter) Read(p []byte) (int, error) {
tmp := make([]byte, len(p)) tmp := make([]byte, len(p))
read, err := cw.reader.Read(tmp) read, err := cw.reader.Read(tmp)
if read == 0 && err != nil { if read == 0 && err != nil {
@@ -95,9 +133,9 @@ func (cw *CustomWriter) Read(p []byte) (int, error) {
return n, nil return n, nil
} }
func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *CustomWriter { func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
return &CustomWriter{ return &customWriter{
RemoteAddr: remoteAddr, remoteAddr: remoteAddr,
writer: writer, writer: writer,
reader: reader, reader: reader,
buf: make([]byte, 0, 4096), buf: make([]byte, 0, 4096),
@@ -129,7 +167,7 @@ func isHTTPHeader(buf []byte) bool {
return true return true
} }
func (cw *CustomWriter) Write(p []byte) (int, error) { func (cw *customWriter) Write(p []byte) (int, error) {
if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" { if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" {
cw.respHeader = nil cw.respHeader = nil
} }
@@ -186,7 +224,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) {
return len(p), nil return len(p), nil
} }
func (cw *CustomWriter) AddInteraction(interaction Interaction) { func (cw *customWriter) AddInteraction(interaction Interaction) {
cw.interaction = interaction cw.interaction = interaction
} }
@@ -292,13 +330,13 @@ func Handler(conn net.Conn) {
return return
} }
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
cw.SetInteraction(sshSession.Interaction) cw.SetInteraction(sshSession.GetInteraction())
forwardRequest(cw, reqhf, sshSession) forwardRequest(cw, reqhf, sshSession)
return return
} }
func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { func forwardRequest(cw HTTPWriter, initialRequest RequestHeaderManager, sshSession *session.SSHSession) {
payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr) payload := sshSession.GetForwarder().CreateForwardedTCPIPPayload(cw.GetRemoteAddr())
type channelResult struct { type channelResult struct {
channel ssh.Channel channel ssh.Channel
@@ -308,7 +346,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := sshSession.GetLifecycle().GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -319,29 +357,28 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
case result := <-resultChan: case result := <-resultChan:
if result.err != nil { if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err) log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer) sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
channel = result.channel channel = result.channel
reqs = result.reqs reqs = result.reqs
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel") log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer) sshSession.GetForwarder().WriteBadGatewayResponse(cw.GetWriter())
return return
} }
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
fingerprintMiddleware := NewTunnelFingerprint() fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr) forwardedForMiddleware := NewForwardedFor(cw.GetRemoteAddr())
cw.respMW = append(cw.respMW, fingerprintMiddleware) cw.AddResponseMiddleware(fingerprintMiddleware)
cw.reqStartMW = append(cw.reqStartMW, forwardedForMiddleware) cw.AddRequestStartMiddleware(forwardedForMiddleware)
cw.reqEndMW = nil cw.SetRequestHeader(initialRequest)
cw.reqHeader = initialRequest
for _, m := range cw.reqStartMW { for _, m := range cw.GetRequestStartMiddleware() {
if err := m.HandleRequest(cw.reqHeader); err != nil { if err := m.HandleRequest(initialRequest); err != nil {
log.Printf("Error handling request: %v", err) log.Printf("Error handling request: %v", err)
return return
} }
@@ -353,6 +390,6 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
return return
} }
sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) sshSession.GetForwarder().HandleConnection(cw, channel, cw.GetRemoteAddr())
return return
} }

View File

@@ -104,7 +104,7 @@ func HandlerTLS(conn net.Conn) {
return return
} }
cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr()) cw := NewCustomWriter(conn, dstReader, conn.RemoteAddr())
cw.SetInteraction(sshSession.Interaction) cw.SetInteraction(sshSession.GetInteraction())
forwardRequest(cw, reqhf, sshSession) forwardRequest(cw, reqhf, sshSession)
return return
} }

View File

@@ -5,11 +5,11 @@ import (
) )
type RequestMiddleware interface { type RequestMiddleware interface {
HandleRequest(header *RequestHeaderFactory) error HandleRequest(header RequestHeaderManager) error
} }
type ResponseMiddleware interface { type ResponseMiddleware interface {
HandleResponse(header *ResponseHeaderFactory, body []byte) error HandleResponse(header ResponseHeaderManager, body []byte) error
} }
type TunnelFingerprint struct{} type TunnelFingerprint struct{}
@@ -18,16 +18,11 @@ func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{} return &TunnelFingerprint{}
} }
func (h *TunnelFingerprint) HandleResponse(header *ResponseHeaderFactory, body []byte) error { func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
header.Set("Server", "Tunnel Please") header.Set("Server", "Tunnel Please")
return nil return nil
} }
type RequestLogger struct {
interaction Interaction
remoteAddr net.Addr
}
type ForwardedFor struct { type ForwardedFor struct {
addr net.Addr addr net.Addr
} }
@@ -36,7 +31,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr} return &ForwardedFor{addr: addr}
} }
func (ff *ForwardedFor) HandleRequest(header *RequestHeaderFactory) error { func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
host, _, err := net.SplitHostPort(ff.addr.String()) host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil { if err != nil {
return err return err

View File

@@ -11,9 +11,21 @@ import (
) )
type Server struct { type Server struct {
Conn *net.Listener conn *net.Listener
Config *ssh.ServerConfig config *ssh.ServerConfig
HttpServer *http.Server httpServer *http.Server
}
func (s *Server) GetConn() *net.Listener {
return s.conn
}
func (s *Server) GetConfig() *ssh.ServerConfig {
return s.config
}
func (s *Server) GetHttpServer() *http.Server {
return s.httpServer
} }
func NewServer(config *ssh.ServerConfig) *Server { func NewServer(config *ssh.ServerConfig) *Server {
@@ -33,15 +45,15 @@ func NewServer(config *ssh.ServerConfig) *Server {
log.Fatalf("failed to start http server: %v", err) log.Fatalf("failed to start http server: %v", err)
} }
return &Server{ return &Server{
Conn: &listener, conn: &listener,
Config: config, config: config,
} }
} }
func (s *Server) Start() { func (s *Server) Start() {
log.Println("SSH server is starting on port 2200...") log.Println("SSH server is starting on port 2200...")
for { for {
conn, err := (*s.Conn).Accept() conn, err := (*s.conn).Accept()
if err != nil { if err != nil {
log.Printf("failed to accept connection: %v", err) log.Printf("failed to accept connection: %v", err)
continue continue

View File

@@ -16,7 +16,16 @@ import (
"github.com/libdns/cloudflare" "github.com/libdns/cloudflare"
) )
type TLSManager struct { type TLSManager interface {
userCertsExistAndValid() bool
loadUserCerts() error
startCertWatcher()
initCertMagic() error
getTLSConfig() *tls.Config
getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error)
}
type tlsManager struct {
domain string domain string
certPath string certPath string
keyPath string keyPath string
@@ -30,7 +39,7 @@ type TLSManager struct {
useCertMagic bool useCertMagic bool
} }
var tlsManager *TLSManager var globalTLSManager TLSManager
var tlsManagerOnce sync.Once var tlsManagerOnce sync.Once
func NewTLSConfig(domain string) (*tls.Config, error) { func NewTLSConfig(domain string) (*tls.Config, error) {
@@ -41,7 +50,7 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
keyPath := "certs/tls/privkey.pem" keyPath := "certs/tls/privkey.pem"
storagePath := "certs/tls/certmagic" storagePath := "certs/tls/certmagic"
tm := &TLSManager{ tm := &tlsManager{
domain: domain, domain: domain,
certPath: certPath, certPath: certPath,
keyPath: keyPath, keyPath: keyPath,
@@ -72,14 +81,14 @@ func NewTLSConfig(domain string) (*tls.Config, error) {
tm.useCertMagic = true tm.useCertMagic = true
} }
tlsManager = tm globalTLSManager = tm
}) })
if initErr != nil { if initErr != nil {
return nil, initErr return nil, initErr
} }
return tlsManager.getTLSConfig(), nil return globalTLSManager.getTLSConfig(), nil
} }
func isACMEConfigComplete() bool { func isACMEConfigComplete() bool {
@@ -87,7 +96,7 @@ func isACMEConfigComplete() bool {
return cfAPIToken != "" return cfAPIToken != ""
} }
func (tm *TLSManager) userCertsExistAndValid() bool { func (tm *tlsManager) userCertsExistAndValid() bool {
if _, err := os.Stat(tm.certPath); os.IsNotExist(err) { if _, err := os.Stat(tm.certPath); os.IsNotExist(err) {
log.Printf("Certificate file not found: %s", tm.certPath) log.Printf("Certificate file not found: %s", tm.certPath)
return false return false
@@ -158,7 +167,7 @@ func ValidateCertDomains(certPath, domain string) bool {
return hasBase && hasWildcard return hasBase && hasWildcard
} }
func (tm *TLSManager) loadUserCerts() error { func (tm *tlsManager) loadUserCerts() error {
cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath) cert, err := tls.LoadX509KeyPair(tm.certPath, tm.keyPath)
if err != nil { if err != nil {
return err return err
@@ -172,7 +181,7 @@ func (tm *TLSManager) loadUserCerts() error {
return nil return nil
} }
func (tm *TLSManager) startCertWatcher() { func (tm *tlsManager) startCertWatcher() {
go func() { go func() {
var lastCertMod, lastKeyMod time.Time var lastCertMod, lastKeyMod time.Time
@@ -227,7 +236,7 @@ func (tm *TLSManager) startCertWatcher() {
}() }()
} }
func (tm *TLSManager) initCertMagic() error { func (tm *tlsManager) initCertMagic() error {
if err := os.MkdirAll(tm.storagePath, 0700); err != nil { if err := os.MkdirAll(tm.storagePath, 0700); err != nil {
return fmt.Errorf("failed to create cert storage directory: %w", err) return fmt.Errorf("failed to create cert storage directory: %w", err)
} }
@@ -289,14 +298,14 @@ func (tm *TLSManager) initCertMagic() error {
return nil return nil
} }
func (tm *TLSManager) getTLSConfig() *tls.Config { func (tm *tlsManager) getTLSConfig() *tls.Config {
return &tls.Config{ return &tls.Config{
GetCertificate: tm.getCertificate, GetCertificate: tm.getCertificate,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
} }
} }
func (tm *TLSManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { func (tm *tlsManager) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if tm.useCertMagic { if tm.useCertMagic {
return tm.magic.GetCertificate(hello) return tm.magic.GetCertificate(hello)
} }

View File

@@ -31,11 +31,21 @@ func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
} }
type Forwarder struct { type Forwarder struct {
Listener net.Listener listener net.Listener
TunnelType types.TunnelType tunnelType types.TunnelType
ForwardedPort uint16 forwardedPort uint16
SlugManager slug.Manager slugManager slug.Manager
Lifecycle Lifecycle lifecycle Lifecycle
}
func NewForwarder(slugManager slug.Manager) *Forwarder {
return &Forwarder{
listener: nil,
tunnelType: "",
forwardedPort: 0,
slugManager: slugManager,
lifecycle: nil,
}
} }
type Lifecycle interface { type Lifecycle interface {
@@ -58,7 +68,7 @@ type ForwardingController interface {
} }
func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) { func (f *Forwarder) SetLifecycle(lifecycle Lifecycle) {
f.Lifecycle = lifecycle f.lifecycle = lifecycle
} }
func (f *Forwarder) AcceptTCPConnections() { func (f *Forwarder) AcceptTCPConnections() {
@@ -90,7 +100,7 @@ func (f *Forwarder) AcceptTCPConnections() {
resultChan := make(chan channelResult, 1) resultChan := make(chan channelResult, 1)
go func() { go func() {
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := f.lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
resultChan <- channelResult{channel, reqs, err} resultChan <- channelResult{channel, reqs, err}
}() }()
@@ -164,27 +174,27 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
} }
func (f *Forwarder) SetType(tunnelType types.TunnelType) { func (f *Forwarder) SetType(tunnelType types.TunnelType) {
f.TunnelType = tunnelType f.tunnelType = tunnelType
} }
func (f *Forwarder) GetTunnelType() types.TunnelType { func (f *Forwarder) GetTunnelType() types.TunnelType {
return f.TunnelType return f.tunnelType
} }
func (f *Forwarder) GetForwardedPort() uint16 { func (f *Forwarder) GetForwardedPort() uint16 {
return f.ForwardedPort return f.forwardedPort
} }
func (f *Forwarder) SetForwardedPort(port uint16) { func (f *Forwarder) SetForwardedPort(port uint16) {
f.ForwardedPort = port f.forwardedPort = port
} }
func (f *Forwarder) SetListener(listener net.Listener) { func (f *Forwarder) SetListener(listener net.Listener) {
f.Listener = listener f.listener = listener
} }
func (f *Forwarder) GetListener() net.Listener { func (f *Forwarder) GetListener() net.Listener {
return f.Listener return f.listener
} }
func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) { func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
@@ -197,7 +207,7 @@ func (f *Forwarder) WriteBadGatewayResponse(dst io.Writer) {
func (f *Forwarder) Close() error { func (f *Forwarder) Close() error {
if f.GetListener() != nil { if f.GetListener() != nil {
return f.Listener.Close() return f.listener.Close()
} }
return nil return nil
} }

View File

@@ -49,7 +49,7 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -59,13 +59,13 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
var rawPortToBind uint32 var rawPortToBind uint32
if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil { if err := binary.Read(reader, binary.BigEndian, &rawPortToBind); err != nil {
log.Println("Failed to read port from payload:", err) log.Println("Failed to read port from payload:", err)
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02) \r\n", rawPortToBind))
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -73,13 +73,13 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
} }
if rawPortToBind > 65535 { if rawPortToBind > 65535 {
s.Interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is larger then allowed port of 65535. (02)\r\n", rawPortToBind))
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -89,13 +89,13 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
portToBind := uint16(rawPortToBind) portToBind := uint16(rawPortToBind)
if isBlockedPort(portToBind) { if isBlockedPort(portToBind) {
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (02)\r\n", portToBind))
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -110,26 +110,26 @@ func (s *SSHSession) HandleTCPIPForward(req *ssh.Request) {
unassign, success := portUtil.Default.GetUnassignedPort() unassign, success := portUtil.Default.GetUnassignedPort()
portToBind = unassign portToBind = unassign
if !success { if !success {
s.Interaction.SendMessage("No available port\r\n") s.interaction.SendMessage("No available port\r\n")
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return return
} }
} else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse { } else if isUse, isExist := portUtil.Default.GetPortStatus(portToBind); isExist && isUse {
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port. (03)\r\n", portToBind))
err := req.Reply(false, nil) err := req.Reply(false, nil)
if err != nil { if err != nil {
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -193,21 +193,21 @@ func (s *SSHSession) HandleHTTPForward(req *ssh.Request, portToBind uint16) {
return return
} }
s.Forwarder.SetType(types.HTTP) s.forwarder.SetType(types.HTTP)
s.Forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.SlugManager.Set(slug) s.slugManager.Set(slug)
s.Interaction.SendMessage("\033[H\033[2J") s.interaction.SendMessage("\033[H\033[2J")
s.Interaction.ShowWelcomeMessage() s.interaction.ShowWelcomeMessage()
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain)) s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s\r\n", protocol, slug, domain))
s.Lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
s.Interaction.HandleUserInput() s.interaction.HandleUserInput()
} }
func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) { func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind uint16) {
log.Printf("Requested forwarding on %s:%d", addr, portToBind) log.Printf("Requested forwarding on %s:%d", addr, portToBind)
listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind)) listener, err := net.Listen("tcp", fmt.Sprintf("0.0.0.0:%d", portToBind))
if err != nil { if err != nil {
s.Interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind)) s.interaction.SendMessage(fmt.Sprintf("Port %d is already in use or restricted. Please choose a different port.\r\n", portToBind))
if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil { if setErr := portUtil.Default.SetPortStatus(portToBind, false); setErr != nil {
log.Printf("Failed to reset port status: %v", setErr) log.Printf("Failed to reset port status: %v", setErr)
} }
@@ -216,7 +216,7 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
log.Println("Failed to reply to request:", err) log.Println("Failed to reply to request:", err)
return return
} }
err = s.Lifecycle.Close() err = s.lifecycle.Close()
if err != nil { if err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
@@ -253,15 +253,15 @@ func (s *SSHSession) HandleTCPForward(req *ssh.Request, addr string, portToBind
return return
} }
s.Forwarder.SetType(types.TCP) s.forwarder.SetType(types.TCP)
s.Forwarder.SetListener(listener) s.forwarder.SetListener(listener)
s.Forwarder.SetForwardedPort(portToBind) s.forwarder.SetForwardedPort(portToBind)
s.Interaction.SendMessage("\033[H\033[2J") s.interaction.SendMessage("\033[H\033[2J")
s.Interaction.ShowWelcomeMessage() s.interaction.ShowWelcomeMessage()
s.Interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("DOMAIN", "localhost"), s.Forwarder.GetForwardedPort())) s.interaction.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", utils.Getenv("DOMAIN", "localhost"), s.forwarder.GetForwardedPort()))
s.Lifecycle.SetStatus(types.RUNNING) s.lifecycle.SetStatus(types.RUNNING)
go s.Forwarder.AcceptTCPConnections() go s.forwarder.AcceptTCPConnections()
s.Interaction.HandleUserInput() s.interaction.HandleUserInput()
} }
func generateUniqueSlug() string { func generateUniqueSlug() string {

View File

@@ -42,21 +42,37 @@ type Forwarder interface {
} }
type Interaction struct { type Interaction struct {
InputLength int inputLength int
CommandBuffer *bytes.Buffer commandBuffer *bytes.Buffer
InteractiveMode bool interactiveMode bool
InteractionType types.InteractionType interactionType types.InteractionType
EditSlug string editSlug string
channel ssh.Channel channel ssh.Channel
SlugManager slug.Manager slugManager slug.Manager
Forwarder Forwarder forwarder Forwarder
Lifecycle Lifecycle lifecycle Lifecycle
pendingExit bool pendingExit bool
updateClientSlug func(oldSlug, newSlug string) bool updateClientSlug func(oldSlug, newSlug string) bool
} }
func NewInteraction(slugManager slug.Manager, forwarder Forwarder) *Interaction {
return &Interaction{
inputLength: 0,
commandBuffer: bytes.NewBuffer(make([]byte, 0, 20)),
interactiveMode: false,
interactionType: "",
editSlug: "",
channel: nil,
slugManager: slugManager,
forwarder: forwarder,
lifecycle: nil,
pendingExit: false,
updateClientSlug: nil,
}
}
func (i *Interaction) SetLifecycle(lifecycle Lifecycle) { func (i *Interaction) SetLifecycle(lifecycle Lifecycle) {
i.Lifecycle = lifecycle i.lifecycle = lifecycle
} }
func (i *Interaction) SetChannel(channel ssh.Channel) { func (i *Interaction) SetChannel(channel ssh.Channel) {
@@ -77,7 +93,7 @@ func (i *Interaction) SendMessage(message string) {
func (i *Interaction) HandleUserInput() { func (i *Interaction) HandleUserInput() {
buf := make([]byte, 1) buf := make([]byte, 1)
i.InteractiveMode = false i.interactiveMode = false
for { for {
n, err := i.channel.Read(buf) n, err := i.channel.Read(buf)
@@ -99,7 +115,7 @@ func (i *Interaction) handleReadError(err error) {
} }
func (i *Interaction) processCharacter(char byte) { func (i *Interaction) processCharacter(char byte) {
if i.InteractiveMode { if i.interactiveMode {
i.handleInteractiveMode(char) i.handleInteractiveMode(char)
return return
} }
@@ -113,7 +129,7 @@ func (i *Interaction) processCharacter(char byte) {
} }
func (i *Interaction) handleInteractiveMode(char byte) { func (i *Interaction) handleInteractiveMode(char byte) {
switch i.InteractionType { switch i.interactionType {
case types.Slug: case types.Slug:
i.HandleSlugEditMode(char) i.HandleSlugEditMode(char)
} }
@@ -123,7 +139,7 @@ func (i *Interaction) handleExitSequence(char byte) bool {
if char == ctrlC { if char == ctrlC {
if i.pendingExit { if i.pendingExit {
i.SendMessage("Closing connection...\r\n") i.SendMessage("Closing connection...\r\n")
if err := i.Lifecycle.Close(); err != nil { if err := i.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return true return true
@@ -147,37 +163,37 @@ func (i *Interaction) handleNonInteractiveInput(char byte) {
i.handleBackspace() i.handleBackspace()
case char == forwardSlash: case char == forwardSlash:
i.handleCommandStart() i.handleCommandStart()
case i.CommandBuffer.Len() > 0: case i.commandBuffer.Len() > 0:
i.handleCommandInput(char) i.handleCommandInput(char)
case char == enterChar: case char == enterChar:
i.SendMessage(clearLine) i.SendMessage(clearLine)
default: default:
i.InputLength++ i.inputLength++
} }
} }
func (i *Interaction) handleBackspace() { func (i *Interaction) handleBackspace() {
if i.InputLength > 0 { if i.inputLength > 0 {
i.SendMessage(backspaceSeq) i.SendMessage(backspaceSeq)
} }
if i.CommandBuffer.Len() > 0 { if i.commandBuffer.Len() > 0 {
i.CommandBuffer.Truncate(i.CommandBuffer.Len() - 1) i.commandBuffer.Truncate(i.commandBuffer.Len() - 1)
} }
} }
func (i *Interaction) handleCommandStart() { func (i *Interaction) handleCommandStart() {
i.CommandBuffer.Reset() i.commandBuffer.Reset()
i.CommandBuffer.WriteByte(forwardSlash) i.commandBuffer.WriteByte(forwardSlash)
} }
func (i *Interaction) handleCommandInput(char byte) { func (i *Interaction) handleCommandInput(char byte) {
if char == enterChar { if char == enterChar {
i.SendMessage(clearLine) i.SendMessage(clearLine)
i.HandleCommand(i.CommandBuffer.String()) i.HandleCommand(i.commandBuffer.String())
return return
} }
i.CommandBuffer.WriteByte(char) i.commandBuffer.WriteByte(char)
i.InputLength++ i.inputLength++
} }
func (i *Interaction) HandleSlugEditMode(char byte) { func (i *Interaction) HandleSlugEditMode(char byte) {
@@ -194,15 +210,15 @@ func (i *Interaction) HandleSlugEditMode(char byte) {
} }
func (i *Interaction) handleSlugBackspace() { func (i *Interaction) handleSlugBackspace() {
if len(i.EditSlug) > 0 { if len(i.editSlug) > 0 {
i.EditSlug = i.EditSlug[:len(i.EditSlug)-1] i.editSlug = i.editSlug[:len(i.editSlug)-1]
i.refreshSlugDisplay() i.refreshSlugDisplay()
} }
} }
func (i *Interaction) appendToSlug(char byte) { func (i *Interaction) appendToSlug(char byte) {
if isValidSlugChar(char) { if len(i.editSlug) < maxSlugLength {
i.EditSlug += string(char) i.editSlug += string(char)
i.refreshSlugDisplay() i.refreshSlugDisplay()
} }
} }
@@ -210,16 +226,16 @@ func (i *Interaction) appendToSlug(char byte) {
func (i *Interaction) refreshSlugDisplay() { func (i *Interaction) refreshSlugDisplay() {
domain := utils.Getenv("DOMAIN", "localhost") domain := utils.Getenv("DOMAIN", "localhost")
i.SendMessage(clearToLineEnd) i.SendMessage(clearToLineEnd)
i.SendMessage("➤ " + i.EditSlug + "." + domain) i.SendMessage("➤ " + i.editSlug + "." + domain)
} }
func (i *Interaction) HandleSlugSave() { func (i *Interaction) HandleSlugSave() {
i.SendMessage(clearScreen) i.SendMessage(clearScreen)
switch { switch {
case isForbiddenSlug(i.EditSlug): case isForbiddenSlug(i.editSlug):
i.showForbiddenSlugMessage() i.showForbiddenSlugMessage()
case !isValidSlug(i.EditSlug): case !isValidSlug(i.editSlug):
i.showInvalidSlugMessage() i.showInvalidSlugMessage()
default: default:
i.updateSlug() i.updateSlug()
@@ -230,8 +246,8 @@ func (i *Interaction) HandleSlugSave() {
} }
func (i *Interaction) updateSlug() { func (i *Interaction) updateSlug() {
oldSlug := i.SlugManager.Get() oldSlug := i.slugManager.Get()
newSlug := i.EditSlug newSlug := i.editSlug
if !i.updateClientSlug(oldSlug, newSlug) { if !i.updateClientSlug(oldSlug, newSlug) {
i.HandleSlugUpdateError() i.HandleSlugUpdateError()
@@ -262,8 +278,8 @@ func (i *Interaction) returnToMainScreen() {
i.SendMessage(clearScreen) i.SendMessage(clearScreen)
i.ShowWelcomeMessage() i.ShowWelcomeMessage()
i.ShowForwardingMessage() i.ShowForwardingMessage()
i.InteractiveMode = false i.interactiveMode = false
i.CommandBuffer.Reset() i.commandBuffer.Reset()
} }
func (i *Interaction) HandleSlugCancel() { func (i *Interaction) HandleSlugCancel() {
@@ -271,8 +287,8 @@ func (i *Interaction) HandleSlugCancel() {
i.SendMessage("\r\n\r\n⚠ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n") i.SendMessage("\r\n\r\n⚠ SUBDOMAIN EDIT CANCELLED ⚠️\r\n\r\n")
i.SendMessage("Press any key to continue...\r\n") i.SendMessage("Press any key to continue...\r\n")
i.InteractiveMode = false i.interactiveMode = false
i.InteractionType = "" i.interactionType = ""
i.WaitForKeyPress() i.WaitForKeyPress()
i.SendMessage(clearScreen) i.SendMessage(clearScreen)
@@ -289,7 +305,7 @@ func (i *Interaction) HandleSlugUpdateError() {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
if err := i.Lifecycle.Close(); err != nil { if err := i.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
} }
@@ -308,12 +324,12 @@ func (i *Interaction) HandleCommand(command string) {
i.SendMessage("Unknown command\r\n") i.SendMessage("Unknown command\r\n")
} }
i.CommandBuffer.Reset() i.commandBuffer.Reset()
} }
func (i *Interaction) handleByeCommand() { func (i *Interaction) handleByeCommand() {
i.SendMessage("Closing connection...\r\n") i.SendMessage("Closing connection...\r\n")
if err := i.Lifecycle.Close(); err != nil { if err := i.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
} }
@@ -329,32 +345,32 @@ func (i *Interaction) handleClearCommand() {
} }
func (i *Interaction) handleSlugCommand() { func (i *Interaction) handleSlugCommand() {
if i.Forwarder.GetTunnelType() != types.HTTP { if i.forwarder.GetTunnelType() != types.HTTP {
i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains\r\n", i.Forwarder.GetTunnelType())) i.SendMessage(fmt.Sprintf("\r\n%s tunnels cannot have custom subdomains\r\n", i.forwarder.GetTunnelType()))
return return
} }
i.InteractiveMode = true i.interactiveMode = true
i.InteractionType = types.Slug i.interactionType = types.Slug
i.EditSlug = i.SlugManager.Get() i.editSlug = i.slugManager.Get()
i.SendMessage(clearScreen) i.SendMessage(clearScreen)
i.DisplaySlugEditor() i.DisplaySlugEditor()
domain := utils.Getenv("DOMAIN", "localhost") domain := utils.Getenv("DOMAIN", "localhost")
i.SendMessage("➤ " + i.EditSlug + "." + domain) i.SendMessage("➤ " + i.editSlug + "." + domain)
} }
func (i *Interaction) ShowForwardingMessage() { func (i *Interaction) ShowForwardingMessage() {
domain := utils.Getenv("DOMAIN", "localhost") domain := utils.Getenv("DOMAIN", "localhost")
if i.Forwarder.GetTunnelType() == types.HTTP { if i.forwarder.GetTunnelType() == types.HTTP {
protocol := "http" protocol := "http"
if utils.Getenv("TLS_ENABLED", "false") == "true" { if utils.Getenv("TLS_ENABLED", "false") == "true" {
protocol = "https" protocol = "https"
} }
i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.SlugManager.Get(), domain)) i.SendMessage(fmt.Sprintf("Forwarding your traffic to %s://%s.%s \r\n", protocol, i.slugManager.Get(), domain))
} else { } else {
i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.Forwarder.GetForwardedPort())) i.SendMessage(fmt.Sprintf("Forwarding your traffic to tcp://%s:%d \r\n", domain, i.forwarder.GetForwardedPort()))
} }
} }
@@ -385,7 +401,7 @@ func (i *Interaction) ShowWelcomeMessage() {
func (i *Interaction) DisplaySlugEditor() { func (i *Interaction) DisplaySlugEditor() {
domain := utils.Getenv("DOMAIN", "localhost") domain := utils.Getenv("DOMAIN", "localhost")
fullDomain := i.SlugManager.Get() + "." + domain fullDomain := i.slugManager.Get() + "." + domain
contentLine := " ║ Current: " + fullDomain contentLine := " ║ Current: " + fullDomain
boxWidth := calculateBoxWidth(contentLine) boxWidth := calculateBoxWidth(contentLine)

View File

@@ -22,16 +22,27 @@ type Forwarder interface {
} }
type Lifecycle struct { type Lifecycle struct {
Status types.Status status types.Status
Conn ssh.Conn conn ssh.Conn
Channel ssh.Channel channel ssh.Channel
interaction Interaction
Interaction Interaction forwarder Forwarder
Forwarder Forwarder slugManager slug.Manager
SlugManager slug.Manager
unregisterClient func(slug string) unregisterClient func(slug string)
} }
func NewLifecycle(conn ssh.Conn, interaction Interaction, forwarder Forwarder, slugManager slug.Manager) *Lifecycle {
return &Lifecycle{
status: "",
conn: conn,
channel: nil,
interaction: interaction,
forwarder: forwarder,
slugManager: slugManager,
unregisterClient: nil,
}
}
func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) { func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
l.unregisterClient = unregisterClient l.unregisterClient = unregisterClient
} }
@@ -46,46 +57,46 @@ type SessionLifecycle interface {
} }
func (l *Lifecycle) GetChannel() ssh.Channel { func (l *Lifecycle) GetChannel() ssh.Channel {
return l.Channel return l.channel
} }
func (l *Lifecycle) SetChannel(channel ssh.Channel) { func (l *Lifecycle) SetChannel(channel ssh.Channel) {
l.Channel = channel l.channel = channel
} }
func (l *Lifecycle) GetConnection() ssh.Conn { func (l *Lifecycle) GetConnection() ssh.Conn {
return l.Conn return l.conn
} }
func (l *Lifecycle) SetStatus(status types.Status) { func (l *Lifecycle) SetStatus(status types.Status) {
l.Status = status l.status = status
} }
func (l *Lifecycle) Close() error { func (l *Lifecycle) Close() error {
err := l.Forwarder.Close() err := l.forwarder.Close()
if err != nil && !errors.Is(err, net.ErrClosed) { if err != nil && !errors.Is(err, net.ErrClosed) {
return err return err
} }
if l.Channel != nil { if l.channel != nil {
err := l.Channel.Close() err := l.channel.Close()
if err != nil && !errors.Is(err, io.EOF) { if err != nil && !errors.Is(err, io.EOF) {
return err return err
} }
} }
if l.Conn != nil { if l.conn != nil {
err := l.Conn.Close() err := l.conn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) { if err != nil && !errors.Is(err, net.ErrClosed) {
return err return err
} }
} }
clientSlug := l.SlugManager.Get() clientSlug := l.slugManager.Get()
if clientSlug != "" { if clientSlug != "" {
l.unregisterClient(clientSlug) l.unregisterClient(clientSlug)
} }
if l.Forwarder.GetTunnelType() == types.TCP { if l.forwarder.GetTunnelType() == types.TCP {
err := portUtil.Default.SetPortStatus(l.Forwarder.GetForwardedPort(), false) err := portUtil.Default.SetPortStatus(l.forwarder.GetForwardedPort(), false)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -1,7 +1,6 @@
package session package session
import ( import (
"bytes"
"fmt" "fmt"
"log" "log"
"sync" "sync"
@@ -28,36 +27,33 @@ type Session interface {
} }
type SSHSession struct { type SSHSession struct {
Lifecycle lifecycle.SessionLifecycle lifecycle lifecycle.SessionLifecycle
Interaction interaction.Controller interaction interaction.Controller
Forwarder forwarder.ForwardingController forwarder forwarder.ForwardingController
SlugManager slug.Manager slugManager slug.Manager
}
func (s *SSHSession) GetLifecycle() lifecycle.SessionLifecycle {
return s.lifecycle
}
func (s *SSHSession) GetInteraction() interaction.Controller {
return s.interaction
}
func (s *SSHSession) GetForwarder() forwarder.ForwardingController {
return s.forwarder
}
func (s *SSHSession) GetSlugManager() slug.Manager {
return s.slugManager
} }
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
slugManager := slug.NewManager() slugManager := slug.NewManager()
forwarderManager := &forwarder.Forwarder{ forwarderManager := forwarder.NewForwarder(slugManager)
Listener: nil, interactionManager := interaction.NewInteraction(slugManager, forwarderManager)
TunnelType: "", lifecycleManager := lifecycle.NewLifecycle(conn, interactionManager, forwarderManager, slugManager)
ForwardedPort: 0,
SlugManager: slugManager,
}
interactionManager := &interaction.Interaction{
CommandBuffer: bytes.NewBuffer(make([]byte, 0, 20)),
InteractiveMode: false,
EditSlug: "",
SlugManager: slugManager,
Forwarder: forwarderManager,
Lifecycle: nil,
}
lifecycleManager := &lifecycle.Lifecycle{
Status: "",
Conn: conn,
Channel: nil,
Interaction: interactionManager,
Forwarder: forwarderManager,
SlugManager: slugManager,
}
interactionManager.SetLifecycle(lifecycleManager) interactionManager.SetLifecycle(lifecycleManager)
interactionManager.SetSlugModificator(updateClientSlug) interactionManager.SetSlugModificator(updateClientSlug)
@@ -65,10 +61,10 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
lifecycleManager.SetUnregisterClient(unregisterClient) lifecycleManager.SetUnregisterClient(unregisterClient)
session := &SSHSession{ session := &SSHSession{
Lifecycle: lifecycleManager, lifecycle: lifecycleManager,
Interaction: interactionManager, interaction: interactionManager,
Forwarder: forwarderManager, forwarder: forwarderManager,
SlugManager: slugManager, slugManager: slugManager,
} }
var once sync.Once var once sync.Once
@@ -79,13 +75,13 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
continue continue
} }
once.Do(func() { once.Do(func() {
session.Lifecycle.SetChannel(ch) session.lifecycle.SetChannel(ch)
session.Interaction.SetChannel(ch) session.interaction.SetChannel(ch)
tcpipReq := session.waitForTCPIPForward(forwardingReq) tcpipReq := session.waitForTCPIPForward(forwardingReq)
if tcpipReq == nil { if tcpipReq == nil {
session.Interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200"))) session.interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200")))
if err := session.Lifecycle.Close(); err != nil { if err := session.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
return return
@@ -94,7 +90,7 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
}) })
go session.HandleGlobalRequest(reqs) go session.HandleGlobalRequest(reqs)
} }
if err := session.Lifecycle.Close(); err != nil { if err := session.lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err) log.Printf("failed to close session: %v", err)
} }
} }
@@ -134,7 +130,7 @@ func updateClientSlug(oldSlug, newSlug string) bool {
} }
delete(Clients, oldSlug) delete(Clients, oldSlug)
client.SlugManager.Set(newSlug) client.slugManager.Set(newSlug)
Clients[newSlug] = client Clients[newSlug] = client
return true return true
} }