fix: prevent OOM by bounding io.Copy buffer usage
All checks were successful
Docker Build and Push / build-and-push (push) Successful in 3m47s

This commit is contained in:
2025-12-18 21:09:12 +07:00
parent 7bc5a01ba7
commit 6dff735216
3 changed files with 28 additions and 2 deletions

View File

@@ -29,6 +29,7 @@ The following environment variables can be configured in the `.env` file:
| `SSH_PRIVATE_KEY` | Path to SSH private key (auto-generated if missing) | `certs/id_rsa` | No |
| `CORS_LIST` | Comma-separated list of allowed CORS origins | - | No |
| `ALLOWED_PORTS` | Port range for TCP tunnels (e.g., 40000-41000) | `40000-41000` | No |
| `BUFFER_SIZE` | Buffer size for io.Copy operations in bytes (4096-1048576) | `32768` | No |
| `PPROF_ENABLED` | Enable pprof profiling server | `false` | No |
| `PPROF_PORT` | Port for pprof server | `6060` | No |

View File

@@ -8,12 +8,27 @@ import (
"log"
"net"
"strconv"
"sync"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"tunnel_pls/utils"
"golang.org/x/crypto/ssh"
)
var bufferPool = sync.Pool{
New: func() interface{} {
bufSize := utils.GetBufferSize()
return make([]byte, bufSize)
},
}
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := bufferPool.Get().([]byte)
defer bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
type Forwarder struct {
Listener net.Listener
TunnelType types.TunnelType
@@ -103,7 +118,7 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
done := make(chan struct{}, 2)
go func() {
_, err := io.Copy(src, dst)
_, err := copyWithBuffer(src, dst)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying from conn.Reader to channel: %v", err)
}
@@ -111,7 +126,7 @@ func (f *Forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteA
}()
go func() {
_, err := io.Copy(dst, src)
_, err := copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
log.Printf("Error copying from channel to conn.Writer: %v", err)
}

View File

@@ -9,6 +9,7 @@ import (
mathrand "math/rand"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
@@ -62,6 +63,15 @@ func Getenv(key, defaultValue string) string {
return val
}
func GetBufferSize() int {
sizeStr := Getenv("BUFFER_SIZE", "32768")
size, err := strconv.Atoi(sizeStr)
if err != nil || size < 4096 || size > 1048576 {
return 32768
}
return size
}
func GenerateSSHKeyIfNotExist(keyPath string) error {
if _, err := os.Stat(keyPath); err == nil {
log.Printf("SSH key already exists at %s", keyPath)