package server import ( "context" "crypto/rand" "crypto/rsa" "errors" "fmt" "net" "testing" "time" "tunnel_pls/internal/registry" "tunnel_pls/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "golang.org/x/crypto/ssh" "google.golang.org/grpc" ) type mockRandom struct { mock.Mock } func (m *mockRandom) String(length int) (string, error) { args := m.Called(length) return args.String(0), args.Error(1) } type MockConfig struct { mock.Mock } func (m *MockConfig) Domain() string { return m.Called().String(0) } func (m *MockConfig) SSHPort() string { return m.Called().String(0) } func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } func (m *MockConfig) TLSStoragePath() string { return m.Called().String(0) } func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } func (m *MockConfig) NodeToken() string { return m.Called().String(0) } func (m *MockConfig) KeyLoc() string { return m.Called().String(0) } type mockRegistry struct { mock.Mock } func (m *mockRegistry) Get(key types.SessionKey) (registry.Session, error) { args := m.Called(key) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(registry.Session), args.Error(1) } func (m *mockRegistry) GetWithUser(user string, key types.SessionKey) (registry.Session, error) { args := m.Called(user, key) if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(registry.Session), args.Error(1) } func (m *mockRegistry) Register(key types.SessionKey, session registry.Session) bool { return m.Called(key, session).Bool(0) } func (m *mockRegistry) Update(user string, oldKey types.SessionKey, newKey types.SessionKey) error { return m.Called(user, oldKey, newKey).Error(0) } func (m *mockRegistry) Remove(key types.SessionKey) { m.Called(key) } func (m *mockRegistry) GetAllSessionFromUser(user string) []registry.Session { return m.Called(user).Get(0).([]registry.Session) } type mockGrpcClient struct { mock.Mock } func (m *mockGrpcClient) SubscribeEvents(ctx context.Context, identity string, authToken string) error { return m.Called(ctx, identity, authToken).Error(0) } func (m *mockGrpcClient) ClientConn() *grpc.ClientConn { args := m.Called() if args.Get(0) == nil { return nil } return args.Get(0).(*grpc.ClientConn) } func (m *mockGrpcClient) AuthorizeConn(ctx context.Context, token string) (bool, string, error) { args := m.Called(ctx, token) return args.Bool(0), args.String(1), args.Error(2) } func (m *mockGrpcClient) CheckServerHealth(ctx context.Context) error { return m.Called(ctx).Error(0) } func (m *mockGrpcClient) Close() error { return m.Called().Error(0) } type mockPort struct { mock.Mock } func (m *mockPort) AddRange(startPort, endPort uint16) error { return m.Called(startPort, endPort).Error(0) } func (m *mockPort) Unassigned() (uint16, bool) { args := m.Called() return uint16(args.Int(0)), args.Bool(1) } func (m *mockPort) SetStatus(port uint16, assigned bool) error { return m.Called(port, assigned).Error(0) } func (m *mockPort) Claim(port uint16) bool { return m.Called(port).Bool(0) } type mockListener struct { mock.Mock } func (m *mockListener) Accept() (net.Conn, error) { args := m.Called() if args.Get(0) == nil { return nil, args.Error(1) } return args.Get(0).(net.Conn), args.Error(1) } func (m *mockListener) Close() error { return m.Called().Error(0) } func (m *mockListener) Addr() net.Addr { return m.Called().Get(0).(net.Addr) } func getTestSSHConfig() (*ssh.ServerConfig, ssh.Signer) { key, _ := rsa.GenerateKey(rand.Reader, 2048) signer, _ := ssh.NewSignerFromKey(key) config := &ssh.ServerConfig{ NoClientAuth: true, } config.AddHostKey(signer) return config, signer } func TestNew(t *testing.T) { mr := new(mockRandom) mc := new(MockConfig) mreg := new(mockRegistry) mg := new(mockGrpcClient) mp := new(mockPort) sc, _ := getTestSSHConfig() tests := []struct { name string port string wantErr bool }{ { name: "success", port: "0", wantErr: false, }, { name: "invalid port", port: "invalid", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s, err := New(mr, mc, sc, mreg, mg, mp, tt.port) if tt.wantErr { assert.Error(t, err) assert.Nil(t, s) } else { assert.NoError(t, err) assert.NotNil(t, s) _ = s.Close() } }) } t.Run("port already in use", func(t *testing.T) { l, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } port := l.Addr().(*net.TCPAddr).Port defer l.Close() s, err := New(mr, mc, sc, mreg, mg, mp, fmt.Sprintf("%d", port)) assert.Error(t, err) assert.Nil(t, s) }) } func TestClose(t *testing.T) { mr := new(mockRandom) mc := new(MockConfig) mreg := new(mockRegistry) mg := new(mockGrpcClient) mp := new(mockPort) sc, _ := getTestSSHConfig() s, _ := New(mr, mc, sc, mreg, mg, mp, "0") err := s.Close() assert.NoError(t, err) err = s.Close() assert.Error(t, err) } func TestStart(t *testing.T) { mr := new(mockRandom) mc := new(MockConfig) mreg := new(mockRegistry) mg := new(mockGrpcClient) mp := new(mockPort) sc, _ := getTestSSHConfig() t.Run("normal stop", func(t *testing.T) { s, _ := New(mr, mc, sc, mreg, mg, mp, "0") go func() { time.Sleep(100 * time.Millisecond) _ = s.Close() }() s.Start() }) t.Run("accept error", func(t *testing.T) { ml := new(mockListener) s := &server{ sshListener: ml, sshPort: "0", } ml.On("Accept").Return(nil, errors.New("temporary error")).Once() ml.On("Accept").Return(nil, net.ErrClosed).Once() s.Start() ml.AssertExpectations(t) }) }