diff --git a/server/http.go b/server/http.go index cc39d46..6ae0454 100644 --- a/server/http.go +++ b/server/http.go @@ -11,7 +11,6 @@ import ( "regexp" "strings" "tunnel_pls/session" - "tunnel_pls/types" "tunnel_pls/utils" ) @@ -30,7 +29,6 @@ type CustomWriter struct { respMW []ResponseMiddleware reqStartMW []RequestMiddleware reqEndMW []RequestMiddleware - overflow []byte } func (cw *CustomWriter) SetInteraction(interaction Interaction) { @@ -38,14 +36,6 @@ func (cw *CustomWriter) SetInteraction(interaction Interaction) { } func (cw *CustomWriter) Read(p []byte) (int, error) { - if len(cw.overflow) > 0 { - n := copy(p, cw.overflow) - cw.overflow = cw.overflow[n:] - if len(cw.overflow) == 0 { - cw.overflow = nil - } - return n, nil - } tmp := make([]byte, len(p)) read, err := cw.reader.Read(tmp) if read == 0 && err != nil { @@ -99,12 +89,6 @@ func (cw *CustomWriter) Read(p []byte) (int, error) { n := copy(p, combined) - if n > len(p) { - cw.overflow = make([]byte, len(combined)-n) - copy(cw.overflow, combined[n:]) - log.Printf("output buffer too small (%d vs %d)", len(p), n) - } - return n, nil } @@ -118,7 +102,7 @@ func NewCustomWriter(writer io.Writer, reader io.Reader, remoteAddr net.Addr) *C } } -var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} // HTTP HEADER DELIMITER `\r\n\r\n` +var DELIMITER = []byte{0x0D, 0x0A, 0x0D, 0x0A} var requestLine = regexp.MustCompile(`^(GET|POST|PUT|DELETE|HEAD|OPTIONS|PATCH|TRACE|CONNECT) \S+ HTTP/\d\.\d$`) var responseLine = regexp.MustCompile(`^HTTP/\d\.\d \d{3} .+`) @@ -143,8 +127,8 @@ func isHTTPHeader(buf []byte) bool { } func (cw *CustomWriter) Write(p []byte) (int, error) { - if len(p) == len(types.BadGatewayResponse) && bytes.Equal(p, types.BadGatewayResponse) { - return cw.writer.Write(p) + if cw.respHeader != nil && len(cw.buf) == 0 && len(p) >= 5 && string(p[0:5]) == "HTTP/" { + cw.respHeader = nil } if cw.respHeader != nil { @@ -166,9 +150,12 @@ func (cw *CustomWriter) Write(p []byte) (int, error) { body := cw.buf[idx+len(DELIMITER):] if !isHTTPHeader(header) { - n, err := cw.writer.Write(cw.buf) + _, err := cw.writer.Write(cw.buf) cw.buf = nil - return n, err + if err != nil { + return 0, err + } + return len(p), nil } resphf := NewResponseHeaderFactory(header) @@ -323,20 +310,12 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS } } }() - _, err = channel.Write(initialRequest.Finalize()) - if err != nil { - log.Printf("Failed to forward request: %v", err) - return - } - //TODO: Implement wrapper func buat add/remove middleware + fingerprintMiddleware := NewTunnelFingerprint() - loggerMiddleware := NewRequestLogger(cw.interaction, cw.RemoteAddr) forwardedForMiddleware := NewForwardedFor(cw.RemoteAddr) cw.respMW = append(cw.respMW, fingerprintMiddleware) - cw.reqStartMW = append(cw.reqStartMW, loggerMiddleware) cw.reqStartMW = append(cw.reqStartMW, forwardedForMiddleware) - //TODO: Tambah req Middleware cw.reqEndMW = nil cw.reqHeader = initialRequest @@ -348,6 +327,12 @@ func forwardRequest(cw *CustomWriter, initialRequest *RequestHeaderFactory, sshS } } + _, err = channel.Write(initialRequest.Finalize()) + if err != nil { + log.Printf("Failed to forward request: %v", err) + return + } + sshSession.Forwarder.HandleConnection(cw, channel, cw.RemoteAddr) return } diff --git a/server/middleware.go b/server/middleware.go index f26504c..a0e3c2b 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -1,9 +1,7 @@ package server import ( - "fmt" "net" - "time" ) type RequestMiddleware interface { @@ -30,18 +28,6 @@ type RequestLogger struct { remoteAddr net.Addr } -func NewRequestLogger(interaction Interaction, remoteAddr net.Addr) *RequestLogger { - return &RequestLogger{ - interaction: interaction, - remoteAddr: remoteAddr, - } -} - -func (rl *RequestLogger) HandleRequest(header *RequestHeaderFactory) error { - rl.interaction.SendMessage(fmt.Sprintf("\033[32m%s %s -> %s %s \033[0m\r\n", time.Now().UTC().Format(time.RFC3339), rl.remoteAddr.String(), header.Method, header.Path)) - return nil -} - type ForwardedFor struct { addr net.Addr }