1803 lines
41 KiB
Go
1803 lines
41 KiB
Go
package forwarder
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
"tunnel_pls/session/slug"
|
|
"tunnel_pls/types"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type mockConfig struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockConfig) Domain() string { return m.Called().String(0) }
|
|
func (m *mockConfig) SSHPort() string { return m.Called().String(0) }
|
|
func (m *mockConfig) HTTPPort() string { return m.Called().String(0) }
|
|
func (m *mockConfig) HTTPSPort() string { return m.Called().String(0) }
|
|
func (m *mockConfig) KeyLoc() string { return m.Called().String(0) }
|
|
func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
|
func (m *mockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
|
func (m *mockConfig) TLSStoragePath() string { return m.Called().String(0) }
|
|
func (m *mockConfig) ACMEEmail() string { return m.Called().String(0) }
|
|
func (m *mockConfig) CFAPIToken() string { return m.Called().String(0) }
|
|
func (m *mockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
|
func (m *mockConfig) AllowedPortsStart() uint16 { return m.Called().Get(0).(uint16) }
|
|
func (m *mockConfig) AllowedPortsEnd() uint16 { return m.Called().Get(0).(uint16) }
|
|
func (m *mockConfig) BufferSize() int { return m.Called().Int(0) }
|
|
func (m *mockConfig) HeaderSize() int { return m.Called().Int(0) }
|
|
func (m *mockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
|
func (m *mockConfig) PprofPort() string { return m.Called().String(0) }
|
|
func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
|
|
func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) }
|
|
func (m *mockConfig) GRPCPort() string { return m.Called().String(0) }
|
|
func (m *mockConfig) NodeToken() string { return m.Called().String(0) }
|
|
|
|
type mockConn struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (c *mockConn) Close() error { return c.Called().Error(0) }
|
|
func (c *mockConn) User() string { return c.Called().String(0) }
|
|
func (c *mockConn) SessionID() []byte { return c.Called().Get(0).([]byte) }
|
|
func (c *mockConn) ClientVersion() []byte { return c.Called().Get(0).([]byte) }
|
|
func (c *mockConn) ServerVersion() []byte { return c.Called().Get(0).([]byte) }
|
|
func (c *mockConn) RemoteAddr() net.Addr { return c.Called().Get(0).(net.Addr) }
|
|
func (c *mockConn) LocalAddr() net.Addr { return c.Called().Get(0).(net.Addr) }
|
|
func (c *mockConn) SendRequest(s string, b bool, d []byte) (bool, []byte, error) {
|
|
args := c.Called(s, b, d)
|
|
return args.Bool(0), args.Get(1).([]byte), args.Error(2)
|
|
}
|
|
func (c *mockConn) Wait() error { return c.Called().Error(0) }
|
|
|
|
func (c *mockConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
|
args := c.Called(name, data)
|
|
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
|
}
|
|
|
|
type testChannel struct {
|
|
mock.Mock
|
|
readBuf *syncBuffer
|
|
writeBuf *syncBuffer
|
|
closedWrite atomic.Bool
|
|
}
|
|
|
|
func (c *testChannel) Read(b []byte) (int, error) {
|
|
return c.readBuf.Read(b)
|
|
}
|
|
|
|
func (c *testChannel) Write(b []byte) (int, error) {
|
|
return c.writeBuf.Write(b)
|
|
}
|
|
|
|
func (c *testChannel) Close() error {
|
|
return c.Called().Error(0)
|
|
}
|
|
|
|
func (c *testChannel) CloseWrite() error {
|
|
c.closedWrite.Store(true)
|
|
return c.writeBuf.Close()
|
|
}
|
|
|
|
func (c *testChannel) Stderr() io.ReadWriter {
|
|
return c.Called().Get(0).(io.ReadWriter)
|
|
}
|
|
|
|
func (c *testChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
|
|
args := c.Called(name, wantReply, payload)
|
|
return args.Bool(0), args.Error(1)
|
|
}
|
|
|
|
func (c *testChannel) AckRequest(ok bool, payload []byte) error {
|
|
return c.Called(ok, payload).Error(0)
|
|
}
|
|
|
|
type syncBuffer struct {
|
|
mu sync.Mutex
|
|
buf []byte
|
|
closed bool
|
|
cond *sync.Cond
|
|
}
|
|
|
|
func newSyncBuffer() *syncBuffer {
|
|
sb := &syncBuffer{}
|
|
sb.cond = sync.NewCond(&sb.mu)
|
|
return sb
|
|
}
|
|
|
|
func (sb *syncBuffer) Write(p []byte) (int, error) {
|
|
sb.mu.Lock()
|
|
defer sb.mu.Unlock()
|
|
if sb.closed {
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
sb.buf = append(sb.buf, p...)
|
|
sb.cond.Broadcast()
|
|
return len(p), nil
|
|
}
|
|
|
|
func (sb *syncBuffer) Read(p []byte) (int, error) {
|
|
sb.mu.Lock()
|
|
defer sb.mu.Unlock()
|
|
|
|
for len(sb.buf) == 0 {
|
|
if sb.closed {
|
|
return 0, io.EOF
|
|
}
|
|
sb.cond.Wait()
|
|
}
|
|
|
|
n := copy(p, sb.buf)
|
|
sb.buf = sb.buf[n:]
|
|
return n, nil
|
|
}
|
|
|
|
func (sb *syncBuffer) Close() error {
|
|
sb.mu.Lock()
|
|
defer sb.mu.Unlock()
|
|
sb.closed = true
|
|
sb.cond.Broadcast()
|
|
return nil
|
|
}
|
|
|
|
func newChannelPair() (*testChannel, *testChannelPeer) {
|
|
peerToChBuf := newSyncBuffer()
|
|
chToPeerBuf := newSyncBuffer()
|
|
|
|
channel := &testChannel{
|
|
readBuf: peerToChBuf,
|
|
writeBuf: chToPeerBuf,
|
|
}
|
|
|
|
peer := &testChannelPeer{
|
|
readBuf: chToPeerBuf,
|
|
writeBuf: peerToChBuf,
|
|
}
|
|
|
|
channel.On("Close").Return(nil).Maybe()
|
|
|
|
return channel, peer
|
|
}
|
|
|
|
type testChannelPeer struct {
|
|
readBuf *syncBuffer
|
|
writeBuf *syncBuffer
|
|
}
|
|
|
|
func (p *testChannelPeer) Read(b []byte) (int, error) {
|
|
return p.readBuf.Read(b)
|
|
}
|
|
|
|
func (p *testChannelPeer) Write(b []byte) (int, error) {
|
|
return p.writeBuf.Write(b)
|
|
}
|
|
|
|
func (p *testChannelPeer) CloseWrite() error {
|
|
return p.writeBuf.Close()
|
|
}
|
|
|
|
func newPipePair() (*pipeConn, *pipeConn) {
|
|
r1, w1 := io.Pipe()
|
|
r2, w2 := io.Pipe()
|
|
|
|
conn1 := &pipeConn{
|
|
reader: r1,
|
|
writer: w2,
|
|
}
|
|
|
|
conn2 := &pipeConn{
|
|
reader: r2,
|
|
writer: w1,
|
|
}
|
|
|
|
return conn1, conn2
|
|
}
|
|
|
|
type pipeConn struct {
|
|
reader *io.PipeReader
|
|
writer *io.PipeWriter
|
|
}
|
|
|
|
func (p *pipeConn) Read(b []byte) (int, error) {
|
|
return p.reader.Read(b)
|
|
}
|
|
|
|
func (p *pipeConn) Write(b []byte) (int, error) {
|
|
return p.writer.Write(b)
|
|
}
|
|
|
|
func (p *pipeConn) Close() error {
|
|
err := p.reader.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = p.writer.Close()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *pipeConn) CloseWrite() error {
|
|
return p.writer.Close()
|
|
}
|
|
|
|
func (p *pipeConn) LocalAddr() net.Addr {
|
|
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
|
}
|
|
|
|
func (p *pipeConn) RemoteAddr() net.Addr {
|
|
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
|
}
|
|
|
|
func (p *pipeConn) SetDeadline(t time.Time) error { return nil }
|
|
func (p *pipeConn) SetReadDeadline(t time.Time) error { return nil }
|
|
func (p *pipeConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
|
|
func TestNew(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
bufferSize int
|
|
wantBufLen int
|
|
}{
|
|
{
|
|
name: "default buffer size",
|
|
bufferSize: 16,
|
|
wantBufLen: 16,
|
|
},
|
|
{
|
|
name: "custom buffer size",
|
|
bufferSize: 32,
|
|
wantBufLen: 32,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
|
|
s := slug.New()
|
|
conn := &mockConn{}
|
|
|
|
forwarder := New(cfg, s, conn).(*forwarder)
|
|
|
|
buf := forwarder.bufferPool.Get().([]byte)
|
|
require.Len(t, buf, tt.wantBufLen)
|
|
forwarder.bufferPool.Put(buf)
|
|
|
|
assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType())
|
|
assert.Equal(t, uint16(0), forwarder.ForwardedPort())
|
|
assert.Equal(t, conn, forwarder.conn)
|
|
assert.Equal(t, s, forwarder.slug)
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleConnection(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
bufferSize int
|
|
messageToDst []byte
|
|
messageToSrc []byte
|
|
}{
|
|
{
|
|
name: "small messages",
|
|
bufferSize: 4,
|
|
messageToDst: []byte("hi"),
|
|
messageToSrc: []byte("yo"),
|
|
},
|
|
{
|
|
name: "medium messages",
|
|
bufferSize: 8,
|
|
messageToDst: []byte("hello"),
|
|
messageToSrc: []byte("world"),
|
|
},
|
|
{
|
|
name: "larger messages",
|
|
bufferSize: 16,
|
|
messageToDst: []byte("I love femboy"),
|
|
messageToSrc: []byte("mee too"),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
channel, channelPeer := newChannelPair()
|
|
dstEndpoint, dstPeer := newPipePair()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
forwarder.HandleConnection(dstEndpoint, channel)
|
|
close(done)
|
|
}()
|
|
|
|
readDst := make(chan struct {
|
|
data []byte
|
|
err error
|
|
}, 1)
|
|
go func() {
|
|
buf := make([]byte, len(tt.messageToDst))
|
|
n, err := io.ReadFull(dstPeer, buf)
|
|
readDst <- struct {
|
|
data []byte
|
|
err error
|
|
}{data: buf[:n], err: err}
|
|
}()
|
|
|
|
_, err := channelPeer.Write(tt.messageToDst)
|
|
require.NoError(t, err)
|
|
|
|
dstResult := <-readDst
|
|
require.NoError(t, dstResult.err)
|
|
assert.Equal(t, tt.messageToDst, dstResult.data)
|
|
|
|
readSrc := make(chan struct {
|
|
data []byte
|
|
err error
|
|
}, 1)
|
|
go func() {
|
|
buf := make([]byte, len(tt.messageToSrc))
|
|
n, err := io.ReadFull(channelPeer, buf)
|
|
readSrc <- struct {
|
|
data []byte
|
|
err error
|
|
}{data: buf[:n], err: err}
|
|
}()
|
|
|
|
_, err = dstPeer.Write(tt.messageToSrc)
|
|
require.NoError(t, err)
|
|
|
|
srcResult := <-readSrc
|
|
require.NoError(t, srcResult.err)
|
|
assert.Equal(t, tt.messageToSrc, srcResult.data)
|
|
|
|
require.NoError(t, channelPeer.CloseWrite())
|
|
require.NoError(t, dstPeer.CloseWrite())
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("HandleConnection did not complete")
|
|
}
|
|
assert.True(t, channel.closedWrite.Load())
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleConnection_Error(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
bufferSize int
|
|
messageToDst []byte
|
|
messageToSrc []byte
|
|
}{
|
|
{
|
|
name: "small messages",
|
|
bufferSize: 4,
|
|
messageToDst: []byte("hi"),
|
|
messageToSrc: []byte("yo"),
|
|
},
|
|
{
|
|
name: "medium messages",
|
|
bufferSize: 8,
|
|
messageToDst: []byte("hello"),
|
|
messageToSrc: []byte("world"),
|
|
},
|
|
{
|
|
name: "larger messages",
|
|
bufferSize: 16,
|
|
messageToDst: []byte("I love femboy"),
|
|
messageToSrc: []byte("mee too"),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
channel, _ := newChannelPair()
|
|
dstEndpoint, _ := newPipePair()
|
|
|
|
go func() {
|
|
forwarder.HandleConnection(dstEndpoint, channel)
|
|
}()
|
|
|
|
err := dstEndpoint.Close()
|
|
assert.NoError(t, err)
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOpenForwardedChannel(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
forwardedPort uint16
|
|
originIP string
|
|
originPort int
|
|
wantDestAddr string
|
|
wantDestPort uint32
|
|
wantOrigAddr string
|
|
wantOrigPort uint32
|
|
}{
|
|
{
|
|
name: "localhost origin",
|
|
forwardedPort: 2222,
|
|
originIP: "127.0.0.1",
|
|
originPort: 9000,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 2222,
|
|
wantOrigAddr: "127.0.0.1",
|
|
wantOrigPort: 9000,
|
|
},
|
|
{
|
|
name: "remote origin",
|
|
forwardedPort: 8080,
|
|
originIP: "192.168.1.100",
|
|
originPort: 5000,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 8080,
|
|
wantOrigAddr: "192.168.1.100",
|
|
wantOrigPort: 5000,
|
|
},
|
|
{
|
|
name: "different port",
|
|
forwardedPort: 3000,
|
|
originIP: "10.0.0.1",
|
|
originPort: 7777,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 3000,
|
|
wantOrigAddr: "10.0.0.1",
|
|
wantOrigPort: 7777,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(8).Maybe()
|
|
channel := &testChannel{
|
|
readBuf: newSyncBuffer(),
|
|
writeBuf: newSyncBuffer(),
|
|
}
|
|
requests := make(chan *ssh.Request)
|
|
|
|
var capturedData []byte
|
|
conn := &mockConn{}
|
|
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) {
|
|
data := args.Get(1).([]byte)
|
|
capturedData = make([]byte, len(data))
|
|
copy(capturedData, data)
|
|
}).Return(channel, (<-chan *ssh.Request)(requests), nil)
|
|
|
|
forwarder := New(cfg, slug.New(), conn).(*forwarder)
|
|
forwarder.SetForwardedPort(tt.forwardedPort)
|
|
|
|
origin := &net.TCPAddr{IP: net.ParseIP(tt.originIP), Port: tt.originPort}
|
|
ch, reqs, err := forwarder.OpenForwardedChannel(context.Background(), origin)
|
|
require.NoError(t, err)
|
|
assert.Same(t, channel, ch)
|
|
assert.NotNil(t, reqs)
|
|
|
|
var payload struct {
|
|
DestAddr string
|
|
DestPort uint32
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
err = ssh.Unmarshal(capturedData, &payload)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.wantDestAddr, payload.DestAddr)
|
|
assert.Equal(t, tt.wantDestPort, payload.DestPort)
|
|
assert.Equal(t, tt.wantOrigAddr, payload.OriginAddr)
|
|
assert.Equal(t, tt.wantOrigPort, payload.OriginPort)
|
|
|
|
conn.AssertExpectations(t)
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOpenForwardedChannelContextCancellation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cancelBefore bool
|
|
cancelDuring bool
|
|
wantErr bool
|
|
wantErrType error
|
|
}{
|
|
{
|
|
name: "cancel during open",
|
|
cancelBefore: false,
|
|
cancelDuring: true,
|
|
wantErr: true,
|
|
wantErrType: context.Canceled,
|
|
},
|
|
{
|
|
name: "cancel before open",
|
|
cancelBefore: true,
|
|
cancelDuring: false,
|
|
wantErr: true,
|
|
wantErrType: context.Canceled,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(8).Maybe()
|
|
channel := &testChannel{
|
|
readBuf: newSyncBuffer(),
|
|
writeBuf: newSyncBuffer(),
|
|
}
|
|
channel.On("Close").Return(nil)
|
|
requests := make(chan *ssh.Request)
|
|
|
|
openChannelCalled := make(chan struct{})
|
|
openChannelBlock := make(chan struct{})
|
|
|
|
conn := &mockConn{}
|
|
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) {
|
|
close(openChannelCalled)
|
|
<-openChannelBlock
|
|
}).Return(channel, (<-chan *ssh.Request)(requests), nil).Maybe()
|
|
|
|
forwarder := New(cfg, slug.New(), conn).(*forwarder)
|
|
forwarder.SetForwardedPort(8080)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
if tt.cancelBefore {
|
|
cancel()
|
|
}
|
|
|
|
var (
|
|
openedChannel ssh.Channel
|
|
openedReqs <-chan *ssh.Request
|
|
openErr error
|
|
)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
origin := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 7000}
|
|
openedChannel, openedReqs, openErr = forwarder.OpenForwardedChannel(ctx, origin)
|
|
close(done)
|
|
}()
|
|
|
|
if tt.cancelDuring {
|
|
<-openChannelCalled
|
|
cancel()
|
|
}
|
|
close(openChannelBlock)
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("OpenForwardedChannel did not return after cancellation")
|
|
}
|
|
|
|
if tt.wantErr {
|
|
require.Error(t, openErr)
|
|
assert.True(t, errors.Is(openErr, tt.wantErrType))
|
|
assert.Nil(t, openedChannel)
|
|
assert.Nil(t, openedReqs)
|
|
} else {
|
|
require.NoError(t, openErr)
|
|
assert.NotNil(t, openedChannel)
|
|
assert.NotNil(t, openedReqs)
|
|
}
|
|
|
|
conn.AssertExpectations(t)
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCreateForwardedTCPIPPayload(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
originIP string
|
|
originPort int
|
|
forwardedPort uint16
|
|
wantDestAddr string
|
|
wantDestPort uint32
|
|
wantOriginAddr string
|
|
wantOriginPort uint32
|
|
}{
|
|
{
|
|
name: "standard case",
|
|
originIP: "192.0.2.10",
|
|
originPort: 5050,
|
|
forwardedPort: 8080,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 8080,
|
|
wantOriginAddr: "192.0.2.10",
|
|
wantOriginPort: 5050,
|
|
},
|
|
{
|
|
name: "localhost origin",
|
|
originIP: "127.0.0.1",
|
|
originPort: 3000,
|
|
forwardedPort: 9000,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 9000,
|
|
wantOriginAddr: "127.0.0.1",
|
|
wantOriginPort: 3000,
|
|
},
|
|
{
|
|
name: "high port numbers",
|
|
originIP: "10.0.0.1",
|
|
originPort: 65535,
|
|
forwardedPort: 65534,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 65534,
|
|
wantOriginAddr: "10.0.0.1",
|
|
wantOriginPort: 65535,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
origin := &net.TCPAddr{IP: net.ParseIP(tt.originIP), Port: tt.originPort}
|
|
payload := createForwardedTCPIPPayload(origin, tt.forwardedPort)
|
|
|
|
var decoded struct {
|
|
DestAddr string
|
|
DestPort uint32
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
|
|
err := ssh.Unmarshal(payload, &decoded)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.wantDestAddr, decoded.DestAddr)
|
|
assert.Equal(t, tt.wantDestPort, decoded.DestPort)
|
|
assert.Equal(t, tt.wantOriginAddr, decoded.OriginAddr)
|
|
assert.Equal(t, tt.wantOriginPort, decoded.OriginPort)
|
|
})
|
|
}
|
|
}
|
|
|
|
type mockReader struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockReader) Read(p []byte) (int, error) {
|
|
args := m.Called(p)
|
|
return args.Int(0), args.Error(1)
|
|
}
|
|
|
|
type mockWriter struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockWriter) Write(p []byte) (int, error) {
|
|
args := m.Called(p)
|
|
return args.Int(0), args.Error(1)
|
|
}
|
|
|
|
func (m *mockWriter) CloseWrite() error {
|
|
return m.Called().Error(0)
|
|
}
|
|
|
|
type mockWriteCloser struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockWriteCloser) Write(p []byte) (int, error) {
|
|
args := m.Called(p)
|
|
return args.Int(0), args.Error(1)
|
|
}
|
|
|
|
func (m *mockWriteCloser) Close() error {
|
|
return m.Called().Error(0)
|
|
}
|
|
|
|
func TestCopyAndClose(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupSrc func() io.Reader
|
|
setupDst func() io.Writer
|
|
direction string
|
|
wantErr bool
|
|
wantErrMsg string
|
|
checkErrTypes []error
|
|
}{
|
|
{
|
|
name: "successful copy with EOF",
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
r.On("Read", mock.Anything).Return(5, nil).Once()
|
|
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriter{}
|
|
w.On("Write", mock.Anything).Return(5, nil).Once()
|
|
w.On("CloseWrite").Return(nil).Once()
|
|
return w
|
|
},
|
|
direction: "src->dst",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "copy error - not EOF or ErrClosed",
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
customErr := errors.New("custom read error")
|
|
r.On("Read", mock.Anything).Return(0, customErr).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriter{}
|
|
w.On("CloseWrite").Return(nil).Once()
|
|
return w
|
|
},
|
|
direction: "src->dst",
|
|
wantErr: true,
|
|
wantErrMsg: "copy error (src->dst)",
|
|
},
|
|
{
|
|
name: "copy error - ErrClosed should be ignored",
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
r.On("Read", mock.Anything).Return(0, net.ErrClosed).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriter{}
|
|
w.On("CloseWrite").Return(nil).Once()
|
|
return w
|
|
},
|
|
direction: "src->dst",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "close writer error - not EOF",
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriter{}
|
|
closeErr := errors.New("close error")
|
|
w.On("CloseWrite").Return(closeErr).Once()
|
|
return w
|
|
},
|
|
direction: "src->dst",
|
|
wantErr: true,
|
|
wantErrMsg: "close stream error (src->dst)",
|
|
},
|
|
{
|
|
name: "close writer error - EOF should be ignored",
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriter{}
|
|
w.On("CloseWrite").Return(io.EOF).Once()
|
|
return w
|
|
},
|
|
direction: "src->dst",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "both copy and close errors",
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
copyErr := errors.New("copy error")
|
|
r.On("Read", mock.Anything).Return(0, copyErr).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriter{}
|
|
closeErr := errors.New("close error")
|
|
w.On("CloseWrite").Return(closeErr).Once()
|
|
return w
|
|
},
|
|
direction: "src->dst",
|
|
wantErr: true,
|
|
wantErrMsg: "copy error (src->dst)",
|
|
},
|
|
{
|
|
name: "successful copy with WriteCloser",
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
r.On("Read", mock.Anything).Return(5, nil).Once()
|
|
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriteCloser{}
|
|
w.On("Write", mock.Anything).Return(5, nil).Once()
|
|
w.On("Close").Return(nil).Once()
|
|
return w
|
|
},
|
|
direction: "dst->src",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "WriteCloser close error",
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriteCloser{}
|
|
closeErr := errors.New("writeCloser close error")
|
|
w.On("Close").Return(closeErr).Once()
|
|
return w
|
|
},
|
|
direction: "dst->src",
|
|
wantErr: true,
|
|
wantErrMsg: "close stream error (dst->src)",
|
|
},
|
|
{
|
|
name: "copy with multiple reads before EOF",
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
r.On("Read", mock.Anything).Return(10, nil).Once()
|
|
r.On("Read", mock.Anything).Return(15, nil).Once()
|
|
r.On("Read", mock.Anything).Return(5, nil).Once()
|
|
r.On("Read", mock.Anything).Return(0, io.EOF).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriter{}
|
|
w.On("Write", mock.Anything).Return(10, nil).Once()
|
|
w.On("Write", mock.Anything).Return(15, nil).Once()
|
|
w.On("Write", mock.Anything).Return(5, nil).Once()
|
|
w.On("CloseWrite").Return(nil).Once()
|
|
return w
|
|
},
|
|
direction: "src->dst",
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(32).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
src := tt.setupSrc()
|
|
dst := tt.setupDst()
|
|
|
|
err := forwarder.copyAndClose(dst, src, tt.direction)
|
|
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), tt.wantErrMsg)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
if mr, ok := src.(*mockReader); ok {
|
|
mr.AssertExpectations(t)
|
|
}
|
|
if mw, ok := dst.(*mockWriter); ok {
|
|
mw.AssertExpectations(t)
|
|
}
|
|
if mwc, ok := dst.(*mockWriteCloser); ok {
|
|
mwc.AssertExpectations(t)
|
|
}
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCopyAndCloseJoinedErrors(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(32).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
src := &mockReader{}
|
|
copyErr := errors.New("copy failed")
|
|
src.On("Read", mock.Anything).Return(0, copyErr).Once()
|
|
|
|
dst := &mockWriter{}
|
|
closeErr := errors.New("close failed")
|
|
dst.On("CloseWrite").Return(closeErr).Once()
|
|
|
|
err := forwarder.copyAndClose(dst, src, "test")
|
|
|
|
require.Error(t, err)
|
|
|
|
assert.Contains(t, err.Error(), "copy error (test)")
|
|
assert.Contains(t, err.Error(), "close stream error (test)")
|
|
assert.Contains(t, err.Error(), "copy failed")
|
|
assert.Contains(t, err.Error(), "close failed")
|
|
|
|
src.AssertExpectations(t)
|
|
dst.AssertExpectations(t)
|
|
cfg.AssertExpectations(t)
|
|
}
|
|
|
|
func TestCopyWithBuffer(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
bufferSize int
|
|
setupSrc func() io.Reader
|
|
setupDst func() io.Writer
|
|
wantBytesCount int64
|
|
wantErr bool
|
|
wantErrType error
|
|
}{
|
|
{
|
|
name: "successful copy small data",
|
|
bufferSize: 16,
|
|
setupSrc: func() io.Reader {
|
|
return io.NopCloser(bytes.NewReader([]byte("hello world")))
|
|
},
|
|
setupDst: func() io.Writer {
|
|
return &bytes.Buffer{}
|
|
},
|
|
wantBytesCount: 11,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "successful copy large data",
|
|
bufferSize: 8,
|
|
setupSrc: func() io.Reader {
|
|
data := make([]byte, 1024)
|
|
for i := range data {
|
|
data[i] = byte(i % 256)
|
|
}
|
|
return io.NopCloser(bytes.NewReader(data))
|
|
},
|
|
setupDst: func() io.Writer {
|
|
return &bytes.Buffer{}
|
|
},
|
|
wantBytesCount: 1024,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "empty data",
|
|
bufferSize: 16,
|
|
setupSrc: func() io.Reader {
|
|
return io.NopCloser(bytes.NewReader([]byte{}))
|
|
},
|
|
setupDst: func() io.Writer {
|
|
return &bytes.Buffer{}
|
|
},
|
|
wantBytesCount: 0,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "read error",
|
|
bufferSize: 16,
|
|
setupSrc: func() io.Reader {
|
|
r := &mockReader{}
|
|
r.On("Read", mock.Anything).Return(0, errors.New("read error")).Once()
|
|
return r
|
|
},
|
|
setupDst: func() io.Writer {
|
|
return &bytes.Buffer{}
|
|
},
|
|
wantBytesCount: 0,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "write error",
|
|
bufferSize: 16,
|
|
setupSrc: func() io.Reader {
|
|
return io.NopCloser(bytes.NewReader([]byte("test data")))
|
|
},
|
|
setupDst: func() io.Writer {
|
|
w := &mockWriter{}
|
|
w.On("Write", mock.Anything).Return(0, errors.New("write error")).Once()
|
|
return w
|
|
},
|
|
wantBytesCount: 0,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "partial write continues",
|
|
bufferSize: 16,
|
|
setupSrc: func() io.Reader {
|
|
return io.NopCloser(bytes.NewReader([]byte("testing")))
|
|
},
|
|
setupDst: func() io.Writer {
|
|
buf := &bytes.Buffer{}
|
|
return buf
|
|
},
|
|
wantBytesCount: 7,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "multiple buffer fills",
|
|
bufferSize: 4,
|
|
setupSrc: func() io.Reader {
|
|
return io.NopCloser(bytes.NewReader([]byte("this is a longer message")))
|
|
},
|
|
setupDst: func() io.Writer {
|
|
return &bytes.Buffer{}
|
|
},
|
|
wantBytesCount: 24,
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
src := tt.setupSrc()
|
|
dst := tt.setupDst()
|
|
|
|
n, err := forwarder.copyWithBuffer(dst, src)
|
|
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
} else {
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tt.wantBytesCount, n)
|
|
}
|
|
|
|
if buf, ok := dst.(*bytes.Buffer); ok && !tt.wantErr {
|
|
assert.Equal(t, tt.wantBytesCount, int64(buf.Len()))
|
|
}
|
|
|
|
if mr, ok := src.(*mockReader); ok {
|
|
mr.AssertExpectations(t)
|
|
}
|
|
if mw, ok := dst.(*mockWriter); ok {
|
|
mw.AssertExpectations(t)
|
|
}
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCopyWithBufferReusesBuffer(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(16).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
buf1 := forwarder.bufferPool.Get().([]byte)
|
|
initialPtr := &buf1[0]
|
|
forwarder.bufferPool.Put(buf1)
|
|
|
|
src := io.NopCloser(bytes.NewReader([]byte("test")))
|
|
dst := &bytes.Buffer{}
|
|
_, err := forwarder.copyWithBuffer(dst, src)
|
|
require.NoError(t, err)
|
|
|
|
buf2 := forwarder.bufferPool.Get().([]byte)
|
|
secondPtr := &buf2[0]
|
|
forwarder.bufferPool.Put(buf2)
|
|
|
|
assert.Equal(t, len(buf1), len(buf2))
|
|
|
|
assert.Len(t, buf2, 16)
|
|
|
|
_ = initialPtr
|
|
_ = secondPtr
|
|
cfg.AssertExpectations(t)
|
|
}
|
|
|
|
func TestSetType(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tunnelType types.TunnelType
|
|
}{
|
|
{
|
|
name: "set to HTTP",
|
|
tunnelType: types.TunnelTypeHTTP,
|
|
},
|
|
{
|
|
name: "set to TCP",
|
|
tunnelType: types.TunnelTypeTCP,
|
|
},
|
|
{
|
|
name: "set to UNKNOWN",
|
|
tunnelType: types.TunnelTypeUNKNOWN,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(16).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
assert.Equal(t, types.TunnelTypeUNKNOWN, forwarder.TunnelType())
|
|
|
|
forwarder.SetType(tt.tunnelType)
|
|
|
|
assert.Equal(t, tt.tunnelType, forwarder.TunnelType())
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTunnelType(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tunnelType types.TunnelType
|
|
}{
|
|
{
|
|
name: "get HTTP type",
|
|
tunnelType: types.TunnelTypeHTTP,
|
|
},
|
|
{
|
|
name: "get TCP type",
|
|
tunnelType: types.TunnelTypeTCP,
|
|
},
|
|
{
|
|
name: "get UNKNOWN type",
|
|
tunnelType: types.TunnelTypeUNKNOWN,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(16).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
forwarder.SetType(tt.tunnelType)
|
|
result := forwarder.TunnelType()
|
|
|
|
assert.Equal(t, tt.tunnelType, result)
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSetForwardedPort(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
port uint16
|
|
}{
|
|
{
|
|
name: "set standard port",
|
|
port: 8080,
|
|
},
|
|
{
|
|
name: "set low port",
|
|
port: 80,
|
|
},
|
|
{
|
|
name: "set high port",
|
|
port: 65535,
|
|
},
|
|
{
|
|
name: "set zero port",
|
|
port: 0,
|
|
},
|
|
{
|
|
name: "set custom port",
|
|
port: 3000,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(16).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
assert.Equal(t, uint16(0), forwarder.ForwardedPort())
|
|
|
|
forwarder.SetForwardedPort(tt.port)
|
|
|
|
assert.Equal(t, tt.port, forwarder.ForwardedPort())
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestForwardedPort(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
port uint16
|
|
}{
|
|
{
|
|
name: "get default port",
|
|
port: 0,
|
|
},
|
|
{
|
|
name: "get standard port",
|
|
port: 8080,
|
|
},
|
|
{
|
|
name: "get high port",
|
|
port: 65535,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(16).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
if tt.port != 0 {
|
|
forwarder.SetForwardedPort(tt.port)
|
|
}
|
|
|
|
result := forwarder.ForwardedPort()
|
|
assert.Equal(t, tt.port, result)
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSetListener(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupListener func() net.Listener
|
|
}{
|
|
{
|
|
name: "set TCP listener",
|
|
setupListener: func() net.Listener {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
return listener
|
|
},
|
|
},
|
|
{
|
|
name: "set nil listener",
|
|
setupListener: func() net.Listener {
|
|
return nil
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(16).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
listener := tt.setupListener()
|
|
if listener != nil {
|
|
defer func(listener net.Listener) {
|
|
err := listener.Close()
|
|
assert.NoError(t, err)
|
|
}(listener)
|
|
}
|
|
|
|
assert.Nil(t, forwarder.Listener())
|
|
|
|
forwarder.SetListener(listener)
|
|
|
|
assert.Equal(t, listener, forwarder.Listener())
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestListener(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupListener func() net.Listener
|
|
}{
|
|
{
|
|
name: "get nil listener",
|
|
setupListener: func() net.Listener {
|
|
return nil
|
|
},
|
|
},
|
|
{
|
|
name: "get TCP listener",
|
|
setupListener: func() net.Listener {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
return listener
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(16).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
listener := tt.setupListener()
|
|
if listener != nil {
|
|
defer func(listener net.Listener) {
|
|
err := listener.Close()
|
|
assert.NoError(t, err)
|
|
}(listener)
|
|
forwarder.SetListener(listener)
|
|
}
|
|
|
|
result := forwarder.Listener()
|
|
assert.Equal(t, listener, result)
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestClose(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupListener func() net.Listener
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "close with nil listener",
|
|
setupListener: func() net.Listener {
|
|
return nil
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "close with active listener",
|
|
setupListener: func() net.Listener {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
return listener
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "close already closed listener",
|
|
setupListener: func() net.Listener {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
err = listener.Close()
|
|
assert.NoError(t, err)
|
|
return listener
|
|
},
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(16).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
listener := tt.setupListener()
|
|
if listener != nil {
|
|
forwarder.SetListener(listener)
|
|
}
|
|
|
|
err := forwarder.Close()
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCloseWriter(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setup func() io.Writer
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "close writer with CloseWrite method",
|
|
setup: func() io.Writer {
|
|
w := &mockWriter{}
|
|
w.On("CloseWrite").Return(nil).Once()
|
|
return w
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "close writer with CloseWrite error",
|
|
setup: func() io.Writer {
|
|
w := &mockWriter{}
|
|
w.On("CloseWrite").Return(errors.New("close write error")).Once()
|
|
return w
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "close WriteCloser",
|
|
setup: func() io.Writer {
|
|
w := &mockWriteCloser{}
|
|
w.On("Close").Return(nil).Once()
|
|
return w
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "close WriteCloser with error",
|
|
setup: func() io.Writer {
|
|
w := &mockWriteCloser{}
|
|
w.On("Close").Return(errors.New("close error")).Once()
|
|
return w
|
|
},
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "close plain writer (no close method)",
|
|
setup: func() io.Writer {
|
|
return &bytes.Buffer{}
|
|
},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
writer := tt.setup()
|
|
|
|
err := closeWriter(writer)
|
|
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
if mw, ok := writer.(*mockWriter); ok {
|
|
mw.AssertExpectations(t)
|
|
}
|
|
if mwc, ok := writer.(*mockWriteCloser); ok {
|
|
mwc.AssertExpectations(t)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleConnectionWithErrors(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
bufferSize int
|
|
setupChannel func() (*testChannel, *testChannelPeer)
|
|
setupDst func() (net.Conn, *pipeConn)
|
|
simulateErr func(channel *testChannelPeer, dst *pipeConn)
|
|
}{
|
|
{
|
|
name: "handle read error from channel",
|
|
bufferSize: 16,
|
|
setupChannel: func() (*testChannel, *testChannelPeer) {
|
|
return newChannelPair()
|
|
},
|
|
setupDst: func() (net.Conn, *pipeConn) {
|
|
return newPipePair()
|
|
},
|
|
simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
|
|
err := channel.CloseWrite()
|
|
assert.NoError(t, err)
|
|
err = dst.CloseWrite()
|
|
assert.NoError(t, err)
|
|
},
|
|
},
|
|
{
|
|
name: "handle write error to destination",
|
|
bufferSize: 16,
|
|
setupChannel: func() (*testChannel, *testChannelPeer) {
|
|
return newChannelPair()
|
|
},
|
|
setupDst: func() (net.Conn, *pipeConn) {
|
|
return newPipePair()
|
|
},
|
|
simulateErr: func(channel *testChannelPeer, dst *pipeConn) {
|
|
err := dst.Close()
|
|
assert.NoError(t, err)
|
|
time.Sleep(10 * time.Millisecond)
|
|
write, err := channel.Write([]byte("test"))
|
|
assert.NotZero(t, write)
|
|
assert.NoError(t, err)
|
|
err = channel.CloseWrite()
|
|
assert.NoError(t, err)
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(tt.bufferSize).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
channel, channelPeer := tt.setupChannel()
|
|
dstEndpoint, dstPeer := tt.setupDst()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
forwarder.HandleConnection(dstEndpoint, channel)
|
|
close(done)
|
|
}()
|
|
|
|
tt.simulateErr(channelPeer, dstPeer)
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("HandleConnection did not complete")
|
|
}
|
|
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleConnectionDiscardOnExit(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(16).Maybe()
|
|
forwarder := New(cfg, slug.New(), nil).(*forwarder)
|
|
|
|
channel, channelPeer := newChannelPair()
|
|
dstEndpoint, dstPeer := newPipePair()
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
forwarder.HandleConnection(dstEndpoint, channel)
|
|
close(done)
|
|
}()
|
|
|
|
_, err := channelPeer.Write([]byte("test data"))
|
|
require.NoError(t, err)
|
|
require.NoError(t, channelPeer.CloseWrite())
|
|
require.NoError(t, dstPeer.Close())
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(10 * time.Second):
|
|
t.Fatal("HandleConnection did not complete")
|
|
}
|
|
|
|
cfg.AssertExpectations(t)
|
|
}
|
|
|
|
func TestOpenForwardedChannelSuccess(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
forwardedPort uint16
|
|
originAddr string
|
|
originPort int
|
|
}{
|
|
{
|
|
name: "open channel standard port",
|
|
forwardedPort: 8080,
|
|
originAddr: "127.0.0.1",
|
|
originPort: 9000,
|
|
},
|
|
{
|
|
name: "open channel high port",
|
|
forwardedPort: 65534,
|
|
originAddr: "192.168.1.100",
|
|
originPort: 5000,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(8).Maybe()
|
|
channel := &testChannel{
|
|
readBuf: newSyncBuffer(),
|
|
writeBuf: newSyncBuffer(),
|
|
}
|
|
requests := make(chan *ssh.Request)
|
|
|
|
conn := &mockConn{}
|
|
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).
|
|
Return(channel, (<-chan *ssh.Request)(requests), nil)
|
|
|
|
forwarder := New(cfg, slug.New(), conn).(*forwarder)
|
|
forwarder.SetForwardedPort(tt.forwardedPort)
|
|
|
|
origin := &net.TCPAddr{IP: net.ParseIP(tt.originAddr), Port: tt.originPort}
|
|
ch, reqs, err := forwarder.OpenForwardedChannel(context.Background(), origin)
|
|
|
|
require.NoError(t, err)
|
|
assert.NotNil(t, ch)
|
|
assert.NotNil(t, reqs)
|
|
|
|
conn.AssertExpectations(t)
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOpenForwardedChannelError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupConn func() *mockConn
|
|
wantErr bool
|
|
wantErrMsg string
|
|
}{
|
|
{
|
|
name: "open channel returns error",
|
|
setupConn: func() *mockConn {
|
|
conn := &mockConn{}
|
|
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).
|
|
Return((*testChannel)(nil), (<-chan *ssh.Request)(nil), errors.New("channel open failed"))
|
|
return conn
|
|
},
|
|
wantErr: true,
|
|
wantErrMsg: "channel open failed",
|
|
},
|
|
{
|
|
name: "open channel with nil channel",
|
|
setupConn: func() *mockConn {
|
|
conn := &mockConn{}
|
|
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).
|
|
Return((*testChannel)(nil), (<-chan *ssh.Request)(nil), nil)
|
|
return conn
|
|
},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(8).Maybe()
|
|
|
|
conn := tt.setupConn()
|
|
forwarder := New(cfg, slug.New(), conn).(*forwarder)
|
|
forwarder.SetForwardedPort(8080)
|
|
|
|
origin := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 9000}
|
|
_, _, err := forwarder.OpenForwardedChannel(context.Background(), origin)
|
|
|
|
if tt.wantErr {
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), tt.wantErrMsg)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
conn.AssertExpectations(t)
|
|
cfg.AssertExpectations(t)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOpenForwardedChannelContextCancelledDuringOpen(t *testing.T) {
|
|
cfg := &mockConfig{}
|
|
cfg.On("BufferSize").Return(8).Maybe()
|
|
|
|
channel := &testChannel{
|
|
readBuf: newSyncBuffer(),
|
|
writeBuf: newSyncBuffer(),
|
|
}
|
|
channel.On("Close").Return(nil).Maybe()
|
|
|
|
requests := make(chan *ssh.Request)
|
|
|
|
openChannelStarted := make(chan struct{})
|
|
openChannelBlock := make(chan struct{})
|
|
|
|
conn := &mockConn{}
|
|
conn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Run(func(args mock.Arguments) {
|
|
close(openChannelStarted)
|
|
<-openChannelBlock
|
|
}).Return(channel, (<-chan *ssh.Request)(requests), nil)
|
|
|
|
forwarder := New(cfg, slug.New(), conn).(*forwarder)
|
|
forwarder.SetForwardedPort(8080)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
resultChan := make(chan error, 1)
|
|
go func() {
|
|
origin := &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 7000}
|
|
_, _, err := forwarder.OpenForwardedChannel(ctx, origin)
|
|
resultChan <- err
|
|
}()
|
|
|
|
<-openChannelStarted
|
|
|
|
cancel()
|
|
|
|
close(openChannelBlock)
|
|
|
|
select {
|
|
case err := <-resultChan:
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "context cancelled")
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("OpenForwardedChannel did not return")
|
|
}
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
conn.AssertExpectations(t)
|
|
cfg.AssertExpectations(t)
|
|
channel.AssertExpectations(t)
|
|
}
|
|
|
|
func TestCreateForwardedTCPIPPayloadEdgeCases(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
originAddr string
|
|
destPort uint16
|
|
wantDestAddr string
|
|
wantDestPort uint32
|
|
}{
|
|
{
|
|
name: "IPv4 localhost",
|
|
originAddr: "127.0.0.1:5000",
|
|
destPort: 8080,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 8080,
|
|
},
|
|
{
|
|
name: "IPv6 address",
|
|
originAddr: "[::1]:3000",
|
|
destPort: 9000,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 9000,
|
|
},
|
|
{
|
|
name: "private network",
|
|
originAddr: "192.168.1.1:12345",
|
|
destPort: 443,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 443,
|
|
},
|
|
{
|
|
name: "port 1",
|
|
originAddr: "10.0.0.1:1",
|
|
destPort: 1,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 1,
|
|
},
|
|
{
|
|
name: "max port",
|
|
originAddr: "172.16.0.1:65535",
|
|
destPort: 65535,
|
|
wantDestAddr: "localhost",
|
|
wantDestPort: 65535,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
addr, err := net.ResolveTCPAddr("tcp", tt.originAddr)
|
|
require.NoError(t, err)
|
|
|
|
payload := createForwardedTCPIPPayload(addr, tt.destPort)
|
|
|
|
var decoded struct {
|
|
DestAddr string
|
|
DestPort uint32
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
|
|
err = ssh.Unmarshal(payload, &decoded)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, tt.wantDestAddr, decoded.DestAddr)
|
|
assert.Equal(t, tt.wantDestPort, decoded.DestPort)
|
|
})
|
|
}
|
|
}
|