diff --git a/server/header.go b/server/header.go index ec0c224..584394b 100644 --- a/server/header.go +++ b/server/header.go @@ -4,7 +4,6 @@ import ( "bufio" "bytes" "fmt" - "strings" ) type HeaderManager interface { @@ -44,43 +43,132 @@ type requestHeaderFactory struct { headers map[string]string } -func NewRequestHeaderFactory(br *bufio.Reader) (RequestHeaderManager, error) { +func NewRequestHeaderFactory(r interface{}) (RequestHeaderManager, error) { + switch v := r.(type) { + case []byte: + return parseHeadersFromBytes(v) + case *bufio.Reader: + return parseHeadersFromReader(v) + default: + return nil, fmt.Errorf("unsupported type: %T", r) + } +} + +func parseHeadersFromBytes(headerData []byte) (RequestHeaderManager, error) { header := &requestHeaderFactory{ - headers: make(map[string]string), + headers: make(map[string]string, 16), } - startLine, err := br.ReadString('\n') - if err != nil { - return nil, err + lineEnd := bytes.IndexByte(headerData, '\n') + if lineEnd == -1 { + return nil, fmt.Errorf("invalid request: no newline found") } - startLine = strings.TrimRight(startLine, "\r\n") - header.startLine = []byte(startLine) - parts := strings.Split(startLine, " ") + startLine := bytes.TrimRight(headerData[:lineEnd], "\r\n") + header.startLine = make([]byte, len(startLine)) + copy(header.startLine, startLine) + + parts := bytes.Split(startLine, []byte{' '}) if len(parts) < 3 { return nil, fmt.Errorf("invalid request line") } - header.method = parts[0] - header.path = parts[1] - header.version = parts[2] + header.method = string(parts[0]) + header.path = string(parts[1]) + header.version = string(parts[2]) - for { - line, err := br.ReadString('\n') - if err != nil { - return nil, err + remaining := headerData[lineEnd+1:] + + for len(remaining) > 0 { + lineEnd = bytes.IndexByte(remaining, '\n') + if lineEnd == -1 { + lineEnd = len(remaining) } - line = strings.TrimRight(line, "\r\n") - if line == "" { + line := bytes.TrimRight(remaining[:lineEnd], "\r\n") + + if len(line) == 0 { break } - kv := strings.SplitN(line, ":", 2) - if len(kv) != 2 { + colonIdx := bytes.IndexByte(line, ':') + if colonIdx != -1 { + key := bytes.TrimSpace(line[:colonIdx]) + value := bytes.TrimSpace(line[colonIdx+1:]) + header.headers[string(key)] = string(value) + } + + if lineEnd == len(remaining) { + break + } + remaining = remaining[lineEnd+1:] + } + + return header, nil +} + +func parseHeadersFromReader(br *bufio.Reader) (RequestHeaderManager, error) { + header := &requestHeaderFactory{ + headers: make(map[string]string, 16), + } + + startLineBytes, err := br.ReadSlice('\n') + if err != nil { + if err == bufio.ErrBufferFull { + var startLine string + startLine, err = br.ReadString('\n') + if err != nil { + return nil, err + } + startLineBytes = []byte(startLine) + } else { + return nil, err + } + } + + startLineBytes = bytes.TrimRight(startLineBytes, "\r\n") + header.startLine = make([]byte, len(startLineBytes)) + copy(header.startLine, startLineBytes) + + parts := bytes.Split(startLineBytes, []byte{' '}) + if len(parts) < 3 { + return nil, fmt.Errorf("invalid request line") + } + + header.method = string(parts[0]) + header.path = string(parts[1]) + header.version = string(parts[2]) + + for { + lineBytes, err := br.ReadSlice('\n') + if err != nil { + if err == bufio.ErrBufferFull { + var line string + line, err = br.ReadString('\n') + if err != nil { + return nil, err + } + lineBytes = []byte(line) + } else { + return nil, err + } + } + + lineBytes = bytes.TrimRight(lineBytes, "\r\n") + + if len(lineBytes) == 0 { + break + } + + colonIdx := bytes.IndexByte(lineBytes, ':') + if colonIdx == -1 { continue } - header.headers[strings.TrimSpace(kv[0])] = strings.TrimSpace(kv[1]) + + key := bytes.TrimSpace(lineBytes[:colonIdx]) + value := bytes.TrimSpace(lineBytes[colonIdx+1:]) + + header.headers[string(key)] = string(value) } return header, nil diff --git a/server/http.go b/server/http.go index 7b58d8c..433b9a0 100644 --- a/server/http.go +++ b/server/http.go @@ -99,8 +99,7 @@ func (cw *customWriter) Read(p []byte) (int, error) { } } - headerReader := bufio.NewReader(bytes.NewReader(header)) - reqhf, err := NewRequestHeaderFactory(headerReader) + reqhf, err := NewRequestHeaderFactory(header) if err != nil { return 0, err }