staging #74

Merged
bagas merged 12 commits from staging into main 2026-01-22 22:16:34 +07:00
7 changed files with 141 additions and 125 deletions
Showing only changes of commit 9a4539cc02 - Show all commits
+30
View File
@@ -0,0 +1,30 @@
package httpheader
type ResponseHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type responseHeader struct {
startLine []byte
headers map[string]string
}
type RequestHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
}
type requestHeader struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
@@ -1,4 +1,4 @@
package server package httpheader
import ( import (
"bufio" "bufio"
@@ -6,46 +6,6 @@ import (
"fmt" "fmt"
) )
type ResponseHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
}
type responseHeader struct {
startLine []byte
headers map[string]string
}
type RequestHeader interface {
Value(key string) string
Set(key string, value string)
Remove(key string)
Finalize() []byte
GetMethod() string
GetPath() string
GetVersion() string
}
type requestHeader struct {
method string
path string
version string
startLine []byte
headers map[string]string
}
func NewRequestHeader(r interface{}) (RequestHeader, error) {
switch v := r.(type) {
case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
}
}
func setRemainingHeaders(remaining []byte, header interface { func setRemainingHeaders(remaining []byte, header interface {
Set(key string, value string) Set(key string, value string)
}) { }) {
@@ -165,36 +125,6 @@ func parseHeadersFromReader(br *bufio.Reader) (RequestHeader, error) {
return header, nil return header, nil
} }
func NewResponseHeader(headerData []byte) (ResponseHeader, error) {
header := &responseHeader{
startLine: nil,
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid request: no CRLF found in start line")
}
header.startLine = headerData[:lineEnd]
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func (resp *responseHeader) Value(key string) string {
return resp.headers[key]
}
func (resp *responseHeader) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeader) Remove(key string) {
delete(resp.headers, key)
}
func finalize(startLine []byte, headers map[string]string) []byte { func finalize(startLine []byte, headers map[string]string) []byte {
size := len(startLine) + 2 size := len(startLine) + 2
for key, val := range headers { for key, val := range headers {
@@ -216,39 +146,3 @@ func finalize(startLine []byte, headers map[string]string) []byte {
buf = append(buf, '\r', '\n') buf = append(buf, '\r', '\n')
return buf return buf
} }
func (resp *responseHeader) Finalize() []byte {
return finalize(resp.startLine, resp.headers)
}
func (req *requestHeader) Value(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeader) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeader) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeader) GetMethod() string {
return req.method
}
func (req *requestHeader) GetPath() string {
return req.path
}
func (req *requestHeader) GetVersion() string {
return req.version
}
func (req *requestHeader) Finalize() []byte {
return finalize(req.startLine, req.headers)
}
+49
View File
@@ -0,0 +1,49 @@
package httpheader
import (
"bufio"
"fmt"
)
func NewRequestHeader(r interface{}) (RequestHeader, error) {
switch v := r.(type) {
case []byte:
return parseHeadersFromBytes(v)
case *bufio.Reader:
return parseHeadersFromReader(v)
default:
return nil, fmt.Errorf("unsupported type: %T", r)
}
}
func (req *requestHeader) Value(key string) string {
val, ok := req.headers[key]
if !ok {
return ""
}
return val
}
func (req *requestHeader) Set(key string, value string) {
req.headers[key] = value
}
func (req *requestHeader) Remove(key string) {
delete(req.headers, key)
}
func (req *requestHeader) GetMethod() string {
return req.method
}
func (req *requestHeader) GetPath() string {
return req.path
}
func (req *requestHeader) GetVersion() string {
return req.version
}
func (req *requestHeader) Finalize() []byte {
return finalize(req.startLine, req.headers)
}
+40
View File
@@ -0,0 +1,40 @@
package httpheader
import (
"bytes"
"fmt"
)
func NewResponseHeader(headerData []byte) (ResponseHeader, error) {
header := &responseHeader{
startLine: nil,
headers: make(map[string]string, 16),
}
lineEnd := bytes.Index(headerData, []byte("\r\n"))
if lineEnd == -1 {
return nil, fmt.Errorf("invalid response: no CRLF found in start line")
}
header.startLine = headerData[:lineEnd]
remaining := headerData[lineEnd+2:]
setRemainingHeaders(remaining, header)
return header, nil
}
func (resp *responseHeader) Value(key string) string {
return resp.headers[key]
}
func (resp *responseHeader) Set(key string, value string) {
resp.headers[key] = value
}
func (resp *responseHeader) Remove(key string) {
delete(resp.headers, key)
}
func (resp *responseHeader) Finalize() []byte {
return finalize(resp.startLine, resp.headers)
}
+5 -4
View File
@@ -11,6 +11,7 @@ import (
"strings" "strings"
"time" "time"
"tunnel_pls/internal/config" "tunnel_pls/internal/config"
"tunnel_pls/internal/httpheader"
"tunnel_pls/session" "tunnel_pls/session"
"tunnel_pls/types" "tunnel_pls/types"
@@ -112,7 +113,7 @@ func (hs *httpServer) handler(conn net.Conn, isTLS bool) {
defer hs.closeConnection(conn) defer hs.closeConnection(conn)
dstReader := bufio.NewReader(conn) dstReader := bufio.NewReader(conn)
reqhf, err := NewRequestHeader(dstReader) reqhf, err := httpheader.NewRequestHeader(dstReader)
if err != nil { if err != nil {
log.Printf("Error creating request header: %v", err) log.Printf("Error creating request header: %v", err)
return return
@@ -150,7 +151,7 @@ func (hs *httpServer) closeConnection(conn net.Conn) {
} }
} }
func (hs *httpServer) extractSlug(reqhf RequestHeader) (string, error) { func (hs *httpServer) extractSlug(reqhf httpheader.RequestHeader) (string, error) {
host := strings.Split(reqhf.Value("Host"), ".") host := strings.Split(reqhf.Value("Host"), ".")
if len(host) < 1 { if len(host) < 1 {
return "", errors.New("invalid host") return "", errors.New("invalid host")
@@ -193,7 +194,7 @@ func (hs *httpServer) getSession(slug string) (session.Session, error) {
return sshSession, nil return sshSession, nil
} }
func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest RequestHeader, sshSession session.Session) { func (hs *httpServer) forwardRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, sshSession session.Session) {
channel, err := hs.openForwardedChannel(hw, sshSession) channel, err := hs.openForwardedChannel(hw, sshSession)
if err != nil { if err != nil {
log.Printf("Failed to establish channel: %v", err) log.Printf("Failed to establish channel: %v", err)
@@ -260,7 +261,7 @@ func (hs *httpServer) setupMiddlewares(hw HTTPWriter) {
hw.UseRequestMiddleware(forwardedForMiddleware) hw.UseRequestMiddleware(forwardedForMiddleware)
} }
func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest RequestHeader, channel ssh.Channel) error { func (hs *httpServer) sendInitialRequest(hw HTTPWriter, initialRequest httpheader.RequestHeader, channel ssh.Channel) error {
hw.SetRequestHeader(initialRequest) hw.SetRequestHeader(initialRequest)
if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil { if err := hw.ApplyRequestMiddlewares(initialRequest); err != nil {
+11 -10
View File
@@ -6,6 +6,7 @@ import (
"log" "log"
"net" "net"
"regexp" "regexp"
"tunnel_pls/internal/httpheader"
) )
type HTTPWriter interface { type HTTPWriter interface {
@@ -14,11 +15,11 @@ type HTTPWriter interface {
RemoteAddr() net.Addr RemoteAddr() net.Addr
UseResponseMiddleware(mw ResponseMiddleware) UseResponseMiddleware(mw ResponseMiddleware)
UseRequestMiddleware(mw RequestMiddleware) UseRequestMiddleware(mw RequestMiddleware)
SetRequestHeader(header RequestHeader) SetRequestHeader(header httpheader.RequestHeader)
RequestMiddlewares() []RequestMiddleware RequestMiddlewares() []RequestMiddleware
ResponseMiddlewares() []ResponseMiddleware ResponseMiddlewares() []ResponseMiddleware
ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error
ApplyRequestMiddlewares(reqhf RequestHeader) error ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error
} }
type httpWriter struct { type httpWriter struct {
@@ -27,8 +28,8 @@ type httpWriter struct {
reader io.Reader reader io.Reader
headerBuf []byte headerBuf []byte
buf []byte buf []byte
respHeader ResponseHeader respHeader httpheader.ResponseHeader
reqHeader RequestHeader reqHeader httpheader.RequestHeader
respMW []ResponseMiddleware respMW []ResponseMiddleware
reqMW []RequestMiddleware reqMW []RequestMiddleware
} }
@@ -49,7 +50,7 @@ func (hw *httpWriter) UseRequestMiddleware(mw RequestMiddleware) {
hw.reqMW = append(hw.reqMW, mw) hw.reqMW = append(hw.reqMW, mw)
} }
func (hw *httpWriter) SetRequestHeader(header RequestHeader) { func (hw *httpWriter) SetRequestHeader(header httpheader.RequestHeader) {
hw.reqHeader = header hw.reqHeader = header
} }
@@ -107,7 +108,7 @@ func (hw *httpWriter) splitHeaderAndBody(data []byte, delimiterIdx int) ([]byte,
} }
func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) { func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
reqhf, err := NewRequestHeader(header) reqhf, err := httpheader.NewRequestHeader(header)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -121,7 +122,7 @@ func (hw *httpWriter) processHTTPRequest(p, header, body []byte) (int, error) {
return copy(p, combined), nil return copy(p, combined), nil
} }
func (hw *httpWriter) ApplyRequestMiddlewares(reqhf RequestHeader) error { func (hw *httpWriter) ApplyRequestMiddlewares(reqhf httpheader.RequestHeader) error {
for _, m := range hw.RequestMiddlewares() { for _, m := range hw.RequestMiddlewares() {
if err := m.HandleRequest(reqhf); err != nil { if err := m.HandleRequest(reqhf); err != nil {
log.Printf("Error when applying request middleware: %v", err) log.Printf("Error when applying request middleware: %v", err)
@@ -180,7 +181,7 @@ func (hw *httpWriter) writeRawBuffer() (int, error) {
} }
func (hw *httpWriter) processHTTPResponse(header, body []byte) error { func (hw *httpWriter) processHTTPResponse(header, body []byte) error {
resphf, err := NewResponseHeader(header) resphf, err := httpheader.NewResponseHeader(header)
if err != nil { if err != nil {
return err return err
} }
@@ -199,7 +200,7 @@ func (hw *httpWriter) processHTTPResponse(header, body []byte) error {
return nil return nil
} }
func (hw *httpWriter) ApplyResponseMiddlewares(resphf ResponseHeader, body []byte) error { func (hw *httpWriter) ApplyResponseMiddlewares(resphf httpheader.ResponseHeader, body []byte) error {
for _, m := range hw.ResponseMiddlewares() { for _, m := range hw.ResponseMiddlewares() {
if err := m.HandleResponse(resphf, body); err != nil { if err := m.HandleResponse(resphf, body); err != nil {
log.Printf("Cannot apply middleware: %s\n", err) log.Printf("Cannot apply middleware: %s\n", err)
+5 -4
View File
@@ -2,14 +2,15 @@ package server
import ( import (
"net" "net"
"tunnel_pls/internal/httpheader"
) )
type RequestMiddleware interface { type RequestMiddleware interface {
HandleRequest(header RequestHeader) error HandleRequest(header httpheader.RequestHeader) error
} }
type ResponseMiddleware interface { type ResponseMiddleware interface {
HandleResponse(header ResponseHeader, body []byte) error HandleResponse(header httpheader.ResponseHeader, body []byte) error
} }
type TunnelFingerprint struct{} type TunnelFingerprint struct{}
@@ -18,7 +19,7 @@ func NewTunnelFingerprint() *TunnelFingerprint {
return &TunnelFingerprint{} return &TunnelFingerprint{}
} }
func (h *TunnelFingerprint) HandleResponse(header ResponseHeader, body []byte) error { func (h *TunnelFingerprint) HandleResponse(header httpheader.ResponseHeader, body []byte) error {
header.Set("Server", "Tunnel Please") header.Set("Server", "Tunnel Please")
return nil return nil
} }
@@ -31,7 +32,7 @@ func NewForwardedFor(addr net.Addr) *ForwardedFor {
return &ForwardedFor{addr: addr} return &ForwardedFor{addr: addr}
} }
func (ff *ForwardedFor) HandleRequest(header RequestHeader) error { func (ff *ForwardedFor) HandleRequest(header httpheader.RequestHeader) error {
host, _, err := net.SplitHostPort(ff.addr.String()) host, _, err := net.SplitHostPort(ff.addr.String())
if err != nil { if err != nil {
return err return err