feat(testing): add comprehensive test coverage and code quality improvements #76

Merged
bagas merged 47 commits from feat/testing into staging 2026-01-27 16:36:40 +07:00
Showing only changes of commit 58f1fdabe1 - Show all commits
+268
View File
@@ -0,0 +1,268 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"net"
"testing"
"time"
"tunnel_pls/internal/registry"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/crypto/ssh"
"google.golang.org/grpc"
)
type mockRandom struct {
mock.Mock
}
func (m *mockRandom) String(length int) (string, error) {
args := m.Called(length)
return args.String(0), args.Error(1)
}
type mockConfig struct {
mock.Mock
}
func (m *mockConfig) Domain() string { return m.Called().String(0) }
func (m *mockConfig) SSHPort() string { return m.Called().String(0) }
func (m *mockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *mockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *mockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *mockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *mockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *mockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *mockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *mockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *mockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *mockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *mockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *mockConfig) PprofPort() string { return m.Called().String(0) }
func (m *mockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *mockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *mockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *mockConfig) NodeToken() string { return m.Called().String(0) }
type mockRegistry struct {
mock.Mock
}
func (m *mockRegistry) Get(key types.SessionKey) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *mockRegistry) GetWithUser(user string, key types.SessionKey) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool {
return m.Called(key, session).Bool(0)
}
func (m *mockRegistry) Update(user string, oldKey types.SessionKey, newKey types.SessionKey) error {
return m.Called(user, oldKey, newKey).Error(0)
}
func (m *mockRegistry) Remove(key types.SessionKey) {
m.Called(key)
}
func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session {
return m.Called(user).Get(0).([]registry.Session)
}
type mockGrpcClient struct {
mock.Mock
}
func (m *mockGrpcClient) SubscribeEvents(ctx context.Context, identity string, authToken string) error {
return m.Called(ctx, identity, authToken).Error(0)
}
func (m *mockGrpcClient) ClientConn() *grpc.ClientConn {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(*grpc.ClientConn)
}
func (m *mockGrpcClient) AuthorizeConn(ctx context.Context, token string) (bool, string, error) {
args := m.Called(ctx, token)
return args.Bool(0), args.String(1), args.Error(2)
}
func (m *mockGrpcClient) CheckServerHealth(ctx context.Context) error {
return m.Called(ctx).Error(0)
}
func (m *mockGrpcClient) Close() error {
return m.Called().Error(0)
}
type mockPort struct {
mock.Mock
}
func (m *mockPort) AddRange(startPort, endPort uint16) error {
return m.Called(startPort, endPort).Error(0)
}
func (m *mockPort) Unassigned() (uint16, bool) {
args := m.Called()
return uint16(args.Int(0)), args.Bool(1)
}
func (m *mockPort) SetStatus(port uint16, assigned bool) error {
return m.Called(port, assigned).Error(0)
}
func (m *mockPort) Claim(port uint16) bool {
return m.Called(port).Bool(0)
}
type mockListener struct {
mock.Mock
}
func (m *mockListener) Accept() (net.Conn, error) {
args := m.Called()
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(net.Conn), args.Error(1)
}
func (m *mockListener) Close() error {
return m.Called().Error(0)
}
func (m *mockListener) Addr() net.Addr {
return m.Called().Get(0).(net.Addr)
}
func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) {
key, _ := rsa.GenerateKey(rand.Reader, 2048)
signer, _ := ssh.NewSignerFromKey(key)
config := &ssh.ServerConfig{
NoClientAuth: true,
}
config.AddHostKey(signer)
return config, signer
}
func TestNew(t *testing.T) {
mr := new(mockRandom)
mc := new(mockConfig)
mreg := new(mockRegistry)
mg := new(mockGrpcClient)
mp := new(mockPort)
sc, _ := getTestSSHConfig()
tests := []struct {
name string
port string
wantErr bool
}{
{
name: "success",
port: "0",
wantErr: false,
},
{
name: "invalid port",
port: "invalid",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := New(mr, mc, sc, mreg, mg, mp, tt.port)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, s)
} else {
assert.NoError(t, err)
assert.NotNil(t, s)
_ = s.Close()
}
})
}
t.Run("port already in use", func(t *testing.T) {
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
port := l.Addr().(*net.TCPAddr).Port
defer l.Close()
s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port))
assert.Error(t, err)
assert.Nil(t, s)
})
}
func TestClose(t *testing.T) {
mr := new(mockRandom)
mc := new(mockConfig)
mreg := new(mockRegistry)
mg := new(mockGrpcClient)
mp := new(mockPort)
sc, _ := getTestSSHConfig()
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
err := s.Close()
assert.NoError(t, err)
err = s.Close()
assert.Error(t, err)
}
func TestStart(t *testing.T) {
mr := new(mockRandom)
mc := new(mockConfig)
mreg := new(mockRegistry)
mg := new(mockGrpcClient)
mp := new(mockPort)
sc, _ := getTestSSHConfig()
t.Run("normal stop", func(t *testing.T) {
s, _ := New(mr, mc, sc, mreg, mg, mp, "0")
go func() {
time.Sleep(100 * time.Millisecond)
_ = s.Close()
}()
s.Start()
})
t.Run("accept error", func(t *testing.T) {
ml := new(mockListener)
s := &server{
sshListener: ml,
sshPort: "0",
}
ml.On("Accept").Return(nil, errors.New("temporary error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
s.Start()
ml.AssertExpectations(t)
})
}