test(stream): migrate mocking to testify
SonarQube Scan / SonarQube Trigger (push) Successful in 2m21s

This commit is contained in:
2026-01-25 18:14:29 +07:00
parent 8b44e4db4e
commit 9cdce24030
+346 -160
View File
@@ -9,66 +9,160 @@ import (
"tunnel_pls/internal/http/header" "tunnel_pls/internal/http/header"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
) )
type mockAddr struct { type MockAddr struct {
addr string mock.Mock
} }
func (m *mockAddr) String() string { return m.addr } func (m *MockAddr) String() string {
func (m *mockAddr) Network() string { return "tcp" } args := m.Called()
return args.String(0)
type mockRequestMiddleware struct {
err error
} }
func (m *mockRequestMiddleware) HandleRequest(h header.RequestHeader) error { func (m *MockAddr) Network() string {
if m.err == nil { args := m.Called()
h.Set("X-Middleware", "true") return args.String(0)
}
return m.err
} }
type mockResponseMiddleware struct { type MockRequestMiddleware struct {
err error mock.Mock
} }
func (m *mockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error { func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
if m.err == nil { args := m.Called(h)
h.Set("X-Resp-Middleware", "true") return args.Error(0)
}
return m.err
} }
type mockReadWriter struct { type MockResponseMiddleware struct {
mock.Mock
}
func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
args := m.Called(h, body)
return args.Error(0)
}
type MockReadWriter struct {
mock.Mock
bytes.Buffer bytes.Buffer
closed bool
writeClosed bool
} }
func (m *mockReadWriter) Close() error { func (m *MockReadWriter) Read(p []byte) (int, error) {
m.closed = true args := m.Called(p)
return nil return args.Int(0), args.Error(1)
} }
func (m *mockReadWriter) CloseWrite() error { func (m *MockReadWriter) Write(p []byte) (int, error) {
m.writeClosed = true args := m.Called(p)
return nil return args.Int(0), args.Error(1)
}
func (m *MockReadWriter) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockReadWriter) CloseWrite() error {
args := m.Called()
return args.Error(0)
}
type MockReadWriterOnlyCloser struct {
mock.Mock
bytes.Buffer
}
func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockReadWriterOnlyCloser) Close() error {
args := m.Called()
return args.Error(0)
}
type MockWriterOnly struct {
mock.Mock
}
func (m *MockWriterOnly) Write(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
func (m *MockWriterOnly) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
type MockReader struct {
mock.Mock
}
func (m *MockReader) Read(p []byte) (int, error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}
type MockWriter struct {
mock.Mock
}
func (m *MockWriter) Write(p []byte) (int, error) {
ret := m.Called(p)
var n int
var err error
switch v := ret.Get(0).(type) {
case func([]byte) int:
n = v(p)
case int:
n = v
default:
n = len(p)
}
switch v := ret.Get(1).(type) {
case func([]byte) error:
err = v(p)
case error:
err = v
default:
err = nil
}
return n, err
}
func (m *MockWriter) Close() error {
args := m.Called()
return args.Error(0)
} }
func TestHTTPMethods(t *testing.T) { func TestHTTPMethods(t *testing.T) {
addr := &mockAddr{addr: "1.2.3.4:1234"} addr := new(MockAddr)
rw := &mockReadWriter{} addr.On("String").Return("1.2.3.4:1234")
rw := new(MockReadWriter)
hs := New(rw, rw, addr) hs := New(rw, rw, addr)
assert.Equal(t, addr, hs.RemoteAddr()) assert.Equal(t, addr, hs.RemoteAddr())
reqMW := &mockRequestMiddleware{} reqMW := new(MockRequestMiddleware)
hs.UseRequestMiddleware(reqMW) hs.UseRequestMiddleware(reqMW)
assert.Equal(t, 1, len(hs.RequestMiddlewares())) assert.Equal(t, 1, len(hs.RequestMiddlewares()))
assert.Equal(t, reqMW, hs.RequestMiddlewares()[0]) assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
respMW := &mockResponseMiddleware{} respMW := new(MockResponseMiddleware)
hs.UseResponseMiddleware(respMW) hs.UseResponseMiddleware(respMW)
assert.Equal(t, 1, len(hs.ResponseMiddlewares())) assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
assert.Equal(t, respMW, hs.ResponseMiddlewares()[0]) assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
@@ -78,19 +172,21 @@ func TestHTTPMethods(t *testing.T) {
} }
func TestApplyMiddlewares(t *testing.T) { func TestApplyMiddlewares(t *testing.T) {
addr := &mockAddr{addr: "1.2.3.4:1234"}
tests := []struct { tests := []struct {
name string name string
setup func(HTTP) setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware)
apply func(HTTP, header.RequestHeader, header.ResponseHeader) error apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
verify func(*testing.T, header.RequestHeader, header.ResponseHeader) verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
expectErr bool expectErr bool
}{ }{
{ {
name: "apply request middleware success", name: "apply request middleware success",
setup: func(hs HTTP) { setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
hs.UseRequestMiddleware(&mockRequestMiddleware{}) reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.RequestHeader)
h.Set("X-Middleware", "true")
}).Return(nil)
hs.UseRequestMiddleware(reqMW)
}, },
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyRequestMiddlewares(reqH) return hs.ApplyRequestMiddlewares(reqH)
@@ -101,8 +197,12 @@ func TestApplyMiddlewares(t *testing.T) {
}, },
{ {
name: "apply response middleware success", name: "apply response middleware success",
setup: func(hs HTTP) { setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
hs.UseResponseMiddleware(&mockResponseMiddleware{}) respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.ResponseHeader)
h.Set("X-Resp-Middleware", "true")
}).Return(nil)
hs.UseResponseMiddleware(respMW)
}, },
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyResponseMiddlewares(respH, []byte("body")) return hs.ApplyResponseMiddlewares(respH, []byte("body"))
@@ -113,8 +213,9 @@ func TestApplyMiddlewares(t *testing.T) {
}, },
{ {
name: "apply request middleware error", name: "apply request middleware error",
setup: func(hs HTTP) { setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
hs.UseRequestMiddleware(&mockRequestMiddleware{err: assert.AnError}) reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError)
hs.UseRequestMiddleware(reqMW)
}, },
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyRequestMiddlewares(reqH) return hs.ApplyRequestMiddlewares(reqH)
@@ -123,8 +224,9 @@ func TestApplyMiddlewares(t *testing.T) {
}, },
{ {
name: "apply response middleware error", name: "apply response middleware error",
setup: func(hs HTTP) { setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) {
hs.UseResponseMiddleware(&mockResponseMiddleware{err: assert.AnError}) respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError)
hs.UseResponseMiddleware(respMW)
}, },
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
return hs.ApplyResponseMiddlewares(respH, []byte("body")) return hs.ApplyResponseMiddlewares(respH, []byte("body"))
@@ -137,9 +239,17 @@ func TestApplyMiddlewares(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n")) reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n")) respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
rw := &mockReadWriter{}
addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
rw := new(MockReadWriter)
hs := New(rw, rw, addr) hs := New(rw, rw, addr)
tt.setup(hs)
reqMW := new(MockRequestMiddleware)
respMW := new(MockResponseMiddleware)
tt.setup(hs, reqMW, respMW)
err := tt.apply(hs, reqH, respH) err := tt.apply(hs, reqH, respH)
if tt.expectErr { if tt.expectErr {
assert.Error(t, err) assert.Error(t, err)
@@ -149,83 +259,96 @@ func TestApplyMiddlewares(t *testing.T) {
tt.verify(t, reqH, respH) tt.verify(t, reqH, respH)
} }
} }
reqMW.AssertExpectations(t)
respMW.AssertExpectations(t)
}) })
} }
} }
type mockWriterOnly struct {
bytes.Buffer
}
func TestCloseMethods(t *testing.T) { func TestCloseMethods(t *testing.T) {
addr := &mockAddr{addr: "1.2.3.4:1234"}
tests := []struct { tests := []struct {
name string name string
writer any setup func() (io.Writer, io.Reader)
op func(HTTP) error op func(HTTP) error
verify func(*testing.T, any) verify func(*testing.T, io.Writer)
}{ }{
{ {
name: "Close success", name: "Close success",
writer: &mockReadWriter{}, setup: func() (io.Writer, io.Reader) {
op: func(hs HTTP) error { return hs.Close() }, rw := new(MockReadWriter)
verify: func(t *testing.T, w any) { rw.On("Close").Return(nil)
assert.True(t, w.(*mockReadWriter).closed) return rw, rw
},
op: func(hs HTTP) error { return hs.Close() },
verify: func(t *testing.T, w io.Writer) {
w.(*MockReadWriter).AssertCalled(t, "Close")
}, },
}, },
{ {
name: "CloseWrite with CloseWrite implementation", name: "CloseWrite with CloseWrite implementation",
writer: &mockReadWriter{}, setup: func() (io.Writer, io.Reader) {
op: func(hs HTTP) error { return hs.CloseWrite() }, rw := new(MockReadWriter)
verify: func(t *testing.T, w any) { rw.On("CloseWrite").Return(nil)
assert.True(t, w.(*mockReadWriter).writeClosed) return rw, rw
},
op: func(hs HTTP) error { return hs.CloseWrite() },
verify: func(t *testing.T, w io.Writer) {
w.(*MockReadWriter).AssertCalled(t, "CloseWrite")
}, },
}, },
{ {
name: "CloseWrite fallback to Close", name: "CloseWrite fallback to Close",
writer: &mockReadWriterOnlyCloser{}, setup: func() (io.Writer, io.Reader) {
op: func(hs HTTP) error { return hs.CloseWrite() }, rw := new(MockReadWriterOnlyCloser)
verify: func(t *testing.T, w any) { rw.On("Close").Return(nil)
assert.True(t, w.(*mockReadWriterOnlyCloser).closed) return rw, rw
},
op: func(hs HTTP) error { return hs.CloseWrite() },
verify: func(t *testing.T, w io.Writer) {
w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close")
}, },
}, },
{ {
name: "Close with No Closer", name: "Close with No Closer",
writer: &mockWriterOnly{}, setup: func() (io.Writer, io.Reader) {
op: func(hs HTTP) error { return hs.Close() }, w := new(MockWriterOnly)
r := new(MockReader)
return w, r
},
op: func(hs HTTP) error { return hs.Close() },
}, },
{ {
name: "CloseWrite with No CloseWrite and No Closer", name: "CloseWrite with No CloseWrite and No Closer",
writer: &mockWriterOnly{}, setup: func() (io.Writer, io.Reader) {
op: func(hs HTTP) error { return hs.CloseWrite() }, w := new(MockWriterOnly)
r := new(MockReader)
return w, r
},
op: func(hs HTTP) error { return hs.CloseWrite() },
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr) addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
w, r := tt.setup()
hs := New(w, r, addr)
assert.NotPanics(t, func() { assert.NotPanics(t, func() {
err := tt.op(hs) err := tt.op(hs)
assert.NoError(t, err) assert.NoError(t, err)
}) })
if tt.verify != nil { if tt.verify != nil {
tt.verify(t, tt.writer) tt.verify(t, w)
} }
}) })
} }
} }
type mockReadWriterOnlyCloser struct {
bytes.Buffer
closed bool
}
func (m *mockReadWriterOnlyCloser) Close() error {
m.closed = true
return nil
}
func TestSplitHeaderAndBody(t *testing.T) { func TestSplitHeaderAndBody(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -319,6 +442,7 @@ func TestRead(t *testing.T) {
expectRead int expectRead int
expectErr bool expectErr bool
middlewareErr error middlewareErr error
isHTTP bool
}{ }{
{ {
name: "valid http request", name: "valid http request",
@@ -326,6 +450,7 @@ func TestRead(t *testing.T) {
readLen: 100, readLen: 100,
expectContent: "Body", expectContent: "Body",
expectRead: 54, expectRead: 54,
isHTTP: true,
}, },
{ {
name: "non-http data", name: "non-http data",
@@ -333,6 +458,7 @@ func TestRead(t *testing.T) {
readLen: 100, readLen: 100,
expectContent: "Some random data\r\n\r\nMore data", expectContent: "Some random data\r\n\r\nMore data",
expectRead: 29, expectRead: 29,
isHTTP: false,
}, },
{ {
name: "no delimiter", name: "no delimiter",
@@ -340,6 +466,7 @@ func TestRead(t *testing.T) {
readLen: 100, readLen: 100,
expectContent: "Partial data without delimiter", expectContent: "Partial data without delimiter",
expectRead: 30, expectRead: 30,
isHTTP: false,
}, },
{ {
name: "middleware error", name: "middleware error",
@@ -347,20 +474,45 @@ func TestRead(t *testing.T) {
readLen: 100, readLen: 100,
middlewareErr: assert.AnError, middlewareErr: assert.AnError,
expectErr: true, expectErr: true,
isHTTP: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
rw := &mockReadWriter{} addr := new(MockAddr)
rw.Write(tt.input) addr.On("String").Return("1.2.3.4:1234")
hs := New(rw, rw, &mockAddr{})
if tt.middlewareErr != nil { reader := new(MockReader)
hs.UseRequestMiddleware(&mockRequestMiddleware{err: tt.middlewareErr}) writer := new(MockWriterOnly)
if tt.expectErr || tt.name == "valid http request" {
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
copy(p, tt.input)
}).Return(len(tt.input), io.EOF).Once()
} else { } else {
hs.UseRequestMiddleware(&mockRequestMiddleware{}) reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
copy(p, tt.input)
}).Return(len(tt.input), nil).Once()
} }
hs := New(writer, reader, addr)
reqMW := new(MockRequestMiddleware)
if tt.isHTTP {
if tt.middlewareErr != nil {
reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr)
} else {
reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.RequestHeader)
h.Set("X-Middleware", "true")
}).Return(nil)
}
}
hs.UseRequestMiddleware(reqMW)
p := make([]byte, tt.readLen) p := make([]byte, tt.readLen)
n, err := hs.Read(p) n, err := hs.Read(p)
@@ -379,6 +531,11 @@ func TestRead(t *testing.T) {
assert.Equal(t, tt.expectContent, string(p[:n])) assert.Equal(t, tt.expectContent, string(p[:n]))
} }
} }
if tt.isHTTP {
reqMW.AssertExpectations(t)
}
reader.AssertExpectations(t)
}) })
} }
} }
@@ -390,6 +547,7 @@ func TestWrite(t *testing.T) {
expectWritten string expectWritten string
expectErr bool expectErr bool
middlewareErr error middlewareErr error
isHTTP bool
}{ }{
{ {
name: "valid http response in one write", name: "valid http response in one write",
@@ -397,6 +555,7 @@ func TestWrite(t *testing.T) {
[]byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"), []byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
}, },
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody", expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
isHTTP: true,
}, },
{ {
name: "valid http response in multiple writes", name: "valid http response in multiple writes",
@@ -406,6 +565,7 @@ func TestWrite(t *testing.T) {
[]byte("Body"), []byte("Body"),
}, },
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody", expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
isHTTP: true,
}, },
{ {
name: "non-http data", name: "non-http data",
@@ -413,6 +573,7 @@ func TestWrite(t *testing.T) {
[]byte("Random data with delimiter\r\n\r\nFlush"), []byte("Random data with delimiter\r\n\r\nFlush"),
}, },
expectWritten: "Random data with delimiter\r\n\r\nFlush", expectWritten: "Random data with delimiter\r\n\r\nFlush",
isHTTP: false,
}, },
{ {
name: "bypass buffering", name: "bypass buffering",
@@ -422,6 +583,7 @@ func TestWrite(t *testing.T) {
}, },
expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" + expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
"HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n", "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
isHTTP: true,
}, },
{ {
name: "middleware error", name: "middleware error",
@@ -430,18 +592,40 @@ func TestWrite(t *testing.T) {
}, },
middlewareErr: assert.AnError, middlewareErr: assert.AnError,
expectErr: true, expectErr: true,
isHTTP: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
rw := &mockReadWriter{} addr := new(MockAddr)
hs := New(rw, rw, &mockAddr{}) addr.On("String").Return("1.2.3.4:1234")
if tt.middlewareErr != nil {
hs.UseResponseMiddleware(&mockResponseMiddleware{err: tt.middlewareErr}) var writtenData bytes.Buffer
} else { writer := new(MockWriter)
hs.UseResponseMiddleware(&mockResponseMiddleware{})
writer.On("Write", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
writtenData.Write(p)
}).Return(func(p []byte) int {
return len(p)
}, nil)
reader := new(MockReader)
hs := New(writer, reader, addr)
respMW := new(MockResponseMiddleware)
if tt.isHTTP {
if tt.middlewareErr != nil {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr)
} else {
respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
h := args.Get(0).(header.ResponseHeader)
h.Set("X-Resp-Middleware", "true")
}).Return(nil)
}
} }
hs.UseResponseMiddleware(respMW)
var totalN int var totalN int
var err error var err error
@@ -458,8 +642,8 @@ func TestWrite(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
written := writtenData.String()
if strings.HasPrefix(tt.expectWritten, "HTTP/") { if strings.HasPrefix(tt.expectWritten, "HTTP/") {
written := rw.String()
assert.Contains(t, written, "HTTP/1.1 200 OK\r\n") assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
assert.Contains(t, written, "X-Resp-Middleware: true\r\n") assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
if strings.Contains(tt.expectWritten, "Content-Length: 4") { if strings.Contains(tt.expectWritten, "Content-Length: 4") {
@@ -467,43 +651,71 @@ func TestWrite(t *testing.T) {
} }
assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n")) assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
} else { } else {
assert.Equal(t, tt.expectWritten, rw.String()) assert.Equal(t, tt.expectWritten, written)
} }
} }
if tt.isHTTP {
respMW.AssertExpectations(t)
}
if tt.middlewareErr == nil {
writer.AssertExpectations(t)
}
}) })
} }
} }
func TestWriteErrors(t *testing.T) { func TestWriteErrors(t *testing.T) {
addr := &mockAddr{addr: "1.2.3.4:1234"}
tests := []struct { tests := []struct {
name string name string
writer any setup func() (io.Writer, io.Reader)
data []byte data []byte
}{ }{
{ {
name: "write error in writeHeaderAndBody", name: "write error in writeHeaderAndBody",
writer: &mockErrorWriteCloser{}, setup: func() (io.Writer, io.Reader) {
data: []byte("HTTP/1.1 200 OK\r\n\r\n"), writer := new(MockWriter)
writer.On("Write", mock.Anything).Return(0, assert.AnError)
reader := new(MockReader)
return writer, reader
},
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
}, },
{ {
name: "write error in writeHeaderAndBody second write", name: "write error in writeHeaderAndBody second write",
writer: &mockFailSecondWriteCloser{}, setup: func() (io.Writer, io.Reader) {
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"), writer := new(MockWriter)
writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once()
writer.On("Write", mock.Anything).Return(0, assert.AnError).Once()
reader := new(MockReader)
return writer, reader
},
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
}, },
{ {
name: "write error in writeRawBuffer", name: "write error in writeRawBuffer",
writer: &mockErrorWriteCloser{}, setup: func() (io.Writer, io.Reader) {
data: []byte("Not HTTP\r\n\r\nFlush"), writer := new(MockWriter)
writer.On("Write", mock.Anything).Return(0, assert.AnError)
reader := new(MockReader)
return writer, reader
},
data: []byte("Not HTTP\r\n\r\nFlush"),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr) addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
w, r := tt.setup()
hs := New(w, r, addr)
_, err := hs.Write(tt.data) _, err := hs.Write(tt.data)
assert.Error(t, err) assert.Error(t, err)
w.(*MockWriter).AssertExpectations(t)
}) })
} }
} }
@@ -511,14 +723,21 @@ func TestWriteErrors(t *testing.T) {
func TestReadEOF(t *testing.T) { func TestReadEOF(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
reader io.Reader setup func() io.Reader
expectN int expectN int
expectErr error expectErr error
expectContent string expectContent string
}{ }{
{ {
name: "read eof", name: "read eof",
reader: &mockEOFReader{}, setup: func() io.Reader {
reader := new(MockReader)
reader.On("Read", mock.Anything).Run(func(args mock.Arguments) {
p := args.Get(0).([]byte)
copy(p, "data")
}).Return(4, io.EOF)
return reader
},
expectN: 4, expectN: 4,
expectErr: io.EOF, expectErr: io.EOF,
expectContent: "data", expectContent: "data",
@@ -527,53 +746,20 @@ func TestReadEOF(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
hs := New(nil, tt.reader, &mockAddr{}) addr := new(MockAddr)
addr.On("String").Return("1.2.3.4:1234")
reader := tt.setup()
hs := New(nil, reader, addr)
p := make([]byte, 100) p := make([]byte, 100)
n, err := hs.Read(p) n, err := hs.Read(p)
assert.Equal(t, tt.expectN, n) assert.Equal(t, tt.expectN, n)
assert.Equal(t, tt.expectErr, err) assert.Equal(t, tt.expectErr, err)
assert.Equal(t, tt.expectContent, string(p[:n])) assert.Equal(t, tt.expectContent, string(p[:n]))
reader.(*MockReader).AssertExpectations(t)
}) })
} }
} }
type mockEOFReader struct {
mockReadWriter
}
func (m *mockEOFReader) Read(p []byte) (int, error) {
copy(p, "data")
return 4, io.EOF
}
type mockFailSecondWriteCloser struct {
count int
}
func (m *mockFailSecondWriteCloser) Write(p []byte) (int, error) {
m.count++
if m.count == 2 {
return 0, assert.AnError
}
return len(p), nil
}
func (m *mockFailSecondWriteCloser) Close() error { return nil }
func (m *mockFailSecondWriteCloser) Read(p []byte) (int, error) { return 0, nil }
type mockErrorWriteCloser struct {
closed bool
}
func (m *mockErrorWriteCloser) Write(p []byte) (int, error) {
return 0, assert.AnError
}
func (m *mockErrorWriteCloser) Close() error {
m.closed = true
return nil
}
func (m *mockErrorWriteCloser) Read(p []byte) (int, error) {
return 0, nil
}