Add Redis support for session and user management
This commit is contained in:
@ -1,12 +1,15 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/fossyy/filekeeper/app"
|
||||
"github.com/fossyy/filekeeper/types"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fossyy/filekeeper/utils"
|
||||
@ -14,9 +17,8 @@ import (
|
||||
|
||||
type Session struct {
|
||||
ID string
|
||||
Values map[string]interface{}
|
||||
Values types.User
|
||||
CreateTime time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type SessionInfo struct {
|
||||
@ -27,7 +29,6 @@ type SessionInfo struct {
|
||||
OSVersion string
|
||||
IP string
|
||||
Location string
|
||||
AccessAt string
|
||||
}
|
||||
|
||||
type UserStatus string
|
||||
@ -39,61 +40,65 @@ const (
|
||||
InvalidSession UserStatus = "invalid_session"
|
||||
)
|
||||
|
||||
var GlobalSessionStore = make(map[string]*Session)
|
||||
var UserSessionInfoList = make(map[string]map[string]*SessionInfo)
|
||||
|
||||
func init() {
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
go func() {
|
||||
for {
|
||||
<-ticker.C
|
||||
currentTime := time.Now()
|
||||
cacheClean := 0
|
||||
cleanID := utils.GenerateRandomString(10)
|
||||
app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Session] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second()))
|
||||
|
||||
for _, data := range GlobalSessionStore {
|
||||
data.mu.Lock()
|
||||
if currentTime.Sub(data.CreateTime) > time.Hour*24*7 {
|
||||
RemoveSessionInfo(data.Values["user"].(types.User).Email, data.ID)
|
||||
delete(GlobalSessionStore, data.ID)
|
||||
cacheClean++
|
||||
}
|
||||
data.mu.Unlock()
|
||||
}
|
||||
|
||||
app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Session] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime)))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (e *SessionNotFoundError) Error() string {
|
||||
return "session not found"
|
||||
}
|
||||
|
||||
func Get(id string) (*Session, error) {
|
||||
if session, ok := GlobalSessionStore[id]; ok {
|
||||
return session, nil
|
||||
session, err := app.Server.Cache.GetCache(context.Background(), "Session:"+id)
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, &SessionNotFoundError{}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return nil, &SessionNotFoundError{}
|
||||
var userSession Session
|
||||
err = json.Unmarshal([]byte(session), &userSession)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &userSession, nil
|
||||
}
|
||||
|
||||
func Create() *Session {
|
||||
func Create(values types.User) (*Session, error) {
|
||||
id := utils.GenerateRandomString(128)
|
||||
session := &Session{
|
||||
ID: id,
|
||||
CreateTime: time.Now(),
|
||||
Values: make(map[string]interface{}),
|
||||
Values: values,
|
||||
}
|
||||
GlobalSessionStore[id] = session
|
||||
return session
|
||||
|
||||
sessionData, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = app.Server.Cache.SetCache(context.Background(), "Session:"+id, string(sessionData), time.Hour*24) // Set expiration time as needed
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *Session) Delete() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(GlobalSessionStore, s.ID)
|
||||
func (s *Session) Change(user types.User) error {
|
||||
newSessionValue, err := json.Marshal(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = app.Server.Cache.SetCache(context.Background(), "Session:"+s.ID, newSessionValue, time.Hour*24*7)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) Delete() error {
|
||||
err := app.Server.Cache.DeleteCache(context.Background(), "Session:"+s.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) Save(w http.ResponseWriter) {
|
||||
@ -114,46 +119,71 @@ func (s *Session) Destroy(w http.ResponseWriter) {
|
||||
})
|
||||
}
|
||||
|
||||
func AddSessionInfo(email string, sessionInfo *SessionInfo) {
|
||||
if _, ok := UserSessionInfoList[email]; !ok {
|
||||
UserSessionInfoList[email] = make(map[string]*SessionInfo)
|
||||
func AddSessionInfo(email string, sessionInfo *SessionInfo) error {
|
||||
sessionInfoData, err := json.Marshal(sessionInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
UserSessionInfoList[email][sessionInfo.SessionID] = sessionInfo
|
||||
}
|
||||
|
||||
func RemoveSessionInfo(email string, id string) {
|
||||
if userSessions, ok := UserSessionInfoList[email]; ok {
|
||||
if _, ok := userSessions[id]; ok {
|
||||
delete(userSessions, id)
|
||||
if len(userSessions) == 0 {
|
||||
delete(UserSessionInfoList, email)
|
||||
}
|
||||
}
|
||||
err = app.Server.Cache.SetCache(context.Background(), "UserSessionInfo:"+email+":"+sessionInfo.SessionID, string(sessionInfoData), 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveAllSessions(email string) {
|
||||
sessionInfos := UserSessionInfoList[email]
|
||||
for _, sessionInfo := range sessionInfos {
|
||||
delete(GlobalSessionStore, sessionInfo.SessionID)
|
||||
}
|
||||
delete(UserSessionInfoList, email)
|
||||
}
|
||||
|
||||
func GetSessionInfo(email string, id string) *SessionInfo {
|
||||
if userSession, ok := UserSessionInfoList[email]; ok {
|
||||
if sessionInfo, ok := userSession[id]; ok {
|
||||
return sessionInfo
|
||||
}
|
||||
func RemoveSessionInfo(email string, id string) error {
|
||||
key := "UserSessionInfo:" + email + ":" + id
|
||||
err := app.Server.Cache.DeleteCache(context.Background(), key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sessionInfo *SessionInfo) UpdateAccessTime() {
|
||||
currentTime := time.Now()
|
||||
formattedTime := currentTime.Format("01-02-2006")
|
||||
sessionInfo.AccessAt = formattedTime
|
||||
func RemoveAllSessions(email string) error {
|
||||
pattern := "UserSessionInfo:" + email + ":*"
|
||||
keys, err := app.Server.Cache.GetKeys(context.Background(), pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
parts := strings.Split(key, ":")
|
||||
sessionID := parts[2]
|
||||
|
||||
err = app.Server.Cache.DeleteCache(context.Background(), "Session:"+sessionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = app.Server.Cache.DeleteCache(context.Background(), key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSessionInfo(email string, id string) (*SessionInfo, error) {
|
||||
key := "UserSessionInfo:" + email + ":" + id
|
||||
|
||||
sessionInfoData, err := app.Server.Cache.GetCache(context.Background(), key)
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sessionInfo SessionInfo
|
||||
err = json.Unmarshal([]byte(sessionInfoData), &sessionInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sessionInfo, nil
|
||||
}
|
||||
|
||||
func GetSession(r *http.Request) (UserStatus, types.User, string) {
|
||||
@ -162,18 +192,22 @@ func GetSession(r *http.Request) (UserStatus, types.User, string) {
|
||||
return Unauthorized, types.User{}, ""
|
||||
}
|
||||
|
||||
storeSession, ok := GlobalSessionStore[cookie.Value]
|
||||
if !ok {
|
||||
return InvalidSession, types.User{}, ""
|
||||
}
|
||||
|
||||
val := storeSession.Values["user"]
|
||||
var userSession = types.User{}
|
||||
userSession, ok = val.(types.User)
|
||||
if !ok {
|
||||
sessionData, err := app.Server.Cache.GetCache(context.Background(), "Session:"+cookie.Value)
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return InvalidSession, types.User{}, ""
|
||||
}
|
||||
return Unauthorized, types.User{}, ""
|
||||
}
|
||||
|
||||
var storeSession Session
|
||||
err = json.Unmarshal([]byte(sessionData), &storeSession)
|
||||
if err != nil {
|
||||
return Unauthorized, types.User{}, ""
|
||||
}
|
||||
|
||||
userSession := storeSession.Values
|
||||
|
||||
if !userSession.Authenticated && userSession.Totp != "" {
|
||||
return Unauthorized, userSession, cookie.Value
|
||||
}
|
||||
@ -185,13 +219,27 @@ func GetSession(r *http.Request) (UserStatus, types.User, string) {
|
||||
return Authorized, userSession, cookie.Value
|
||||
}
|
||||
|
||||
func GetSessions(email string) []*SessionInfo {
|
||||
if sessions, ok := UserSessionInfoList[email]; ok {
|
||||
result := make([]*SessionInfo, 0, len(sessions))
|
||||
for _, sessionInfo := range sessions {
|
||||
result = append(result, sessionInfo)
|
||||
}
|
||||
return result
|
||||
func GetSessions(email string) ([]*SessionInfo, error) {
|
||||
pattern := "UserSessionInfo:" + email + ":*"
|
||||
keys, err := app.Server.Cache.GetKeys(context.Background(), pattern)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil
|
||||
|
||||
var sessions []*SessionInfo
|
||||
for _, key := range keys {
|
||||
sessionData, err := app.Server.Cache.GetCache(context.Background(), key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sessionInfo SessionInfo
|
||||
err = json.Unmarshal([]byte(sessionData), &sessionInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sessions = append(sessions, &sessionInfo)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user