Improve concurrency and resource management #2

Merged
bagas merged 12 commits from staging into main 2025-07-23 06:51:09 +00:00
7 changed files with 279 additions and 202 deletions
Showing only changes of commit 9f18cfa954 - Show all commits

View File

@ -5,17 +5,14 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/net/context"
"log" "log"
"net" "net"
"strconv"
"strings" "strings"
"time"
"tunnel_pls/session" "tunnel_pls/session"
"tunnel_pls/utils" "tunnel_pls/utils"
) )
var redirectTLS bool = false var redirectTLS = false
func NewHTTPServer() error { func NewHTTPServer() error {
listener, err := net.Listen("tcp", ":80") listener, err := net.Listen("tcp", ":80")
@ -81,23 +78,10 @@ func Handler(conn net.Conn) {
conn.Close() conn.Close()
return return
} }
keepalive, timeout := parseConnectionDetails(headers)
var ctx context.Context
var cancel context.CancelFunc
if keepalive {
if timeout >= 300 {
timeout = 300
}
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second))
} else {
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
}
sshSession.HandleForwardedConnection(session.UserConnection{ sshSession.HandleForwardedConnection(session.UserConnection{
Reader: reader, Reader: reader,
Writer: conn, Writer: conn,
Context: ctx,
Cancel: cancel,
}, sshSession.Connection) }, sshSession.Connection)
return return
} }
@ -131,42 +115,3 @@ func parseHostFromHeader(data []byte) string {
} }
return "" return ""
} }
func parseConnectionDetails(data []byte) (keepAlive bool, timeout int) {
keepAlive = false
timeout = 30
lines := strings.Split(string(data), "\r\n")
for _, line := range lines {
if strings.HasPrefix(strings.ToLower(line), "connection:") {
value := strings.TrimSpace(strings.TrimPrefix(strings.ToLower(line), "connection:"))
keepAlive = (value == "keep-alive")
break
}
}
if keepAlive {
for _, line := range lines {
if strings.HasPrefix(strings.ToLower(line), "keep-alive:") {
value := strings.TrimSpace(strings.TrimPrefix(line, "Keep-Alive:"))
if strings.Contains(value, "timeout=") {
parts := strings.Split(value, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "timeout=") {
timeoutStr := strings.TrimPrefix(part, "timeout=")
if t, err := strconv.Atoi(timeoutStr); err == nil {
timeout = t
}
}
}
}
break
}
}
}
return keepAlive, timeout
}

View File

@ -4,11 +4,9 @@ import (
"bufio" "bufio"
"crypto/tls" "crypto/tls"
"errors" "errors"
"golang.org/x/net/context"
"log" "log"
"net" "net"
"strings" "strings"
"time"
"tunnel_pls/session" "tunnel_pls/session"
"tunnel_pls/utils" "tunnel_pls/utils"
) )
@ -70,23 +68,10 @@ func HandlerTLS(conn net.Conn) {
conn.Close() conn.Close()
return return
} }
keepalive, timeout := parseConnectionDetails(headers)
var ctx context.Context
var cancel context.CancelFunc
if keepalive {
if timeout >= 300 {
timeout = 300
}
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Duration(timeout)*time.Second))
} else {
ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
}
sshSession.HandleForwardedConnection(session.UserConnection{ sshSession.HandleForwardedConnection(session.UserConnection{
Reader: reader, Reader: reader,
Writer: conn, Writer: conn,
Context: ctx,
Cancel: cancel,
}, sshSession.Connection) }, sshSession.Connection)
return return
} }

View File

@ -3,6 +3,7 @@ package session
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -16,7 +17,6 @@ import (
portUtil "tunnel_pls/internal/port" portUtil "tunnel_pls/internal/port"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/net/context"
"tunnel_pls/utils" "tunnel_pls/utils"
) )
@ -29,10 +29,8 @@ const (
) )
type UserConnection struct { type UserConnection struct {
Reader io.Reader Reader io.Reader
Writer net.Conn Writer net.Conn
Context context.Context
Cancel context.CancelFunc
} }
var ( var (
@ -78,6 +76,13 @@ func updateClientSlug(oldSlug, newSlug string) bool {
return true return true
} }
func (s *Session) safeClose() {
s.once.Do(func() {
close(s.ChannelChan)
close(s.Done)
})
}
func (s *Session) Close() { func (s *Session) Close() {
if s.Listener != nil { if s.Listener != nil {
s.Listener.Close() s.Listener.Close()
@ -99,7 +104,7 @@ func (s *Session) Close() {
portUtil.Manager.SetPortStatus(s.ForwardedPort, false) portUtil.Manager.SetPortStatus(s.ForwardedPort, false)
} }
close(s.Done) s.safeClose()
} }
func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) { func (s *Session) HandleGlobalRequest(GlobalRequest <-chan *ssh.Request) {
@ -267,9 +272,8 @@ func (s *Session) acceptTCPConnections() {
} }
go s.HandleForwardedConnection(UserConnection{ go s.HandleForwardedConnection(UserConnection{
Reader: nil, Reader: nil,
Writer: conn, Writer: conn,
Context: context.Background(),
}, s.Connection) }, s.Connection)
} }
} }
@ -538,40 +542,105 @@ func (s *Session) HandleForwardedConnection(conn UserConnection, sshConn *ssh.Se
} }
defer channel.Close() defer channel.Close()
go handleChannelRequests(reqs, conn, channel) go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in request handler: %v", r)
}
}()
for req := range reqs {
req.Reply(false, nil)
}
}()
if conn.Reader == nil { if conn.Reader == nil {
conn.Reader = bufio.NewReader(conn.Writer) conn.Reader = bufio.NewReader(conn.Writer)
} }
go io.Copy(channel, conn.Reader) ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
errChan := make(chan error, 2)
wg.Add(1)
go func() {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in reader copy: %v", r)
errChan <- fmt.Errorf("panic: %v", r)
}
}()
_, err := io.Copy(channel, conn.Reader)
if err != nil && !errors.Is(err, io.EOF) {
log.Printf("Error copying from conn.Reader to channel: %v", err)
errChan <- err
}
cancel()
}()
reader := bufio.NewReader(channel) reader := bufio.NewReader(channel)
_, err = reader.Peek(1)
if err == io.EOF { peekDone := make(chan error, 1)
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, "Could not forward request to the tunnel addr")) go func() {
_, err := reader.Peek(1)
peekDone <- err
}()
select {
case err := <-peekDone:
if err == io.EOF {
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
sendBadGatewayResponse(conn.Writer)
return
}
if err != nil {
log.Printf("Error peeking channel data: %v", err)
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
sendBadGatewayResponse(conn.Writer)
return
}
case <-time.After(5 * time.Second):
log.Printf("Timeout waiting for channel data from %s", conn.Writer.RemoteAddr())
s.sendMessage(fmt.Sprintf("\033[33m%s -> [%s] WARNING -- \"Could not forward request to the tunnel addr\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType))
sendBadGatewayResponse(conn.Writer) sendBadGatewayResponse(conn.Writer)
conn.Writer.Close() return
channel.Close() case <-ctx.Done():
return return
} }
s.sendMessage(fmt.Sprintf("\033[32m %s -> [%s] TUNNEL ADDRESS -- \"%s\" \r\n \033[0m", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp)) s.sendMessage(fmt.Sprintf("\033[32m%s -> [%s] TUNNEL ADDRESS -- \"%s\"\033[0m\r\n", conn.Writer.RemoteAddr().String(), s.TunnelType, timestamp))
io.Copy(conn.Writer, reader) wg.Add(1)
} go func() {
defer wg.Done()
defer func() {
if r := recover(); r != nil {
log.Printf("Panic in writer copy: %v", r)
errChan <- fmt.Errorf("panic: %v", r)
}
}()
func handleChannelRequests(reqs <-chan *ssh.Request, conn UserConnection, channel ssh.Channel) { _, err := io.Copy(conn.Writer, reader)
select { if err != nil && !errors.Is(err, io.EOF) {
case <-reqs: log.Printf("Error copying from channel to conn.Writer: %v", err)
for req := range reqs { errChan <- err
req.Reply(false, nil) }
cancel()
}()
go func() {
wg.Wait()
close(errChan)
}()
for err := range errChan {
if err != nil {
log.Printf("Connection error: %v", err)
break
} }
case <-conn.Context.Done():
conn.Writer.Close()
channel.Close()
log.Println("Connection closed by timeout")
return
} }
} }

View File

@ -3,6 +3,7 @@ package session
import ( import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"net" "net"
"sync"
) )
type TunnelType string type TunnelType string
@ -24,6 +25,7 @@ type Session struct {
Slug string Slug string
ChannelChan chan ssh.NewChannel ChannelChan chan ssh.NewChannel
Done chan bool Done chan bool
once sync.Once
} }
func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session { func New(conn *ssh.ServerConn, forwardingReq <-chan *ssh.Request) *Session {