staging #77
@@ -11,6 +11,7 @@ require (
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/libdns/cloudflare v0.2.2
|
||||
github.com/muesli/termenv v0.16.0
|
||||
github.com/stretchr/testify v1.8.1
|
||||
golang.org/x/crypto v0.47.0
|
||||
google.golang.org/grpc v1.78.0
|
||||
google.golang.org/protobuf v1.36.11
|
||||
@@ -27,6 +28,7 @@ require (
|
||||
github.com/clipperhouse/displaywidth v0.6.2 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/libdns/libdns v1.1.1 // indirect
|
||||
@@ -38,8 +40,10 @@ require (
|
||||
github.com/miekg/dns v1.1.69 // indirect
|
||||
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
|
||||
github.com/muesli/cancelreader v0.2.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/sahilm/fuzzy v0.1.1 // indirect
|
||||
github.com/stretchr/objx v0.5.0 // indirect
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
|
||||
github.com/zeebo/blake3 v0.2.4 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
@@ -52,4 +56,5 @@ require (
|
||||
golang.org/x/text v0.33.0 // indirect
|
||||
golang.org/x/tools v0.40.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -32,6 +32,7 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa
|
||||
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
|
||||
@@ -80,6 +81,12 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
|
||||
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
|
||||
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
|
||||
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
|
||||
@@ -138,5 +145,8 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
|
||||
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -35,6 +35,6 @@ func (ht *httpServer) Serve(listener net.Listener) error {
|
||||
continue
|
||||
}
|
||||
|
||||
go ht.handler.handler(conn, false)
|
||||
go ht.handler.Handler(conn, false)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func TestNewHTTPServer(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
domain := "example.com"
|
||||
port := "8080"
|
||||
redirectTLS := true
|
||||
|
||||
srv := NewHTTPServer(domain, port, msr, redirectTLS)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
httpSrv, ok := srv.(*httpServer)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, port, httpSrv.port)
|
||||
assert.Equal(t, domain, httpSrv.handler.domain)
|
||||
assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
|
||||
assert.Equal(t, redirectTLS, httpSrv.handler.redirectTLS)
|
||||
}
|
||||
|
||||
func TestHTTPServer_Listen(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPServer("example.com", "0", msr, false)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, listener)
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
func TestHTTPServer_Serve(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPServer("example.com", "0", msr, false)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
err = srv.Serve(listener)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
}
|
||||
|
||||
func TestHTTPServer_Serve_AcceptError(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPServer("example.com", "0", msr, false)
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
err := srv.Serve(ml)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
ml.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestHTTPServer_Serve_Success(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPServer("example.com", "0", msr, false)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
conn.Close()
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
type mockListener struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *mockListener) Accept() (net.Conn, error) {
|
||||
args := m.Called()
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(net.Conn), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *mockListener) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *mockListener) Addr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var openChannelTimeout = 5 * time.Second
|
||||
|
||||
type httpHandler struct {
|
||||
domain string
|
||||
sessionRegistry registry.Registry
|
||||
@@ -52,7 +54,7 @@ func (hh *httpHandler) badRequest(conn net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
||||
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
|
||||
defer hh.closeConnection(conn)
|
||||
|
||||
dstReader := bufio.NewReader(conn)
|
||||
@@ -69,7 +71,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
||||
}
|
||||
|
||||
if hh.shouldRedirectToTLS(isTLS) {
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.domain))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -77,7 +79,10 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
|
||||
return
|
||||
}
|
||||
|
||||
sshSession, err := hh.getSession(slug)
|
||||
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
||||
Id: slug,
|
||||
Type: types.TunnelTypeHTTP,
|
||||
})
|
||||
if err != nil {
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
|
||||
return
|
||||
@@ -102,7 +107,7 @@ func (hh *httpHandler) closeConnection(conn net.Conn) {
|
||||
|
||||
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
|
||||
host := strings.Split(reqhf.Value("Host"), ".")
|
||||
if len(host) < 1 {
|
||||
if len(host) <= 1 {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
return host[0], nil
|
||||
@@ -128,21 +133,11 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
||||
))
|
||||
if err != nil {
|
||||
log.Println("Failed to write 200 OK:", err)
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
|
||||
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
|
||||
Id: slug,
|
||||
Type: types.TunnelTypeHTTP,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sshSession, nil
|
||||
}
|
||||
|
||||
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
|
||||
channel, err := hh.openForwardedChannel(hw, sshSession)
|
||||
if err != nil {
|
||||
@@ -180,11 +175,7 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.
|
||||
|
||||
go func() {
|
||||
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
|
||||
select {
|
||||
case resultChan <- channelResult{channel, reqs, err}:
|
||||
default:
|
||||
hh.cleanupUnusedChannel(channel, reqs)
|
||||
}
|
||||
resultChan <- channelResult{channel, reqs, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -194,7 +185,11 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.
|
||||
}
|
||||
go ssh.DiscardRequests(result.reqs)
|
||||
return result.channel, nil
|
||||
case <-time.After(5 * time.Second):
|
||||
case <-time.After(openChannelTimeout):
|
||||
go func() {
|
||||
result := <-resultChan
|
||||
hh.cleanupUnusedChannel(result.channel, result.reqs)
|
||||
}()
|
||||
return nil, errors.New("timeout opening forwarded-tcpip channel")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,755 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/internal/port"
|
||||
"tunnel_pls/internal/registry"
|
||||
"tunnel_pls/session/forwarder"
|
||||
"tunnel_pls/session/interaction"
|
||||
"tunnel_pls/session/lifecycle"
|
||||
"tunnel_pls/session/slug"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockSessionRegistry struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
|
||||
args := m.Called(user, key)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
return args.Get(0).(registry.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
|
||||
args := m.Called(user, oldKey, newKey)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
|
||||
args := m.Called(key, session)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Remove(key registry.Key) {
|
||||
m.Called(key)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
|
||||
args := m.Called(user)
|
||||
return args.Get(0).([]registry.Session)
|
||||
}
|
||||
|
||||
func (m *MockSessionRegistry) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
|
||||
args := m.Called()
|
||||
return args.Get(0).(lifecycle.Lifecycle)
|
||||
}
|
||||
|
||||
func (m *MockSession) Interaction() interaction.Interaction {
|
||||
args := m.Called()
|
||||
return args.Get(0).(interaction.Interaction)
|
||||
}
|
||||
|
||||
func (m *MockSession) Forwarder() forwarder.Forwarder {
|
||||
args := m.Called()
|
||||
return args.Get(0).(forwarder.Forwarder)
|
||||
}
|
||||
|
||||
func (m *MockSession) Slug() slug.Slug {
|
||||
args := m.Called()
|
||||
return args.Get(0).(slug.Slug)
|
||||
}
|
||||
|
||||
func (m *MockSession) Detail() *types.Detail {
|
||||
args := m.Called()
|
||||
return args.Get(0).(*types.Detail)
|
||||
}
|
||||
|
||||
type MockLifecycle struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) Connection() ssh.Conn {
|
||||
args := m.Called()
|
||||
return args.Get(0).(ssh.Conn)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) PortRegistry() port.Port {
|
||||
args := m.Called()
|
||||
return args.Get(0).(port.Port)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) User() string {
|
||||
args := m.Called()
|
||||
return args.String(0)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) SetChannel(channel ssh.Channel) {
|
||||
m.Called(channel)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) SetStatus(status types.SessionStatus) {
|
||||
m.Called(status)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) IsActive() bool {
|
||||
args := m.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) StartedAt() time.Time {
|
||||
args := m.Called()
|
||||
return args.Get(0).(time.Time)
|
||||
}
|
||||
|
||||
func (m *MockLifecycle) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockSSHConn struct {
|
||||
ssh.Conn
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(name, data)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
|
||||
type MockSSHChannel struct {
|
||||
ssh.Channel
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Write(data []byte) (int, error) {
|
||||
args := m.Called(data)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockSSHChannel) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type MockForwarder struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
|
||||
args := m.Called(origin)
|
||||
return args.Get(0).([]byte)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) {
|
||||
m.Called(dst)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
|
||||
m.Called(dst, src)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) TunnelType() types.TunnelType {
|
||||
args := m.Called()
|
||||
return args.Get(0).(types.TunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) ForwardedPort() uint16 {
|
||||
args := m.Called()
|
||||
return uint16(args.Int(0))
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
|
||||
m.Called(tunnelType)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetForwardedPort(port uint16) {
|
||||
m.Called(port)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) SetListener(listener net.Listener) {
|
||||
m.Called(listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) Listener() net.Listener {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Listener)
|
||||
}
|
||||
|
||||
func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
|
||||
args := m.Called(payload)
|
||||
if args.Get(0) == nil {
|
||||
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
|
||||
}
|
||||
|
||||
type MockConn struct {
|
||||
net.Conn
|
||||
mock.Mock
|
||||
ReadBuffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
if m.ReadBuffer != nil {
|
||||
return m.ReadBuffer.Read(b)
|
||||
}
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
args := m.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
type wrappedConn struct {
|
||||
net.Conn
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *wrappedConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func TestNewHTTPHandler(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
hh := newHTTPHandler("domain", msr, true)
|
||||
assert.NotNil(t, hh)
|
||||
assert.Equal(t, "domain", hh.domain)
|
||||
assert.Equal(t, msr, hh.sessionRegistry)
|
||||
assert.True(t, hh.redirectTLS)
|
||||
}
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isTLS bool
|
||||
redirectTLS bool
|
||||
request []byte
|
||||
expected []byte
|
||||
setupMocks func(*MockSessionRegistry)
|
||||
setupConn func() (net.Conn, net.Conn)
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "bad request - invalid host",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - missing host",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "isTLS true and redirectTLS true - no redirect",
|
||||
isTLS: true,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redirect to TLS",
|
||||
isTLS: false,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://example.domain/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle ping request",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "session not found",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return((registry.Session)(nil), fmt.Errorf("session not found"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - invalid http",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("INVALID\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - open channel fails",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
|
||||
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.Writer)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - send initial request fails",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
go func() {
|
||||
for range reqCh {
|
||||
}
|
||||
}()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - success",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", types.SessionKey{
|
||||
Id: "test",
|
||||
Type: types.TunnelTypeHTTP,
|
||||
}).Return(mockSession, nil)
|
||||
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.ReadWriter)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
||||
})
|
||||
|
||||
go func() {
|
||||
for range reqCh {
|
||||
}
|
||||
}()
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "redirect - write failure",
|
||||
isTLS: false,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - write failure",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle ping request - write failure",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "close connection - error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("Write", mock.Anything).Return(182, nil)
|
||||
mc.On("Close").Return(fmt.Errorf("close error"))
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - stream close error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return()
|
||||
},
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
|
||||
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
mc.On("RemoteAddr").Return(addr)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - middleware failure",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
|
||||
return k.Id == "test"
|
||||
})).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
mockSSHChannel.On("Close").Return(nil)
|
||||
},
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("Close").Return(nil).Times(2)
|
||||
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - channel close error",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
|
||||
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
|
||||
|
||||
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.ReadWriter)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
|
||||
})
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "forwarding - open channel timeout",
|
||||
isTLS: true,
|
||||
redirectTLS: false,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
oldTimeout := openChannelTimeout
|
||||
openChannelTimeout = 10 * time.Millisecond
|
||||
t.Cleanup(func() {
|
||||
openChannelTimeout = oldTimeout
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
mockSession := new(MockSession)
|
||||
mockForwarder := new(MockForwarder)
|
||||
mockLifecycle := new(MockLifecycle)
|
||||
mockSSHConn := new(MockSSHConn)
|
||||
mockSSHChannel := new(MockSSHChannel)
|
||||
|
||||
msr.On("Get", mock.Anything).Return(mockSession, nil)
|
||||
mockSession.On("Forwarder").Return(mockForwarder)
|
||||
mockSession.On("Lifecycle").Return(mockLifecycle)
|
||||
mockLifecycle.On("Connection").Return(mockSSHConn)
|
||||
|
||||
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
|
||||
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
|
||||
w := args.Get(0).(io.Writer)
|
||||
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
|
||||
})
|
||||
|
||||
reqCh := make(chan *ssh.Request)
|
||||
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Run(func(args mock.Arguments) {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
|
||||
|
||||
mockSSHChannel.On("Close").Return(fmt.Errorf("cleanup close error"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSessionRegistry := new(MockSessionRegistry)
|
||||
hh := &httpHandler{
|
||||
domain: "domain",
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
redirectTLS: tt.redirectTLS,
|
||||
}
|
||||
|
||||
if tt.setupMocks != nil {
|
||||
tt.setupMocks(mockSessionRegistry)
|
||||
}
|
||||
|
||||
var serverConn, clientConn net.Conn
|
||||
if tt.setupConn != nil {
|
||||
serverConn, clientConn = tt.setupConn()
|
||||
} else {
|
||||
serverConn, clientConn = net.Pipe()
|
||||
}
|
||||
|
||||
if clientConn != nil {
|
||||
defer clientConn.Close()
|
||||
}
|
||||
|
||||
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
var wrappedServerConn net.Conn
|
||||
if _, ok := serverConn.(*MockConn); ok {
|
||||
wrappedServerConn = serverConn
|
||||
} else {
|
||||
wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr}
|
||||
}
|
||||
|
||||
responseChan := make(chan []byte, 1)
|
||||
doneChan := make(chan struct{})
|
||||
|
||||
if clientConn != nil {
|
||||
go func() {
|
||||
defer close(doneChan)
|
||||
var res []byte
|
||||
for {
|
||||
buf := make([]byte, 4096)
|
||||
n, err := clientConn.Read(buf)
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Logf("Error reading response: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
res = append(res, buf[:n]...)
|
||||
if len(tt.expected) > 0 && len(res) >= len(tt.expected) {
|
||||
break
|
||||
}
|
||||
}
|
||||
responseChan <- res
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := clientConn.Write(tt.request)
|
||||
if err != nil {
|
||||
t.Logf("Error writing request: %v", err)
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
close(responseChan)
|
||||
close(doneChan)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
hh.Handler(wrappedServerConn, tt.isTLS)
|
||||
}()
|
||||
|
||||
select {
|
||||
case response := <-responseChan:
|
||||
if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" {
|
||||
resStr := string(response)
|
||||
assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n"))
|
||||
assert.Contains(t, resStr, "Content-Length: 5\r\n")
|
||||
assert.Contains(t, resStr, "Server: Tunnel Please\r\n")
|
||||
assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello"))
|
||||
} else {
|
||||
assert.Equal(t, string(tt.expected), string(response))
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
if clientConn != nil {
|
||||
t.Fatal("Test timeout - no response received")
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if clientConn != nil {
|
||||
<-doneChan
|
||||
}
|
||||
|
||||
mockSessionRegistry.AssertExpectations(t)
|
||||
if mc, ok := serverConn.(*MockConn); ok {
|
||||
mc.AssertExpectations(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -40,6 +40,6 @@ func (ht *https) Serve(listener net.Listener) error {
|
||||
continue
|
||||
}
|
||||
|
||||
go ht.httpHandler.handler(conn, true)
|
||||
go ht.httpHandler.Handler(conn, true)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewHTTPSServer(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
domain := "example.com"
|
||||
port := "443"
|
||||
redirectTLS := false
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
srv := NewHTTPSServer(domain, port, msr, redirectTLS, tlsConfig)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
httpsSrv, ok := srv.(*https)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, port, httpsSrv.port)
|
||||
assert.Equal(t, domain, httpsSrv.domain)
|
||||
assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
|
||||
assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Listen(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
srv := NewHTTPSServer("example.com", "0", msr, false, tlsConfig)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
if err != nil {
|
||||
t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err)
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, listener)
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Serve(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
err = srv.Serve(listener)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
err := srv.Serve(ml)
|
||||
assert.True(t, errors.Is(err, net.ErrClosed))
|
||||
ml.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Serve_Success(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
conn.Close()
|
||||
listener.Close()
|
||||
}
|
||||
@@ -12,16 +12,16 @@ import (
|
||||
|
||||
type tcp struct {
|
||||
port uint16
|
||||
forwarder forwarder
|
||||
forwarder Forwarder
|
||||
}
|
||||
|
||||
type forwarder interface {
|
||||
type Forwarder interface {
|
||||
CreateForwardedTCPIPPayload(origin net.Addr) []byte
|
||||
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
|
||||
HandleConnection(dst io.ReadWriter, src ssh.Channel)
|
||||
}
|
||||
|
||||
func NewTCPServer(port uint16, forwarder forwarder) Transport {
|
||||
func NewTCPServer(port uint16, forwarder Forwarder) Transport {
|
||||
return &tcp{
|
||||
port: port,
|
||||
forwarder: forwarder,
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func TestNewTCPServer(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
port := uint16(9000)
|
||||
|
||||
srv := NewTCPServer(port, mf)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
tcpSrv, ok := srv.(*tcp)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, port, tcpSrv.port)
|
||||
assert.Equal(t, mf, tcpSrv.forwarder)
|
||||
}
|
||||
|
||||
func TestTCPServer_Listen(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, listener)
|
||||
listener.Close()
|
||||
}
|
||||
|
||||
func TestTCPServer_Serve(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
err = srv.Serve(listener)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestTCPServer_Serve_AcceptError(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
ml.On("Accept").Return(nil, net.ErrClosed).Once()
|
||||
|
||||
err := srv.Serve(ml)
|
||||
assert.Nil(t, err)
|
||||
ml.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_Serve_Success(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
payload := []byte("test-payload")
|
||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
||||
reqs := make(chan *ssh.Request)
|
||||
mf.On("OpenForwardedChannel", payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
|
||||
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
assert.NoError(t, err)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
conn.Close()
|
||||
listener.Close()
|
||||
|
||||
mf.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_handleTcp_Success(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf).(*tcp)
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
|
||||
payload := []byte("test-payload")
|
||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
||||
|
||||
reqs := make(chan *ssh.Request)
|
||||
mockChannel := new(MockSSHChannel)
|
||||
mf.On("OpenForwardedChannel", payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
|
||||
|
||||
mf.On("HandleConnection", serverConn, mockChannel).Return()
|
||||
|
||||
srv.handleTcp(serverConn)
|
||||
|
||||
mf.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_handleTcp_CloseError(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf).(*tcp)
|
||||
|
||||
mc := new(MockConn)
|
||||
mc.On("Close").Return(errors.New("close error"))
|
||||
mc.On("RemoteAddr").Return(&net.TCPAddr{})
|
||||
|
||||
payload := []byte("test-payload")
|
||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
||||
mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||
|
||||
srv.handleTcp(mc)
|
||||
mc.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
|
||||
mf := new(MockForwarder)
|
||||
srv := NewTCPServer(0, mf).(*tcp)
|
||||
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
|
||||
payload := []byte("test-payload")
|
||||
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
|
||||
mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
|
||||
|
||||
srv.handleTcp(serverConn)
|
||||
|
||||
mf.AssertExpectations(t)
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
"tunnel_pls/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type MockConfig struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockConfig) Domain() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
|
||||
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
|
||||
|
||||
func TestValidateCertDomains_NotFound(t *testing.T) {
|
||||
result := ValidateCertDomains("nonexistent.pem", "example.com")
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func TestValidateCertDomains_InvalidPEM(t *testing.T) {
|
||||
tmpFile, err := os.CreateTemp("", "invalid*.pem")
|
||||
assert.NoError(t, err)
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
_, _ = tmpFile.WriteString("not a pem")
|
||||
tmpFile.Close()
|
||||
|
||||
result := ValidateCertDomains(tmpFile.Name(), "example.com")
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func TestTLSManager_getTLSConfig(t *testing.T) {
|
||||
tm := &tlsManager{
|
||||
useCertMagic: false,
|
||||
}
|
||||
cfg := tm.getTLSConfig()
|
||||
assert.NotNil(t, cfg)
|
||||
assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
|
||||
}
|
||||
|
||||
func TestTLSManager_getCertificate_Magic(t *testing.T) {
|
||||
tm := &tlsManager{
|
||||
useCertMagic: true,
|
||||
}
|
||||
hello := &tls.ClientHelloInfo{}
|
||||
assert.Panics(t, func() {
|
||||
_, _ = tm.getCertificate(hello)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTLSManager_userCertsExistAndValid(t *testing.T) {
|
||||
tm := &tlsManager{
|
||||
certPath: "nonexistent.pem",
|
||||
keyPath: "nonexistent.key",
|
||||
}
|
||||
assert.False(t, tm.userCertsExistAndValid())
|
||||
|
||||
keyFile, _ := os.CreateTemp("", "key*.pem")
|
||||
defer os.Remove(keyFile.Name())
|
||||
tm.keyPath = keyFile.Name()
|
||||
assert.False(t, tm.userCertsExistAndValid())
|
||||
}
|
||||
|
||||
func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, soon bool) (string, string) {
|
||||
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
|
||||
notAfter := time.Now().Add(365 * 24 * time.Hour)
|
||||
if expired {
|
||||
notAfter = time.Now().Add(-24 * time.Hour)
|
||||
} else if soon {
|
||||
notAfter = time.Now().Add(15 * 24 * time.Hour)
|
||||
}
|
||||
|
||||
template := x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{
|
||||
CommonName: domain,
|
||||
},
|
||||
NotBefore: time.Now().Add(-24 * time.Hour),
|
||||
NotAfter: notAfter,
|
||||
DNSNames: []string{domain},
|
||||
}
|
||||
|
||||
if wildcard {
|
||||
template.DNSNames = append(template.DNSNames, "*."+domain)
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
|
||||
assert.NoError(t, err)
|
||||
|
||||
certOut, _ := os.CreateTemp("", "cert*.pem")
|
||||
_ = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
certOut.Close()
|
||||
|
||||
keyOut, _ := os.CreateTemp("", "key*.pem")
|
||||
_ = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||
keyOut.Close()
|
||||
|
||||
return certOut.Name(), keyOut.Name()
|
||||
}
|
||||
|
||||
func TestValidateCertDomains_Success(t *testing.T) {
|
||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||
defer os.Remove(certPath)
|
||||
defer os.Remove(keyPath)
|
||||
|
||||
result := ValidateCertDomains(certPath, "example.com")
|
||||
assert.True(t, result)
|
||||
}
|
||||
|
||||
func TestValidateCertDomains_Expired(t *testing.T) {
|
||||
certPath, keyPath := createTestCert(t, "example.com", true, true, false)
|
||||
defer os.Remove(certPath)
|
||||
defer os.Remove(keyPath)
|
||||
|
||||
result := ValidateCertDomains(certPath, "example.com")
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func TestValidateCertDomains_ExpiringSoon(t *testing.T) {
|
||||
certPath, keyPath := createTestCert(t, "example.com", true, false, true)
|
||||
defer os.Remove(certPath)
|
||||
defer os.Remove(keyPath)
|
||||
|
||||
result := ValidateCertDomains(certPath, "example.com")
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func TestValidateCertDomains_MissingWildcard(t *testing.T) {
|
||||
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
|
||||
defer os.Remove(certPath)
|
||||
defer os.Remove(keyPath)
|
||||
|
||||
result := ValidateCertDomains(certPath, "example.com")
|
||||
assert.False(t, result)
|
||||
}
|
||||
|
||||
func TestTLSManager_loadUserCerts_Success(t *testing.T) {
|
||||
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
|
||||
defer os.Remove(certPath)
|
||||
defer os.Remove(keyPath)
|
||||
|
||||
tm := &tlsManager{
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
}
|
||||
err := tm.loadUserCerts()
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, tm.userCert)
|
||||
}
|
||||
@@ -8,3 +8,7 @@ type Transport interface {
|
||||
Listen() (net.Listener, error)
|
||||
Serve(listener net.Listener) error
|
||||
}
|
||||
|
||||
type HTTP interface {
|
||||
Handler(conn net.Conn, isTLS bool)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTransportInterface(t *testing.T) {
|
||||
var _ Transport = (*httpServer)(nil)
|
||||
var _ Transport = (*https)(nil)
|
||||
var _ Transport = (*tcp)(nil)
|
||||
|
||||
assert.True(t, true)
|
||||
}
|
||||
Reference in New Issue
Block a user