test(stream): add unit tests for stream behavior
- Fix duplicating EOF error when closing SSH connection - Add new SessionStatusCLOSED type
This commit is contained in:
@@ -317,11 +317,28 @@ func TestSetRemainingHeaders(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParseHeadersFromReaderEdgeCases(t *testing.T) {
|
||||
t.Run("malformed header line", func(t *testing.T) {
|
||||
data := []byte("GET / HTTP/1.1\r\nMalformedLine\r\nK1: V1\r\n\r\n")
|
||||
br := bufio.NewReader(bytes.NewReader(data))
|
||||
req, err := parseHeadersFromReader(br)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "V1", req.Value("K1"))
|
||||
})
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expectHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "malformed header line",
|
||||
data: []byte("GET / HTTP/1.1\r\nMalformedLine\r\nK1: V1\r\n\r\n"),
|
||||
expectHeaders: map[string]string{
|
||||
"K1": "V1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
br := bufio.NewReader(bytes.NewReader(tt.data))
|
||||
req, err := parseHeadersFromReader(br)
|
||||
assert.NoError(t, err)
|
||||
for k, v := range tt.expectHeaders {
|
||||
assert.Equal(t, v, req.Value(k))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -72,7 +72,10 @@ func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware {
|
||||
}
|
||||
|
||||
func (hs *http) Close() error {
|
||||
return hs.writer.(io.Closer).Close()
|
||||
if closer, ok := hs.writer.(io.Closer); ok {
|
||||
return closer.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hs *http) CloseWrite() error {
|
||||
|
||||
@@ -0,0 +1,579 @@
|
||||
package stream
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"tunnel_pls/internal/http/header"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m *mockAddr) String() string { return m.addr }
|
||||
func (m *mockAddr) Network() string { return "tcp" }
|
||||
|
||||
type mockRequestMiddleware struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockRequestMiddleware) HandleRequest(h header.RequestHeader) error {
|
||||
if m.err == nil {
|
||||
h.Set("X-Middleware", "true")
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
type mockResponseMiddleware struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error {
|
||||
if m.err == nil {
|
||||
h.Set("X-Resp-Middleware", "true")
|
||||
}
|
||||
return m.err
|
||||
}
|
||||
|
||||
type mockReadWriter struct {
|
||||
bytes.Buffer
|
||||
closed bool
|
||||
writeClosed bool
|
||||
}
|
||||
|
||||
func (m *mockReadWriter) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReadWriter) CloseWrite() error {
|
||||
m.writeClosed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHTTPMethods(t *testing.T) {
|
||||
addr := &mockAddr{addr: "1.2.3.4:1234"}
|
||||
rw := &mockReadWriter{}
|
||||
hs := New(rw, rw, addr)
|
||||
|
||||
assert.Equal(t, addr, hs.RemoteAddr())
|
||||
|
||||
reqMW := &mockRequestMiddleware{}
|
||||
hs.UseRequestMiddleware(reqMW)
|
||||
assert.Equal(t, 1, len(hs.RequestMiddlewares()))
|
||||
assert.Equal(t, reqMW, hs.RequestMiddlewares()[0])
|
||||
|
||||
respMW := &mockResponseMiddleware{}
|
||||
hs.UseResponseMiddleware(respMW)
|
||||
assert.Equal(t, 1, len(hs.ResponseMiddlewares()))
|
||||
assert.Equal(t, respMW, hs.ResponseMiddlewares()[0])
|
||||
|
||||
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
hs.SetRequestHeader(reqH)
|
||||
}
|
||||
|
||||
func TestApplyMiddlewares(t *testing.T) {
|
||||
addr := &mockAddr{addr: "1.2.3.4:1234"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(HTTP)
|
||||
apply func(HTTP, header.RequestHeader, header.ResponseHeader) error
|
||||
verify func(*testing.T, header.RequestHeader, header.ResponseHeader)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "apply request middleware success",
|
||||
setup: func(hs HTTP) {
|
||||
hs.UseRequestMiddleware(&mockRequestMiddleware{})
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyRequestMiddlewares(reqH)
|
||||
},
|
||||
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
|
||||
assert.Equal(t, "true", reqH.Value("X-Middleware"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "apply response middleware success",
|
||||
setup: func(hs HTTP) {
|
||||
hs.UseResponseMiddleware(&mockResponseMiddleware{})
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
||||
},
|
||||
verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) {
|
||||
assert.Equal(t, "true", respH.Value("X-Resp-Middleware"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "apply request middleware error",
|
||||
setup: func(hs HTTP) {
|
||||
hs.UseRequestMiddleware(&mockRequestMiddleware{err: assert.AnError})
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyRequestMiddlewares(reqH)
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "apply response middleware error",
|
||||
setup: func(hs HTTP) {
|
||||
hs.UseResponseMiddleware(&mockResponseMiddleware{err: assert.AnError})
|
||||
},
|
||||
apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error {
|
||||
return hs.ApplyResponseMiddlewares(respH, []byte("body"))
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n"))
|
||||
rw := &mockReadWriter{}
|
||||
hs := New(rw, rw, addr)
|
||||
tt.setup(hs)
|
||||
err := tt.apply(hs, reqH, respH)
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, reqH, respH)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockWriterOnly struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func TestCloseMethods(t *testing.T) {
|
||||
addr := &mockAddr{addr: "1.2.3.4:1234"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
writer any
|
||||
op func(HTTP) error
|
||||
verify func(*testing.T, any)
|
||||
}{
|
||||
{
|
||||
name: "Close success",
|
||||
writer: &mockReadWriter{},
|
||||
op: func(hs HTTP) error { return hs.Close() },
|
||||
verify: func(t *testing.T, w any) {
|
||||
assert.True(t, w.(*mockReadWriter).closed)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWrite with CloseWrite implementation",
|
||||
writer: &mockReadWriter{},
|
||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||
verify: func(t *testing.T, w any) {
|
||||
assert.True(t, w.(*mockReadWriter).writeClosed)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CloseWrite fallback to Close",
|
||||
writer: &mockReadWriterOnlyCloser{},
|
||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||
verify: func(t *testing.T, w any) {
|
||||
assert.True(t, w.(*mockReadWriterOnlyCloser).closed)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Close with No Closer",
|
||||
writer: &mockWriterOnly{},
|
||||
op: func(hs HTTP) error { return hs.Close() },
|
||||
},
|
||||
{
|
||||
name: "CloseWrite with No CloseWrite and No Closer",
|
||||
writer: &mockWriterOnly{},
|
||||
op: func(hs HTTP) error { return hs.CloseWrite() },
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr)
|
||||
assert.NotPanics(t, func() {
|
||||
err := tt.op(hs)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, tt.writer)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockReadWriterOnlyCloser struct {
|
||||
bytes.Buffer
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockReadWriterOnlyCloser) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSplitHeaderAndBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
delimiterIdx int
|
||||
expectHeader []byte
|
||||
expectBody []byte
|
||||
}{
|
||||
{
|
||||
name: "standard",
|
||||
data: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nBodyContent"),
|
||||
delimiterIdx: 31,
|
||||
expectHeader: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"),
|
||||
expectBody: []byte("BodyContent"),
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
delimiterIdx: 15,
|
||||
expectHeader: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
expectBody: []byte(""),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, b := splitHeaderAndBody(tt.data, tt.delimiterIdx)
|
||||
assert.Equal(t, tt.expectHeader, h)
|
||||
assert.Equal(t, tt.expectBody, b)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsHTTPHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
buf []byte
|
||||
expect bool
|
||||
}{
|
||||
{
|
||||
name: "valid request",
|
||||
buf: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "valid response",
|
||||
buf: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "invalid start line",
|
||||
buf: []byte("NOT_HTTP /path\r\nHost: example.com\r\n\r\n"),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header line (no colon)",
|
||||
buf: []byte("GET / HTTP/1.1\r\nInvalidHeaderLine\r\n\r\n"),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header line (colon at 0)",
|
||||
buf: []byte("GET / HTTP/1.1\r\n: value\r\n\r\n"),
|
||||
expect: false,
|
||||
},
|
||||
{
|
||||
name: "empty header section",
|
||||
buf: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
{
|
||||
name: "multiple headers",
|
||||
buf: []byte("GET / HTTP/1.1\r\nH1: V1\r\nH2: V2\r\n\r\n"),
|
||||
expect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isHTTPHeader(tt.buf)
|
||||
assert.Equal(t, tt.expect, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
readLen int
|
||||
expectContent string
|
||||
expectRead int
|
||||
expectErr bool
|
||||
middlewareErr error
|
||||
}{
|
||||
{
|
||||
name: "valid http request",
|
||||
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\nBody"),
|
||||
readLen: 100,
|
||||
expectContent: "Body",
|
||||
expectRead: 54,
|
||||
},
|
||||
{
|
||||
name: "non-http data",
|
||||
input: []byte("Some random data\r\n\r\nMore data"),
|
||||
readLen: 100,
|
||||
expectContent: "Some random data\r\n\r\nMore data",
|
||||
expectRead: 29,
|
||||
},
|
||||
{
|
||||
name: "no delimiter",
|
||||
input: []byte("Partial data without delimiter"),
|
||||
readLen: 100,
|
||||
expectContent: "Partial data without delimiter",
|
||||
expectRead: 30,
|
||||
},
|
||||
{
|
||||
name: "middleware error",
|
||||
input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\n"),
|
||||
readLen: 100,
|
||||
middlewareErr: assert.AnError,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rw := &mockReadWriter{}
|
||||
rw.Write(tt.input)
|
||||
hs := New(rw, rw, &mockAddr{})
|
||||
if tt.middlewareErr != nil {
|
||||
hs.UseRequestMiddleware(&mockRequestMiddleware{err: tt.middlewareErr})
|
||||
} else {
|
||||
hs.UseRequestMiddleware(&mockRequestMiddleware{})
|
||||
}
|
||||
|
||||
p := make([]byte, tt.readLen)
|
||||
n, err := hs.Read(p)
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.expectRead, n)
|
||||
if tt.name == "valid http request" {
|
||||
content := string(p[:n])
|
||||
assert.Contains(t, content, "GET / HTTP/1.1\r\n")
|
||||
assert.Contains(t, content, "Host: test\r\n")
|
||||
assert.Contains(t, content, "X-Middleware: true\r\n")
|
||||
assert.True(t, bytes.HasSuffix(p[:n], []byte("\r\n\r\nBody")))
|
||||
} else {
|
||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
writes [][]byte
|
||||
expectWritten string
|
||||
expectErr bool
|
||||
middlewareErr error
|
||||
}{
|
||||
{
|
||||
name: "valid http response in one write",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"),
|
||||
},
|
||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
||||
},
|
||||
{
|
||||
name: "valid http response in multiple writes",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\n"),
|
||||
[]byte("Content-Length: 4\r\n\r\n"),
|
||||
[]byte("Body"),
|
||||
},
|
||||
expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody",
|
||||
},
|
||||
{
|
||||
name: "non-http data",
|
||||
writes: [][]byte{
|
||||
[]byte("Random data with delimiter\r\n\r\nFlush"),
|
||||
},
|
||||
expectWritten: "Random data with delimiter\r\n\r\nFlush",
|
||||
},
|
||||
{
|
||||
name: "bypass buffering",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
},
|
||||
expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" +
|
||||
"HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n",
|
||||
},
|
||||
{
|
||||
name: "middleware error",
|
||||
writes: [][]byte{
|
||||
[]byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
},
|
||||
middlewareErr: assert.AnError,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rw := &mockReadWriter{}
|
||||
hs := New(rw, rw, &mockAddr{})
|
||||
if tt.middlewareErr != nil {
|
||||
hs.UseResponseMiddleware(&mockResponseMiddleware{err: tt.middlewareErr})
|
||||
} else {
|
||||
hs.UseResponseMiddleware(&mockResponseMiddleware{})
|
||||
}
|
||||
|
||||
var totalN int
|
||||
var err error
|
||||
for _, w := range tt.writes {
|
||||
var n int
|
||||
n, err = hs.Write(w)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
totalN += n
|
||||
}
|
||||
|
||||
if tt.expectErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if strings.HasPrefix(tt.expectWritten, "HTTP/") {
|
||||
written := rw.String()
|
||||
assert.Contains(t, written, "HTTP/1.1 200 OK\r\n")
|
||||
assert.Contains(t, written, "X-Resp-Middleware: true\r\n")
|
||||
if strings.Contains(tt.expectWritten, "Content-Length: 4") {
|
||||
assert.Contains(t, written, "Content-Length: 4\r\n")
|
||||
}
|
||||
assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n"))
|
||||
} else {
|
||||
assert.Equal(t, tt.expectWritten, rw.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteErrors(t *testing.T) {
|
||||
addr := &mockAddr{addr: "1.2.3.4:1234"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
writer any
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "write error in writeHeaderAndBody",
|
||||
writer: &mockErrorWriteCloser{},
|
||||
data: []byte("HTTP/1.1 200 OK\r\n\r\n"),
|
||||
},
|
||||
{
|
||||
name: "write error in writeHeaderAndBody second write",
|
||||
writer: &mockFailSecondWriteCloser{},
|
||||
data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"),
|
||||
},
|
||||
{
|
||||
name: "write error in writeRawBuffer",
|
||||
writer: &mockErrorWriteCloser{},
|
||||
data: []byte("Not HTTP\r\n\r\nFlush"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr)
|
||||
_, err := hs.Write(tt.data)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadEOF(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reader io.Reader
|
||||
expectN int
|
||||
expectErr error
|
||||
expectContent string
|
||||
}{
|
||||
{
|
||||
name: "read eof",
|
||||
reader: &mockEOFReader{},
|
||||
expectN: 4,
|
||||
expectErr: io.EOF,
|
||||
expectContent: "data",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
hs := New(nil, tt.reader, &mockAddr{})
|
||||
p := make([]byte, 100)
|
||||
n, err := hs.Read(p)
|
||||
assert.Equal(t, tt.expectN, n)
|
||||
assert.Equal(t, tt.expectErr, err)
|
||||
assert.Equal(t, tt.expectContent, string(p[:n]))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockEOFReader struct {
|
||||
mockReadWriter
|
||||
}
|
||||
|
||||
func (m *mockEOFReader) Read(p []byte) (int, error) {
|
||||
copy(p, "data")
|
||||
return 4, io.EOF
|
||||
}
|
||||
|
||||
type mockFailSecondWriteCloser struct {
|
||||
count int
|
||||
}
|
||||
|
||||
func (m *mockFailSecondWriteCloser) Write(p []byte) (int, error) {
|
||||
m.count++
|
||||
if m.count == 2 {
|
||||
return 0, assert.AnError
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (m *mockFailSecondWriteCloser) Close() error { return nil }
|
||||
func (m *mockFailSecondWriteCloser) Read(p []byte) (int, error) { return 0, nil }
|
||||
|
||||
type mockErrorWriteCloser struct {
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (m *mockErrorWriteCloser) Write(p []byte) (int, error) {
|
||||
return 0, assert.AnError
|
||||
}
|
||||
|
||||
func (m *mockErrorWriteCloser) Close() error {
|
||||
m.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockErrorWriteCloser) Read(p []byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package random
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -44,18 +45,32 @@ func TestRandom_String(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRandomWithFailingReader_String(t *testing.T) {
|
||||
var randomizer Random
|
||||
var errBrainrot = fmt.Errorf("you are not sigma enough")
|
||||
randomizer = &random{reader: &brainrotReader{err: errBrainrot}}
|
||||
t.Run("test failing reader", func(t *testing.T) {
|
||||
result, err := randomizer.String(20)
|
||||
if !errors.Is(err, errBrainrot) {
|
||||
t.Errorf("String() error = %v, wantErr %v", err, errBrainrot)
|
||||
return
|
||||
}
|
||||
errBrainrot := fmt.Errorf("you are not sigma enough")
|
||||
|
||||
if result != "" {
|
||||
t.Errorf("String() result = %v, want an empty string due to error", result)
|
||||
}
|
||||
})
|
||||
tests := []struct {
|
||||
name string
|
||||
reader io.Reader
|
||||
expectErr error
|
||||
}{
|
||||
{
|
||||
name: "failing reader",
|
||||
reader: &brainrotReader{err: errBrainrot},
|
||||
expectErr: errBrainrot,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
randomizer := &random{reader: tt.reader}
|
||||
result, err := randomizer.String(20)
|
||||
if !errors.Is(err, tt.expectErr) {
|
||||
t.Errorf("String() error = %v, wantErr %v", err, tt.expectErr)
|
||||
return
|
||||
}
|
||||
|
||||
if result != "" {
|
||||
t.Errorf("String() result = %v, want an empty string due to error", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -468,89 +468,108 @@ func TestRegistry_Register(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistry_GetAllSessionFromUser(t *testing.T) {
|
||||
t.Run("user has no sessions", func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
}
|
||||
sessions := r.GetAllSessionFromUser("user1")
|
||||
if len(sessions) != 0 {
|
||||
t.Errorf("expected 0 sessions, got %d", len(sessions))
|
||||
}
|
||||
})
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry) string
|
||||
expectN int
|
||||
}{
|
||||
{
|
||||
name: "user has no sessions",
|
||||
setupFunc: func(r *registry) string {
|
||||
return "user1"
|
||||
},
|
||||
expectN: 0,
|
||||
},
|
||||
{
|
||||
name: "user has multiple sessions",
|
||||
setupFunc: func(r *registry) string {
|
||||
user := "user1"
|
||||
key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
|
||||
r.mu.Lock()
|
||||
r.byUser[user] = map[Key]Session{
|
||||
key1: &mockSession{user: user},
|
||||
key2: &mockSession{user: user},
|
||||
}
|
||||
r.mu.Unlock()
|
||||
return user
|
||||
},
|
||||
expectN: 2,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("user has multiple sessions", func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
|
||||
user := "user1"
|
||||
key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||
key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP}
|
||||
session1 := &mockSession{user: user}
|
||||
session2 := &mockSession{user: user}
|
||||
|
||||
r.mu.Lock()
|
||||
r.byUser[user] = map[Key]Session{
|
||||
key1: session1,
|
||||
key2: session2,
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
sessions := r.GetAllSessionFromUser(user)
|
||||
if len(sessions) != 2 {
|
||||
t.Errorf("expected 2 sessions, got %d", len(sessions))
|
||||
}
|
||||
|
||||
found := map[Session]bool{}
|
||||
for _, s := range sessions {
|
||||
found[s] = true
|
||||
}
|
||||
if !found[session1] || !found[session2] {
|
||||
t.Errorf("returned sessions do not match expected")
|
||||
}
|
||||
})
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
user := tt.setupFunc(r)
|
||||
sessions := r.GetAllSessionFromUser(user)
|
||||
if len(sessions) != tt.expectN {
|
||||
t.Errorf("expected %d sessions, got %d", tt.expectN, len(sessions))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Remove(t *testing.T) {
|
||||
t.Run("remove existing key", func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
setupFunc func(r *registry) (string, types.SessionKey)
|
||||
key types.SessionKey
|
||||
verify func(*testing.T, *registry, string, types.SessionKey)
|
||||
}{
|
||||
{
|
||||
name: "remove existing key",
|
||||
setupFunc: func(r *registry) (string, types.SessionKey) {
|
||||
user := "user1"
|
||||
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||
session := &mockSession{user: user}
|
||||
r.mu.Lock()
|
||||
r.byUser[user] = map[Key]Session{key: session}
|
||||
r.slugIndex[key] = user
|
||||
r.mu.Unlock()
|
||||
return user, key
|
||||
},
|
||||
verify: func(t *testing.T, r *registry, user string, key types.SessionKey) {
|
||||
if _, ok := r.byUser[user][key]; ok {
|
||||
t.Errorf("expected key to be removed from byUser")
|
||||
}
|
||||
if _, ok := r.slugIndex[key]; ok {
|
||||
t.Errorf("expected key to be removed from slugIndex")
|
||||
}
|
||||
if _, ok := r.byUser[user]; ok {
|
||||
t.Errorf("expected user to be removed from byUser map")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove non-existing key",
|
||||
setupFunc: func(r *registry) (string, types.SessionKey) {
|
||||
return "", types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
user := "user1"
|
||||
key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP}
|
||||
session := &mockSession{user: user}
|
||||
|
||||
r.mu.Lock()
|
||||
r.byUser[user] = map[Key]Session{key: session}
|
||||
r.slugIndex[key] = user
|
||||
r.mu.Unlock()
|
||||
|
||||
r.Remove(key)
|
||||
|
||||
if _, ok := r.byUser[user][key]; ok {
|
||||
t.Errorf("expected key to be removed from byUser")
|
||||
}
|
||||
if _, ok := r.slugIndex[key]; ok {
|
||||
t.Errorf("expected key to be removed from slugIndex")
|
||||
}
|
||||
if _, ok := r.byUser[user]; ok {
|
||||
t.Errorf("expected user to be removed from byUser map")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("remove non-existing key", func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
}
|
||||
r.Remove(types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP})
|
||||
})
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := ®istry{
|
||||
byUser: make(map[string]map[Key]Session),
|
||||
slugIndex: make(map[Key]string),
|
||||
mu: sync.RWMutex{},
|
||||
}
|
||||
user, key := tt.setupFunc(r)
|
||||
if user == "" {
|
||||
key = tt.key
|
||||
}
|
||||
r.Remove(key)
|
||||
if tt.verify != nil {
|
||||
tt.verify(t, r, user, key)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidSlug(t *testing.T) {
|
||||
|
||||
@@ -175,7 +175,11 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
hh.cleanupUnusedChannel(channel, reqs)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
|
||||
+3
-2
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
@@ -85,7 +86,7 @@ func (s *server) handleConnection(conn net.Conn) {
|
||||
|
||||
defer func(sshConn *ssh.ServerConn) {
|
||||
err = sshConn.Close()
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) {
|
||||
log.Printf("failed to close SSH server: %v", err)
|
||||
}
|
||||
}(sshConn)
|
||||
@@ -101,7 +102,7 @@ func (s *server) handleConnection(conn net.Conn) {
|
||||
sshSession := session.New(s.randomizer, s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user)
|
||||
err = sshSession.Start()
|
||||
if err != nil {
|
||||
log.Printf("SSH session ended with error: %v", err)
|
||||
log.Printf("SSH session ended with error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
return
|
||||
|
||||
@@ -256,8 +256,6 @@ func (i *interaction) Start() {
|
||||
i.program.Kill()
|
||||
i.program = nil
|
||||
if i.closeFunc != nil {
|
||||
if err := i.closeFunc(); err != nil {
|
||||
log.Printf("Cannot close session: %s \n", err)
|
||||
}
|
||||
_ = i.closeFunc()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@ package lifecycle
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
portUtil "tunnel_pls/internal/port"
|
||||
@@ -22,7 +25,9 @@ type SessionRegistry interface {
|
||||
}
|
||||
|
||||
type lifecycle struct {
|
||||
mu sync.Mutex
|
||||
status types.SessionStatus
|
||||
closeErr error
|
||||
conn ssh.Conn
|
||||
channel ssh.Channel
|
||||
forwarder Forwarder
|
||||
@@ -73,29 +78,41 @@ func (l *lifecycle) Connection() ssh.Conn {
|
||||
return l.conn
|
||||
}
|
||||
func (l *lifecycle) SetStatus(status types.SessionStatus) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.status = status
|
||||
if status == types.SessionStatusRUNNING && l.startedAt.IsZero() {
|
||||
l.startedAt = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
func closeIfNotNil(c interface{ Close() error }) error {
|
||||
if c != nil {
|
||||
return c.Close()
|
||||
}
|
||||
return nil
|
||||
func (l *lifecycle) IsActive() bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.status == types.SessionStatusRUNNING
|
||||
}
|
||||
|
||||
func (l *lifecycle) Close() error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if l.status == types.SessionStatusCLOSED {
|
||||
return l.closeErr
|
||||
}
|
||||
l.status = types.SessionStatusCLOSED
|
||||
|
||||
var errs []error
|
||||
tunnelType := l.forwarder.TunnelType()
|
||||
|
||||
if err := closeIfNotNil(l.channel); err != nil {
|
||||
errs = append(errs, err)
|
||||
if l.channel != nil {
|
||||
if err := l.channel.Close(); err != nil && !isClosedError(err) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := closeIfNotNil(l.conn); err != nil {
|
||||
errs = append(errs, err)
|
||||
if l.conn != nil {
|
||||
if err := l.conn.Close(); err != nil && !isClosedError(err) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
clientSlug := l.slug.String()
|
||||
@@ -114,11 +131,15 @@ func (l *lifecycle) Close() error {
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Join(errs...)
|
||||
l.closeErr = errors.Join(errs...)
|
||||
return l.closeErr
|
||||
}
|
||||
|
||||
func (l *lifecycle) IsActive() bool {
|
||||
return l.status == types.SessionStatusRUNNING
|
||||
func isClosedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || err.Error() == "EOF"
|
||||
}
|
||||
|
||||
func (l *lifecycle) StartedAt() time.Time {
|
||||
|
||||
@@ -7,6 +7,7 @@ type SessionStatus int
|
||||
const (
|
||||
SessionStatusINITIALIZING SessionStatus = iota
|
||||
SessionStatusRUNNING
|
||||
SessionStatusCLOSED
|
||||
)
|
||||
|
||||
type InteractiveMode int
|
||||
|
||||
Reference in New Issue
Block a user