diff --git a/cache/user.go b/cache/user.go index 99f23dd..4607eee 100644 --- a/cache/user.go +++ b/cache/user.go @@ -26,7 +26,7 @@ func init() { log = logger.Logger() userCache = make(map[string]*UserWithExpired) - ticker := time.NewTicker(time.Hour) + ticker := time.NewTicker(time.Minute) go func() { for { diff --git a/handler/auth/google/callback/callback.go b/handler/auth/google/callback/callback.go index d4a2cc3..cc937a9 100644 --- a/handler/auth/google/callback/callback.go +++ b/handler/auth/google/callback/callback.go @@ -167,7 +167,7 @@ func GET(w http.ResponseWriter, r *http.Request) { log.Error(err.Error()) return } - storeSession := session.GlobalSessionStore.Create() + storeSession := session.Create() storeSession.Values["user"] = types.User{ UserID: user.UserID, Email: oauthUser.Email, diff --git a/handler/auth/google/setup/setup.go b/handler/auth/google/setup/setup.go index 5ba6b3e..16aa7df 100644 --- a/handler/auth/google/setup/setup.go +++ b/handler/auth/google/setup/setup.go @@ -130,7 +130,7 @@ func POST(w http.ResponseWriter, r *http.Request) { delete(SetupUser, code) - storeSession := session.GlobalSessionStore.Create() + storeSession := session.Create() storeSession.Values["user"] = types.User{ UserID: userID, Email: unregisteredUser.Email, diff --git a/handler/logout/logout.go b/handler/logout/logout.go index 181310e..ab5c8cb 100644 --- a/handler/logout/logout.go +++ b/handler/logout/logout.go @@ -4,25 +4,18 @@ import ( "errors" "net/http" - "github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" ) -var log *logger.AggregatedLogger - -func init() { - log = logger.Logger() -} - func GET(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie("Session") if err != nil { return } - storeSession, err := session.GlobalSessionStore.Get(cookie.Value) + storeSession, err := session.Get(cookie.Value) if err != nil { if errors.Is(err, &session.SessionNotFoundError{}) { storeSession.Destroy(w) @@ -31,7 +24,7 @@ func GET(w http.ResponseWriter, r *http.Request) { return } - session.GlobalSessionStore.Delete(cookie.Value) + storeSession.Delete() session.RemoveSessionInfo(storeSession.Values["user"].(types.User).Email, cookie.Value) http.SetCookie(w, &http.Cookie{ diff --git a/handler/signin/signin.go b/handler/signin/signin.go index 2d410b0..fbdc88b 100644 --- a/handler/signin/signin.go +++ b/handler/signin/signin.go @@ -58,7 +58,7 @@ func POST(w http.ResponseWriter, r *http.Request) { } if email == userData.Email && utils.CheckPasswordHash(password, userData.Password) { - storeSession := session.GlobalSessionStore.Create() + storeSession := session.Create() storeSession.Values["user"] = types.User{ UserID: userData.UserID, Email: email, diff --git a/handler/signup/signup.go b/handler/signup/signup.go index 804a2a5..eb68f59 100644 --- a/handler/signup/signup.go +++ b/handler/signup/signup.go @@ -50,7 +50,7 @@ func init() { for _, data := range VerifyUser { data.mu.Lock() - if currentTime.Sub(data.CreateTime) > time.Minute*1 { + if currentTime.Sub(data.CreateTime) > time.Minute*10 { delete(VerifyUser, data.Code) delete(VerifyEmail, data.User.Email) cacheClean++ diff --git a/session/session.go b/session/session.go index a4a07f4..d0fc312 100644 --- a/session/session.go +++ b/session/session.go @@ -1,7 +1,8 @@ package session import ( - "errors" + "fmt" + "github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/types" "net/http" "strconv" @@ -12,13 +13,10 @@ import ( ) type Session struct { - ID string - Values map[string]interface{} -} - -type SessionStore struct { - Sessions map[string]*Session - mu sync.Mutex + ID string + Values map[string]interface{} + CreateTime time.Time + mu sync.Mutex } type SessionInfo struct { @@ -33,6 +31,7 @@ type SessionInfo struct { } type UserStatus string +type SessionNotFoundError struct{} const ( Authorized UserStatus = "authorized" @@ -40,38 +39,62 @@ const ( InvalidSession UserStatus = "invalid_session" ) -var GlobalSessionStore = SessionStore{Sessions: make(map[string]*Session)} +var GlobalSessionStore = make(map[string]*Session) var UserSessionInfoList = make(map[string]map[string]*SessionInfo) +var log *logger.AggregatedLogger -type SessionNotFoundError struct{} +func init() { + log = logger.Logger() + + ticker := time.NewTicker(time.Minute) + go func() { + for { + <-ticker.C + currentTime := time.Now() + cacheClean := 0 + cleanID := utils.GenerateRandomString(10) + log.Info(fmt.Sprintf("Cache cleanup [Session] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) + + for _, data := range GlobalSessionStore { + data.mu.Lock() + if currentTime.Sub(data.CreateTime) > time.Hour*24*7 { + RemoveSessionInfo(data.Values["user"].(types.User).Email, data.ID) + delete(GlobalSessionStore, data.ID) + cacheClean++ + } + data.mu.Unlock() + } + + log.Info(fmt.Sprintf("Cache cleanup [Session] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) + } + }() +} func (e *SessionNotFoundError) Error() string { return "session not found" } -func (s *SessionStore) Get(id string) (*Session, error) { - s.mu.Lock() - defer s.mu.Unlock() - if session, ok := s.Sessions[id]; ok { +func Get(id string) (*Session, error) { + if session, ok := GlobalSessionStore[id]; ok { return session, nil } return nil, &SessionNotFoundError{} } -func (s *SessionStore) Create() *Session { +func Create() *Session { id := utils.GenerateRandomString(128) session := &Session{ ID: id, Values: make(map[string]interface{}), } - s.Sessions[id] = session + GlobalSessionStore[id] = session return session } -func (s *SessionStore) Delete(id string) { +func (s *Session) Delete() { s.mu.Lock() defer s.mu.Unlock() - delete(s.Sessions, id) + delete(GlobalSessionStore, s.ID) } func (s *Session) Save(w http.ResponseWriter) { @@ -114,7 +137,7 @@ func RemoveSessionInfo(email string, id string) { func RemoveAllSessions(email string) { sessionInfos := UserSessionInfoList[email] for _, sessionInfo := range sessionInfos { - delete(GlobalSessionStore.Sessions, sessionInfo.SessionID) + delete(GlobalSessionStore, sessionInfo.SessionID) } delete(UserSessionInfoList, email) } @@ -140,17 +163,14 @@ func GetSession(r *http.Request) (UserStatus, types.User, string) { return Unauthorized, types.User{}, "" } - storeSession, err := GlobalSessionStore.Get(cookie.Value) - if err != nil { - if errors.Is(err, &SessionNotFoundError{}) { - return InvalidSession, types.User{}, "" - } - return Unauthorized, types.User{}, "" + storeSession, ok := GlobalSessionStore[cookie.Value] + if !ok { + return InvalidSession, types.User{}, "" } val := storeSession.Values["user"] var userSession = types.User{} - userSession, ok := val.(types.User) + userSession, ok = val.(types.User) if !ok { return Unauthorized, types.User{}, "" }