chore(tests): migrate to Testify for mocking and assertions

This commit is contained in:
2026-01-26 11:53:00 +07:00
parent 65df01fee5
commit ee1dc3c3cd
10 changed files with 530 additions and 547 deletions
+19 -24
View File
@@ -1,40 +1,42 @@
package middleware
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"net"
"testing"
)
type mockRequestHeader struct {
headers map[string]string
mock.Mock
}
func (m *mockRequestHeader) Value(key string) string {
return m.headers[key]
return m.Called(key).String(0)
}
func (m *mockRequestHeader) Set(key string, value string) {
m.headers[key] = value
m.Called(key, value)
}
func (m *mockRequestHeader) Remove(key string) {
delete(m.headers, key)
m.Called(key)
}
func (m *mockRequestHeader) Finalize() []byte {
return []byte{}
return m.Called().Get(0).([]byte)
}
func (m *mockRequestHeader) Method() string {
return ""
return m.Called().String(0)
}
func (m *mockRequestHeader) Path() string {
return ""
return m.Called().String(0)
}
func (m *mockRequestHeader) Version() string {
return ""
return m.Called().String(0)
}
func TestForwardedFor_HandleRequest(t *testing.T) {
@@ -73,23 +75,19 @@ func TestForwardedFor_HandleRequest(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ff := NewForwardedFor(tc.addr)
reqHeader := &mockRequestHeader{headers: make(map[string]string)}
reqHeader := new(mockRequestHeader)
if !tc.expectError {
reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return()
}
err := ff.HandleRequest(reqHeader)
if tc.expectError {
if err == nil {
t.Fatalf("expected error but got none")
}
assert.Error(t, err)
} else {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
host := reqHeader.Value("X-Forwarded-For")
if host != tc.expectedHost {
t.Errorf("expected X-Forwarded-For header to be '%s', got '%s'", tc.expectedHost, host)
}
assert.NoError(t, err)
reqHeader.AssertExpectations(t)
}
})
}
@@ -121,10 +119,7 @@ func TestNewForwardedFor(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ff := NewForwardedFor(tc.addr)
if ff.addr.String() != tc.expectAddr.String() {
t.Errorf("expected addr to be '%v', got '%v'", tc.expectAddr, ff.addr)
}
assert.Equal(t, tc.expectAddr.String(), ff.addr.String())
})
}
}
+27 -34
View File
@@ -1,76 +1,69 @@
package middleware
import (
"errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"testing"
)
type mockResponseHeader struct {
headers map[string]string
mock.Mock
}
func (m *mockResponseHeader) Value(key string) string {
return m.headers[key]
return m.Called(key).String(0)
}
func (m *mockResponseHeader) Set(key string, value string) {
m.headers[key] = value
m.Called(key, value)
}
func (m *mockResponseHeader) Remove(key string) {
delete(m.headers, key)
m.Called(key)
}
func (m *mockResponseHeader) Finalize() []byte {
return nil
return m.Called().Get(0).([]byte)
}
func TestTunnelFingerprintHandleResponse(t *testing.T) {
tests := []struct {
name string
initialState map[string]string
expected map[string]string
body []byte
wantErr error
name string
expected map[string]string
body []byte
wantErr error
}{
{
name: "Sets Server Header",
initialState: map[string]string{},
expected: map[string]string{"Server": "Tunnel Please"},
body: []byte("Sample body"),
wantErr: nil,
name: "Sets Server Header",
expected: map[string]string{"Server": "Tunnel Please"},
body: []byte("Sample body"),
wantErr: nil,
},
{
name: "Overwrites Server Header",
initialState: map[string]string{"Server": "Old Value"},
expected: map[string]string{"Server": "Tunnel Please"},
body: nil,
wantErr: nil,
name: "Overwrites Server Header",
expected: map[string]string{"Server": "Tunnel Please"},
body: nil,
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockHeader := &mockResponseHeader{headers: tt.initialState}
mockHeader := new(mockResponseHeader)
for k, v := range tt.expected {
mockHeader.On("Set", k, v).Return()
}
tunnelFingerprint := NewTunnelFingerprint()
err := tunnelFingerprint.HandleResponse(mockHeader, tt.body)
if !errors.Is(err, tt.wantErr) {
t.Fatalf("unexpected error, got: %v, want: %v", err, tt.wantErr)
}
for key, expectedValue := range tt.expected {
if val := mockHeader.Value(key); val != expectedValue {
t.Errorf("header[%q] = %q; want %q", key, val, expectedValue)
}
}
assert.ErrorIs(t, err, tt.wantErr)
mockHeader.AssertExpectations(t)
})
}
}
func TestNewTunnelFingerprint(t *testing.T) {
instance := NewTunnelFingerprint()
if instance == nil {
t.Errorf("NewTunnelFingerprint() = nil; want non-nil instance")
}
assert.NotNil(t, instance)
}