From 6dff7352169adad7c86da58f969d5b526c76d764 Mon Sep 17 00:00:00 2001 From: bagas Date: Thu, 18 Dec 2025 21:09:12 +0700 Subject: [PATCH] fix: prevent OOM by bounding io.Copy buffer usage --- README.md | 1 + session/forwarder/forwarder.go | 19 +++++++++++++++++-- utils/utils.go | 10 ++++++++++ 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 9e8bf54..d66785f 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ The following environment variables can be configured in the `.env` file: | `SSH_PRIVATE_KEY` | Path to SSH private key (auto-generated if missing) | `certs/id_rsa` | No | | `CORS_LIST` | Comma-separated list of allowed CORS origins | - | No | | `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No | +| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No | | `PPROF_ENABLED` | Enable pprof profiling server | `false` | No | | `PPROF_PORT` | Port for pprof server | `6060` | No | diff --git a/session/forwarder/forwarder.go b/session/forwarder/forwarder.go index 9d94abe..3bf41bb 100644 --- a/session/forwarder/forwarder.go +++ b/session/forwarder/forwarder.go @@ -8,12 +8,27 @@ import ( "log" "net" "strconv" + "sync" "tunnel_pls/session/slug" "tunnel_pls/types" + "tunnel_pls/utils" "golang.org/x/crypto/ssh" ) +var bufferPool = sync.Pool{ + New: func() interface{} { + bufSize := utils.GetBufferSize() + return make([]byte, bufSize) + }, +} + +func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) { + buf := bufferPool.Get().([]byte) + defer bufferPool.Put(buf) + return io.CopyBuffer(dst, src, buf) +} + type Forwarder struct { Listener net.Listener TunnelType types.TunnelType @@ -103,7 +118,7 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA done := make(chan struct{}, 2) go func() { - _, err := io.Copy(src, dst) + _, err := copyWithBuffer(src, dst) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { log.Printf("Error copying from conn.Reader to channel: %v", err) } @@ -111,7 +126,7 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA }() go func() { - _, err := io.Copy(dst, src) + _, err := copyWithBuffer(dst, src) if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) { log.Printf("Error copying from channel to conn.Writer: %v", err) } diff --git a/utils/utils.go b/utils/utils.go index 2518627..d2087d1 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -9,6 +9,7 @@ import ( mathrand "math/rand" "os" "path/filepath" + "strconv" "strings" "sync" "time" @@ -62,6 +63,15 @@ func Getenv(key, defaultValue string) string { return val } +func GetBufferSize() int { + sizeStr := Getenv("BUFFER_SIZE", "32768") + size, err := strconv.Atoi(sizeStr) + if err != nil || size < 4096 || size > 1048576 { + return 32768 + } + return size +} + func GenerateSSHKeyIfNotExist(keyPath string) error { if _, err := os.Stat(keyPath); err == nil { log.Printf("SSH key already exists at %s", keyPath)