172 lines
4.4 KiB
Go
172 lines
4.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"github.com/fossyy/filekeeper/app"
|
|
errorView "github.com/fossyy/filekeeper/view/client/error"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/fossyy/filekeeper/session"
|
|
"github.com/fossyy/filekeeper/utils"
|
|
)
|
|
|
|
type wrapper struct {
|
|
http.ResponseWriter
|
|
request *http.Request
|
|
statusCode int
|
|
}
|
|
|
|
func (w *wrapper) WriteHeader(code int) {
|
|
w.statusCode = code
|
|
|
|
if code == http.StatusNotFound {
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.ResponseWriter.WriteHeader(code)
|
|
component := errorView.NotFound("Not Found")
|
|
err := component.Render(w.request.Context(), w)
|
|
if err != nil {
|
|
app.Server.Logger.Error(err.Error())
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
if code == http.StatusInternalServerError {
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.ResponseWriter.WriteHeader(code)
|
|
component := errorView.InternalServerError("Internal Server Error")
|
|
err := component.Render(w.request.Context(), w)
|
|
if err != nil {
|
|
app.Server.Logger.Error(err.Error())
|
|
return
|
|
}
|
|
return
|
|
}
|
|
w.ResponseWriter.WriteHeader(code)
|
|
|
|
return
|
|
}
|
|
|
|
func Handler(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
|
if request.Header.Get("upgrade") == "websocket" {
|
|
hijacker, ok := writer.(http.Hijacker)
|
|
if !ok {
|
|
http.Error(writer, "Hijacking not supported", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
hijack, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
http.Error(writer, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer hijack.Close()
|
|
rw := NewHijackWriter(hijack)
|
|
app.Server.Logger.Info(fmt.Sprintf("%s %s %s %v", utils.ClientIP(request), "WEBSOCKET", request.RequestURI, http.StatusSwitchingProtocols))
|
|
next.ServeHTTP(rw, request)
|
|
return
|
|
}
|
|
|
|
address := strings.Split(utils.Getenv("CORS_LIST"), ",")
|
|
|
|
for _, addr := range address {
|
|
if request.Host == addr {
|
|
writer.Header().Set("Access-Control-Allow-Origin", fmt.Sprintf("%s://%s", utils.Getenv("CORS_PROTO"), addr))
|
|
}
|
|
}
|
|
|
|
wrappedWriter := &wrapper{
|
|
ResponseWriter: writer,
|
|
request: request,
|
|
statusCode: http.StatusOK,
|
|
}
|
|
|
|
writer.Header().Set("Access-Control-Allow-Methods", fmt.Sprintf("%s, OPTIONS", utils.Getenv("CORS_METHODS")))
|
|
writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
|
next.ServeHTTP(wrappedWriter, request)
|
|
app.Server.Logger.Info(fmt.Sprintf("%s %s %s %v", utils.ClientIP(request), request.Method, request.RequestURI, wrappedWriter.statusCode))
|
|
})
|
|
}
|
|
|
|
func Auth(next http.HandlerFunc) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
status, user, sessionID := session.GetSession(r)
|
|
|
|
switch status {
|
|
case session.Authorized:
|
|
ctx := context.WithValue(r.Context(), "user", user)
|
|
ctx = context.WithValue(ctx, "sessionID", sessionID)
|
|
req := r.WithContext(ctx)
|
|
next.ServeHTTP(w, req)
|
|
return
|
|
case session.Unauthorized:
|
|
if r.RequestURI != "/logout" {
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "redirect",
|
|
Value: r.RequestURI,
|
|
Path: "/",
|
|
})
|
|
}
|
|
http.Redirect(w, r, "/signin", http.StatusSeeOther)
|
|
return
|
|
case session.InvalidSession:
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "Session",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
})
|
|
http.Redirect(w, r, "/signin", http.StatusSeeOther)
|
|
return
|
|
case session.Suspicious:
|
|
userSession := session.Get(sessionID)
|
|
err := userSession.Delete()
|
|
if err != nil {
|
|
app.Server.Logger.Error(err)
|
|
}
|
|
err = session.RemoveSessionInfo(user.Email, sessionID)
|
|
if err != nil {
|
|
app.Server.Logger.Error(err)
|
|
return
|
|
}
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "Session",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
})
|
|
http.Redirect(w, r, "/signin?error=suspicious_session", http.StatusSeeOther)
|
|
return
|
|
default:
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
|
|
func Guest(next http.HandlerFunc) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
status, _, _ := session.GetSession(r)
|
|
|
|
switch status {
|
|
case session.Authorized:
|
|
http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
return
|
|
case session.Unauthorized:
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
case session.InvalidSession:
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "Session",
|
|
Value: "",
|
|
Path: "/",
|
|
MaxAge: -1,
|
|
})
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
})
|
|
}
|