- Centralize environment variable loading in config.MustLoad - Parse and validate all env vars once at initialization - Make config fields private and read-only - Remove public Getenv usage in favor of typed accessors - Improve validation and initialization order - Normalize enum naming to be idiomatic and avoid constant collisions
231 lines
5.2 KiB
Go
231 lines
5.2 KiB
Go
package forwarder
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
"tunnel_pls/internal/config"
|
|
"tunnel_pls/session/slug"
|
|
"tunnel_pls/types"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type Forwarder interface {
|
|
SetType(tunnelType types.TunnelType)
|
|
SetForwardedPort(port uint16)
|
|
SetListener(listener net.Listener)
|
|
Listener() net.Listener
|
|
TunnelType() types.TunnelType
|
|
ForwardedPort() uint16
|
|
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
|
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
|
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
|
WriteBadGatewayResponse(dst io.Writer)
|
|
Close() error
|
|
}
|
|
type forwarder struct {
|
|
listener net.Listener
|
|
tunnelType types.TunnelType
|
|
forwardedPort uint16
|
|
slug slug.Slug
|
|
conn ssh.Conn
|
|
bufferPool sync.Pool
|
|
}
|
|
|
|
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
|
|
return &forwarder{
|
|
listener: nil,
|
|
tunnelType: types.TunnelTypeUNKNOWN,
|
|
forwardedPort: 0,
|
|
slug: slug,
|
|
conn: conn,
|
|
bufferPool: sync.Pool{
|
|
New: func() interface{} {
|
|
bufSize := config.BufferSize()
|
|
return make([]byte, bufSize)
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
|
|
buf := f.bufferPool.Get().([]byte)
|
|
defer f.bufferPool.Put(buf)
|
|
return io.CopyBuffer(dst, src, buf)
|
|
}
|
|
|
|
func (f *forwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
|
type channelResult struct {
|
|
channel ssh.Channel
|
|
reqs <-chan *ssh.Request
|
|
err error
|
|
}
|
|
resultChan := make(chan channelResult, 1)
|
|
|
|
go func() {
|
|
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
|
|
select {
|
|
case resultChan <- channelResult{channel, reqs, err}:
|
|
default:
|
|
if channel != nil {
|
|
err = channel.Close()
|
|
if err != nil {
|
|
log.Printf("Failed to close unused channel: %v", err)
|
|
return
|
|
}
|
|
go ssh.DiscardRequests(reqs)
|
|
}
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case result := <-resultChan:
|
|
return result.channel, result.reqs, result.err
|
|
case <-time.After(5 * time.Second):
|
|
return nil, nil, errors.New("timeout opening forwarded-tcpip channel")
|
|
}
|
|
}
|
|
|
|
func closeWriter(w io.Writer) error {
|
|
if cw, ok := w.(interface{ CloseWrite() error }); ok {
|
|
return cw.CloseWrite()
|
|
}
|
|
if closer, ok := w.(io.Closer); ok {
|
|
return closer.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
|
|
var errs []error
|
|
_, err := f.copyWithBuffer(dst, src)
|
|
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
|
|
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
|
|
}
|
|
|
|
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
|
|
errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
|
|
}
|
|
return errors.Join(errs...)
|
|
}
|
|
|
|
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
|
defer func() {
|
|
_, err := io.Copy(io.Discard, src)
|
|
if err != nil {
|
|
log.Printf("Failed to discard connection: %v", err)
|
|
}
|
|
}()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
err := f.copyAndClose(dst, src, "src to dst")
|
|
if err != nil {
|
|
log.Println("Error during copy: ", err)
|
|
return
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
err := f.copyAndClose(src, dst, "dst to src")
|
|
if err != nil {
|
|
log.Println("Error during copy: ", err)
|
|
return
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func (f *forwarder) SetType(tunnelType types.TunnelType) {
|
|
f.tunnelType = tunnelType
|
|
}
|
|
|
|
func (f *forwarder) TunnelType() types.TunnelType {
|
|
return f.tunnelType
|
|
}
|
|
|
|
func (f *forwarder) ForwardedPort() uint16 {
|
|
return f.forwardedPort
|
|
}
|
|
|
|
func (f *forwarder) SetForwardedPort(port uint16) {
|
|
f.forwardedPort = port
|
|
}
|
|
|
|
func (f *forwarder) SetListener(listener net.Listener) {
|
|
f.listener = listener
|
|
}
|
|
|
|
func (f *forwarder) Listener() net.Listener {
|
|
return f.listener
|
|
}
|
|
|
|
func (f *forwarder) WriteBadGatewayResponse(dst io.Writer) {
|
|
_, err := dst.Write(types.BadGatewayResponse)
|
|
if err != nil {
|
|
log.Printf("failed to write Bad Gateway response: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (f *forwarder) Close() error {
|
|
if f.Listener() != nil {
|
|
return f.listener.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *forwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
|
var buf bytes.Buffer
|
|
|
|
host, originPort := parseAddr(origin.String())
|
|
|
|
writeSSHString(&buf, "localhost")
|
|
err := binary.Write(&buf, binary.BigEndian, uint32(f.ForwardedPort()))
|
|
if err != nil {
|
|
log.Printf("Failed to write string to buffer: %v", err)
|
|
return nil
|
|
}
|
|
|
|
writeSSHString(&buf, host)
|
|
err = binary.Write(&buf, binary.BigEndian, uint32(originPort))
|
|
if err != nil {
|
|
log.Printf("Failed to write string to buffer: %v", err)
|
|
return nil
|
|
}
|
|
|
|
return buf.Bytes()
|
|
}
|
|
|
|
func parseAddr(addr string) (string, uint16) {
|
|
host, portStr, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
log.Printf("Failed to parse origin address: %s from address %s", err.Error(), addr)
|
|
return "0.0.0.0", uint16(0)
|
|
}
|
|
port, _ := strconv.Atoi(portStr)
|
|
return host, uint16(port)
|
|
}
|
|
|
|
func writeSSHString(buffer *bytes.Buffer, str string) {
|
|
err := binary.Write(buffer, binary.BigEndian, uint32(len(str)))
|
|
if err != nil {
|
|
log.Printf("Failed to write string to buffer: %v", err)
|
|
return
|
|
}
|
|
buffer.WriteString(str)
|
|
}
|