diff --git a/server/header.go b/server/header.go index 584394b..bc3ce73 100644 --- a/server/header.go +++ b/server/header.go @@ -6,22 +6,20 @@ import ( "fmt" ) -type HeaderManager interface { - Get(key string) []byte - Set(key string, value []byte) - Remove(key string) - Finalize() []byte -} - -type ResponseHeaderManager interface { - Get(key string) string +type ResponseHeader interface { + Value(key string) string Set(key string, value string) Remove(key string) Finalize() []byte } -type RequestHeaderManager interface { - Get(key string) string +type responseHeader struct { + startLine []byte + headers map[string]string +} + +type RequestHeader interface { + Value(key string) string Set(key string, value string) Remove(key string) Finalize() []byte @@ -29,13 +27,7 @@ type RequestHeaderManager interface { GetPath() string GetVersion() string } - -type responseHeaderFactory struct { - startLine []byte - headers map[string]string -} - -type requestHeaderFactory struct { +type requestHeader struct { method string path string version string @@ -43,7 +35,7 @@ type requestHeaderFactory struct { headers map[string]string } -func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) { +func NewRequestHeader(r interface{}) (RequestHeader, error) { switch v := r.(type) { case []byte: return parseHeadersFromBytes(v) @@ -54,38 +46,16 @@ func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) { } } -func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) { - header := &requestHeaderFactory{ - headers: make(map[string]string, 16), - } - - lineEnd := bytes.IndexByte(headerData, '\n') - if lineEnd == -1 { - return nil, fmt.Errorf("invalid request: no newline found") - } - - startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n") - header.startLine = make([]byte, len(startLine)) - copy(header.startLine, startLine) - - parts := bytes.Split(startLine, []byte{' '}) - if len(parts) < 3 { - return nil, fmt.Errorf("invalid request line") - } - - header.method = string(parts[0]) - header.path = string(parts[1]) - header.version = string(parts[2]) - - remaining := headerData[lineEnd+1:] - +func setRemainingHeaders(remaining []byte, header interface { + Set(key string, value string) +}) { for len(remaining) > 0 { - lineEnd = bytes.IndexByte(remaining, '\n') + lineEnd := bytes.Index(remaining, []byte("\r\n")) if lineEnd == -1 { lineEnd = len(remaining) } - line := bytes.TrimRight(remaining[:lineEnd], "\r\n") + line := remaining[:lineEnd] if len(line) == 0 { break @@ -95,63 +65,84 @@ func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) { if colonIdx != -1 { key := bytes.TrimSpace(line[:colonIdx]) value := bytes.TrimSpace(line[colonIdx+1:]) - header.headers[string(key)] = string(value) + header.Set(string(key), string(value)) } if lineEnd == len(remaining) { break } - remaining = remaining[lineEnd+1:] + + remaining = remaining[lineEnd+2:] } +} + +func parseHeadersFromBytes(headerData []byte) (RequestHeader, error) { + header := &requestHeader{ + headers: make(map[string]string, 16), + } + + lineEnd := bytes.Index(headerData, []byte("\r\n")) + if lineEnd == -1 { + return nil, fmt.Errorf("invalid request: no CRLF found in start line") + } + + startLine := headerData[:lineEnd] + header.startLine = startLine + var err error + header.method, header.path, header.version, err = parseStartLine(startLine) + if err != nil { + return nil, err + } + + remaining := headerData[lineEnd+2:] + + setRemainingHeaders(remaining, header) return header, nil } -func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) { - header := &requestHeaderFactory{ +func parseStartLine(startLine []byte) (method, path, version string, err error) { + firstSpace := bytes.IndexByte(startLine, ' ') + if firstSpace == -1 { + return "", "", "", fmt.Errorf("invalid start line: missing method") + } + + secondSpace := bytes.IndexByte(startLine[firstSpace+1:], ' ') + if secondSpace == -1 { + return "", "", "", fmt.Errorf("invalid start line: missing version") + } + secondSpace += firstSpace + 1 + + method = string(startLine[:firstSpace]) + path = string(startLine[firstSpace+1 : secondSpace]) + version = string(startLine[secondSpace+1:]) + + return method, path, version, nil +} + +func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) { + header := &requestHeader{ headers: make(map[string]string, 16), } startLineBytes, err := br.ReadSlice('\n') if err != nil { - if err == bufio.ErrBufferFull { - var startLine string - startLine, err = br.ReadString('\n') - if err != nil { - return nil, err - } - startLineBytes = []byte(startLine) - } else { - return nil, err - } + return nil, err } startLineBytes = bytes.TrimRight(startLineBytes, "\r\n") header.startLine = make([]byte, len(startLineBytes)) copy(header.startLine, startLineBytes) - parts := bytes.Split(startLineBytes, []byte{' '}) - if len(parts) < 3 { - return nil, fmt.Errorf("invalid request line") + header.method, header.path, header.version, err = parseStartLine(header.startLine) + if err != nil { + return nil, err } - header.method = string(parts[0]) - header.path = string(parts[1]) - header.version = string(parts[2]) - for { lineBytes, err := br.ReadSlice('\n') if err != nil { - if err == bufio.ErrBufferFull { - var line string - line, err = br.ReadString('\n') - if err != nil { - return nil, err - } - lineBytes = []byte(line) - } else { - return nil, err - } + return nil, err } lineBytes = bytes.TrimRight(lineBytes, "\r\n") @@ -174,63 +165,63 @@ func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) { return header, nil } -func NewResponseHeaderFactory(startLine []byte) ResponseHeaderManager { - header := &responseHeaderFactory{ +func NewResponseHeader(headerData []byte) (ResponseHeader, error) { + header := &responseHeader{ startLine: nil, - headers: make(map[string]string), + headers: make(map[string]string, 16), } - lines := bytes.Split(startLine, []byte("\r\n")) - if len(lines) == 0 { - return header - } - header.startLine = lines[0] - for _, h := range lines[1:] { - if len(h) == 0 { - continue - } - parts := bytes.SplitN(h, []byte(":"), 2) - if len(parts) < 2 { - continue - } - - key := parts[0] - val := bytes.TrimSpace(parts[1]) - header.headers[string(key)] = string(val) + lineEnd := bytes.Index(headerData, []byte("\r\n")) + if lineEnd == -1 { + return nil, fmt.Errorf("invalid request: no CRLF found in start line") } - return header + + header.startLine = headerData[:lineEnd] + remaining := headerData[lineEnd+2:] + setRemainingHeaders(remaining, header) + + return header, nil } -func (resp *responseHeaderFactory) Get(key string) string { +func (resp *responseHeader) Value(key string) string { return resp.headers[key] } -func (resp *responseHeaderFactory) Set(key string, value string) { +func (resp *responseHeader) Set(key string, value string) { resp.headers[key] = value } -func (resp *responseHeaderFactory) Remove(key string) { +func (resp *responseHeader) Remove(key string) { delete(resp.headers, key) } -func (resp *responseHeaderFactory) Finalize() []byte { - var buf bytes.Buffer +func finalize(startLine []byte, headers map[string]string) []byte { + size := len(startLine) + 2 + for key, val := range headers { + size += len(key) + 2 + len(val) + 2 + } + size += 2 - buf.Write(resp.startLine) - buf.WriteString("\r\n") + buf := make([]byte, 0, size) + buf = append(buf, startLine...) + buf = append(buf, '\r', '\n') - for key, val := range resp.headers { - buf.WriteString(key) - buf.WriteString(": ") - buf.WriteString(val) - buf.WriteString("\r\n") + for key, val := range headers { + buf = append(buf, key...) + buf = append(buf, ':', ' ') + buf = append(buf, val...) + buf = append(buf, '\r', '\n') } - buf.WriteString("\r\n") - return buf.Bytes() + buf = append(buf, '\r', '\n') + return buf } -func (req *requestHeaderFactory) Get(key string) string { +func (resp *responseHeader) Finalize() []byte { + return finalize(resp.startLine, resp.headers) +} + +func (req *requestHeader) Value(key string) string { val, ok := req.headers[key] if !ok { return "" @@ -238,39 +229,26 @@ func (req *requestHeaderFactory) Get(key string) string { return val } -func (req *requestHeaderFactory) Set(key string, value string) { +func (req *requestHeader) Set(key string, value string) { req.headers[key] = value } -func (req *requestHeaderFactory) Remove(key string) { +func (req *requestHeader) Remove(key string) { delete(req.headers, key) } -func (req *requestHeaderFactory) GetMethod() string { +func (req *requestHeader) GetMethod() string { return req.method } -func (req *requestHeaderFactory) GetPath() string { +func (req *requestHeader) GetPath() string { return req.path } -func (req *requestHeaderFactory) GetVersion() string { +func (req *requestHeader) GetVersion() string { return req.version } -func (req *requestHeaderFactory) Finalize() []byte { - var buf bytes.Buffer - - buf.Write(req.startLine) - buf.WriteString("\r\n") - - for key, val := range req.headers { - buf.WriteString(key) - buf.WriteString(": ") - buf.WriteString(val) - buf.WriteString("\r\n") - } - - buf.WriteString("\r\n") - return buf.Bytes() +func (req *requestHeader) Finalize() []byte { + return finalize(req.startLine, req.headers) } diff --git a/server/http.go b/server/http.go index f685d39..e8da8a6 100644 --- a/server/http.go +++ b/server/http.go @@ -112,7 +112,7 @@ func (hs *httpServer) handler(conn net.Conn, isTLS bool) { defer hs.closeConnection(conn) dstReader := bufio.NewReader(conn) - reqhf, err := NewRequestHeaderFactory(dstReader) + reqhf, err := NewRequestHeader(dstReader) if err != nil { log.Printf("Error creating request header: %v", err) return @@ -150,8 +150,8 @@ func (hs *httpServer) closeConnection(conn net.Conn) { } } -func (hs *httpServer) extractSlug(reqhf RequestHeaderManager) (string, error) { - host := strings.Split(reqhf.Get("Host"), ".") +func (hs *httpServer) extractSlug(reqhf RequestHeader) (string, error) { + host := strings.Split(reqhf.Value("Host"), ".") if len(host) < 1 { return "", errors.New("invalid host") } @@ -193,7 +193,7 @@ func (hs *httpServer) getSession(slug string) (session.Session, error) { return sshSession, nil } -func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeaderManager, sshSession session.Session) { +func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeader, sshSession session.Session) { channel, err := hs.openForwardedChannel(hw, sshSession) if err != nil { log.Printf("Failed to establish channel: %v", err) @@ -260,7 +260,7 @@ func (hs *httpServer) setupMiddlewares(hw HTTPWriter) { hw.UseRequestMiddleware(forwardedForMiddleware) } -func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeaderManager, channel ssh.Channel) error { +func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeader, channel ssh.Channel) error { hw.SetRequestHeader(initialRequest) if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { diff --git a/server/httpwritter.go b/server/httpwritter.go index bde7452..64154d0 100644 --- a/server/httpwritter.go +++ b/server/httpwritter.go @@ -14,11 +14,11 @@ type HTTPWriter interface { RemoteAddr() net.Addr UseResponseMiddleware(mw ResponseMiddleware) UseRequestMiddleware(mw RequestMiddleware) - SetRequestHeader(header RequestHeaderManager) + SetRequestHeader(header RequestHeader) RequestMiddlewares() []RequestMiddleware ResponseMiddlewares() []ResponseMiddleware - ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error - ApplyRequestMiddlewares(reqhf RequestHeaderManager) error + ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error + ApplyRequestMiddlewares(reqhf RequestHeader) error } type httpWriter struct { @@ -27,8 +27,8 @@ type httpWriter struct { reader io.Reader headerBuf []byte buf []byte - respHeader ResponseHeaderManager - reqHeader RequestHeaderManager + respHeader ResponseHeader + reqHeader RequestHeader respMW []ResponseMiddleware reqMW []RequestMiddleware } @@ -49,7 +49,7 @@ func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) { hw.reqMW = append(hw.reqMW, mw) } -func (hw *httpWriter) SetRequestHeader(header RequestHeaderManager) { +func (hw *httpWriter) SetRequestHeader(header RequestHeader) { hw.reqHeader = header } @@ -107,7 +107,7 @@ func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, } func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { - reqhf, err := NewRequestHeaderFactory(header) + reqhf, err := NewRequestHeader(header) if err != nil { return 0, err } @@ -121,7 +121,7 @@ func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { return copy(p, combined), nil } -func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeaderManager) error { +func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeader) error { for _, m := range hw.RequestMiddlewares() { if err := m.HandleRequest(reqhf); err != nil { log.Printf("Error when applying request middleware: %v", err) @@ -180,23 +180,26 @@ func (hw *httpWriter) writeRawBuffer() (int, error) { } func (hw *httpWriter) processHTTPResponse(header, body []byte) error { - resphf := NewResponseHeaderFactory(header) + resphf, err := NewResponseHeader(header) + if err != nil { + return err + } - if err := hw.ApplyResponseMiddlewares(resphf, body); err != nil { + if err = hw.ApplyResponseMiddlewares(resphf, body); err != nil { return err } hw.respHeader = resphf finalHeader := resphf.Finalize() - if err := hw.writeHeaderAndBody(finalHeader, body); err != nil { + if err = hw.writeHeaderAndBody(finalHeader, body); err != nil { return err } return nil } -func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeaderManager, body []byte) error { +func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error { for _, m := range hw.ResponseMiddlewares() { if err := m.HandleResponse(resphf, body); err != nil { log.Printf("Cannot apply middleware: %s\n", err) diff --git a/server/middleware.go b/server/middleware.go index ee6ca1a..63b2467 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -5,11 +5,11 @@ import ( ) type RequestMiddleware interface { - HandleRequest(header RequestHeaderManager) error + HandleRequest(header RequestHeader) error } type ResponseMiddleware interface { - HandleResponse(header ResponseHeaderManager, body []byte) error + HandleResponse(header ResponseHeader, body []byte) error } type TunnelFingerprint struct{} @@ -18,7 +18,7 @@ func NewTunnelFingerprint() *TunnelFingerprint { return &TunnelFingerprint{} } -func (h *TunnelFingerprint) HandleResponse(header ResponseHeaderManager, body []byte) error { +func (h *TunnelFingerprint) HandleResponse(header ResponseHeader, body []byte) error { header.Set("Server", "Tunnel Please") return nil } @@ -31,7 +31,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor { return &ForwardedFor{addr: addr} } -func (ff *ForwardedFor) HandleRequest(header RequestHeaderManager) error { +func (ff *ForwardedFor) HandleRequest(header RequestHeader) error { host, _, err := net.SplitHostPort(ff.addr.String()) if err != nil { return err