diff --git a/internal/http/stream/stream_test.go b/internal/http/stream/stream_test.go index 65c660b..60c9d65 100644 --- a/internal/http/stream/stream_test.go +++ b/internal/http/stream/stream_test.go @@ -9,66 +9,160 @@ import ( "tunnel_pls/internal/http/header" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) -type mockAddr struct { - addr string +type MockAddr struct { + mock.Mock } -func (m *mockAddr) String() string { return m.addr } -func (m *mockAddr) Network() string { return "tcp" } - -type mockRequestMiddleware struct { - err error +func (m *MockAddr) String() string { + args := m.Called() + return args.String(0) } -func (m *mockRequestMiddleware) HandleRequest(h header.RequestHeader) error { - if m.err == nil { - h.Set("X-Middleware", "true") - } - return m.err +func (m *MockAddr) Network() string { + args := m.Called() + return args.String(0) } -type mockResponseMiddleware struct { - err error +type MockRequestMiddleware struct { + mock.Mock } -func (m *mockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error { - if m.err == nil { - h.Set("X-Resp-Middleware", "true") - } - return m.err +func (m *MockRequestMiddleware) HandleRequest(h header.RequestHeader) error { + args := m.Called(h) + return args.Error(0) } -type mockReadWriter struct { +type MockResponseMiddleware struct { + mock.Mock +} + +func (m *MockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error { + args := m.Called(h, body) + return args.Error(0) +} + +type MockReadWriter struct { + mock.Mock bytes.Buffer - closed bool - writeClosed bool } -func (m *mockReadWriter) Close() error { - m.closed = true - return nil +func (m *MockReadWriter) Read(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) } -func (m *mockReadWriter) CloseWrite() error { - m.writeClosed = true - return nil +func (m *MockReadWriter) Write(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *MockReadWriter) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockReadWriter) CloseWrite() error { + args := m.Called() + return args.Error(0) +} + +type MockReadWriterOnlyCloser struct { + mock.Mock + bytes.Buffer +} + +func (m *MockReadWriterOnlyCloser) Read(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *MockReadWriterOnlyCloser) Write(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *MockReadWriterOnlyCloser) Close() error { + args := m.Called() + return args.Error(0) +} + +type MockWriterOnly struct { + mock.Mock +} + +func (m *MockWriterOnly) Write(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +func (m *MockWriterOnly) Read(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +type MockReader struct { + mock.Mock +} + +func (m *MockReader) Read(p []byte) (int, error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +type MockWriter struct { + mock.Mock +} + +func (m *MockWriter) Write(p []byte) (int, error) { + ret := m.Called(p) + + var n int + var err error + + switch v := ret.Get(0).(type) { + case func([]byte) int: + n = v(p) + case int: + n = v + default: + n = len(p) + } + + switch v := ret.Get(1).(type) { + case func([]byte) error: + err = v(p) + case error: + err = v + default: + err = nil + } + + return n, err +} + +func (m *MockWriter) Close() error { + args := m.Called() + return args.Error(0) } func TestHTTPMethods(t *testing.T) { - addr := &mockAddr{addr: "1.2.3.4:1234"} - rw := &mockReadWriter{} + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + rw := new(MockReadWriter) hs := New(rw, rw, addr) assert.Equal(t, addr, hs.RemoteAddr()) - reqMW := &mockRequestMiddleware{} + reqMW := new(MockRequestMiddleware) hs.UseRequestMiddleware(reqMW) assert.Equal(t, 1, len(hs.RequestMiddlewares())) assert.Equal(t, reqMW, hs.RequestMiddlewares()[0]) - respMW := &mockResponseMiddleware{} + respMW := new(MockResponseMiddleware) hs.UseResponseMiddleware(respMW) assert.Equal(t, 1, len(hs.ResponseMiddlewares())) assert.Equal(t, respMW, hs.ResponseMiddlewares()[0]) @@ -78,19 +172,21 @@ func TestHTTPMethods(t *testing.T) { } func TestApplyMiddlewares(t *testing.T) { - addr := &mockAddr{addr: "1.2.3.4:1234"} - tests := []struct { name string - setup func(HTTP) + setup func(HTTP, *MockRequestMiddleware, *MockResponseMiddleware) apply func(HTTP, header.RequestHeader, header.ResponseHeader) error verify func(*testing.T, header.RequestHeader, header.ResponseHeader) expectErr bool }{ { name: "apply request middleware success", - setup: func(hs HTTP) { - hs.UseRequestMiddleware(&mockRequestMiddleware{}) + setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) { + reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) { + h := args.Get(0).(header.RequestHeader) + h.Set("X-Middleware", "true") + }).Return(nil) + hs.UseRequestMiddleware(reqMW) }, apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { return hs.ApplyRequestMiddlewares(reqH) @@ -101,8 +197,12 @@ func TestApplyMiddlewares(t *testing.T) { }, { name: "apply response middleware success", - setup: func(hs HTTP) { - hs.UseResponseMiddleware(&mockResponseMiddleware{}) + setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) { + respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + h := args.Get(0).(header.ResponseHeader) + h.Set("X-Resp-Middleware", "true") + }).Return(nil) + hs.UseResponseMiddleware(respMW) }, apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { return hs.ApplyResponseMiddlewares(respH, []byte("body")) @@ -113,8 +213,9 @@ func TestApplyMiddlewares(t *testing.T) { }, { name: "apply request middleware error", - setup: func(hs HTTP) { - hs.UseRequestMiddleware(&mockRequestMiddleware{err: assert.AnError}) + setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) { + reqMW.On("HandleRequest", mock.Anything).Return(assert.AnError) + hs.UseRequestMiddleware(reqMW) }, apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { return hs.ApplyRequestMiddlewares(reqH) @@ -123,8 +224,9 @@ func TestApplyMiddlewares(t *testing.T) { }, { name: "apply response middleware error", - setup: func(hs HTTP) { - hs.UseResponseMiddleware(&mockResponseMiddleware{err: assert.AnError}) + setup: func(hs HTTP, reqMW *MockRequestMiddleware, respMW *MockResponseMiddleware) { + respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(assert.AnError) + hs.UseResponseMiddleware(respMW) }, apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { return hs.ApplyResponseMiddlewares(respH, []byte("body")) @@ -137,9 +239,17 @@ func TestApplyMiddlewares(t *testing.T) { t.Run(tt.name, func(t *testing.T) { reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n")) respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n")) - rw := &mockReadWriter{} + + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + rw := new(MockReadWriter) hs := New(rw, rw, addr) - tt.setup(hs) + + reqMW := new(MockRequestMiddleware) + respMW := new(MockResponseMiddleware) + tt.setup(hs, reqMW, respMW) + err := tt.apply(hs, reqH, respH) if tt.expectErr { assert.Error(t, err) @@ -149,83 +259,96 @@ func TestApplyMiddlewares(t *testing.T) { tt.verify(t, reqH, respH) } } + + reqMW.AssertExpectations(t) + respMW.AssertExpectations(t) }) } } -type mockWriterOnly struct { - bytes.Buffer -} - func TestCloseMethods(t *testing.T) { - addr := &mockAddr{addr: "1.2.3.4:1234"} - tests := []struct { name string - writer any + setup func() (io.Writer, io.Reader) op func(HTTP) error - verify func(*testing.T, any) + verify func(*testing.T, io.Writer) }{ { - name: "Close success", - writer: &mockReadWriter{}, - op: func(hs HTTP) error { return hs.Close() }, - verify: func(t *testing.T, w any) { - assert.True(t, w.(*mockReadWriter).closed) + name: "Close success", + setup: func() (io.Writer, io.Reader) { + rw := new(MockReadWriter) + rw.On("Close").Return(nil) + return rw, rw + }, + op: func(hs HTTP) error { return hs.Close() }, + verify: func(t *testing.T, w io.Writer) { + w.(*MockReadWriter).AssertCalled(t, "Close") }, }, { - name: "CloseWrite with CloseWrite implementation", - writer: &mockReadWriter{}, - op: func(hs HTTP) error { return hs.CloseWrite() }, - verify: func(t *testing.T, w any) { - assert.True(t, w.(*mockReadWriter).writeClosed) + name: "CloseWrite with CloseWrite implementation", + setup: func() (io.Writer, io.Reader) { + rw := new(MockReadWriter) + rw.On("CloseWrite").Return(nil) + return rw, rw + }, + op: func(hs HTTP) error { return hs.CloseWrite() }, + verify: func(t *testing.T, w io.Writer) { + w.(*MockReadWriter).AssertCalled(t, "CloseWrite") }, }, { - name: "CloseWrite fallback to Close", - writer: &mockReadWriterOnlyCloser{}, - op: func(hs HTTP) error { return hs.CloseWrite() }, - verify: func(t *testing.T, w any) { - assert.True(t, w.(*mockReadWriterOnlyCloser).closed) + name: "CloseWrite fallback to Close", + setup: func() (io.Writer, io.Reader) { + rw := new(MockReadWriterOnlyCloser) + rw.On("Close").Return(nil) + return rw, rw + }, + op: func(hs HTTP) error { return hs.CloseWrite() }, + verify: func(t *testing.T, w io.Writer) { + w.(*MockReadWriterOnlyCloser).AssertCalled(t, "Close") }, }, { - name: "Close with No Closer", - writer: &mockWriterOnly{}, - op: func(hs HTTP) error { return hs.Close() }, + name: "Close with No Closer", + setup: func() (io.Writer, io.Reader) { + w := new(MockWriterOnly) + r := new(MockReader) + return w, r + }, + op: func(hs HTTP) error { return hs.Close() }, }, { - name: "CloseWrite with No CloseWrite and No Closer", - writer: &mockWriterOnly{}, - op: func(hs HTTP) error { return hs.CloseWrite() }, + name: "CloseWrite with No CloseWrite and No Closer", + setup: func() (io.Writer, io.Reader) { + w := new(MockWriterOnly) + r := new(MockReader) + return w, r + }, + op: func(hs HTTP) error { return hs.CloseWrite() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr) + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + w, r := tt.setup() + hs := New(w, r, addr) + assert.NotPanics(t, func() { err := tt.op(hs) assert.NoError(t, err) }) + if tt.verify != nil { - tt.verify(t, tt.writer) + tt.verify(t, w) } }) } } -type mockReadWriterOnlyCloser struct { - bytes.Buffer - closed bool -} - -func (m *mockReadWriterOnlyCloser) Close() error { - m.closed = true - return nil -} - func TestSplitHeaderAndBody(t *testing.T) { tests := []struct { name string @@ -319,6 +442,7 @@ func TestRead(t *testing.T) { expectRead int expectErr bool middlewareErr error + isHTTP bool }{ { name: "valid http request", @@ -326,6 +450,7 @@ func TestRead(t *testing.T) { readLen: 100, expectContent: "Body", expectRead: 54, + isHTTP: true, }, { name: "non-http data", @@ -333,6 +458,7 @@ func TestRead(t *testing.T) { readLen: 100, expectContent: "Some random data\r\n\r\nMore data", expectRead: 29, + isHTTP: false, }, { name: "no delimiter", @@ -340,6 +466,7 @@ func TestRead(t *testing.T) { readLen: 100, expectContent: "Partial data without delimiter", expectRead: 30, + isHTTP: false, }, { name: "middleware error", @@ -347,20 +474,45 @@ func TestRead(t *testing.T) { readLen: 100, middlewareErr: assert.AnError, expectErr: true, + isHTTP: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rw := &mockReadWriter{} - rw.Write(tt.input) - hs := New(rw, rw, &mockAddr{}) - if tt.middlewareErr != nil { - hs.UseRequestMiddleware(&mockRequestMiddleware{err: tt.middlewareErr}) + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + reader := new(MockReader) + writer := new(MockWriterOnly) + + if tt.expectErr || tt.name == "valid http request" { + reader.On("Read", mock.Anything).Run(func(args mock.Arguments) { + p := args.Get(0).([]byte) + copy(p, tt.input) + }).Return(len(tt.input), io.EOF).Once() } else { - hs.UseRequestMiddleware(&mockRequestMiddleware{}) + reader.On("Read", mock.Anything).Run(func(args mock.Arguments) { + p := args.Get(0).([]byte) + copy(p, tt.input) + }).Return(len(tt.input), nil).Once() } + hs := New(writer, reader, addr) + + reqMW := new(MockRequestMiddleware) + if tt.isHTTP { + if tt.middlewareErr != nil { + reqMW.On("HandleRequest", mock.Anything).Return(tt.middlewareErr) + } else { + reqMW.On("HandleRequest", mock.Anything).Run(func(args mock.Arguments) { + h := args.Get(0).(header.RequestHeader) + h.Set("X-Middleware", "true") + }).Return(nil) + } + } + hs.UseRequestMiddleware(reqMW) + p := make([]byte, tt.readLen) n, err := hs.Read(p) @@ -379,6 +531,11 @@ func TestRead(t *testing.T) { assert.Equal(t, tt.expectContent, string(p[:n])) } } + + if tt.isHTTP { + reqMW.AssertExpectations(t) + } + reader.AssertExpectations(t) }) } } @@ -390,6 +547,7 @@ func TestWrite(t *testing.T) { expectWritten string expectErr bool middlewareErr error + isHTTP bool }{ { name: "valid http response in one write", @@ -397,6 +555,7 @@ func TestWrite(t *testing.T) { []byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"), }, expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody", + isHTTP: true, }, { name: "valid http response in multiple writes", @@ -406,6 +565,7 @@ func TestWrite(t *testing.T) { []byte("Body"), }, expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody", + isHTTP: true, }, { name: "non-http data", @@ -413,6 +573,7 @@ func TestWrite(t *testing.T) { []byte("Random data with delimiter\r\n\r\nFlush"), }, expectWritten: "Random data with delimiter\r\n\r\nFlush", + isHTTP: false, }, { name: "bypass buffering", @@ -422,6 +583,7 @@ func TestWrite(t *testing.T) { }, expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" + "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n", + isHTTP: true, }, { name: "middleware error", @@ -430,18 +592,40 @@ func TestWrite(t *testing.T) { }, middlewareErr: assert.AnError, expectErr: true, + isHTTP: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rw := &mockReadWriter{} - hs := New(rw, rw, &mockAddr{}) - if tt.middlewareErr != nil { - hs.UseResponseMiddleware(&mockResponseMiddleware{err: tt.middlewareErr}) - } else { - hs.UseResponseMiddleware(&mockResponseMiddleware{}) + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + var writtenData bytes.Buffer + writer := new(MockWriter) + + writer.On("Write", mock.Anything).Run(func(args mock.Arguments) { + p := args.Get(0).([]byte) + writtenData.Write(p) + }).Return(func(p []byte) int { + return len(p) + }, nil) + + reader := new(MockReader) + hs := New(writer, reader, addr) + + respMW := new(MockResponseMiddleware) + if tt.isHTTP { + if tt.middlewareErr != nil { + respMW.On("HandleResponse", mock.Anything, mock.Anything).Return(tt.middlewareErr) + } else { + respMW.On("HandleResponse", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + h := args.Get(0).(header.ResponseHeader) + h.Set("X-Resp-Middleware", "true") + }).Return(nil) + } } + hs.UseResponseMiddleware(respMW) var totalN int var err error @@ -458,8 +642,8 @@ func TestWrite(t *testing.T) { assert.Error(t, err) } else { assert.NoError(t, err) + written := writtenData.String() if strings.HasPrefix(tt.expectWritten, "HTTP/") { - written := rw.String() assert.Contains(t, written, "HTTP/1.1 200 OK\r\n") assert.Contains(t, written, "X-Resp-Middleware: true\r\n") if strings.Contains(tt.expectWritten, "Content-Length: 4") { @@ -467,43 +651,71 @@ func TestWrite(t *testing.T) { } assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n")) } else { - assert.Equal(t, tt.expectWritten, rw.String()) + assert.Equal(t, tt.expectWritten, written) } } + + if tt.isHTTP { + respMW.AssertExpectations(t) + } + if tt.middlewareErr == nil { + writer.AssertExpectations(t) + } }) } } func TestWriteErrors(t *testing.T) { - addr := &mockAddr{addr: "1.2.3.4:1234"} - tests := []struct { - name string - writer any - data []byte + name string + setup func() (io.Writer, io.Reader) + data []byte }{ { - name: "write error in writeHeaderAndBody", - writer: &mockErrorWriteCloser{}, - data: []byte("HTTP/1.1 200 OK\r\n\r\n"), + name: "write error in writeHeaderAndBody", + setup: func() (io.Writer, io.Reader) { + writer := new(MockWriter) + writer.On("Write", mock.Anything).Return(0, assert.AnError) + reader := new(MockReader) + return writer, reader + }, + data: []byte("HTTP/1.1 200 OK\r\n\r\n"), }, { - name: "write error in writeHeaderAndBody second write", - writer: &mockFailSecondWriteCloser{}, - data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"), + name: "write error in writeHeaderAndBody second write", + setup: func() (io.Writer, io.Reader) { + writer := new(MockWriter) + writer.On("Write", mock.Anything).Return(len([]byte("HTTP/1.1 200 OK\r\n\r\n")), nil).Once() + writer.On("Write", mock.Anything).Return(0, assert.AnError).Once() + reader := new(MockReader) + return writer, reader + }, + data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"), }, { - name: "write error in writeRawBuffer", - writer: &mockErrorWriteCloser{}, - data: []byte("Not HTTP\r\n\r\nFlush"), + name: "write error in writeRawBuffer", + setup: func() (io.Writer, io.Reader) { + writer := new(MockWriter) + writer.On("Write", mock.Anything).Return(0, assert.AnError) + reader := new(MockReader) + return writer, reader + }, + data: []byte("Not HTTP\r\n\r\nFlush"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr) + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + w, r := tt.setup() + hs := New(w, r, addr) + _, err := hs.Write(tt.data) assert.Error(t, err) + + w.(*MockWriter).AssertExpectations(t) }) } } @@ -511,14 +723,21 @@ func TestWriteErrors(t *testing.T) { func TestReadEOF(t *testing.T) { tests := []struct { name string - reader io.Reader + setup func() io.Reader expectN int expectErr error expectContent string }{ { - name: "read eof", - reader: &mockEOFReader{}, + name: "read eof", + setup: func() io.Reader { + reader := new(MockReader) + reader.On("Read", mock.Anything).Run(func(args mock.Arguments) { + p := args.Get(0).([]byte) + copy(p, "data") + }).Return(4, io.EOF) + return reader + }, expectN: 4, expectErr: io.EOF, expectContent: "data", @@ -527,53 +746,20 @@ func TestReadEOF(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - hs := New(nil, tt.reader, &mockAddr{}) + addr := new(MockAddr) + addr.On("String").Return("1.2.3.4:1234") + + reader := tt.setup() + hs := New(nil, reader, addr) + p := make([]byte, 100) n, err := hs.Read(p) + assert.Equal(t, tt.expectN, n) assert.Equal(t, tt.expectErr, err) assert.Equal(t, tt.expectContent, string(p[:n])) + + reader.(*MockReader).AssertExpectations(t) }) } } - -type mockEOFReader struct { - mockReadWriter -} - -func (m *mockEOFReader) Read(p []byte) (int, error) { - copy(p, "data") - return 4, io.EOF -} - -type mockFailSecondWriteCloser struct { - count int -} - -func (m *mockFailSecondWriteCloser) Write(p []byte) (int, error) { - m.count++ - if m.count == 2 { - return 0, assert.AnError - } - return len(p), nil -} - -func (m *mockFailSecondWriteCloser) Close() error { return nil } -func (m *mockFailSecondWriteCloser) Read(p []byte) (int, error) { return 0, nil } - -type mockErrorWriteCloser struct { - closed bool -} - -func (m *mockErrorWriteCloser) Write(p []byte) (int, error) { - return 0, assert.AnError -} - -func (m *mockErrorWriteCloser) Close() error { - m.closed = true - return nil -} - -func (m *mockErrorWriteCloser) Read(p []byte) (int, error) { - return 0, nil -}