Files
tunnel-please/session/forwarder/forwarder.go

191 lines
4.2 KiB
Go

package forwarder
import (
"context"
"errors"
"fmt"
"io"
"log"
"net"
"strconv"
"sync"
"tunnel_pls/internal/config"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
)
type Forwarder interface {
SetType(tunnelType types.TunnelType)
SetForwardedPort(port uint16)
SetListener(listener net.Listener)
Listener() net.Listener
TunnelType() types.TunnelType
ForwardedPort() uint16
HandleConnection(dst io.ReadWriter, src ssh.Channel)
OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error)
Close() error
}
type forwarder struct {
listener net.Listener
tunnelType types.TunnelType
forwardedPort uint16
slug slug.Slug
conn ssh.Conn
bufferPool sync.Pool
}
func New(config config.Config, slug slug.Slug, conn ssh.Conn) Forwarder {
return &forwarder{
listener: nil,
tunnelType: types.TunnelTypeUNKNOWN,
forwardedPort: 0,
slug: slug,
conn: conn,
bufferPool: sync.Pool{
New: func() interface{} {
bufSize := config.BufferSize()
return make([]byte, bufSize)
},
},
}
}
func (f *forwarder) copyWithBuffer(dst io.Writer, src io.Reader) (written int64, err error) {
buf := f.bufferPool.Get().([]byte)
defer f.bufferPool.Put(buf)
return io.CopyBuffer(dst, src, buf)
}
func (f *forwarder) OpenForwardedChannel(ctx context.Context, origin net.Addr) (ssh.Channel, <-chan *ssh.Request, error) {
payload := createForwardedTCPIPPayload(origin, f.forwardedPort)
type channelResult struct {
channel ssh.Channel
reqs <-chan *ssh.Request
err error
}
resultChan := make(chan channelResult, 1)
go func() {
channel, reqs, err := f.conn.OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
case <-ctx.Done():
if channel != nil {
_ = channel.Close()
go ssh.DiscardRequests(reqs)
}
}
}()
select {
case result := <-resultChan:
return result.channel, result.reqs, result.err
case <-ctx.Done():
return nil, nil, fmt.Errorf("context cancelled: %w", ctx.Err())
}
}
func closeWriter(w io.Writer) error {
if cw, ok := w.(interface{ CloseWrite() error }); ok {
return cw.CloseWrite()
}
if closer, ok := w.(io.Closer); ok {
return closer.Close()
}
return nil
}
func (f *forwarder) copyAndClose(dst io.Writer, src io.Reader, direction string) error {
var errs []error
_, err := f.copyWithBuffer(dst, src)
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, net.ErrClosed) {
errs = append(errs, fmt.Errorf("copy error (%s): %w", direction, err))
}
if err = closeWriter(dst); err != nil && !errors.Is(err, io.EOF) {
errs = append(errs, fmt.Errorf("close stream error (%s): %w", direction, err))
}
return errors.Join(errs...)
}
func (f *forwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
defer func() {
_, _ = io.Copy(io.Discard, src)
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
err := f.copyAndClose(dst, src, "src to dst")
if err != nil {
log.Println("Error during copy: ", err)
return
}
}()
go func() {
defer wg.Done()
err := f.copyAndClose(src, dst, "dst to src")
if err != nil {
log.Println("Error during copy: ", err)
return
}
}()
wg.Wait()
}
func (f *forwarder) SetType(tunnelType types.TunnelType) {
f.tunnelType = tunnelType
}
func (f *forwarder) TunnelType() types.TunnelType {
return f.tunnelType
}
func (f *forwarder) ForwardedPort() uint16 {
return f.forwardedPort
}
func (f *forwarder) SetForwardedPort(port uint16) {
f.forwardedPort = port
}
func (f *forwarder) SetListener(listener net.Listener) {
f.listener = listener
}
func (f *forwarder) Listener() net.Listener {
return f.listener
}
func (f *forwarder) Close() error {
if f.Listener() != nil {
return f.listener.Close()
}
return nil
}
func createForwardedTCPIPPayload(origin net.Addr, destPort uint16) []byte {
host, portStr, _ := net.SplitHostPort(origin.String())
port, _ := strconv.Atoi(portStr)
forwardPayload := struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}{
DestAddr: "localhost",
DestPort: uint32(destPort),
OriginAddr: host,
OriginPort: uint32(port),
}
return ssh.Marshal(forwardPayload)
}