From b967619a3a3eaa72a90488a98a2620f28517c462 Mon Sep 17 00:00:00 2001 From: bagas Date: Tue, 2 Dec 2025 19:17:20 +0700 Subject: [PATCH] fix: chunk request not sent properly --- server/http.go | 83 +++++--------------------------------------------- 1 file changed, 7 insertions(+), 76 deletions(-) diff --git a/server/http.go b/server/http.go index 1088f84..319e906 100644 --- a/server/http.go +++ b/server/http.go @@ -8,7 +8,6 @@ import ( "io" "log" "net" - "net/http" "regexp" "strconv" "strings" @@ -16,7 +15,6 @@ import ( "tunnel_pls/session" "tunnel_pls/utils" - "github.com/gorilla/websocket" "golang.org/x/crypto/ssh" ) @@ -25,61 +23,6 @@ var BAD_GATEWAY_RESPONSE = []byte("HTTP/1.1 502 Bad Gateway\r\n" + "Content-Type: text/plain\r\n\r\n" + "Bad Gateway") -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} - -type connResponseWriter struct { - conn net.Conn - header http.Header - wrote bool -} - -func (w *connResponseWriter) Header() http.Header { - if w.header == nil { - w.header = make(http.Header) - } - return w.header -} - -func (w *connResponseWriter) WriteHeader(statusCode int) { - if w.wrote { - return - } - w.wrote = true - _, err := fmt.Fprintf(w.conn, "HTTP/1.1 %d %s\r\n", statusCode, http.StatusText(statusCode)) - if err != nil { - log.Printf("Error writing HTTP response: %v", err) - return - } - err = w.header.Write(w.conn) - if err != nil { - log.Printf("Error writing HTTP header: %v", err) - return - } - _, err = fmt.Fprint(w.conn, "\r\n") - if err != nil { - log.Printf("Error writing HTTP header: %v", err) - return - } -} - -func (w *connResponseWriter) Write(b []byte) (int, error) { - if !w.wrote { - w.WriteHeader(http.StatusOK) - } - return w.conn.Write(b) -} - -func (w *connResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - rw := bufio.NewReadWriter( - bufio.NewReader(w.conn), - bufio.NewWriter(w.conn), - ) - return w.conn, rw, nil -} - type CustomWriter struct { RemoteAddr net.Addr writer io.Writer @@ -198,10 +141,9 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { } req := cw.Requests[0] req.Written += len(body) - if req.Chunked { req.Tail = append(req.Tail, p[len(p)-5:]...) - if bytes.Equal(req.Tail, []byte("0\r\n\r\n")) { + if bytes.Contains(p, []byte("0\r\n\r\n")) { cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } } else if req.ContentSize != -1 { @@ -211,6 +153,7 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { } } else { cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } cw.buf = nil return len(p), nil @@ -223,10 +166,10 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { } req := cw.Requests[0] - req.Written += len(p) + req.Written += n if req.Chunked { req.Tail = append(req.Tail, p[len(p)-5:]...) - if bytes.Equal(req.Tail, []byte("0\r\n\r\n")) { + if bytes.Contains(p, []byte("0\r\n\r\n")) { cw.Requests = cw.Requests[1:] cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } @@ -235,10 +178,11 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { cw.Requests = cw.Requests[1:] cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } + } else { cw.Requests = cw.Requests[1:] + cw.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", timestamp, cw.RemoteAddr.String(), req.Method, req.Path)) } - return n, nil } @@ -247,19 +191,6 @@ func (cw *CustomWriter) AddInteraction(interaction *session.Interaction) { } var redirectTLS = false -var allowedCors = make(map[string]bool) -var isAllowedAllCors = false - -func init() { - corsList := utils.Getenv("cors_list") - if corsList == "*" { - isAllowedAllCors = true - } else { - for _, allowedOrigin := range strings.Split(corsList, ",") { - allowedCors[allowedOrigin] = true - } - } -} func NewHTTPServer() error { listener, err := net.Listen("tcp", ":80") @@ -416,7 +347,7 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS _, err = channel.Write(initialRequest.Finalize()) if err != nil { - log.Printf("Failed to write forwarded-tcpip:", err) + log.Printf("Failed to forward request: %v", err) return }