refactor(httpheader): extract header parsing into dedicated package
Moved HTTP header parsing and building logic from server package to internal/httpheader
This commit is contained in:
@@ -0,0 +1,30 @@
|
|||||||
|
package httpheader
|
||||||
|
|
||||||
|
type ResponseHeader interface {
|
||||||
|
Value(key string) string
|
||||||
|
Set(key string, value string)
|
||||||
|
Remove(key string)
|
||||||
|
Finalize() []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type responseHeader struct {
|
||||||
|
startLine []byte
|
||||||
|
headers map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type RequestHeader interface {
|
||||||
|
Value(key string) string
|
||||||
|
Set(key string, value string)
|
||||||
|
Remove(key string)
|
||||||
|
Finalize() []byte
|
||||||
|
GetMethod() string
|
||||||
|
GetPath() string
|
||||||
|
GetVersion() string
|
||||||
|
}
|
||||||
|
type requestHeader struct {
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
version string
|
||||||
|
startLine []byte
|
||||||
|
headers map[string]string
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package httpheader
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@@ -6,46 +6,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ResponseHeader interface {
|
|
||||||
Value(key string) string
|
|
||||||
Set(key string, value string)
|
|
||||||
Remove(key string)
|
|
||||||
Finalize() []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type responseHeader struct {
|
|
||||||
startLine []byte
|
|
||||||
headers map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
type RequestHeader interface {
|
|
||||||
Value(key string) string
|
|
||||||
Set(key string, value string)
|
|
||||||
Remove(key string)
|
|
||||||
Finalize() []byte
|
|
||||||
GetMethod() string
|
|
||||||
GetPath() string
|
|
||||||
GetVersion() string
|
|
||||||
}
|
|
||||||
type requestHeader struct {
|
|
||||||
method string
|
|
||||||
path string
|
|
||||||
version string
|
|
||||||
startLine []byte
|
|
||||||
headers map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRequestHeader(r interface{}) (RequestHeader, error) {
|
|
||||||
switch v := r.(type) {
|
|
||||||
case []byte:
|
|
||||||
return parseHeadersFromBytes(v)
|
|
||||||
case *bufio.Reader:
|
|
||||||
return parseHeadersFromReader(v)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported type: %T", r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setRemainingHeaders(remaining []byte, header interface {
|
func setRemainingHeaders(remaining []byte, header interface {
|
||||||
Set(key string, value string)
|
Set(key string, value string)
|
||||||
}) {
|
}) {
|
||||||
@@ -165,36 +125,6 @@ func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) {
|
|||||||
return header, nil
|
return header, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewResponseHeader(headerData []byte) (ResponseHeader, error) {
|
|
||||||
header := &responseHeader{
|
|
||||||
startLine: nil,
|
|
||||||
headers: make(map[string]string, 16),
|
|
||||||
}
|
|
||||||
|
|
||||||
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
|
||||||
if lineEnd == -1 {
|
|
||||||
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
|
|
||||||
}
|
|
||||||
|
|
||||||
header.startLine = headerData[:lineEnd]
|
|
||||||
remaining := headerData[lineEnd+2:]
|
|
||||||
setRemainingHeaders(remaining, header)
|
|
||||||
|
|
||||||
return header, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (resp *responseHeader) Value(key string) string {
|
|
||||||
return resp.headers[key]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (resp *responseHeader) Set(key string, value string) {
|
|
||||||
resp.headers[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (resp *responseHeader) Remove(key string) {
|
|
||||||
delete(resp.headers, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func finalize(startLine []byte, headers map[string]string) []byte {
|
func finalize(startLine []byte, headers map[string]string) []byte {
|
||||||
size := len(startLine) + 2
|
size := len(startLine) + 2
|
||||||
for key, val := range headers {
|
for key, val := range headers {
|
||||||
@@ -216,39 +146,3 @@ func finalize(startLine []byte, headers map[string]string) []byte {
|
|||||||
buf = append(buf, '\r', '\n')
|
buf = append(buf, '\r', '\n')
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
func (resp *responseHeader) Finalize() []byte {
|
|
||||||
return finalize(resp.startLine, resp.headers)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Value(key string) string {
|
|
||||||
val, ok := req.headers[key]
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Set(key string, value string) {
|
|
||||||
req.headers[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Remove(key string) {
|
|
||||||
delete(req.headers, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) GetMethod() string {
|
|
||||||
return req.method
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) GetPath() string {
|
|
||||||
return req.path
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) GetVersion() string {
|
|
||||||
return req.version
|
|
||||||
}
|
|
||||||
|
|
||||||
func (req *requestHeader) Finalize() []byte {
|
|
||||||
return finalize(req.startLine, req.headers)
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package httpheader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewRequestHeader(r interface{}) (RequestHeader, error) {
|
||||||
|
switch v := r.(type) {
|
||||||
|
case []byte:
|
||||||
|
return parseHeadersFromBytes(v)
|
||||||
|
case *bufio.Reader:
|
||||||
|
return parseHeadersFromReader(v)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported type: %T", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeader) Value(key string) string {
|
||||||
|
val, ok := req.headers[key]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeader) Set(key string, value string) {
|
||||||
|
req.headers[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeader) Remove(key string) {
|
||||||
|
delete(req.headers, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeader) GetMethod() string {
|
||||||
|
return req.method
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeader) GetPath() string {
|
||||||
|
return req.path
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeader) GetVersion() string {
|
||||||
|
return req.version
|
||||||
|
}
|
||||||
|
|
||||||
|
func (req *requestHeader) Finalize() []byte {
|
||||||
|
return finalize(req.startLine, req.headers)
|
||||||
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package httpheader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewResponseHeader(headerData []byte) (ResponseHeader, error) {
|
||||||
|
header := &responseHeader{
|
||||||
|
startLine: nil,
|
||||||
|
headers: make(map[string]string, 16),
|
||||||
|
}
|
||||||
|
|
||||||
|
lineEnd := bytes.Index(headerData, []byte("\r\n"))
|
||||||
|
if lineEnd == -1 {
|
||||||
|
return nil, fmt.Errorf("invalid response: no CRLF found in start line")
|
||||||
|
}
|
||||||
|
|
||||||
|
header.startLine = headerData[:lineEnd]
|
||||||
|
remaining := headerData[lineEnd+2:]
|
||||||
|
setRemainingHeaders(remaining, header)
|
||||||
|
|
||||||
|
return header, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *responseHeader) Value(key string) string {
|
||||||
|
return resp.headers[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *responseHeader) Set(key string, value string) {
|
||||||
|
resp.headers[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *responseHeader) Remove(key string) {
|
||||||
|
delete(resp.headers, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (resp *responseHeader) Finalize() []byte {
|
||||||
|
return finalize(resp.startLine, resp.headers)
|
||||||
|
}
|
||||||
+5
-4
@@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"tunnel_pls/internal/config"
|
"tunnel_pls/internal/config"
|
||||||
|
"tunnel_pls/internal/httpheader"
|
||||||
"tunnel_pls/session"
|
"tunnel_pls/session"
|
||||||
"tunnel_pls/types"
|
"tunnel_pls/types"
|
||||||
|
|
||||||
@@ -112,7 +113,7 @@ func (hs *httpServer) handler(conn net.Conn, isTLS bool) {
|
|||||||
defer hs.closeConnection(conn)
|
defer hs.closeConnection(conn)
|
||||||
|
|
||||||
dstReader := bufio.NewReader(conn)
|
dstReader := bufio.NewReader(conn)
|
||||||
reqhf, err := NewRequestHeader(dstReader)
|
reqhf, err := httpheader.NewRequestHeader(dstReader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error creating request header: %v", err)
|
log.Printf("Error creating request header: %v", err)
|
||||||
return
|
return
|
||||||
@@ -150,7 +151,7 @@ func (hs *httpServer) closeConnection(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hs *httpServer) extractSlug(reqhf RequestHeader) (string, error) {
|
func (hs *httpServer) extractSlug(reqhf httpheader.RequestHeader) (string, error) {
|
||||||
host := strings.Split(reqhf.Value("Host"), ".")
|
host := strings.Split(reqhf.Value("Host"), ".")
|
||||||
if len(host) < 1 {
|
if len(host) < 1 {
|
||||||
return "", errors.New("invalid host")
|
return "", errors.New("invalid host")
|
||||||
@@ -193,7 +194,7 @@ func (hs *httpServer) getSession(slug string) (session.Session, error) {
|
|||||||
return sshSession, nil
|
return sshSession, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeader, sshSession session.Session) {
|
func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, sshSession session.Session) {
|
||||||
channel, err := hs.openForwardedChannel(hw, sshSession)
|
channel, err := hs.openForwardedChannel(hw, sshSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to establish channel: %v", err)
|
log.Printf("Failed to establish channel: %v", err)
|
||||||
@@ -260,7 +261,7 @@ func (hs *httpServer) setupMiddlewares(hw HTTPWriter) {
|
|||||||
hw.UseRequestMiddleware(forwardedForMiddleware)
|
hw.UseRequestMiddleware(forwardedForMiddleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeader, channel ssh.Channel) error {
|
func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, channel ssh.Channel) error {
|
||||||
hw.SetRequestHeader(initialRequest)
|
hw.SetRequestHeader(initialRequest)
|
||||||
|
|
||||||
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
|
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
|
||||||
|
|||||||
+11
-10
@@ -6,6 +6,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"tunnel_pls/internal/httpheader"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HTTPWriter interface {
|
type HTTPWriter interface {
|
||||||
@@ -14,11 +15,11 @@ type HTTPWriter interface {
|
|||||||
RemoteAddr() net.Addr
|
RemoteAddr() net.Addr
|
||||||
UseResponseMiddleware(mw ResponseMiddleware)
|
UseResponseMiddleware(mw ResponseMiddleware)
|
||||||
UseRequestMiddleware(mw RequestMiddleware)
|
UseRequestMiddleware(mw RequestMiddleware)
|
||||||
SetRequestHeader(header RequestHeader)
|
SetRequestHeader(header httpheader.RequestHeader)
|
||||||
RequestMiddlewares() []RequestMiddleware
|
RequestMiddlewares() []RequestMiddleware
|
||||||
ResponseMiddlewares() []ResponseMiddleware
|
ResponseMiddlewares() []ResponseMiddleware
|
||||||
ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error
|
ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error
|
||||||
ApplyRequestMiddlewares(reqhf RequestHeader) error
|
ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type httpWriter struct {
|
type httpWriter struct {
|
||||||
@@ -27,8 +28,8 @@ type httpWriter struct {
|
|||||||
reader io.Reader
|
reader io.Reader
|
||||||
headerBuf []byte
|
headerBuf []byte
|
||||||
buf []byte
|
buf []byte
|
||||||
respHeader ResponseHeader
|
respHeader httpheader.ResponseHeader
|
||||||
reqHeader RequestHeader
|
reqHeader httpheader.RequestHeader
|
||||||
respMW []ResponseMiddleware
|
respMW []ResponseMiddleware
|
||||||
reqMW []RequestMiddleware
|
reqMW []RequestMiddleware
|
||||||
}
|
}
|
||||||
@@ -49,7 +50,7 @@ func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) {
|
|||||||
hw.reqMW = append(hw.reqMW, mw)
|
hw.reqMW = append(hw.reqMW, mw)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hw *httpWriter) SetRequestHeader(header RequestHeader) {
|
func (hw *httpWriter) SetRequestHeader(header httpheader.RequestHeader) {
|
||||||
hw.reqHeader = header
|
hw.reqHeader = header
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,7 +108,7 @@ func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
|
func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
|
||||||
reqhf, err := NewRequestHeader(header)
|
reqhf, err := httpheader.NewRequestHeader(header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -121,7 +122,7 @@ func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
|
|||||||
return copy(p, combined), nil
|
return copy(p, combined), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeader) error {
|
func (hw *httpWriter) ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error {
|
||||||
for _, m := range hw.RequestMiddlewares() {
|
for _, m := range hw.RequestMiddlewares() {
|
||||||
if err := m.HandleRequest(reqhf); err != nil {
|
if err := m.HandleRequest(reqhf); err != nil {
|
||||||
log.Printf("Error when applying request middleware: %v", err)
|
log.Printf("Error when applying request middleware: %v", err)
|
||||||
@@ -180,7 +181,7 @@ func (hw *httpWriter) writeRawBuffer() (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hw *httpWriter) processHTTPResponse(header, body []byte) error {
|
func (hw *httpWriter) processHTTPResponse(header, body []byte) error {
|
||||||
resphf, err := NewResponseHeader(header)
|
resphf, err := httpheader.NewResponseHeader(header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -199,7 +200,7 @@ func (hw *httpWriter) processHTTPResponse(header, body []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error {
|
func (hw *httpWriter) ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error {
|
||||||
for _, m := range hw.ResponseMiddlewares() {
|
for _, m := range hw.ResponseMiddlewares() {
|
||||||
if err := m.HandleResponse(resphf, body); err != nil {
|
if err := m.HandleResponse(resphf, body); err != nil {
|
||||||
log.Printf("Cannot apply middleware: %s\n", err)
|
log.Printf("Cannot apply middleware: %s\n", err)
|
||||||
|
|||||||
@@ -2,14 +2,15 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"tunnel_pls/internal/httpheader"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RequestMiddleware interface {
|
type RequestMiddleware interface {
|
||||||
HandleRequest(header RequestHeader) error
|
HandleRequest(header httpheader.RequestHeader) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type ResponseMiddleware interface {
|
type ResponseMiddleware interface {
|
||||||
HandleResponse(header ResponseHeader, body []byte) error
|
HandleResponse(header httpheader.ResponseHeader, body []byte) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type TunnelFingerprint struct{}
|
type TunnelFingerprint struct{}
|
||||||
@@ -18,7 +19,7 @@ func NewTunnelFingerprint() *TunnelFingerprint {
|
|||||||
return &TunnelFingerprint{}
|
return &TunnelFingerprint{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *TunnelFingerprint) HandleResponse(header ResponseHeader, body []byte) error {
|
func (h *TunnelFingerprint) HandleResponse(header httpheader.ResponseHeader, body []byte) error {
|
||||||
header.Set("Server", "Tunnel Please")
|
header.Set("Server", "Tunnel Please")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -31,7 +32,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor {
|
|||||||
return &ForwardedFor{addr: addr}
|
return &ForwardedFor{addr: addr}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ff *ForwardedFor) HandleRequest(header RequestHeader) error {
|
func (ff *ForwardedFor) HandleRequest(header httpheader.RequestHeader) error {
|
||||||
host, _, err := net.SplitHostPort(ff.addr.String())
|
host, _, err := net.SplitHostPort(ff.addr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
Reference in New Issue
Block a user