diff --git a/internal/httpheader/header.go b/internal/httpheader/header.go new file mode 100644 index 0000000..ccd1bed --- /dev/null +++ b/internal/httpheader/header.go @@ -0,0 +1,30 @@ +package httpheader + +type ResponseHeader interface { + Value(key string) string + Set(key string, value string) + Remove(key string) + Finalize() []byte +} + +type responseHeader struct { + startLine []byte + headers map[string]string +} + +type RequestHeader interface { + Value(key string) string + Set(key string, value string) + Remove(key string) + Finalize() []byte + GetMethod() string + GetPath() string + GetVersion() string +} +type requestHeader struct { + method string + path string + version string + startLine []byte + headers map[string]string +} diff --git a/server/header.go b/internal/httpheader/parser.go similarity index 59% rename from server/header.go rename to internal/httpheader/parser.go index bc3ce73..3325ae5 100644 --- a/server/header.go +++ b/internal/httpheader/parser.go @@ -1,4 +1,4 @@ -package server +package httpheader import ( "bufio" @@ -6,46 +6,6 @@ import ( "fmt" ) -type ResponseHeader interface { - Value(key string) string - Set(key string, value string) - Remove(key string) - Finalize() []byte -} - -type responseHeader struct { - startLine []byte - headers map[string]string -} - -type RequestHeader interface { - Value(key string) string - Set(key string, value string) - Remove(key string) - Finalize() []byte - GetMethod() string - GetPath() string - GetVersion() string -} -type requestHeader struct { - method string - path string - version string - startLine []byte - headers map[string]string -} - -func NewRequestHeader(r interface{}) (RequestHeader, error) { - switch v := r.(type) { - case []byte: - return parseHeadersFromBytes(v) - case *bufio.Reader: - return parseHeadersFromReader(v) - default: - return nil, fmt.Errorf("unsupported type: %T", r) - } -} - func setRemainingHeaders(remaining []byte, header interface { Set(key string, value string) }) { @@ -165,36 +125,6 @@ func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) { return header, nil } -func NewResponseHeader(headerData []byte) (ResponseHeader, error) { - header := &responseHeader{ - startLine: nil, - headers: make(map[string]string, 16), - } - - lineEnd := bytes.Index(headerData, []byte("\r\n")) - if lineEnd == -1 { - return nil, fmt.Errorf("invalid request: no CRLF found in start line") - } - - header.startLine = headerData[:lineEnd] - remaining := headerData[lineEnd+2:] - setRemainingHeaders(remaining, header) - - return header, nil -} - -func (resp *responseHeader) Value(key string) string { - return resp.headers[key] -} - -func (resp *responseHeader) Set(key string, value string) { - resp.headers[key] = value -} - -func (resp *responseHeader) Remove(key string) { - delete(resp.headers, key) -} - func finalize(startLine []byte, headers map[string]string) []byte { size := len(startLine) + 2 for key, val := range headers { @@ -216,39 +146,3 @@ func finalize(startLine []byte, headers map[string]string) []byte { buf = append(buf, '\r', '\n') return buf } - -func (resp *responseHeader) Finalize() []byte { - return finalize(resp.startLine, resp.headers) -} - -func (req *requestHeader) Value(key string) string { - val, ok := req.headers[key] - if !ok { - return "" - } - return val -} - -func (req *requestHeader) Set(key string, value string) { - req.headers[key] = value -} - -func (req *requestHeader) Remove(key string) { - delete(req.headers, key) -} - -func (req *requestHeader) GetMethod() string { - return req.method -} - -func (req *requestHeader) GetPath() string { - return req.path -} - -func (req *requestHeader) GetVersion() string { - return req.version -} - -func (req *requestHeader) Finalize() []byte { - return finalize(req.startLine, req.headers) -} diff --git a/internal/httpheader/request.go b/internal/httpheader/request.go new file mode 100644 index 0000000..ae63340 --- /dev/null +++ b/internal/httpheader/request.go @@ -0,0 +1,49 @@ +package httpheader + +import ( + "bufio" + "fmt" +) + +func NewRequestHeader(r interface{}) (RequestHeader, error) { + switch v := r.(type) { + case []byte: + return parseHeadersFromBytes(v) + case *bufio.Reader: + return parseHeadersFromReader(v) + default: + return nil, fmt.Errorf("unsupported type: %T", r) + } +} + +func (req *requestHeader) Value(key string) string { + val, ok := req.headers[key] + if !ok { + return "" + } + return val +} + +func (req *requestHeader) Set(key string, value string) { + req.headers[key] = value +} + +func (req *requestHeader) Remove(key string) { + delete(req.headers, key) +} + +func (req *requestHeader) GetMethod() string { + return req.method +} + +func (req *requestHeader) GetPath() string { + return req.path +} + +func (req *requestHeader) GetVersion() string { + return req.version +} + +func (req *requestHeader) Finalize() []byte { + return finalize(req.startLine, req.headers) +} diff --git a/internal/httpheader/response.go b/internal/httpheader/response.go new file mode 100644 index 0000000..63ad352 --- /dev/null +++ b/internal/httpheader/response.go @@ -0,0 +1,40 @@ +package httpheader + +import ( + "bytes" + "fmt" +) + +func NewResponseHeader(headerData []byte) (ResponseHeader, error) { + header := &responseHeader{ + startLine: nil, + headers: make(map[string]string, 16), + } + + lineEnd := bytes.Index(headerData, []byte("\r\n")) + if lineEnd == -1 { + return nil, fmt.Errorf("invalid response: no CRLF found in start line") + } + + header.startLine = headerData[:lineEnd] + remaining := headerData[lineEnd+2:] + setRemainingHeaders(remaining, header) + + return header, nil +} + +func (resp *responseHeader) Value(key string) string { + return resp.headers[key] +} + +func (resp *responseHeader) Set(key string, value string) { + resp.headers[key] = value +} + +func (resp *responseHeader) Remove(key string) { + delete(resp.headers, key) +} + +func (resp *responseHeader) Finalize() []byte { + return finalize(resp.startLine, resp.headers) +} diff --git a/server/http.go b/server/http.go index e8da8a6..a36b6b1 100644 --- a/server/http.go +++ b/server/http.go @@ -11,6 +11,7 @@ import ( "strings" "time" "tunnel_pls/internal/config" + "tunnel_pls/internal/httpheader" "tunnel_pls/session" "tunnel_pls/types" @@ -112,7 +113,7 @@ func (hs *httpServer) handler(conn net.Conn, isTLS bool) { defer hs.closeConnection(conn) dstReader := bufio.NewReader(conn) - reqhf, err := NewRequestHeader(dstReader) + reqhf, err := httpheader.NewRequestHeader(dstReader) if err != nil { log.Printf("Error creating request header: %v", err) return @@ -150,7 +151,7 @@ func (hs *httpServer) closeConnection(conn net.Conn) { } } -func (hs *httpServer) extractSlug(reqhf RequestHeader) (string, error) { +func (hs *httpServer) extractSlug(reqhf httpheader.RequestHeader) (string, error) { host := strings.Split(reqhf.Value("Host"), ".") if len(host) < 1 { return "", errors.New("invalid host") @@ -193,7 +194,7 @@ func (hs *httpServer) getSession(slug string) (session.Session, error) { return sshSession, nil } -func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeader, sshSession session.Session) { +func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, sshSession session.Session) { channel, err := hs.openForwardedChannel(hw, sshSession) if err != nil { log.Printf("Failed to establish channel: %v", err) @@ -260,7 +261,7 @@ func (hs *httpServer) setupMiddlewares(hw HTTPWriter) { hw.UseRequestMiddleware(forwardedForMiddleware) } -func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeader, channel ssh.Channel) error { +func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, channel ssh.Channel) error { hw.SetRequestHeader(initialRequest) if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { diff --git a/server/httpwritter.go b/server/httpwritter.go index 64154d0..9d52f24 100644 --- a/server/httpwritter.go +++ b/server/httpwritter.go @@ -6,6 +6,7 @@ import ( "log" "net" "regexp" + "tunnel_pls/internal/httpheader" ) type HTTPWriter interface { @@ -14,11 +15,11 @@ type HTTPWriter interface { RemoteAddr() net.Addr UseResponseMiddleware(mw ResponseMiddleware) UseRequestMiddleware(mw RequestMiddleware) - SetRequestHeader(header RequestHeader) + SetRequestHeader(header httpheader.RequestHeader) RequestMiddlewares() []RequestMiddleware ResponseMiddlewares() []ResponseMiddleware - ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error - ApplyRequestMiddlewares(reqhf RequestHeader) error + ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error + ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error } type httpWriter struct { @@ -27,8 +28,8 @@ type httpWriter struct { reader io.Reader headerBuf []byte buf []byte - respHeader ResponseHeader - reqHeader RequestHeader + respHeader httpheader.ResponseHeader + reqHeader httpheader.RequestHeader respMW []ResponseMiddleware reqMW []RequestMiddleware } @@ -49,7 +50,7 @@ func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) { hw.reqMW = append(hw.reqMW, mw) } -func (hw *httpWriter) SetRequestHeader(header RequestHeader) { +func (hw *httpWriter) SetRequestHeader(header httpheader.RequestHeader) { hw.reqHeader = header } @@ -107,7 +108,7 @@ func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte, } func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { - reqhf, err := NewRequestHeader(header) + reqhf, err := httpheader.NewRequestHeader(header) if err != nil { return 0, err } @@ -121,7 +122,7 @@ func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { return copy(p, combined), nil } -func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeader) error { +func (hw *httpWriter) ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error { for _, m := range hw.RequestMiddlewares() { if err := m.HandleRequest(reqhf); err != nil { log.Printf("Error when applying request middleware: %v", err) @@ -180,7 +181,7 @@ func (hw *httpWriter) writeRawBuffer() (int, error) { } func (hw *httpWriter) processHTTPResponse(header, body []byte) error { - resphf, err := NewResponseHeader(header) + resphf, err := httpheader.NewResponseHeader(header) if err != nil { return err } @@ -199,7 +200,7 @@ func (hw *httpWriter) processHTTPResponse(header, body []byte) error { return nil } -func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error { +func (hw *httpWriter) ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error { for _, m := range hw.ResponseMiddlewares() { if err := m.HandleResponse(resphf, body); err != nil { log.Printf("Cannot apply middleware: %s\n", err) diff --git a/server/middleware.go b/server/middleware.go index 63b2467..6f50c4c 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -2,14 +2,15 @@ package server import ( "net" + "tunnel_pls/internal/httpheader" ) type RequestMiddleware interface { - HandleRequest(header RequestHeader) error + HandleRequest(header httpheader.RequestHeader) error } type ResponseMiddleware interface { - HandleResponse(header ResponseHeader, body []byte) error + HandleResponse(header httpheader.ResponseHeader, body []byte) error } type TunnelFingerprint struct{} @@ -18,7 +19,7 @@ func NewTunnelFingerprint() *TunnelFingerprint { return &TunnelFingerprint{} } -func (h *TunnelFingerprint) HandleResponse(header ResponseHeader, body []byte) error { +func (h *TunnelFingerprint) HandleResponse(header httpheader.ResponseHeader, body []byte) error { header.Set("Server", "Tunnel Please") return nil } @@ -31,7 +32,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor { return &ForwardedFor{addr: addr} } -func (ff *ForwardedFor) HandleRequest(header RequestHeader) error { +func (ff *ForwardedFor) HandleRequest(header httpheader.RequestHeader) error { host, _, err := net.SplitHostPort(ff.addr.String()) if err != nil { return err