127 lines
2.8 KiB
Go
127 lines
2.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
)
|
|
|
|
type mockRequestHeader struct {
|
|
mock.Mock
|
|
}
|
|
|
|
func (m *mockRequestHeader) Value(key string) string {
|
|
return m.Called(key).String(0)
|
|
}
|
|
|
|
func (m *mockRequestHeader) Set(key string, value string) {
|
|
m.Called(key, value)
|
|
}
|
|
|
|
func (m *mockRequestHeader) Remove(key string) {
|
|
m.Called(key)
|
|
}
|
|
|
|
func (m *mockRequestHeader) Finalize() []byte {
|
|
return m.Called().Get(0).([]byte)
|
|
}
|
|
|
|
func (m *mockRequestHeader) Method() string {
|
|
return m.Called().String(0)
|
|
}
|
|
|
|
func (m *mockRequestHeader) Path() string {
|
|
return m.Called().String(0)
|
|
}
|
|
|
|
func (m *mockRequestHeader) Version() string {
|
|
return m.Called().String(0)
|
|
}
|
|
|
|
func TestForwardedFor_HandleRequest(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
addr net.Addr
|
|
expectedHost string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "valid IPv4 address",
|
|
addr: &net.TCPAddr{IP: net.ParseIP("192.168.1.100"), Port: 8080},
|
|
expectedHost: "192.168.1.100",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "valid IPv6 address",
|
|
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 8080},
|
|
expectedHost: "2001:db8::ff00:42:8329",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "invalid address format",
|
|
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
|
expectedHost: "",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "valid IPv4 address with port",
|
|
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234},
|
|
expectedHost: "127.0.0.1",
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ff := NewForwardedFor(tc.addr)
|
|
reqHeader := new(mockRequestHeader)
|
|
|
|
if !tc.expectError {
|
|
reqHeader.On("Set", "X-Forwarded-For", tc.expectedHost).Return()
|
|
}
|
|
|
|
err := ff.HandleRequest(reqHeader)
|
|
|
|
if tc.expectError {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
reqHeader.AssertExpectations(t)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNewForwardedFor(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
addr net.Addr
|
|
expectAddr net.Addr
|
|
}{
|
|
{
|
|
name: "IPv4 address",
|
|
addr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
|
|
expectAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080},
|
|
},
|
|
{
|
|
name: "IPv6 address",
|
|
addr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
|
|
expectAddr: &net.TCPAddr{IP: net.ParseIP("2001:db8::ff00:42:8329"), Port: 0},
|
|
},
|
|
{
|
|
name: "Unix address",
|
|
addr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
|
expectAddr: &net.UnixAddr{Name: "/tmp/socket", Net: "unix"},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ff := NewForwardedFor(tc.addr)
|
|
assert.Equal(t, tc.expectAddr.String(), ff.addr.String())
|
|
})
|
|
}
|
|
}
|