1361 lines
35 KiB
Go
1361 lines
35 KiB
Go
package session
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/binary"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
"tunnel_pls/internal/config"
|
|
"tunnel_pls/internal/registry"
|
|
"tunnel_pls/session/lifecycle"
|
|
"tunnel_pls/types"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type mockRandom struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockRandom) String(length int) (string, error) {
|
|
args := m.Called(length)
|
|
return args.String(0), args.Error(1)
|
|
}
|
|
|
|
type mockConfig struct {
|
|
mock.Mock
|
|
config.Config
|
|
}
|
|
|
|
func (m *mockConfig) Domain() string { return m.Called().String(0) }
|
|
func (m *mockConfig) SSHPort() string { return m.Called().String(0) }
|
|
func (m *mockConfig) Mode() types.ServerMode {
|
|
args := m.Called()
|
|
if args.Get(0) == nil {
|
|
return 0
|
|
}
|
|
switch v := args.Get(0).(type) {
|
|
case types.ServerMode:
|
|
return v
|
|
case int:
|
|
return types.ServerMode(v)
|
|
default:
|
|
return types.ServerMode(args.Int(0))
|
|
}
|
|
}
|
|
func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
|
|
|
type mockRegistry struct {
|
|
mock.Mock
|
|
registry.Registry
|
|
removedKey types.SessionKey
|
|
}
|
|
|
|
func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool {
|
|
return m.Called(key, session).Bool(0)
|
|
}
|
|
|
|
func (m *mockRegistry) Remove(key types.SessionKey) {
|
|
m.removedKey = key
|
|
}
|
|
|
|
type mockPort struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockPort) AddRange(startPort, endPort uint16) error {
|
|
return m.Called(startPort, endPort).Error(0)
|
|
}
|
|
func (m *mockPort) Unassigned() (uint16, bool) {
|
|
args := m.Called()
|
|
var port uint16
|
|
if args.Get(0) != nil {
|
|
switch v := args.Get(0).(type) {
|
|
case int:
|
|
port = uint16(v)
|
|
case uint16:
|
|
port = v
|
|
case uint32:
|
|
port = uint16(v)
|
|
case int32:
|
|
port = uint16(v)
|
|
case float64:
|
|
port = uint16(v)
|
|
default:
|
|
port = uint16(args.Int(0))
|
|
}
|
|
}
|
|
return port, args.Bool(1)
|
|
}
|
|
func (m *mockPort) SetStatus(port uint16, assigned bool) error {
|
|
return m.Called(port, assigned).Error(0)
|
|
}
|
|
func (m *mockPort) Claim(port uint16) bool {
|
|
return m.Called(port).Bool(0)
|
|
}
|
|
|
|
type mockSSHConn struct {
|
|
ssh.Conn
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockSSHConn) Wait() error {
|
|
return m.Called().Error(0)
|
|
}
|
|
|
|
func (m *mockSSHConn) Close() error {
|
|
return m.Called().Error(0)
|
|
}
|
|
|
|
func (m *mockSSHConn) User() string {
|
|
return m.Called().String(0)
|
|
}
|
|
|
|
func setupSSH(t *testing.T) (sConn *ssh.ServerConn, sReqs <-chan *ssh.Request, sChans <-chan ssh.NewChannel, cConn ssh.Conn, cleanup func()) {
|
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
|
|
privDER := x509.MarshalPKCS1PrivateKey(key)
|
|
privBlock := pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: privDER,
|
|
}
|
|
pk, err := ssh.ParsePrivateKey(pem.EncodeToMemory(&privBlock))
|
|
require.NoError(t, err)
|
|
|
|
sCfg := &ssh.ServerConfig{
|
|
NoClientAuth: true,
|
|
}
|
|
sCfg.AddHostKey(pk)
|
|
|
|
cCfg := &ssh.ClientConfig{
|
|
User: "test",
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
var sConnObj *ssh.ServerConn
|
|
var sChansChan <-chan ssh.NewChannel
|
|
var sReqsChan <-chan *ssh.Request
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
sConnObj, sChansChan, sReqsChan, err = ssh.NewServerConn(conn, sCfg)
|
|
errChan <- err
|
|
}()
|
|
|
|
conn, err := net.Dial("tcp", l.Addr().String())
|
|
require.NoError(t, err)
|
|
cConnObj, cChans, cReqs, err := ssh.NewClientConn(conn, "pipe", cCfg)
|
|
require.NoError(t, err)
|
|
|
|
go ssh.DiscardRequests(cReqs)
|
|
go func() {
|
|
for newChan := range cChans {
|
|
if newChan.ChannelType() == "session" {
|
|
continue
|
|
}
|
|
err = newChan.Reject(ssh.Prohibited, "")
|
|
assert.NoError(t, err)
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case err := <-errChan:
|
|
require.NoError(t, err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("SSH handshake timed out")
|
|
}
|
|
|
|
return sConnObj, sReqsChan, sChansChan, cConnObj, func() {
|
|
_ = cConnObj.Close()
|
|
_ = sConnObj.Close()
|
|
_ = l.Close()
|
|
}
|
|
}
|
|
|
|
func TestNew(t *testing.T) {
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: &mockConfig{},
|
|
Conn: &ssh.ServerConn{},
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: &mockRegistry{},
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}
|
|
|
|
s := New(conf)
|
|
assert.NotNil(t, s)
|
|
assert.NotNil(t, s.Lifecycle())
|
|
assert.NotNil(t, s.Interaction())
|
|
assert.NotNil(t, s.Forwarder())
|
|
assert.NotNil(t, s.Slug())
|
|
}
|
|
|
|
func TestDetail(t *testing.T) {
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: &mockConfig{},
|
|
Conn: &ssh.ServerConn{},
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: &mockRegistry{},
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}
|
|
|
|
s := New(conf).(*session)
|
|
s.forwarder.SetType(types.TunnelTypeHTTP)
|
|
s.slug.Set("test-slug")
|
|
s.lifecycle.SetStatus(types.SessionStatusRUNNING)
|
|
|
|
detail := s.Detail()
|
|
assert.Equal(t, "HTTP", detail.ForwardingType)
|
|
assert.Equal(t, "test-slug", detail.Slug)
|
|
assert.Equal(t, "testuser", detail.UserID)
|
|
assert.True(t, detail.Active)
|
|
|
|
s.forwarder.SetType(types.TunnelTypeTCP)
|
|
detail = s.Detail()
|
|
assert.Equal(t, "TCP", detail.ForwardingType)
|
|
|
|
s.forwarder.SetType(types.TunnelTypeUNKNOWN)
|
|
detail = s.Detail()
|
|
assert.Equal(t, "UNKNOWN", detail.ForwardingType)
|
|
}
|
|
|
|
func TestIsBlockedPort(t *testing.T) {
|
|
tests := []struct {
|
|
port uint16
|
|
expected bool
|
|
}{
|
|
{80, false},
|
|
{443, false},
|
|
{22, true},
|
|
{1023, true},
|
|
{1024, false},
|
|
{1080, true},
|
|
{3306, true},
|
|
{8080, true},
|
|
{0, false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(fmt.Sprintf("Port %d", tt.port), func(t *testing.T) {
|
|
assert.Equal(t, tt.expected, isBlockedPort(tt.port))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHandleGlobalRequest(t *testing.T) {
|
|
_, sReqs, _, cConn, cleanup := setupSSH(t)
|
|
defer cleanup()
|
|
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: &mockConfig{},
|
|
Conn: &ssh.ServerConn{},
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: &mockRegistry{},
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
_ = s.HandleGlobalRequest(sReqs)
|
|
close(done)
|
|
}()
|
|
|
|
tests := []struct {
|
|
name string
|
|
reqType string
|
|
payload []byte
|
|
wantReply bool
|
|
expected bool
|
|
}{
|
|
{"shell", "shell", nil, true, true},
|
|
{"pty-req", "pty-req", nil, true, true},
|
|
{"window-change valid", "window-change", make([]byte, 16), true, true},
|
|
{"window-change invalid", "window-change", make([]byte, 4), true, false},
|
|
{"unknown", "unknown", nil, true, false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ok, _, err := cConn.SendRequest(tt.reqType, tt.wantReply, tt.payload)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expected, ok)
|
|
})
|
|
}
|
|
|
|
err := cConn.Close()
|
|
assert.NoError(t, err)
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("HandleGlobalRequest timed out after cConn.Close()")
|
|
}
|
|
}
|
|
|
|
func TestHandleTCPIPForward_Table(t *testing.T) {
|
|
setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
|
|
sConn, sReqs, _, cConn, cleanup := setupSSH(t)
|
|
mRegistry := &mockRegistry{}
|
|
mPort := &mockPort{}
|
|
mRandom := &mockRandom{}
|
|
conf := &Config{
|
|
Randomizer: mRandom,
|
|
Config: &mockConfig{},
|
|
Conn: sConn,
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: mRegistry,
|
|
PortRegistry: mPort,
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
return s, mRegistry, mPort, mRandom, sConn, sReqs, cConn, cleanup
|
|
}
|
|
|
|
t.Run("HTTP Forward Success", func(t *testing.T) {
|
|
s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mRandom.On("String", 20).Return("test-slug-1234567890", nil)
|
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 80)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
var req *ssh.Request
|
|
select {
|
|
case req = <-sReqs:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for tcpip-forward request")
|
|
}
|
|
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "test-slug-1234567890", s.slug.String())
|
|
})
|
|
|
|
t.Run("TCP Forward Success", func(t *testing.T) {
|
|
s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mPort.On("Claim", mock.Anything).Return(true)
|
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 0)
|
|
|
|
mPort.On("Unassigned").Return(uint16(12345), true)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
var req *ssh.Request
|
|
select {
|
|
case req = <-sReqs:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for tcpip-forward request")
|
|
}
|
|
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, uint16(12345), s.forwarder.ForwardedPort())
|
|
})
|
|
|
|
t.Run("Invalid Payload", func(t *testing.T) {
|
|
s, _, _, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
payload := []byte{0, 0, 0}
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
var req *ssh.Request
|
|
select {
|
|
case req = <-sReqs:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for tcpip-forward request")
|
|
}
|
|
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("Blocked Port", func(t *testing.T) {
|
|
s, _, _, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 22)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
var req *ssh.Request
|
|
select {
|
|
case req = <-sReqs:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for tcpip-forward request")
|
|
}
|
|
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestStart_Table(t *testing.T) {
|
|
setup := func(t *testing.T) (*session, *Config, ssh.Conn, func()) {
|
|
sConn, sReqs, sChans, cConn, cleanup := setupSSH(t)
|
|
mRegistry := &mockRegistry{}
|
|
mPort := &mockPort{}
|
|
mRandom := &mockRandom{}
|
|
mConfig := &mockConfig{}
|
|
mConfig.On("Mode").Return(types.ServerModeSTANDALONE)
|
|
mConfig.On("Domain").Return("example.com")
|
|
mConfig.On("SSHPort").Return("2222")
|
|
|
|
conf := &Config{
|
|
Randomizer: mRandom,
|
|
Config: mConfig,
|
|
Conn: sConn,
|
|
InitialReq: sReqs,
|
|
SshChan: sChans,
|
|
SessionRegistry: mRegistry,
|
|
PortRegistry: mPort,
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
return s, conf, cConn, cleanup
|
|
}
|
|
|
|
t.Run("Full Success TCP", func(t *testing.T) {
|
|
s, conf, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 0)
|
|
|
|
conf.PortRegistry.(*mockPort).On("Claim", mock.Anything).Return(true)
|
|
conf.PortRegistry.(*mockPort).On("Unassigned").Return(uint16(0), true)
|
|
conf.PortRegistry.(*mockPort).On("SetStatus", mock.AnythingOfType("uint16"), mock.Anything).Return(nil)
|
|
conf.SessionRegistry.(*mockRegistry).On("Register", mock.Anything, mock.Anything).Return(true)
|
|
conf.Config.(*mockConfig).On("TLSEnabled").Return(false)
|
|
go func() {
|
|
time.Sleep(200 * time.Millisecond)
|
|
ch, reqs, err := cConn.OpenChannel("session", nil)
|
|
if err == nil {
|
|
go ssh.DiscardRequests(reqs)
|
|
time.Sleep(200 * time.Millisecond)
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
time.Sleep(200 * time.Millisecond)
|
|
write, err := ch.Write([]byte("q"))
|
|
assert.NoError(t, err)
|
|
assert.NotZero(t, write)
|
|
time.Sleep(100 * time.Millisecond)
|
|
_ = ch.Close()
|
|
}
|
|
_ = cConn.Close()
|
|
}()
|
|
|
|
err := s.Start()
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("Headless mode success", func(t *testing.T) {
|
|
s, conf, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 80)
|
|
|
|
conf.Randomizer.(*mockRandom).On("String", 20).Return("headless-slug", nil)
|
|
conf.SessionRegistry.(*mockRegistry).On("Register", mock.Anything, mock.Anything).Return(true)
|
|
|
|
go func() {
|
|
time.Sleep(600 * time.Millisecond)
|
|
_, _, err := cConn.SendRequest("tcpip-forward", true, payload)
|
|
assert.NoError(t, err)
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
err = cConn.Close()
|
|
assert.NoError(t, err)
|
|
|
|
}()
|
|
|
|
err := s.Start()
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("Missing Forward Request", func(t *testing.T) {
|
|
s, _, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
|
|
go func() {
|
|
time.Sleep(1200 * time.Millisecond)
|
|
_ = cConn.Close()
|
|
}()
|
|
|
|
err := s.Start()
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no forwarding Request")
|
|
})
|
|
|
|
t.Run("Unauthorized Headless", func(t *testing.T) {
|
|
_, conf, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
|
|
conf.User = "UNAUTHORIZED"
|
|
s := New(conf).(*session)
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 80)
|
|
|
|
go func() {
|
|
time.Sleep(600 * time.Millisecond)
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
err := s.Start()
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestForwardingFailures(t *testing.T) {
|
|
setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
|
|
sConn, sReqs, _, cConn, cleanup := setupSSH(t)
|
|
mRegistry := &mockRegistry{}
|
|
mPort := &mockPort{}
|
|
mRandom := &mockRandom{}
|
|
conf := &Config{
|
|
Randomizer: mRandom,
|
|
Config: &mockConfig{},
|
|
Conn: sConn,
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: mRegistry,
|
|
PortRegistry: mPort,
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
return s, mRegistry, mPort, mRandom, sConn, sReqs, cConn, cleanup
|
|
}
|
|
|
|
t.Run("HTTP Registration Failed", func(t *testing.T) {
|
|
s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mRandom.On("String", 20).Return("test-slug", nil)
|
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(false)
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 80)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
var req *ssh.Request
|
|
select {
|
|
case req = <-sReqs:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for tcpip-forward request")
|
|
}
|
|
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("TCP Port Claim Failed", func(t *testing.T) {
|
|
s, _, mPort, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mPort.On("Claim", mock.Anything).Return(false)
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 1234)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
var req *ssh.Request
|
|
select {
|
|
case req = <-sReqs:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timed out waiting for tcpip-forward request")
|
|
}
|
|
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("HTTP Randomizer Error", func(t *testing.T) {
|
|
s, _, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mRandom.On("String", 20).Return("", fmt.Errorf("random error"))
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 80)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
req := <-sReqs
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "random error")
|
|
})
|
|
|
|
t.Run("Port Registry No Port", func(t *testing.T) {
|
|
s, _, mPort, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mPort.On("Unassigned").Return(uint16(0), false)
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 0)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
req := <-sReqs
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no available port")
|
|
})
|
|
|
|
t.Run("Port too large", func(t *testing.T) {
|
|
s, _, _, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 70000)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
req := <-sReqs
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "port is larger than allowed")
|
|
})
|
|
|
|
t.Run("TCP Registration Failed", func(t *testing.T) {
|
|
s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mPort.On("Claim", mock.Anything).Return(true)
|
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(false)
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 1234)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
req := <-sReqs
|
|
err := s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "Failed to register TunnelTypeTCP client")
|
|
})
|
|
|
|
t.Run("Finalize Forwarding Failure", func(t *testing.T) {
|
|
s, mRegistry, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mRandom.On("String", 20).Return("test-slug", nil)
|
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], 80)
|
|
|
|
go func() {
|
|
_, _, err := cConn.SendRequest("tcpip-forward", true, payload)
|
|
assert.Error(t, err, io.EOF)
|
|
}()
|
|
|
|
req := <-sReqs
|
|
err := cConn.Close()
|
|
assert.NoError(t, err)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
err = s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("TCP Listen Failure", func(t *testing.T) {
|
|
s, mRegistry, mPort, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mPort.On("Claim", mock.Anything).Return(true)
|
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
|
|
l, err := net.Listen("tcp", "0.0.0.0:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func(l net.Listener) {
|
|
err = l.Close()
|
|
assert.NoError(t, err)
|
|
}(l)
|
|
_, portStr, _ := net.SplitHostPort(l.Addr().String())
|
|
port, _ := strconv.Atoi(portStr)
|
|
|
|
payload := make([]byte, 4+9+4)
|
|
binary.BigEndian.PutUint32(payload[0:4], 9)
|
|
copy(payload[4:13], "localhost")
|
|
binary.BigEndian.PutUint32(payload[13:17], uint32(port))
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, payload)
|
|
}()
|
|
|
|
req := <-sReqs
|
|
err = s.HandleTCPIPForward(req)
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "is already in use or restricted")
|
|
})
|
|
}
|
|
|
|
func TestSetupInteractiveMode_Error(t *testing.T) {
|
|
sConn, _, sChans, _, cleanup := setupSSH(t)
|
|
defer cleanup()
|
|
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: &mockConfig{},
|
|
Conn: sConn,
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: sChans,
|
|
SessionRegistry: &mockRegistry{},
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
|
|
mockChan := &mockNewChanFail{}
|
|
err := s.setupInteractiveMode(mockChan)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
}
|
|
|
|
type mockNewChanFail struct {
|
|
ssh.NewChannel
|
|
}
|
|
|
|
func (m *mockNewChanFail) Accept() (ssh.Channel, <-chan *ssh.Request, error) {
|
|
return nil, nil, fmt.Errorf("accept failed")
|
|
}
|
|
|
|
func TestWaitForTCPIPForward_EdgeCases(t *testing.T) {
|
|
t.Run("Wrong Request Type", func(t *testing.T) {
|
|
_, sReqs, _, cConn, cleanup := setupSSH(t)
|
|
defer cleanup()
|
|
|
|
s := &session{initialReq: sReqs}
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("not-tcpip-forward", true, nil)
|
|
}()
|
|
|
|
req := s.waitForTCPIPForward()
|
|
if req != nil {
|
|
t.Error("expected nil request")
|
|
}
|
|
})
|
|
|
|
t.Run("Channel Closed", func(t *testing.T) {
|
|
initialReq := make(chan *ssh.Request)
|
|
s := &session{initialReq: initialReq}
|
|
close(initialReq)
|
|
|
|
req := s.waitForTCPIPForward()
|
|
if req != nil {
|
|
t.Error("expected nil request")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestSetupSessionMode_ChannelClosed(t *testing.T) {
|
|
sshChan := make(chan ssh.NewChannel)
|
|
s := &session{sshChan: sshChan}
|
|
close(sshChan)
|
|
|
|
err := s.setupSessionMode()
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestStart_SetupSessionModeError(t *testing.T) {
|
|
sshChan := make(chan ssh.NewChannel, 1)
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: &mockConfig{},
|
|
Conn: &ssh.ServerConn{},
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: sshChan,
|
|
SessionRegistry: &mockRegistry{},
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
|
|
mockChan := &mockNewChanFail{}
|
|
sshChan <- mockChan
|
|
|
|
err := s.Start()
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
}
|
|
|
|
func TestWaitForSessionEnd_Error(t *testing.T) {
|
|
mConn := &mockSSHConn{}
|
|
mConn.On("Wait").Return(fmt.Errorf("wait error"))
|
|
mConn.On("Close").Return(nil)
|
|
|
|
mForwarder := &mockLifecycleForwarder{}
|
|
mForwarder.On("TunnelType").Return(types.TunnelTypeTCP)
|
|
mForwarder.On("ForwardedPort").Return(uint16(80))
|
|
mForwarder.On("Close").Return(fmt.Errorf("close error"))
|
|
|
|
mSlug := &mockLifecycleSlug{}
|
|
mSlug.On("String").Return("slug")
|
|
|
|
mPort := &mockPort{}
|
|
mPort.On("SetStatus", mock.Anything, mock.Anything).Return(nil)
|
|
|
|
mRegistry := &mockRegistry{}
|
|
mRegistry.On("Remove", mock.Anything).Return()
|
|
|
|
l := lifecycle.New(mConn, mForwarder, mSlug, mPort, mRegistry, "testuser")
|
|
s := &session{
|
|
lifecycle: l,
|
|
}
|
|
|
|
err := s.waitForSessionEnd()
|
|
assert.Error(t, err)
|
|
}
|
|
|
|
type mockLifecycleForwarder struct {
|
|
mock.Mock
|
|
lifecycle.Forwarder
|
|
}
|
|
|
|
func (m *mockLifecycleForwarder) TunnelType() types.TunnelType {
|
|
return m.Called().Get(0).(types.TunnelType)
|
|
}
|
|
func (m *mockLifecycleForwarder) ForwardedPort() uint16 {
|
|
args := m.Called()
|
|
if args.Get(0) == nil {
|
|
return 0
|
|
}
|
|
switch v := args.Get(0).(type) {
|
|
case uint16:
|
|
return v
|
|
case uint32:
|
|
return uint16(v)
|
|
case uint64:
|
|
return uint16(v)
|
|
case uint8:
|
|
return uint16(v)
|
|
case uint:
|
|
return uint16(v)
|
|
case int:
|
|
return uint16(v)
|
|
case int8:
|
|
return uint16(v)
|
|
case int16:
|
|
return uint16(v)
|
|
case int32:
|
|
return uint16(v)
|
|
case int64:
|
|
return uint16(v)
|
|
case float32:
|
|
return uint16(v)
|
|
case float64:
|
|
return uint16(v)
|
|
default:
|
|
return uint16(args.Int(0))
|
|
}
|
|
}
|
|
func (m *mockLifecycleForwarder) Close() error {
|
|
return m.Called().Error(0)
|
|
}
|
|
|
|
type mockLifecycleSlug struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockLifecycleSlug) String() string { return m.Called().String(0) }
|
|
func (m *mockLifecycleSlug) Set(slug string) {
|
|
m.Called(slug)
|
|
}
|
|
|
|
func TestHandleMissingForwardRequest(t *testing.T) {
|
|
mConn := &mockSSHConn{}
|
|
mConfig := &mockConfig{}
|
|
mConfig.On("Domain").Return("example.com")
|
|
mConfig.On("SSHPort").Return("2222")
|
|
mConn.On("Close").Return(nil)
|
|
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: mConfig,
|
|
Conn: &ssh.ServerConn{Conn: mConn},
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: &mockRegistry{},
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}
|
|
|
|
s := New(conf).(*session)
|
|
|
|
err := s.handleMissingForwardRequest()
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
}
|
|
|
|
func TestParseForwardPayload_Errors(t *testing.T) {
|
|
s := &session{}
|
|
|
|
t.Run("Short Address", func(t *testing.T) {
|
|
_, _, err := s.parseForwardPayload([]byte{0, 0, 0, 4})
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
})
|
|
|
|
t.Run("Short Port", func(t *testing.T) {
|
|
payload := append([]byte{0, 0, 0, 4}, []byte("addr")...)
|
|
_, _, err := s.parseForwardPayload(payload)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
})
|
|
|
|
t.Run("Blocked Port", func(t *testing.T) {
|
|
payload := append([]byte{0, 0, 0, 4}, []byte("addr")...)
|
|
portBuf := make([]byte, 4)
|
|
binary.BigEndian.PutUint32(portBuf, 22)
|
|
payload = append(payload, portBuf...)
|
|
_, _, err := s.parseForwardPayload(payload)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), "port is block") {
|
|
t.Errorf("expected error to contain %q, got %q", "port is block", err.Error())
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDenyForwardingRequest_TunnelNotSetupYet(t *testing.T) {
|
|
sConn, sReqs, _, cConn, cleanup := setupSSH(t)
|
|
defer cleanup()
|
|
|
|
mRegistry := &mockRegistry{}
|
|
mPort := &mockPort{}
|
|
mRandom := &mockRandom{}
|
|
conf := &Config{
|
|
Randomizer: mRandom,
|
|
Config: &mockConfig{},
|
|
Conn: sConn,
|
|
InitialReq: sReqs,
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: mRegistry,
|
|
PortRegistry: mPort,
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
|
|
go func() {
|
|
_, _, _ = cConn.SendRequest("tcpip-forward", true, nil)
|
|
}()
|
|
|
|
var req *ssh.Request
|
|
select {
|
|
case req = <-sReqs:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
key := &types.SessionKey{Id: "", Type: types.TunnelTypeUNKNOWN}
|
|
err := s.denyForwardingRequest(req, key, &mockCloser{}, "test error")
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), "test error") {
|
|
t.Errorf("expected error to contain %q, got %q", "test error", err.Error())
|
|
}
|
|
assert.Equal(t, *key, mRegistry.removedKey)
|
|
}
|
|
|
|
func TestDenyForwardingRequest_Full(t *testing.T) {
|
|
setup := func(t *testing.T) (*session, *mockRegistry, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
|
|
sConn, sReqs, _, cConn, cleanup := setupSSH(t)
|
|
mRegistry := &mockRegistry{}
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: &mockConfig{},
|
|
Conn: sConn,
|
|
InitialReq: sReqs,
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: mRegistry,
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
return s, mRegistry, sConn, sReqs, cConn, cleanup
|
|
}
|
|
|
|
getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request {
|
|
go func() {
|
|
_, _, _ = client.SendRequest("tcpip-forward", true, nil)
|
|
}()
|
|
select {
|
|
case req, ok := <-serverReqs:
|
|
if !ok {
|
|
t.Fatal("channel closed")
|
|
}
|
|
return req
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout getting request")
|
|
return nil
|
|
}
|
|
}
|
|
|
|
t.Run("All Success", func(t *testing.T) {
|
|
s, mRegistry, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
req := getReq(t, cConn, sReqs)
|
|
key := &types.SessionKey{Id: "test", Type: types.TunnelTypeHTTP}
|
|
|
|
s.slug.Set("test")
|
|
s.forwarder.SetType(types.TunnelTypeHTTP)
|
|
|
|
mCloser := &mockCloser{}
|
|
err := s.denyForwardingRequest(req, key, mCloser, "error")
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), "error") {
|
|
t.Errorf("expected error to contain %q, got %q", "error", err.Error())
|
|
}
|
|
assert.Equal(t, *key, mRegistry.removedKey)
|
|
})
|
|
|
|
t.Run("Listener Close error", func(t *testing.T) {
|
|
s, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
req := getReq(t, cConn, sReqs)
|
|
mCloser := &mockCloser{err: fmt.Errorf("close error")}
|
|
err := s.denyForwardingRequest(req, nil, mCloser, "error")
|
|
assert.Error(t, err, net.ErrClosed)
|
|
})
|
|
|
|
t.Run("Reply error", func(t *testing.T) {
|
|
s, _, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
req := getReq(t, cConn, sReqs)
|
|
err := cConn.Close()
|
|
assert.NoError(t, err)
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
err = s.denyForwardingRequest(req, nil, nil, assert.AnError.Error())
|
|
assert.Error(t, err, assert.AnError)
|
|
})
|
|
}
|
|
|
|
func TestHandleTCPForward_Failures(t *testing.T) {
|
|
setup := func(t *testing.T) (*session, *mockRegistry, *mockPort, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
|
|
sConn, sReqs, _, cConn, cleanup := setupSSH(t)
|
|
mRegistry := &mockRegistry{}
|
|
mPort := &mockPort{}
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: &mockConfig{},
|
|
Conn: sConn,
|
|
InitialReq: sReqs,
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: mRegistry,
|
|
PortRegistry: mPort,
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
return s, mRegistry, mPort, sConn, sReqs, cConn, cleanup
|
|
}
|
|
|
|
getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request {
|
|
go func() {
|
|
_, _, _ = client.SendRequest("tcpip-forward", true, nil)
|
|
}()
|
|
select {
|
|
case req, ok := <-serverReqs:
|
|
if !ok {
|
|
t.Fatal("channel closed")
|
|
}
|
|
return req
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout getting request")
|
|
return nil
|
|
}
|
|
}
|
|
|
|
t.Run("Port Claim fail", func(t *testing.T) {
|
|
s, _, mPort, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mPort.On("Claim", mock.Anything).Return(false)
|
|
err := s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", 1234)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), "already in use") {
|
|
t.Errorf("expected error to contain %q, got %q", "already in use", err.Error())
|
|
}
|
|
})
|
|
|
|
t.Run("Listen fail", func(t *testing.T) {
|
|
s, _, mPort, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mPort.On("Claim", mock.Anything).Return(true)
|
|
l, err := net.Listen("tcp", "0.0.0.0:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer func(l net.Listener) {
|
|
err = l.Close()
|
|
assert.NoError(t, err)
|
|
}(l)
|
|
port := uint16(l.Addr().(*net.TCPAddr).Port)
|
|
|
|
err = s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", port)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), "already in use") {
|
|
t.Errorf("expected error to contain %q, got %q", "already in use", err.Error())
|
|
}
|
|
})
|
|
|
|
t.Run("Registry Register fail", func(t *testing.T) {
|
|
s, mRegistry, mPort, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mPort.On("Claim", mock.Anything).Return(true)
|
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(false)
|
|
err := s.HandleTCPForward(getReq(t, cConn, sReqs), "localhost", 0)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), "Failed to register") {
|
|
t.Errorf("expected error to contain %q, got %q", "Failed to register", err.Error())
|
|
}
|
|
})
|
|
|
|
t.Run("Finalize fail (Reply fail)", func(t *testing.T) {
|
|
s, mRegistry, mPort, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mPort.On("Claim", mock.Anything).Return(true)
|
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
req := getReq(t, cConn, sReqs)
|
|
err := cConn.Close()
|
|
assert.NoError(t, err)
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
err = s.HandleTCPForward(req, "localhost", 0)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), "Failed to finalize forwarding") {
|
|
t.Errorf("expected error to contain %q, got %q", "Failed to finalize forwarding", err.Error())
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHandleHTTPForward_Failures(t *testing.T) {
|
|
setup := func(t *testing.T) (*session, *mockRegistry, *mockRandom, *ssh.ServerConn, <-chan *ssh.Request, ssh.Conn, func()) {
|
|
sConn, sReqs, _, cConn, cleanup := setupSSH(t)
|
|
mRegistry := &mockRegistry{}
|
|
mRandom := &mockRandom{}
|
|
s := New(&Config{
|
|
Randomizer: mRandom,
|
|
Config: &mockConfig{},
|
|
Conn: sConn,
|
|
InitialReq: sReqs,
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: mRegistry,
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}).(*session)
|
|
return s, mRegistry, mRandom, sConn, sReqs, cConn, cleanup
|
|
}
|
|
|
|
getReq := func(t *testing.T, client ssh.Conn, serverReqs <-chan *ssh.Request) *ssh.Request {
|
|
go func() { _, _, _ = client.SendRequest("tcpip-forward", true, nil) }()
|
|
return <-serverReqs
|
|
}
|
|
|
|
t.Run("Random fail", func(t *testing.T) {
|
|
s, _, mRandom, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mRandom.On("String", 20).Return("", fmt.Errorf("random error"))
|
|
err := s.HandleHTTPForward(getReq(t, cConn, sReqs), 80)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), "Failed to create slug") {
|
|
t.Errorf("expected error to contain %q, got %q", "Failed to create slug", err.Error())
|
|
}
|
|
})
|
|
|
|
t.Run("Register fail", func(t *testing.T) {
|
|
s, mRegistry, mRandom, _, sReqs, cConn, cleanup := setup(t)
|
|
defer cleanup()
|
|
mRandom.On("String", 20).Return("slug", nil)
|
|
mRegistry.On("Register", mock.Anything, mock.Anything).Return(false)
|
|
err := s.HandleHTTPForward(getReq(t, cConn, sReqs), 80)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), "Failed to register") {
|
|
t.Errorf("expected error to contain %q, got %q", "Failed to register", err.Error())
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestHandleGlobalRequest_Failures(t *testing.T) {
|
|
_, sReqs, _, cConn, cleanup := setupSSH(t)
|
|
defer cleanup()
|
|
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: &mockConfig{},
|
|
Conn: &ssh.ServerConn{},
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: make(chan ssh.NewChannel),
|
|
SessionRegistry: &mockRegistry{},
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
_ = s.HandleGlobalRequest(sReqs)
|
|
close(done)
|
|
}()
|
|
|
|
tests := []struct {
|
|
name string
|
|
reqType string
|
|
payload []byte
|
|
wantReply bool
|
|
expected bool
|
|
}{
|
|
{"shell", "shell", nil, true, true},
|
|
{"pty-req", "pty-req", nil, true, true},
|
|
{"window-change valid", "window-change", make([]byte, 16), true, true},
|
|
{"window-change invalid", "window-change", make([]byte, 4), true, false},
|
|
{"unknown", "unknown", nil, true, false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ok, _, err := cConn.SendRequest(tt.reqType, tt.wantReply, tt.payload)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expected, ok)
|
|
})
|
|
}
|
|
|
|
err := cConn.Close()
|
|
assert.NoError(t, err)
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("HandleGlobalRequest timed out after cConn.Close()")
|
|
}
|
|
}
|
|
|
|
func TestSetupInteractiveMode_GlobalRequestError(t *testing.T) {
|
|
sConn, _, sChans, _, cleanup := setupSSH(t)
|
|
defer cleanup()
|
|
|
|
conf := &Config{
|
|
Randomizer: &mockRandom{},
|
|
Config: &mockConfig{},
|
|
Conn: sConn,
|
|
InitialReq: make(chan *ssh.Request),
|
|
SshChan: sChans,
|
|
SessionRegistry: &mockRegistry{},
|
|
PortRegistry: &mockPort{},
|
|
User: "testuser",
|
|
}
|
|
s := New(conf).(*session)
|
|
|
|
mockChan := &mockNewChanFail{}
|
|
err := s.setupInteractiveMode(mockChan)
|
|
if err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
}
|
|
|
|
type mockCloser struct {
|
|
err error
|
|
}
|
|
|
|
func (m *mockCloser) Close() error { return m.err }
|