From 7159300fa283379dd218c0c1d2715bd2a0de6c38 Mon Sep 17 00:00:00 2001 From: bagas Date: Thu, 22 Jan 2026 21:04:05 +0700 Subject: [PATCH] test(stream): add unit tests for stream behavior - Fix duplicating EOF error when closing SSH connection - Add new SessionStatusCLOSED type --- internal/http/header/header_test.go | 31 +- internal/http/stream/stream.go | 5 +- internal/http/stream/stream_test.go | 579 ++++++++++++++++++++++++++++ internal/random/random_test.go | 41 +- internal/registry/registry_test.go | 175 +++++---- internal/transport/httphandler.go | 6 +- server/server.go | 5 +- session/interaction/interaction.go | 4 +- session/lifecycle/lifecycle.go | 45 ++- types/types.go | 1 + 10 files changed, 775 insertions(+), 117 deletions(-) create mode 100644 internal/http/stream/stream_test.go diff --git a/internal/http/header/header_test.go b/internal/http/header/header_test.go index 10fe1b0..5207b1c 100644 --- a/internal/http/header/header_test.go +++ b/internal/http/header/header_test.go @@ -317,11 +317,28 @@ func TestSetRemainingHeaders(t *testing.T) { } func TestParseHeadersFromReaderEdgeCases(t *testing.T) { - t.Run("malformed header line", func(t *testing.T) { - data := []byte("GET / HTTP/1.1\r\nMalformedLine\r\nK1: V1\r\n\r\n") - br := bufio.NewReader(bytes.NewReader(data)) - req, err := parseHeadersFromReader(br) - assert.NoError(t, err) - assert.Equal(t, "V1", req.Value("K1")) - }) + tests := []struct { + name string + data []byte + expectHeaders map[string]string + }{ + { + name: "malformed header line", + data: []byte("GET / HTTP/1.1\r\nMalformedLine\r\nK1: V1\r\n\r\n"), + expectHeaders: map[string]string{ + "K1": "V1", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + br := bufio.NewReader(bytes.NewReader(tt.data)) + req, err := parseHeadersFromReader(br) + assert.NoError(t, err) + for k, v := range tt.expectHeaders { + assert.Equal(t, v, req.Value(k)) + } + }) + } } diff --git a/internal/http/stream/stream.go b/internal/http/stream/stream.go index 97d2752..dcc09f9 100644 --- a/internal/http/stream/stream.go +++ b/internal/http/stream/stream.go @@ -72,7 +72,10 @@ func (hs *http) ResponseMiddlewares() []middleware.ResponseMiddleware { } func (hs *http) Close() error { - return hs.writer.(io.Closer).Close() + if closer, ok := hs.writer.(io.Closer); ok { + return closer.Close() + } + return nil } func (hs *http) CloseWrite() error { diff --git a/internal/http/stream/stream_test.go b/internal/http/stream/stream_test.go new file mode 100644 index 0000000..65c660b --- /dev/null +++ b/internal/http/stream/stream_test.go @@ -0,0 +1,579 @@ +package stream + +import ( + "bytes" + "io" + "strings" + "testing" + + "tunnel_pls/internal/http/header" + + "github.com/stretchr/testify/assert" +) + +type mockAddr struct { + addr string +} + +func (m *mockAddr) String() string { return m.addr } +func (m *mockAddr) Network() string { return "tcp" } + +type mockRequestMiddleware struct { + err error +} + +func (m *mockRequestMiddleware) HandleRequest(h header.RequestHeader) error { + if m.err == nil { + h.Set("X-Middleware", "true") + } + return m.err +} + +type mockResponseMiddleware struct { + err error +} + +func (m *mockResponseMiddleware) HandleResponse(h header.ResponseHeader, body []byte) error { + if m.err == nil { + h.Set("X-Resp-Middleware", "true") + } + return m.err +} + +type mockReadWriter struct { + bytes.Buffer + closed bool + writeClosed bool +} + +func (m *mockReadWriter) Close() error { + m.closed = true + return nil +} + +func (m *mockReadWriter) CloseWrite() error { + m.writeClosed = true + return nil +} + +func TestHTTPMethods(t *testing.T) { + addr := &mockAddr{addr: "1.2.3.4:1234"} + rw := &mockReadWriter{} + hs := New(rw, rw, addr) + + assert.Equal(t, addr, hs.RemoteAddr()) + + reqMW := &mockRequestMiddleware{} + hs.UseRequestMiddleware(reqMW) + assert.Equal(t, 1, len(hs.RequestMiddlewares())) + assert.Equal(t, reqMW, hs.RequestMiddlewares()[0]) + + respMW := &mockResponseMiddleware{} + hs.UseResponseMiddleware(respMW) + assert.Equal(t, 1, len(hs.ResponseMiddlewares())) + assert.Equal(t, respMW, hs.ResponseMiddlewares()[0]) + + reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n")) + hs.SetRequestHeader(reqH) +} + +func TestApplyMiddlewares(t *testing.T) { + addr := &mockAddr{addr: "1.2.3.4:1234"} + + tests := []struct { + name string + setup func(HTTP) + apply func(HTTP, header.RequestHeader, header.ResponseHeader) error + verify func(*testing.T, header.RequestHeader, header.ResponseHeader) + expectErr bool + }{ + { + name: "apply request middleware success", + setup: func(hs HTTP) { + hs.UseRequestMiddleware(&mockRequestMiddleware{}) + }, + apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { + return hs.ApplyRequestMiddlewares(reqH) + }, + verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) { + assert.Equal(t, "true", reqH.Value("X-Middleware")) + }, + }, + { + name: "apply response middleware success", + setup: func(hs HTTP) { + hs.UseResponseMiddleware(&mockResponseMiddleware{}) + }, + apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { + return hs.ApplyResponseMiddlewares(respH, []byte("body")) + }, + verify: func(t *testing.T, reqH header.RequestHeader, respH header.ResponseHeader) { + assert.Equal(t, "true", respH.Value("X-Resp-Middleware")) + }, + }, + { + name: "apply request middleware error", + setup: func(hs HTTP) { + hs.UseRequestMiddleware(&mockRequestMiddleware{err: assert.AnError}) + }, + apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { + return hs.ApplyRequestMiddlewares(reqH) + }, + expectErr: true, + }, + { + name: "apply response middleware error", + setup: func(hs HTTP) { + hs.UseResponseMiddleware(&mockResponseMiddleware{err: assert.AnError}) + }, + apply: func(hs HTTP, reqH header.RequestHeader, respH header.ResponseHeader) error { + return hs.ApplyResponseMiddlewares(respH, []byte("body")) + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reqH, _ := header.NewRequest([]byte("GET / HTTP/1.1\r\n\r\n")) + respH, _ := header.NewResponse([]byte("HTTP/1.1 200 OK\r\n\r\n")) + rw := &mockReadWriter{} + hs := New(rw, rw, addr) + tt.setup(hs) + err := tt.apply(hs, reqH, respH) + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if tt.verify != nil { + tt.verify(t, reqH, respH) + } + } + }) + } +} + +type mockWriterOnly struct { + bytes.Buffer +} + +func TestCloseMethods(t *testing.T) { + addr := &mockAddr{addr: "1.2.3.4:1234"} + + tests := []struct { + name string + writer any + op func(HTTP) error + verify func(*testing.T, any) + }{ + { + name: "Close success", + writer: &mockReadWriter{}, + op: func(hs HTTP) error { return hs.Close() }, + verify: func(t *testing.T, w any) { + assert.True(t, w.(*mockReadWriter).closed) + }, + }, + { + name: "CloseWrite with CloseWrite implementation", + writer: &mockReadWriter{}, + op: func(hs HTTP) error { return hs.CloseWrite() }, + verify: func(t *testing.T, w any) { + assert.True(t, w.(*mockReadWriter).writeClosed) + }, + }, + { + name: "CloseWrite fallback to Close", + writer: &mockReadWriterOnlyCloser{}, + op: func(hs HTTP) error { return hs.CloseWrite() }, + verify: func(t *testing.T, w any) { + assert.True(t, w.(*mockReadWriterOnlyCloser).closed) + }, + }, + { + name: "Close with No Closer", + writer: &mockWriterOnly{}, + op: func(hs HTTP) error { return hs.Close() }, + }, + { + name: "CloseWrite with No CloseWrite and No Closer", + writer: &mockWriterOnly{}, + op: func(hs HTTP) error { return hs.CloseWrite() }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr) + assert.NotPanics(t, func() { + err := tt.op(hs) + assert.NoError(t, err) + }) + if tt.verify != nil { + tt.verify(t, tt.writer) + } + }) + } +} + +type mockReadWriterOnlyCloser struct { + bytes.Buffer + closed bool +} + +func (m *mockReadWriterOnlyCloser) Close() error { + m.closed = true + return nil +} + +func TestSplitHeaderAndBody(t *testing.T) { + tests := []struct { + name string + data []byte + delimiterIdx int + expectHeader []byte + expectBody []byte + }{ + { + name: "standard", + data: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\nBodyContent"), + delimiterIdx: 31, + expectHeader: []byte("GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"), + expectBody: []byte("BodyContent"), + }, + { + name: "empty body", + data: []byte("HTTP/1.1 200 OK\r\n\r\n"), + delimiterIdx: 15, + expectHeader: []byte("HTTP/1.1 200 OK\r\n\r\n"), + expectBody: []byte(""), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, b := splitHeaderAndBody(tt.data, tt.delimiterIdx) + assert.Equal(t, tt.expectHeader, h) + assert.Equal(t, tt.expectBody, b) + }) + } +} + +func TestIsHTTPHeader(t *testing.T) { + tests := []struct { + name string + buf []byte + expect bool + }{ + { + name: "valid request", + buf: []byte("GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"), + expect: true, + }, + { + name: "valid response", + buf: []byte("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\n\r\n"), + expect: true, + }, + { + name: "invalid start line", + buf: []byte("NOT_HTTP /path\r\nHost: example.com\r\n\r\n"), + expect: false, + }, + { + name: "invalid header line (no colon)", + buf: []byte("GET / HTTP/1.1\r\nInvalidHeaderLine\r\n\r\n"), + expect: false, + }, + { + name: "invalid header line (colon at 0)", + buf: []byte("GET / HTTP/1.1\r\n: value\r\n\r\n"), + expect: false, + }, + { + name: "empty header section", + buf: []byte("GET / HTTP/1.1\r\n\r\n"), + expect: true, + }, + { + name: "multiple headers", + buf: []byte("GET / HTTP/1.1\r\nH1: V1\r\nH2: V2\r\n\r\n"), + expect: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isHTTPHeader(tt.buf) + assert.Equal(t, tt.expect, result) + }) + } +} + +func TestRead(t *testing.T) { + tests := []struct { + name string + input []byte + readLen int + expectContent string + expectRead int + expectErr bool + middlewareErr error + }{ + { + name: "valid http request", + input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\nBody"), + readLen: 100, + expectContent: "Body", + expectRead: 54, + }, + { + name: "non-http data", + input: []byte("Some random data\r\n\r\nMore data"), + readLen: 100, + expectContent: "Some random data\r\n\r\nMore data", + expectRead: 29, + }, + { + name: "no delimiter", + input: []byte("Partial data without delimiter"), + readLen: 100, + expectContent: "Partial data without delimiter", + expectRead: 30, + }, + { + name: "middleware error", + input: []byte("GET / HTTP/1.1\r\nHost: test\r\n\r\n"), + readLen: 100, + middlewareErr: assert.AnError, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rw := &mockReadWriter{} + rw.Write(tt.input) + hs := New(rw, rw, &mockAddr{}) + if tt.middlewareErr != nil { + hs.UseRequestMiddleware(&mockRequestMiddleware{err: tt.middlewareErr}) + } else { + hs.UseRequestMiddleware(&mockRequestMiddleware{}) + } + + p := make([]byte, tt.readLen) + n, err := hs.Read(p) + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectRead, n) + if tt.name == "valid http request" { + content := string(p[:n]) + assert.Contains(t, content, "GET / HTTP/1.1\r\n") + assert.Contains(t, content, "Host: test\r\n") + assert.Contains(t, content, "X-Middleware: true\r\n") + assert.True(t, bytes.HasSuffix(p[:n], []byte("\r\n\r\nBody"))) + } else { + assert.Equal(t, tt.expectContent, string(p[:n])) + } + } + }) + } +} + +func TestWrite(t *testing.T) { + tests := []struct { + name string + writes [][]byte + expectWritten string + expectErr bool + middlewareErr error + }{ + { + name: "valid http response in one write", + writes: [][]byte{ + []byte("HTTP/1.1 200 OK\r\nContent-Length: 4\r\n\r\nBody"), + }, + expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody", + }, + { + name: "valid http response in multiple writes", + writes: [][]byte{ + []byte("HTTP/1.1 200 OK\r\n"), + []byte("Content-Length: 4\r\n\r\n"), + []byte("Body"), + }, + expectWritten: "HTTP/1.1 200 OK\r\nContent-Length: 4\r\nX-Resp-Middleware: true\r\n\r\nBody", + }, + { + name: "non-http data", + writes: [][]byte{ + []byte("Random data with delimiter\r\n\r\nFlush"), + }, + expectWritten: "Random data with delimiter\r\n\r\nFlush", + }, + { + name: "bypass buffering", + writes: [][]byte{ + []byte("HTTP/1.1 200 OK\r\n\r\n"), + []byte("HTTP/1.1 200 OK\r\n\r\n"), + }, + expectWritten: "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n" + + "HTTP/1.1 200 OK\r\nX-Resp-Middleware: true\r\n\r\n", + }, + { + name: "middleware error", + writes: [][]byte{ + []byte("HTTP/1.1 200 OK\r\n\r\n"), + }, + middlewareErr: assert.AnError, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rw := &mockReadWriter{} + hs := New(rw, rw, &mockAddr{}) + if tt.middlewareErr != nil { + hs.UseResponseMiddleware(&mockResponseMiddleware{err: tt.middlewareErr}) + } else { + hs.UseResponseMiddleware(&mockResponseMiddleware{}) + } + + var totalN int + var err error + for _, w := range tt.writes { + var n int + n, err = hs.Write(w) + if err != nil { + break + } + totalN += n + } + + if tt.expectErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + if strings.HasPrefix(tt.expectWritten, "HTTP/") { + written := rw.String() + assert.Contains(t, written, "HTTP/1.1 200 OK\r\n") + assert.Contains(t, written, "X-Resp-Middleware: true\r\n") + if strings.Contains(tt.expectWritten, "Content-Length: 4") { + assert.Contains(t, written, "Content-Length: 4\r\n") + } + assert.True(t, strings.HasSuffix(written, "\r\n\r\nBody") || strings.HasSuffix(written, "\r\n\r\n")) + } else { + assert.Equal(t, tt.expectWritten, rw.String()) + } + } + }) + } +} + +func TestWriteErrors(t *testing.T) { + addr := &mockAddr{addr: "1.2.3.4:1234"} + + tests := []struct { + name string + writer any + data []byte + }{ + { + name: "write error in writeHeaderAndBody", + writer: &mockErrorWriteCloser{}, + data: []byte("HTTP/1.1 200 OK\r\n\r\n"), + }, + { + name: "write error in writeHeaderAndBody second write", + writer: &mockFailSecondWriteCloser{}, + data: []byte("HTTP/1.1 200 OK\r\n\r\nBody"), + }, + { + name: "write error in writeRawBuffer", + writer: &mockErrorWriteCloser{}, + data: []byte("Not HTTP\r\n\r\nFlush"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hs := New(tt.writer.(io.Writer), tt.writer.(io.Reader), addr) + _, err := hs.Write(tt.data) + assert.Error(t, err) + }) + } +} + +func TestReadEOF(t *testing.T) { + tests := []struct { + name string + reader io.Reader + expectN int + expectErr error + expectContent string + }{ + { + name: "read eof", + reader: &mockEOFReader{}, + expectN: 4, + expectErr: io.EOF, + expectContent: "data", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hs := New(nil, tt.reader, &mockAddr{}) + p := make([]byte, 100) + n, err := hs.Read(p) + assert.Equal(t, tt.expectN, n) + assert.Equal(t, tt.expectErr, err) + assert.Equal(t, tt.expectContent, string(p[:n])) + }) + } +} + +type mockEOFReader struct { + mockReadWriter +} + +func (m *mockEOFReader) Read(p []byte) (int, error) { + copy(p, "data") + return 4, io.EOF +} + +type mockFailSecondWriteCloser struct { + count int +} + +func (m *mockFailSecondWriteCloser) Write(p []byte) (int, error) { + m.count++ + if m.count == 2 { + return 0, assert.AnError + } + return len(p), nil +} + +func (m *mockFailSecondWriteCloser) Close() error { return nil } +func (m *mockFailSecondWriteCloser) Read(p []byte) (int, error) { return 0, nil } + +type mockErrorWriteCloser struct { + closed bool +} + +func (m *mockErrorWriteCloser) Write(p []byte) (int, error) { + return 0, assert.AnError +} + +func (m *mockErrorWriteCloser) Close() error { + m.closed = true + return nil +} + +func (m *mockErrorWriteCloser) Read(p []byte) (int, error) { + return 0, nil +} diff --git a/internal/random/random_test.go b/internal/random/random_test.go index 8c6787e..057487b 100644 --- a/internal/random/random_test.go +++ b/internal/random/random_test.go @@ -3,6 +3,7 @@ package random import ( "errors" "fmt" + "io" "testing" ) @@ -44,18 +45,32 @@ func TestRandom_String(t *testing.T) { } func TestRandomWithFailingReader_String(t *testing.T) { - var randomizer Random - var errBrainrot = fmt.Errorf("you are not sigma enough") - randomizer = &random{reader: &brainrotReader{err: errBrainrot}} - t.Run("test failing reader", func(t *testing.T) { - result, err := randomizer.String(20) - if !errors.Is(err, errBrainrot) { - t.Errorf("String() error = %v, wantErr %v", err, errBrainrot) - return - } + errBrainrot := fmt.Errorf("you are not sigma enough") - if result != "" { - t.Errorf("String() result = %v, want an empty string due to error", result) - } - }) + tests := []struct { + name string + reader io.Reader + expectErr error + }{ + { + name: "failing reader", + reader: &brainrotReader{err: errBrainrot}, + expectErr: errBrainrot, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + randomizer := &random{reader: tt.reader} + result, err := randomizer.String(20) + if !errors.Is(err, tt.expectErr) { + t.Errorf("String() error = %v, wantErr %v", err, tt.expectErr) + return + } + + if result != "" { + t.Errorf("String() result = %v, want an empty string due to error", result) + } + }) + } } diff --git a/internal/registry/registry_test.go b/internal/registry/registry_test.go index d4734ed..9b3d026 100644 --- a/internal/registry/registry_test.go +++ b/internal/registry/registry_test.go @@ -468,89 +468,108 @@ func TestRegistry_Register(t *testing.T) { } func TestRegistry_GetAllSessionFromUser(t *testing.T) { - t.Run("user has no sessions", func(t *testing.T) { - r := ®istry{ - byUser: make(map[string]map[Key]Session), - slugIndex: make(map[Key]string), - } - sessions := r.GetAllSessionFromUser("user1") - if len(sessions) != 0 { - t.Errorf("expected 0 sessions, got %d", len(sessions)) - } - }) + tests := []struct { + name string + setupFunc func(r *registry) string + expectN int + }{ + { + name: "user has no sessions", + setupFunc: func(r *registry) string { + return "user1" + }, + expectN: 0, + }, + { + name: "user has multiple sessions", + setupFunc: func(r *registry) string { + user := "user1" + key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP} + key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP} + r.mu.Lock() + r.byUser[user] = map[Key]Session{ + key1: &mockSession{user: user}, + key2: &mockSession{user: user}, + } + r.mu.Unlock() + return user + }, + expectN: 2, + }, + } - t.Run("user has multiple sessions", func(t *testing.T) { - r := ®istry{ - byUser: make(map[string]map[Key]Session), - slugIndex: make(map[Key]string), - mu: sync.RWMutex{}, - } - - user := "user1" - key1 := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP} - key2 := types.SessionKey{Id: "b", Type: types.TunnelTypeTCP} - session1 := &mockSession{user: user} - session2 := &mockSession{user: user} - - r.mu.Lock() - r.byUser[user] = map[Key]Session{ - key1: session1, - key2: session2, - } - r.mu.Unlock() - - sessions := r.GetAllSessionFromUser(user) - if len(sessions) != 2 { - t.Errorf("expected 2 sessions, got %d", len(sessions)) - } - - found := map[Session]bool{} - for _, s := range sessions { - found[s] = true - } - if !found[session1] || !found[session2] { - t.Errorf("returned sessions do not match expected") - } - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := ®istry{ + byUser: make(map[string]map[Key]Session), + slugIndex: make(map[Key]string), + mu: sync.RWMutex{}, + } + user := tt.setupFunc(r) + sessions := r.GetAllSessionFromUser(user) + if len(sessions) != tt.expectN { + t.Errorf("expected %d sessions, got %d", tt.expectN, len(sessions)) + } + }) + } } func TestRegistry_Remove(t *testing.T) { - t.Run("remove existing key", func(t *testing.T) { - r := ®istry{ - byUser: make(map[string]map[Key]Session), - slugIndex: make(map[Key]string), - mu: sync.RWMutex{}, - } + tests := []struct { + name string + setupFunc func(r *registry) (string, types.SessionKey) + key types.SessionKey + verify func(*testing.T, *registry, string, types.SessionKey) + }{ + { + name: "remove existing key", + setupFunc: func(r *registry) (string, types.SessionKey) { + user := "user1" + key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP} + session := &mockSession{user: user} + r.mu.Lock() + r.byUser[user] = map[Key]Session{key: session} + r.slugIndex[key] = user + r.mu.Unlock() + return user, key + }, + verify: func(t *testing.T, r *registry, user string, key types.SessionKey) { + if _, ok := r.byUser[user][key]; ok { + t.Errorf("expected key to be removed from byUser") + } + if _, ok := r.slugIndex[key]; ok { + t.Errorf("expected key to be removed from slugIndex") + } + if _, ok := r.byUser[user]; ok { + t.Errorf("expected user to be removed from byUser map") + } + }, + }, + { + name: "remove non-existing key", + setupFunc: func(r *registry) (string, types.SessionKey) { + return "", types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP} + }, + }, + } - user := "user1" - key := types.SessionKey{Id: "a", Type: types.TunnelTypeHTTP} - session := &mockSession{user: user} - - r.mu.Lock() - r.byUser[user] = map[Key]Session{key: session} - r.slugIndex[key] = user - r.mu.Unlock() - - r.Remove(key) - - if _, ok := r.byUser[user][key]; ok { - t.Errorf("expected key to be removed from byUser") - } - if _, ok := r.slugIndex[key]; ok { - t.Errorf("expected key to be removed from slugIndex") - } - if _, ok := r.byUser[user]; ok { - t.Errorf("expected user to be removed from byUser map") - } - }) - - t.Run("remove non-existing key", func(t *testing.T) { - r := ®istry{ - byUser: make(map[string]map[Key]Session), - slugIndex: make(map[Key]string), - } - r.Remove(types.SessionKey{Id: "nonexist", Type: types.TunnelTypeHTTP}) - }) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := ®istry{ + byUser: make(map[string]map[Key]Session), + slugIndex: make(map[Key]string), + mu: sync.RWMutex{}, + } + user, key := tt.setupFunc(r) + if user == "" { + key = tt.key + } + r.Remove(key) + if tt.verify != nil { + tt.verify(t, r, user, key) + } + }) + } } func TestIsValidSlug(t *testing.T) { diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go index 800d2c8..2a43f95 100644 --- a/internal/transport/httphandler.go +++ b/internal/transport/httphandler.go @@ -175,7 +175,11 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry. go func() { channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) - resultChan <- channelResult{channel, reqs, err} + select { + case resultChan <- channelResult{channel, reqs, err}: + default: + hh.cleanupUnusedChannel(channel, reqs) + } }() select { diff --git a/server/server.go b/server/server.go index a1990b4..b6a66c1 100644 --- a/server/server.go +++ b/server/server.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "log" "net" "time" @@ -85,7 +86,7 @@ func (s *server) handleConnection(conn net.Conn) { defer func(sshConn *ssh.ServerConn) { err = sshConn.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { + if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.EOF) { log.Printf("failed to close SSH server: %v", err) } }(sshConn) @@ -101,7 +102,7 @@ func (s *server) handleConnection(conn net.Conn) { sshSession := session.New(s.randomizer, s.config, sshConn, forwardingReqs, chans, s.sessionRegistry, s.portRegistry, user) err = sshSession.Start() if err != nil { - log.Printf("SSH session ended with error: %v", err) + log.Printf("SSH session ended with error: %s", err.Error()) return } return diff --git a/session/interaction/interaction.go b/session/interaction/interaction.go index fe5b496..0e6570b 100644 --- a/session/interaction/interaction.go +++ b/session/interaction/interaction.go @@ -256,8 +256,6 @@ func (i *interaction) Start() { i.program.Kill() i.program = nil if i.closeFunc != nil { - if err := i.closeFunc(); err != nil { - log.Printf("Cannot close session: %s \n", err) - } + _ = i.closeFunc() } } diff --git a/session/lifecycle/lifecycle.go b/session/lifecycle/lifecycle.go index f9f9d6e..a775c22 100644 --- a/session/lifecycle/lifecycle.go +++ b/session/lifecycle/lifecycle.go @@ -2,6 +2,9 @@ package lifecycle import ( "errors" + "io" + "net" + "sync" "time" portUtil "tunnel_pls/internal/port" @@ -22,7 +25,9 @@ type SessionRegistry interface { } type lifecycle struct { + mu sync.Mutex status types.SessionStatus + closeErr error conn ssh.Conn channel ssh.Channel forwarder Forwarder @@ -73,29 +78,41 @@ func (l *lifecycle) Connection() ssh.Conn { return l.conn } func (l *lifecycle) SetStatus(status types.SessionStatus) { + l.mu.Lock() + defer l.mu.Unlock() l.status = status if status == types.SessionStatusRUNNING && l.startedAt.IsZero() { l.startedAt = time.Now() } } -func closeIfNotNil(c interface{ Close() error }) error { - if c != nil { - return c.Close() - } - return nil +func (l *lifecycle) IsActive() bool { + l.mu.Lock() + defer l.mu.Unlock() + return l.status == types.SessionStatusRUNNING } func (l *lifecycle) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + if l.status == types.SessionStatusCLOSED { + return l.closeErr + } + l.status = types.SessionStatusCLOSED + var errs []error tunnelType := l.forwarder.TunnelType() - if err := closeIfNotNil(l.channel); err != nil { - errs = append(errs, err) + if l.channel != nil { + if err := l.channel.Close(); err != nil && !isClosedError(err) { + errs = append(errs, err) + } } - if err := closeIfNotNil(l.conn); err != nil { - errs = append(errs, err) + if l.conn != nil { + if err := l.conn.Close(); err != nil && !isClosedError(err) { + errs = append(errs, err) + } } clientSlug := l.slug.String() @@ -114,11 +131,15 @@ func (l *lifecycle) Close() error { } } - return errors.Join(errs...) + l.closeErr = errors.Join(errs...) + return l.closeErr } -func (l *lifecycle) IsActive() bool { - return l.status == types.SessionStatusRUNNING +func isClosedError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || err.Error() == "EOF" } func (l *lifecycle) StartedAt() time.Time { diff --git a/types/types.go b/types/types.go index 34ccfb4..77d6ac4 100644 --- a/types/types.go +++ b/types/types.go @@ -7,6 +7,7 @@ type SessionStatus int const ( SessionStatusINITIALIZING SessionStatus = iota SessionStatusRUNNING + SessionStatusCLOSED ) type InteractiveMode int