test(server): add unit tests for server startup behavior
This commit is contained in:
@@ -0,0 +1,268 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
type mockRandom struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockRandom) String(length int) (string, error) {
|
||||
args := m.Called(length)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
type mockConfig struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockConfig) Domain() string { return m.Called().String(0) }
|
||||
func (m *mockConfig) SSHPort() string { return m.Called().String(0) }
|
||||
func (m *mockConfig) HTTPPort() string { return m.Called().String(0) }
|
||||
func (m *mockConfig) HTTPSPort() string { return m.Called().String(0) }
|
||||
func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *mockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
||||
func (m *mockConfig) ACMEEmail() string { return m.Called().String(0) }
|
||||
func (m *mockConfig) CFAPIToken() string { return m.Called().String(0) }
|
||||
func (m *mockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *mockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *mockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *mockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *mockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *mockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
|
||||
func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) }
|
||||
func (m *mockConfig) GRPCPort() string { return m.Called().String(0) }
|
||||
func (m *mockConfig) NodeToken() string { return m.Called().String(0) }
|
||||
|
||||
type mockRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Get(key types.SessionKey) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetWithUser(user string, key types.SessionKey) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool {
|
||||
return m.Called(key, session).Bool(0)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Update(user string, oldKey types.SessionKey, newKey types.SessionKey) error {
|
||||
return m.Called(user, oldKey, newKey).Error(0)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Remove(key types.SessionKey) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
return m.Called(user).Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
type mockGrpcClient struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) SubscribeEvents(ctx context.Context, identity string, authToken string) error {
|
||||
return m.Called(ctx, identity, authToken).Error(0)
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) ClientConn() *grpc.ClientConn {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil
|
||||
}
|
||||
return args.Get(0).(*grpc.ClientConn)
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) AuthorizeConn(ctx context.Context, token string) (bool, string, error) {
|
||||
args := m.Called(ctx, token)
|
||||
return args.Bool(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) CheckServerHealth(ctx context.Context) error {
|
||||
return m.Called(ctx).Error(0)
|
||||
}
|
||||
|
||||
func (m *mockGrpcClient) Close() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
type mockPort struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockPort) AddRange(startPort, endPort uint16) error {
|
||||
return m.Called(startPort, endPort).Error(0)
|
||||
}
|
||||
|
||||
func (m *mockPort) Unassigned() (uint16, bool) {
|
||||
args := m.Called()
|
||||
return uint16(args.Int(0)), args.Bool(1)
|
||||
}
|
||||
|
||||
func (m *mockPort) SetStatus(port uint16, assigned bool) error {
|
||||
return m.Called(port, assigned).Error(0)
|
||||
}
|
||||
|
||||
func (m *mockPort) Claim(port uint16) bool {
|
||||
return m.Called(port).Bool(0)
|
||||
}
|
||||
|
||||
type mockListener struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockListener) Accept() (net.Conn, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(net.Conn), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockListener) Close() error {
|
||||
return m.Called().Error(0)
|
||||
}
|
||||
|
||||
func (m *mockListener) Addr() net.Addr {
|
||||
return m.Called().Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
signer, _ := ssh.NewSignerFromKey(key)
|
||||
config := &ssh.ServerConfig{
|
||||
NoClientAuth: true,
|
||||
}
|
||||
config.AddHostKey(signer)
|
||||
return config, signer
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
mr := new(mockRandom)
|
||||
mc := new(mockConfig)
|
||||
mreg := new(mockRegistry)
|
||||
mg := new(mockGrpcClient)
|
||||
mp := new(mockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
port string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
port: "0",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid port",
|
||||
port: "invalid",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, err := New(mr, mc, sc, mreg, mg, mp, tt.port)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, s)
|
||||
_ = s.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("port already in use", func(t *testing.T) {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := l.Addr().(*net.TCPAddr).Port
|
||||
defer l.Close()
|
||||
|
||||
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, s)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
mr := new(mockRandom)
|
||||
mc := new(mockConfig)
|
||||
mreg := new(mockRegistry)
|
||||
mg := new(mockGrpcClient)
|
||||
mp := new(mockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
err := s.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = s.Close()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStart(t *testing.T) {
|
||||
mr := new(mockRandom)
|
||||
mc := new(mockConfig)
|
||||
mreg := new(mockRegistry)
|
||||
mg := new(mockGrpcClient)
|
||||
mp := new(mockPort)
|
||||
sc, _ := getTestSSHConfig()
|
||||
|
||||
t.Run("normal stop", func(t *testing.T) {
|
||||
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_ = s.Close()
|
||||
}()
|
||||
s.Start()
|
||||
})
|
||||
|
||||
t.Run("accept error", func(t *testing.T) {
|
||||
ml := new(mockListener)
|
||||
s := &server{
|
||||
sshListener: ml,
|
||||
sshPort: "0",
|
||||
}
|
||||
|
||||
ml.On("Accept").Return(nil, errors.New("temporary error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
s.Start()
|
||||
ml.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user