feat(testing): add comprehensive test coverage and code quality improvements #76
@@ -98,7 +98,7 @@ func (b *Bootstrap) startGRPCClient(ctx context.Context, conf config.Config, err
|
||||
}
|
||||
|
||||
func startHTTPServer(conf config.Config, registry registry.Registry, errChan chan<- error) {
|
||||
httpserver := transport.NewHTTPServer(conf.Domain(), conf.HTTPPort(), registry, conf.TLSRedirect())
|
||||
httpserver := transport.NewHTTPServer(conf, registry)
|
||||
ln, err := httpserver.Listen()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to start http server: %w", err)
|
||||
@@ -115,7 +115,7 @@ func startHTTPSServer(conf config.Config, registry registry.Registry, errChan ch
|
||||
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
|
||||
return
|
||||
}
|
||||
httpsServer := transport.NewHTTPSServer(conf.Domain(), conf.HTTPSPort(), registry, conf.TLSRedirect(), tlsCfg)
|
||||
httpsServer := transport.NewHTTPSServer(conf, registry, tlsCfg)
|
||||
ln, err := httpsServer.Listen()
|
||||
if err != nil {
|
||||
errChan <- fmt.Errorf("failed to create TLS config: %w", err)
|
||||
|
||||
@@ -131,6 +131,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode {
|
||||
|
||||
@@ -887,6 +887,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
|
||||
|
||||
@@ -4,27 +4,28 @@ import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/registry"
|
||||
)
|
||||
|
||||
type httpServer struct {
|
||||
handler *httpHandler
|
||||
port string
|
||||
config config.Config
|
||||
}
|
||||
|
||||
func NewHTTPServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool) Transport {
|
||||
func NewHTTPServer(config config.Config, sessionRegistry registry.Registry) Transport {
|
||||
return &httpServer{
|
||||
handler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
|
||||
port: port,
|
||||
handler: newHTTPHandler(config, sessionRegistry),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (ht *httpServer) Listen() (net.Listener, error) {
|
||||
return net.Listen("tcp", ":"+ht.port)
|
||||
return net.Listen("tcp", ":"+ht.config.HTTPPort())
|
||||
}
|
||||
|
||||
func (ht *httpServer) Serve(listener net.Listener) error {
|
||||
log.Printf("HTTP server is starting on port %s", ht.port)
|
||||
log.Printf("HTTP server is starting on port %s", ht.config.HTTPPort())
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
|
||||
@@ -13,24 +13,27 @@ import (
|
||||
|
||||
func TestNewHTTPServer(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
domain := "example.com"
|
||||
port := "8080"
|
||||
redirectTLS := true
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
|
||||
srv := NewHTTPServer(domain, port, msr, redirectTLS)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
httpSrv, ok := srv.(*httpServer)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, port, httpSrv.port)
|
||||
assert.Equal(t, domain, httpSrv.handler.domain)
|
||||
assert.Equal(t, msr, httpSrv.handler.sessionRegistry)
|
||||
assert.Equal(t, redirectTLS, httpSrv.handler.redirectTLS)
|
||||
assert.NotNil(t, srv)
|
||||
}
|
||||
|
||||
func TestHTTPServer_Listen(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPServer("example.com", "0", msr, false)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
assert.NoError(t, err)
|
||||
@@ -40,7 +43,11 @@ func TestHTTPServer_Listen(t *testing.T) {
|
||||
|
||||
func TestHTTPServer_Serve(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPServer("example.com", "0", msr, false)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
@@ -56,7 +63,11 @@ func TestHTTPServer_Serve(t *testing.T) {
|
||||
|
||||
func TestHTTPServer_Serve_AcceptError(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPServer("example.com", "0", msr, false)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
@@ -69,17 +80,23 @@ func TestHTTPServer_Serve_AcceptError(t *testing.T) {
|
||||
|
||||
func TestHTTPServer_Serve_Success(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPServer("example.com", "0", msr, false)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
mockConfig.On("HeaderSize").Return(4096)
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
srv := NewHTTPServer(mockConfig, msr)
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listenerport := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/http/header"
|
||||
"tunnel_pls/internal/http/stream"
|
||||
"tunnel_pls/internal/middleware"
|
||||
@@ -21,16 +22,14 @@ import (
|
||||
)
|
||||
|
||||
type httpHandler struct {
|
||||
domain string
|
||||
config config.Config
|
||||
sessionRegistry registry.Registry
|
||||
redirectTLS bool
|
||||
}
|
||||
|
||||
func newHTTPHandler(domain string, sessionRegistry registry.Registry, redirectTLS bool) *httpHandler {
|
||||
func newHTTPHandler(config config.Config, sessionRegistry registry.Registry) *httpHandler {
|
||||
return &httpHandler{
|
||||
domain: domain,
|
||||
config: config,
|
||||
sessionRegistry: sessionRegistry,
|
||||
redirectTLS: redirectTLS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,10 +55,25 @@ func (hh *httpHandler) badRequest(conn net.Conn) error {
|
||||
func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
|
||||
defer hh.closeConnection(conn)
|
||||
|
||||
dstReader := bufio.NewReader(conn)
|
||||
reqhf, err := header.NewRequest(dstReader)
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
buf := make([]byte, hh.config.HeaderSize())
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
if idx := bytes.Index(buf[:n], []byte("\r\n\r\n")); idx == -1 {
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
|
||||
reqhf, err := header.NewRequest(buf[:n])
|
||||
if err != nil {
|
||||
log.Printf("Error creating request header: %v", err)
|
||||
_ = hh.badRequest(conn)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -70,7 +84,7 @@ func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
|
||||
}
|
||||
|
||||
if hh.shouldRedirectToTLS(isTLS) {
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.domain))
|
||||
_ = hh.redirect(conn, http.StatusMovedPermanently, fmt.Sprintf("https://%s.%s/\r\n", slug, hh.config.Domain()))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -87,7 +101,7 @@ func (hh *httpHandler) Handler(conn net.Conn, isTLS bool) {
|
||||
return
|
||||
}
|
||||
|
||||
hw := stream.New(conn, dstReader, conn.RemoteAddr())
|
||||
hw := stream.New(conn, conn, conn.RemoteAddr())
|
||||
defer func(hw stream.HTTP) {
|
||||
err = hw.Close()
|
||||
if err != nil {
|
||||
@@ -113,7 +127,7 @@ func (hh *httpHandler) extractSlug(reqhf header.RequestHeader) (string, error) {
|
||||
}
|
||||
|
||||
func (hh *httpHandler) shouldRedirectToTLS(isTLS bool) bool {
|
||||
return !isTLS && hh.redirectTLS
|
||||
return !isTLS && hh.config.TLSRedirect()
|
||||
}
|
||||
|
||||
func (hh *httpHandler) handlePingRequest(slug string, conn net.Conn) bool {
|
||||
|
||||
@@ -230,11 +230,30 @@ func (m *MockForwarder) OpenForwardedChannel(ctx context.Context, payload []byte
|
||||
}
|
||||
|
||||
type MockConn struct {
|
||||
net.Conn
|
||||
mock.Mock
|
||||
ReadBuffer *bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
args := m.Called()
|
||||
return args.Get(0).(net.Addr)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
args := m.Called(t)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
if m.ReadBuffer != nil {
|
||||
return m.ReadBuffer.Read(b)
|
||||
@@ -245,6 +264,9 @@ func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
args := m.Called(b)
|
||||
if args.Int(0) == -1 {
|
||||
return len(b), args.Error(1)
|
||||
}
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
@@ -269,11 +291,12 @@ func (c *wrappedConn) RemoteAddr() net.Addr {
|
||||
|
||||
func TestNewHTTPHandler(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
hh := newHTTPHandler("domain", msr, true)
|
||||
mockConfig := &MockConfig{}
|
||||
mockConfig.On("Domain").Return("domain")
|
||||
mockConfig.On("TLSRedirect").Return(false)
|
||||
hh := newHTTPHandler(mockConfig, msr)
|
||||
assert.NotNil(t, hh)
|
||||
assert.Equal(t, "domain", hh.domain)
|
||||
assert.Equal(t, msr, hh.sessionRegistry)
|
||||
assert.True(t, hh.redirectTLS)
|
||||
}
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
@@ -318,8 +341,8 @@ func TestHandler(t *testing.T) {
|
||||
name: "redirect to TLS",
|
||||
isTLS: false,
|
||||
redirectTLS: true,
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://example.domain/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
request: []byte("GET / HTTP/1.1\r\nHost: tunnel.example.com\r\n\r\n"),
|
||||
expected: []byte("HTTP/1.1 301 Moved Permanently\r\nLocation: https://tunnel.example.com/\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
@@ -350,7 +373,25 @@ func TestHandler(t *testing.T) {
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte("INVALID\r\n\r\n"),
|
||||
expected: []byte(""),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - header too large",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: test.domain\r\n%s\r\n\r\n", strings.Repeat("test", 10000))),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bad request - no request",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte(""),
|
||||
expected: []byte("HTTP/1.1 400 Bad Request\r\n\r\n"),
|
||||
setupMocks: func(msr *MockSessionRegistry) {
|
||||
},
|
||||
},
|
||||
@@ -453,7 +494,8 @@ func TestHandler(t *testing.T) {
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: example.domain\r\n\r\n"))
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(-1, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
@@ -467,11 +509,27 @@ func TestHandler(t *testing.T) {
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "read error - connection failure",
|
||||
isTLS: false,
|
||||
redirectTLS: false,
|
||||
request: []byte(""),
|
||||
expected: []byte(""),
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Read", mock.Anything).Return(0, fmt.Errorf("connection reset by peer"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "handle ping request - write failure",
|
||||
isTLS: true,
|
||||
@@ -481,6 +539,7 @@ func TestHandler(t *testing.T) {
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(0, fmt.Errorf("write error"))
|
||||
mc.On("Close").Return(nil)
|
||||
return mc, nil
|
||||
@@ -495,6 +554,7 @@ func TestHandler(t *testing.T) {
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: ping.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Write", mock.Anything).Return(182, nil)
|
||||
mc.On("Close").Return(fmt.Errorf("close error"))
|
||||
return mc, nil
|
||||
@@ -527,6 +587,7 @@ func TestHandler(t *testing.T) {
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Close").Return(fmt.Errorf("stream close error")).Times(2)
|
||||
addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:12345")
|
||||
mc.On("RemoteAddr").Return(addr)
|
||||
@@ -557,6 +618,7 @@ func TestHandler(t *testing.T) {
|
||||
setupConn: func() (net.Conn, net.Conn) {
|
||||
mc := new(MockConn)
|
||||
mc.ReadBuffer = bytes.NewBuffer([]byte("GET / HTTP/1.1\r\nHost: test.domain\r\n\r\n"))
|
||||
mc.On("SetReadDeadline", mock.Anything).Return(nil)
|
||||
mc.On("Close").Return(nil).Times(2)
|
||||
mc.On("RemoteAddr").Return(&net.IPAddr{IP: net.ParseIP("127.0.0.1")})
|
||||
return mc, nil
|
||||
@@ -615,10 +677,15 @@ func TestHandler(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockSessionRegistry := new(MockSessionRegistry)
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return("example.com")
|
||||
mockConfig.On("HTTPPort").Return(port)
|
||||
mockConfig.On("HeaderSize").Return(4096)
|
||||
mockConfig.On("TLSRedirect").Return(true)
|
||||
hh := &httpHandler{
|
||||
domain: "domain",
|
||||
sessionRegistry: mockSessionRegistry,
|
||||
redirectTLS: tt.redirectTLS,
|
||||
config: mockConfig,
|
||||
}
|
||||
|
||||
if tt.setupMocks != nil {
|
||||
|
||||
@@ -5,31 +5,30 @@ import (
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"tunnel_pls/internal/config"
|
||||
"tunnel_pls/internal/registry"
|
||||
)
|
||||
|
||||
type https struct {
|
||||
config config.Config
|
||||
tlsConfig *tls.Config
|
||||
httpHandler *httpHandler
|
||||
domain string
|
||||
port string
|
||||
}
|
||||
|
||||
func NewHTTPSServer(domain, port string, sessionRegistry registry.Registry, redirectTLS bool, tlsConfig *tls.Config) Transport {
|
||||
func NewHTTPSServer(config config.Config, sessionRegistry registry.Registry, tlsConfig *tls.Config) Transport {
|
||||
return &https{
|
||||
config: config,
|
||||
tlsConfig: tlsConfig,
|
||||
httpHandler: newHTTPHandler(domain, sessionRegistry, redirectTLS),
|
||||
domain: domain,
|
||||
port: port,
|
||||
httpHandler: newHTTPHandler(config, sessionRegistry),
|
||||
}
|
||||
}
|
||||
|
||||
func (ht *https) Listen() (net.Listener, error) {
|
||||
return tls.Listen("tcp", ":"+ht.port, ht.tlsConfig)
|
||||
return tls.Listen("tcp", ":"+ht.config.HTTPSPort(), ht.tlsConfig)
|
||||
}
|
||||
|
||||
func (ht *https) Serve(listener net.Listener) error {
|
||||
log.Printf("HTTPS server is starting on port %s", ht.port)
|
||||
log.Printf("HTTPS server is starting on port %s", ht.config.HTTPSPort())
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
|
||||
@@ -13,31 +13,32 @@ import (
|
||||
|
||||
func TestNewHTTPSServer(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
domain := "example.com"
|
||||
port := "443"
|
||||
redirectTLS := false
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
srv := NewHTTPSServer(domain, port, msr, redirectTLS, tlsConfig)
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
|
||||
assert.NotNil(t, srv)
|
||||
|
||||
httpsSrv, ok := srv.(*https)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, port, httpsSrv.port)
|
||||
assert.Equal(t, domain, httpsSrv.domain)
|
||||
assert.Equal(t, tlsConfig, httpsSrv.tlsConfig)
|
||||
assert.Equal(t, msr, httpsSrv.httpHandler.sessionRegistry)
|
||||
}
|
||||
|
||||
func TestHTTPSServer_Listen(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
srv := NewHTTPSServer("example.com", "0", msr, false, tlsConfig)
|
||||
srv := NewHTTPSServer(mockConfig, msr, tlsConfig)
|
||||
|
||||
listener, err := srv.Listen()
|
||||
if err != nil {
|
||||
@@ -50,7 +51,11 @@ func TestHTTPSServer_Listen(t *testing.T) {
|
||||
|
||||
func TestHTTPSServer_Serve(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
@@ -66,7 +71,12 @@ func TestHTTPSServer_Serve(t *testing.T) {
|
||||
|
||||
func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
|
||||
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
||||
|
||||
ml := new(mockListener)
|
||||
ml.On("Accept").Return(nil, errors.New("accept error")).Once()
|
||||
@@ -79,17 +89,23 @@ func TestHTTPSServer_Serve_AcceptError(t *testing.T) {
|
||||
|
||||
func TestHTTPSServer_Serve_Success(t *testing.T) {
|
||||
msr := new(MockSessionRegistry)
|
||||
srv := NewHTTPSServer("example.com", "0", msr, false, &tls.Config{})
|
||||
mockConfig := &MockConfig{}
|
||||
port := "0"
|
||||
mockConfig.On("Domain").Return(mockConfig)
|
||||
mockConfig.On("HTTPSPort").Return(port)
|
||||
mockConfig.On("HeaderSize").Return(4096)
|
||||
|
||||
srv := NewHTTPSServer(mockConfig, msr, &tls.Config{})
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
assert.NoError(t, err)
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
listenerport := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(listener)
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", listenerport))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: ping.example.com\r\n\r\n"))
|
||||
|
||||
@@ -36,6 +36,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode { return m.Called().Get(0).(types.ServerMode) }
|
||||
|
||||
@@ -45,6 +45,7 @@ func (m *MockConfig) ACMEStaging() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) AllowedPortsStart() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) AllowedPortsEnd() uint16 { return uint16(m.Called().Int(0)) }
|
||||
func (m *MockConfig) BufferSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) HeaderSize() int { return m.Called().Int(0) }
|
||||
func (m *MockConfig) PprofEnabled() bool { return m.Called().Bool(0) }
|
||||
func (m *MockConfig) PprofPort() string { return m.Called().String(0) }
|
||||
func (m *MockConfig) Mode() types.ServerMode {
|
||||
|
||||
Reference in New Issue
Block a user