860 lines
21 KiB
Go
860 lines
21 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
"tunnel_pls/internal/registry"
|
|
"tunnel_pls/session/slug"
|
|
"tunnel_pls/types"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"golang.org/x/crypto/ssh"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
type MockRandom struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockRandom) String(length int) (string, error) {
|
|
args := m.Called(length)
|
|
return args.String(0), args.Error(1)
|
|
}
|
|
|
|
type MockConfig struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockConfig) Domain() string { return m.Called().String(0) }
|
|
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
|
|
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
|
|
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
|
|
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
|
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
|
func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) }
|
|
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
|
|
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
|
|
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
|
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
|
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
|
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
|
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
|
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
|
func (m *MockConfig) Mode() types.ServerMode {
|
|
args := m.Called()
|
|
if args.Get(0) == nil {
|
|
return 0
|
|
}
|
|
switch v := args.Get(0).(type) {
|
|
case types.ServerMode:
|
|
return v
|
|
case int:
|
|
return types.ServerMode(v)
|
|
default:
|
|
return types.ServerMode(args.Int(0))
|
|
}
|
|
}
|
|
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
|
|
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
|
|
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
|
func (m *MockConfig) KeyLoc() string { return m.Called().String(0) }
|
|
|
|
type MockSessionRegistry struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
|
args := m.Called(key)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(registry.Session), args.Error(1)
|
|
}
|
|
|
|
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
|
args := m.Called(user, key)
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(registry.Session), args.Error(1)
|
|
}
|
|
|
|
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
|
args := m.Called(user, oldKey, newKey)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
|
args := m.Called(key, session)
|
|
return args.Bool(0)
|
|
}
|
|
|
|
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
|
m.Called(key)
|
|
}
|
|
|
|
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
|
args := m.Called(user)
|
|
return args.Get(0).([]registry.Session)
|
|
}
|
|
|
|
func (m *MockSessionRegistry) Slug() slug.Slug {
|
|
args := m.Called()
|
|
return args.Get(0).(slug.Slug)
|
|
}
|
|
|
|
type MockGRPCClient struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockGRPCClient) ClientConn() *grpc.ClientConn {
|
|
args := m.Called()
|
|
return args.Get(0).(*grpc.ClientConn)
|
|
}
|
|
|
|
func (m *MockGRPCClient) AuthorizeConn(ctx context.Context, token string) (authorized bool, user string, err error) {
|
|
args := m.Called(ctx, token)
|
|
return args.Bool(0), args.String(1), args.Error(2)
|
|
}
|
|
|
|
func (m *MockGRPCClient) CheckServerHealth(ctx context.Context) error {
|
|
args := m.Called(ctx)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockGRPCClient) SubscribeEvents(ctx context.Context, domain, token string) error {
|
|
args := m.Called(ctx, domain, token)
|
|
return args.Error(0)
|
|
}
|
|
|
|
func (m *MockGRPCClient) Close() error {
|
|
args := m.Called()
|
|
return args.Error(0)
|
|
}
|
|
|
|
type MockPort struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockPort) AddRange(startPort, endPort uint16) error {
|
|
return m.Called(startPort, endPort).Error(0)
|
|
}
|
|
|
|
func (m *MockPort) Unassigned() (uint16, bool) {
|
|
args := m.Called()
|
|
return uint16(args.Int(0)), args.Bool(1)
|
|
}
|
|
|
|
func (m *MockPort) SetStatus(port uint16, assigned bool) error {
|
|
return m.Called(port, assigned).Error(0)
|
|
}
|
|
|
|
func (m *MockPort) Claim(port uint16) bool {
|
|
return m.Called(port).Bool(0)
|
|
}
|
|
|
|
type MockListener struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockListener) Accept() (net.Conn, error) {
|
|
args := m.Called()
|
|
if args.Get(0) == nil {
|
|
return nil, args.Error(1)
|
|
}
|
|
return args.Get(0).(net.Conn), args.Error(1)
|
|
}
|
|
|
|
func (m *MockListener) Close() error {
|
|
return m.Called().Error(0)
|
|
}
|
|
|
|
func (m *MockListener) Addr() net.Addr {
|
|
return m.Called().Get(0).(net.Addr)
|
|
}
|
|
|
|
type MockSession struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *MockSession) Start() error {
|
|
return m.Called().Error(0)
|
|
}
|
|
|
|
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
|
|
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
|
signer, _ := ssh.NewSignerFromKey(key)
|
|
config := &ssh.ServerConfig{
|
|
NoClientAuth: true,
|
|
}
|
|
config.AddHostKey(signer)
|
|
return config, signer
|
|
}
|
|
|
|
func TestNew(t *testing.T) {
|
|
mr := new(MockRandom)
|
|
mc := new(MockConfig)
|
|
mreg := new(MockSessionRegistry)
|
|
mg := new(MockGRPCClient)
|
|
mp := new(MockPort)
|
|
sc, _ := getTestSSHConfig()
|
|
|
|
tests := []struct {
|
|
name string
|
|
port string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "success",
|
|
port: "0",
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid port",
|
|
port: "invalid",
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
s, err := New(mr, mc, sc, mreg, mg, mp, tt.port)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, s)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, s)
|
|
_ = s.Close()
|
|
}
|
|
})
|
|
}
|
|
|
|
t.Run("port already in use", func(t *testing.T) {
|
|
l, err := net.Listen("tcp", ":0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
port := l.Addr().(*net.TCPAddr).Port
|
|
defer l.Close()
|
|
|
|
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
|
|
assert.Error(t, err)
|
|
assert.Nil(t, s)
|
|
})
|
|
}
|
|
|
|
func TestClose(t *testing.T) {
|
|
mr := new(MockRandom)
|
|
mc := new(MockConfig)
|
|
mreg := new(MockSessionRegistry)
|
|
mg := new(MockGRPCClient)
|
|
mp := new(MockPort)
|
|
sc, _ := getTestSSHConfig()
|
|
|
|
t.Run("successful close", func(t *testing.T) {
|
|
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
|
err := s.Close()
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("close already closed listener", func(t *testing.T) {
|
|
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
|
_ = s.Close()
|
|
err := s.Close()
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("close with nil listener", func(t *testing.T) {
|
|
s := &server{
|
|
sshListener: nil,
|
|
}
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
assert.NotNil(t, r)
|
|
}
|
|
}()
|
|
_ = s.Close()
|
|
t.Fatal("expected panic for nil listener")
|
|
})
|
|
}
|
|
|
|
func TestStart(t *testing.T) {
|
|
mr := new(MockRandom)
|
|
mc := new(MockConfig)
|
|
mreg := new(MockSessionRegistry)
|
|
mg := new(MockGRPCClient)
|
|
mp := new(MockPort)
|
|
sc, _ := getTestSSHConfig()
|
|
|
|
t.Run("normal stop", func(t *testing.T) {
|
|
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
|
go func() {
|
|
time.Sleep(100 * time.Millisecond)
|
|
_ = s.Close()
|
|
}()
|
|
s.Start()
|
|
})
|
|
|
|
t.Run("accept error - temporary error continues loop", func(t *testing.T) {
|
|
ml := new(MockListener)
|
|
s := &server{
|
|
sshListener: ml,
|
|
sshPort: "0",
|
|
}
|
|
|
|
ml.On("Accept").Return(nil, errors.New("temporary error")).Once()
|
|
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
|
|
s.Start()
|
|
ml.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("accept error - immediate close", func(t *testing.T) {
|
|
ml := new(MockListener)
|
|
s := &server{
|
|
sshListener: ml,
|
|
sshPort: "0",
|
|
}
|
|
|
|
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
|
|
s.Start()
|
|
ml.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("accept success - connection fails SSH handshake", func(t *testing.T) {
|
|
mockRandom := &MockRandom{}
|
|
mockConfig := &MockConfig{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockGrpcClient := &MockGRPCClient{}
|
|
mockPort := &MockPort{}
|
|
|
|
sshConfig, _ := getTestSSHConfig()
|
|
|
|
serverConn, clientConn := net.Pipe()
|
|
|
|
mockListener := &MockListener{}
|
|
mockListener.On("Accept").Return(serverConn, nil).Once()
|
|
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
|
|
s := &server{
|
|
randomizer: mockRandom,
|
|
config: mockConfig,
|
|
sshPort: "0",
|
|
sshListener: mockListener,
|
|
sshConfig: sshConfig,
|
|
grpcClient: mockGrpcClient,
|
|
sessionRegistry: mockSessionRegistry,
|
|
portRegistry: mockPort,
|
|
}
|
|
|
|
go s.Start()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
clientConn.Close()
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
mockListener.AssertExpectations(t)
|
|
})
|
|
|
|
t.Run("accept success - valid SSH connection without auth", func(t *testing.T) {
|
|
mockRandom := &MockRandom{}
|
|
mockConfig := &MockConfig{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockPort := &MockPort{}
|
|
|
|
sshConfig, _ := getTestSSHConfig()
|
|
|
|
serverConn, clientConn := net.Pipe()
|
|
|
|
mockListener := &MockListener{}
|
|
mockListener.On("Accept").Return(serverConn, nil).Once()
|
|
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
|
|
s := &server{
|
|
randomizer: mockRandom,
|
|
config: mockConfig,
|
|
sshPort: "0",
|
|
sshListener: mockListener,
|
|
sshConfig: sshConfig,
|
|
grpcClient: nil,
|
|
sessionRegistry: mockSessionRegistry,
|
|
portRegistry: mockPort,
|
|
}
|
|
|
|
go s.Start()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
clientConn.Close()
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
mockListener.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestHandleConnection(t *testing.T) {
|
|
t.Run("SSH handshake fails - connection closed", func(t *testing.T) {
|
|
mockRandom := &MockRandom{}
|
|
mockConfig := &MockConfig{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockGrpcClient := &MockGRPCClient{}
|
|
mockPort := &MockPort{}
|
|
|
|
sshConfig, _ := getTestSSHConfig()
|
|
|
|
serverConn, clientConn := net.Pipe()
|
|
|
|
s := &server{
|
|
randomizer: mockRandom,
|
|
config: mockConfig,
|
|
sshPort: "0",
|
|
sshConfig: sshConfig,
|
|
grpcClient: mockGrpcClient,
|
|
sessionRegistry: mockSessionRegistry,
|
|
portRegistry: mockPort,
|
|
}
|
|
|
|
clientConn.Close()
|
|
s.handleConnection(serverConn)
|
|
})
|
|
|
|
// SSH SERVER SUCH PAIN IN THE ASS TO BE UNIT TEST, I FUCKING HATE THIS
|
|
// GONNA IMPLEMENT THIS UNIT TEST LATER
|
|
|
|
//t.Run("SSH handshake fails - invalid protocol", func(t *testing.T) {
|
|
// mockRandom := &MockRandom{}
|
|
// mockConfig := &MockConfig{}
|
|
// mockSessionRegistry := &MockSessionRegistry{}
|
|
// mockGrpcClient := &MockGRPCClient{}
|
|
// mockPort := &MockPort{}
|
|
//
|
|
// sshConfig, _ := getTestSSHConfig()
|
|
//
|
|
// serverConn, clientConn := net.Pipe()
|
|
//
|
|
// s := &server{
|
|
// randomizer: mockRandom,
|
|
// config: mockConfig,
|
|
// sshPort: "0",
|
|
// sshConfig: sshConfig,
|
|
// grpcClient: mockGrpcClient,
|
|
// sessionRegistry: mockSessionRegistry,
|
|
// portRegistry: mockPort,
|
|
// }
|
|
//
|
|
// done := make(chan bool, 1)
|
|
//
|
|
// go func() {
|
|
// s.handleConnection(serverConn)
|
|
// done <- true
|
|
// }()
|
|
//
|
|
// go func() {
|
|
// clientConn.Write([]byte("invalid ssh protocol\n"))
|
|
// clientConn.Close()
|
|
// }()
|
|
//
|
|
// select {
|
|
// case <-done:
|
|
// case <-time.After(1 * time.Second):
|
|
// t.Fatal("handleConnection did not complete in time")
|
|
// }
|
|
//})
|
|
|
|
t.Run("SSH connection established without gRPC client", func(t *testing.T) {
|
|
mockRandom := &MockRandom{}
|
|
mockConfig := &MockConfig{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockPort := &MockPort{}
|
|
|
|
serverConfig, _ := getTestSSHConfig()
|
|
|
|
mockConfig.On("Domain").Return("test.com")
|
|
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
|
mockConfig.On("SSHPort").Return("2200")
|
|
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
|
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
serverAddr := listener.Addr().String()
|
|
|
|
s := &server{
|
|
randomizer: mockRandom,
|
|
config: mockConfig,
|
|
sshPort: "0",
|
|
sshConfig: serverConfig,
|
|
grpcClient: nil,
|
|
sessionRegistry: mockSessionRegistry,
|
|
portRegistry: mockPort,
|
|
}
|
|
|
|
done := make(chan bool, 1)
|
|
|
|
go func() {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.handleConnection(conn)
|
|
done <- true
|
|
}()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
clientConfig := &ssh.ClientConfig{
|
|
User: "testuser",
|
|
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 2 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
|
if err != nil {
|
|
t.Logf("Client dial failed: %v", err)
|
|
return
|
|
}
|
|
defer client.Close()
|
|
|
|
type forwardPayload struct {
|
|
BindAddr string
|
|
BindPort uint32
|
|
}
|
|
|
|
payload := ssh.Marshal(forwardPayload{
|
|
BindAddr: "localhost",
|
|
BindPort: 80,
|
|
})
|
|
|
|
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
|
if err != nil {
|
|
t.Logf("Forward request failed: %v", err)
|
|
}
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
t.Log("handleConnection completed")
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("handleConnection did not complete in time")
|
|
}
|
|
})
|
|
|
|
t.Run("SSH connection established with gRPC authorization", func(t *testing.T) {
|
|
mockRandom := &MockRandom{}
|
|
mockConfig := &MockConfig{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockGrpcClient := &MockGRPCClient{}
|
|
mockPort := &MockPort{}
|
|
|
|
serverConfig, _ := getTestSSHConfig()
|
|
|
|
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
|
mockConfig.On("Domain").Return("test.com")
|
|
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
|
mockConfig.On("SSHPort").Return("2200")
|
|
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
|
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
serverAddr := listener.Addr().String()
|
|
|
|
s := &server{
|
|
randomizer: mockRandom,
|
|
config: mockConfig,
|
|
sshPort: "0",
|
|
sshConfig: serverConfig,
|
|
grpcClient: mockGrpcClient,
|
|
sessionRegistry: mockSessionRegistry,
|
|
portRegistry: mockPort,
|
|
}
|
|
|
|
done := make(chan bool, 1)
|
|
|
|
go func() {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.handleConnection(conn)
|
|
done <- true
|
|
}()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
clientConfig := &ssh.ClientConfig{
|
|
User: "testuser",
|
|
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 2 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
|
if err != nil {
|
|
t.Logf("Client dial failed: %v", err)
|
|
return
|
|
}
|
|
defer client.Close()
|
|
|
|
type forwardPayload struct {
|
|
BindAddr string
|
|
BindPort uint32
|
|
}
|
|
|
|
payload := ssh.Marshal(forwardPayload{
|
|
BindAddr: "localhost",
|
|
BindPort: 80,
|
|
})
|
|
|
|
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
|
if err != nil {
|
|
t.Logf("Forward request failed: %v", err)
|
|
}
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
mockGrpcClient.AssertExpectations(t)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("handleConnection did not complete in time")
|
|
}
|
|
})
|
|
|
|
t.Run("SSH connection with gRPC authorization error", func(t *testing.T) {
|
|
mockRandom := &MockRandom{}
|
|
mockConfig := &MockConfig{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockGrpcClient := &MockGRPCClient{}
|
|
mockPort := &MockPort{}
|
|
|
|
serverConfig, _ := getTestSSHConfig()
|
|
|
|
mockGrpcClient.On("AuthorizeConn", mock.Anything, "testuser").Return(true, "authorized_user", nil)
|
|
mockConfig.On("Domain").Return("test.com")
|
|
mockConfig.On("Mode").Return(types.ServerModeNODE)
|
|
mockConfig.On("SSHPort").Return("2200")
|
|
mockRandom.On("String", mock.Anything).Return("ilovefemboy", nil)
|
|
mockSessionRegistry.On("Register", mock.Anything, mock.Anything).Return(true)
|
|
mockSessionRegistry.On("Remove", mock.Anything).Return(nil)
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer listener.Close()
|
|
|
|
serverAddr := listener.Addr().String()
|
|
|
|
s := &server{
|
|
randomizer: mockRandom,
|
|
config: mockConfig,
|
|
sshPort: "0",
|
|
sshConfig: serverConfig,
|
|
grpcClient: mockGrpcClient,
|
|
sessionRegistry: mockSessionRegistry,
|
|
portRegistry: mockPort,
|
|
}
|
|
|
|
done := make(chan bool, 1)
|
|
|
|
go func() {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.handleConnection(conn)
|
|
done <- true
|
|
}()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
|
|
clientConfig := &ssh.ClientConfig{
|
|
User: "testuser",
|
|
Auth: []ssh.AuthMethod{ssh.Password("password")},
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
Timeout: 2 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
client, err := ssh.Dial("tcp", serverAddr, clientConfig)
|
|
if err != nil {
|
|
t.Logf("Client dial failed: %v", err)
|
|
return
|
|
}
|
|
defer client.Close()
|
|
|
|
type forwardPayload struct {
|
|
BindAddr string
|
|
BindPort uint32
|
|
}
|
|
|
|
payload := ssh.Marshal(forwardPayload{
|
|
BindAddr: "localhost",
|
|
BindPort: 8080,
|
|
})
|
|
|
|
_, _, err = client.SendRequest("tcpip-forward", true, payload)
|
|
if err != nil {
|
|
t.Logf("Forward request failed: %v", err)
|
|
}
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
mockGrpcClient.AssertExpectations(t)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("handleConnection did not complete in time")
|
|
}
|
|
})
|
|
|
|
t.Run("connection cleanup on close", func(t *testing.T) {
|
|
mockRandom := &MockRandom{}
|
|
mockConfig := &MockConfig{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockPort := &MockPort{}
|
|
|
|
serverConfig, _ := getTestSSHConfig()
|
|
|
|
serverConn, clientConn := net.Pipe()
|
|
|
|
s := &server{
|
|
randomizer: mockRandom,
|
|
config: mockConfig,
|
|
sshPort: "0",
|
|
sshConfig: serverConfig,
|
|
grpcClient: nil,
|
|
sessionRegistry: mockSessionRegistry,
|
|
portRegistry: mockPort,
|
|
}
|
|
|
|
done := make(chan bool, 1)
|
|
|
|
go func() {
|
|
s.handleConnection(serverConn)
|
|
done <- true
|
|
}()
|
|
|
|
clientConn.Close()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(1 * time.Second):
|
|
t.Fatal("handleConnection did not complete in time")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestIntegration(t *testing.T) {
|
|
t.Run("full server lifecycle", func(t *testing.T) {
|
|
mr := new(MockRandom)
|
|
mc := new(MockConfig)
|
|
mreg := new(MockSessionRegistry)
|
|
mg := new(MockGRPCClient)
|
|
mp := new(MockPort)
|
|
sc, _ := getTestSSHConfig()
|
|
|
|
s, err := New(mr, mc, sc, mreg, mg, mp, "0")
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, s)
|
|
|
|
go func() {
|
|
time.Sleep(100 * time.Millisecond)
|
|
err := s.Close()
|
|
assert.NoError(t, err)
|
|
}()
|
|
|
|
s.Start()
|
|
})
|
|
|
|
t.Run("multiple connections", func(t *testing.T) {
|
|
mockRandom := &MockRandom{}
|
|
mockConfig := &MockConfig{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockPort := &MockPort{}
|
|
|
|
sshConfig, _ := getTestSSHConfig()
|
|
|
|
conn1Server, conn1Client := net.Pipe()
|
|
conn2Server, conn2Client := net.Pipe()
|
|
|
|
mockListener := &MockListener{}
|
|
mockListener.On("Accept").Return(conn1Server, nil).Once()
|
|
mockListener.On("Accept").Return(conn2Server, nil).Once()
|
|
mockListener.On("Accept").Return(nil, net.ErrClosed).Once()
|
|
|
|
s := &server{
|
|
randomizer: mockRandom,
|
|
config: mockConfig,
|
|
sshPort: "0",
|
|
sshListener: mockListener,
|
|
sshConfig: sshConfig,
|
|
grpcClient: nil,
|
|
sessionRegistry: mockSessionRegistry,
|
|
portRegistry: mockPort,
|
|
}
|
|
|
|
go s.Start()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
conn1Client.Close()
|
|
time.Sleep(50 * time.Millisecond)
|
|
conn2Client.Close()
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
mockListener.AssertExpectations(t)
|
|
})
|
|
}
|
|
|
|
func TestErrorHandling(t *testing.T) {
|
|
t.Run("write error during SSH handshake", func(t *testing.T) {
|
|
mockRandom := &MockRandom{}
|
|
mockConfig := &MockConfig{}
|
|
mockSessionRegistry := &MockSessionRegistry{}
|
|
mockPort := &MockPort{}
|
|
|
|
sshConfig, _ := getTestSSHConfig()
|
|
|
|
serverConn, clientConn := net.Pipe()
|
|
clientConn.Close()
|
|
|
|
s := &server{
|
|
randomizer: mockRandom,
|
|
config: mockConfig,
|
|
sshPort: "0",
|
|
sshConfig: sshConfig,
|
|
grpcClient: nil,
|
|
sessionRegistry: mockSessionRegistry,
|
|
portRegistry: mockPort,
|
|
}
|
|
|
|
s.handleConnection(serverConn)
|
|
})
|
|
}
|