Files
tunnel-please/session/forwarder/forwarder_test.go

1803 lines
41 KiB
Go

package forwarder
import (
"bytes"
"context"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
type mockConfig struct {
mock.Mock
}
func (m *mockConfig) Domain() string { return m.Called().String(0) }
func (m *mockConfig) SSHPort() string { return m.Called().String(0) }
func (m *mockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *mockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *mockConfig) KeyLoc() string { return m.Called().String(0) }
func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *mockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *mockConfig) TLSStoragePath() string { return m.Called().String(0) }
func (m *mockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *mockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *mockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *mockConfig) AllowedPortsStart() uint16 { return m.Called().Get(0).(uint16) }
func (m *mockConfig) AllowedPortsEnd() uint16 { return m.Called().Get(0).(uint16) }
func (m *mockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *mockConfig) HeaderSize() int { return m.Called().Int(0) }
func (m *mockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *mockConfig) PprofPort() string { return m.Called().String(0) }
func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *mockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *mockConfig) NodeToken() string { return m.Called().String(0) }
type mockConn struct {
mock.Mock
}
func (c *mockConn) Close() error { return c.Called().Error(0) }
func (c *mockConn) User() string { return c.Called().String(0) }
func (c *mockConn) SessionID() []byte { return c.Called().Get(0).([]byte) }
func (c *mockConn) ClientVersion() []byte { return c.Called().Get(0).([]byte) }
func (c *mockConn) ServerVersion() []byte { return c.Called().Get(0).([]byte) }
func (c *mockConn) RemoteAddr() net.Addr { return c.Called().Get(0).(net.Addr) }
func (c *mockConn) LocalAddr() net.Addr { return c.Called().Get(0).(net.Addr) }
func (c *mockConn) SendRequest(s string, b bool, d []byte) (bool, []byte, error) {
args := c.Called(s, b, d)
return args.Bool(0), args.Get(1).([]byte), args.Error(2)
}
func (c *mockConn) Wait() error { return c.Called().Error(0) }
func (c *mockConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
args := c.Called(name, data)
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
type testChannel struct {
mock.Mock
readBuf *syncBuffer
writeBuf *syncBuffer
closedWrite atomic.Bool
}
func (c *testChannel) Read(b []byte) (int, error) {
return c.readBuf.Read(b)
}
func (c *testChannel) Write(b []byte) (int, error) {
return c.writeBuf.Write(b)
}
func (c *testChannel) Close() error {
return c.Called().Error(0)
}
func (c *testChannel) CloseWrite() error {
c.closedWrite.Store(true)
return c.writeBuf.Close()
}
func (c *testChannel) Stderr() io.ReadWriter {
return c.Called().Get(0).(io.ReadWriter)
}
func (c *testChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
args := c.Called(name, wantReply, payload)
return args.Bool(0), args.Error(1)
}
func (c *testChannel) AckRequest(ok bool, payload []byte) error {
return c.Called(ok, payload).Error(0)
}
type syncBuffer struct {
mu sync.Mutex
buf []byte
closed bool
cond *sync.Cond
}
func newSyncBuffer() *syncBuffer {
sb := &syncBuffer{}
sb.cond = sync.NewCond(&sb.mu)
return sb
}
func (sb *syncBuffer) Write(p []byte) (int, error) {
sb.mu.Lock()
defer sb.mu.Unlock()
if sb.closed {
return 0, io.ErrClosedPipe
}
sb.buf = append(sb.buf, p...)
sb.cond.Broadcast()
return len(p), nil
}
func (sb *syncBuffer) Read(p []byte) (int, error) {
sb.mu.Lock()
defer sb.mu.Unlock()
for len(sb.buf) == 0 {
if sb.closed {
return 0, io.EOF
}
sb.cond.Wait()
}
n := copy(p, sb.buf)
sb.buf = sb.buf[n:]
return n, nil
}
func (sb *syncBuffer) Close() error {
sb.mu.Lock()
defer sb.mu.Unlock()
sb.closed = true
sb.cond.Broadcast()
return nil
}
func newChannelPair() (*testChannel, *testChannelPeer) {
peerToChBuf := newSyncBuffer()
chToPeerBuf := newSyncBuffer()
channel := &testChannel{
readBuf: peerToChBuf,
writeBuf: chToPeerBuf,
}
peer := &testChannelPeer{
readBuf: chToPeerBuf,
writeBuf: peerToChBuf,
}
channel.On("Close").Return(nil).Maybe()
return channel, peer
}
type testChannelPeer struct {
readBuf *syncBuffer
writeBuf *syncBuffer
}
func (p *testChannelPeer) Read(b []byte) (int, error) {
return p.readBuf.Read(b)
}
func (p *testChannelPeer) Write(b []byte) (int, error) {
return p.writeBuf.Write(b)
}
func (p *testChannelPeer) CloseWrite() error {
return p.writeBuf.Close()
}
func newPipePair() (*pipeConn, *pipeConn) {
r1, w1 := io.Pipe()
r2, w2 := io.Pipe()
conn1 := &pipeConn{
reader: r1,
writer: w2,
}
conn2 := &pipeConn{
reader: r2,
writer: w1,
}
return conn1, conn2
}
type pipeConn struct {
reader *io.PipeReader
writer *io.PipeWriter
}
func (p *pipeConn) Read(b []byte) (int, error) {
return p.reader.Read(b)
}
func (p *pipeConn) Write(b []byte) (int, error) {
return p.writer.Write(b)
}
func (p *pipeConn) Close() error {
err := p.reader.Close()
if err != nil {
return err
}
err = p.writer.Close()
if err != nil {
return err
}
return nil
}
func (p *pipeConn) CloseWrite() error {
return p.writer.Close()
}
func (p *pipeConn) LocalAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
func (p *pipeConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
func (p *pipeConn) SetDeadline(t time.Time) error { return nil }
func (p *pipeConn) SetReadDeadline(t time.Time) error { return nil }
func (p *pipeConn) SetWriteDeadline(t time.Time) error { return nil }
func TestNew(t *testing.T) {
tests := []struct {
name string
bufferSize int
wantBufLen int
}{
{
name: "default buffer size",
bufferSize: 16,
wantBufLen: 16,
},
{
name: "custom buffer size",
bufferSize: 32,
wantBufLen: 32,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
s := slug.New()
conn := &mockConn{}
forwarder := New(cfg, s, conn).(*forwarder)
buf := forwarder.bufferPool.Get().([]byte)
require.Len(t, buf, tt.wantBufLen)
forwarder.bufferPool.Put(buf)
assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType())
assert.Equal(t, uint16(0), forwarder.ForwardedPort())
assert.Equal(t, conn, forwarder.conn)
assert.Equal(t, s, forwarder.slug)
cfg.AssertExpectations(t)
})
}
}
func TestHandleConnection(t *testing.T) {
tests := []struct {
name string
bufferSize int
messageToDst []byte
messageToSrc []byte
}{
{
name: "small messages",
bufferSize: 4,
messageToDst: []byte("hi"),
messageToSrc: []byte("yo"),
},
{
name: "medium messages",
bufferSize: 8,
messageToDst: []byte("hello"),
messageToSrc: []byte("world"),
},
{
name: "larger messages",
bufferSize: 16,
messageToDst: []byte("I love femboy"),
messageToSrc: []byte("mee too"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
channel, channelPeer := newChannelPair()
dstEndpoint, dstPeer := newPipePair()
done := make(chan struct{})
go func() {
forwarder.HandleConnection(dstEndpoint, channel)
close(done)
}()
readDst := make(chan struct {
data []byte
err error
}, 1)
go func() {
buf := make([]byte, len(tt.messageToDst))
n, err := io.ReadFull(dstPeer, buf)
readDst <- struct {
data []byte
err error
}{data: buf[:n], err: err}
}()
_, err := channelPeer.Write(tt.messageToDst)
require.NoError(t, err)
dstResult := <-readDst
require.NoError(t, dstResult.err)
assert.Equal(t, tt.messageToDst, dstResult.data)
readSrc := make(chan struct {
data []byte
err error
}, 1)
go func() {
buf := make([]byte, len(tt.messageToSrc))
n, err := io.ReadFull(channelPeer, buf)
readSrc <- struct {
data []byte
err error
}{data: buf[:n], err: err}
}()
_, err = dstPeer.Write(tt.messageToSrc)
require.NoError(t, err)
srcResult := <-readSrc
require.NoError(t, srcResult.err)
assert.Equal(t, tt.messageToSrc, srcResult.data)
require.NoError(t, channelPeer.CloseWrite())
require.NoError(t, dstPeer.CloseWrite())
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("HandleConnection did not complete")
}
assert.True(t, channel.closedWrite.Load())
cfg.AssertExpectations(t)
})
}
}
func TestHandleConnection_Error(t *testing.T) {
tests := []struct {
name string
bufferSize int
messageToDst []byte
messageToSrc []byte
}{
{
name: "small messages",
bufferSize: 4,
messageToDst: []byte("hi"),
messageToSrc: []byte("yo"),
},
{
name: "medium messages",
bufferSize: 8,
messageToDst: []byte("hello"),
messageToSrc: []byte("world"),
},
{
name: "larger messages",
bufferSize: 16,
messageToDst: []byte("I love femboy"),
messageToSrc: []byte("mee too"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
channel, _ := newChannelPair()
dstEndpoint, _ := newPipePair()
go func() {
forwarder.HandleConnection(dstEndpoint, channel)
}()
err := dstEndpoint.Close()
assert.NoError(t, err)
cfg.AssertExpectations(t)
})
}
}
func TestOpenForwardedChannel(t *testing.T) {
tests := []struct {
name string
forwardedPort uint16
originIP string
originPort int
wantDestAddr string
wantDestPort uint32
wantOrigAddr string
wantOrigPort uint32
}{
{
name: "localhost origin",
forwardedPort: 2222,
originIP: "127.0.0.1",
originPort: 9000,
wantDestAddr: "localhost",
wantDestPort: 2222,
wantOrigAddr: "127.0.0.1",
wantOrigPort: 9000,
},
{
name: "remote origin",
forwardedPort: 8080,
originIP: "192.168.1.100",
originPort: 5000,
wantDestAddr: "localhost",
wantDestPort: 8080,
wantOrigAddr: "192.168.1.100",
wantOrigPort: 5000,
},
{
name: "different port",
forwardedPort: 3000,
originIP: "10.0.0.1",
originPort: 7777,
wantDestAddr: "localhost",
wantDestPort: 3000,
wantOrigAddr: "10.0.0.1",
wantOrigPort: 7777,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(8).Maybe()
channel := &testChannel{
readBuf: newSyncBuffer(),
writeBuf: newSyncBuffer(),
}
requests := make(chan *ssh.Request)
var capturedData []byte
conn := &mockConn{}
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) {
data := args.Get(1).([]byte)
capturedData = make([]byte, len(data))
copy(capturedData, data)
}).Return(channel, (<-chan *ssh.Request)(requests), nil)
forwarder := New(cfg, slug.New(), conn).(*forwarder)
forwarder.SetForwardedPort(tt.forwardedPort)
origin := &net.TCPAddr{IP: net.ParseIP(tt.originIP), Port: tt.originPort}
ch, reqs, err := forwarder.OpenForwardedChannel(context.Background(), origin)
require.NoError(t, err)
assert.Same(t, channel, ch)
assert.NotNil(t, reqs)
var payload struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
err = ssh.Unmarshal(capturedData, &payload)
assert.NoError(t, err)
assert.Equal(t, tt.wantDestAddr, payload.DestAddr)
assert.Equal(t, tt.wantDestPort, payload.DestPort)
assert.Equal(t, tt.wantOrigAddr, payload.OriginAddr)
assert.Equal(t, tt.wantOrigPort, payload.OriginPort)
conn.AssertExpectations(t)
cfg.AssertExpectations(t)
})
}
}
func TestOpenForwardedChannelContextCancellation(t *testing.T) {
tests := []struct {
name string
cancelBefore bool
cancelDuring bool
wantErr bool
wantErrType error
}{
{
name: "cancel during open",
cancelBefore: false,
cancelDuring: true,
wantErr: true,
wantErrType: context.Canceled,
},
{
name: "cancel before open",
cancelBefore: true,
cancelDuring: false,
wantErr: true,
wantErrType: context.Canceled,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(8).Maybe()
channel := &testChannel{
readBuf: newSyncBuffer(),
writeBuf: newSyncBuffer(),
}
channel.On("Close").Return(nil)
requests := make(chan *ssh.Request)
openChannelCalled := make(chan struct{})
openChannelBlock := make(chan struct{})
conn := &mockConn{}
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) {
close(openChannelCalled)
<-openChannelBlock
}).Return(channel, (<-chan *ssh.Request)(requests), nil).Maybe()
forwarder := New(cfg, slug.New(), conn).(*forwarder)
forwarder.SetForwardedPort(8080)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if tt.cancelBefore {
cancel()
}
var (
openedChannel ssh.Channel
openedReqs <-chan *ssh.Request
openErr error
)
done := make(chan struct{})
go func() {
origin := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 7000}
openedChannel, openedReqs, openErr = forwarder.OpenForwardedChannel(ctx, origin)
close(done)
}()
if tt.cancelDuring {
<-openChannelCalled
cancel()
}
close(openChannelBlock)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("OpenForwardedChannel did not return after cancellation")
}
if tt.wantErr {
require.Error(t, openErr)
assert.True(t, errors.Is(openErr, tt.wantErrType))
assert.Nil(t, openedChannel)
assert.Nil(t, openedReqs)
} else {
require.NoError(t, openErr)
assert.NotNil(t, openedChannel)
assert.NotNil(t, openedReqs)
}
conn.AssertExpectations(t)
cfg.AssertExpectations(t)
})
}
}
func TestCreateForwardedTCPIPPayload(t *testing.T) {
tests := []struct {
name string
originIP string
originPort int
forwardedPort uint16
wantDestAddr string
wantDestPort uint32
wantOriginAddr string
wantOriginPort uint32
}{
{
name: "standard case",
originIP: "192.0.2.10",
originPort: 5050,
forwardedPort: 8080,
wantDestAddr: "localhost",
wantDestPort: 8080,
wantOriginAddr: "192.0.2.10",
wantOriginPort: 5050,
},
{
name: "localhost origin",
originIP: "127.0.0.1",
originPort: 3000,
forwardedPort: 9000,
wantDestAddr: "localhost",
wantDestPort: 9000,
wantOriginAddr: "127.0.0.1",
wantOriginPort: 3000,
},
{
name: "high port numbers",
originIP: "10.0.0.1",
originPort: 65535,
forwardedPort: 65534,
wantDestAddr: "localhost",
wantDestPort: 65534,
wantOriginAddr: "10.0.0.1",
wantOriginPort: 65535,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
origin := &net.TCPAddr{IP: net.ParseIP(tt.originIP), Port: tt.originPort}
payload := createForwardedTCPIPPayload(origin, tt.forwardedPort)
var decoded struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
err := ssh.Unmarshal(payload, &decoded)
assert.NoError(t, err)
assert.Equal(t, tt.wantDestAddr, decoded.DestAddr)
assert.Equal(t, tt.wantDestPort, decoded.DestPort)
assert.Equal(t, tt.wantOriginAddr, decoded.OriginAddr)
assert.Equal(t, tt.wantOriginPort, decoded.OriginPort)
})
}
}
type mockReader struct {
mock.Mock
}
func (m *mockReader) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
type mockWriter struct {
mock.Mock
}
func (m *mockWriter) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *mockWriter) CloseWrite() error {
return m.Called().Error(0)
}
type mockWriteCloser struct {
mock.Mock
}
func (m *mockWriteCloser) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *mockWriteCloser) Close() error {
return m.Called().Error(0)
}
func TestCopyAndClose(t *testing.T) {
tests := []struct {
name string
setupSrc func() io.Reader
setupDst func() io.Writer
direction string
wantErr bool
wantErrMsg string
checkErrTypes []error
}{
{
name: "successful copy with EOF",
setupSrc: func() io.Reader {
r := &mockReader{}
r.On("Read", mock.Anything).Return(5, nil).Once()
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
return r
},
setupDst: func() io.Writer {
w := &mockWriter{}
w.On("Write", mock.Anything).Return(5, nil).Once()
w.On("CloseWrite").Return(nil).Once()
return w
},
direction: "src->dst",
wantErr: false,
},
{
name: "copy error - not EOF or ErrClosed",
setupSrc: func() io.Reader {
r := &mockReader{}
customErr := errors.New("custom read error")
r.On("Read", mock.Anything).Return(0, customErr).Once()
return r
},
setupDst: func() io.Writer {
w := &mockWriter{}
w.On("CloseWrite").Return(nil).Once()
return w
},
direction: "src->dst",
wantErr: true,
wantErrMsg: "copy error (src->dst)",
},
{
name: "copy error - ErrClosed should be ignored",
setupSrc: func() io.Reader {
r := &mockReader{}
r.On("Read", mock.Anything).Return(0, net.ErrClosed).Once()
return r
},
setupDst: func() io.Writer {
w := &mockWriter{}
w.On("CloseWrite").Return(nil).Once()
return w
},
direction: "src->dst",
wantErr: false,
},
{
name: "close writer error - not EOF",
setupSrc: func() io.Reader {
r := &mockReader{}
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
return r
},
setupDst: func() io.Writer {
w := &mockWriter{}
closeErr := errors.New("close error")
w.On("CloseWrite").Return(closeErr).Once()
return w
},
direction: "src->dst",
wantErr: true,
wantErrMsg: "close stream error (src->dst)",
},
{
name: "close writer error - EOF should be ignored",
setupSrc: func() io.Reader {
r := &mockReader{}
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
return r
},
setupDst: func() io.Writer {
w := &mockWriter{}
w.On("CloseWrite").Return(io.EOF).Once()
return w
},
direction: "src->dst",
wantErr: false,
},
{
name: "both copy and close errors",
setupSrc: func() io.Reader {
r := &mockReader{}
copyErr := errors.New("copy error")
r.On("Read", mock.Anything).Return(0, copyErr).Once()
return r
},
setupDst: func() io.Writer {
w := &mockWriter{}
closeErr := errors.New("close error")
w.On("CloseWrite").Return(closeErr).Once()
return w
},
direction: "src->dst",
wantErr: true,
wantErrMsg: "copy error (src->dst)",
},
{
name: "successful copy with WriteCloser",
setupSrc: func() io.Reader {
r := &mockReader{}
r.On("Read", mock.Anything).Return(5, nil).Once()
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
return r
},
setupDst: func() io.Writer {
w := &mockWriteCloser{}
w.On("Write", mock.Anything).Return(5, nil).Once()
w.On("Close").Return(nil).Once()
return w
},
direction: "dst->src",
wantErr: false,
},
{
name: "WriteCloser close error",
setupSrc: func() io.Reader {
r := &mockReader{}
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
return r
},
setupDst: func() io.Writer {
w := &mockWriteCloser{}
closeErr := errors.New("writeCloser close error")
w.On("Close").Return(closeErr).Once()
return w
},
direction: "dst->src",
wantErr: true,
wantErrMsg: "close stream error (dst->src)",
},
{
name: "copy with multiple reads before EOF",
setupSrc: func() io.Reader {
r := &mockReader{}
r.On("Read", mock.Anything).Return(10, nil).Once()
r.On("Read", mock.Anything).Return(15, nil).Once()
r.On("Read", mock.Anything).Return(5, nil).Once()
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
return r
},
setupDst: func() io.Writer {
w := &mockWriter{}
w.On("Write", mock.Anything).Return(10, nil).Once()
w.On("Write", mock.Anything).Return(15, nil).Once()
w.On("Write", mock.Anything).Return(5, nil).Once()
w.On("CloseWrite").Return(nil).Once()
return w
},
direction: "src->dst",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(32).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
src := tt.setupSrc()
dst := tt.setupDst()
err := forwarder.copyAndClose(dst, src, tt.direction)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErrMsg)
} else {
assert.NoError(t, err)
}
if mr, ok := src.(*mockReader); ok {
mr.AssertExpectations(t)
}
if mw, ok := dst.(*mockWriter); ok {
mw.AssertExpectations(t)
}
if mwc, ok := dst.(*mockWriteCloser); ok {
mwc.AssertExpectations(t)
}
cfg.AssertExpectations(t)
})
}
}
func TestCopyAndCloseJoinedErrors(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(32).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
src := &mockReader{}
copyErr := errors.New("copy failed")
src.On("Read", mock.Anything).Return(0, copyErr).Once()
dst := &mockWriter{}
closeErr := errors.New("close failed")
dst.On("CloseWrite").Return(closeErr).Once()
err := forwarder.copyAndClose(dst, src, "test")
require.Error(t, err)
assert.Contains(t, err.Error(), "copy error (test)")
assert.Contains(t, err.Error(), "close stream error (test)")
assert.Contains(t, err.Error(), "copy failed")
assert.Contains(t, err.Error(), "close failed")
src.AssertExpectations(t)
dst.AssertExpectations(t)
cfg.AssertExpectations(t)
}
func TestCopyWithBuffer(t *testing.T) {
tests := []struct {
name string
bufferSize int
setupSrc func() io.Reader
setupDst func() io.Writer
wantBytesCount int64
wantErr bool
wantErrType error
}{
{
name: "successful copy small data",
bufferSize: 16,
setupSrc: func() io.Reader {
return io.NopCloser(bytes.NewReader([]byte("hello world")))
},
setupDst: func() io.Writer {
return &bytes.Buffer{}
},
wantBytesCount: 11,
wantErr: false,
},
{
name: "successful copy large data",
bufferSize: 8,
setupSrc: func() io.Reader {
data := make([]byte, 1024)
for i := range data {
data[i] = byte(i % 256)
}
return io.NopCloser(bytes.NewReader(data))
},
setupDst: func() io.Writer {
return &bytes.Buffer{}
},
wantBytesCount: 1024,
wantErr: false,
},
{
name: "empty data",
bufferSize: 16,
setupSrc: func() io.Reader {
return io.NopCloser(bytes.NewReader([]byte{}))
},
setupDst: func() io.Writer {
return &bytes.Buffer{}
},
wantBytesCount: 0,
wantErr: false,
},
{
name: "read error",
bufferSize: 16,
setupSrc: func() io.Reader {
r := &mockReader{}
r.On("Read", mock.Anything).Return(0, errors.New("read error")).Once()
return r
},
setupDst: func() io.Writer {
return &bytes.Buffer{}
},
wantBytesCount: 0,
wantErr: true,
},
{
name: "write error",
bufferSize: 16,
setupSrc: func() io.Reader {
return io.NopCloser(bytes.NewReader([]byte("test data")))
},
setupDst: func() io.Writer {
w := &mockWriter{}
w.On("Write", mock.Anything).Return(0, errors.New("write error")).Once()
return w
},
wantBytesCount: 0,
wantErr: true,
},
{
name: "partial write continues",
bufferSize: 16,
setupSrc: func() io.Reader {
return io.NopCloser(bytes.NewReader([]byte("testing")))
},
setupDst: func() io.Writer {
buf := &bytes.Buffer{}
return buf
},
wantBytesCount: 7,
wantErr: false,
},
{
name: "multiple buffer fills",
bufferSize: 4,
setupSrc: func() io.Reader {
return io.NopCloser(bytes.NewReader([]byte("this is a longer message")))
},
setupDst: func() io.Writer {
return &bytes.Buffer{}
},
wantBytesCount: 24,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
src := tt.setupSrc()
dst := tt.setupDst()
n, err := forwarder.copyWithBuffer(dst, src)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.wantBytesCount, n)
}
if buf, ok := dst.(*bytes.Buffer); ok && !tt.wantErr {
assert.Equal(t, tt.wantBytesCount, int64(buf.Len()))
}
if mr, ok := src.(*mockReader); ok {
mr.AssertExpectations(t)
}
if mw, ok := dst.(*mockWriter); ok {
mw.AssertExpectations(t)
}
cfg.AssertExpectations(t)
})
}
}
func TestCopyWithBufferReusesBuffer(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
buf1 := forwarder.bufferPool.Get().([]byte)
initialPtr := &buf1[0]
forwarder.bufferPool.Put(buf1)
src := io.NopCloser(bytes.NewReader([]byte("test")))
dst := &bytes.Buffer{}
_, err := forwarder.copyWithBuffer(dst, src)
require.NoError(t, err)
buf2 := forwarder.bufferPool.Get().([]byte)
secondPtr := &buf2[0]
forwarder.bufferPool.Put(buf2)
assert.Equal(t, len(buf1), len(buf2))
assert.Len(t, buf2, 16)
_ = initialPtr
_ = secondPtr
cfg.AssertExpectations(t)
}
func TestSetType(t *testing.T) {
tests := []struct {
name string
tunnelType types.TunnelType
}{
{
name: "set to HTTP",
tunnelType: types.TunnelTypeHTTP,
},
{
name: "set to TCP",
tunnelType: types.TunnelTypeTCP,
},
{
name: "set to UNKNOWN",
tunnelType: types.TunnelTypeUNKNOWN,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType())
forwarder.SetType(tt.tunnelType)
assert.Equal(t, tt.tunnelType, forwarder.TunnelType())
cfg.AssertExpectations(t)
})
}
}
func TestTunnelType(t *testing.T) {
tests := []struct {
name string
tunnelType types.TunnelType
}{
{
name: "get HTTP type",
tunnelType: types.TunnelTypeHTTP,
},
{
name: "get TCP type",
tunnelType: types.TunnelTypeTCP,
},
{
name: "get UNKNOWN type",
tunnelType: types.TunnelTypeUNKNOWN,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
forwarder.SetType(tt.tunnelType)
result := forwarder.TunnelType()
assert.Equal(t, tt.tunnelType, result)
cfg.AssertExpectations(t)
})
}
}
func TestSetForwardedPort(t *testing.T) {
tests := []struct {
name string
port uint16
}{
{
name: "set standard port",
port: 8080,
},
{
name: "set low port",
port: 80,
},
{
name: "set high port",
port: 65535,
},
{
name: "set zero port",
port: 0,
},
{
name: "set custom port",
port: 3000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
assert.Equal(t, uint16(0), forwarder.ForwardedPort())
forwarder.SetForwardedPort(tt.port)
assert.Equal(t, tt.port, forwarder.ForwardedPort())
cfg.AssertExpectations(t)
})
}
}
func TestForwardedPort(t *testing.T) {
tests := []struct {
name string
port uint16
}{
{
name: "get default port",
port: 0,
},
{
name: "get standard port",
port: 8080,
},
{
name: "get high port",
port: 65535,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
if tt.port != 0 {
forwarder.SetForwardedPort(tt.port)
}
result := forwarder.ForwardedPort()
assert.Equal(t, tt.port, result)
cfg.AssertExpectations(t)
})
}
}
func TestSetListener(t *testing.T) {
tests := []struct {
name string
setupListener func() net.Listener
}{
{
name: "set TCP listener",
setupListener: func() net.Listener {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
return listener
},
},
{
name: "set nil listener",
setupListener: func() net.Listener {
return nil
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
listener := tt.setupListener()
if listener != nil {
defer func(listener net.Listener) {
err := listener.Close()
assert.NoError(t, err)
}(listener)
}
assert.Nil(t, forwarder.Listener())
forwarder.SetListener(listener)
assert.Equal(t, listener, forwarder.Listener())
cfg.AssertExpectations(t)
})
}
}
func TestListener(t *testing.T) {
tests := []struct {
name string
setupListener func() net.Listener
}{
{
name: "get nil listener",
setupListener: func() net.Listener {
return nil
},
},
{
name: "get TCP listener",
setupListener: func() net.Listener {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
return listener
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
listener := tt.setupListener()
if listener != nil {
defer func(listener net.Listener) {
err := listener.Close()
assert.NoError(t, err)
}(listener)
forwarder.SetListener(listener)
}
result := forwarder.Listener()
assert.Equal(t, listener, result)
cfg.AssertExpectations(t)
})
}
}
func TestClose(t *testing.T) {
tests := []struct {
name string
setupListener func() net.Listener
wantErr bool
}{
{
name: "close with nil listener",
setupListener: func() net.Listener {
return nil
},
wantErr: false,
},
{
name: "close with active listener",
setupListener: func() net.Listener {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
return listener
},
wantErr: false,
},
{
name: "close already closed listener",
setupListener: func() net.Listener {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
err = listener.Close()
assert.NoError(t, err)
return listener
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
listener := tt.setupListener()
if listener != nil {
forwarder.SetListener(listener)
}
err := forwarder.Close()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
cfg.AssertExpectations(t)
})
}
}
func TestCloseWriter(t *testing.T) {
tests := []struct {
name string
setup func() io.Writer
wantErr bool
}{
{
name: "close writer with CloseWrite method",
setup: func() io.Writer {
w := &mockWriter{}
w.On("CloseWrite").Return(nil).Once()
return w
},
wantErr: false,
},
{
name: "close writer with CloseWrite error",
setup: func() io.Writer {
w := &mockWriter{}
w.On("CloseWrite").Return(errors.New("close write error")).Once()
return w
},
wantErr: true,
},
{
name: "close WriteCloser",
setup: func() io.Writer {
w := &mockWriteCloser{}
w.On("Close").Return(nil).Once()
return w
},
wantErr: false,
},
{
name: "close WriteCloser with error",
setup: func() io.Writer {
w := &mockWriteCloser{}
w.On("Close").Return(errors.New("close error")).Once()
return w
},
wantErr: true,
},
{
name: "close plain writer (no close method)",
setup: func() io.Writer {
return &bytes.Buffer{}
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
writer := tt.setup()
err := closeWriter(writer)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
if mw, ok := writer.(*mockWriter); ok {
mw.AssertExpectations(t)
}
if mwc, ok := writer.(*mockWriteCloser); ok {
mwc.AssertExpectations(t)
}
})
}
}
func TestHandleConnectionWithErrors(t *testing.T) {
tests := []struct {
name string
bufferSize int
setupChannel func() (*testChannel, *testChannelPeer)
setupDst func() (net.Conn, *pipeConn)
simulateErr func(channel *testChannelPeer, dst *pipeConn)
}{
{
name: "handle read error from channel",
bufferSize: 16,
setupChannel: func() (*testChannel, *testChannelPeer) {
return newChannelPair()
},
setupDst: func() (net.Conn, *pipeConn) {
return newPipePair()
},
simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
err := channel.CloseWrite()
assert.NoError(t, err)
err = dst.CloseWrite()
assert.NoError(t, err)
},
},
{
name: "handle write error to destination",
bufferSize: 16,
setupChannel: func() (*testChannel, *testChannelPeer) {
return newChannelPair()
},
setupDst: func() (net.Conn, *pipeConn) {
return newPipePair()
},
simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
err := dst.Close()
assert.NoError(t, err)
time.Sleep(10 * time.Millisecond)
write, err := channel.Write([]byte("test"))
assert.NotZero(t, write)
assert.NoError(t, err)
err = channel.CloseWrite()
assert.NoError(t, err)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
channel, channelPeer := tt.setupChannel()
dstEndpoint, dstPeer := tt.setupDst()
done := make(chan struct{})
go func() {
forwarder.HandleConnection(dstEndpoint, channel)
close(done)
}()
tt.simulateErr(channelPeer, dstPeer)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("HandleConnection did not complete")
}
cfg.AssertExpectations(t)
})
}
}
func TestHandleConnectionDiscardOnExit(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(16).Maybe()
forwarder := New(cfg, slug.New(), nil).(*forwarder)
channel, channelPeer := newChannelPair()
dstEndpoint, dstPeer := newPipePair()
done := make(chan struct{})
go func() {
forwarder.HandleConnection(dstEndpoint, channel)
close(done)
}()
_, err := channelPeer.Write([]byte("test data"))
require.NoError(t, err)
require.NoError(t, channelPeer.CloseWrite())
require.NoError(t, dstPeer.Close())
select {
case <-done:
case <-time.After(10 * time.Second):
t.Fatal("HandleConnection did not complete")
}
cfg.AssertExpectations(t)
}
func TestOpenForwardedChannelSuccess(t *testing.T) {
tests := []struct {
name string
forwardedPort uint16
originAddr string
originPort int
}{
{
name: "open channel standard port",
forwardedPort: 8080,
originAddr: "127.0.0.1",
originPort: 9000,
},
{
name: "open channel high port",
forwardedPort: 65534,
originAddr: "192.168.1.100",
originPort: 5000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(8).Maybe()
channel := &testChannel{
readBuf: newSyncBuffer(),
writeBuf: newSyncBuffer(),
}
requests := make(chan *ssh.Request)
conn := &mockConn{}
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).
Return(channel, (<-chan *ssh.Request)(requests), nil)
forwarder := New(cfg, slug.New(), conn).(*forwarder)
forwarder.SetForwardedPort(tt.forwardedPort)
origin := &net.TCPAddr{IP: net.ParseIP(tt.originAddr), Port: tt.originPort}
ch, reqs, err := forwarder.OpenForwardedChannel(context.Background(), origin)
require.NoError(t, err)
assert.NotNil(t, ch)
assert.NotNil(t, reqs)
conn.AssertExpectations(t)
cfg.AssertExpectations(t)
})
}
}
func TestOpenForwardedChannelError(t *testing.T) {
tests := []struct {
name string
setupConn func() *mockConn
wantErr bool
wantErrMsg string
}{
{
name: "open channel returns error",
setupConn: func() *mockConn {
conn := &mockConn{}
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).
Return((*testChannel)(nil), (<-chan *ssh.Request)(nil), errors.New("channel open failed"))
return conn
},
wantErr: true,
wantErrMsg: "channel open failed",
},
{
name: "open channel with nil channel",
setupConn: func() *mockConn {
conn := &mockConn{}
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).
Return((*testChannel)(nil), (<-chan *ssh.Request)(nil), nil)
return conn
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(8).Maybe()
conn := tt.setupConn()
forwarder := New(cfg, slug.New(), conn).(*forwarder)
forwarder.SetForwardedPort(8080)
origin := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 9000}
_, _, err := forwarder.OpenForwardedChannel(context.Background(), origin)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErrMsg)
} else {
assert.NoError(t, err)
}
conn.AssertExpectations(t)
cfg.AssertExpectations(t)
})
}
}
func TestOpenForwardedChannelContextCancelledDuringOpen(t *testing.T) {
cfg := &mockConfig{}
cfg.On("BufferSize").Return(8).Maybe()
channel := &testChannel{
readBuf: newSyncBuffer(),
writeBuf: newSyncBuffer(),
}
channel.On("Close").Return(nil).Maybe()
requests := make(chan *ssh.Request)
openChannelStarted := make(chan struct{})
openChannelBlock := make(chan struct{})
conn := &mockConn{}
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) {
close(openChannelStarted)
<-openChannelBlock
}).Return(channel, (<-chan *ssh.Request)(requests), nil)
forwarder := New(cfg, slug.New(), conn).(*forwarder)
forwarder.SetForwardedPort(8080)
ctx, cancel := context.WithCancel(context.Background())
resultChan := make(chan error, 1)
go func() {
origin := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 7000}
_, _, err := forwarder.OpenForwardedChannel(ctx, origin)
resultChan <- err
}()
<-openChannelStarted
cancel()
close(openChannelBlock)
select {
case err := <-resultChan:
require.Error(t, err)
assert.Contains(t, err.Error(), "context cancelled")
case <-time.After(2 * time.Second):
t.Fatal("OpenForwardedChannel did not return")
}
time.Sleep(50 * time.Millisecond)
conn.AssertExpectations(t)
cfg.AssertExpectations(t)
channel.AssertExpectations(t)
}
func TestCreateForwardedTCPIPPayloadEdgeCases(t *testing.T) {
tests := []struct {
name string
originAddr string
destPort uint16
wantDestAddr string
wantDestPort uint32
}{
{
name: "IPv4 localhost",
originAddr: "127.0.0.1:5000",
destPort: 8080,
wantDestAddr: "localhost",
wantDestPort: 8080,
},
{
name: "IPv6 address",
originAddr: "[::1]:3000",
destPort: 9000,
wantDestAddr: "localhost",
wantDestPort: 9000,
},
{
name: "private network",
originAddr: "192.168.1.1:12345",
destPort: 443,
wantDestAddr: "localhost",
wantDestPort: 443,
},
{
name: "port 1",
originAddr: "10.0.0.1:1",
destPort: 1,
wantDestAddr: "localhost",
wantDestPort: 1,
},
{
name: "max port",
originAddr: "172.16.0.1:65535",
destPort: 65535,
wantDestAddr: "localhost",
wantDestPort: 65535,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := net.ResolveTCPAddr("tcp", tt.originAddr)
require.NoError(t, err)
payload := createForwardedTCPIPPayload(addr, tt.destPort)
var decoded struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
err = ssh.Unmarshal(payload, &decoded)
require.NoError(t, err)
assert.Equal(t, tt.wantDestAddr, decoded.DestAddr)
assert.Equal(t, tt.wantDestPort, decoded.DestPort)
})
}
}