From b0249c45ae3e77ce353a6a0de9786fc3e0f67e17 Mon Sep 17 00:00:00 2001 From: bagas Date: Thu, 22 Jan 2026 19:22:35 +0700 Subject: [PATCH] test(transport): add unit tests for transport behavior using Testify --- go.mod | 5 + go.sum | 10 + internal/transport/http.go | 2 +- internal/transport/http_test.go | 112 ++++ internal/transport/httphandler.go | 37 +- internal/transport/httphandler_test.go | 755 +++++++++++++++++++++++++ internal/transport/https.go | 2 +- internal/transport/https_test.go | 100 ++++ internal/transport/tcp.go | 6 +- internal/transport/tcp_test.go | 146 +++++ internal/transport/tls_test.go | 178 ++++++ internal/transport/transport.go | 4 + internal/transport/transport_test.go | 15 + 13 files changed, 1346 insertions(+), 26 deletions(-) create mode 100644 internal/transport/http_test.go create mode 100644 internal/transport/httphandler_test.go create mode 100644 internal/transport/https_test.go create mode 100644 internal/transport/tcp_test.go create mode 100644 internal/transport/tls_test.go create mode 100644 internal/transport/transport_test.go diff --git a/go.mod b/go.mod index 958657e..214f1c5 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/libdns/cloudflare v0.2.2 github.com/muesli/termenv v0.16.0 + github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.47.0 google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 @@ -27,6 +28,7 @@ require ( github.com/clipperhouse/displaywidth v0.6.2 // indirect github.com/clipperhouse/stringish v0.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/libdns/libdns v1.1.1 // indirect @@ -38,8 +40,10 @@ require ( github.com/miekg/dns v1.1.69 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/zeebo/blake3 v0.2.4 // indirect go.uber.org/multierr v1.11.0 // indirect @@ -52,4 +56,5 @@ require ( golang.org/x/text v0.33.0 // indirect golang.org/x/tools v0.40.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 11912af..4356e9d 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,7 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= @@ -80,6 +81,12 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA= github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= @@ -138,5 +145,8 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/transport/http.go b/internal/transport/http.go index dd091c3..8ea5c4d 100644 --- a/internal/transport/http.go +++ b/internal/transport/http.go @@ -35,6 +35,6 @@ func (ht *httpServer) Serve(listener net.Listener) error { continue } - go ht.handler.handler(conn, false) + go ht.handler.Handler(conn, false) } } diff --git a/internal/transport/http_test.go b/internal/transport/http_test.go new file mode 100644 index 0000000..348dbb0 --- /dev/null +++ b/internal/transport/http_test.go @@ -0,0 +1,112 @@ +package transport + +import ( + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestNewHTTPServer(t *testing.T) { + msr := new(MockSessionRegistry) + domain := "example.com" + port := "8080" + redirectTLS := true + + srv := NewHTTPServer(domain, port, msr, redirectTLS) + assert.NotNil(t, srv) + + httpSrv, ok := srv.(*httpServer) + assert.True(t, ok) + assert.Equal(t, port, httpSrv.port) + assert.Equal(t, domain, httpSrv.handler.domain) + assert.Equal(t, msr, httpSrv.handler.sessionRegistry) + assert.Equal(t, redirectTLS, httpSrv.handler.redirectTLS) +} + +func TestHTTPServer_Listen(t *testing.T) { + msr := new(MockSessionRegistry) + srv := NewHTTPServer("example.com", "0", msr, false) + + listener, err := srv.Listen() + assert.NoError(t, err) + assert.NotNil(t, listener) + listener.Close() +} + +func TestHTTPServer_Serve(t *testing.T) { + msr := new(MockSessionRegistry) + srv := NewHTTPServer("example.com", "0", msr, false) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + + go func() { + time.Sleep(100 * time.Millisecond) + listener.Close() + }() + + err = srv.Serve(listener) + assert.True(t, errors.Is(err, net.ErrClosed)) +} + +func TestHTTPServer_Serve_AcceptError(t *testing.T) { + msr := new(MockSessionRegistry) + srv := NewHTTPServer("example.com", "0", msr, false) + + ml := new(mockListener) + ml.On("Accept").Return(nil, errors.New("accept error")).Once() + ml.On("Accept").Return(nil, net.ErrClosed).Once() + + err := srv.Serve(ml) + assert.True(t, errors.Is(err, net.ErrClosed)) + ml.AssertExpectations(t) +} + +func TestHTTPServer_Serve_Success(t *testing.T) { + msr := new(MockSessionRegistry) + srv := NewHTTPServer("example.com", "0", msr, false) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + + go func() { + _ = srv.Serve(listener) + }() + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + assert.NoError(t, err) + + _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) + + time.Sleep(100 * time.Millisecond) + conn.Close() + listener.Close() +} + +type mockListener struct { + mock.Mock +} + +func (m *mockListener) Accept() (net.Conn, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(net.Conn), args.Error(1) +} + +func (m *mockListener) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *mockListener) Addr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go index 8bab4a0..800d2c8 100644 --- a/internal/transport/httphandler.go +++ b/internal/transport/httphandler.go @@ -19,6 +19,8 @@ import ( "golang.org/x/crypto/ssh" ) +var openChannelTimeout = 5 * time.Second + type httpHandler struct { domain string sessionRegistry registry.Registry @@ -52,7 +54,7 @@ func (hh *httpHandler) badRequest(conn net.Conn) error { return nil } -func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { +func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) { defer hh.closeConnection(conn) dstReader := bufio.NewReader(conn) @@ -69,7 +71,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { } if hh.shouldRedirectToTLS(isTLS) { - _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain)) + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.domain)) return } @@ -77,7 +79,10 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) { return } - sshSession, err := hh.getSession(slug) + sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ + Id: slug, + Type: types.TunnelTypeHTTP, + }) if err != nil { _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug)) return @@ -102,7 +107,7 @@ func (hh *httpHandler) closeConnection(conn net.Conn) { func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) { host := strings.Split(reqhf.Value("Host"), ".") - if len(host) < 1 { + if len(host) <= 1 { return "", errors.New("invalid host") } return host[0], nil @@ -128,21 +133,11 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { )) if err != nil { log.Println("Failed to write 200 OK:", err) + return true } return true } -func (hh *httpHandler) getSession(slug string) (registry.Session, error) { - sshSession, err := hh.sessionRegistry.Get(types.SessionKey{ - Id: slug, - Type: types.TunnelTypeHTTP, - }) - if err != nil { - return nil, err - } - return sshSession, nil -} - func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) { channel, err := hh.openForwardedChannel(hw, sshSession) if err != nil { @@ -180,11 +175,7 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry. go func() { channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload) - select { - case resultChan <- channelResult{channel, reqs, err}: - default: - hh.cleanupUnusedChannel(channel, reqs) - } + resultChan <- channelResult{channel, reqs, err} }() select { @@ -194,7 +185,11 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry. } go ssh.DiscardRequests(result.reqs) return result.channel, nil - case <-time.After(5 * time.Second): + case <-time.After(openChannelTimeout): + go func() { + result := <-resultChan + hh.cleanupUnusedChannel(result.channel, result.reqs) + }() return nil, errors.New("timeout opening forwarded-tcpip channel") } } diff --git a/internal/transport/httphandler_test.go b/internal/transport/httphandler_test.go new file mode 100644 index 0000000..3159366 --- /dev/null +++ b/internal/transport/httphandler_test.go @@ -0,0 +1,755 @@ +package transport + +import ( + "bytes" + "fmt" + "io" + "net" + "strings" + "sync" + "testing" + "time" + "tunnel_pls/internal/port" + "tunnel_pls/internal/registry" + "tunnel_pls/session/forwarder" + "tunnel_pls/session/interaction" + "tunnel_pls/session/lifecycle" + "tunnel_pls/session/slug" + "tunnel_pls/types" + + "golang.org/x/crypto/ssh" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockSessionRegistry struct { + mock.Mock +} + +func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) { + args := m.Called(key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) { + args := m.Called(user, key) + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).(registry.Session), args.Error(1) +} + +func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error { + args := m.Called(user, oldKey, newKey) + return args.Error(0) +} + +func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool { + args := m.Called(key, session) + return args.Bool(0) +} + +func (m *MockSessionRegistry) Remove(key registry.Key) { + m.Called(key) +} + +func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session { + args := m.Called(user) + return args.Get(0).([]registry.Session) +} + +func (m *MockSessionRegistry) Slug() slug.Slug { + args := m.Called() + return args.Get(0).(slug.Slug) +} + +type MockSession struct { + mock.Mock +} + +func (m *MockSession) Lifecycle() lifecycle.Lifecycle { + args := m.Called() + return args.Get(0).(lifecycle.Lifecycle) +} + +func (m *MockSession) Interaction() interaction.Interaction { + args := m.Called() + return args.Get(0).(interaction.Interaction) +} + +func (m *MockSession) Forwarder() forwarder.Forwarder { + args := m.Called() + return args.Get(0).(forwarder.Forwarder) +} + +func (m *MockSession) Slug() slug.Slug { + args := m.Called() + return args.Get(0).(slug.Slug) +} + +func (m *MockSession) Detail() *types.Detail { + args := m.Called() + return args.Get(0).(*types.Detail) +} + +type MockLifecycle struct { + mock.Mock +} + +func (m *MockLifecycle) Connection() ssh.Conn { + args := m.Called() + return args.Get(0).(ssh.Conn) +} + +func (m *MockLifecycle) PortRegistry() port.Port { + args := m.Called() + return args.Get(0).(port.Port) +} + +func (m *MockLifecycle) User() string { + args := m.Called() + return args.String(0) +} + +func (m *MockLifecycle) SetChannel(channel ssh.Channel) { + m.Called(channel) +} + +func (m *MockLifecycle) SetStatus(status types.SessionStatus) { + m.Called(status) +} + +func (m *MockLifecycle) IsActive() bool { + args := m.Called() + return args.Bool(0) +} + +func (m *MockLifecycle) StartedAt() time.Time { + args := m.Called() + return args.Get(0).(time.Time) +} + +func (m *MockLifecycle) Close() error { + args := m.Called() + return args.Error(0) +} + +type MockSSHConn struct { + ssh.Conn + mock.Mock +} + +func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(name, data) + if args.Get(0) == nil { + return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) + } + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +type MockSSHChannel struct { + ssh.Channel + mock.Mock +} + +func (m *MockSSHChannel) Write(data []byte) (int, error) { + args := m.Called(data) + return args.Int(0), args.Error(1) +} + +func (m *MockSSHChannel) Close() error { + args := m.Called() + return args.Error(0) +} + +type MockForwarder struct { + mock.Mock +} + +func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte { + args := m.Called(origin) + return args.Get(0).([]byte) +} + +func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) { + m.Called(dst) +} + +func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) { + m.Called(dst, src) +} + +func (m *MockForwarder) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockForwarder) TunnelType() types.TunnelType { + args := m.Called() + return args.Get(0).(types.TunnelType) +} + +func (m *MockForwarder) ForwardedPort() uint16 { + args := m.Called() + return uint16(args.Int(0)) +} + +func (m *MockForwarder) SetType(tunnelType types.TunnelType) { + m.Called(tunnelType) +} + +func (m *MockForwarder) SetForwardedPort(port uint16) { + m.Called(port) +} + +func (m *MockForwarder) SetListener(listener net.Listener) { + m.Called(listener) +} + +func (m *MockForwarder) Listener() net.Listener { + args := m.Called() + return args.Get(0).(net.Listener) +} + +func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) { + args := m.Called(payload) + if args.Get(0) == nil { + return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2) + } + return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2) +} + +type MockConn struct { + net.Conn + mock.Mock + ReadBuffer *bytes.Buffer +} + +func (m *MockConn) Read(b []byte) (n int, err error) { + if m.ReadBuffer != nil { + return m.ReadBuffer.Read(b) + } + args := m.Called(b) + return args.Int(0), args.Error(1) +} + +func (m *MockConn) Write(b []byte) (n int, err error) { + args := m.Called(b) + return args.Int(0), args.Error(1) +} + +func (m *MockConn) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockConn) RemoteAddr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} + +type wrappedConn struct { + net.Conn + remoteAddr net.Addr +} + +func (c *wrappedConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func TestNewHTTPHandler(t *testing.T) { + msr := new(MockSessionRegistry) + hh := newHTTPHandler("domain", msr, true) + assert.NotNil(t, hh) + assert.Equal(t, "domain", hh.domain) + assert.Equal(t, msr, hh.sessionRegistry) + assert.True(t, hh.redirectTLS) +} + +func TestHandler(t *testing.T) { + tests := []struct { + name string + isTLS bool + redirectTLS bool + request []byte + expected []byte + setupMocks func(*MockSessionRegistry) + setupConn func() (net.Conn, net.Conn) + expectError bool + }{ + { + name: "bad request - invalid host", + isTLS: false, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "bad request - missing host", + isTLS: false, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\n\r\n"), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "isTLS true and redirectTLS true - no redirect", + isTLS: true, + redirectTLS: true, + request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "redirect to TLS", + isTLS: false, + redirectTLS: true, + request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://example.domain/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "handle ping request", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "session not found", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + msr.On("Get", types.SessionKey{ + Id: "test", + Type: types.TunnelTypeHTTP, + }).Return((registry.Session)(nil), fmt.Errorf("session not found")) + }, + }, + { + name: "bad request - invalid http", + isTLS: false, + redirectTLS: false, + request: []byte("INVALID\r\n\r\n"), + expected: []byte(""), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "forwarding - open channel fails", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockLifecycle := new(MockLifecycle) + mockSSHConn := new(MockSSHConn) + + msr.On("Get", types.SessionKey{ + Id: "test", + Type: types.TunnelTypeHTTP, + }).Return(mockSession, nil) + + mockSession.On("Forwarder").Return(mockForwarder) + mockSession.On("Lifecycle").Return(mockLifecycle) + mockLifecycle.On("Connection").Return(mockSSHConn) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed")) + mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) { + w := args.Get(0).(io.Writer) + _, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) + }) + }, + }, + { + name: "forwarding - send initial request fails", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte(""), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockLifecycle := new(MockLifecycle) + mockSSHConn := new(MockSSHConn) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", types.SessionKey{ + Id: "test", + Type: types.TunnelTypeHTTP, + }).Return(mockSession, nil) + + mockSession.On("Forwarder").Return(mockForwarder) + mockSession.On("Lifecycle").Return(mockLifecycle) + mockLifecycle.On("Connection").Return(mockSSHConn) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + + reqCh := make(chan *ssh.Request) + mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + + mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mockSSHChannel.On("Close").Return(nil) + + go func() { + for range reqCh { + } + }() + }, + }, + { + name: "forwarding - success", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockLifecycle := new(MockLifecycle) + mockSSHConn := new(MockSSHConn) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", types.SessionKey{ + Id: "test", + Type: types.TunnelTypeHTTP, + }).Return(mockSession, nil) + + mockSession.On("Forwarder").Return(mockForwarder) + mockSession.On("Lifecycle").Return(mockLifecycle) + mockLifecycle.On("Connection").Return(mockSSHConn) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + + reqCh := make(chan *ssh.Request) + mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + + mockSSHChannel.On("Write", mock.Anything).Return(0, nil) + mockSSHChannel.On("Close").Return(nil) + + mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) { + w := args.Get(0).(io.ReadWriter) + _, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello")) + }) + + go func() { + for range reqCh { + } + }() + }, + }, + { + name: "redirect - write failure", + isTLS: false, + redirectTLS: true, + request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n")) + mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mc.On("Close").Return(nil) + return mc, nil + }, + }, + { + name: "bad request - write failure", + isTLS: false, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\n\r\n"), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n")) + mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mc.On("Close").Return(nil) + return mc, nil + }, + }, + { + name: "handle ping request - write failure", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n")) + mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mc.On("Close").Return(nil) + return mc, nil + }, + }, + { + name: "close connection - error", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n")) + mc.On("Write", mock.Anything).Return(182, nil) + mc.On("Close").Return(fmt.Errorf("close error")) + return mc, nil + }, + }, + { + name: "forwarding - stream close error", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte(""), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockLifecycle := new(MockLifecycle) + mockSSHConn := new(MockSSHConn) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", mock.Anything).Return(mockSession, nil) + mockSession.On("Forwarder").Return(mockForwarder) + mockSession.On("Lifecycle").Return(mockLifecycle) + mockLifecycle.On("Connection").Return(mockSSHConn) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + reqCh := make(chan *ssh.Request) + mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + + mockSSHChannel.On("Write", mock.Anything).Return(0, nil) + mockSSHChannel.On("Close").Return(nil) + + mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return() + }, + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n")) + mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2) + addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + mc.On("RemoteAddr").Return(addr) + return mc, nil + }, + }, + { + name: "forwarding - middleware failure", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte(""), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool { + return k.Id == "test" + })).Return(mockSession, nil) + mockSession.On("Forwarder").Return(mockForwarder) + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + + mockLifecycle := new(MockLifecycle) + mockSession.On("Lifecycle").Return(mockLifecycle) + mockSSHConn := new(MockSSHConn) + mockLifecycle.On("Connection").Return(mockSSHConn) + mockSSHChannel := new(MockSSHChannel) + reqCh := make(chan *ssh.Request) + mockSSHConn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + mockSSHChannel.On("Close").Return(nil) + }, + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n")) + mc.On("Close").Return(nil).Times(2) + mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")}) + return mc, nil + }, + }, + { + name: "forwarding - channel close error", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"), + setupMocks: func(msr *MockSessionRegistry) { + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockLifecycle := new(MockLifecycle) + mockSSHConn := new(MockSSHConn) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", mock.Anything).Return(mockSession, nil) + mockSession.On("Forwarder").Return(mockForwarder) + mockSession.On("Lifecycle").Return(mockLifecycle) + mockLifecycle.On("Connection").Return(mockSSHConn) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + reqCh := make(chan *ssh.Request) + mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + + mockSSHChannel.On("Write", mock.Anything).Return(0, nil) + mockSSHChannel.On("Close").Return(fmt.Errorf("close error")) + + mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) { + w := args.Get(0).(io.ReadWriter) + _, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello")) + }) + }, + }, + { + name: "forwarding - open channel timeout", + isTLS: true, + redirectTLS: false, + request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"), + expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + oldTimeout := openChannelTimeout + openChannelTimeout = 10 * time.Millisecond + t.Cleanup(func() { + openChannelTimeout = oldTimeout + time.Sleep(100 * time.Millisecond) + }) + + mockSession := new(MockSession) + mockForwarder := new(MockForwarder) + mockLifecycle := new(MockLifecycle) + mockSSHConn := new(MockSSHConn) + mockSSHChannel := new(MockSSHChannel) + + msr.On("Get", mock.Anything).Return(mockSession, nil) + mockSession.On("Forwarder").Return(mockForwarder) + mockSession.On("Lifecycle").Return(mockLifecycle) + mockLifecycle.On("Connection").Return(mockSSHConn) + + mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload")) + mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) { + w := args.Get(0).(io.Writer) + _, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) + }) + + reqCh := make(chan *ssh.Request) + mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Run(func(args mock.Arguments) { + time.Sleep(50 * time.Millisecond) + }).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil) + + mockSSHChannel.On("Close").Return(fmt.Errorf("cleanup close error")) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSessionRegistry := new(MockSessionRegistry) + hh := &httpHandler{ + domain: "domain", + sessionRegistry: mockSessionRegistry, + redirectTLS: tt.redirectTLS, + } + + if tt.setupMocks != nil { + tt.setupMocks(mockSessionRegistry) + } + + var serverConn, clientConn net.Conn + if tt.setupConn != nil { + serverConn, clientConn = tt.setupConn() + } else { + serverConn, clientConn = net.Pipe() + } + + if clientConn != nil { + defer clientConn.Close() + } + + remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") + var wrappedServerConn net.Conn + if _, ok := serverConn.(*MockConn); ok { + wrappedServerConn = serverConn + } else { + wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr} + } + + responseChan := make(chan []byte, 1) + doneChan := make(chan struct{}) + + if clientConn != nil { + go func() { + defer close(doneChan) + var res []byte + for { + buf := make([]byte, 4096) + n, err := clientConn.Read(buf) + if err != nil { + if err != io.EOF { + t.Logf("Error reading response: %v", err) + } + break + } + res = append(res, buf[:n]...) + if len(tt.expected) > 0 && len(res) >= len(tt.expected) { + break + } + } + responseChan <- res + }() + + go func() { + _, err := clientConn.Write(tt.request) + if err != nil { + t.Logf("Error writing request: %v", err) + } + }() + } else { + close(responseChan) + close(doneChan) + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + hh.Handler(wrappedServerConn, tt.isTLS) + }() + + select { + case response := <-responseChan: + if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" { + resStr := string(response) + assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n")) + assert.Contains(t, resStr, "Content-Length: 5\r\n") + assert.Contains(t, resStr, "Server: Tunnel Please\r\n") + assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello")) + } else { + assert.Equal(t, string(tt.expected), string(response)) + } + case <-time.After(2 * time.Second): + if clientConn != nil { + t.Fatal("Test timeout - no response received") + } + } + + wg.Wait() + if clientConn != nil { + <-doneChan + } + + mockSessionRegistry.AssertExpectations(t) + if mc, ok := serverConn.(*MockConn); ok { + mc.AssertExpectations(t) + } + }) + } +} diff --git a/internal/transport/https.go b/internal/transport/https.go index 88ffe27..10be94f 100644 --- a/internal/transport/https.go +++ b/internal/transport/https.go @@ -40,6 +40,6 @@ func (ht *https) Serve(listener net.Listener) error { continue } - go ht.httpHandler.handler(conn, true) + go ht.httpHandler.Handler(conn, true) } } diff --git a/internal/transport/https_test.go b/internal/transport/https_test.go new file mode 100644 index 0000000..e956430 --- /dev/null +++ b/internal/transport/https_test.go @@ -0,0 +1,100 @@ +package transport + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewHTTPSServer(t *testing.T) { + msr := new(MockSessionRegistry) + domain := "example.com" + port := "443" + redirectTLS := false + tlsConfig := &tls.Config{} + + srv := NewHTTPSServer(domain, port, msr, redirectTLS, tlsConfig) + assert.NotNil(t, srv) + + httpsSrv, ok := srv.(*https) + assert.True(t, ok) + assert.Equal(t, port, httpsSrv.port) + assert.Equal(t, domain, httpsSrv.domain) + assert.Equal(t, tlsConfig, httpsSrv.tlsConfig) + assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry) +} + +func TestHTTPSServer_Listen(t *testing.T) { + msr := new(MockSessionRegistry) + + tlsConfig := &tls.Config{ + GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return nil, nil + }, + } + srv := NewHTTPSServer("example.com", "0", msr, false, tlsConfig) + + listener, err := srv.Listen() + if err != nil { + t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err) + return + } + assert.NotNil(t, listener) + listener.Close() +} + +func TestHTTPSServer_Serve(t *testing.T) { + msr := new(MockSessionRegistry) + srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{}) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + + go func() { + time.Sleep(100 * time.Millisecond) + listener.Close() + }() + + err = srv.Serve(listener) + assert.True(t, errors.Is(err, net.ErrClosed)) +} + +func TestHTTPSServer_Serve_AcceptError(t *testing.T) { + msr := new(MockSessionRegistry) + srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{}) + + ml := new(mockListener) + ml.On("Accept").Return(nil, errors.New("accept error")).Once() + ml.On("Accept").Return(nil, net.ErrClosed).Once() + + err := srv.Serve(ml) + assert.True(t, errors.Is(err, net.ErrClosed)) + ml.AssertExpectations(t) +} + +func TestHTTPSServer_Serve_Success(t *testing.T) { + msr := new(MockSessionRegistry) + srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{}) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + + go func() { + _ = srv.Serve(listener) + }() + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + assert.NoError(t, err) + + _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) + + time.Sleep(100 * time.Millisecond) + conn.Close() + listener.Close() +} diff --git a/internal/transport/tcp.go b/internal/transport/tcp.go index 99670d2..91ab0b0 100644 --- a/internal/transport/tcp.go +++ b/internal/transport/tcp.go @@ -12,16 +12,16 @@ import ( type tcp struct { port uint16 - forwarder forwarder + forwarder Forwarder } -type forwarder interface { +type Forwarder interface { CreateForwardedTCPIPPayload(origin net.Addr) []byte OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) HandleConnection(dst io.ReadWriter, src ssh.Channel) } -func NewTCPServer(port uint16, forwarder forwarder) Transport { +func NewTCPServer(port uint16, forwarder Forwarder) Transport { return &tcp{ port: port, forwarder: forwarder, diff --git a/internal/transport/tcp_test.go b/internal/transport/tcp_test.go new file mode 100644 index 0000000..d7c3f8b --- /dev/null +++ b/internal/transport/tcp_test.go @@ -0,0 +1,146 @@ +package transport + +import ( + "errors" + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "golang.org/x/crypto/ssh" +) + +func TestNewTCPServer(t *testing.T) { + mf := new(MockForwarder) + port := uint16(9000) + + srv := NewTCPServer(port, mf) + assert.NotNil(t, srv) + + tcpSrv, ok := srv.(*tcp) + assert.True(t, ok) + assert.Equal(t, port, tcpSrv.port) + assert.Equal(t, mf, tcpSrv.forwarder) +} + +func TestTCPServer_Listen(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf) + + listener, err := srv.Listen() + assert.NoError(t, err) + assert.NotNil(t, listener) + listener.Close() +} + +func TestTCPServer_Serve(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + + go func() { + time.Sleep(100 * time.Millisecond) + listener.Close() + }() + + err = srv.Serve(listener) + assert.Nil(t, err) +} + +func TestTCPServer_Serve_AcceptError(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf) + + ml := new(mockListener) + ml.On("Accept").Return(nil, errors.New("accept error")).Once() + ml.On("Accept").Return(nil, net.ErrClosed).Once() + + err := srv.Serve(ml) + assert.Nil(t, err) + ml.AssertExpectations(t) +} + +func TestTCPServer_Serve_Success(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + assert.NoError(t, err) + port := listener.Addr().(*net.TCPAddr).Port + + payload := []byte("test-payload") + mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) + reqs := make(chan *ssh.Request) + mf.On("OpenForwardedChannel", payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil) + mf.On("HandleConnection", mock.Anything, mock.Anything).Return() + + go func() { + _ = srv.Serve(listener) + }() + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + assert.NoError(t, err) + + time.Sleep(100 * time.Millisecond) + conn.Close() + listener.Close() + + mf.AssertExpectations(t) +} + +func TestTCPServer_handleTcp_Success(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf).(*tcp) + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + payload := []byte("test-payload") + mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) + + reqs := make(chan *ssh.Request) + mockChannel := new(MockSSHChannel) + mf.On("OpenForwardedChannel", payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil) + + mf.On("HandleConnection", serverConn, mockChannel).Return() + + srv.handleTcp(serverConn) + + mf.AssertExpectations(t) +} + +func TestTCPServer_handleTcp_CloseError(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf).(*tcp) + + mc := new(MockConn) + mc.On("Close").Return(errors.New("close error")) + mc.On("RemoteAddr").Return(&net.TCPAddr{}) + + payload := []byte("test-payload") + mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) + mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) + + srv.handleTcp(mc) + mc.AssertExpectations(t) +} + +func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) { + mf := new(MockForwarder) + srv := NewTCPServer(0, mf).(*tcp) + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + payload := []byte("test-payload") + mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload) + mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error")) + + srv.handleTcp(serverConn) + + mf.AssertExpectations(t) +} diff --git a/internal/transport/tls_test.go b/internal/transport/tls_test.go new file mode 100644 index 0000000..17e7214 --- /dev/null +++ b/internal/transport/tls_test.go @@ -0,0 +1,178 @@ +package transport + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "testing" + "time" + "tunnel_pls/types" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockConfig struct { + mock.Mock +} + +func (m *MockConfig) Domain() string { return m.Called().String(0) } +func (m *MockConfig) SSHPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPPort() string { return m.Called().String(0) } +func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) } +func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) } +func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) } +func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) } +func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } +func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } +func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } +func (m *MockConfig) PprofPort() string { return m.Called().String(0) } +func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } +func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) } +func (m *MockConfig) GRPCPort() string { return m.Called().String(0) } +func (m *MockConfig) NodeToken() string { return m.Called().String(0) } + +func TestValidateCertDomains_NotFound(t *testing.T) { + result := ValidateCertDomains("nonexistent.pem", "example.com") + assert.False(t, result) +} + +func TestValidateCertDomains_InvalidPEM(t *testing.T) { + tmpFile, err := os.CreateTemp("", "invalid*.pem") + assert.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + _, _ = tmpFile.WriteString("not a pem") + tmpFile.Close() + + result := ValidateCertDomains(tmpFile.Name(), "example.com") + assert.False(t, result) +} + +func TestTLSManager_getTLSConfig(t *testing.T) { + tm := &tlsManager{ + useCertMagic: false, + } + cfg := tm.getTLSConfig() + assert.NotNil(t, cfg) + assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion) +} + +func TestTLSManager_getCertificate_Magic(t *testing.T) { + tm := &tlsManager{ + useCertMagic: true, + } + hello := &tls.ClientHelloInfo{} + assert.Panics(t, func() { + _, _ = tm.getCertificate(hello) + }) +} + +func TestTLSManager_userCertsExistAndValid(t *testing.T) { + tm := &tlsManager{ + certPath: "nonexistent.pem", + keyPath: "nonexistent.key", + } + assert.False(t, tm.userCertsExistAndValid()) + + keyFile, _ := os.CreateTemp("", "key*.pem") + defer os.Remove(keyFile.Name()) + tm.keyPath = keyFile.Name() + assert.False(t, tm.userCertsExistAndValid()) +} + +func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, soon bool) (string, string) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + + notAfter := time.Now().Add(365 * 24 * time.Hour) + if expired { + notAfter = time.Now().Add(-24 * time.Hour) + } else if soon { + notAfter = time.Now().Add(15 * 24 * time.Hour) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: domain, + }, + NotBefore: time.Now().Add(-24 * time.Hour), + NotAfter: notAfter, + DNSNames: []string{domain}, + } + + if wildcard { + template.DNSNames = append(template.DNSNames, "*."+domain) + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + assert.NoError(t, err) + + certOut, _ := os.CreateTemp("", "cert*.pem") + _ = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + certOut.Close() + + keyOut, _ := os.CreateTemp("", "key*.pem") + _ = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + keyOut.Close() + + return certOut.Name(), keyOut.Name() +} + +func TestValidateCertDomains_Success(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + result := ValidateCertDomains(certPath, "example.com") + assert.True(t, result) +} + +func TestValidateCertDomains_Expired(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, true, false) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + result := ValidateCertDomains(certPath, "example.com") + assert.False(t, result) +} + +func TestValidateCertDomains_ExpiringSoon(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, true) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + result := ValidateCertDomains(certPath, "example.com") + assert.False(t, result) +} + +func TestValidateCertDomains_MissingWildcard(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", false, false, false) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + result := ValidateCertDomains(certPath, "example.com") + assert.False(t, result) +} + +func TestTLSManager_loadUserCerts_Success(t *testing.T) { + certPath, keyPath := createTestCert(t, "example.com", true, false, false) + defer os.Remove(certPath) + defer os.Remove(keyPath) + + tm := &tlsManager{ + certPath: certPath, + keyPath: keyPath, + } + err := tm.loadUserCerts() + assert.NoError(t, err) + assert.NotNil(t, tm.userCert) +} diff --git a/internal/transport/transport.go b/internal/transport/transport.go index ca27061..31219fd 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -8,3 +8,7 @@ type Transport interface { Listen() (net.Listener, error) Serve(listener net.Listener) error } + +type HTTP interface { + Handler(conn net.Conn, isTLS bool) +} diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go new file mode 100644 index 0000000..016445d --- /dev/null +++ b/internal/transport/transport_test.go @@ -0,0 +1,15 @@ +package transport + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTransportInterface(t *testing.T) { + var _ Transport = (*httpServer)(nil) + var _ Transport = (*https)(nil) + var _ Transport = (*tcp)(nil) + + assert.True(t, true) +}