test(stream): migrate mocking to testify
SonarQube Scan / SonarQube Trigger (push) Successful in 2m21s
SonarQube Scan / SonarQube Trigger (push) Successful in 2m21s
This commit is contained in:
+346
-160
@@ -9,66 +9,160 @@ import (
|
|||||||
"tunnel_pls/internal/http/header"
|
"tunnel_pls/internal/http/header"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockAddr struct {
|
type MockAddr struct {
|
||||||
addr string
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockAddr) String() string { return m.addr }
|
func (m *MockAddr) String() string {
|
||||||
func (m *mockAddr) Network() string { return "tcp" }
|
args := m.Called()
|
||||||
|
return args.String(0)
|
||||||
type mockRequestMiddleware struct {
|
|
||||||
err error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
|
func (m *MockAddr) Network() string {
|
||||||
if m.err == nil {
|
args := m.Called()
|
||||||
h.Set("X-Middleware", "true")
|
return args.String(0)
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockResponseMiddleware struct {
|
type MockRequestMiddleware struct {
|
||||||
err error
|
mock.Mock
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
|
func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
|
||||||
if m.err == nil {
|
args := m.Called(h)
|
||||||
h.Set("X-Resp-Middleware", "true")
|
return args.Error(0)
|
||||||
}
|
|
||||||
return m.err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockReadWriter struct {
|
type MockResponseMiddleware struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
|
||||||
|
args := m.Called(h, body)
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockReadWriter struct {
|
||||||
|
mock.Mock
|
||||||
bytes.Buffer
|
bytes.Buffer
|
||||||
closed bool
|
|
||||||
writeClosed bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReadWriter) Close() error {
|
func (m *MockReadWriter) Read(p []byte) (int, error) {
|
||||||
m.closed = true
|
args := m.Called(p)
|
||||||
return nil
|
return args.Int(0), args.Error(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockReadWriter) CloseWrite() error {
|
func (m *MockReadWriter) Write(p []byte) (int, error) {
|
||||||
m.writeClosed = true
|
args := m.Called(p)
|
||||||
return nil
|
return args.Int(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockReadWriter) Close() error {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockReadWriter) CloseWrite() error {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockReadWriterOnlyCloser struct {
|
||||||
|
mock.Mock
|
||||||
|
bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) {
|
||||||
|
args := m.Called(p)
|
||||||
|
return args.Int(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) {
|
||||||
|
args := m.Called(p)
|
||||||
|
return args.Int(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockReadWriterOnlyCloser) Close() error {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockWriterOnly struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWriterOnly) Write(p []byte) (int, error) {
|
||||||
|
args := m.Called(p)
|
||||||
|
return args.Int(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWriterOnly) Read(p []byte) (int, error) {
|
||||||
|
args := m.Called(p)
|
||||||
|
return args.Int(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockReader struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockReader) Read(p []byte) (int, error) {
|
||||||
|
args := m.Called(p)
|
||||||
|
return args.Int(0), args.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockWriter struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWriter) Write(p []byte) (int, error) {
|
||||||
|
ret := m.Called(p)
|
||||||
|
|
||||||
|
var n int
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch v := ret.Get(0).(type) {
|
||||||
|
case func([]byte) int:
|
||||||
|
n = v(p)
|
||||||
|
case int:
|
||||||
|
n = v
|
||||||
|
default:
|
||||||
|
n = len(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := ret.Get(1).(type) {
|
||||||
|
case func([]byte) error:
|
||||||
|
err = v(p)
|
||||||
|
case error:
|
||||||
|
err = v
|
||||||
|
default:
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockWriter) Close() error {
|
||||||
|
args := m.Called()
|
||||||
|
return args.Error(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHTTPMethods(t *testing.T) {
|
func TestHTTPMethods(t *testing.T) {
|
||||||
addr := &mockAddr{addr: "1.2.3.4:1234"}
|
addr := new(MockAddr)
|
||||||
rw := &mockReadWriter{}
|
addr.On("String").Return("1.2.3.4:1234")
|
||||||
|
|
||||||
|
rw := new(MockReadWriter)
|
||||||
hs := New(rw, rw, addr)
|
hs := New(rw, rw, addr)
|
||||||
|
|
||||||
assert.Equal(t, addr, hs.RemoteAddr())
|
assert.Equal(t, addr, hs.RemoteAddr())
|
||||||
|
|
||||||
reqMW := &mockRequestMiddleware{}
|
reqMW := new(MockRequestMiddleware)
|
||||||
hs.UseRequestMiddleware(reqMW)
|
hs.UseRequestMiddleware(reqMW)
|
||||||
assert.Equal(t, 1, len(hs.RequestMiddlewares()))
|
assert.Equal(t, 1, len(hs.RequestMiddlewares()))
|
||||||
assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
|
assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
|
||||||
|
|
||||||
respMW := &mockResponseMiddleware{}
|
respMW := new(MockResponseMiddleware)
|
||||||
hs.UseResponseMiddleware(respMW)
|
hs.UseResponseMiddleware(respMW)
|
||||||
assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
|
assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
|
||||||
assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
|
assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
|
||||||
@@ -78,19 +172,21 @@ func TestHTTPMethods(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestApplyMiddlewares(t *testing.T) {
|
func TestApplyMiddlewares(t *testing.T) {
|
||||||
addr := &mockAddr{addr: "1.2.3.4:1234"}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setup func(HTTP)
|
setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware)
|
||||||
apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
|
apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
|
||||||
verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
|
verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
|
||||||
expectErr bool
|
expectErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "apply request middleware success",
|
name: "apply request middleware success",
|
||||||
setup: func(hs HTTP) {
|
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||||
hs.UseRequestMiddleware(&mockRequestMiddleware{})
|
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
h := args.Get(0).(header.RequestHeader)
|
||||||
|
h.Set("X-Middleware", "true")
|
||||||
|
}).Return(nil)
|
||||||
|
hs.UseRequestMiddleware(reqMW)
|
||||||
},
|
},
|
||||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||||
return hs.ApplyRequestMiddlewares(reqH)
|
return hs.ApplyRequestMiddlewares(reqH)
|
||||||
@@ -101,8 +197,12 @@ func TestApplyMiddlewares(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "apply response middleware success",
|
name: "apply response middleware success",
|
||||||
setup: func(hs HTTP) {
|
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||||
hs.UseResponseMiddleware(&mockResponseMiddleware{})
|
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
h := args.Get(0).(header.ResponseHeader)
|
||||||
|
h.Set("X-Resp-Middleware", "true")
|
||||||
|
}).Return(nil)
|
||||||
|
hs.UseResponseMiddleware(respMW)
|
||||||
},
|
},
|
||||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
||||||
@@ -113,8 +213,9 @@ func TestApplyMiddlewares(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "apply request middleware error",
|
name: "apply request middleware error",
|
||||||
setup: func(hs HTTP) {
|
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||||
hs.UseRequestMiddleware(&mockRequestMiddleware{err: assert.AnError})
|
reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError)
|
||||||
|
hs.UseRequestMiddleware(reqMW)
|
||||||
},
|
},
|
||||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||||
return hs.ApplyRequestMiddlewares(reqH)
|
return hs.ApplyRequestMiddlewares(reqH)
|
||||||
@@ -123,8 +224,9 @@ func TestApplyMiddlewares(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "apply response middleware error",
|
name: "apply response middleware error",
|
||||||
setup: func(hs HTTP) {
|
setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
|
||||||
hs.UseResponseMiddleware(&mockResponseMiddleware{err: assert.AnError})
|
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError)
|
||||||
|
hs.UseResponseMiddleware(respMW)
|
||||||
},
|
},
|
||||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
||||||
@@ -137,9 +239,17 @@ func TestApplyMiddlewares(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||||
respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
||||||
rw := &mockReadWriter{}
|
|
||||||
|
addr := new(MockAddr)
|
||||||
|
addr.On("String").Return("1.2.3.4:1234")
|
||||||
|
|
||||||
|
rw := new(MockReadWriter)
|
||||||
hs := New(rw, rw, addr)
|
hs := New(rw, rw, addr)
|
||||||
tt.setup(hs)
|
|
||||||
|
reqMW := new(MockRequestMiddleware)
|
||||||
|
respMW := new(MockResponseMiddleware)
|
||||||
|
tt.setup(hs, reqMW, respMW)
|
||||||
|
|
||||||
err := tt.apply(hs, reqH, respH)
|
err := tt.apply(hs, reqH, respH)
|
||||||
if tt.expectErr {
|
if tt.expectErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -149,83 +259,96 @@ func TestApplyMiddlewares(t *testing.T) {
|
|||||||
tt.verify(t, reqH, respH)
|
tt.verify(t, reqH, respH)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
reqMW.AssertExpectations(t)
|
||||||
|
respMW.AssertExpectations(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockWriterOnly struct {
|
|
||||||
bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloseMethods(t *testing.T) {
|
func TestCloseMethods(t *testing.T) {
|
||||||
addr := &mockAddr{addr: "1.2.3.4:1234"}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
writer any
|
setup func() (io.Writer, io.Reader)
|
||||||
op func(HTTP) error
|
op func(HTTP) error
|
||||||
verify func(*testing.T, any)
|
verify func(*testing.T, io.Writer)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Close success",
|
name: "Close success",
|
||||||
writer: &mockReadWriter{},
|
setup: func() (io.Writer, io.Reader) {
|
||||||
op: func(hs HTTP) error { return hs.Close() },
|
rw := new(MockReadWriter)
|
||||||
verify: func(t *testing.T, w any) {
|
rw.On("Close").Return(nil)
|
||||||
assert.True(t, w.(*mockReadWriter).closed)
|
return rw, rw
|
||||||
|
},
|
||||||
|
op: func(hs HTTP) error { return hs.Close() },
|
||||||
|
verify: func(t *testing.T, w io.Writer) {
|
||||||
|
w.(*MockReadWriter).AssertCalled(t, "Close")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "CloseWrite with CloseWrite implementation",
|
name: "CloseWrite with CloseWrite implementation",
|
||||||
writer: &mockReadWriter{},
|
setup: func() (io.Writer, io.Reader) {
|
||||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
rw := new(MockReadWriter)
|
||||||
verify: func(t *testing.T, w any) {
|
rw.On("CloseWrite").Return(nil)
|
||||||
assert.True(t, w.(*mockReadWriter).writeClosed)
|
return rw, rw
|
||||||
|
},
|
||||||
|
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||||
|
verify: func(t *testing.T, w io.Writer) {
|
||||||
|
w.(*MockReadWriter).AssertCalled(t, "CloseWrite")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "CloseWrite fallback to Close",
|
name: "CloseWrite fallback to Close",
|
||||||
writer: &mockReadWriterOnlyCloser{},
|
setup: func() (io.Writer, io.Reader) {
|
||||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
rw := new(MockReadWriterOnlyCloser)
|
||||||
verify: func(t *testing.T, w any) {
|
rw.On("Close").Return(nil)
|
||||||
assert.True(t, w.(*mockReadWriterOnlyCloser).closed)
|
return rw, rw
|
||||||
|
},
|
||||||
|
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||||
|
verify: func(t *testing.T, w io.Writer) {
|
||||||
|
w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Close with No Closer",
|
name: "Close with No Closer",
|
||||||
writer: &mockWriterOnly{},
|
setup: func() (io.Writer, io.Reader) {
|
||||||
op: func(hs HTTP) error { return hs.Close() },
|
w := new(MockWriterOnly)
|
||||||
|
r := new(MockReader)
|
||||||
|
return w, r
|
||||||
|
},
|
||||||
|
op: func(hs HTTP) error { return hs.Close() },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "CloseWrite with No CloseWrite and No Closer",
|
name: "CloseWrite with No CloseWrite and No Closer",
|
||||||
writer: &mockWriterOnly{},
|
setup: func() (io.Writer, io.Reader) {
|
||||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
w := new(MockWriterOnly)
|
||||||
|
r := new(MockReader)
|
||||||
|
return w, r
|
||||||
|
},
|
||||||
|
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr)
|
addr := new(MockAddr)
|
||||||
|
addr.On("String").Return("1.2.3.4:1234")
|
||||||
|
|
||||||
|
w, r := tt.setup()
|
||||||
|
hs := New(w, r, addr)
|
||||||
|
|
||||||
assert.NotPanics(t, func() {
|
assert.NotPanics(t, func() {
|
||||||
err := tt.op(hs)
|
err := tt.op(hs)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
if tt.verify != nil {
|
if tt.verify != nil {
|
||||||
tt.verify(t, tt.writer)
|
tt.verify(t, w)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockReadWriterOnlyCloser struct {
|
|
||||||
bytes.Buffer
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockReadWriterOnlyCloser) Close() error {
|
|
||||||
m.closed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSplitHeaderAndBody(t *testing.T) {
|
func TestSplitHeaderAndBody(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -319,6 +442,7 @@ func TestRead(t *testing.T) {
|
|||||||
expectRead int
|
expectRead int
|
||||||
expectErr bool
|
expectErr bool
|
||||||
middlewareErr error
|
middlewareErr error
|
||||||
|
isHTTP bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "valid http request",
|
name: "valid http request",
|
||||||
@@ -326,6 +450,7 @@ func TestRead(t *testing.T) {
|
|||||||
readLen: 100,
|
readLen: 100,
|
||||||
expectContent: "Body",
|
expectContent: "Body",
|
||||||
expectRead: 54,
|
expectRead: 54,
|
||||||
|
isHTTP: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "non-http data",
|
name: "non-http data",
|
||||||
@@ -333,6 +458,7 @@ func TestRead(t *testing.T) {
|
|||||||
readLen: 100,
|
readLen: 100,
|
||||||
expectContent: "Some random data\r\n\r\nMore data",
|
expectContent: "Some random data\r\n\r\nMore data",
|
||||||
expectRead: 29,
|
expectRead: 29,
|
||||||
|
isHTTP: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no delimiter",
|
name: "no delimiter",
|
||||||
@@ -340,6 +466,7 @@ func TestRead(t *testing.T) {
|
|||||||
readLen: 100,
|
readLen: 100,
|
||||||
expectContent: "Partial data without delimiter",
|
expectContent: "Partial data without delimiter",
|
||||||
expectRead: 30,
|
expectRead: 30,
|
||||||
|
isHTTP: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "middleware error",
|
name: "middleware error",
|
||||||
@@ -347,20 +474,45 @@ func TestRead(t *testing.T) {
|
|||||||
readLen: 100,
|
readLen: 100,
|
||||||
middlewareErr: assert.AnError,
|
middlewareErr: assert.AnError,
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
|
isHTTP: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
rw := &mockReadWriter{}
|
addr := new(MockAddr)
|
||||||
rw.Write(tt.input)
|
addr.On("String").Return("1.2.3.4:1234")
|
||||||
hs := New(rw, rw, &mockAddr{})
|
|
||||||
if tt.middlewareErr != nil {
|
reader := new(MockReader)
|
||||||
hs.UseRequestMiddleware(&mockRequestMiddleware{err: tt.middlewareErr})
|
writer := new(MockWriterOnly)
|
||||||
|
|
||||||
|
if tt.expectErr || tt.name == "valid http request" {
|
||||||
|
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
p := args.Get(0).([]byte)
|
||||||
|
copy(p, tt.input)
|
||||||
|
}).Return(len(tt.input), io.EOF).Once()
|
||||||
} else {
|
} else {
|
||||||
hs.UseRequestMiddleware(&mockRequestMiddleware{})
|
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
p := args.Get(0).([]byte)
|
||||||
|
copy(p, tt.input)
|
||||||
|
}).Return(len(tt.input), nil).Once()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hs := New(writer, reader, addr)
|
||||||
|
|
||||||
|
reqMW := new(MockRequestMiddleware)
|
||||||
|
if tt.isHTTP {
|
||||||
|
if tt.middlewareErr != nil {
|
||||||
|
reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr)
|
||||||
|
} else {
|
||||||
|
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
h := args.Get(0).(header.RequestHeader)
|
||||||
|
h.Set("X-Middleware", "true")
|
||||||
|
}).Return(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hs.UseRequestMiddleware(reqMW)
|
||||||
|
|
||||||
p := make([]byte, tt.readLen)
|
p := make([]byte, tt.readLen)
|
||||||
n, err := hs.Read(p)
|
n, err := hs.Read(p)
|
||||||
|
|
||||||
@@ -379,6 +531,11 @@ func TestRead(t *testing.T) {
|
|||||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
assert.Equal(t, tt.expectContent, string(p[:n]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tt.isHTTP {
|
||||||
|
reqMW.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
reader.AssertExpectations(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -390,6 +547,7 @@ func TestWrite(t *testing.T) {
|
|||||||
expectWritten string
|
expectWritten string
|
||||||
expectErr bool
|
expectErr bool
|
||||||
middlewareErr error
|
middlewareErr error
|
||||||
|
isHTTP bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "valid http response in one write",
|
name: "valid http response in one write",
|
||||||
@@ -397,6 +555,7 @@ func TestWrite(t *testing.T) {
|
|||||||
[]byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
|
[]byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
|
||||||
},
|
},
|
||||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
||||||
|
isHTTP: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid http response in multiple writes",
|
name: "valid http response in multiple writes",
|
||||||
@@ -406,6 +565,7 @@ func TestWrite(t *testing.T) {
|
|||||||
[]byte("Body"),
|
[]byte("Body"),
|
||||||
},
|
},
|
||||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
||||||
|
isHTTP: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "non-http data",
|
name: "non-http data",
|
||||||
@@ -413,6 +573,7 @@ func TestWrite(t *testing.T) {
|
|||||||
[]byte("Random data with delimiter\r\n\r\nFlush"),
|
[]byte("Random data with delimiter\r\n\r\nFlush"),
|
||||||
},
|
},
|
||||||
expectWritten: "Random data with delimiter\r\n\r\nFlush",
|
expectWritten: "Random data with delimiter\r\n\r\nFlush",
|
||||||
|
isHTTP: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "bypass buffering",
|
name: "bypass buffering",
|
||||||
@@ -422,6 +583,7 @@ func TestWrite(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
|
expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
|
||||||
"HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
|
"HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
|
||||||
|
isHTTP: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "middleware error",
|
name: "middleware error",
|
||||||
@@ -430,18 +592,40 @@ func TestWrite(t *testing.T) {
|
|||||||
},
|
},
|
||||||
middlewareErr: assert.AnError,
|
middlewareErr: assert.AnError,
|
||||||
expectErr: true,
|
expectErr: true,
|
||||||
|
isHTTP: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
rw := &mockReadWriter{}
|
addr := new(MockAddr)
|
||||||
hs := New(rw, rw, &mockAddr{})
|
addr.On("String").Return("1.2.3.4:1234")
|
||||||
if tt.middlewareErr != nil {
|
|
||||||
hs.UseResponseMiddleware(&mockResponseMiddleware{err: tt.middlewareErr})
|
var writtenData bytes.Buffer
|
||||||
} else {
|
writer := new(MockWriter)
|
||||||
hs.UseResponseMiddleware(&mockResponseMiddleware{})
|
|
||||||
|
writer.On("Write", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
p := args.Get(0).([]byte)
|
||||||
|
writtenData.Write(p)
|
||||||
|
}).Return(func(p []byte) int {
|
||||||
|
return len(p)
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
reader := new(MockReader)
|
||||||
|
hs := New(writer, reader, addr)
|
||||||
|
|
||||||
|
respMW := new(MockResponseMiddleware)
|
||||||
|
if tt.isHTTP {
|
||||||
|
if tt.middlewareErr != nil {
|
||||||
|
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr)
|
||||||
|
} else {
|
||||||
|
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
h := args.Get(0).(header.ResponseHeader)
|
||||||
|
h.Set("X-Resp-Middleware", "true")
|
||||||
|
}).Return(nil)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
hs.UseResponseMiddleware(respMW)
|
||||||
|
|
||||||
var totalN int
|
var totalN int
|
||||||
var err error
|
var err error
|
||||||
@@ -458,8 +642,8 @@ func TestWrite(t *testing.T) {
|
|||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
} else {
|
} else {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
written := writtenData.String()
|
||||||
if strings.HasPrefix(tt.expectWritten, "HTTP/") {
|
if strings.HasPrefix(tt.expectWritten, "HTTP/") {
|
||||||
written := rw.String()
|
|
||||||
assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
|
assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
|
||||||
assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
|
assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
|
||||||
if strings.Contains(tt.expectWritten, "Content-Length: 4") {
|
if strings.Contains(tt.expectWritten, "Content-Length: 4") {
|
||||||
@@ -467,43 +651,71 @@ func TestWrite(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
|
assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
|
||||||
} else {
|
} else {
|
||||||
assert.Equal(t, tt.expectWritten, rw.String())
|
assert.Equal(t, tt.expectWritten, written)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if tt.isHTTP {
|
||||||
|
respMW.AssertExpectations(t)
|
||||||
|
}
|
||||||
|
if tt.middlewareErr == nil {
|
||||||
|
writer.AssertExpectations(t)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWriteErrors(t *testing.T) {
|
func TestWriteErrors(t *testing.T) {
|
||||||
addr := &mockAddr{addr: "1.2.3.4:1234"}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
writer any
|
setup func() (io.Writer, io.Reader)
|
||||||
data []byte
|
data []byte
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "write error in writeHeaderAndBody",
|
name: "write error in writeHeaderAndBody",
|
||||||
writer: &mockErrorWriteCloser{},
|
setup: func() (io.Writer, io.Reader) {
|
||||||
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
writer := new(MockWriter)
|
||||||
|
writer.On("Write", mock.Anything).Return(0, assert.AnError)
|
||||||
|
reader := new(MockReader)
|
||||||
|
return writer, reader
|
||||||
|
},
|
||||||
|
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "write error in writeHeaderAndBody second write",
|
name: "write error in writeHeaderAndBody second write",
|
||||||
writer: &mockFailSecondWriteCloser{},
|
setup: func() (io.Writer, io.Reader) {
|
||||||
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
|
writer := new(MockWriter)
|
||||||
|
writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once()
|
||||||
|
writer.On("Write", mock.Anything).Return(0, assert.AnError).Once()
|
||||||
|
reader := new(MockReader)
|
||||||
|
return writer, reader
|
||||||
|
},
|
||||||
|
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "write error in writeRawBuffer",
|
name: "write error in writeRawBuffer",
|
||||||
writer: &mockErrorWriteCloser{},
|
setup: func() (io.Writer, io.Reader) {
|
||||||
data: []byte("Not HTTP\r\n\r\nFlush"),
|
writer := new(MockWriter)
|
||||||
|
writer.On("Write", mock.Anything).Return(0, assert.AnError)
|
||||||
|
reader := new(MockReader)
|
||||||
|
return writer, reader
|
||||||
|
},
|
||||||
|
data: []byte("Not HTTP\r\n\r\nFlush"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr)
|
addr := new(MockAddr)
|
||||||
|
addr.On("String").Return("1.2.3.4:1234")
|
||||||
|
|
||||||
|
w, r := tt.setup()
|
||||||
|
hs := New(w, r, addr)
|
||||||
|
|
||||||
_, err := hs.Write(tt.data)
|
_, err := hs.Write(tt.data)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
w.(*MockWriter).AssertExpectations(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -511,14 +723,21 @@ func TestWriteErrors(t *testing.T) {
|
|||||||
func TestReadEOF(t *testing.T) {
|
func TestReadEOF(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
reader io.Reader
|
setup func() io.Reader
|
||||||
expectN int
|
expectN int
|
||||||
expectErr error
|
expectErr error
|
||||||
expectContent string
|
expectContent string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "read eof",
|
name: "read eof",
|
||||||
reader: &mockEOFReader{},
|
setup: func() io.Reader {
|
||||||
|
reader := new(MockReader)
|
||||||
|
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
p := args.Get(0).([]byte)
|
||||||
|
copy(p, "data")
|
||||||
|
}).Return(4, io.EOF)
|
||||||
|
return reader
|
||||||
|
},
|
||||||
expectN: 4,
|
expectN: 4,
|
||||||
expectErr: io.EOF,
|
expectErr: io.EOF,
|
||||||
expectContent: "data",
|
expectContent: "data",
|
||||||
@@ -527,53 +746,20 @@ func TestReadEOF(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
hs := New(nil, tt.reader, &mockAddr{})
|
addr := new(MockAddr)
|
||||||
|
addr.On("String").Return("1.2.3.4:1234")
|
||||||
|
|
||||||
|
reader := tt.setup()
|
||||||
|
hs := New(nil, reader, addr)
|
||||||
|
|
||||||
p := make([]byte, 100)
|
p := make([]byte, 100)
|
||||||
n, err := hs.Read(p)
|
n, err := hs.Read(p)
|
||||||
|
|
||||||
assert.Equal(t, tt.expectN, n)
|
assert.Equal(t, tt.expectN, n)
|
||||||
assert.Equal(t, tt.expectErr, err)
|
assert.Equal(t, tt.expectErr, err)
|
||||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
assert.Equal(t, tt.expectContent, string(p[:n]))
|
||||||
|
|
||||||
|
reader.(*MockReader).AssertExpectations(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockEOFReader struct {
|
|
||||||
mockReadWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockEOFReader) Read(p []byte) (int, error) {
|
|
||||||
copy(p, "data")
|
|
||||||
return 4, io.EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
type mockFailSecondWriteCloser struct {
|
|
||||||
count int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockFailSecondWriteCloser) Write(p []byte) (int, error) {
|
|
||||||
m.count++
|
|
||||||
if m.count == 2 {
|
|
||||||
return 0, assert.AnError
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockFailSecondWriteCloser) Close() error { return nil }
|
|
||||||
func (m *mockFailSecondWriteCloser) Read(p []byte) (int, error) { return 0, nil }
|
|
||||||
|
|
||||||
type mockErrorWriteCloser struct {
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockErrorWriteCloser) Write(p []byte) (int, error) {
|
|
||||||
return 0, assert.AnError
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockErrorWriteCloser) Close() error {
|
|
||||||
m.closed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockErrorWriteCloser) Read(p []byte) (int, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user