- Rename customWriter struct to httpWriter for clarity - Add closeWriter field to properly close write side of connections - Update all cw variable references to hw - Merge handlerTLS into handler function to reduce code duplication - Extract handler into smaller, focused methods - Split Read/Write/forwardRequest into composable functions Fixes resource leak where connections weren't properly closed on the write side, matching the forwarder's CloseWrite() pattern.
251 lines
5.6 KiB
Go
251 lines
5.6 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"regexp"
|
|
)
|
|
|
|
type HTTPWriter interface {
|
|
io.ReadWriteCloser
|
|
CloseWrite() error
|
|
RemoteAddr() net.Addr
|
|
UseResponseMiddleware(mw ResponseMiddleware)
|
|
UseRequestMiddleware(mw RequestMiddleware)
|
|
SetRequestHeader(header RequestHeaderManager)
|
|
RequestMiddlewares() []RequestMiddleware
|
|
ResponseMiddlewares() []ResponseMiddleware
|
|
ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error
|
|
ApplyRequestMiddlewares(reqhf RequestHeaderManager) error
|
|
}
|
|
|
|
type httpWriter struct {
|
|
remoteAddr net.Addr
|
|
writer io.Writer
|
|
reader io.Reader
|
|
headerBuf []byte
|
|
buf []byte
|
|
respHeader ResponseHeaderManager
|
|
reqHeader RequestHeaderManager
|
|
respMW []ResponseMiddleware
|
|
reqMW []RequestMiddleware
|
|
}
|
|
|
|
var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A}
|
|
var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`)
|
|
var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`)
|
|
|
|
func (hw *httpWriter) RemoteAddr() net.Addr {
|
|
return hw.remoteAddr
|
|
}
|
|
|
|
func (hw *httpWriter) UseResponseMiddleware(mw ResponseMiddleware) {
|
|
hw.respMW = append(hw.respMW, mw)
|
|
}
|
|
|
|
func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) {
|
|
hw.reqMW = append(hw.reqMW, mw)
|
|
}
|
|
|
|
func (hw *httpWriter) SetRequestHeader(header RequestHeaderManager) {
|
|
hw.reqHeader = header
|
|
}
|
|
|
|
func (hw *httpWriter) RequestMiddlewares() []RequestMiddleware {
|
|
return hw.reqMW
|
|
}
|
|
|
|
func (hw *httpWriter) ResponseMiddlewares() []ResponseMiddleware {
|
|
return hw.respMW
|
|
}
|
|
func (hw *httpWriter) Close() error {
|
|
return hw.writer.(io.Closer).Close()
|
|
}
|
|
|
|
func (hw *httpWriter) CloseWrite() error {
|
|
if closer, ok := hw.writer.(interface{ CloseWrite() error }); ok {
|
|
return closer.CloseWrite()
|
|
}
|
|
return hw.Close()
|
|
}
|
|
|
|
func (hw *httpWriter) Read(p []byte) (int, error) {
|
|
tmp := make([]byte, len(p))
|
|
read, err := hw.reader.Read(tmp)
|
|
if read == 0 && err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
tmp = tmp[:read]
|
|
|
|
headerEndIdx := bytes.Index(tmp, DELIMITER)
|
|
if headerEndIdx == -1 {
|
|
return hw.handleNoDelimiter(p, tmp, err)
|
|
}
|
|
|
|
header, body := hw.splitHeaderAndBody(tmp, headerEndIdx)
|
|
|
|
if !isHTTPHeader(header) {
|
|
copy(p, tmp)
|
|
return read, nil
|
|
}
|
|
|
|
return hw.processHTTPRequest(p, header, body)
|
|
}
|
|
|
|
func (hw *httpWriter) handleNoDelimiter(p, tmp []byte, err error) (int, error) {
|
|
copy(p, tmp)
|
|
return len(tmp), err
|
|
}
|
|
|
|
func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, []byte) {
|
|
header := data[:delimiterIdx+len(DELIMITER)]
|
|
body := data[delimiterIdx+len(DELIMITER):]
|
|
return header, body
|
|
}
|
|
|
|
func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
|
|
reqhf, err := NewRequestHeaderFactory(header)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if err = hw.ApplyRequestMiddlewares(reqhf); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
hw.reqHeader = reqhf
|
|
combined := append(reqhf.Finalize(), body...)
|
|
return copy(p, combined), nil
|
|
}
|
|
|
|
func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeaderManager) error {
|
|
for _, m := range hw.RequestMiddlewares() {
|
|
if err := m.HandleRequest(reqhf); err != nil {
|
|
log.Printf("Error when applying request middleware: %v", err)
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (hw *httpWriter) Write(p []byte) (int, error) {
|
|
if hw.shouldBypassBuffering(p) {
|
|
hw.respHeader = nil
|
|
}
|
|
|
|
if hw.respHeader != nil {
|
|
return hw.writer.Write(p)
|
|
}
|
|
|
|
hw.buf = append(hw.buf, p...)
|
|
|
|
headerEndIdx := bytes.Index(hw.buf, DELIMITER)
|
|
if headerEndIdx == -1 {
|
|
return len(p), nil
|
|
}
|
|
|
|
return hw.processBufferedResponse(p, headerEndIdx)
|
|
}
|
|
|
|
func (hw *httpWriter) shouldBypassBuffering(p []byte) bool {
|
|
return hw.respHeader != nil && len(hw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/"
|
|
}
|
|
|
|
func (hw *httpWriter) processBufferedResponse(p []byte, delimiterIdx int) (int, error) {
|
|
header, body := hw.splitHeaderAndBody(hw.buf, delimiterIdx)
|
|
|
|
if !isHTTPHeader(header) {
|
|
return hw.writeRawBuffer()
|
|
}
|
|
|
|
if err := hw.processHTTPResponse(header, body); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
hw.buf = nil
|
|
return len(p), nil
|
|
}
|
|
|
|
func (hw *httpWriter) writeRawBuffer() (int, error) {
|
|
_, err := hw.writer.Write(hw.buf)
|
|
length := len(hw.buf)
|
|
hw.buf = nil
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return length, nil
|
|
}
|
|
|
|
func (hw *httpWriter) processHTTPResponse(header, body []byte) error {
|
|
resphf := NewResponseHeaderFactory(header)
|
|
|
|
if err := hw.ApplyResponseMiddlewares(resphf, body); err != nil {
|
|
return err
|
|
}
|
|
|
|
hw.respHeader = resphf
|
|
finalHeader := resphf.Finalize()
|
|
|
|
if err := hw.writeHeaderAndBody(finalHeader, body); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error {
|
|
for _, m := range hw.ResponseMiddlewares() {
|
|
if err := m.HandleResponse(resphf, body); err != nil {
|
|
log.Printf("Cannot apply middleware: %s\n", err)
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (hw *httpWriter) writeHeaderAndBody(header, body []byte) error {
|
|
if _, err := hw.writer.Write(header); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(body) > 0 {
|
|
if _, err := hw.writer.Write(body); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func NewHTTPWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) HTTPWriter {
|
|
return &httpWriter{
|
|
remoteAddr: remoteAddr,
|
|
writer: writer,
|
|
reader: reader,
|
|
buf: make([]byte, 0, 4096),
|
|
}
|
|
}
|
|
|
|
func isHTTPHeader(buf []byte) bool {
|
|
lines := bytes.Split(buf, []byte("\r\n"))
|
|
|
|
startLine := string(lines[0])
|
|
if !requestLine.MatchString(startLine) && !responseLine.MatchString(startLine) {
|
|
return false
|
|
}
|
|
|
|
for _, line := range lines[1:] {
|
|
if len(line) == 0 {
|
|
break
|
|
}
|
|
colonIdx := bytes.IndexByte(line, ':')
|
|
if colonIdx <= 0 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|