revert-54069ad305 #11

Closed
bagas wants to merge 217 commits from revert-54069ad305 into main
11 changed files with 182 additions and 64 deletions
Showing only changes of commit 79fd292a77 - Show all commits
+2 -2
View File
@@ -98,7 +98,7 @@ func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, err
} }
func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) { func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), registry, conf.TLSRedirect()) httpserver := transport.NewHTTPServer(conf, registry)
ln, err := httpserver.Listen() ln, err := httpserver.Listen()
if err != nil { if err != nil {
errChan <- fmt.Errorf("failed to start http server: %w", err) errChan <- fmt.Errorf("failed to start http server: %w", err)
@@ -115,7 +115,7 @@ func startHTTPSServer(conf config.Config, registry registry.Registry, errChan ch
errChan <- fmt.Errorf("failed to create TLS config: %w", err) errChan <- fmt.Errorf("failed to create TLS config: %w", err)
return return
} }
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), registry, conf.TLSRedirect(), tlsCfg) httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg)
ln, err := httpsServer.Listen() ln, err := httpsServer.Listen()
if err != nil { if err != nil {
errChan <- fmt.Errorf("failed to create TLS config: %w", err) errChan <- fmt.Errorf("failed to create TLS config: %w", err)
+1
View File
@@ -131,6 +131,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { func (m *MockConfig) Mode() types.ServerMode {
+1
View File
@@ -887,6 +887,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
+7 -6
View File
@@ -4,27 +4,28 @@ import (
"errors" "errors"
"log" "log"
"net" "net"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry" "tunnel_pls/internal/registry"
) )
type httpServer struct { type httpServer struct {
handler *httpHandler handler *httpHandler
port string config config.Config
} }
func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport {
return &httpServer{ return &httpServer{
handler: newHTTPHandler(domain, sessionRegistry, redirectTLS), handler: newHTTPHandler(config, sessionRegistry),
port: port, config: config,
} }
} }
func (ht *httpServer) Listen() (net.Listener, error) { func (ht *httpServer) Listen() (net.Listener, error) {
return net.Listen("tcp", ":"+ht.port) return net.Listen("tcp", ":"+ht.config.HTTPPort())
} }
func (ht *httpServer) Serve(listener net.Listener) error { func (ht *httpServer) Serve(listener net.Listener) error {
log.Printf("HTTP server is starting on port %s", ht.port) log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort())
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
+30 -13
View File
@@ -13,24 +13,27 @@ import (
func TestNewHTTPServer(t *testing.T) { func TestNewHTTPServer(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
domain := "example.com" mockConfig := &MockConfig{}
port := "8080" port := "0"
redirectTLS := true mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(domain, port, msr, redirectTLS) srv := NewHTTPServer(mockConfig, msr)
assert.NotNil(t, srv) assert.NotNil(t, srv)
httpSrv, ok := srv.(*httpServer) httpSrv, ok := srv.(*httpServer)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, port, httpSrv.port)
assert.Equal(t, domain, httpSrv.handler.domain)
assert.Equal(t, msr, httpSrv.handler.sessionRegistry) assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
assert.Equal(t, redirectTLS, httpSrv.handler.redirectTLS) assert.NotNil(t, srv)
} }
func TestHTTPServer_Listen(t *testing.T) { func TestHTTPServer_Listen(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
srv := NewHTTPServer("example.com", "0", msr, false) mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
listener, err := srv.Listen() listener, err := srv.Listen()
assert.NoError(t, err) assert.NoError(t, err)
@@ -40,7 +43,11 @@ func TestHTTPServer_Listen(t *testing.T) {
func TestHTTPServer_Serve(t *testing.T) { func TestHTTPServer_Serve(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
srv := NewHTTPServer("example.com", "0", msr, false) mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err) assert.NoError(t, err)
@@ -56,7 +63,11 @@ func TestHTTPServer_Serve(t *testing.T) {
func TestHTTPServer_Serve_AcceptError(t *testing.T) { func TestHTTPServer_Serve_AcceptError(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
srv := NewHTTPServer("example.com", "0", msr, false) mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
srv := NewHTTPServer(mockConfig, msr)
ml := new(mockListener) ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once() ml.On("Accept").Return(nil, errors.New("accept error")).Once()
@@ -69,17 +80,23 @@ func TestHTTPServer_Serve_AcceptError(t *testing.T) {
func TestHTTPServer_Serve_Success(t *testing.T) { func TestHTTPServer_Serve_Success(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
srv := NewHTTPServer("example.com", "0", msr, false) mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
mockConfig.On("HeaderSize").Return(4096)
mockConfig.On("TLSRedirect").Return(false)
srv := NewHTTPServer(mockConfig, msr)
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err) assert.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port listenerport := listener.Addr().(*net.TCPAddr).Port
go func() { go func() {
_ = srv.Serve(listener) _ = srv.Serve(listener)
}() }()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
assert.NoError(t, err) assert.NoError(t, err)
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
+25 -11
View File
@@ -1,7 +1,7 @@
package transport package transport
import ( import (
"bufio" "bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@@ -11,6 +11,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"time" "time"
"tunnel_pls/internal/config"
"tunnel_pls/internal/http/header" "tunnel_pls/internal/http/header"
"tunnel_pls/internal/http/stream" "tunnel_pls/internal/http/stream"
"tunnel_pls/internal/middleware" "tunnel_pls/internal/middleware"
@@ -21,16 +22,14 @@ import (
) )
type httpHandler struct { type httpHandler struct {
domain string config config.Config
sessionRegistry registry.Registry sessionRegistry registry.Registry
redirectTLS bool
} }
func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler {
return &httpHandler{ return &httpHandler{
domain: domain, config: config,
sessionRegistry: sessionRegistry, sessionRegistry: sessionRegistry,
redirectTLS: redirectTLS,
} }
} }
@@ -56,10 +55,25 @@ func (hh *httpHandler) badRequest(conn net.Conn) error {
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) { func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
defer hh.closeConnection(conn) defer hh.closeConnection(conn)
dstReader := bufio.NewReader(conn) _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
reqhf, err := header.NewRequest(dstReader) buf := make([]byte, hh.config.HeaderSize())
n, err := conn.Read(buf)
if err != nil {
_ = hh.badRequest(conn)
return
}
if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 {
_ = hh.badRequest(conn)
return
}
_ = conn.SetReadDeadline(time.Time{})
reqhf, err := header.NewRequest(buf[:n])
if err != nil { if err != nil {
log.Printf("Error creating request header: %v", err) log.Printf("Error creating request header: %v", err)
_ = hh.badRequest(conn)
return return
} }
@@ -70,7 +84,7 @@ func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
} }
if hh.shouldRedirectToTLS(isTLS) { if hh.shouldRedirectToTLS(isTLS) {
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.domain)) _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain()))
return return
} }
@@ -87,7 +101,7 @@ func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
return return
} }
hw := stream.New(conn, dstReader, conn.RemoteAddr()) hw := stream.New(conn, conn, conn.RemoteAddr())
defer func(hw stream.HTTP) { defer func(hw stream.HTTP) {
err = hw.Close() err = hw.Close()
if err != nil { if err != nil {
@@ -113,7 +127,7 @@ func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
} }
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool { func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
return !isTLS && hh.redirectTLS return !isTLS && hh.config.TLSRedirect()
} }
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
+77 -10
View File
@@ -230,11 +230,30 @@ func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte
} }
type MockConn struct { type MockConn struct {
net.Conn
mock.Mock mock.Mock
ReadBuffer *bytes.Buffer ReadBuffer *bytes.Buffer
} }
func (m *MockConn) LocalAddr() net.Addr {
args := m.Called()
return args.Get(0).(net.Addr)
}
func (m *MockConn) SetDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
args := m.Called(t)
return args.Error(0)
}
func (m *MockConn) Read(b []byte) (n int, err error) { func (m *MockConn) Read(b []byte) (n int, err error) {
if m.ReadBuffer != nil { if m.ReadBuffer != nil {
return m.ReadBuffer.Read(b) return m.ReadBuffer.Read(b)
@@ -245,6 +264,9 @@ func (m *MockConn) Read(b []byte) (n int, err error) {
func (m *MockConn) Write(b []byte) (n int, err error) { func (m *MockConn) Write(b []byte) (n int, err error) {
args := m.Called(b) args := m.Called(b)
if args.Int(0) == -1 {
return len(b), args.Error(1)
}
return args.Int(0), args.Error(1) return args.Int(0), args.Error(1)
} }
@@ -269,11 +291,12 @@ func (c *wrappedConn) RemoteAddr() net.Addr {
func TestNewHTTPHandler(t *testing.T) { func TestNewHTTPHandler(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
hh := newHTTPHandler("domain", msr, true) mockConfig := &MockConfig{}
mockConfig.On("Domain").Return("domain")
mockConfig.On("TLSRedirect").Return(false)
hh := newHTTPHandler(mockConfig, msr)
assert.NotNil(t, hh) assert.NotNil(t, hh)
assert.Equal(t, "domain", hh.domain)
assert.Equal(t, msr, hh.sessionRegistry) assert.Equal(t, msr, hh.sessionRegistry)
assert.True(t, hh.redirectTLS)
} }
func TestHandler(t *testing.T) { func TestHandler(t *testing.T) {
@@ -318,8 +341,8 @@ func TestHandler(t *testing.T) {
name: "redirect to TLS", name: "redirect to TLS",
isTLS: false, isTLS: false,
redirectTLS: true, redirectTLS: true,
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"), request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"),
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://example.domain/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"), expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) { setupMocks: func(msr *MockSessionRegistry) {
}, },
}, },
@@ -350,7 +373,25 @@ func TestHandler(t *testing.T) {
isTLS: false, isTLS: false,
redirectTLS: false, redirectTLS: false,
request: []byte("INVALID\r\n\r\n"), request: []byte("INVALID\r\n\r\n"),
expected: []byte(""), expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "bad request - header too large",
isTLS: false,
redirectTLS: false,
request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) {
},
},
{
name: "bad request - no request",
isTLS: false,
redirectTLS: false,
request: []byte(""),
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
setupMocks: func(msr *MockSessionRegistry) { setupMocks: func(msr *MockSessionRegistry) {
}, },
}, },
@@ -453,7 +494,8 @@ func TestHandler(t *testing.T) {
setupConn: func() (net.Conn, net.Conn) { setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn) mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n")) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error"))
mc.On("Close").Return(nil) mc.On("Close").Return(nil)
return mc, nil return mc, nil
}, },
@@ -467,11 +509,27 @@ func TestHandler(t *testing.T) {
setupConn: func() (net.Conn, net.Conn) { setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn) mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n")) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Close").Return(nil) mc.On("Close").Return(nil)
return mc, nil return mc, nil
}, },
}, },
{
name: "read error - connection failure",
isTLS: false,
redirectTLS: false,
request: []byte(""),
expected: []byte(""),
setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn)
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer"))
mc.On("Close").Return(nil)
return mc, nil
},
},
{ {
name: "handle ping request - write failure", name: "handle ping request - write failure",
isTLS: true, isTLS: true,
@@ -481,6 +539,7 @@ func TestHandler(t *testing.T) {
setupConn: func() (net.Conn, net.Conn) { setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn) mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n")) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
mc.On("Close").Return(nil) mc.On("Close").Return(nil)
return mc, nil return mc, nil
@@ -495,6 +554,7 @@ func TestHandler(t *testing.T) {
setupConn: func() (net.Conn, net.Conn) { setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn) mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n")) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Write", mock.Anything).Return(182, nil) mc.On("Write", mock.Anything).Return(182, nil)
mc.On("Close").Return(fmt.Errorf("close error")) mc.On("Close").Return(fmt.Errorf("close error"))
return mc, nil return mc, nil
@@ -527,6 +587,7 @@ func TestHandler(t *testing.T) {
setupConn: func() (net.Conn, net.Conn) { setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn) mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n")) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2) mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
mc.On("RemoteAddr").Return(addr) mc.On("RemoteAddr").Return(addr)
@@ -557,6 +618,7 @@ func TestHandler(t *testing.T) {
setupConn: func() (net.Conn, net.Conn) { setupConn: func() (net.Conn, net.Conn) {
mc := new(MockConn) mc := new(MockConn)
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n")) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
mc.On("SetReadDeadline", mock.Anything).Return(nil)
mc.On("Close").Return(nil).Times(2) mc.On("Close").Return(nil).Times(2)
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")}) mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
return mc, nil return mc, nil
@@ -615,10 +677,15 @@ func TestHandler(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
mockSessionRegistry := new(MockSessionRegistry) mockSessionRegistry := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return("example.com")
mockConfig.On("HTTPPort").Return(port)
mockConfig.On("HeaderSize").Return(4096)
mockConfig.On("TLSRedirect").Return(true)
hh := &httpHandler{ hh := &httpHandler{
domain: "domain",
sessionRegistry: mockSessionRegistry, sessionRegistry: mockSessionRegistry,
redirectTLS: tt.redirectTLS, config: mockConfig,
} }
if tt.setupMocks != nil { if tt.setupMocks != nil {
+7 -8
View File
@@ -5,31 +5,30 @@ import (
"errors" "errors"
"log" "log"
"net" "net"
"tunnel_pls/internal/config"
"tunnel_pls/internal/registry" "tunnel_pls/internal/registry"
) )
type https struct { type https struct {
config config.Config
tlsConfig *tls.Config tlsConfig *tls.Config
httpHandler *httpHandler httpHandler *httpHandler
domain string
port string
} }
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport { func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport {
return &https{ return &https{
config: config,
tlsConfig: tlsConfig, tlsConfig: tlsConfig,
httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS), httpHandler: newHTTPHandler(config, sessionRegistry),
domain: domain,
port: port,
} }
} }
func (ht *https) Listen() (net.Listener, error) { func (ht *https) Listen() (net.Listener, error) {
return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig) return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig)
} }
func (ht *https) Serve(listener net.Listener) error { func (ht *https) Serve(listener net.Listener) error {
log.Printf("HTTPS server is starting on port %s", ht.port) log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort())
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
+30 -14
View File
@@ -13,31 +13,32 @@ import (
func TestNewHTTPSServer(t *testing.T) { func TestNewHTTPSServer(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
domain := "example.com" mockConfig := &MockConfig{}
port := "443" port := "0"
redirectTLS := false
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
mockConfig.On("Domain").Return(mockConfig)
srv := NewHTTPSServer(domain, port, msr, redirectTLS, tlsConfig) mockConfig.On("HTTPSPort").Return(port)
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
assert.NotNil(t, srv) assert.NotNil(t, srv)
httpsSrv, ok := srv.(*https) httpsSrv, ok := srv.(*https)
assert.True(t, ok) assert.True(t, ok)
assert.Equal(t, port, httpsSrv.port)
assert.Equal(t, domain, httpsSrv.domain)
assert.Equal(t, tlsConfig, httpsSrv.tlsConfig) assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry) assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
} }
func TestHTTPSServer_Listen(t *testing.T) { func TestHTTPSServer_Listen(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return nil, nil return nil, nil
}, },
} }
srv := NewHTTPSServer("example.com", "0", msr, false, tlsConfig) srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
listener, err := srv.Listen() listener, err := srv.Listen()
if err != nil { if err != nil {
@@ -50,7 +51,11 @@ func TestHTTPSServer_Listen(t *testing.T) {
func TestHTTPSServer_Serve(t *testing.T) { func TestHTTPSServer_Serve(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{}) mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err) assert.NoError(t, err)
@@ -66,7 +71,12 @@ func TestHTTPSServer_Serve(t *testing.T) {
func TestHTTPSServer_Serve_AcceptError(t *testing.T) { func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
ml := new(mockListener) ml := new(mockListener)
ml.On("Accept").Return(nil, errors.New("accept error")).Once() ml.On("Accept").Return(nil, errors.New("accept error")).Once()
@@ -79,17 +89,23 @@ func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
func TestHTTPSServer_Serve_Success(t *testing.T) { func TestHTTPSServer_Serve_Success(t *testing.T) {
msr := new(MockSessionRegistry) msr := new(MockSessionRegistry)
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{}) mockConfig := &MockConfig{}
port := "0"
mockConfig.On("Domain").Return(mockConfig)
mockConfig.On("HTTPSPort").Return(port)
mockConfig.On("HeaderSize").Return(4096)
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
assert.NoError(t, err) assert.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port listenerport := listener.Addr().(*net.TCPAddr).Port
go func() { go func() {
_ = srv.Serve(listener) _ = srv.Serve(listener)
}() }()
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
assert.NoError(t, err) assert.NoError(t, err)
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
+1
View File
@@ -36,6 +36,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
+1
View File
@@ -45,6 +45,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
func (m *MockConfig) Mode() types.ServerMode { func (m *MockConfig) Mode() types.ServerMode {