diff --git a/middleware/hijack.go b/middleware/hijack.go new file mode 100644 index 0000000..30e19e2 --- /dev/null +++ b/middleware/hijack.go @@ -0,0 +1,30 @@ +package middleware + +import ( + "bufio" + "net" + "net/http" +) + +type HijackWriter struct { + conn net.Conn +} + +func NewHijackWriter(conn net.Conn) http.ResponseWriter { + return &HijackWriter{conn: conn} +} + +func (rw *HijackWriter) Header() http.Header { + return http.Header{} +} + +func (rw *HijackWriter) Write(bytes []byte) (int, error) { + return rw.conn.Write(bytes) +} + +func (rw *HijackWriter) WriteHeader(statusCode int) { +} + +func (rw *HijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rw.conn, bufio.NewReadWriter(bufio.NewReader(rw.conn), bufio.NewWriter(rw.conn)), nil +} diff --git a/middleware/middleware.go b/middleware/middleware.go index 6f051dc..6e6d35b 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -41,6 +41,24 @@ func (w *wrapper) WriteHeader(code int) { func Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if request.Header.Get("upgrade") == "websocket" { + hijacker, ok := writer.(http.Hijacker) + if !ok { + http.Error(writer, "Hijacking not supported", http.StatusInternalServerError) + return + } + hijack, _, err := hijacker.Hijack() + if err != nil { + http.Error(writer, err.Error(), http.StatusInternalServerError) + return + } + defer hijack.Close() + rw := NewHijackWriter(hijack) + app.Server.Logger.Info(fmt.Sprintf("%s %s %s %v", utils.ClientIP(request), "WEBSOCKET", request.RequestURI, http.StatusSwitchingProtocols)) + next.ServeHTTP(rw, request) + return + } + address := strings.Split(utils.Getenv("CORS_LIST"), ",") for _, addr := range address {