diff --git a/app/app.go b/app/app.go index bc1a906..2b1e150 100644 --- a/app/app.go +++ b/app/app.go @@ -1,9 +1,9 @@ package app import ( - "github.com/fossyy/filekeeper/db" "github.com/fossyy/filekeeper/email" "github.com/fossyy/filekeeper/logger" + "github.com/fossyy/filekeeper/types" "net/http" ) @@ -12,12 +12,14 @@ var Admin App type App struct { http.Server - Database db.Database + Database types.Database + Cache types.CachingServer + Service types.Services Logger *logger.AggregatedLogger Mail *email.SmtpServer } -func NewClientServer(addr string, handler http.Handler, logger logger.AggregatedLogger, database db.Database, mail email.SmtpServer) App { +func NewClientServer(addr string, handler http.Handler, logger logger.AggregatedLogger, database types.Database, cache types.CachingServer, service types.Services, mail email.SmtpServer) App { return App{ Server: http.Server{ Addr: addr, @@ -25,11 +27,13 @@ func NewClientServer(addr string, handler http.Handler, logger logger.Aggregated }, Logger: &logger, Database: database, + Cache: cache, + Service: service, Mail: &mail, } } -func NewAdminServer(addr string, handler http.Handler, database db.Database) App { +func NewAdminServer(addr string, handler http.Handler, database types.Database) App { return App{ Server: http.Server{ Addr: addr, diff --git a/cache/cache.go b/cache/cache.go index cc4dee5..49c8bb9 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,177 +1,67 @@ package cache import ( - "fmt" - "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/utils" - "github.com/google/uuid" - "sync" + "context" + "github.com/fossyy/filekeeper/types" + "github.com/redis/go-redis/v9" "time" ) -type UserWithExpired struct { - UserID uuid.UUID - Username string - Email string - Password string - Totp string - AccessAt time.Time - mu sync.Mutex +type RedisServer struct { + client *redis.Client + database types.Database } -type FileWithExpired struct { - ID uuid.UUID - OwnerID uuid.UUID - Name string - Size int64 - Downloaded int64 - UploadedByte int64 - UploadedChunk int64 - Done bool - AccessAt time.Time - mu sync.Mutex +func NewRedisServer(db types.Database) types.CachingServer { + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "Password123", + DB: 0, + }) + return &RedisServer{client: client, database: db} } -var userCache map[string]*UserWithExpired -var fileCache map[string]*FileWithExpired +func (r *RedisServer) GetCache(ctx context.Context, key string) (string, error) { + val, err := r.client.Get(ctx, key).Result() + if err != nil { + return "", err + } + return val, nil +} -func init() { +func (r *RedisServer) SetCache(ctx context.Context, key string, value interface{}, expiration time.Duration) error { + err := r.client.Set(ctx, key, value, expiration).Err() + if err != nil { + return err + } + return nil +} - userCache = make(map[string]*UserWithExpired) - fileCache = make(map[string]*FileWithExpired) - ticker := time.NewTicker(time.Minute) +func (r *RedisServer) DeleteCache(ctx context.Context, key string) error { + err := r.client.Del(ctx, key).Err() + if err != nil { + return err + } + return nil +} - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [user] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) +func (r *RedisServer) GetKeys(ctx context.Context, pattern string) ([]string, error) { + var cursor uint64 + var keys []string + for { + var newKeys []string + var err error - for _, user := range userCache { - user.mu.Lock() - if currentTime.Sub(user.AccessAt) > time.Hour*8 { - delete(userCache, user.Email) - cacheClean++ - } - user.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [user] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) + newKeys, cursor, err = r.client.Scan(ctx, cursor, pattern, 0).Result() + if err != nil { + return nil, err } - }() - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [files] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) + keys = append(keys, newKeys...) - for _, file := range fileCache { - file.mu.Lock() - if currentTime.Sub(file.AccessAt) > time.Minute*1 { - app.Server.Database.UpdateUploadedByte(file.UploadedByte, file.ID.String()) - app.Server.Database.UpdateUploadedChunk(file.UploadedChunk, file.ID.String()) - delete(fileCache, file.ID.String()) - cacheClean++ - } - file.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [files] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) + if cursor == 0 { + break } - }() + } + return keys, nil } - -func GetUser(email string) (*UserWithExpired, error) { - if user, ok := userCache[email]; ok { - return user, nil - } - - userData, err := app.Server.Database.GetUser(email) - if err != nil { - return nil, err - } - - userCache[email] = &UserWithExpired{ - UserID: userData.UserID, - Username: userData.Username, - Email: userData.Email, - Password: userData.Password, - Totp: userData.Totp, - AccessAt: time.Now(), - } - - return userCache[email], nil -} - -func DeleteUser(email string) { - userCache[email].mu.Lock() - defer userCache[email].mu.Unlock() - - delete(userCache, email) -} - -func GetFile(id string) (*FileWithExpired, error) { - if file, ok := fileCache[id]; ok { - file.AccessAt = time.Now() - return file, nil - } - - uploadData, err := app.Server.Database.GetFile(id) - if err != nil { - return nil, err - } - - fileCache[id] = &FileWithExpired{ - ID: uploadData.ID, - OwnerID: uploadData.OwnerID, - Name: uploadData.Name, - Size: uploadData.Size, - Downloaded: uploadData.Downloaded, - UploadedByte: uploadData.UploadedByte, - UploadedChunk: uploadData.UploadedChunk, - Done: uploadData.Done, - AccessAt: time.Now(), - } - - return fileCache[id], nil -} - -func (file *FileWithExpired) UpdateProgress(index int64, size int64) { - file.UploadedChunk = index - file.UploadedByte = size - file.AccessAt = time.Now() -} - -func GetUserFile(name, ownerID string) (*FileWithExpired, error) { - fileData, err := app.Server.Database.GetUserFile(name, ownerID) - if err != nil { - return nil, err - } - - file, err := GetFile(fileData.ID.String()) - if err != nil { - return nil, err - } - - return file, nil -} - -func (file *FileWithExpired) FinalizeFileUpload() { - app.Server.Database.UpdateUploadedByte(file.UploadedByte, file.ID.String()) - app.Server.Database.UpdateUploadedChunk(file.UploadedChunk, file.ID.String()) - app.Server.Database.FinalizeFileUpload(file.ID.String()) - delete(fileCache, file.ID.String()) - return -} - -//func DeleteUploadInfo(id string) { -// filesUploadedCache[id].mu.Lock() -// defer filesUploadedCache[id].mu.Unlock() -// -// delete(filesUploadedCache, id) -//} diff --git a/db/database.go b/db/database.go index feabbcf..9e76362 100644 --- a/db/database.go +++ b/db/database.go @@ -3,6 +3,7 @@ package db import ( "errors" "fmt" + "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/types/models" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -27,28 +28,7 @@ const ( EnableSSL SSLMode = "enable" ) -type Database interface { - IsUserRegistered(email string, username string) bool - IsEmailRegistered(email string) bool - - CreateUser(user *models.User) error - GetUser(email string) (*models.User, error) - GetAllUsers() ([]models.User, error) - UpdateUserPassword(email string, password string) error - - CreateFile(file *models.File) error - GetFile(fileID string) (*models.File, error) - GetUserFile(name string, ownerID string) (*models.File, error) - GetFiles(ownerID string) ([]*models.File, error) - - UpdateUploadedByte(index int64, fileID string) - UpdateUploadedChunk(index int64, fileID string) - FinalizeFileUpload(fileID string) - - InitializeTotp(email string, secret string) error -} - -func NewMYSQLdb(username, password, host, port, dbName string) Database { +func NewMYSQLdb(username, password, host, port, dbName string) types.Database { var err error var count int64 @@ -110,7 +90,7 @@ func NewMYSQLdb(username, password, host, port, dbName string) Database { return &mySQLdb{DB} } -func NewPostgresDB(username, password, host, port, dbName string, mode SSLMode) Database { +func NewPostgresDB(username, password, host, port, dbName string, mode SSLMode) types.Database { var err error var count int64 diff --git a/go.mod b/go.mod index 8858568..b01d232 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,8 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -29,6 +31,7 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/redis/go-redis/v9 v9.6.1 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/go.sum b/go.sum index 33922a0..3adffb8 100644 --- a/go.sum +++ b/go.sum @@ -2,9 +2,13 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/a-h/templ v0.2.707 h1:T1Gkd2ugbRglZ9rYw/VBchWOSZVKmetDbBkm4YubM7U= github.com/a-h/templ v0.2.707/go.mod h1:5cqsugkq9IerRNucNsI4DEamdHPsoGMQy99DzydLhM8= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= @@ -37,6 +41,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0y4= +github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA= github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= diff --git a/handler/auth/google/callback/callback.go b/handler/auth/google/callback/callback.go index c112c46..0b2a718 100644 --- a/handler/auth/google/callback/callback.go +++ b/handler/auth/google/callback/callback.go @@ -6,15 +6,14 @@ import ( "errors" "fmt" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" googleOauthSetupHandler "github.com/fossyy/filekeeper/handler/auth/google/setup" signinHandler "github.com/fossyy/filekeeper/handler/signin" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" + "github.com/redis/go-redis/v9" "net/http" "net/url" - "sync" "time" ) @@ -46,49 +45,22 @@ type OauthUser struct { VerifiedEmail bool `json:"verified_email"` } -type CsrfToken struct { - Token string - CreateTime time.Time - mu sync.Mutex -} - -var CsrfTokens map[string]*CsrfToken - -func init() { - - CsrfTokens = make(map[string]*CsrfToken) - - ticker := time.NewTicker(time.Minute) - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [csrf_token] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) - - for _, data := range CsrfTokens { - data.mu.Lock() - if currentTime.Sub(data.CreateTime) > time.Minute*10 { - delete(CsrfTokens, data.Token) - cacheClean++ - } - data.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [csrf_token] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) - } - }() -} - func GET(w http.ResponseWriter, r *http.Request) { - if _, ok := CsrfTokens[r.URL.Query().Get("state")]; !ok { - //csrf token mismatch error + _, err := app.Server.Cache.GetCache(r.Context(), "CsrfTokens:"+r.URL.Query().Get("state")) + if err != nil { + if errors.Is(err, redis.Nil) { + w.WriteHeader(http.StatusUnauthorized) + return + } w.WriteHeader(http.StatusInternalServerError) return } - delete(CsrfTokens, r.URL.Query().Get("state")) + err = app.Server.Cache.DeleteCache(r.Context(), "CsrfTokens:"+r.URL.Query().Get("state")) + if err != nil { + http.Redirect(w, r, fmt.Sprintf("/signin?error=%s", "csrf_token_error"), http.StatusFound) + return + } if err := r.URL.Query().Get("error"); err != "" { http.Redirect(w, r, fmt.Sprintf("/signin?error=%s", err), http.StatusFound) @@ -146,16 +118,23 @@ func GET(w http.ResponseWriter, r *http.Request) { if !app.Server.Database.IsEmailRegistered(oauthUser.Email) { code := utils.GenerateRandomString(64) - googleOauthSetupHandler.SetupUser[code] = &googleOauthSetupHandler.UnregisteredUser{ + + user := googleOauthSetupHandler.UnregisteredUser{ Code: code, Email: oauthUser.Email, CreateTime: time.Now(), } + newGoogleSetupJSON, _ := json.Marshal(user) + err = app.Server.Cache.SetCache(r.Context(), "GoogleSetup:"+code, newGoogleSetupJSON, time.Minute*15) + if err != nil { + fmt.Println("Error setting up Google Setup:", err) + return + } http.Redirect(w, r, fmt.Sprintf("/auth/google/setup/%s", code), http.StatusSeeOther) return } - user, err := cache.GetUser(oauthUser.Email) + user, err := app.Server.Service.GetUser(r.Context(), oauthUser.Email) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -163,12 +142,15 @@ func GET(w http.ResponseWriter, r *http.Request) { return } - storeSession := session.Create() - storeSession.Values["user"] = types.User{ + storeSession, err := session.Create(types.User{ UserID: user.UserID, Email: oauthUser.Email, Username: user.Username, Authenticated: true, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return } userAgent := r.Header.Get("User-Agent") diff --git a/handler/auth/google/google.go b/handler/auth/google/google.go index 5e50b1f..72ded9d 100644 --- a/handler/auth/google/google.go +++ b/handler/auth/google/google.go @@ -1,17 +1,31 @@ package googleOauthHandler import ( + "encoding/json" "fmt" "github.com/fossyy/filekeeper/app" - googleOauthCallbackHandler "github.com/fossyy/filekeeper/handler/auth/google/callback" "github.com/fossyy/filekeeper/utils" "net/http" "time" ) +type CsrfToken struct { + Token string + CreateTime time.Time +} + func GET(w http.ResponseWriter, r *http.Request) { token, err := utils.GenerateCSRFToken() - googleOauthCallbackHandler.CsrfTokens[token] = &googleOauthCallbackHandler.CsrfToken{Token: token, CreateTime: time.Now()} + csrfToken := CsrfToken{ + Token: token, + CreateTime: time.Now(), + } + newCsrfToken, err := json.Marshal(csrfToken) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + app.Server.Cache.SetCache(r.Context(), "CsrfTokens:"+token, newCsrfToken, time.Minute*15) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) diff --git a/handler/auth/google/setup/setup.go b/handler/auth/google/setup/setup.go index 17823f6..f49ca72 100644 --- a/handler/auth/google/setup/setup.go +++ b/handler/auth/google/setup/setup.go @@ -1,7 +1,8 @@ package googleOauthSetupHandler import ( - "fmt" + "encoding/json" + "errors" "github.com/fossyy/filekeeper/app" signinHandler "github.com/fossyy/filekeeper/handler/signin" "github.com/fossyy/filekeeper/session" @@ -11,8 +12,8 @@ import ( "github.com/fossyy/filekeeper/view/client/auth" signupView "github.com/fossyy/filekeeper/view/client/signup" "github.com/google/uuid" + "github.com/redis/go-redis/v9" "net/http" - "sync" "time" ) @@ -20,50 +21,26 @@ type UnregisteredUser struct { Code string Email string CreateTime time.Time - mu sync.Mutex -} - -var SetupUser map[string]*UnregisteredUser - -func init() { - - SetupUser = make(map[string]*UnregisteredUser) - - ticker := time.NewTicker(time.Minute) - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [GoogleSetup] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) - - for _, data := range SetupUser { - data.mu.Lock() - if currentTime.Sub(data.CreateTime) > time.Minute*10 { - delete(SetupUser, data.Code) - cacheClean++ - } - - data.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [GoogleSetup] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) - } - }() } func GET(w http.ResponseWriter, r *http.Request) { code := r.PathValue("code") - if _, ok := SetupUser[code]; !ok { - http.Redirect(w, r, "/signup", http.StatusSeeOther) + _, err := app.Server.Cache.GetCache(r.Context(), "GoogleSetup:"+code) + if err != nil { + if errors.Is(err, redis.Nil) { + http.Redirect(w, r, "/signup", http.StatusSeeOther) + return + } + w.WriteHeader(http.StatusInternalServerError) + app.Server.Logger.Error(err.Error()) return } + component := authView.GoogleSetup("Filekeeper - Setup Page", types.Message{ Code: 3, Message: "", }) - err := component.Render(r.Context(), w) + err = component.Render(r.Context(), w) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) @@ -73,12 +50,22 @@ func GET(w http.ResponseWriter, r *http.Request) { func POST(w http.ResponseWriter, r *http.Request) { code := r.PathValue("code") - unregisteredUser, ok := SetupUser[code] - if !ok { + cache, err := app.Server.Cache.GetCache(r.Context(), "GoogleSetup:"+code) + + if errors.Is(err, redis.Nil) { http.Error(w, "Unauthorized Action", http.StatusUnauthorized) return } - err := r.ParseForm() + + var unregisteredUser UnregisteredUser + err = json.Unmarshal([]byte(cache), &unregisteredUser) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + app.Server.Logger.Error(err.Error()) + return + } + + err = r.ParseForm() if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) @@ -126,14 +113,15 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - delete(SetupUser, code) - - storeSession := session.Create() - storeSession.Values["user"] = types.User{ + storeSession, err := session.Create(types.User{ UserID: userID, Email: unregisteredUser.Email, Username: username, Authenticated: true, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return } userAgent := r.Header.Get("User-Agent") diff --git a/handler/auth/totp/totp.go b/handler/auth/totp/totp.go index 7f3ad23..386d085 100644 --- a/handler/auth/totp/totp.go +++ b/handler/auth/totp/totp.go @@ -38,14 +38,18 @@ func POST(w http.ResponseWriter, r *http.Request) { if totp.Verify(code, time.Now().Unix()) { storeSession, err := session.Get(key) - if err != nil { - return - } - storeSession.Values["user"] = types.User{ + err = storeSession.Change(types.User{ UserID: user.UserID, Email: user.Email, Username: user.Username, Authenticated: true, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + if err != nil { + return } userAgent := r.Header.Get("User-Agent") browserInfo, osInfo := ParseUserAgent(userAgent) diff --git a/handler/forgotPassword/forgotPassword.go b/handler/forgotPassword/forgotPassword.go index 2af2053..e3eb791 100644 --- a/handler/forgotPassword/forgotPassword.go +++ b/handler/forgotPassword/forgotPassword.go @@ -3,13 +3,14 @@ package forgotPasswordHandler import ( "bytes" "context" + "encoding/json" "errors" "fmt" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/view/client/email" "github.com/fossyy/filekeeper/view/client/forgotPassword" "github.com/google/uuid" + "github.com/redis/go-redis/v9" "net/http" "sync" "time" @@ -27,35 +28,6 @@ type ForgotPassword struct { CreateTime time.Time } -var ListForgotPassword map[string]*ForgotPassword -var UserForgotPassword = make(map[string]string) - -func init() { - ListForgotPassword = make(map[string]*ForgotPassword) - ticker := time.NewTicker(time.Minute) - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Forgot Password] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) - - for _, data := range ListForgotPassword { - data.mu.Lock() - if currentTime.Sub(data.CreateTime) > time.Minute*10 { - delete(ListForgotPassword, data.User.Email) - delete(UserForgotPassword, data.Code) - cacheClean++ - } - data.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Forgot Password] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) - } - }() -} - func GET(w http.ResponseWriter, r *http.Request) { component := forgotPasswordView.Main("Filekeeper - Forgot Password Page", types.Message{ Code: 3, @@ -79,7 +51,7 @@ func POST(w http.ResponseWriter, r *http.Request) { emailForm := r.Form.Get("email") - user, err := cache.GetUser(emailForm) + user, err := app.Server.Service.GetUser(r.Context(), emailForm) if errors.Is(err, gorm.ErrRecordNotFound) { component := forgotPasswordView.Main("Filekeeper - Forgot Password Page", types.Message{ Code: 0, @@ -119,31 +91,48 @@ func POST(w http.ResponseWriter, r *http.Request) { } func verifyForgot(user *models.User) error { + var userData *ForgotPassword var code string + var err error + code, err = app.Server.Cache.GetCache(context.Background(), "ForgotPasswordCode:"+user.Email) + if err != nil { + if errors.Is(err, redis.Nil) { + code = utils.GenerateRandomString(64) + userData = &ForgotPassword{ + User: user, + Code: code, + CreateTime: time.Now(), + } - var buffer bytes.Buffer - data, ok := ListForgotPassword[user.Email] - - if !ok { - code = utils.GenerateRandomString(64) + newForgotUser, err := json.Marshal(userData) + if err != nil { + return err + } + err = app.Server.Cache.SetCache(context.Background(), "ForgotPasswordCode:"+user.Email, code, time.Minute*15) + if err != nil { + return err + } + err = app.Server.Cache.SetCache(context.Background(), "ForgotPassword:"+userData.Code, newForgotUser, time.Minute*15) + if err != nil { + return err + } + } else { + return err + } } else { - code = data.Code + storedCode, err := app.Server.Cache.GetCache(context.Background(), "ForgotPassword:"+code) + err = json.Unmarshal([]byte(storedCode), &userData) + if err != nil { + return err + } } - err := emailView.ForgotPassword(user.Username, fmt.Sprintf("https://%s/forgot-password/verify/%s", utils.Getenv("DOMAIN"), code)).Render(context.Background(), &buffer) + var buffer bytes.Buffer + err = emailView.ForgotPassword(user.Username, fmt.Sprintf("https://%s/forgot-password/verify/%s", utils.Getenv("DOMAIN"), code)).Render(context.Background(), &buffer) if err != nil { return err } - userData := &ForgotPassword{ - User: user, - Code: code, - CreateTime: time.Now(), - } - - UserForgotPassword[code] = user.Email - ListForgotPassword[user.Email] = userData - err = app.Server.Mail.Send(user.Email, "Password Change Request", buffer.String()) if err != nil { return err diff --git a/handler/forgotPassword/verify/verify.go b/handler/forgotPassword/verify/verify.go index 8deda5e..c329deb 100644 --- a/handler/forgotPassword/verify/verify.go +++ b/handler/forgotPassword/verify/verify.go @@ -1,14 +1,16 @@ package forgotPasswordVerifyHandler import ( + "encoding/json" + "errors" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" forgotPasswordHandler "github.com/fossyy/filekeeper/handler/forgotPassword" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" "github.com/fossyy/filekeeper/view/client/forgotPassword" signupView "github.com/fossyy/filekeeper/view/client/signup" + "github.com/redis/go-redis/v9" "net/http" ) @@ -21,11 +23,13 @@ func init() { func GET(w http.ResponseWriter, r *http.Request) { code := r.PathValue("code") - email := forgotPasswordHandler.UserForgotPassword[code] - _, ok := forgotPasswordHandler.ListForgotPassword[email] - - if !ok { - w.WriteHeader(http.StatusNotFound) + _, err := app.Server.Cache.GetCache(r.Context(), "ForgotPassword:"+code) + if err != nil { + if errors.Is(err, redis.Nil) { + w.WriteHeader(http.StatusNotFound) + return + } + w.WriteHeader(http.StatusInternalServerError) return } @@ -33,7 +37,7 @@ func GET(w http.ResponseWriter, r *http.Request) { Code: 3, Message: "", }) - err := component.Render(r.Context(), w) + err = component.Render(r.Context(), w) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) @@ -44,15 +48,20 @@ func GET(w http.ResponseWriter, r *http.Request) { func POST(w http.ResponseWriter, r *http.Request) { code := r.PathValue("code") - email := forgotPasswordHandler.UserForgotPassword[code] - data, ok := forgotPasswordHandler.ListForgotPassword[email] - - if !ok { + data, err := app.Server.Cache.GetCache(r.Context(), "ForgotPassword:"+code) + if err != nil { w.WriteHeader(http.StatusNotFound) return } + var userData *forgotPasswordHandler.ForgotPassword - err := r.ParseForm() + err = json.Unmarshal([]byte(data), &userData) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + err = r.ParseForm() if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) @@ -82,19 +91,19 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - err = app.Server.Database.UpdateUserPassword(data.User.Email, hashedPassword) + err = app.Server.Database.UpdateUserPassword(userData.User.Email, hashedPassword) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) return } - delete(forgotPasswordHandler.ListForgotPassword, data.User.Email) - delete(forgotPasswordHandler.UserForgotPassword, data.Code) + app.Server.Cache.DeleteCache(r.Context(), "ForgotPasswordCode:"+userData.User.Email) + app.Server.Cache.DeleteCache(r.Context(), "ForgotPassword:"+code) - session.RemoveAllSessions(data.User.Email) + session.RemoveAllSessions(userData.User.Email) - cache.DeleteUser(data.User.Email) + app.Server.Service.DeleteUser(userData.User.Email) component := forgotPasswordView.ChangeSuccess("Filekeeper - Forgot Password Page") err = component.Render(r.Context(), w) diff --git a/handler/logout/logout.go b/handler/logout/logout.go index 85cb180..5084c84 100644 --- a/handler/logout/logout.go +++ b/handler/logout/logout.go @@ -5,7 +5,6 @@ import ( "net/http" "github.com/fossyy/filekeeper/session" - "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" ) @@ -25,7 +24,7 @@ func GET(w http.ResponseWriter, r *http.Request) { } storeSession.Delete() - session.RemoveSessionInfo(storeSession.Values["user"].(types.User).Email, cookie.Value) + session.RemoveSessionInfo(storeSession.Values.Email, cookie.Value) http.SetCookie(w, &http.Cookie{ Name: utils.Getenv("SESSION_NAME"), diff --git a/handler/signin/signin.go b/handler/signin/signin.go index f1fc30c..8c01347 100644 --- a/handler/signin/signin.go +++ b/handler/signin/signin.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/a-h/templ" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" @@ -37,8 +36,8 @@ func init() { "login_required": "You need to log in again to proceed. Please try logging in again.", "account_selection_required": "Please select an account to proceed with the request.", "consent_required": "Consent is required to proceed. Please provide consent to continue.", + "csrf_token_error": "The CSRF token is missing or invalid. Please refresh the page and try again.", } - } func GET(w http.ResponseWriter, r *http.Request) { @@ -77,7 +76,7 @@ func POST(w http.ResponseWriter, r *http.Request) { } email := r.Form.Get("email") password := r.Form.Get("password") - userData, err := cache.GetUser(email) + userData, err := app.Server.Service.GetUser(r.Context(), email) if err != nil { component := signinView.Main("Filekeeper - Sign in Page", types.Message{ Code: 0, @@ -95,27 +94,33 @@ func POST(w http.ResponseWriter, r *http.Request) { if email == userData.Email && utils.CheckPasswordHash(password, userData.Password) { if userData.Totp != "" { - storeSession := session.Create() - storeSession.Values["user"] = types.User{ + + storeSession, err := session.Create(types.User{ UserID: userData.UserID, Email: email, Username: userData.Username, Totp: userData.Totp, Authenticated: false, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return } storeSession.Save(w) http.Redirect(w, r, "/auth/totp", http.StatusSeeOther) return } - storeSession := session.Create() - storeSession.Values["user"] = types.User{ + storeSession, err := session.Create(types.User{ UserID: userData.UserID, Email: email, Username: userData.Username, Authenticated: true, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return } - userAgent := r.Header.Get("User-Agent") browserInfo, osInfo := ParseUserAgent(userAgent) diff --git a/handler/signup/signup.go b/handler/signup/signup.go index 86d43ba..a08007d 100644 --- a/handler/signup/signup.go +++ b/handler/signup/signup.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "github.com/fossyy/filekeeper/app" + "github.com/fossyy/filekeeper/utils" "github.com/fossyy/filekeeper/view/client/email" signupView "github.com/fossyy/filekeeper/view/client/signup" "net/http" @@ -13,7 +14,6 @@ import ( "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/types/models" - "github.com/fossyy/filekeeper/utils" "github.com/google/uuid" ) diff --git a/handler/upload/initialisation/initialisation.go b/handler/upload/initialisation/initialisation.go index 9c9158f..bd302a7 100644 --- a/handler/upload/initialisation/initialisation.go +++ b/handler/upload/initialisation/initialisation.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "io" "net/http" "os" @@ -31,7 +30,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - fileData, err := cache.GetUserFile(fileInfo.Name, userSession.UserID.String()) + fileData, err := app.Server.Service.GetUserFile(fileInfo.Name, userSession.UserID.String()) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { upload, err := handleNewUpload(userSession, fileInfo) diff --git a/handler/upload/upload.go b/handler/upload/upload.go index b57f2b9..00304eb 100644 --- a/handler/upload/upload.go +++ b/handler/upload/upload.go @@ -1,15 +1,8 @@ package uploadHandler import ( - "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" - "github.com/fossyy/filekeeper/types" filesView "github.com/fossyy/filekeeper/view/client/upload" - "io" "net/http" - "os" - "path/filepath" - "strconv" ) func GET(w http.ResponseWriter, r *http.Request) { @@ -21,73 +14,74 @@ func GET(w http.ResponseWriter, r *http.Request) { } func POST(w http.ResponseWriter, r *http.Request) { - fileID := r.PathValue("id") - if err := r.ParseMultipartForm(32 << 20); err != nil { - w.WriteHeader(http.StatusInternalServerError) - return - } - - userSession := r.Context().Value("user").(types.User) - - uploadDir := "uploads" - if _, err := os.Stat(uploadDir); os.IsNotExist(err) { - if err := os.Mkdir(uploadDir, os.ModePerm); err != nil { - app.Server.Logger.Error("error getting upload info: " + err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - } - - file, err := cache.GetFile(fileID) - if err != nil { - app.Server.Logger.Error("error getting upload info: " + err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - - currentDir, _ := os.Getwd() - basePath := filepath.Join(currentDir, uploadDir) - saveFolder := filepath.Join(basePath, userSession.UserID.String(), file.ID.String()) - - if filepath.Dir(saveFolder) != filepath.Join(basePath, userSession.UserID.String()) { - app.Server.Logger.Error("invalid path") - w.WriteHeader(http.StatusInternalServerError) - return - } - - fileByte, fileHeader, err := r.FormFile("chunk") - if err != nil { - app.Server.Logger.Error("error getting upload info: " + err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - defer fileByte.Close() - - rawIndex := r.FormValue("index") - index, err := strconv.Atoi(rawIndex) - if err != nil { - return - } - - file.UpdateProgress(int64(index), file.UploadedByte+int64(fileHeader.Size)) - - dst, err := os.OpenFile(filepath.Join(saveFolder, file.Name), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) - if err != nil { - app.Server.Logger.Error("error making upload folder: " + err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - - defer dst.Close() - if _, err := io.Copy(dst, fileByte); err != nil { - app.Server.Logger.Error("error copying byte to file dst: " + err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - - if file.UploadedByte >= file.Size { - file.FinalizeFileUpload() - return - } return + //fileID := r.PathValue("id") + //if err := r.ParseMultipartForm(32 << 20); err != nil { + // w.WriteHeader(http.StatusInternalServerError) + // return + //} + // + //userSession := r.Context().Value("user").(types.User) + // + //uploadDir := "uploads" + //if _, err := os.Stat(uploadDir); os.IsNotExist(err) { + // if err := os.Mkdir(uploadDir, os.ModePerm); err != nil { + // app.Server.Logger.Error("error getting upload info: " + err.Error()) + // w.WriteHeader(http.StatusInternalServerError) + // return + // } + //} + // + //file, err := app.Server.Service.GetFile(fileID) + //if err != nil { + // app.Server.Logger.Error("error getting upload info: " + err.Error()) + // w.WriteHeader(http.StatusInternalServerError) + // return + //} + // + //currentDir, _ := os.Getwd() + //basePath := filepath.Join(currentDir, uploadDir) + //saveFolder := filepath.Join(basePath, userSession.UserID.String(), file.ID.String()) + // + //if filepath.Dir(saveFolder) != filepath.Join(basePath, userSession.UserID.String()) { + // app.Server.Logger.Error("invalid path") + // w.WriteHeader(http.StatusInternalServerError) + // return + //} + // + //fileByte, fileHeader, err := r.FormFile("chunk") + //if err != nil { + // app.Server.Logger.Error("error getting upload info: " + err.Error()) + // w.WriteHeader(http.StatusInternalServerError) + // return + //} + //defer fileByte.Close() + // + //rawIndex := r.FormValue("index") + //index, err := strconv.Atoi(rawIndex) + //if err != nil { + // return + //} + // + //file.UpdateProgress(int64(index), file.UploadedByte+int64(fileHeader.Size)) + // + //dst, err := os.OpenFile(filepath.Join(saveFolder, file.Name), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) + //if err != nil { + // app.Server.Logger.Error("error making upload folder: " + err.Error()) + // w.WriteHeader(http.StatusInternalServerError) + // return + //} + // + //defer dst.Close() + //if _, err := io.Copy(dst, fileByte); err != nil { + // app.Server.Logger.Error("error copying byte to file dst: " + err.Error()) + // w.WriteHeader(http.StatusInternalServerError) + // return + //} + // + //if file.UploadedByte >= file.Size { + // file.FinalizeFileUpload() + // return + //} + //return } diff --git a/handler/user/ResetPassword/ResetPassword.go b/handler/user/ResetPassword/ResetPassword.go index 0f0f934..5d76aa7 100644 --- a/handler/user/ResetPassword/ResetPassword.go +++ b/handler/user/ResetPassword/ResetPassword.go @@ -2,7 +2,6 @@ package userHandlerResetPassword import ( "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" @@ -14,7 +13,7 @@ func POST(w http.ResponseWriter, r *http.Request) { userSession := r.Context().Value("user").(types.User) currentPassword := r.Form.Get("currentPassword") password := r.Form.Get("password") - user, err := cache.GetUser(userSession.Email) + user, err := app.Server.Service.GetUser(r.Context(), userSession.Email) if err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -38,7 +37,7 @@ func POST(w http.ResponseWriter, r *http.Request) { } session.RemoveAllSessions(userSession.Email) - cache.DeleteUser(userSession.Email) + app.Server.Service.DeleteUser(userSession.Email) http.Redirect(w, r, "/signin", http.StatusSeeOther) return diff --git a/handler/user/session/terminate/terminate.go b/handler/user/session/terminate/terminate.go index f1dd852..eac09e0 100644 --- a/handler/user/session/terminate/terminate.go +++ b/handler/user/session/terminate/terminate.go @@ -11,15 +11,21 @@ func DELETE(w http.ResponseWriter, r *http.Request) { _, mySession, _ := session.GetSession(r) otherSession, _ := session.Get(id) - if session.GetSessionInfo(mySession.Email, otherSession.ID) == nil { + if _, err := session.GetSessionInfo(mySession.Email, otherSession.ID); err != nil { w.WriteHeader(http.StatusUnauthorized) return } + otherSession.Delete() session.RemoveSessionInfo(mySession.Email, otherSession.ID) - component := userView.SessionTable(session.GetSessions(mySession.Email)) + sessions, err := session.GetSessions(mySession.Email) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + component := userView.SessionTable(sessions) - err := component.Render(r.Context(), w) + err = component.Render(r.Context(), w) if err != nil { w.WriteHeader(http.StatusInternalServerError) return diff --git a/handler/user/totp/setup.go b/handler/user/totp/setup.go index 600e917..89b2817 100644 --- a/handler/user/totp/setup.go +++ b/handler/user/totp/setup.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "fmt" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/view/client/user/totp" "image/png" "net/http" @@ -75,7 +74,7 @@ func POST(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - cache.DeleteUser(userSession.Email) + app.Server.Service.DeleteUser(userSession.Email) component := userTotpSetupView.Main("Filekeeper - 2FA Setup Page", base64Str, secret, userSession, types.Message{ Code: 1, Message: "Your TOTP setup is complete! Your account is now more secure.", diff --git a/handler/user/user.go b/handler/user/user.go index fc57ce2..463fcb5 100644 --- a/handler/user/user.go +++ b/handler/user/user.go @@ -17,24 +17,28 @@ var errorMessages = map[string]string{ func GET(w http.ResponseWriter, r *http.Request) { var component templ.Component userSession := r.Context().Value("user").(types.User) - + sessions, err := session.GetSessions(userSession.Email) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } if err := r.URL.Query().Get("error"); err != "" { message, ok := errorMessages[err] if !ok { message = "Unknown error occurred. Please contact support at bagas@fossy.my.id for assistance." } - component = userView.Main("Filekeeper - User Page", userSession, session.GetSessions(userSession.Email), types.Message{ + component = userView.Main("Filekeeper - User Page", userSession, sessions, types.Message{ Code: 0, Message: message, }) } else { - component = userView.Main("Filekeeper - User Page", userSession, session.GetSessions(userSession.Email), types.Message{ + component = userView.Main("Filekeeper - User Page", userSession, sessions, types.Message{ Code: 1, Message: "", }) } - err := component.Render(r.Context(), w) + err = component.Render(r.Context(), w) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) diff --git a/main.go b/main.go index 9f5eb97..addb202 100644 --- a/main.go +++ b/main.go @@ -3,12 +3,14 @@ package main import ( "fmt" "github.com/fossyy/filekeeper/app" + "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/db" "github.com/fossyy/filekeeper/email" "github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/middleware" "github.com/fossyy/filekeeper/routes/admin" "github.com/fossyy/filekeeper/routes/client" + "github.com/fossyy/filekeeper/service" "github.com/fossyy/filekeeper/utils" "strconv" ) @@ -24,11 +26,13 @@ func main() { dbName := utils.Getenv("DB_NAME") database := db.NewPostgresDB(dbUser, dbPass, dbHost, dbPort, dbName, db.DisableSSL) + cacheServer := cache.NewRedisServer(database) + services := service.NewService(database, cacheServer) smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT")) mailServer := email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) - app.Server = app.NewClientServer(clientAddr, middleware.Handler(client.SetupRoutes()), *logger.Logger(), database, mailServer) + app.Server = app.NewClientServer(clientAddr, middleware.Handler(client.SetupRoutes()), *logger.Logger(), database, cacheServer, services, mailServer) app.Admin = app.NewAdminServer(adminAddr, middleware.Handler(admin.SetupRoutes()), database) go func() { diff --git a/middleware/middleware.go b/middleware/middleware.go index 6e6d35b..f6132b6 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -81,13 +81,10 @@ func Handler(next http.Handler) http.Handler { } func Auth(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) { - status, user, sessionID := session.GetSession(r) + status, user, _ := session.GetSession(r) switch status { case session.Authorized: - userSession := session.GetSessionInfo(user.Email, sessionID) - userSession.UpdateAccessTime() - ctx := context.WithValue(r.Context(), "user", user) req := r.WithContext(ctx) r.Context().Value("user") diff --git a/service/service.go b/service/service.go new file mode 100644 index 0000000..39deec4 --- /dev/null +++ b/service/service.go @@ -0,0 +1,120 @@ +package service + +import ( + "context" + "encoding/json" + "github.com/fossyy/filekeeper/app" + "github.com/fossyy/filekeeper/types" + "github.com/redis/go-redis/v9" + "time" +) + +type Service struct { + db types.Database + cache types.CachingServer +} + +func NewService(db types.Database, cache types.CachingServer) *Service { + return &Service{ + db: db, + cache: cache, + } +} + +func (r *Service) GetUser(ctx context.Context, email string) (*types.UserWithExpired, error) { + userJSON, err := app.Server.Cache.GetCache(ctx, "UserCache:"+email) + if err == redis.Nil { + userData, err := r.db.GetUser(email) + if err != nil { + return nil, err + } + + user := &types.UserWithExpired{ + UserID: userData.UserID, + Username: userData.Username, + Email: userData.Email, + Password: userData.Password, + Totp: userData.Totp, + AccessAt: time.Now(), + } + + newUserJSON, _ := json.Marshal(user) + err = r.cache.SetCache(ctx, email, newUserJSON, time.Hour*24) + if err != nil { + return nil, err + } + + return user, nil + } + if err != nil { + return nil, err + } + + var user types.UserWithExpired + err = json.Unmarshal([]byte(userJSON), &user) + if err != nil { + return nil, err + } + + return &user, nil +} + +func (r *Service) DeleteUser(email string) { + err := r.cache.DeleteCache(context.Background(), "UserCache:"+email) + if err != nil { + return + } +} + +func (r *Service) GetFile(id string) (*types.FileWithExpired, error) { + fileJSON, err := r.cache.GetCache(context.Background(), "FileCache:"+id) + if err == redis.Nil { + uploadData, err := r.db.GetFile(id) + if err != nil { + return nil, err + } + + fileCache := &types.FileWithExpired{ + ID: uploadData.ID, + OwnerID: uploadData.OwnerID, + Name: uploadData.Name, + Size: uploadData.Size, + Downloaded: uploadData.Downloaded, + UploadedByte: uploadData.UploadedByte, + UploadedChunk: uploadData.UploadedChunk, + Done: uploadData.Done, + AccessAt: time.Now(), + } + + newFileJSON, _ := json.Marshal(fileCache) + err = r.cache.SetCache(context.Background(), "FileCache:"+id, newFileJSON, time.Hour*24) + if err != nil { + return nil, err + } + return fileCache, nil + } + if err != nil { + return nil, err + } + + var fileCache types.FileWithExpired + err = json.Unmarshal([]byte(fileJSON), &fileCache) + if err != nil { + return nil, err + } + return &fileCache, nil +} + +func (r *Service) GetUserFile(name, ownerID string) (*types.FileWithExpired, error) { + fileData, err := r.db.GetUserFile(name, ownerID) + if err != nil { + return nil, err + } + + file, err := r.GetFile(fileData.ID.String()) + if err != nil { + return nil, err + } + + return file, nil +} diff --git a/session/session.go b/session/session.go index 61a7fba..565c4a9 100644 --- a/session/session.go +++ b/session/session.go @@ -1,12 +1,15 @@ package session import ( - "fmt" + "context" + "encoding/json" + "errors" "github.com/fossyy/filekeeper/app" "github.com/fossyy/filekeeper/types" + "github.com/redis/go-redis/v9" "net/http" "strconv" - "sync" + "strings" "time" "github.com/fossyy/filekeeper/utils" @@ -14,9 +17,8 @@ import ( type Session struct { ID string - Values map[string]interface{} + Values types.User CreateTime time.Time - mu sync.Mutex } type SessionInfo struct { @@ -27,7 +29,6 @@ type SessionInfo struct { OSVersion string IP string Location string - AccessAt string } type UserStatus string @@ -39,61 +40,65 @@ const ( InvalidSession UserStatus = "invalid_session" ) -var GlobalSessionStore = make(map[string]*Session) -var UserSessionInfoList = make(map[string]map[string]*SessionInfo) - -func init() { - - ticker := time.NewTicker(time.Minute) - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Session] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) - - for _, data := range GlobalSessionStore { - data.mu.Lock() - if currentTime.Sub(data.CreateTime) > time.Hour*24*7 { - RemoveSessionInfo(data.Values["user"].(types.User).Email, data.ID) - delete(GlobalSessionStore, data.ID) - cacheClean++ - } - data.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Session] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) - } - }() -} - func (e *SessionNotFoundError) Error() string { return "session not found" } func Get(id string) (*Session, error) { - if session, ok := GlobalSessionStore[id]; ok { - return session, nil + session, err := app.Server.Cache.GetCache(context.Background(), "Session:"+id) + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, &SessionNotFoundError{} + } + return nil, err } - return nil, &SessionNotFoundError{} + var userSession Session + err = json.Unmarshal([]byte(session), &userSession) + if err != nil { + return nil, err + } + return &userSession, nil } -func Create() *Session { +func Create(values types.User) (*Session, error) { id := utils.GenerateRandomString(128) session := &Session{ ID: id, CreateTime: time.Now(), - Values: make(map[string]interface{}), + Values: values, } - GlobalSessionStore[id] = session - return session + + sessionData, err := json.Marshal(session) + if err != nil { + return nil, err + } + + err = app.Server.Cache.SetCache(context.Background(), "Session:"+id, string(sessionData), time.Hour*24) // Set expiration time as needed + if err != nil { + return nil, err + } + + return session, nil } -func (s *Session) Delete() { - s.mu.Lock() - defer s.mu.Unlock() - delete(GlobalSessionStore, s.ID) +func (s *Session) Change(user types.User) error { + newSessionValue, err := json.Marshal(user) + if err != nil { + return err + } + err = app.Server.Cache.SetCache(context.Background(), "Session:"+s.ID, newSessionValue, time.Hour*24*7) + if err != nil { + return err + } + return nil +} + +func (s *Session) Delete() error { + err := app.Server.Cache.DeleteCache(context.Background(), "Session:"+s.ID) + if err != nil { + return err + } + return nil } func (s *Session) Save(w http.ResponseWriter) { @@ -114,46 +119,71 @@ func (s *Session) Destroy(w http.ResponseWriter) { }) } -func AddSessionInfo(email string, sessionInfo *SessionInfo) { - if _, ok := UserSessionInfoList[email]; !ok { - UserSessionInfoList[email] = make(map[string]*SessionInfo) +func AddSessionInfo(email string, sessionInfo *SessionInfo) error { + sessionInfoData, err := json.Marshal(sessionInfo) + if err != nil { + return err } - UserSessionInfoList[email][sessionInfo.SessionID] = sessionInfo -} - -func RemoveSessionInfo(email string, id string) { - if userSessions, ok := UserSessionInfoList[email]; ok { - if _, ok := userSessions[id]; ok { - delete(userSessions, id) - if len(userSessions) == 0 { - delete(UserSessionInfoList, email) - } - } + err = app.Server.Cache.SetCache(context.Background(), "UserSessionInfo:"+email+":"+sessionInfo.SessionID, string(sessionInfoData), 0) + if err != nil { + return err } + + return nil } -func RemoveAllSessions(email string) { - sessionInfos := UserSessionInfoList[email] - for _, sessionInfo := range sessionInfos { - delete(GlobalSessionStore, sessionInfo.SessionID) - } - delete(UserSessionInfoList, email) -} - -func GetSessionInfo(email string, id string) *SessionInfo { - if userSession, ok := UserSessionInfoList[email]; ok { - if sessionInfo, ok := userSession[id]; ok { - return sessionInfo - } +func RemoveSessionInfo(email string, id string) error { + key := "UserSessionInfo:" + email + ":" + id + err := app.Server.Cache.DeleteCache(context.Background(), key) + if err != nil { + return err } return nil } -func (sessionInfo *SessionInfo) UpdateAccessTime() { - currentTime := time.Now() - formattedTime := currentTime.Format("01-02-2006") - sessionInfo.AccessAt = formattedTime +func RemoveAllSessions(email string) error { + pattern := "UserSessionInfo:" + email + ":*" + keys, err := app.Server.Cache.GetKeys(context.Background(), pattern) + if err != nil { + return err + } + + for _, key := range keys { + parts := strings.Split(key, ":") + sessionID := parts[2] + + err = app.Server.Cache.DeleteCache(context.Background(), "Session:"+sessionID) + if err != nil { + return err + } + err = app.Server.Cache.DeleteCache(context.Background(), key) + if err != nil { + return err + } + } + + return nil +} + +func GetSessionInfo(email string, id string) (*SessionInfo, error) { + key := "UserSessionInfo:" + email + ":" + id + + sessionInfoData, err := app.Server.Cache.GetCache(context.Background(), key) + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, nil + } + return nil, err + } + + var sessionInfo SessionInfo + err = json.Unmarshal([]byte(sessionInfoData), &sessionInfo) + if err != nil { + return nil, err + } + + return &sessionInfo, nil } func GetSession(r *http.Request) (UserStatus, types.User, string) { @@ -162,18 +192,22 @@ func GetSession(r *http.Request) (UserStatus, types.User, string) { return Unauthorized, types.User{}, "" } - storeSession, ok := GlobalSessionStore[cookie.Value] - if !ok { - return InvalidSession, types.User{}, "" - } - - val := storeSession.Values["user"] - var userSession = types.User{} - userSession, ok = val.(types.User) - if !ok { + sessionData, err := app.Server.Cache.GetCache(context.Background(), "Session:"+cookie.Value) + if err != nil { + if errors.Is(err, redis.Nil) { + return InvalidSession, types.User{}, "" + } return Unauthorized, types.User{}, "" } + var storeSession Session + err = json.Unmarshal([]byte(sessionData), &storeSession) + if err != nil { + return Unauthorized, types.User{}, "" + } + + userSession := storeSession.Values + if !userSession.Authenticated && userSession.Totp != "" { return Unauthorized, userSession, cookie.Value } @@ -185,13 +219,27 @@ func GetSession(r *http.Request) (UserStatus, types.User, string) { return Authorized, userSession, cookie.Value } -func GetSessions(email string) []*SessionInfo { - if sessions, ok := UserSessionInfoList[email]; ok { - result := make([]*SessionInfo, 0, len(sessions)) - for _, sessionInfo := range sessions { - result = append(result, sessionInfo) - } - return result +func GetSessions(email string) ([]*SessionInfo, error) { + pattern := "UserSessionInfo:" + email + ":*" + keys, err := app.Server.Cache.GetKeys(context.Background(), pattern) + if err != nil { + return nil, err } - return nil + + var sessions []*SessionInfo + for _, key := range keys { + sessionData, err := app.Server.Cache.GetCache(context.Background(), key) + if err != nil { + return nil, err + } + + var sessionInfo SessionInfo + err = json.Unmarshal([]byte(sessionData), &sessionInfo) + if err != nil { + return nil, err + } + sessions = append(sessions, &sessionInfo) + } + + return sessions, nil } diff --git a/types/types.go b/types/types.go index 7eda14f..1d71e8f 100644 --- a/types/types.go +++ b/types/types.go @@ -1,7 +1,10 @@ package types import ( + "context" + "github.com/fossyy/filekeeper/types/models" "github.com/google/uuid" + "time" ) type Message struct { @@ -29,3 +32,59 @@ type FileData struct { Size string Downloaded int64 } + +type UserWithExpired struct { + UserID uuid.UUID + Username string + Email string + Password string + Totp string + AccessAt time.Time +} + +type FileWithExpired struct { + ID uuid.UUID + OwnerID uuid.UUID + Name string + Size int64 + Downloaded int64 + UploadedByte int64 + UploadedChunk int64 + Done bool + AccessAt time.Time +} + +type Database interface { + IsUserRegistered(email string, username string) bool + IsEmailRegistered(email string) bool + + CreateUser(user *models.User) error + GetUser(email string) (*models.User, error) + GetAllUsers() ([]models.User, error) + UpdateUserPassword(email string, password string) error + + CreateFile(file *models.File) error + GetFile(fileID string) (*models.File, error) + GetUserFile(name string, ownerID string) (*models.File, error) + GetFiles(ownerID string) ([]*models.File, error) + + UpdateUploadedByte(index int64, fileID string) + UpdateUploadedChunk(index int64, fileID string) + FinalizeFileUpload(fileID string) + + InitializeTotp(email string, secret string) error +} + +type CachingServer interface { + GetCache(ctx context.Context, key string) (string, error) + SetCache(ctx context.Context, key string, value interface{}, expiration time.Duration) error + DeleteCache(ctx context.Context, key string) error + GetKeys(ctx context.Context, pattern string) ([]string, error) +} + +type Services interface { + GetUser(ctx context.Context, email string) (*UserWithExpired, error) + DeleteUser(email string) + GetFile(id string) (*FileWithExpired, error) + GetUserFile(name, ownerID string) (*FileWithExpired, error) +} diff --git a/view/client/user/user.templ b/view/client/user/user.templ index f9e1417..fd1b597 100644 --- a/view/client/user/user.templ +++ b/view/client/user/user.templ @@ -99,10 +99,6 @@ templ content(message types.Message, title string, user types.User, ListSession class="h-12 px-4 text-left align-middle font-medium text-muted-foreground [&:has([role=checkbox])]:pr-0"> Device - - Last Activity - Actions @@ -119,8 +115,6 @@ templ content(message types.Message, title string, user types.User, ListSession {ses.OS + ses.OSVersion} - {ses.AccessAt} -