From a350692e81436f2b8b055941cceca1023eb29fcd Mon Sep 17 00:00:00 2001 From: bagas Date: Thu, 22 Jan 2026 12:30:20 +0700 Subject: [PATCH] test(middleware): add unit tests for middleware behavior - remove redundant check on registry.Update and check if slug exist before locking the mutex - Update SonarQube action to not use Go cache when setting up Go --- .gitea/workflows/sonarqube.yml | 1 + internal/middleware/forwardedfor_test.go | 130 ++++++++++++++++++ internal/middleware/tunnelfingerprint_test.go | 76 ++++++++++ internal/registry/registry.go | 10 +- 4 files changed, 211 insertions(+), 6 deletions(-) create mode 100644 internal/middleware/forwardedfor_test.go create mode 100644 internal/middleware/tunnelfingerprint_test.go diff --git a/.gitea/workflows/sonarqube.yml b/.gitea/workflows/sonarqube.yml index 4a00cc5..dc03a14 100644 --- a/.gitea/workflows/sonarqube.yml +++ b/.gitea/workflows/sonarqube.yml @@ -18,6 +18,7 @@ jobs: uses: actions/setup-go@v6 with: go-version: '1.25.5' + cache: false - name: Install dependencies run: go mod tidy diff --git a/internal/middleware/forwardedfor_test.go b/internal/middleware/forwardedfor_test.go new file mode 100644 index 0000000..ef6a536 --- /dev/null +++ b/internal/middleware/forwardedfor_test.go @@ -0,0 +1,130 @@ +package middleware + +import ( + "net" + "testing" +) + +type mockRequestHeader struct { + headers map[string]string +} + +func (m *mockRequestHeader) Value(key string) string { + return m.headers[key] +} + +func (m *mockRequestHeader) Set(key string, value string) { + m.headers[key] = value +} + +func (m *mockRequestHeader) Remove(key string) { + delete(m.headers, key) +} + +func (m *mockRequestHeader) Finalize() []byte { + return []byte{} +} + +func (m *mockRequestHeader) Method() string { + return "" +} + +func (m *mockRequestHeader) Path() string { + return "" +} + +func (m *mockRequestHeader) Version() string { + return "" +} + +func TestForwardedFor_HandleRequest(t *testing.T) { + tests := []struct { + name string + addr net.Addr + expectedHost string + expectError bool + }{ + { + name: "valid IPv4 address", + addr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 8080}, + expectedHost: "192.168.1.100", + expectError: false, + }, + { + name: "valid IPv6 address", + addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 8080}, + expectedHost: "2001:db8::ff00:42:8329", + expectError: false, + }, + { + name: "invalid address format", + addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"}, + expectedHost: "", + expectError: true, + }, + { + name: "valid IPv4 address with port", + addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234}, + expectedHost: "127.0.0.1", + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ff := NewForwardedFor(tc.addr) + reqHeader := &mockRequestHeader{headers: make(map[string]string)} + + err := ff.HandleRequest(reqHeader) + + if tc.expectError { + if err == nil { + t.Fatalf("expected error but got none") + } + } else { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + host := reqHeader.Value("X-Forwarded-For") + if host != tc.expectedHost { + t.Errorf("expected X-Forwarded-For header to be '%s', got '%s'", tc.expectedHost, host) + } + } + }) + } +} + +func TestNewForwardedFor(t *testing.T) { + tests := []struct { + name string + addr net.Addr + expectAddr net.Addr + }{ + { + name: "IPv4 address", + addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + expectAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, + }, + { + name: "IPv6 address", + addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0}, + expectAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0}, + }, + { + name: "Unix address", + addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"}, + expectAddr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ff := NewForwardedFor(tc.addr) + + if ff.addr.String() != tc.expectAddr.String() { + t.Errorf("expected addr to be '%v', got '%v'", tc.expectAddr, ff.addr) + } + }) + } +} diff --git a/internal/middleware/tunnelfingerprint_test.go b/internal/middleware/tunnelfingerprint_test.go new file mode 100644 index 0000000..4753ac0 --- /dev/null +++ b/internal/middleware/tunnelfingerprint_test.go @@ -0,0 +1,76 @@ +package middleware + +import ( + "errors" + "testing" +) + +type mockResponseHeader struct { + headers map[string]string +} + +func (m *mockResponseHeader) Value(key string) string { + return m.headers[key] +} + +func (m *mockResponseHeader) Set(key string, value string) { + m.headers[key] = value +} + +func (m *mockResponseHeader) Remove(key string) { + delete(m.headers, key) +} + +func (m *mockResponseHeader) Finalize() []byte { + return nil +} + +func TestTunnelFingerprintHandleResponse(t *testing.T) { + tests := []struct { + name string + initialState map[string]string + expected map[string]string + body []byte + wantErr error + }{ + { + name: "Sets Server Header", + initialState: map[string]string{}, + expected: map[string]string{"Server": "Tunnel Please"}, + body: []byte("Sample body"), + wantErr: nil, + }, + { + name: "Overwrites Server Header", + initialState: map[string]string{"Server": "Old Value"}, + expected: map[string]string{"Server": "Tunnel Please"}, + body: nil, + wantErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockHeader := &mockResponseHeader{headers: tt.initialState} + tunnelFingerprint := NewTunnelFingerprint() + + err := tunnelFingerprint.HandleResponse(mockHeader, tt.body) + if !errors.Is(err, tt.wantErr) { + t.Fatalf("unexpected error, got: %v, want: %v", err, tt.wantErr) + } + + for key, expectedValue := range tt.expected { + if val := mockHeader.Value(key); val != expectedValue { + t.Errorf("header[%q] = %q; want %q", key, val, expectedValue) + } + } + }) + } +} + +func TestNewTunnelFingerprint(t *testing.T) { + instance := NewTunnelFingerprint() + if instance == nil { + t.Errorf("NewTunnelFingerprint() = nil; want non-nil instance") + } +} diff --git a/internal/registry/registry.go b/internal/registry/registry.go index 89cac48..e12ea0b 100644 --- a/internal/registry/registry.go +++ b/internal/registry/registry.go @@ -94,12 +94,13 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { return ErrInvalidSlug } - r.mu.Lock() - defer r.mu.Unlock() - if _, exists := r.slugIndex[newKey]; exists && newKey != oldKey { return ErrSlugInUse } + + r.mu.Lock() + defer r.mu.Unlock() + client, ok := r.byUser[user][oldKey] if !ok { return ErrSessionNotFound @@ -111,9 +112,6 @@ func (r *registry) Update(user string, oldKey, newKey Key) error { client.Slug().Set(newKey.Id) r.slugIndex[newKey] = user - if r.byUser[user] == nil { - r.byUser[user] = make(map[Key]Session) - } r.byUser[user][newKey] = client return nil }