test(transport): add unit tests for transport behavior using Testify
This commit is contained in:
@@ -0,0 +1,755 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
args := m.Called(user, oldKey, newKey)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
args := m.Called(key, session)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
args := m.Called(user)
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
|
||||
args := m.Called()
|
||||
return args.Get(0).(lifecycle.Lifecycle)
|
||||
}
|
||||
|
||||
func (m *MockSession) Interaction() interaction.Interaction {
|
||||
args := m.Called()
|
||||
return args.Get(0).(interaction.Interaction)
|
||||
}
|
||||
|
||||
func (m *MockSession) Forwarder() forwarder.Forwarder {
|
||||
args := m.Called()
|
||||
return args.Get(0).(forwarder.Forwarder)
|
||||
}
|
||||
|
||||
func (m *MockSession) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
func (m *MockSession) Detail() *types.Detail {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*types.Detail)
|
||||
}
|
||||
|
||||
type MockLifecycle struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) Connection() ssh.Conn {
|
||||
args := m.Called()
|
||||
return args.Get(0).(ssh.Conn)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) PortRegistry() port.Port {
|
||||
args := m.Called()
|
||||
return args.Get(0).(port.Port)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) User() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) SetChannel(channel ssh.Channel) {
|
||||
m.Called(channel)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) SetStatus(status types.SessionStatus) {
|
||||
m.Called(status)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) IsActive() bool {
|
||||
args := m.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) StartedAt() time.Time {
|
||||
args := m.Called()
|
||||
return args.Get(0).(time.Time)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockSSHConn struct {
|
||||
ssh.Conn
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(name, data)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
|
||||
type MockSSHChannel struct {
|
||||
ssh.Channel
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Write(data []byte) (int, error) {
|
||||
args := m.Called(data)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockForwarder struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||
args := m.Called(origin)
|
||||
return args.Get(0).([]byte)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||
m.Called(dst)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||
m.Called(dst, src)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) TunnelType() types.TunnelType {
|
||||
args := m.Called()
|
||||
return args.Get(0).(types.TunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) ForwardedPort() uint16 {
|
||||
args := m.Called()
|
||||
return uint16(args.Int(0))
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
||||
m.Called(tunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
||||
m.Called(port)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetListener(listener net.Listener) {
|
||||
m.Called(listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Listener() net.Listener {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(payload)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
|
||||
type MockConn struct {
|
||||
net.Conn
|
||||
mock.Mock
|
||||
ReadBuffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
if m.ReadBuffer != nil {
|
||||
return m.ReadBuffer.Read(b)
|
||||
}
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *wrappedConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func TestNewHTTPHandler(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
hh := newHTTPHandler("domain", msr, true)
|
||||
assert.NotNil(t, hh)
|
||||
assert.Equal(t, "domain", hh.domain)
|
||||
assert.Equal(t, msr, hh.sessionRegistry)
|
||||
assert.True(t, hh.redirectTLS)
|
||||
}
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isTLS bool
|
||||
redirectTLS bool
|
||||
request []byte
|
||||
expected []byte
|
||||
setupMocks func(*MockSessionRegistry)
|
||||
setupConn func() (net.Conn, net.Conn)
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "bad request - invalid host",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - missing host",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "isTLS true and redirectTLS true - no redirect",
|
||||
isTLS: true,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redirect to TLS",
|
||||
isTLS: false,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://example.domain/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle ping request",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "session not found",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return((registry.Session)(nil), fmt.Errorf("session not found"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - invalid http",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("INVALID\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - open channel fails",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
||||
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.Writer)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - send initial request fails",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
go func() {
|
||||
for range reqCh {
|
||||
}
|
||||
}()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - success",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.ReadWriter)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
||||
})
|
||||
|
||||
go func() {
|
||||
for range reqCh {
|
||||
}
|
||||
}()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redirect - write failure",
|
||||
isTLS: false,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - write failure",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle ping request - write failure",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "close connection - error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("Write", mock.Anything).Return(182, nil)
|
||||
mc.On("Close").Return(fmt.Errorf("close error"))
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - stream close error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return()
|
||||
},
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
|
||||
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
mc.On("RemoteAddr").Return(addr)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - middleware failure",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
|
||||
return k.Id == "test"
|
||||
})).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
},
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("Close").Return(nil).Times(2)
|
||||
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - channel close error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.ReadWriter)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - open channel timeout",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
oldTimeout := openChannelTimeout
|
||||
openChannelTimeout = 10 * time.Millisecond
|
||||
t.Cleanup(func() {
|
||||
openChannelTimeout = oldTimeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.Writer)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
|
||||
})
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Run(func(args mock.Arguments) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Close").Return(fmt.Errorf("cleanup close error"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSessionRegistry := new(MockSessionRegistry)
|
||||
hh := &httpHandler{
|
||||
domain: "domain",
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
redirectTLS: tt.redirectTLS,
|
||||
}
|
||||
|
||||
if tt.setupMocks != nil {
|
||||
tt.setupMocks(mockSessionRegistry)
|
||||
}
|
||||
|
||||
var serverConn, clientConn net.Conn
|
||||
if tt.setupConn != nil {
|
||||
serverConn, clientConn = tt.setupConn()
|
||||
} else {
|
||||
serverConn, clientConn = net.Pipe()
|
||||
}
|
||||
|
||||
if clientConn != nil {
|
||||
defer clientConn.Close()
|
||||
}
|
||||
|
||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
var wrappedServerConn net.Conn
|
||||
if _, ok := serverConn.(*MockConn); ok {
|
||||
wrappedServerConn = serverConn
|
||||
} else {
|
||||
wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr}
|
||||
}
|
||||
|
||||
responseChan := make(chan []byte, 1)
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
if clientConn != nil {
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
var res []byte
|
||||
for {
|
||||
buf := make([]byte, 4096)
|
||||
n, err := clientConn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Logf("Error reading response: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
res = append(res, buf[:n]...)
|
||||
if len(tt.expected) > 0 && len(res) >= len(tt.expected) {
|
||||
break
|
||||
}
|
||||
}
|
||||
responseChan <- res
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := clientConn.Write(tt.request)
|
||||
if err != nil {
|
||||
t.Logf("Error writing request: %v", err)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
close(responseChan)
|
||||
close(doneChan)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
hh.Handler(wrappedServerConn, tt.isTLS)
|
||||
}()
|
||||
|
||||
select {
|
||||
case response := <-responseChan:
|
||||
if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" {
|
||||
resStr := string(response)
|
||||
assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n"))
|
||||
assert.Contains(t, resStr, "Content-Length: 5\r\n")
|
||||
assert.Contains(t, resStr, "Server: Tunnel Please\r\n")
|
||||
assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello"))
|
||||
} else {
|
||||
assert.Equal(t, string(tt.expected), string(response))
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
if clientConn != nil {
|
||||
t.Fatal("Test timeout - no response received")
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if clientConn != nil {
|
||||
<-doneChan
|
||||
}
|
||||
|
||||
mockSessionRegistry.AssertExpectations(t)
|
||||
if mc, ok := serverConn.(*MockConn); ok {
|
||||
mc.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user