From 21b551a66d6be24357da26f8a334092c8f756872 Mon Sep 17 00:00:00 2001 From: bagas Date: Sun, 25 Jan 2026 18:47:54 +0700 Subject: [PATCH] feat(http): add http header size limit for initial request --- internal/bootstrap/bootstrap.go | 4 +- internal/bootstrap/bootstrap_test.go | 1 + internal/grpc/client/client_test.go | 1 + internal/transport/http.go | 13 ++-- internal/transport/http_test.go | 43 +++++++++---- internal/transport/httphandler.go | 36 +++++++---- internal/transport/httphandler_test.go | 87 +++++++++++++++++++++++--- internal/transport/https.go | 15 +++-- internal/transport/https_test.go | 44 ++++++++----- internal/transport/tls_test.go | 1 + server/server_test.go | 1 + 11 files changed, 182 insertions(+), 64 deletions(-) diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index bd8645f..6c3a1f9 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -98,7 +98,7 @@ func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, err } func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) { - httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), registry, conf.TLSRedirect()) + httpserver := transport.NewHTTPServer(conf, registry) ln, err := httpserver.Listen() if err != nil { errChan <- fmt.Errorf("failed to start http server: %w", err) @@ -115,7 +115,7 @@ func startHTTPSServer(conf config.Config, registry registry.Registry, errChan ch errChan <- fmt.Errorf("failed to create TLS config: %w", err) return } - httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), registry, conf.TLSRedirect(), tlsCfg) + httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg) ln, err := httpsServer.Listen() if err != nil { errChan <- fmt.Errorf("failed to create TLS config: %w", err) diff --git a/internal/bootstrap/bootstrap_test.go b/internal/bootstrap/bootstrap_test.go index 13778c7..6586ea3 100644 --- a/internal/bootstrap/bootstrap_test.go +++ b/internal/bootstrap/bootstrap_test.go @@ -131,6 +131,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) Mode() types.ServerMode { diff --git a/internal/grpc/client/client_test.go b/internal/grpc/client/client_test.go index 3964cf0..f19a0b9 100644 --- a/internal/grpc/client/client_test.go +++ b/internal/grpc/client/client_test.go @@ -887,6 +887,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } diff --git a/internal/transport/http.go b/internal/transport/http.go index 8ea5c4d..5c4648d 100644 --- a/internal/transport/http.go +++ b/internal/transport/http.go @@ -4,27 +4,28 @@ import ( "errors" "log" "net" + "tunnel_pls/internal/config" "tunnel_pls/internal/registry" ) type httpServer struct { handler *httpHandler - port string + config config.Config } -func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport { +func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport { return &httpServer{ - handler: newHTTPHandler(domain, sessionRegistry, redirectTLS), - port: port, + handler: newHTTPHandler(config, sessionRegistry), + config: config, } } func (ht *httpServer) Listen() (net.Listener, error) { - return net.Listen("tcp", ":"+ht.port) + return net.Listen("tcp", ":"+ht.config.HTTPPort()) } func (ht *httpServer) Serve(listener net.Listener) error { - log.Printf("HTTP server is starting on port %s", ht.port) + log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort()) for { conn, err := listener.Accept() if err != nil { diff --git a/internal/transport/http_test.go b/internal/transport/http_test.go index 348dbb0..3922eb0 100644 --- a/internal/transport/http_test.go +++ b/internal/transport/http_test.go @@ -13,24 +13,27 @@ import ( func TestNewHTTPServer(t *testing.T) { msr := new(MockSessionRegistry) - domain := "example.com" - port := "8080" - redirectTLS := true + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) - srv := NewHTTPServer(domain, port, msr, redirectTLS) + srv := NewHTTPServer(mockConfig, msr) assert.NotNil(t, srv) httpSrv, ok := srv.(*httpServer) assert.True(t, ok) - assert.Equal(t, port, httpSrv.port) - assert.Equal(t, domain, httpSrv.handler.domain) assert.Equal(t, msr, httpSrv.handler.sessionRegistry) - assert.Equal(t, redirectTLS, httpSrv.handler.redirectTLS) + assert.NotNil(t, srv) } func TestHTTPServer_Listen(t *testing.T) { msr := new(MockSessionRegistry) - srv := NewHTTPServer("example.com", "0", msr, false) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + srv := NewHTTPServer(mockConfig, msr) listener, err := srv.Listen() assert.NoError(t, err) @@ -40,7 +43,11 @@ func TestHTTPServer_Listen(t *testing.T) { func TestHTTPServer_Serve(t *testing.T) { msr := new(MockSessionRegistry) - srv := NewHTTPServer("example.com", "0", msr, false) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + srv := NewHTTPServer(mockConfig, msr) listener, err := net.Listen("tcp", "127.0.0.1:0") assert.NoError(t, err) @@ -56,7 +63,11 @@ func TestHTTPServer_Serve(t *testing.T) { func TestHTTPServer_Serve_AcceptError(t *testing.T) { msr := new(MockSessionRegistry) - srv := NewHTTPServer("example.com", "0", msr, false) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + srv := NewHTTPServer(mockConfig, msr) ml := new(mockListener) ml.On("Accept").Return(nil, errors.New("accept error")).Once() @@ -69,17 +80,23 @@ func TestHTTPServer_Serve_AcceptError(t *testing.T) { func TestHTTPServer_Serve_Success(t *testing.T) { msr := new(MockSessionRegistry) - srv := NewHTTPServer("example.com", "0", msr, false) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + mockConfig.On("HeaderSize").Return(4096) + mockConfig.On("TLSRedirect").Return(false) + srv := NewHTTPServer(mockConfig, msr) listener, err := net.Listen("tcp", "127.0.0.1:0") assert.NoError(t, err) - port := listener.Addr().(*net.TCPAddr).Port + listenerport := listener.Addr().(*net.TCPAddr).Port go func() { _ = srv.Serve(listener) }() - conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport)) assert.NoError(t, err) _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) diff --git a/internal/transport/httphandler.go b/internal/transport/httphandler.go index 411a3f9..7e2135a 100644 --- a/internal/transport/httphandler.go +++ b/internal/transport/httphandler.go @@ -1,7 +1,7 @@ package transport import ( - "bufio" + "bytes" "context" "errors" "fmt" @@ -11,6 +11,7 @@ import ( "net/http" "strings" "time" + "tunnel_pls/internal/config" "tunnel_pls/internal/http/header" "tunnel_pls/internal/http/stream" "tunnel_pls/internal/middleware" @@ -21,16 +22,14 @@ import ( ) type httpHandler struct { - domain string + config config.Config sessionRegistry registry.Registry - redirectTLS bool } -func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler { +func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler { return &httpHandler{ - domain: domain, + config: config, sessionRegistry: sessionRegistry, - redirectTLS: redirectTLS, } } @@ -56,10 +55,25 @@ func (hh *httpHandler) badRequest(conn net.Conn) error { func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) { defer hh.closeConnection(conn) - dstReader := bufio.NewReader(conn) - reqhf, err := header.NewRequest(dstReader) + _ = conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + buf := make([]byte, hh.config.HeaderSize()) + n, err := conn.Read(buf) + if err != nil { + _ = hh.badRequest(conn) + return + } + + if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 { + _ = hh.badRequest(conn) + return + } + + _ = conn.SetReadDeadline(time.Time{}) + + reqhf, err := header.NewRequest(buf[:n]) if err != nil { log.Printf("Error creating request header: %v", err) + _ = hh.badRequest(conn) return } @@ -70,7 +84,7 @@ func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) { } if hh.shouldRedirectToTLS(isTLS) { - _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.domain)) + _ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain())) return } @@ -87,7 +101,7 @@ func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) { return } - hw := stream.New(conn, dstReader, conn.RemoteAddr()) + hw := stream.New(conn, conn, conn.RemoteAddr()) defer func(hw stream.HTTP) { err = hw.Close() if err != nil { @@ -113,7 +127,7 @@ func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) { } func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool { - return !isTLS && hh.redirectTLS + return !isTLS && hh.config.TLSRedirect() } func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool { diff --git a/internal/transport/httphandler_test.go b/internal/transport/httphandler_test.go index b4e592a..b6df77f 100644 --- a/internal/transport/httphandler_test.go +++ b/internal/transport/httphandler_test.go @@ -230,11 +230,30 @@ func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte } type MockConn struct { - net.Conn mock.Mock ReadBuffer *bytes.Buffer } +func (m *MockConn) LocalAddr() net.Addr { + args := m.Called() + return args.Get(0).(net.Addr) +} + +func (m *MockConn) SetDeadline(t time.Time) error { + args := m.Called(t) + return args.Error(0) +} + +func (m *MockConn) SetReadDeadline(t time.Time) error { + args := m.Called(t) + return args.Error(0) +} + +func (m *MockConn) SetWriteDeadline(t time.Time) error { + args := m.Called(t) + return args.Error(0) +} + func (m *MockConn) Read(b []byte) (n int, err error) { if m.ReadBuffer != nil { return m.ReadBuffer.Read(b) @@ -245,6 +264,9 @@ func (m *MockConn) Read(b []byte) (n int, err error) { func (m *MockConn) Write(b []byte) (n int, err error) { args := m.Called(b) + if args.Int(0) == -1 { + return len(b), args.Error(1) + } return args.Int(0), args.Error(1) } @@ -269,11 +291,12 @@ func (c *wrappedConn) RemoteAddr() net.Addr { func TestNewHTTPHandler(t *testing.T) { msr := new(MockSessionRegistry) - hh := newHTTPHandler("domain", msr, true) + mockConfig := &MockConfig{} + mockConfig.On("Domain").Return("domain") + mockConfig.On("TLSRedirect").Return(false) + hh := newHTTPHandler(mockConfig, msr) assert.NotNil(t, hh) - assert.Equal(t, "domain", hh.domain) assert.Equal(t, msr, hh.sessionRegistry) - assert.True(t, hh.redirectTLS) } func TestHandler(t *testing.T) { @@ -318,8 +341,8 @@ func TestHandler(t *testing.T) { name: "redirect to TLS", isTLS: false, redirectTLS: true, - request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"), - expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://example.domain/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"), + request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"), + expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"), setupMocks: func(msr *MockSessionRegistry) { }, }, @@ -350,7 +373,25 @@ func TestHandler(t *testing.T) { isTLS: false, redirectTLS: false, request: []byte("INVALID\r\n\r\n"), - expected: []byte(""), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "bad request - header too large", + isTLS: false, + redirectTLS: false, + request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), + setupMocks: func(msr *MockSessionRegistry) { + }, + }, + { + name: "bad request - no request", + isTLS: false, + redirectTLS: false, + request: []byte(""), + expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"), setupMocks: func(msr *MockSessionRegistry) { }, }, @@ -453,7 +494,8 @@ func TestHandler(t *testing.T) { setupConn: func() (net.Conn, net.Conn) { mc := new(MockConn) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n")) - mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) + mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error")) mc.On("Close").Return(nil) return mc, nil }, @@ -467,11 +509,27 @@ func TestHandler(t *testing.T) { setupConn: func() (net.Conn, net.Conn) { mc := new(MockConn) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) mc.On("Close").Return(nil) return mc, nil }, }, + { + name: "read error - connection failure", + isTLS: false, + redirectTLS: false, + request: []byte(""), + expected: []byte(""), + setupConn: func() (net.Conn, net.Conn) { + mc := new(MockConn) + mc.On("SetReadDeadline", mock.Anything).Return(nil) + mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) + mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer")) + mc.On("Close").Return(nil) + return mc, nil + }, + }, { name: "handle ping request - write failure", isTLS: true, @@ -481,6 +539,7 @@ func TestHandler(t *testing.T) { setupConn: func() (net.Conn, net.Conn) { mc := new(MockConn) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error")) mc.On("Close").Return(nil) return mc, nil @@ -495,6 +554,7 @@ func TestHandler(t *testing.T) { setupConn: func() (net.Conn, net.Conn) { mc := new(MockConn) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) mc.On("Write", mock.Anything).Return(182, nil) mc.On("Close").Return(fmt.Errorf("close error")) return mc, nil @@ -527,6 +587,7 @@ func TestHandler(t *testing.T) { setupConn: func() (net.Conn, net.Conn) { mc := new(MockConn) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2) addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345") mc.On("RemoteAddr").Return(addr) @@ -557,6 +618,7 @@ func TestHandler(t *testing.T) { setupConn: func() (net.Conn, net.Conn) { mc := new(MockConn) mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n")) + mc.On("SetReadDeadline", mock.Anything).Return(nil) mc.On("Close").Return(nil).Times(2) mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")}) return mc, nil @@ -615,10 +677,15 @@ func TestHandler(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockSessionRegistry := new(MockSessionRegistry) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return("example.com") + mockConfig.On("HTTPPort").Return(port) + mockConfig.On("HeaderSize").Return(4096) + mockConfig.On("TLSRedirect").Return(true) hh := &httpHandler{ - domain: "domain", sessionRegistry: mockSessionRegistry, - redirectTLS: tt.redirectTLS, + config: mockConfig, } if tt.setupMocks != nil { diff --git a/internal/transport/https.go b/internal/transport/https.go index 10be94f..f1076bf 100644 --- a/internal/transport/https.go +++ b/internal/transport/https.go @@ -5,31 +5,30 @@ import ( "errors" "log" "net" + "tunnel_pls/internal/config" "tunnel_pls/internal/registry" ) type https struct { + config config.Config tlsConfig *tls.Config httpHandler *httpHandler - domain string - port string } -func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport { +func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport { return &https{ + config: config, tlsConfig: tlsConfig, - httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS), - domain: domain, - port: port, + httpHandler: newHTTPHandler(config, sessionRegistry), } } func (ht *https) Listen() (net.Listener, error) { - return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig) + return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig) } func (ht *https) Serve(listener net.Listener) error { - log.Printf("HTTPS server is starting on port %s", ht.port) + log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort()) for { conn, err := listener.Accept() if err != nil { diff --git a/internal/transport/https_test.go b/internal/transport/https_test.go index e956430..cf09592 100644 --- a/internal/transport/https_test.go +++ b/internal/transport/https_test.go @@ -13,31 +13,32 @@ import ( func TestNewHTTPSServer(t *testing.T) { msr := new(MockSessionRegistry) - domain := "example.com" - port := "443" - redirectTLS := false + mockConfig := &MockConfig{} + port := "0" tlsConfig := &tls.Config{} - - srv := NewHTTPSServer(domain, port, msr, redirectTLS, tlsConfig) + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) + srv := NewHTTPSServer(mockConfig, msr, tlsConfig) assert.NotNil(t, srv) httpsSrv, ok := srv.(*https) assert.True(t, ok) - assert.Equal(t, port, httpsSrv.port) - assert.Equal(t, domain, httpsSrv.domain) assert.Equal(t, tlsConfig, httpsSrv.tlsConfig) assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry) } func TestHTTPSServer_Listen(t *testing.T) { msr := new(MockSessionRegistry) - + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) tlsConfig := &tls.Config{ GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { return nil, nil }, } - srv := NewHTTPSServer("example.com", "0", msr, false, tlsConfig) + srv := NewHTTPSServer(mockConfig, msr, tlsConfig) listener, err := srv.Listen() if err != nil { @@ -50,7 +51,11 @@ func TestHTTPSServer_Listen(t *testing.T) { func TestHTTPSServer_Serve(t *testing.T) { msr := new(MockSessionRegistry) - srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{}) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) + srv := NewHTTPSServer(mockConfig, msr, &tls.Config{}) listener, err := net.Listen("tcp", "127.0.0.1:0") assert.NoError(t, err) @@ -66,7 +71,12 @@ func TestHTTPSServer_Serve(t *testing.T) { func TestHTTPSServer_Serve_AcceptError(t *testing.T) { msr := new(MockSessionRegistry) - srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{}) + + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) + srv := NewHTTPSServer(mockConfig, msr, &tls.Config{}) ml := new(mockListener) ml.On("Accept").Return(nil, errors.New("accept error")).Once() @@ -79,17 +89,23 @@ func TestHTTPSServer_Serve_AcceptError(t *testing.T) { func TestHTTPSServer_Serve_Success(t *testing.T) { msr := new(MockSessionRegistry) - srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{}) + mockConfig := &MockConfig{} + port := "0" + mockConfig.On("Domain").Return(mockConfig) + mockConfig.On("HTTPSPort").Return(port) + mockConfig.On("HeaderSize").Return(4096) + + srv := NewHTTPSServer(mockConfig, msr, &tls.Config{}) listener, err := net.Listen("tcp", "127.0.0.1:0") assert.NoError(t, err) - port := listener.Addr().(*net.TCPAddr).Port + listenerport := listener.Addr().(*net.TCPAddr).Port go func() { _ = srv.Serve(listener) }() - conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport)) assert.NoError(t, err) _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n")) diff --git a/internal/transport/tls_test.go b/internal/transport/tls_test.go index 1518469..12e656d 100644 --- a/internal/transport/tls_test.go +++ b/internal/transport/tls_test.go @@ -36,6 +36,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) } diff --git a/server/server_test.go b/server/server_test.go index 572fa1f..9b7dd88 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -45,6 +45,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) } func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) } func (m *MockConfig) BufferSize() int { return m.Called().Int(0) } +func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) } func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) } func (m *MockConfig) PprofPort() string { return m.Called().String(0) } func (m *MockConfig) Mode() types.ServerMode {