refactor: optimize header parsing and remove factory naming
- Remove factory naming - Use direct byte indexing instead of bytes.TrimRight - Extract parseStartLine and setRemainingHeaders helpers
This commit is contained in:
+111
-133
@@ -6,22 +6,20 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type HeaderManager interface {
|
||||
Get(key string) []byte
|
||||
Set(key string, value []byte)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
}
|
||||
|
||||
type ResponseHeaderManager interface {
|
||||
Get(key string) string
|
||||
type ResponseHeader interface {
|
||||
Value(key string) string
|
||||
Set(key string, value string)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
}
|
||||
|
||||
type RequestHeaderManager interface {
|
||||
Get(key string) string
|
||||
type responseHeader struct {
|
||||
startLine []byte
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
type RequestHeader interface {
|
||||
Value(key string) string
|
||||
Set(key string, value string)
|
||||
Remove(key string)
|
||||
Finalize() []byte
|
||||
@@ -29,13 +27,7 @@ type RequestHeaderManager interface {
|
||||
GetPath() string
|
||||
GetVersion() string
|
||||
}
|
||||
|
||||
type responseHeaderFactory struct {
|
||||
startLine []byte
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
type requestHeaderFactory struct {
|
||||
type requestHeader struct {
|
||||
method string
|
||||
path string
|
||||
version string
|
||||
@@ -43,7 +35,7 @@ type requestHeaderFactory struct {
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) {
|
||||
func NewRequestHeader(r interface{}) (RequestHeader, error) {
|
||||
switch v := r.(type) {
|
||||
case []byte:
|
||||
return parseHeadersFromBytes(v)
|
||||
@@ -54,38 +46,16 @@ func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) {
|
||||
header := &requestHeaderFactory{
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
lineEnd := bytes.IndexByte(headerData, '\n')
|
||||
if lineEnd == -1 {
|
||||
return nil, fmt.Errorf("invalid request: no newline found")
|
||||
}
|
||||
|
||||
startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n")
|
||||
header.startLine = make([]byte, len(startLine))
|
||||
copy(header.startLine, startLine)
|
||||
|
||||
parts := bytes.Split(startLine, []byte{' '})
|
||||
if len(parts) < 3 {
|
||||
return nil, fmt.Errorf("invalid request line")
|
||||
}
|
||||
|
||||
header.method = string(parts[0])
|
||||
header.path = string(parts[1])
|
||||
header.version = string(parts[2])
|
||||
|
||||
remaining := headerData[lineEnd+1:]
|
||||
|
||||
func setRemainingHeaders(remaining []byte, header interface {
|
||||
Set(key string, value string)
|
||||
}) {
|
||||
for len(remaining) > 0 {
|
||||
lineEnd = bytes.IndexByte(remaining, '\n')
|
||||
lineEnd := bytes.Index(remaining, []byte("\r\n"))
|
||||
if lineEnd == -1 {
|
||||
lineEnd = len(remaining)
|
||||
}
|
||||
|
||||
line := bytes.TrimRight(remaining[:lineEnd], "\r\n")
|
||||
line := remaining[:lineEnd]
|
||||
|
||||
if len(line) == 0 {
|
||||
break
|
||||
@@ -95,63 +65,84 @@ func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) {
|
||||
if colonIdx != -1 {
|
||||
key := bytes.TrimSpace(line[:colonIdx])
|
||||
value := bytes.TrimSpace(line[colonIdx+1:])
|
||||
header.headers[string(key)] = string(value)
|
||||
header.Set(string(key), string(value))
|
||||
}
|
||||
|
||||
if lineEnd == len(remaining) {
|
||||
break
|
||||
}
|
||||
remaining = remaining[lineEnd+1:]
|
||||
|
||||
remaining = remaining[lineEnd+2:]
|
||||
}
|
||||
}
|
||||
|
||||
func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) {
|
||||
header := &requestHeader{
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
||||
if lineEnd == -1 {
|
||||
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
|
||||
}
|
||||
|
||||
startLine := headerData[:lineEnd]
|
||||
header.startLine = startLine
|
||||
var err error
|
||||
header.method, header.path, header.version, err = parseStartLine(startLine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
remaining := headerData[lineEnd+2:]
|
||||
|
||||
setRemainingHeaders(remaining, header)
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) {
|
||||
header := &requestHeaderFactory{
|
||||
func parseStartLine(startLine []byte) (method, path, version string, err error) {
|
||||
firstSpace := bytes.IndexByte(startLine, ' ')
|
||||
if firstSpace == -1 {
|
||||
return "", "", "", fmt.Errorf("invalid start line: missing method")
|
||||
}
|
||||
|
||||
secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ')
|
||||
if secondSpace == -1 {
|
||||
return "", "", "", fmt.Errorf("invalid start line: missing version")
|
||||
}
|
||||
secondSpace += firstSpace + 1
|
||||
|
||||
method = string(startLine[:firstSpace])
|
||||
path = string(startLine[firstSpace+1 : secondSpace])
|
||||
version = string(startLine[secondSpace+1:])
|
||||
|
||||
return method, path, version, nil
|
||||
}
|
||||
|
||||
func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) {
|
||||
header := &requestHeader{
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
|
||||
startLineBytes, err := br.ReadSlice('\n')
|
||||
if err != nil {
|
||||
if err == bufio.ErrBufferFull {
|
||||
var startLine string
|
||||
startLine, err = br.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
startLineBytes = []byte(startLine)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startLineBytes = bytes.TrimRight(startLineBytes, "\r\n")
|
||||
header.startLine = make([]byte, len(startLineBytes))
|
||||
copy(header.startLine, startLineBytes)
|
||||
|
||||
parts := bytes.Split(startLineBytes, []byte{' '})
|
||||
if len(parts) < 3 {
|
||||
return nil, fmt.Errorf("invalid request line")
|
||||
header.method, header.path, header.version, err = parseStartLine(header.startLine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header.method = string(parts[0])
|
||||
header.path = string(parts[1])
|
||||
header.version = string(parts[2])
|
||||
|
||||
for {
|
||||
lineBytes, err := br.ReadSlice('\n')
|
||||
if err != nil {
|
||||
if err == bufio.ErrBufferFull {
|
||||
var line string
|
||||
line, err = br.ReadString('\n')
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lineBytes = []byte(line)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lineBytes = bytes.TrimRight(lineBytes, "\r\n")
|
||||
@@ -174,63 +165,63 @@ func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) {
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager {
|
||||
header := &responseHeaderFactory{
|
||||
func NewResponseHeader(headerData []byte) (ResponseHeader, error) {
|
||||
header := &responseHeader{
|
||||
startLine: nil,
|
||||
headers: make(map[string]string),
|
||||
headers: make(map[string]string, 16),
|
||||
}
|
||||
lines := bytes.Split(startLine, []byte("\r\n"))
|
||||
if len(lines) == 0 {
|
||||
return header
|
||||
}
|
||||
header.startLine = lines[0]
|
||||
for _, h := range lines[1:] {
|
||||
if len(h) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := bytes.SplitN(h, []byte(":"), 2)
|
||||
if len(parts) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
val := bytes.TrimSpace(parts[1])
|
||||
header.headers[string(key)] = string(val)
|
||||
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
||||
if lineEnd == -1 {
|
||||
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
|
||||
}
|
||||
return header
|
||||
|
||||
header.startLine = headerData[:lineEnd]
|
||||
remaining := headerData[lineEnd+2:]
|
||||
setRemainingHeaders(remaining, header)
|
||||
|
||||
return header, nil
|
||||
}
|
||||
|
||||
func (resp *responseHeaderFactory) Get(key string) string {
|
||||
func (resp *responseHeader) Value(key string) string {
|
||||
return resp.headers[key]
|
||||
}
|
||||
|
||||
func (resp *responseHeaderFactory) Set(key string, value string) {
|
||||
func (resp *responseHeader) Set(key string, value string) {
|
||||
resp.headers[key] = value
|
||||
}
|
||||
|
||||
func (resp *responseHeaderFactory) Remove(key string) {
|
||||
func (resp *responseHeader) Remove(key string) {
|
||||
delete(resp.headers, key)
|
||||
}
|
||||
|
||||
func (resp *responseHeaderFactory) Finalize() []byte {
|
||||
var buf bytes.Buffer
|
||||
func finalize(startLine []byte, headers map[string]string) []byte {
|
||||
size := len(startLine) + 2
|
||||
for key, val := range headers {
|
||||
size += len(key) + 2 + len(val) + 2
|
||||
}
|
||||
size += 2
|
||||
|
||||
buf.Write(resp.startLine)
|
||||
buf.WriteString("\r\n")
|
||||
buf := make([]byte, 0, size)
|
||||
buf = append(buf, startLine...)
|
||||
buf = append(buf, '\r', '\n')
|
||||
|
||||
for key, val := range resp.headers {
|
||||
buf.WriteString(key)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(val)
|
||||
buf.WriteString("\r\n")
|
||||
for key, val := range headers {
|
||||
buf = append(buf, key...)
|
||||
buf = append(buf, ':', ' ')
|
||||
buf = append(buf, val...)
|
||||
buf = append(buf, '\r', '\n')
|
||||
}
|
||||
|
||||
buf.WriteString("\r\n")
|
||||
return buf.Bytes()
|
||||
buf = append(buf, '\r', '\n')
|
||||
return buf
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) Get(key string) string {
|
||||
func (resp *responseHeader) Finalize() []byte {
|
||||
return finalize(resp.startLine, resp.headers)
|
||||
}
|
||||
|
||||
func (req *requestHeader) Value(key string) string {
|
||||
val, ok := req.headers[key]
|
||||
if !ok {
|
||||
return ""
|
||||
@@ -238,39 +229,26 @@ func (req *requestHeaderFactory) Get(key string) string {
|
||||
return val
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) Set(key string, value string) {
|
||||
func (req *requestHeader) Set(key string, value string) {
|
||||
req.headers[key] = value
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) Remove(key string) {
|
||||
func (req *requestHeader) Remove(key string) {
|
||||
delete(req.headers, key)
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) GetMethod() string {
|
||||
func (req *requestHeader) GetMethod() string {
|
||||
return req.method
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) GetPath() string {
|
||||
func (req *requestHeader) GetPath() string {
|
||||
return req.path
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) GetVersion() string {
|
||||
func (req *requestHeader) GetVersion() string {
|
||||
return req.version
|
||||
}
|
||||
|
||||
func (req *requestHeaderFactory) Finalize() []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
buf.Write(req.startLine)
|
||||
buf.WriteString("\r\n")
|
||||
|
||||
for key, val := range req.headers {
|
||||
buf.WriteString(key)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(val)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
|
||||
buf.WriteString("\r\n")
|
||||
return buf.Bytes()
|
||||
func (req *requestHeader) Finalize() []byte {
|
||||
return finalize(req.startLine, req.headers)
|
||||
}
|
||||
|
||||
+5
-5
@@ -112,7 +112,7 @@ func (hs *httpServer) handler(conn net.Conn, isTLS bool) {
|
||||
defer hs.closeConnection(conn)
|
||||
|
||||
dstReader := bufio.NewReader(conn)
|
||||
reqhf, err := NewRequestHeaderFactory(dstReader)
|
||||
reqhf, err := NewRequestHeader(dstReader)
|
||||
if err != nil {
|
||||
log.Printf("Error creating request header: %v", err)
|
||||
return
|
||||
@@ -150,8 +150,8 @@ func (hs *httpServer) closeConnection(conn net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (hs *httpServer) extractSlug(reqhf RequestHeaderManager) (string, error) {
|
||||
host := strings.Split(reqhf.Get("Host"), ".")
|
||||
func (hs *httpServer) extractSlug(reqhf RequestHeader) (string, error) {
|
||||
host := strings.Split(reqhf.Value("Host"), ".")
|
||||
if len(host) < 1 {
|
||||
return "", errors.New("invalid host")
|
||||
}
|
||||
@@ -193,7 +193,7 @@ func (hs *httpServer) getSession(slug string) (session.Session, error) {
|
||||
return sshSession, nil
|
||||
}
|
||||
|
||||
func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) {
|
||||
func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeader, sshSession session.Session) {
|
||||
channel, err := hs.openForwardedChannel(hw, sshSession)
|
||||
if err != nil {
|
||||
log.Printf("Failed to establish channel: %v", err)
|
||||
@@ -260,7 +260,7 @@ func (hs *httpServer) setupMiddlewares(hw HTTPWriter) {
|
||||
hw.UseRequestMiddleware(forwardedForMiddleware)
|
||||
}
|
||||
|
||||
func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeaderManager, channel ssh.Channel) error {
|
||||
func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeader, channel ssh.Channel) error {
|
||||
hw.SetRequestHeader(initialRequest)
|
||||
|
||||
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
|
||||
|
||||
+15
-12
@@ -14,11 +14,11 @@ type HTTPWriter interface {
|
||||
RemoteAddr() net.Addr
|
||||
UseResponseMiddleware(mw ResponseMiddleware)
|
||||
UseRequestMiddleware(mw RequestMiddleware)
|
||||
SetRequestHeader(header RequestHeaderManager)
|
||||
SetRequestHeader(header RequestHeader)
|
||||
RequestMiddlewares() []RequestMiddleware
|
||||
ResponseMiddlewares() []ResponseMiddleware
|
||||
ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error
|
||||
ApplyRequestMiddlewares(reqhf RequestHeaderManager) error
|
||||
ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error
|
||||
ApplyRequestMiddlewares(reqhf RequestHeader) error
|
||||
}
|
||||
|
||||
type httpWriter struct {
|
||||
@@ -27,8 +27,8 @@ type httpWriter struct {
|
||||
reader io.Reader
|
||||
headerBuf []byte
|
||||
buf []byte
|
||||
respHeader ResponseHeaderManager
|
||||
reqHeader RequestHeaderManager
|
||||
respHeader ResponseHeader
|
||||
reqHeader RequestHeader
|
||||
respMW []ResponseMiddleware
|
||||
reqMW []RequestMiddleware
|
||||
}
|
||||
@@ -49,7 +49,7 @@ func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) {
|
||||
hw.reqMW = append(hw.reqMW, mw)
|
||||
}
|
||||
|
||||
func (hw *httpWriter) SetRequestHeader(header RequestHeaderManager) {
|
||||
func (hw *httpWriter) SetRequestHeader(header RequestHeader) {
|
||||
hw.reqHeader = header
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte,
|
||||
}
|
||||
|
||||
func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
|
||||
reqhf, err := NewRequestHeaderFactory(header)
|
||||
reqhf, err := NewRequestHeader(header)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -121,7 +121,7 @@ func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
|
||||
return copy(p, combined), nil
|
||||
}
|
||||
|
||||
func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeaderManager) error {
|
||||
func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeader) error {
|
||||
for _, m := range hw.RequestMiddlewares() {
|
||||
if err := m.HandleRequest(reqhf); err != nil {
|
||||
log.Printf("Error when applying request middleware: %v", err)
|
||||
@@ -180,23 +180,26 @@ func (hw *httpWriter) writeRawBuffer() (int, error) {
|
||||
}
|
||||
|
||||
func (hw *httpWriter) processHTTPResponse(header, body []byte) error {
|
||||
resphf := NewResponseHeaderFactory(header)
|
||||
resphf, err := NewResponseHeader(header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := hw.ApplyResponseMiddlewares(resphf, body); err != nil {
|
||||
if err = hw.ApplyResponseMiddlewares(resphf, body); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hw.respHeader = resphf
|
||||
finalHeader := resphf.Finalize()
|
||||
|
||||
if err := hw.writeHeaderAndBody(finalHeader, body); err != nil {
|
||||
if err = hw.writeHeaderAndBody(finalHeader, body); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error {
|
||||
func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error {
|
||||
for _, m := range hw.ResponseMiddlewares() {
|
||||
if err := m.HandleResponse(resphf, body); err != nil {
|
||||
log.Printf("Cannot apply middleware: %s\n", err)
|
||||
|
||||
@@ -5,11 +5,11 @@ import (
|
||||
)
|
||||
|
||||
type RequestMiddleware interface {
|
||||
HandleRequest(header RequestHeaderManager) error
|
||||
HandleRequest(header RequestHeader) error
|
||||
}
|
||||
|
||||
type ResponseMiddleware interface {
|
||||
HandleResponse(header ResponseHeaderManager, body []byte) error
|
||||
HandleResponse(header ResponseHeader, body []byte) error
|
||||
}
|
||||
|
||||
type TunnelFingerprint struct{}
|
||||
@@ -18,7 +18,7 @@ func NewTunnelFingerprint() *TunnelFingerprint {
|
||||
return &TunnelFingerprint{}
|
||||
}
|
||||
|
||||
func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error {
|
||||
func (h *TunnelFingerprint) HandleResponse(header ResponseHeader, body []byte) error {
|
||||
header.Set("Server", "Tunnel Please")
|
||||
return nil
|
||||
}
|
||||
@@ -31,7 +31,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor {
|
||||
return &ForwardedFor{addr: addr}
|
||||
}
|
||||
|
||||
func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error {
|
||||
func (ff *ForwardedFor) HandleRequest(header RequestHeader) error {
|
||||
host, _, err := net.SplitHostPort(ff.addr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
Reference in New Issue
Block a user