- Rename customWriter struct to httpWriter for clarity - Add closeWriter field to properly close write side of connections - Update all cw variable references to hw - Merge handlerTLS into handler function to reduce code duplication - Extract handler into smaller, focused methods - Split Read/Write/forwardRequest into composable functions Fixes resource leak where connections weren't properly closed on the write side, matching the forwarder's CloseWrite() pattern.
246 lines
5.7 KiB
Go
246 lines
5.7 KiB
Go
package forwarder
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
"tunnel_pls/internal/config"
|
|
"tunnel_pls/session/slug"
|
|
"tunnel_pls/types"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
var bufferPool = sync.Pool{
|
|
New: func() interface{} {
|
|
bufSize := config.GetBufferSize()
|
|
return make([]byte, bufSize)
|
|
},
|
|
}
|
|
|
|
func copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
|
buf := bufferPool.Get().([]byte)
|
|
defer bufferPool.Put(buf)
|
|
return io.CopyBuffer(dst, src, buf)
|
|
}
|
|
|
|
type forwarder struct {
|
|
listener net.Listener
|
|
tunnelType types.TunnelType
|
|
forwardedPort uint16
|
|
slug slug.Slug
|
|
conn ssh.Conn
|
|
}
|
|
|
|
func New(slug slug.Slug, conn ssh.Conn) Forwarder {
|
|
return &forwarder{
|
|
listener: nil,
|
|
tunnelType: types.UNKNOWN,
|
|
forwardedPort: 0,
|
|
slug: slug,
|
|
conn: conn,
|
|
}
|
|
}
|
|
|
|
type Forwarder interface {
|
|
SetType(tunnelType types.TunnelType)
|
|
SetForwardedPort(port uint16)
|
|
SetListener(listener net.Listener)
|
|
Listener() net.Listener
|
|
TunnelType() types.TunnelType
|
|
ForwardedPort() uint16
|
|
HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr)
|
|
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
|
WriteBadGatewayResponse(dst io.Writer)
|
|
AcceptTCPConnections()
|
|
Close() error
|
|
}
|
|
|
|
func (f *forwarder) AcceptTCPConnections() {
|
|
for {
|
|
conn, err := f.Listener().Accept()
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
return
|
|
}
|
|
log.Printf("Error accepting connection: %v", err)
|
|
continue
|
|
}
|
|
|
|
payload := f.CreateForwardedTCPIPPayload(conn.RemoteAddr())
|
|
|
|
type channelResult struct {
|
|
channel ssh.Channel
|
|
reqs <-chan *ssh.Request
|
|
err error
|
|
}
|
|
resultChan := make(chan channelResult, 1)
|
|
|
|
go func() {
|
|
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
|
select {
|
|
case resultChan <- channelResult{channel, reqs, err}:
|
|
default:
|
|
if channel != nil {
|
|
err := channel.Close()
|
|
if err != nil {
|
|
log.Printf("Failed to close unused channel: %v", err)
|
|
return
|
|
}
|
|
go ssh.DiscardRequests(reqs)
|
|
}
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case result := <-resultChan:
|
|
if result.err != nil {
|
|
log.Printf("Failed to open forwarded-tcpip channel: %v", result.err)
|
|
if closeErr := conn.Close(); closeErr != nil {
|
|
log.Printf("Failed to close connection: %v", closeErr)
|
|
}
|
|
continue
|
|
}
|
|
go ssh.DiscardRequests(result.reqs)
|
|
go f.HandleConnection(conn, result.channel, conn.RemoteAddr())
|
|
case <-time.After(5 * time.Second):
|
|
log.Printf("Timeout opening forwarded-tcpip channel")
|
|
if closeErr := conn.Close(); closeErr != nil {
|
|
log.Printf("Failed to close connection: %v", closeErr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel, remoteAddr net.Addr) {
|
|
defer func() {
|
|
_, err := io.Copy(io.Discard, src)
|
|
if err != nil {
|
|
log.Printf("Failed to discard connection: %v", err)
|
|
}
|
|
}()
|
|
|
|
log.Printf("Handling new forwarded connection from %s", remoteAddr)
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
_, err := copyWithBuffer(dst, src)
|
|
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
|
log.Printf("Error copying src to dst: %v", err)
|
|
}
|
|
if conn, ok := dst.(interface{ CloseWrite() error }); ok {
|
|
if err = conn.CloseWrite(); err != nil {
|
|
log.Printf("Error closing write side of dst: %v", err)
|
|
}
|
|
} else {
|
|
if closer, closerOk := dst.(io.Closer); closerOk {
|
|
if err = closer.Close(); err != nil && !errors.Is(err, io.EOF) {
|
|
log.Printf("Error closing dst connection: %v", err)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
_, err := copyWithBuffer(src, dst)
|
|
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
|
log.Printf("Error copying dst to src: %v", err)
|
|
}
|
|
if err = src.CloseWrite(); err != nil {
|
|
log.Printf("Error closing write side of src: %v", err)
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func (f *forwarder) SetType(tunnelType types.TunnelType) {
|
|
f.tunnelType = tunnelType
|
|
}
|
|
|
|
func (f *forwarder) TunnelType() types.TunnelType {
|
|
return f.tunnelType
|
|
}
|
|
|
|
func (f *forwarder) ForwardedPort() uint16 {
|
|
return f.forwardedPort
|
|
}
|
|
|
|
func (f *forwarder) SetForwardedPort(port uint16) {
|
|
f.forwardedPort = port
|
|
}
|
|
|
|
func (f *forwarder) SetListener(listener net.Listener) {
|
|
f.listener = listener
|
|
}
|
|
|
|
func (f *forwarder) Listener() net.Listener {
|
|
return f.listener
|
|
}
|
|
|
|
func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
|
_, err := dst.Write(types.BadGatewayResponse)
|
|
if err != nil {
|
|
log.Printf("failed to write Bad Gateway response: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (f *forwarder) Close() error {
|
|
if f.Listener() != nil {
|
|
return f.listener.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
|
var buf bytes.Buffer
|
|
|
|
host, originPort := parseAddr(origin.String())
|
|
|
|
writeSSHString(&buf, "localhost")
|
|
err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
|
|
if err != nil {
|
|
log.Printf("Failed to write string to buffer: %v", err)
|
|
return nil
|
|
}
|
|
|
|
writeSSHString(&buf, host)
|
|
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
|
if err != nil {
|
|
log.Printf("Failed to write string to buffer: %v", err)
|
|
return nil
|
|
}
|
|
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func parseAddr(addr string) (string, uint16) {
|
|
host, portStr, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
|
return "0.0.0.0", uint16(0)
|
|
}
|
|
port, _ := strconv.Atoi(portStr)
|
|
return host, uint16(port)
|
|
}
|
|
|
|
func writeSSHString(buffer *bytes.Buffer, str string) {
|
|
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
|
if err != nil {
|
|
log.Printf("Failed to write string to buffer: %v", err)
|
|
return
|
|
}
|
|
buffer.WriteString(str)
|
|
}
|