staging #41

Merged
bagas merged 10 commits from staging into main 2025-12-28 07:55:00 +00:00
7 changed files with 130 additions and 89 deletions
Showing only changes of commit 76d1202b8e - Show all commits

View File

@@ -39,6 +39,21 @@ The following environment variables can be configured in the `.env` file:
If the SSH private key specified in `SSH_PRIVATE_KEY` doesn't exist, the application will automatically generate a new 4096-bit RSA key pair at the specified location. This makes it easier to get started without manually creating SSH keys. If the SSH private key specified in `SSH_PRIVATE_KEY` doesn't exist, the application will automatically generate a new 4096-bit RSA key pair at the specified location. This makes it easier to get started without manually creating SSH keys.
### Memory Optimization
The application uses a buffer pool with controlled buffer sizes to prevent excessive memory usage under high concurrent loads. The `BUFFER_SIZE` environment variable controls the size of buffers used for io.Copy operations:
- **Default:** 32768 bytes (32 KB) - Good balance for most scenarios
- **Minimum:** 4096 bytes (4 KB) - Lower memory usage, more CPU overhead
- **Maximum:** 1048576 bytes (1 MB) - Higher throughput, more memory usage
**Recommended settings based on load:**
- **Low traffic (<100 concurrent):** `BUFFER_SIZE=32768` (default)
- **High traffic (>100 concurrent):** `BUFFER_SIZE=16384` or `BUFFER_SIZE=8192`
- **Very high traffic (>1000 concurrent):** `BUFFER_SIZE=8192` or `BUFFER_SIZE=4096`
The buffer pool reuses buffers across connections, preventing memory fragmentation and reducing garbage collection pressure.
### Profiling with pprof ### Profiling with pprof
To enable profiling for performance analysis: To enable profiling for performance analysis:

View File

@@ -10,8 +10,11 @@ import (
"net" "net"
"regexp" "regexp"
"strings" "strings"
"time"
"tunnel_pls/session" "tunnel_pls/session"
"tunnel_pls/utils" "tunnel_pls/utils"
"golang.org/x/crypto/ssh"
) )
type Interaction interface { type Interaction interface {
@@ -295,30 +298,38 @@ func Handler(conn net.Conn) {
func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) { func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshSession *session.SSHSession) {
payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr) payload := sshSession.Forwarder.CreateForwardedTCPIPPayload(cw.RemoteAddr)
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) channel, reqs, err := sshSession.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
if err != nil { resultChan <- channelResult{channel, reqs, err}
log.Printf("Failed to open forwarded-tcpip channel: %v", err) }()
if closer, ok := cw.writer.(io.Closer); ok {
if closeErr := closer.Close(); closeErr != nil { var channel ssh.Channel
log.Printf("Failed to close connection: %v", closeErr) var reqs <-chan *ssh.Request
}
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer)
return
} }
channel = result.channel
reqs = result.reqs
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
sshSession.Forwarder.WriteBadGatewayResponse(cw.writer)
return return
} }
go func() { go ssh.DiscardRequests(reqs)
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in request handler goroutine: %v", r)
}
}()
for req := range reqs {
if err := req.Reply(false, nil); err != nil {
log.Printf("Failed to reply to request: %v", err)
return
}
}
}()
fingerprintMiddleware := NewTunnelFingerprint() fingerprintMiddleware := NewTunnelFingerprint()
forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr) forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr)
@@ -329,14 +340,13 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS
cw.reqHeader = initialRequest cw.reqHeader = initialRequest
for _, m := range cw.reqStartMW { for _, m := range cw.reqStartMW {
err = m.HandleRequest(cw.reqHeader) if err := m.HandleRequest(cw.reqHeader); err != nil {
if err != nil {
log.Printf("Error handling request: %v", err) log.Printf("Error handling request: %v", err)
return return
} }
} }
_, err = channel.Write(initialRequest.Finalize()) _, err := channel.Write(initialRequest.Finalize())
if err != nil { if err != nil {
log.Printf("Failed to forward request: %v", err) log.Printf("Failed to forward request: %v", err)
return return

View File

@@ -23,20 +23,15 @@ func NewServer(config *ssh.ServerConfig) *Server {
return nil return nil
} }
if utils.Getenv("TLS_ENABLED", "false") == "true" { if utils.Getenv("TLS_ENABLED", "false") == "true" {
go func() {
err = NewHTTPSServer() err = NewHTTPSServer()
if err != nil { if err != nil {
log.Fatalf("failed to start https server: %v", err) log.Fatalf("failed to start https server: %v", err)
} }
return
}()
} }
go func() {
err = NewHTTPServer() err = NewHTTPServer()
if err != nil { if err != nil {
log.Fatalf("failed to start http server: %v", err) log.Fatalf("failed to start http server: %v", err)
} }
}()
return &Server{ return &Server{
Conn: &listener, Conn: &listener,
Config: config, Config: config,

View File

@@ -9,6 +9,7 @@ import (
"net" "net"
"strconv" "strconv"
"sync" "sync"
"time"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
"tunnel_pls/utils" "tunnel_pls/utils"
@@ -70,26 +71,52 @@ func (f *Forwarder) AcceptTCPConnections() {
log.Printf("Error accepting connection: %v", err) log.Printf("Error accepting connection: %v", err)
continue continue
} }
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload) if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
if err != nil { log.Printf("Failed to set connection deadline: %v", err)
log.Printf("Failed to open forwarded-tcpip channel: %v", err)
if closeErr := conn.Close(); closeErr != nil { if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr) log.Printf("Failed to close connection: %v", closeErr)
} }
continue continue
} }
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() { go func() {
for req := range reqs { channel, reqs, err := f.Lifecycle.GetConnection().OpenChannel("forwarded-tcpip", payload)
err := req.Reply(false, nil) resultChan <- channelResult{channel, reqs, err}
if err != nil {
log.Printf("Failed to reply to request: %v", err)
return
}
}
}() }()
go f.HandleConnection(conn, channel, conn.RemoteAddr())
select {
case result := <-resultChan:
if result.err != nil {
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
continue
}
if err := conn.SetDeadline(time.Time{}); err != nil {
log.Printf("Failed to clear connection deadline: %v", err)
}
go ssh.DiscardRequests(result.reqs)
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
case <-time.After(5 * time.Second):
log.Printf("Timeout opening forwarded-tcpip channel")
if closeErr := conn.Close(); closeErr != nil {
log.Printf("Failed to close connection: %v", closeErr)
}
}
} }
} }

View File

@@ -19,9 +19,6 @@ var blockedReservedPorts = []uint16{1080, 1433, 1521, 1900, 2049, 3306, 3389, 54
func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { func (s *SSHSession) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
for req := range GlobalRequest { for req := range GlobalRequest {
switch req.Type { switch req.Type {
case "tcpip-forward":
s.HandleTCPIPForward(req)
return
case "shell", "pty-req", "window-change": case "shell", "pty-req", "window-change":
err := req.Reply(true, nil) err := req.Reply(true, nil)
if err != nil { if err != nil {

View File

@@ -2,11 +2,8 @@ package lifecycle
import ( import (
"errors" "errors"
"fmt"
"io" "io"
"log"
"net" "net"
"time"
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/types"
@@ -41,7 +38,6 @@ func (l *Lifecycle) SetUnregisterClient(unregisterClient func(slug string)) {
type SessionLifecycle interface { type SessionLifecycle interface {
Close() error Close() error
WaitForRunningStatus()
SetStatus(status types.Status) SetStatus(status types.Status)
GetConnection() ssh.Conn GetConnection() ssh.Conn
GetChannel() ssh.Channel GetChannel() ssh.Channel
@@ -62,33 +58,6 @@ func (l *Lifecycle) GetConnection() ssh.Conn {
func (l *Lifecycle) SetStatus(status types.Status) { func (l *Lifecycle) SetStatus(status types.Status) {
l.Status = status l.Status = status
} }
func (l *Lifecycle) WaitForRunningStatus() {
timeout := time.After(3 * time.Second)
ticker := time.NewTicker(150 * time.Millisecond)
defer ticker.Stop()
frames := []string{"-", "\\", "|", "/"}
i := 0
for {
select {
case <-ticker.C:
l.Interaction.SendMessage(fmt.Sprintf("\rLoading %s", frames[i]))
i = (i + 1) % len(frames)
if l.Status == types.RUNNING {
l.Interaction.SendMessage("\r\033[K")
return
}
case <-timeout:
l.Interaction.SendMessage("\r\033[K")
l.Interaction.SendMessage("TCP/IP request not received in time.\r\nCheck your internet connection and confirm the server responds within 3000ms.\r\nEnsure you ran the correct command. For more details, visit https://tunnl.live.\r\n\r\n")
err := l.Close()
if err != nil {
log.Printf("failed to close session: %v", err)
}
log.Println("Timeout waiting for session to start running")
return
}
}
}
func (l *Lifecycle) Close() error { func (l *Lifecycle) Close() error {
err := l.Forwarder.Close() err := l.Forwarder.Close()

View File

@@ -2,13 +2,15 @@ package session
import ( import (
"bytes" "bytes"
"fmt"
"log" "log"
"sync" "sync"
"time"
"tunnel_pls/session/forwarder" "tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction" "tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle" "tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug" "tunnel_pls/session/slug"
"tunnel_pls/types" "tunnel_pls/utils"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
@@ -30,8 +32,6 @@ type SSHSession struct {
Interaction interaction.Controller Interaction interaction.Controller
Forwarder forwarder.ForwardingController Forwarder forwarder.ForwardingController
SlugManager slug.Manager SlugManager slug.Manager
channelOnce sync.Once
} }
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan ssh.NewChannel) {
@@ -71,20 +71,27 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
SlugManager: slugManager, SlugManager: slugManager,
} }
var once sync.Once
for channel := range sshChan { for channel := range sshChan {
ch, reqs, err := channel.Accept() ch, reqs, err := channel.Accept()
if err != nil { if err != nil {
log.Printf("failed to accept channel: %v", err) log.Printf("failed to accept channel: %v", err)
continue continue
} }
session.channelOnce.Do(func() { once.Do(func() {
session.Lifecycle.SetChannel(ch) session.Lifecycle.SetChannel(ch)
session.Interaction.SetChannel(ch) session.Interaction.SetChannel(ch)
session.Lifecycle.SetStatus(types.SETUP)
go session.HandleGlobalRequest(forwardingReq)
session.Lifecycle.WaitForRunningStatus()
})
tcpipReq := session.waitForTCPIPForward(forwardingReq)
if tcpipReq == nil {
session.Interaction.SendMessage(fmt.Sprintf("Port forwarding request not received.\r\nEnsure you ran the correct command with -R flag.\r\nExample: ssh %s -p %s -R 80:localhost:3000\r\nFor more details, visit https://tunnl.live.\r\n\r\n", utils.Getenv("DOMAIN", "localhost"), utils.Getenv("PORT", "2200")))
if err := session.Lifecycle.Close(); err != nil {
log.Printf("failed to close session: %v", err)
}
return
}
session.HandleTCPIPForward(tcpipReq)
})
go session.HandleGlobalRequest(reqs) go session.HandleGlobalRequest(reqs)
} }
if err := session.Lifecycle.Close(); err != nil { if err := session.Lifecycle.Close(); err != nil {
@@ -92,6 +99,27 @@ func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request, sshChan <-chan
} }
} }
func (s *SSHSession) waitForTCPIPForward(forwardingReq <-chan *ssh.Request) *ssh.Request {
select {
case req, ok := <-forwardingReq:
if !ok {
log.Println("Forwarding request channel closed")
return nil
}
if req.Type == "tcpip-forward" {
return req
}
if err := req.Reply(false, nil); err != nil {
log.Printf("Failed to reply to request: %v", err)
}
log.Printf("Expected tcpip-forward request, got: %s", req.Type)
return nil
case <-time.After(500 * time.Millisecond):
log.Println("No forwarding request received")
return nil
}
}
func updateClientSlug(oldSlug, newSlug string) bool { func updateClientSlug(oldSlug, newSlug string) bool {
clientsMutex.Lock() clientsMutex.Lock()
defer clientsMutex.Unlock() defer clientsMutex.Unlock()