refactor(server): enhance HTTP handler modularity and fix resource leak
Docker Build and Push / build-and-push-tags (push) Has been skipped
Docker Build and Push / build-and-push-branches (push) Successful in 11m43s

- Rename customWriter struct to httpWriter for clarity
- Add closeWriter field to properly close write side of connections
- Update all cw variable references to hw
- Merge handlerTLS into handler function to reduce code duplication
- Extract handler into smaller, focused methods
- Split Read/Write/forwardRequest into composable functions

Fixes resource leak where connections weren't properly closed on the
write side, matching the forwarder's CloseWrite() pattern.
This commit is contained in:
2026-01-19 22:41:04 +07:00
parent adb0264bb5
commit 27f49879af
5 changed files with 428 additions and 432 deletions
+250
View File
@@ -0,0 +1,250 @@
package server
import (
"bytes"
"io"
"log"
"net"
"regexp"
)
type HTTPWriter interface {
io.ReadWriteCloser
CloseWrite() error
RemoteAddr() net.Addr
UseResponseMiddleware(mw ResponseMiddleware)
UseRequestMiddleware(mw RequestMiddleware)
SetRequestHeader(header RequestHeaderManager)
RequestMiddlewares() []RequestMiddleware
ResponseMiddlewares() []ResponseMiddleware
ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error
ApplyRequestMiddlewares(reqhf RequestHeaderManager) error
}
type httpWriter struct {
remoteAddr net.Addr
writer io.Writer
reader io.Reader
headerBuf []byte
buf []byte
respHeader ResponseHeaderManager
reqHeader RequestHeaderManager
respMW []ResponseMiddleware
reqMW []RequestMiddleware
}
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
func (hw *httpWriter) RemoteAddr() net.Addr {
return hw.remoteAddr
}
func (hw *httpWriter) UseResponseMiddleware(mw ResponseMiddleware) {
hw.respMW = append(hw.respMW, mw)
}
func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) {
hw.reqMW = append(hw.reqMW, mw)
}
func (hw *httpWriter) SetRequestHeader(header RequestHeaderManager) {
hw.reqHeader = header
}
func (hw *httpWriter) RequestMiddlewares() []RequestMiddleware {
return hw.reqMW
}
func (hw *httpWriter) ResponseMiddlewares() []ResponseMiddleware {
return hw.respMW
}
func (hw *httpWriter) Close() error {
return hw.writer.(io.Closer).Close()
}
func (hw *httpWriter) CloseWrite() error {
if closer, ok := hw.writer.(interface{ CloseWrite() error }); ok {
return closer.CloseWrite()
}
return hw.Close()
}
func (hw *httpWriter) Read(p []byte) (int, error) {
tmp := make([]byte, len(p))
read, err := hw.reader.Read(tmp)
if read == 0 && err != nil {
return 0, err
}
tmp = tmp[:read]
headerEndIdx := bytes.Index(tmp, DELIMITER)
if headerEndIdx == -1 {
return hw.handleNoDelimiter(p, tmp, err)
}
header, body := hw.splitHeaderAndBody(tmp, headerEndIdx)
if !isHTTPHeader(header) {
copy(p, tmp)
return read, nil
}
return hw.processHTTPRequest(p, header, body)
}
func (hw *httpWriter) handleNoDelimiter(p, tmp []byte, err error) (int, error) {
copy(p, tmp)
return len(tmp), err
}
func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
header := data[:delimiterIdx+len(DELIMITER)]
body := data[delimiterIdx+len(DELIMITER):]
return header, body
}
func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
reqhf, err := NewRequestHeaderFactory(header)
if err != nil {
return 0, err
}
if err = hw.ApplyRequestMiddlewares(reqhf); err != nil {
return 0, err
}
hw.reqHeader = reqhf
combined := append(reqhf.Finalize(), body...)
return copy(p, combined), nil
}
func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeaderManager) error {
for _, m := range hw.RequestMiddlewares() {
if err := m.HandleRequest(reqhf); err != nil {
log.Printf("Error when applying request middleware: %v", err)
return err
}
}
return nil
}
func (hw *httpWriter) Write(p []byte) (int, error) {
if hw.shouldBypassBuffering(p) {
hw.respHeader = nil
}
if hw.respHeader != nil {
return hw.writer.Write(p)
}
hw.buf = append(hw.buf, p...)
headerEndIdx := bytes.Index(hw.buf, DELIMITER)
if headerEndIdx == -1 {
return len(p), nil
}
return hw.processBufferedResponse(p, headerEndIdx)
}
func (hw *httpWriter) shouldBypassBuffering(p []byte) bool {
return hw.respHeader != nil && len(hw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
}
func (hw *httpWriter) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
header, body := hw.splitHeaderAndBody(hw.buf, delimiterIdx)
if !isHTTPHeader(header) {
return hw.writeRawBuffer()
}
if err := hw.processHTTPResponse(header, body); err != nil {
return 0, err
}
hw.buf = nil
return len(p), nil
}
func (hw *httpWriter) writeRawBuffer() (int, error) {
_, err := hw.writer.Write(hw.buf)
length := len(hw.buf)
hw.buf = nil
if err != nil {
return 0, err
}
return length, nil
}
func (hw *httpWriter) processHTTPResponse(header, body []byte) error {
resphf := NewResponseHeaderFactory(header)
if err := hw.ApplyResponseMiddlewares(resphf, body); err != nil {
return err
}
hw.respHeader = resphf
finalHeader := resphf.Finalize()
if err := hw.writeHeaderAndBody(finalHeader, body); err != nil {
return err
}
return nil
}
func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error {
for _, m := range hw.ResponseMiddlewares() {
if err := m.HandleResponse(resphf, body); err != nil {
log.Printf("Cannot apply middleware: %s\n", err)
return err
}
}
return nil
}
func (hw *httpWriter) writeHeaderAndBody(header, body []byte) error {
if _, err := hw.writer.Write(header); err != nil {
return err
}
if len(body) > 0 {
if _, err := hw.writer.Write(body); err != nil {
return err
}
}
return nil
}
func NewHTTPWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
return &httpWriter{
remoteAddr: remoteAddr,
writer: writer,
reader: reader,
buf: make([]byte, 0, 4096),
}
}
func isHTTPHeader(buf []byte) bool {
lines := bytes.Split(buf, []byte("\r\n"))
startLine := string(lines[0])
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
return false
}
for _, line := range lines[1:] {
if len(line) == 0 {
break
}
colonIdx := bytes.IndexByte(line, ':')
if colonIdx <= 0 {
return false
}
}
return true
}