revert-54069ad305 #11

Closed
bagas wants to merge 217 commits from revert-54069ad305 into main
13 changed files with 1346 additions and 26 deletions
Showing only changes of commit 29cabe42d3 - Show all commits
+5
View File
@@ -11,6 +11,7 @@ require (
github.com/joho/godotenv v1.5.1
github.com/libdns/cloudflare v0.2.2
github.com/muesli/termenv v0.16.0
github.com/stretchr/testify v1.8.1
golang.org/x/crypto v0.47.0
google.golang.org/grpc v1.78.0
google.golang.org/protobuf v1.36.11
@@ -27,6 +28,7 @@ require (
github.com/clipperhouse/displaywidth v0.6.2 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/libdns/libdns v1.1.1 // indirect
@@ -38,8 +40,10 @@ require (
github.com/miekg/dns v1.1.69 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/sahilm/fuzzy v0.1.1 // indirect
github.com/stretchr/objx v0.5.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/zeebo/blake3 v0.2.4 // indirect
go.uber.org/multierr v1.11.0 // indirect
@@ -52,4 +56,5 @@ require (
golang.org/x/text v0.33.0 // indirect
golang.org/x/tools v0.40.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
+10
View File
@@ -32,6 +32,7 @@ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfa
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4=
github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
@@ -80,6 +81,12 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/sahilm/fuzzy v0.1.1 h1:ceu5RHF8DGgoi+/dR5PsECjCDH1BE3Fnmpo7aVXOdRA=
github.com/sahilm/fuzzy v0.1.1/go.mod h1:VFvziUEIMCrT6A6tw2RFIXPXXmzXbOsSHF0DOI8ZK9Y=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
@@ -138,5 +145,8 @@ google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+1 -1
View File
@@ -35,6 +35,6 @@ func (ht *httpServer) Serve(listener net.Listener) error {
continue
}
go ht.handler.handler(conn, false)
go ht.handler.Handler(conn, false)
}
}
+112
View File
@@ -0,0 +1,112 @@
package transport
import (
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestNewHTTPServer(t *testing.T) {
msr := new(MockSessionRegistry)
domain := "example.com"
port := "8080"
redirectTLS := true
srv := NewHTTPServer(domain, port, msr, redirectTLS)
assert.NotNil(t, srv)
httpSrv, ok := srv.(*httpServer)
assert.True(t, ok)
assert.Equal(t, port, httpSrv.port)
assert.Equal(t, domain, httpSrv.handler.domain)
assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
assert.Equal(t, redirectTLS, httpSrv.handler.redirectTLS)
}
func TestHTTPServer_Listen(t *testing.T) {
msr := new(MockSessionRegistry)
srv := NewHTTPServer("example.com", "0", msr, false)
listener, err := srv.Listen()
assert.NoError(t, err)
assert.NotNil(t, listener)
listener.Close()
}
func TestHTTPServer_Serve(t *testing.T) {
msr := new(MockSessionRegistry)
srv := NewHTTPServer("example.com", "0", msr, false)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
time.Sleep(100 * time.Millisecond)
listener.Close()
}()
err = srv.Serve(listener)
assert.True(t, errors.Is(err, net.ErrClosed))
}
func TestHTTPServer_Serve_AcceptError(t *testing.T) {
msr := new(MockSessionRegistry)
srv := NewHTTPServer("example.com", "0", msr, false)
ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
err := srv.Serve(ml)
assert.True(t, errors.Is(err, net.ErrClosed))
ml.AssertExpectations(t)
}
func TestHTTPServer_Serve_Success(t *testing.T) {
msr := new(MockSessionRegistry)
srv := NewHTTPServer("example.com", "0", msr, false)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
go func() {
_ = srv.Serve(listener)
}()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
assert.NoError(t, err)
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
time.Sleep(100 * time.Millisecond)
conn.Close()
listener.Close()
}
type mockListener struct {
mock.Mock
}
func (m *mockListener) Accept() (net.Conn, error) {
args := m.Called()
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(net.Conn), args.Error(1)
}
func (m *mockListener) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *mockListener) Addr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
+16 -21
View File
@@ -19,6 +19,8 @@ import (
"golang.org/x/crypto/ssh"
)
var openChannelTimeout = 5 * time.Second
type httpHandler struct {
domain string
sessionRegistry registry.Registry
@@ -52,7 +54,7 @@ func (hh *httpHandler) badRequest(conn net.Conn) error {
return nil
}
func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
defer hh.closeConnection(conn)
dstReader := bufio.NewReader(conn)
@@ -69,7 +71,7 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
}
if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("Location: https://%s.%s/\r\n", slug, hh.domain))
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.domain))
return
}
@@ -77,7 +79,10 @@ func (hh *httpHandler) handler(conn net.Conn, isTLS bool) {
return
}
sshSession, err := hh.getSession(slug)
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.TunnelTypeHTTP,
})
if err != nil {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://tunnl.live/tunnel-not-found?slug=%s\r\n", slug))
return
@@ -102,7 +107,7 @@ func (hh *httpHandler) closeConnection(conn net.Conn) {
func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".")
if len(host) < 1 {
if len(host) <= 1 {
return "", errors.New("invalid host")
}
return host[0], nil
@@ -128,19 +133,9 @@ func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
))
if err != nil {
log.Println("Failed to write 200 OK:", err)
}
return true
}
func (hh *httpHandler) getSession(slug string) (registry.Session, error) {
sshSession, err := hh.sessionRegistry.Get(types.SessionKey{
Id: slug,
Type: types.TunnelTypeHTTP,
})
if err != nil {
return nil, err
}
return sshSession, nil
return true
}
func (hh *httpHandler) forwardRequest(hw stream.HTTP, initialRequest header.RequestHeader, sshSession registry.Session) {
@@ -180,11 +175,7 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.
go func() {
channel, reqs, err := sshSession.Lifecycle().Connection().OpenChannel("forwarded-tcpip", payload)
select {
case resultChan <- channelResult{channel, reqs, err}:
default:
hh.cleanupUnusedChannel(channel, reqs)
}
resultChan <- channelResult{channel, reqs, err}
}()
select {
@@ -194,7 +185,11 @@ func (hh *httpHandler) openForwardedChannel(hw stream.HTTP, sshSession registry.
}
go ssh.DiscardRequests(result.reqs)
return result.channel, nil
case <-time.After(5 * time.Second):
case <-time.After(openChannelTimeout):
go func() {
result := <-resultChan
hh.cleanupUnusedChannel(result.channel, result.reqs)
}()
return nil, errors.New("timeout opening forwarded-tcpip channel")
}
}
+755
View File
@@ -0,0 +1,755 @@
package transport
import (
"bytes"
"fmt"
"io"
"net"
"strings"
"sync"
"testing"
"time"
"tunnel_pls/internal/port"
"tunnel_pls/internal/registry"
"tunnel_pls/session/forwarder"
"tunnel_pls/session/interaction"
"tunnel_pls/session/lifecycle"
"tunnel_pls/session/slug"
"tunnel_pls/types"
"golang.org/x/crypto/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockSessionRegistry struct {
mock.Mock
}
func (m *MockSessionRegistry) Get(key registry.Key) (registry.Session, error) {
args := m.Called(key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) GetWithUser(user string, key registry.Key) (registry.Session, error) {
args := m.Called(user, key)
if args.Get(0) == nil {
return nil, args.Error(1)
}
return args.Get(0).(registry.Session), args.Error(1)
}
func (m *MockSessionRegistry) Update(user string, oldKey, newKey registry.Key) error {
args := m.Called(user, oldKey, newKey)
return args.Error(0)
}
func (m *MockSessionRegistry) Register(key registry.Key, session registry.Session) bool {
args := m.Called(key, session)
return args.Bool(0)
}
func (m *MockSessionRegistry) Remove(key registry.Key) {
m.Called(key)
}
func (m *MockSessionRegistry) GetAllSessionFromUser(user string) []registry.Session {
args := m.Called(user)
return args.Get(0).([]registry.Session)
}
func (m *MockSessionRegistry) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
type MockSession struct {
mock.Mock
}
func (m *MockSession) Lifecycle() lifecycle.Lifecycle {
args := m.Called()
return args.Get(0).(lifecycle.Lifecycle)
}
func (m *MockSession) Interaction() interaction.Interaction {
args := m.Called()
return args.Get(0).(interaction.Interaction)
}
func (m *MockSession) Forwarder() forwarder.Forwarder {
args := m.Called()
return args.Get(0).(forwarder.Forwarder)
}
func (m *MockSession) Slug() slug.Slug {
args := m.Called()
return args.Get(0).(slug.Slug)
}
func (m *MockSession) Detail() *types.Detail {
args := m.Called()
return args.Get(0).(*types.Detail)
}
type MockLifecycle struct {
mock.Mock
}
func (m *MockLifecycle) Connection() ssh.Conn {
args := m.Called()
return args.Get(0).(ssh.Conn)
}
func (m *MockLifecycle) PortRegistry() port.Port {
args := m.Called()
return args.Get(0).(port.Port)
}
func (m *MockLifecycle) User() string {
args := m.Called()
return args.String(0)
}
func (m *MockLifecycle) SetChannel(channel ssh.Channel) {
m.Called(channel)
}
func (m *MockLifecycle) SetStatus(status types.SessionStatus) {
m.Called(status)
}
func (m *MockLifecycle) IsActive() bool {
args := m.Called()
return args.Bool(0)
}
func (m *MockLifecycle) StartedAt() time.Time {
args := m.Called()
return args.Get(0).(time.Time)
}
func (m *MockLifecycle) Close() error {
args := m.Called()
return args.Error(0)
}
type MockSSHConn struct {
ssh.Conn
mock.Mock
}
func (m *MockSSHConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) {
args := m.Called(name, data)
if args.Get(0) == nil {
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
type MockSSHChannel struct {
ssh.Channel
mock.Mock
}
func (m *MockSSHChannel) Write(data []byte) (int, error) {
args := m.Called(data)
return args.Int(0), args.Error(1)
}
func (m *MockSSHChannel) Close() error {
args := m.Called()
return args.Error(0)
}
type MockForwarder struct {
mock.Mock
}
func (m *MockForwarder) CreateForwardedTCPIPPayload(origin net.Addr) []byte {
args := m.Called(origin)
return args.Get(0).([]byte)
}
func (m *MockForwarder) WriteBadGatewayResponse(dst io.Writer) {
m.Called(dst)
}
func (m *MockForwarder) HandleConnection(dst io.ReadWriter, src ssh.Channel) {
m.Called(dst, src)
}
func (m *MockForwarder) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockForwarder) TunnelType() types.TunnelType {
args := m.Called()
return args.Get(0).(types.TunnelType)
}
func (m *MockForwarder) ForwardedPort() uint16 {
args := m.Called()
return uint16(args.Int(0))
}
func (m *MockForwarder) SetType(tunnelType types.TunnelType) {
m.Called(tunnelType)
}
func (m *MockForwarder) SetForwardedPort(port uint16) {
m.Called(port)
}
func (m *MockForwarder) SetListener(listener net.Listener) {
m.Called(listener)
}
func (m *MockForwarder) Listener() net.Listener {
args := m.Called()
return args.Get(0).(net.Listener)
}
func (m *MockForwarder) OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error) {
args := m.Called(payload)
if args.Get(0) == nil {
return nil, args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
return args.Get(0).(ssh.Channel), args.Get(1).(<-chan *ssh.Request), args.Error(2)
}
type MockConn struct {
net.Conn
mock.Mock
ReadBuffer *bytes.Buffer
}
func (m *MockConn) Read(b []byte) (n int, err error) {
if m.ReadBuffer != nil {
return m.ReadBuffer.Read(b)
}
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Write(b []byte) (n int, err error) {
args := m.Called(b)
return args.Int(0), args.Error(1)
}
func (m *MockConn) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockConn) RemoteAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
type wrappedConn struct {
net.Conn
remoteAddr net.Addr
}
func (c *wrappedConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func TestNewHTTPHandler(t *testing.T) {
msr := new(MockSessionRegistry)
hh := newHTTPHandler("domain", msr, true)
assert.NotNil(t, hh)
assert.Equal(t, "domain", hh.domain)
assert.Equal(t, msr, hh.sessionRegistry)
assert.True(t, hh.redirectTLS)
}
func TestHandler(t *testing.T) {
tests := []struct {
name string
isTLS bool
redirectTLS bool
request []byte
expected []byte
setupMocks func(*MockSessionRegistry)
setupConn func() (net.Conn, net.Conn)
expectError bool
}{
{
name: "bad request - invalid host",
isTLS: false,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: invalid\r\n\r\n"),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "bad request - missing host",
isTLS: false,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\n\r\n"),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "isTLS true and redirectTLS true - no redirect",
isTLS: true,
redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "redirect to TLS",
isTLS: false,
redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://example.domain/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "handle ping request",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 0\r\nConnection: close\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Methods: GET, HEAD, OPTIONS\r\nAccess-Control-Allow-Headers: *\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "session not found",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnl.live/tunnel-not-found?slug=test\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return((registry.Session)(nil), fmt.Errorf("session not found"))
},
},
{
name: "bad request - invalid http",
isTLS: false,
redirectTLS: false,
request: []byte("INVALID\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "forwarding - open channel fails",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return((ssh.Channel)(nil), (<-chan *ssh.Request)(nil), fmt.Errorf("open channel failed"))
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
w := args.Get(0).(io.Writer)
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
})
},
},
{
name: "forwarding - send initial request fails",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mockSSHChannel.On("Close").Return(nil)
go func() {
for range reqCh {
}
}()
},
},
{
name: "forwarding - success",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", types.SessionKey{
Id: "test",
Type: types.TunnelTypeHTTP,
}).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil)
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
w := args.Get(0).(io.ReadWriter)
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
})
go func() {
for range reqCh {
}
}()
},
},
{
name: "redirect - write failure",
isTLS: false,
redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "bad request - write failure",
isTLS: false,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "handle ping request - write failure",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{
name: "close connection - error",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
mc.On("Write", mock.Anything).Return(182, nil)
mc.On("Close").Return(fmt.Errorf("close error"))
return mc, nil
},
},
{
name: "forwarding - stream close error",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(nil)
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Return()
},
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
mc.On("RemoteAddr").Return(addr)
return mc, nil
},
},
{
name: "forwarding - middleware failure",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte(""),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
msr.On("Get", mock.MatchedBy(func(k types.SessionKey) bool {
return k.Id == "test"
})).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockLifecycle := new(MockLifecycle)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockSSHConn := new(MockSSHConn)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockSSHChannel := new(MockSSHChannel)
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", mock.Anything).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Close").Return(nil)
},
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
mc.On("Close").Return(nil).Times(2)
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
return mc, nil
},
},
{
name: "forwarding - channel close error",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\nServer: Tunnel Please\r\n\r\nhello"),
setupMocks: func(msr *MockSessionRegistry) {
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Write", mock.Anything).Return(0, nil)
mockSSHChannel.On("Close").Return(fmt.Errorf("close error"))
mockForwarder.On("HandleConnection", mock.Anything, mockSSHChannel).Run(func(args mock.Arguments) {
w := args.Get(0).(io.ReadWriter)
_, _ = w.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\nhello"))
})
},
},
{
name: "forwarding - open channel timeout",
isTLS: true,
redirectTLS: false,
request: []byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"),
expected: []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
oldTimeout := openChannelTimeout
openChannelTimeout = 10 * time.Millisecond
t.Cleanup(func() {
openChannelTimeout = oldTimeout
time.Sleep(100 * time.Millisecond)
})
mockSession := new(MockSession)
mockForwarder := new(MockForwarder)
mockLifecycle := new(MockLifecycle)
mockSSHConn := new(MockSSHConn)
mockSSHChannel := new(MockSSHChannel)
msr.On("Get", mock.Anything).Return(mockSession, nil)
mockSession.On("Forwarder").Return(mockForwarder)
mockSession.On("Lifecycle").Return(mockLifecycle)
mockLifecycle.On("Connection").Return(mockSSHConn)
mockForwarder.On("CreateForwardedTCPIPPayload", mock.Anything).Return([]byte("payload"))
mockForwarder.On("WriteBadGatewayResponse", mock.Anything).Run(func(args mock.Arguments) {
w := args.Get(0).(io.Writer)
_, _ = w.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n"))
})
reqCh := make(chan *ssh.Request)
mockSSHConn.On("OpenChannel", "forwarded-tcpip", []byte("payload")).Run(func(args mock.Arguments) {
time.Sleep(50 * time.Millisecond)
}).Return(mockSSHChannel, (<-chan *ssh.Request)(reqCh), nil)
mockSSHChannel.On("Close").Return(fmt.Errorf("cleanup close error"))
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockSessionRegistry := new(MockSessionRegistry)
hh := &httpHandler{
domain: "domain",
sessionRegistry: mockSessionRegistry,
redirectTLS: tt.redirectTLS,
}
if tt.setupMocks != nil {
tt.setupMocks(mockSessionRegistry)
}
var serverConn, clientConn net.Conn
if tt.setupConn != nil {
serverConn, clientConn = tt.setupConn()
} else {
serverConn, clientConn = net.Pipe()
}
if clientConn != nil {
defer clientConn.Close()
}
remoteAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
var wrappedServerConn net.Conn
if _, ok := serverConn.(*MockConn); ok {
wrappedServerConn = serverConn
} else {
wrappedServerConn = &wrappedConn{Conn: serverConn, remoteAddr: remoteAddr}
}
responseChan := make(chan []byte, 1)
doneChan := make(chan struct{})
if clientConn != nil {
go func() {
defer close(doneChan)
var res []byte
for {
buf := make([]byte, 4096)
n, err := clientConn.Read(buf)
if err != nil {
if err != io.EOF {
t.Logf("Error reading response: %v", err)
}
break
}
res = append(res, buf[:n]...)
if len(tt.expected) > 0 && len(res) >= len(tt.expected) {
break
}
}
responseChan <- res
}()
go func() {
_, err := clientConn.Write(tt.request)
if err != nil {
t.Logf("Error writing request: %v", err)
}
}()
} else {
close(responseChan)
close(doneChan)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
hh.Handler(wrappedServerConn, tt.isTLS)
}()
select {
case response := <-responseChan:
if tt.name == "forwarding - success" || tt.name == "forwarding - channel close error" {
resStr := string(response)
assert.True(t, strings.HasPrefix(resStr, "HTTP/1.1 200 OK\r\n"))
assert.Contains(t, resStr, "Content-Length: 5\r\n")
assert.Contains(t, resStr, "Server: Tunnel Please\r\n")
assert.True(t, strings.HasSuffix(resStr, "\r\n\r\nhello"))
} else {
assert.Equal(t, string(tt.expected), string(response))
}
case <-time.After(2 * time.Second):
if clientConn != nil {
t.Fatal("Test timeout - no response received")
}
}
wg.Wait()
if clientConn != nil {
<-doneChan
}
mockSessionRegistry.AssertExpectations(t)
if mc, ok := serverConn.(*MockConn); ok {
mc.AssertExpectations(t)
}
})
}
}
+1 -1
View File
@@ -40,6 +40,6 @@ func (ht *https) Serve(listener net.Listener) error {
continue
}
go ht.httpHandler.handler(conn, true)
go ht.httpHandler.Handler(conn, true)
}
}
+100
View File
@@ -0,0 +1,100 @@
package transport
import (
"crypto/tls"
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestNewHTTPSServer(t *testing.T) {
msr := new(MockSessionRegistry)
domain := "example.com"
port := "443"
redirectTLS := false
tlsConfig := &tls.Config{}
srv := NewHTTPSServer(domain, port, msr, redirectTLS, tlsConfig)
assert.NotNil(t, srv)
httpsSrv, ok := srv.(*https)
assert.True(t, ok)
assert.Equal(t, port, httpsSrv.port)
assert.Equal(t, domain, httpsSrv.domain)
assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
}
func TestHTTPSServer_Listen(t *testing.T) {
msr := new(MockSessionRegistry)
tlsConfig := &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, nil
},
}
srv := NewHTTPSServer("example.com", "0", msr, false, tlsConfig)
listener, err := srv.Listen()
if err != nil {
t.Skip("Skipping tls.Listen test as it requires valid certificates/setup:", err)
return
}
assert.NotNil(t, listener)
listener.Close()
}
func TestHTTPSServer_Serve(t *testing.T) {
msr := new(MockSessionRegistry)
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
time.Sleep(100 * time.Millisecond)
listener.Close()
}()
err = srv.Serve(listener)
assert.True(t, errors.Is(err, net.ErrClosed))
}
func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
msr := new(MockSessionRegistry)
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
err := srv.Serve(ml)
assert.True(t, errors.Is(err, net.ErrClosed))
ml.AssertExpectations(t)
}
func TestHTTPSServer_Serve_Success(t *testing.T) {
msr := new(MockSessionRegistry)
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
go func() {
_ = srv.Serve(listener)
}()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
assert.NoError(t, err)
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
time.Sleep(100 * time.Millisecond)
conn.Close()
listener.Close()
}
+3 -3
View File
@@ -12,16 +12,16 @@ import (
type tcp struct {
port uint16
forwarder forwarder
forwarder Forwarder
}
type forwarder interface {
type Forwarder interface {
CreateForwardedTCPIPPayload(origin net.Addr) []byte
OpenForwardedChannel(payload []byte) (ssh.Channel, <-chan *ssh.Request, error)
HandleConnection(dst io.ReadWriter, src ssh.Channel)
}
func NewTCPServer(port uint16, forwarder forwarder) Transport {
func NewTCPServer(port uint16, forwarder Forwarder) Transport {
return &tcp{
port: port,
forwarder: forwarder,
+146
View File
@@ -0,0 +1,146 @@
package transport
import (
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.org/x/crypto/ssh"
)
func TestNewTCPServer(t *testing.T) {
mf := new(MockForwarder)
port := uint16(9000)
srv := NewTCPServer(port, mf)
assert.NotNil(t, srv)
tcpSrv, ok := srv.(*tcp)
assert.True(t, ok)
assert.Equal(t, port, tcpSrv.port)
assert.Equal(t, mf, tcpSrv.forwarder)
}
func TestTCPServer_Listen(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
listener, err := srv.Listen()
assert.NoError(t, err)
assert.NotNil(t, listener)
listener.Close()
}
func TestTCPServer_Serve(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
go func() {
time.Sleep(100 * time.Millisecond)
listener.Close()
}()
err = srv.Serve(listener)
assert.Nil(t, err)
}
func TestTCPServer_Serve_AcceptError(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
ml.On("Accept").Return(nil, net.ErrClosed).Once()
err := srv.Serve(ml)
assert.Nil(t, err)
ml.AssertExpectations(t)
}
func TestTCPServer_Serve_Success(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf)
listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
payload := []byte("test-payload")
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
reqs := make(chan *ssh.Request)
mf.On("OpenForwardedChannel", payload).Return(new(MockSSHChannel), (<-chan *ssh.Request)(reqs), nil)
mf.On("HandleConnection", mock.Anything, mock.Anything).Return()
go func() {
_ = srv.Serve(listener)
}()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
conn.Close()
listener.Close()
mf.AssertExpectations(t)
}
func TestTCPServer_handleTcp_Success(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf).(*tcp)
serverConn, clientConn := net.Pipe()
defer clientConn.Close()
payload := []byte("test-payload")
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
reqs := make(chan *ssh.Request)
mockChannel := new(MockSSHChannel)
mf.On("OpenForwardedChannel", payload).Return(mockChannel, (<-chan *ssh.Request)(reqs), nil)
mf.On("HandleConnection", serverConn, mockChannel).Return()
srv.handleTcp(serverConn)
mf.AssertExpectations(t)
}
func TestTCPServer_handleTcp_CloseError(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf).(*tcp)
mc := new(MockConn)
mc.On("Close").Return(errors.New("close error"))
mc.On("RemoteAddr").Return(&net.TCPAddr{})
payload := []byte("test-payload")
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
srv.handleTcp(mc)
mc.AssertExpectations(t)
}
func TestTCPServer_handleTcp_OpenChannelError(t *testing.T) {
mf := new(MockForwarder)
srv := NewTCPServer(0, mf).(*tcp)
serverConn, clientConn := net.Pipe()
defer clientConn.Close()
payload := []byte("test-payload")
mf.On("CreateForwardedTCPIPPayload", mock.Anything).Return(payload)
mf.On("OpenForwardedChannel", payload).Return(nil, (<-chan *ssh.Request)(nil), errors.New("open error"))
srv.handleTcp(serverConn)
mf.AssertExpectations(t)
}
+178
View File
@@ -0,0 +1,178 @@
package transport
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"testing"
"time"
"tunnel_pls/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
type MockConfig struct {
mock.Mock
}
func (m *MockConfig) Domain() string { return m.Called().String(0) }
func (m *MockConfig) SSHPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPPort() string { return m.Called().String(0) }
func (m *MockConfig) HTTPSPort() string { return m.Called().String(0) }
func (m *MockConfig) TLSEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) TLSRedirect() bool { return m.Called().Bool(0) }
func (m *MockConfig) ACMEEmail() string { return m.Called().String(0) }
func (m *MockConfig) CFAPIToken() string { return m.Called().String(0) }
func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
func (m *MockConfig) GRPCAddress() string { return m.Called().String(0) }
func (m *MockConfig) GRPCPort() string { return m.Called().String(0) }
func (m *MockConfig) NodeToken() string { return m.Called().String(0) }
func TestValidateCertDomains_NotFound(t *testing.T) {
result := ValidateCertDomains("nonexistent.pem", "example.com")
assert.False(t, result)
}
func TestValidateCertDomains_InvalidPEM(t *testing.T) {
tmpFile, err := os.CreateTemp("", "invalid*.pem")
assert.NoError(t, err)
defer os.Remove(tmpFile.Name())
_, _ = tmpFile.WriteString("not a pem")
tmpFile.Close()
result := ValidateCertDomains(tmpFile.Name(), "example.com")
assert.False(t, result)
}
func TestTLSManager_getTLSConfig(t *testing.T) {
tm := &tlsManager{
useCertMagic: false,
}
cfg := tm.getTLSConfig()
assert.NotNil(t, cfg)
assert.Equal(t, uint16(tls.VersionTLS13), cfg.MinVersion)
}
func TestTLSManager_getCertificate_Magic(t *testing.T) {
tm := &tlsManager{
useCertMagic: true,
}
hello := &tls.ClientHelloInfo{}
assert.Panics(t, func() {
_, _ = tm.getCertificate(hello)
})
}
func TestTLSManager_userCertsExistAndValid(t *testing.T) {
tm := &tlsManager{
certPath: "nonexistent.pem",
keyPath: "nonexistent.key",
}
assert.False(t, tm.userCertsExistAndValid())
keyFile, _ := os.CreateTemp("", "key*.pem")
defer os.Remove(keyFile.Name())
tm.keyPath = keyFile.Name()
assert.False(t, tm.userCertsExistAndValid())
}
func createTestCert(t *testing.T, domain string, wildcard bool, expired bool, soon bool) (string, string) {
priv, _ := rsa.GenerateKey(rand.Reader, 2048)
notAfter := time.Now().Add(365 * 24 * time.Hour)
if expired {
notAfter = time.Now().Add(-24 * time.Hour)
} else if soon {
notAfter = time.Now().Add(15 * 24 * time.Hour)
}
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: domain,
},
NotBefore: time.Now().Add(-24 * time.Hour),
NotAfter: notAfter,
DNSNames: []string{domain},
}
if wildcard {
template.DNSNames = append(template.DNSNames, "*."+domain)
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
assert.NoError(t, err)
certOut, _ := os.CreateTemp("", "cert*.pem")
_ = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
certOut.Close()
keyOut, _ := os.CreateTemp("", "key*.pem")
_ = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
keyOut.Close()
return certOut.Name(), keyOut.Name()
}
func TestValidateCertDomains_Success(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
result := ValidateCertDomains(certPath, "example.com")
assert.True(t, result)
}
func TestValidateCertDomains_Expired(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, true, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
result := ValidateCertDomains(certPath, "example.com")
assert.False(t, result)
}
func TestValidateCertDomains_ExpiringSoon(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, true)
defer os.Remove(certPath)
defer os.Remove(keyPath)
result := ValidateCertDomains(certPath, "example.com")
assert.False(t, result)
}
func TestValidateCertDomains_MissingWildcard(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", false, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
result := ValidateCertDomains(certPath, "example.com")
assert.False(t, result)
}
func TestTLSManager_loadUserCerts_Success(t *testing.T) {
certPath, keyPath := createTestCert(t, "example.com", true, false, false)
defer os.Remove(certPath)
defer os.Remove(keyPath)
tm := &tlsManager{
certPath: certPath,
keyPath: keyPath,
}
err := tm.loadUserCerts()
assert.NoError(t, err)
assert.NotNil(t, tm.userCert)
}
+4
View File
@@ -8,3 +8,7 @@ type Transport interface {
Listen() (net.Listener, error)
Serve(listener net.Listener) error
}
type HTTP interface {
Handler(conn net.Conn, isTLS bool)
}
+15
View File
@@ -0,0 +1,15 @@
package transport
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestTransportInterface(t *testing.T) {
var _ Transport = (*httpServer)(nil)
var _ Transport = (*https)(nil)
var _ Transport = (*tcp)(nil)
assert.True(t, true)
}