diff --git a/app/app.go b/app/app.go index bc1a906..2b1e150 100644 --- a/app/app.go +++ b/app/app.go @@ -1,9 +1,9 @@ package app import ( - "github.com/fossyy/filekeeper/db" "github.com/fossyy/filekeeper/email" "github.com/fossyy/filekeeper/logger" + "github.com/fossyy/filekeeper/types" "net/http" ) @@ -12,12 +12,14 @@ var Admin App type App struct { http.Server - Database db.Database + Database types.Database + Cache types.CachingServer + Service types.Services Logger *logger.AggregatedLogger Mail *email.SmtpServer } -func NewClientServer(addr string, handler http.Handler, logger logger.AggregatedLogger, database db.Database, mail email.SmtpServer) App { +func NewClientServer(addr string, handler http.Handler, logger logger.AggregatedLogger, database types.Database, cache types.CachingServer, service types.Services, mail email.SmtpServer) App { return App{ Server: http.Server{ Addr: addr, @@ -25,11 +27,13 @@ func NewClientServer(addr string, handler http.Handler, logger logger.Aggregated }, Logger: &logger, Database: database, + Cache: cache, + Service: service, Mail: &mail, } } -func NewAdminServer(addr string, handler http.Handler, database db.Database) App { +func NewAdminServer(addr string, handler http.Handler, database types.Database) App { return App{ Server: http.Server{ Addr: addr, diff --git a/cache/cache.go b/cache/cache.go index cc4dee5..49c8bb9 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,177 +1,67 @@ package cache import ( - "fmt" - "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/utils" - "github.com/google/uuid" - "sync" + "context" + "github.com/fossyy/filekeeper/types" + "github.com/redis/go-redis/v9" "time" ) -type UserWithExpired struct { - UserID uuid.UUID - Username string - Email string - Password string - Totp string - AccessAt time.Time - mu sync.Mutex +type RedisServer struct { + client *redis.Client + database types.Database } -type FileWithExpired struct { - ID uuid.UUID - OwnerID uuid.UUID - Name string - Size int64 - Downloaded int64 - UploadedByte int64 - UploadedChunk int64 - Done bool - AccessAt time.Time - mu sync.Mutex +func NewRedisServer(db types.Database) types.CachingServer { + client := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "Password123", + DB: 0, + }) + return &RedisServer{client: client, database: db} } -var userCache map[string]*UserWithExpired -var fileCache map[string]*FileWithExpired +func (r *RedisServer) GetCache(ctx context.Context, key string) (string, error) { + val, err := r.client.Get(ctx, key).Result() + if err != nil { + return "", err + } + return val, nil +} -func init() { +func (r *RedisServer) SetCache(ctx context.Context, key string, value interface{}, expiration time.Duration) error { + err := r.client.Set(ctx, key, value, expiration).Err() + if err != nil { + return err + } + return nil +} - userCache = make(map[string]*UserWithExpired) - fileCache = make(map[string]*FileWithExpired) - ticker := time.NewTicker(time.Minute) +func (r *RedisServer) DeleteCache(ctx context.Context, key string) error { + err := r.client.Del(ctx, key).Err() + if err != nil { + return err + } + return nil +} - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [user] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) +func (r *RedisServer) GetKeys(ctx context.Context, pattern string) ([]string, error) { + var cursor uint64 + var keys []string + for { + var newKeys []string + var err error - for _, user := range userCache { - user.mu.Lock() - if currentTime.Sub(user.AccessAt) > time.Hour*8 { - delete(userCache, user.Email) - cacheClean++ - } - user.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [user] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) + newKeys, cursor, err = r.client.Scan(ctx, cursor, pattern, 0).Result() + if err != nil { + return nil, err } - }() - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [files] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) + keys = append(keys, newKeys...) - for _, file := range fileCache { - file.mu.Lock() - if currentTime.Sub(file.AccessAt) > time.Minute*1 { - app.Server.Database.UpdateUploadedByte(file.UploadedByte, file.ID.String()) - app.Server.Database.UpdateUploadedChunk(file.UploadedChunk, file.ID.String()) - delete(fileCache, file.ID.String()) - cacheClean++ - } - file.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [files] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) + if cursor == 0 { + break } - }() + } + return keys, nil } - -func GetUser(email string) (*UserWithExpired, error) { - if user, ok := userCache[email]; ok { - return user, nil - } - - userData, err := app.Server.Database.GetUser(email) - if err != nil { - return nil, err - } - - userCache[email] = &UserWithExpired{ - UserID: userData.UserID, - Username: userData.Username, - Email: userData.Email, - Password: userData.Password, - Totp: userData.Totp, - AccessAt: time.Now(), - } - - return userCache[email], nil -} - -func DeleteUser(email string) { - userCache[email].mu.Lock() - defer userCache[email].mu.Unlock() - - delete(userCache, email) -} - -func GetFile(id string) (*FileWithExpired, error) { - if file, ok := fileCache[id]; ok { - file.AccessAt = time.Now() - return file, nil - } - - uploadData, err := app.Server.Database.GetFile(id) - if err != nil { - return nil, err - } - - fileCache[id] = &FileWithExpired{ - ID: uploadData.ID, - OwnerID: uploadData.OwnerID, - Name: uploadData.Name, - Size: uploadData.Size, - Downloaded: uploadData.Downloaded, - UploadedByte: uploadData.UploadedByte, - UploadedChunk: uploadData.UploadedChunk, - Done: uploadData.Done, - AccessAt: time.Now(), - } - - return fileCache[id], nil -} - -func (file *FileWithExpired) UpdateProgress(index int64, size int64) { - file.UploadedChunk = index - file.UploadedByte = size - file.AccessAt = time.Now() -} - -func GetUserFile(name, ownerID string) (*FileWithExpired, error) { - fileData, err := app.Server.Database.GetUserFile(name, ownerID) - if err != nil { - return nil, err - } - - file, err := GetFile(fileData.ID.String()) - if err != nil { - return nil, err - } - - return file, nil -} - -func (file *FileWithExpired) FinalizeFileUpload() { - app.Server.Database.UpdateUploadedByte(file.UploadedByte, file.ID.String()) - app.Server.Database.UpdateUploadedChunk(file.UploadedChunk, file.ID.String()) - app.Server.Database.FinalizeFileUpload(file.ID.String()) - delete(fileCache, file.ID.String()) - return -} - -//func DeleteUploadInfo(id string) { -// filesUploadedCache[id].mu.Lock() -// defer filesUploadedCache[id].mu.Unlock() -// -// delete(filesUploadedCache, id) -//} diff --git a/db/database.go b/db/database.go index feabbcf..0ea68aa 100644 --- a/db/database.go +++ b/db/database.go @@ -3,6 +3,7 @@ package db import ( "errors" "fmt" + "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/types/models" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -27,28 +28,7 @@ const ( EnableSSL SSLMode = "enable" ) -type Database interface { - IsUserRegistered(email string, username string) bool - IsEmailRegistered(email string) bool - - CreateUser(user *models.User) error - GetUser(email string) (*models.User, error) - GetAllUsers() ([]models.User, error) - UpdateUserPassword(email string, password string) error - - CreateFile(file *models.File) error - GetFile(fileID string) (*models.File, error) - GetUserFile(name string, ownerID string) (*models.File, error) - GetFiles(ownerID string) ([]*models.File, error) - - UpdateUploadedByte(index int64, fileID string) - UpdateUploadedChunk(index int64, fileID string) - FinalizeFileUpload(fileID string) - - InitializeTotp(email string, secret string) error -} - -func NewMYSQLdb(username, password, host, port, dbName string) Database { +func NewMYSQLdb(username, password, host, port, dbName string) types.Database { var err error var count int64 @@ -110,7 +90,7 @@ func NewMYSQLdb(username, password, host, port, dbName string) Database { return &mySQLdb{DB} } -func NewPostgresDB(username, password, host, port, dbName string, mode SSLMode) Database { +func NewPostgresDB(username, password, host, port, dbName string, mode SSLMode) types.Database { var err error var count int64 @@ -255,27 +235,6 @@ func (db *mySQLdb) GetFiles(ownerID string) ([]*models.File, error) { return files, err } -func (db *mySQLdb) UpdateUploadedByte(byte int64, fileID string) { - var file models.File - db.DB.Table("files").Where("id = ?", fileID).First(&file) - file.UploadedByte = byte - db.Save(&file) -} - -func (db *mySQLdb) UpdateUploadedChunk(index int64, fileID string) { - var file models.File - db.DB.Table("files").Where("id = ?", fileID).First(&file) - file.UploadedChunk = index - db.Save(&file) -} - -func (db *mySQLdb) FinalizeFileUpload(fileID string) { - var file models.File - db.DB.Table("files").Where("id = ?", fileID).First(&file) - file.Done = true - db.Save(&file) -} - func (db *mySQLdb) InitializeTotp(email string, secret string) error { var user models.User err := db.DB.Table("users").Where("email = ?", email).First(&user).Error @@ -336,7 +295,6 @@ func (db *postgresDB) GetAllUsers() ([]models.User, error) { var users []models.User err := db.DB.Table("users").Select("user_id, username, email").Find(&users).Error if err != nil { - fmt.Println(err) return nil, err } return users, nil @@ -388,26 +346,6 @@ func (db *postgresDB) GetFiles(ownerID string) ([]*models.File, error) { return files, err } -func (db *postgresDB) UpdateUploadedByte(byte int64, fileID string) { - var file models.File - db.DB.Table("files").Where("id = $1", fileID).First(&file) - file.UploadedByte = byte - db.Save(&file) -} -func (db *postgresDB) UpdateUploadedChunk(index int64, fileID string) { - var file models.File - db.DB.Table("files").Where("id = $1", fileID).First(&file) - file.UploadedChunk = index - db.Save(&file) -} - -func (db *postgresDB) FinalizeFileUpload(fileID string) { - var file models.File - db.DB.Table("files").Where("id = $1", fileID).First(&file) - file.Done = true - db.Save(&file) -} - func (db *postgresDB) InitializeTotp(email string, secret string) error { var user models.User err := db.DB.Table("users").Where("email = $1", email).First(&user).Error diff --git a/go.mod b/go.mod index 8858568..b01d232 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,8 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -29,6 +31,7 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/redis/go-redis/v9 v9.6.1 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/go.sum b/go.sum index 33922a0..3adffb8 100644 --- a/go.sum +++ b/go.sum @@ -2,9 +2,13 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/a-h/templ v0.2.707 h1:T1Gkd2ugbRglZ9rYw/VBchWOSZVKmetDbBkm4YubM7U= github.com/a-h/templ v0.2.707/go.mod h1:5cqsugkq9IerRNucNsI4DEamdHPsoGMQy99DzydLhM8= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= @@ -37,6 +41,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0y4= +github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA= github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= diff --git a/handler/auth/google/callback/callback.go b/handler/auth/google/callback/callback.go index c112c46..0b2a718 100644 --- a/handler/auth/google/callback/callback.go +++ b/handler/auth/google/callback/callback.go @@ -6,15 +6,14 @@ import ( "errors" "fmt" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" googleOauthSetupHandler "github.com/fossyy/filekeeper/handler/auth/google/setup" signinHandler "github.com/fossyy/filekeeper/handler/signin" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" + "github.com/redis/go-redis/v9" "net/http" "net/url" - "sync" "time" ) @@ -46,49 +45,22 @@ type OauthUser struct { VerifiedEmail bool `json:"verified_email"` } -type CsrfToken struct { - Token string - CreateTime time.Time - mu sync.Mutex -} - -var CsrfTokens map[string]*CsrfToken - -func init() { - - CsrfTokens = make(map[string]*CsrfToken) - - ticker := time.NewTicker(time.Minute) - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [csrf_token] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) - - for _, data := range CsrfTokens { - data.mu.Lock() - if currentTime.Sub(data.CreateTime) > time.Minute*10 { - delete(CsrfTokens, data.Token) - cacheClean++ - } - data.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [csrf_token] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) - } - }() -} - func GET(w http.ResponseWriter, r *http.Request) { - if _, ok := CsrfTokens[r.URL.Query().Get("state")]; !ok { - //csrf token mismatch error + _, err := app.Server.Cache.GetCache(r.Context(), "CsrfTokens:"+r.URL.Query().Get("state")) + if err != nil { + if errors.Is(err, redis.Nil) { + w.WriteHeader(http.StatusUnauthorized) + return + } w.WriteHeader(http.StatusInternalServerError) return } - delete(CsrfTokens, r.URL.Query().Get("state")) + err = app.Server.Cache.DeleteCache(r.Context(), "CsrfTokens:"+r.URL.Query().Get("state")) + if err != nil { + http.Redirect(w, r, fmt.Sprintf("/signin?error=%s", "csrf_token_error"), http.StatusFound) + return + } if err := r.URL.Query().Get("error"); err != "" { http.Redirect(w, r, fmt.Sprintf("/signin?error=%s", err), http.StatusFound) @@ -146,16 +118,23 @@ func GET(w http.ResponseWriter, r *http.Request) { if !app.Server.Database.IsEmailRegistered(oauthUser.Email) { code := utils.GenerateRandomString(64) - googleOauthSetupHandler.SetupUser[code] = &googleOauthSetupHandler.UnregisteredUser{ + + user := googleOauthSetupHandler.UnregisteredUser{ Code: code, Email: oauthUser.Email, CreateTime: time.Now(), } + newGoogleSetupJSON, _ := json.Marshal(user) + err = app.Server.Cache.SetCache(r.Context(), "GoogleSetup:"+code, newGoogleSetupJSON, time.Minute*15) + if err != nil { + fmt.Println("Error setting up Google Setup:", err) + return + } http.Redirect(w, r, fmt.Sprintf("/auth/google/setup/%s", code), http.StatusSeeOther) return } - user, err := cache.GetUser(oauthUser.Email) + user, err := app.Server.Service.GetUser(r.Context(), oauthUser.Email) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -163,12 +142,15 @@ func GET(w http.ResponseWriter, r *http.Request) { return } - storeSession := session.Create() - storeSession.Values["user"] = types.User{ + storeSession, err := session.Create(types.User{ UserID: user.UserID, Email: oauthUser.Email, Username: user.Username, Authenticated: true, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return } userAgent := r.Header.Get("User-Agent") diff --git a/handler/auth/google/google.go b/handler/auth/google/google.go index 5e50b1f..72ded9d 100644 --- a/handler/auth/google/google.go +++ b/handler/auth/google/google.go @@ -1,17 +1,31 @@ package googleOauthHandler import ( + "encoding/json" "fmt" "github.com/fossyy/filekeeper/app" - googleOauthCallbackHandler "github.com/fossyy/filekeeper/handler/auth/google/callback" "github.com/fossyy/filekeeper/utils" "net/http" "time" ) +type CsrfToken struct { + Token string + CreateTime time.Time +} + func GET(w http.ResponseWriter, r *http.Request) { token, err := utils.GenerateCSRFToken() - googleOauthCallbackHandler.CsrfTokens[token] = &googleOauthCallbackHandler.CsrfToken{Token: token, CreateTime: time.Now()} + csrfToken := CsrfToken{ + Token: token, + CreateTime: time.Now(), + } + newCsrfToken, err := json.Marshal(csrfToken) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + app.Server.Cache.SetCache(r.Context(), "CsrfTokens:"+token, newCsrfToken, time.Minute*15) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) diff --git a/handler/auth/google/setup/setup.go b/handler/auth/google/setup/setup.go index 17823f6..f49ca72 100644 --- a/handler/auth/google/setup/setup.go +++ b/handler/auth/google/setup/setup.go @@ -1,7 +1,8 @@ package googleOauthSetupHandler import ( - "fmt" + "encoding/json" + "errors" "github.com/fossyy/filekeeper/app" signinHandler "github.com/fossyy/filekeeper/handler/signin" "github.com/fossyy/filekeeper/session" @@ -11,8 +12,8 @@ import ( "github.com/fossyy/filekeeper/view/client/auth" signupView "github.com/fossyy/filekeeper/view/client/signup" "github.com/google/uuid" + "github.com/redis/go-redis/v9" "net/http" - "sync" "time" ) @@ -20,50 +21,26 @@ type UnregisteredUser struct { Code string Email string CreateTime time.Time - mu sync.Mutex -} - -var SetupUser map[string]*UnregisteredUser - -func init() { - - SetupUser = make(map[string]*UnregisteredUser) - - ticker := time.NewTicker(time.Minute) - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [GoogleSetup] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) - - for _, data := range SetupUser { - data.mu.Lock() - if currentTime.Sub(data.CreateTime) > time.Minute*10 { - delete(SetupUser, data.Code) - cacheClean++ - } - - data.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [GoogleSetup] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) - } - }() } func GET(w http.ResponseWriter, r *http.Request) { code := r.PathValue("code") - if _, ok := SetupUser[code]; !ok { - http.Redirect(w, r, "/signup", http.StatusSeeOther) + _, err := app.Server.Cache.GetCache(r.Context(), "GoogleSetup:"+code) + if err != nil { + if errors.Is(err, redis.Nil) { + http.Redirect(w, r, "/signup", http.StatusSeeOther) + return + } + w.WriteHeader(http.StatusInternalServerError) + app.Server.Logger.Error(err.Error()) return } + component := authView.GoogleSetup("Filekeeper - Setup Page", types.Message{ Code: 3, Message: "", }) - err := component.Render(r.Context(), w) + err = component.Render(r.Context(), w) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) @@ -73,12 +50,22 @@ func GET(w http.ResponseWriter, r *http.Request) { func POST(w http.ResponseWriter, r *http.Request) { code := r.PathValue("code") - unregisteredUser, ok := SetupUser[code] - if !ok { + cache, err := app.Server.Cache.GetCache(r.Context(), "GoogleSetup:"+code) + + if errors.Is(err, redis.Nil) { http.Error(w, "Unauthorized Action", http.StatusUnauthorized) return } - err := r.ParseForm() + + var unregisteredUser UnregisteredUser + err = json.Unmarshal([]byte(cache), &unregisteredUser) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + app.Server.Logger.Error(err.Error()) + return + } + + err = r.ParseForm() if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) @@ -126,14 +113,15 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - delete(SetupUser, code) - - storeSession := session.Create() - storeSession.Values["user"] = types.User{ + storeSession, err := session.Create(types.User{ UserID: userID, Email: unregisteredUser.Email, Username: username, Authenticated: true, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return } userAgent := r.Header.Get("User-Agent") diff --git a/handler/auth/totp/totp.go b/handler/auth/totp/totp.go index 7f3ad23..29cecc1 100644 --- a/handler/auth/totp/totp.go +++ b/handler/auth/totp/totp.go @@ -37,15 +37,19 @@ func POST(w http.ResponseWriter, r *http.Request) { totp := gotp.NewDefaultTOTP(user.Totp) if totp.Verify(code, time.Now().Unix()) { - storeSession, err := session.Get(key) - if err != nil { - return - } - storeSession.Values["user"] = types.User{ + storeSession := session.Get(key) + err := storeSession.Change(types.User{ UserID: user.UserID, Email: user.Email, Username: user.Username, Authenticated: true, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + if err != nil { + return } userAgent := r.Header.Get("User-Agent") browserInfo, osInfo := ParseUserAgent(userAgent) diff --git a/handler/download/file/file.go b/handler/download/file/file.go index e80c883..cf0c674 100644 --- a/handler/download/file/file.go +++ b/handler/download/file/file.go @@ -1,7 +1,9 @@ package downloadFileHandler import ( + "fmt" "github.com/fossyy/filekeeper/app" + "io" "net/http" "os" "path/filepath" @@ -28,22 +30,21 @@ func GET(w http.ResponseWriter, r *http.Request) { return } - openFile, err := os.OpenFile(filepath.Join(saveFolder, file.Name), os.O_RDONLY, 0) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - app.Server.Logger.Error(err.Error()) - return - } - defer openFile.Close() + w.Header().Set("Content-Disposition", "attachment; filename="+file.Name) + w.Header().Set("Content-Type", "application/octet-stream") + for i := 0; i <= int(file.TotalChunk); i++ { + chunkPath := filepath.Join(saveFolder, file.Name, fmt.Sprintf("chunk_%d", i)) - stat, err := openFile.Stat() - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - app.Server.Logger.Error(err.Error()) - return + chunkFile, err := os.Open(chunkPath) + if err != nil { + http.Error(w, fmt.Sprintf("Error opening chunk: %v", err), http.StatusInternalServerError) + return + } + _, err = io.Copy(w, chunkFile) + chunkFile.Close() + if err != nil { + http.Error(w, fmt.Sprintf("Error writing chunk: %v", err), http.StatusInternalServerError) + return + } } - - w.Header().Set("Content-Disposition", "attachment; filename="+stat.Name()) - http.ServeContent(w, r, stat.Name(), stat.ModTime(), openFile) - return } diff --git a/handler/forgotPassword/forgotPassword.go b/handler/forgotPassword/forgotPassword.go index 2af2053..e3eb791 100644 --- a/handler/forgotPassword/forgotPassword.go +++ b/handler/forgotPassword/forgotPassword.go @@ -3,13 +3,14 @@ package forgotPasswordHandler import ( "bytes" "context" + "encoding/json" "errors" "fmt" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/view/client/email" "github.com/fossyy/filekeeper/view/client/forgotPassword" "github.com/google/uuid" + "github.com/redis/go-redis/v9" "net/http" "sync" "time" @@ -27,35 +28,6 @@ type ForgotPassword struct { CreateTime time.Time } -var ListForgotPassword map[string]*ForgotPassword -var UserForgotPassword = make(map[string]string) - -func init() { - ListForgotPassword = make(map[string]*ForgotPassword) - ticker := time.NewTicker(time.Minute) - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Forgot Password] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) - - for _, data := range ListForgotPassword { - data.mu.Lock() - if currentTime.Sub(data.CreateTime) > time.Minute*10 { - delete(ListForgotPassword, data.User.Email) - delete(UserForgotPassword, data.Code) - cacheClean++ - } - data.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Forgot Password] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) - } - }() -} - func GET(w http.ResponseWriter, r *http.Request) { component := forgotPasswordView.Main("Filekeeper - Forgot Password Page", types.Message{ Code: 3, @@ -79,7 +51,7 @@ func POST(w http.ResponseWriter, r *http.Request) { emailForm := r.Form.Get("email") - user, err := cache.GetUser(emailForm) + user, err := app.Server.Service.GetUser(r.Context(), emailForm) if errors.Is(err, gorm.ErrRecordNotFound) { component := forgotPasswordView.Main("Filekeeper - Forgot Password Page", types.Message{ Code: 0, @@ -119,31 +91,48 @@ func POST(w http.ResponseWriter, r *http.Request) { } func verifyForgot(user *models.User) error { + var userData *ForgotPassword var code string + var err error + code, err = app.Server.Cache.GetCache(context.Background(), "ForgotPasswordCode:"+user.Email) + if err != nil { + if errors.Is(err, redis.Nil) { + code = utils.GenerateRandomString(64) + userData = &ForgotPassword{ + User: user, + Code: code, + CreateTime: time.Now(), + } - var buffer bytes.Buffer - data, ok := ListForgotPassword[user.Email] - - if !ok { - code = utils.GenerateRandomString(64) + newForgotUser, err := json.Marshal(userData) + if err != nil { + return err + } + err = app.Server.Cache.SetCache(context.Background(), "ForgotPasswordCode:"+user.Email, code, time.Minute*15) + if err != nil { + return err + } + err = app.Server.Cache.SetCache(context.Background(), "ForgotPassword:"+userData.Code, newForgotUser, time.Minute*15) + if err != nil { + return err + } + } else { + return err + } } else { - code = data.Code + storedCode, err := app.Server.Cache.GetCache(context.Background(), "ForgotPassword:"+code) + err = json.Unmarshal([]byte(storedCode), &userData) + if err != nil { + return err + } } - err := emailView.ForgotPassword(user.Username, fmt.Sprintf("https://%s/forgot-password/verify/%s", utils.Getenv("DOMAIN"), code)).Render(context.Background(), &buffer) + var buffer bytes.Buffer + err = emailView.ForgotPassword(user.Username, fmt.Sprintf("https://%s/forgot-password/verify/%s", utils.Getenv("DOMAIN"), code)).Render(context.Background(), &buffer) if err != nil { return err } - userData := &ForgotPassword{ - User: user, - Code: code, - CreateTime: time.Now(), - } - - UserForgotPassword[code] = user.Email - ListForgotPassword[user.Email] = userData - err = app.Server.Mail.Send(user.Email, "Password Change Request", buffer.String()) if err != nil { return err diff --git a/handler/forgotPassword/verify/verify.go b/handler/forgotPassword/verify/verify.go index 8deda5e..c329deb 100644 --- a/handler/forgotPassword/verify/verify.go +++ b/handler/forgotPassword/verify/verify.go @@ -1,14 +1,16 @@ package forgotPasswordVerifyHandler import ( + "encoding/json" + "errors" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" forgotPasswordHandler "github.com/fossyy/filekeeper/handler/forgotPassword" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" "github.com/fossyy/filekeeper/view/client/forgotPassword" signupView "github.com/fossyy/filekeeper/view/client/signup" + "github.com/redis/go-redis/v9" "net/http" ) @@ -21,11 +23,13 @@ func init() { func GET(w http.ResponseWriter, r *http.Request) { code := r.PathValue("code") - email := forgotPasswordHandler.UserForgotPassword[code] - _, ok := forgotPasswordHandler.ListForgotPassword[email] - - if !ok { - w.WriteHeader(http.StatusNotFound) + _, err := app.Server.Cache.GetCache(r.Context(), "ForgotPassword:"+code) + if err != nil { + if errors.Is(err, redis.Nil) { + w.WriteHeader(http.StatusNotFound) + return + } + w.WriteHeader(http.StatusInternalServerError) return } @@ -33,7 +37,7 @@ func GET(w http.ResponseWriter, r *http.Request) { Code: 3, Message: "", }) - err := component.Render(r.Context(), w) + err = component.Render(r.Context(), w) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) @@ -44,15 +48,20 @@ func GET(w http.ResponseWriter, r *http.Request) { func POST(w http.ResponseWriter, r *http.Request) { code := r.PathValue("code") - email := forgotPasswordHandler.UserForgotPassword[code] - data, ok := forgotPasswordHandler.ListForgotPassword[email] - - if !ok { + data, err := app.Server.Cache.GetCache(r.Context(), "ForgotPassword:"+code) + if err != nil { w.WriteHeader(http.StatusNotFound) return } + var userData *forgotPasswordHandler.ForgotPassword - err := r.ParseForm() + err = json.Unmarshal([]byte(data), &userData) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + err = r.ParseForm() if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) @@ -82,19 +91,19 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - err = app.Server.Database.UpdateUserPassword(data.User.Email, hashedPassword) + err = app.Server.Database.UpdateUserPassword(userData.User.Email, hashedPassword) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) return } - delete(forgotPasswordHandler.ListForgotPassword, data.User.Email) - delete(forgotPasswordHandler.UserForgotPassword, data.Code) + app.Server.Cache.DeleteCache(r.Context(), "ForgotPasswordCode:"+userData.User.Email) + app.Server.Cache.DeleteCache(r.Context(), "ForgotPassword:"+code) - session.RemoveAllSessions(data.User.Email) + session.RemoveAllSessions(userData.User.Email) - cache.DeleteUser(data.User.Email) + app.Server.Service.DeleteUser(userData.User.Email) component := forgotPasswordView.ChangeSuccess("Filekeeper - Forgot Password Page") err = component.Render(r.Context(), w) diff --git a/handler/logout/logout.go b/handler/logout/logout.go index 85cb180..dfdcb0b 100644 --- a/handler/logout/logout.go +++ b/handler/logout/logout.go @@ -2,20 +2,20 @@ package logoutHandler import ( "errors" + "github.com/fossyy/filekeeper/types" "net/http" "github.com/fossyy/filekeeper/session" - "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" ) func GET(w http.ResponseWriter, r *http.Request) { + userSession := r.Context().Value("user").(types.User) cookie, err := r.Cookie("Session") if err != nil { return } - - storeSession, err := session.Get(cookie.Value) + storeSession := session.Get(cookie.Value) if err != nil { if errors.Is(err, &session.SessionNotFoundError{}) { storeSession.Destroy(w) @@ -24,8 +24,16 @@ func GET(w http.ResponseWriter, r *http.Request) { return } - storeSession.Delete() - session.RemoveSessionInfo(storeSession.Values["user"].(types.User).Email, cookie.Value) + err = storeSession.Delete() + if err != nil { + panic(err) + return + } + err = session.RemoveSessionInfo(userSession.Email, cookie.Value) + if err != nil { + panic(err) + return + } http.SetCookie(w, &http.Cookie{ Name: utils.Getenv("SESSION_NAME"), diff --git a/handler/signin/signin.go b/handler/signin/signin.go index f1fc30c..8c01347 100644 --- a/handler/signin/signin.go +++ b/handler/signin/signin.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/a-h/templ" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" @@ -37,8 +36,8 @@ func init() { "login_required": "You need to log in again to proceed. Please try logging in again.", "account_selection_required": "Please select an account to proceed with the request.", "consent_required": "Consent is required to proceed. Please provide consent to continue.", + "csrf_token_error": "The CSRF token is missing or invalid. Please refresh the page and try again.", } - } func GET(w http.ResponseWriter, r *http.Request) { @@ -77,7 +76,7 @@ func POST(w http.ResponseWriter, r *http.Request) { } email := r.Form.Get("email") password := r.Form.Get("password") - userData, err := cache.GetUser(email) + userData, err := app.Server.Service.GetUser(r.Context(), email) if err != nil { component := signinView.Main("Filekeeper - Sign in Page", types.Message{ Code: 0, @@ -95,27 +94,33 @@ func POST(w http.ResponseWriter, r *http.Request) { if email == userData.Email && utils.CheckPasswordHash(password, userData.Password) { if userData.Totp != "" { - storeSession := session.Create() - storeSession.Values["user"] = types.User{ + + storeSession, err := session.Create(types.User{ UserID: userData.UserID, Email: email, Username: userData.Username, Totp: userData.Totp, Authenticated: false, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return } storeSession.Save(w) http.Redirect(w, r, "/auth/totp", http.StatusSeeOther) return } - storeSession := session.Create() - storeSession.Values["user"] = types.User{ + storeSession, err := session.Create(types.User{ UserID: userData.UserID, Email: email, Username: userData.Username, Authenticated: true, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return } - userAgent := r.Header.Get("User-Agent") browserInfo, osInfo := ParseUserAgent(userAgent) diff --git a/handler/signup/signup.go b/handler/signup/signup.go index 86d43ba..a08007d 100644 --- a/handler/signup/signup.go +++ b/handler/signup/signup.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "github.com/fossyy/filekeeper/app" + "github.com/fossyy/filekeeper/utils" "github.com/fossyy/filekeeper/view/client/email" signupView "github.com/fossyy/filekeeper/view/client/signup" "net/http" @@ -13,7 +14,6 @@ import ( "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/types/models" - "github.com/fossyy/filekeeper/utils" "github.com/google/uuid" ) diff --git a/handler/upload/initialisation/initialisation.go b/handler/upload/initialisation/initialisation.go index 9c9158f..15560b4 100644 --- a/handler/upload/initialisation/initialisation.go +++ b/handler/upload/initialisation/initialisation.go @@ -3,8 +3,8 @@ package initialisation import ( "encoding/json" "errors" + "fmt" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "io" "net/http" "os" @@ -31,7 +31,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - fileData, err := cache.GetUserFile(fileInfo.Name, userSession.UserID.String()) + fileData, err := app.Server.Service.GetUserFile(fileInfo.Name, userSession.UserID.String()) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { upload, err := handleNewUpload(userSession, fileInfo) @@ -39,6 +39,26 @@ func POST(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } + fileData = &types.FileWithDetail{ + ID: fileData.ID, + OwnerID: fileData.OwnerID, + Name: fileData.Name, + Size: fileData.Size, + Downloaded: fileData.Downloaded, + } + fileData.Chunk = make(map[string]bool) + fileData.Done = true + saveFolder := filepath.Join("uploads", userSession.UserID.String(), fileData.ID.String(), fileData.Name) + for i := 0; i <= int(fileInfo.Chunk-1); i++ { + fileName := fmt.Sprintf("%s/chunk_%d", saveFolder, i) + + if _, err := os.Stat(fileName); os.IsNotExist(err) { + fileData.Chunk[fmt.Sprintf("chunk_%d", i)] = false + fileData.Done = false + } else { + fileData.Chunk[fmt.Sprintf("chunk_%d", i)] = true + } + } respondJSON(w, upload) return } @@ -46,11 +66,19 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - if fileData.Done { - respondJSON(w, map[string]bool{"Done": true}) - return - } + fileData.Chunk = make(map[string]bool) + fileData.Done = true + saveFolder := filepath.Join("uploads", userSession.UserID.String(), fileData.ID.String(), fileData.Name) + for i := 0; i <= int(fileInfo.Chunk-1); i++ { + fileName := fmt.Sprintf("%s/chunk_%d", saveFolder, i) + if _, err := os.Stat(fileName); os.IsNotExist(err) { + fileData.Chunk[fmt.Sprintf("chunk_%d", i)] = false + fileData.Done = false + } else { + fileData.Chunk[fmt.Sprintf("chunk_%d", i)] = true + } + } respondJSON(w, fileData) } @@ -82,14 +110,12 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.File, error) } newFile := models.File{ - ID: fileID, - OwnerID: ownerID, - Name: file.Name, - Size: file.Size, - Downloaded: 0, - UploadedByte: 0, - UploadedChunk: -1, - Done: false, + ID: fileID, + OwnerID: ownerID, + Name: file.Name, + Size: file.Size, + TotalChunk: file.Chunk - 1, + Downloaded: 0, } err = app.Server.Database.CreateFile(&newFile) @@ -97,7 +123,6 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.File, error) app.Server.Logger.Error(err.Error()) return models.File{}, err } - return newFile, nil } diff --git a/handler/upload/upload.go b/handler/upload/upload.go index b57f2b9..97b5ade 100644 --- a/handler/upload/upload.go +++ b/handler/upload/upload.go @@ -1,8 +1,8 @@ package uploadHandler import ( + "fmt" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/types" filesView "github.com/fossyy/filekeeper/view/client/upload" "io" @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "strconv" + "strings" ) func GET(w http.ResponseWriter, r *http.Request) { @@ -38,40 +39,56 @@ func POST(w http.ResponseWriter, r *http.Request) { } } - file, err := cache.GetFile(fileID) + file, err := app.Server.Service.GetFile(fileID) if err != nil { app.Server.Logger.Error("error getting upload info: " + err.Error()) w.WriteHeader(http.StatusInternalServerError) return } - currentDir, _ := os.Getwd() - basePath := filepath.Join(currentDir, uploadDir) - saveFolder := filepath.Join(basePath, userSession.UserID.String(), file.ID.String()) - - if filepath.Dir(saveFolder) != filepath.Join(basePath, userSession.UserID.String()) { - app.Server.Logger.Error("invalid path") - w.WriteHeader(http.StatusInternalServerError) - return - } - - fileByte, fileHeader, err := r.FormFile("chunk") - if err != nil { - app.Server.Logger.Error("error getting upload info: " + err.Error()) - w.WriteHeader(http.StatusInternalServerError) - return - } - defer fileByte.Close() - rawIndex := r.FormValue("index") index, err := strconv.Atoi(rawIndex) if err != nil { return } - file.UpdateProgress(int64(index), file.UploadedByte+int64(fileHeader.Size)) + currentDir, err := os.Getwd() + if err != nil { + app.Server.Logger.Error("unable to get current directory") + w.WriteHeader(http.StatusInternalServerError) + return + } - dst, err := os.OpenFile(filepath.Join(saveFolder, file.Name), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) + basePath := filepath.Join(currentDir, uploadDir) + cleanBasePath := filepath.Clean(basePath) + + saveFolder := filepath.Join(cleanBasePath, userSession.UserID.String(), file.ID.String(), file.Name) + + cleanSaveFolder := filepath.Clean(saveFolder) + + if !strings.HasPrefix(cleanSaveFolder, cleanBasePath) { + app.Server.Logger.Error("invalid path") + w.WriteHeader(http.StatusInternalServerError) + return + } + + if _, err := os.Stat(saveFolder); os.IsNotExist(err) { + if err := os.MkdirAll(saveFolder, os.ModePerm); err != nil { + app.Server.Logger.Error("error creating save folder: " + err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + } + + fileByte, _, err := r.FormFile("chunk") + if err != nil { + app.Server.Logger.Error("error getting upload info: " + err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + defer fileByte.Close() + + dst, err := os.OpenFile(filepath.Join(saveFolder, fmt.Sprintf("chunk_%d", index)), os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0666) if err != nil { app.Server.Logger.Error("error making upload folder: " + err.Error()) w.WriteHeader(http.StatusInternalServerError) @@ -85,9 +102,5 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - if file.UploadedByte >= file.Size { - file.FinalizeFileUpload() - return - } return } diff --git a/handler/user/ResetPassword/ResetPassword.go b/handler/user/ResetPassword/ResetPassword.go index 0f0f934..5d76aa7 100644 --- a/handler/user/ResetPassword/ResetPassword.go +++ b/handler/user/ResetPassword/ResetPassword.go @@ -2,7 +2,6 @@ package userHandlerResetPassword import ( "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/utils" @@ -14,7 +13,7 @@ func POST(w http.ResponseWriter, r *http.Request) { userSession := r.Context().Value("user").(types.User) currentPassword := r.Form.Get("currentPassword") password := r.Form.Get("password") - user, err := cache.GetUser(userSession.Email) + user, err := app.Server.Service.GetUser(r.Context(), userSession.Email) if err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -38,7 +37,7 @@ func POST(w http.ResponseWriter, r *http.Request) { } session.RemoveAllSessions(userSession.Email) - cache.DeleteUser(userSession.Email) + app.Server.Service.DeleteUser(userSession.Email) http.Redirect(w, r, "/signin", http.StatusSeeOther) return diff --git a/handler/user/session/terminate/terminate.go b/handler/user/session/terminate/terminate.go index f1dd852..2c005cc 100644 --- a/handler/user/session/terminate/terminate.go +++ b/handler/user/session/terminate/terminate.go @@ -10,16 +10,22 @@ func DELETE(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") _, mySession, _ := session.GetSession(r) - otherSession, _ := session.Get(id) - if session.GetSessionInfo(mySession.Email, otherSession.ID) == nil { + otherSession := session.Get(id) + if _, err := session.GetSessionInfo(mySession.Email, otherSession.ID); err != nil { w.WriteHeader(http.StatusUnauthorized) return } + otherSession.Delete() session.RemoveSessionInfo(mySession.Email, otherSession.ID) - component := userView.SessionTable(session.GetSessions(mySession.Email)) + sessions, err := session.GetSessions(mySession.Email) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + component := userView.SessionTable(sessions) - err := component.Render(r.Context(), w) + err = component.Render(r.Context(), w) if err != nil { w.WriteHeader(http.StatusInternalServerError) return diff --git a/handler/user/totp/setup.go b/handler/user/totp/setup.go index 600e917..89b2817 100644 --- a/handler/user/totp/setup.go +++ b/handler/user/totp/setup.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "fmt" "github.com/fossyy/filekeeper/app" - "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/view/client/user/totp" "image/png" "net/http" @@ -75,7 +74,7 @@ func POST(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - cache.DeleteUser(userSession.Email) + app.Server.Service.DeleteUser(userSession.Email) component := userTotpSetupView.Main("Filekeeper - 2FA Setup Page", base64Str, secret, userSession, types.Message{ Code: 1, Message: "Your TOTP setup is complete! Your account is now more secure.", diff --git a/handler/user/user.go b/handler/user/user.go index fc57ce2..463fcb5 100644 --- a/handler/user/user.go +++ b/handler/user/user.go @@ -17,24 +17,28 @@ var errorMessages = map[string]string{ func GET(w http.ResponseWriter, r *http.Request) { var component templ.Component userSession := r.Context().Value("user").(types.User) - + sessions, err := session.GetSessions(userSession.Email) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } if err := r.URL.Query().Get("error"); err != "" { message, ok := errorMessages[err] if !ok { message = "Unknown error occurred. Please contact support at bagas@fossy.my.id for assistance." } - component = userView.Main("Filekeeper - User Page", userSession, session.GetSessions(userSession.Email), types.Message{ + component = userView.Main("Filekeeper - User Page", userSession, sessions, types.Message{ Code: 0, Message: message, }) } else { - component = userView.Main("Filekeeper - User Page", userSession, session.GetSessions(userSession.Email), types.Message{ + component = userView.Main("Filekeeper - User Page", userSession, sessions, types.Message{ Code: 1, Message: "", }) } - err := component.Render(r.Context(), w) + err = component.Render(r.Context(), w) if err != nil { w.WriteHeader(http.StatusInternalServerError) app.Server.Logger.Error(err.Error()) diff --git a/main.go b/main.go index 9f5eb97..addb202 100644 --- a/main.go +++ b/main.go @@ -3,12 +3,14 @@ package main import ( "fmt" "github.com/fossyy/filekeeper/app" + "github.com/fossyy/filekeeper/cache" "github.com/fossyy/filekeeper/db" "github.com/fossyy/filekeeper/email" "github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/middleware" "github.com/fossyy/filekeeper/routes/admin" "github.com/fossyy/filekeeper/routes/client" + "github.com/fossyy/filekeeper/service" "github.com/fossyy/filekeeper/utils" "strconv" ) @@ -24,11 +26,13 @@ func main() { dbName := utils.Getenv("DB_NAME") database := db.NewPostgresDB(dbUser, dbPass, dbHost, dbPort, dbName, db.DisableSSL) + cacheServer := cache.NewRedisServer(database) + services := service.NewService(database, cacheServer) smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT")) mailServer := email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) - app.Server = app.NewClientServer(clientAddr, middleware.Handler(client.SetupRoutes()), *logger.Logger(), database, mailServer) + app.Server = app.NewClientServer(clientAddr, middleware.Handler(client.SetupRoutes()), *logger.Logger(), database, cacheServer, services, mailServer) app.Admin = app.NewAdminServer(adminAddr, middleware.Handler(admin.SetupRoutes()), database) go func() { diff --git a/middleware/middleware.go b/middleware/middleware.go index 6e6d35b..f6132b6 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -81,13 +81,10 @@ func Handler(next http.Handler) http.Handler { } func Auth(next http.HandlerFunc, w http.ResponseWriter, r *http.Request) { - status, user, sessionID := session.GetSession(r) + status, user, _ := session.GetSession(r) switch status { case session.Authorized: - userSession := session.GetSessionInfo(user.Email, sessionID) - userSession.UpdateAccessTime() - ctx := context.WithValue(r.Context(), "user", user) req := r.WithContext(ctx) r.Context().Value("user") diff --git a/public/upload.js b/public/upload.js index 94ae924..03e4241 100644 --- a/public/upload.js +++ b/public/upload.js @@ -32,7 +32,7 @@ async function handleFile(file){ if (responseData.Done === false) { addNewUploadElement(file) const fileChunks = await splitFile(file, chunkSize); - await uploadChunks(file.name,file.size, fileChunks, responseData.UploadedChunk, responseData.ID); + await uploadChunks(file.name,file.size, fileChunks, responseData.Chunk, responseData.ID); } else { alert("file already uploaded") } @@ -123,7 +123,7 @@ async function splitFile(file, chunkSize) { return fileChunks; } -async function uploadChunks(name, size, chunks, uploadedChunk= -1, FileID) { +async function uploadChunks(name, size, chunks, chunkArray, FileID) { let byteUploaded = 0 let progress1 = document.getElementById(`progress-${name}-1`); let progress2 = document.getElementById(`progress-${name}-2`); @@ -132,7 +132,7 @@ async function uploadChunks(name, size, chunks, uploadedChunk= -1, FileID) { for (let index = 0; index < chunks.length; index++) { const percentComplete = Math.round((index + 1) / chunks.length * 100); const chunk = chunks[index]; - if (!(index <= uploadedChunk)) { + if (!(chunkArray["chunk_"+index])) { const formData = new FormData(); formData.append('name', name); formData.append('chunk', chunk); @@ -152,13 +152,19 @@ async function uploadChunks(name, size, chunks, uploadedChunk= -1, FileID) { const totalTime = (endTime - startTime) / 1000; const uploadSpeed = chunk.size / totalTime / 1024 / 1024; byteUploaded += chunk.size - console.log(byteUploaded) progress3.innerText = `${uploadSpeed.toFixed(2)} MB/s`; progress4.innerText = `Uploading ${percentComplete}% - ${convertFileSize(byteUploaded)} of ${ convertFileSize(size)}`; } else { progress1.setAttribute("aria-valuenow", percentComplete); progress2.style.width = `${percentComplete}%`; + progress3.innerText = `Fixing Missing Byte`; + progress4.innerText = `Uploading Missing Byte ${percentComplete}% - ${convertFileSize(byteUploaded)} of ${ convertFileSize(size)}`; byteUploaded += chunk.size } } + console.log(chunks) + console.log(chunkArray) + + progress3.innerText = `Done`; + progress4.innerText = `File Uploaded 100% - ${convertFileSize(byteUploaded)} of ${ convertFileSize(size)}`; } \ No newline at end of file diff --git a/schema.sql b/schema.sql index d705a30..64f7359 100644 --- a/schema.sql +++ b/schema.sql @@ -1,2 +1,2 @@ -CREATE TABLE IF NOT EXISTS users (user_id VARCHAR(255) PRIMARY KEY NOT NULL,username VARCHAR(255) UNIQUE NOT NULL,email VARCHAR(255) UNIQUE NOT NULL,password TEXT NOT NULL,totp VARCHAR(255) DEFAULT NULL); -CREATE TABLE IF NOT EXISTS files (id VARCHAR(255) PRIMARY KEY NOT NULL,owner_id VARCHAR(255) NOT NULL,name TEXT NOT NULL,size BIGINT NOT NULL,downloaded BIGINT NOT NULL,uploaded_byte BIGINT NOT NULL DEFAULT 0, uploaded_chunk BIGINT NOT NULL DEFAULT -1,done BOOLEAN NOT NULL DEFAULT FALSE,FOREIGN KEY (owner_id) REFERENCES users(user_id)); \ No newline at end of file +CREATE TABLE IF NOT EXISTS users (user_id UUID PRIMARY KEY NOT NULL, username VARCHAR(255) UNIQUE NOT NULL, email VARCHAR(255) UNIQUE NOT NULL, password TEXT NOT NULL, totp VARCHAR(255) NOT NULL); +CREATE TABLE IF NOT EXISTS files (id UUID PRIMARY KEY NOT NULL, owner_id UUID NOT NULL, name TEXT NOT NULL, size BIGINT NOT NULL, total_chunk BIGINT NOT NULL, downloaded BIGINT NOT NULL DEFAULT 0, FOREIGN KEY (owner_id) REFERENCES users(user_id)); diff --git a/service/service.go b/service/service.go new file mode 100644 index 0000000..cde1412 --- /dev/null +++ b/service/service.go @@ -0,0 +1,110 @@ +package service + +import ( + "context" + "encoding/json" + "github.com/fossyy/filekeeper/app" + "github.com/fossyy/filekeeper/types" + "github.com/fossyy/filekeeper/types/models" + "github.com/redis/go-redis/v9" + "time" +) + +type Service struct { + db types.Database + cache types.CachingServer +} + +func NewService(db types.Database, cache types.CachingServer) *Service { + return &Service{ + db: db, + cache: cache, + } +} + +func (r *Service) GetUser(ctx context.Context, email string) (*models.User, error) { + userJSON, err := app.Server.Cache.GetCache(ctx, "UserCache:"+email) + if err == redis.Nil { + userData, err := r.db.GetUser(email) + if err != nil { + return nil, err + } + + user := &models.User{ + UserID: userData.UserID, + Username: userData.Username, + Email: userData.Email, + Password: userData.Password, + Totp: userData.Totp, + } + + newUserJSON, _ := json.Marshal(user) + err = r.cache.SetCache(ctx, email, newUserJSON, time.Hour*24) + if err != nil { + return nil, err + } + + return user, nil + } + if err != nil { + return nil, err + } + + var user models.User + err = json.Unmarshal([]byte(userJSON), &user) + if err != nil { + return nil, err + } + + return &user, nil +} + +func (r *Service) DeleteUser(email string) { + err := r.cache.DeleteCache(context.Background(), "UserCache:"+email) + if err != nil { + return + } +} + +func (r *Service) GetFile(id string) (*models.File, error) { + fileJSON, err := r.cache.GetCache(context.Background(), "FileCache:"+id) + if err == redis.Nil { + uploadData, err := r.db.GetFile(id) + if err != nil { + return nil, err + } + + newFileJSON, _ := json.Marshal(uploadData) + err = r.cache.SetCache(context.Background(), "FileCache:"+id, newFileJSON, time.Hour*24) + if err != nil { + return nil, err + } + return uploadData, nil + } + if err != nil { + return nil, err + } + + var fileCache models.File + err = json.Unmarshal([]byte(fileJSON), &fileCache) + if err != nil { + return nil, err + } + return &fileCache, nil +} + +func (r *Service) GetUserFile(name, ownerID string) (*types.FileWithDetail, error) { + fileData, err := r.db.GetUserFile(name, ownerID) + if err != nil { + return nil, err + } + + dada := &types.FileWithDetail{ + ID: fileData.ID, + OwnerID: fileData.OwnerID, + Name: fileData.Name, + Size: fileData.Size, + Downloaded: fileData.Downloaded, + } + return dada, nil +} diff --git a/session/session.go b/session/session.go index 61a7fba..0c94588 100644 --- a/session/session.go +++ b/session/session.go @@ -1,22 +1,22 @@ package session import ( - "fmt" + "context" + "encoding/json" + "errors" "github.com/fossyy/filekeeper/app" "github.com/fossyy/filekeeper/types" + "github.com/redis/go-redis/v9" "net/http" "strconv" - "sync" + "strings" "time" "github.com/fossyy/filekeeper/utils" ) type Session struct { - ID string - Values map[string]interface{} - CreateTime time.Time - mu sync.Mutex + ID string } type SessionInfo struct { @@ -27,7 +27,6 @@ type SessionInfo struct { OSVersion string IP string Location string - AccessAt string } type UserStatus string @@ -39,61 +38,48 @@ const ( InvalidSession UserStatus = "invalid_session" ) -var GlobalSessionStore = make(map[string]*Session) -var UserSessionInfoList = make(map[string]map[string]*SessionInfo) - -func init() { - - ticker := time.NewTicker(time.Minute) - go func() { - for { - <-ticker.C - currentTime := time.Now() - cacheClean := 0 - cleanID := utils.GenerateRandomString(10) - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Session] [%s] initiated at %02d:%02d:%02d", cleanID, currentTime.Hour(), currentTime.Minute(), currentTime.Second())) - - for _, data := range GlobalSessionStore { - data.mu.Lock() - if currentTime.Sub(data.CreateTime) > time.Hour*24*7 { - RemoveSessionInfo(data.Values["user"].(types.User).Email, data.ID) - delete(GlobalSessionStore, data.ID) - cacheClean++ - } - data.mu.Unlock() - } - - app.Server.Logger.Info(fmt.Sprintf("Cache cleanup [Session] [%s] completed: %d entries removed. Finished at %s", cleanID, cacheClean, time.Since(currentTime))) - } - }() -} - func (e *SessionNotFoundError) Error() string { return "session not found" } -func Get(id string) (*Session, error) { - if session, ok := GlobalSessionStore[id]; ok { - return session, nil - } - return nil, &SessionNotFoundError{} +func Get(id string) *Session { + return &Session{ID: id} } -func Create() *Session { +func Create(values types.User) (*Session, error) { id := utils.GenerateRandomString(128) - session := &Session{ - ID: id, - CreateTime: time.Now(), - Values: make(map[string]interface{}), + + sessionData, err := json.Marshal(values) + if err != nil { + return nil, err } - GlobalSessionStore[id] = session - return session + + err = app.Server.Cache.SetCache(context.Background(), "Session:"+id, string(sessionData), time.Hour*24) + if err != nil { + return nil, err + } + + return &Session{ID: id}, nil } -func (s *Session) Delete() { - s.mu.Lock() - defer s.mu.Unlock() - delete(GlobalSessionStore, s.ID) +func (s *Session) Change(user types.User) error { + newSessionValue, err := json.Marshal(user) + if err != nil { + return err + } + err = app.Server.Cache.SetCache(context.Background(), "Session:"+s.ID, newSessionValue, time.Hour*24*7) + if err != nil { + return err + } + return nil +} + +func (s *Session) Delete() error { + err := app.Server.Cache.DeleteCache(context.Background(), "Session:"+s.ID) + if err != nil { + return err + } + return nil } func (s *Session) Save(w http.ResponseWriter) { @@ -114,46 +100,71 @@ func (s *Session) Destroy(w http.ResponseWriter) { }) } -func AddSessionInfo(email string, sessionInfo *SessionInfo) { - if _, ok := UserSessionInfoList[email]; !ok { - UserSessionInfoList[email] = make(map[string]*SessionInfo) +func AddSessionInfo(email string, sessionInfo *SessionInfo) error { + sessionInfoData, err := json.Marshal(sessionInfo) + if err != nil { + return err } - UserSessionInfoList[email][sessionInfo.SessionID] = sessionInfo -} - -func RemoveSessionInfo(email string, id string) { - if userSessions, ok := UserSessionInfoList[email]; ok { - if _, ok := userSessions[id]; ok { - delete(userSessions, id) - if len(userSessions) == 0 { - delete(UserSessionInfoList, email) - } - } + err = app.Server.Cache.SetCache(context.Background(), "UserSessionInfo:"+email+":"+sessionInfo.SessionID, string(sessionInfoData), 0) + if err != nil { + return err } + + return nil } -func RemoveAllSessions(email string) { - sessionInfos := UserSessionInfoList[email] - for _, sessionInfo := range sessionInfos { - delete(GlobalSessionStore, sessionInfo.SessionID) - } - delete(UserSessionInfoList, email) -} - -func GetSessionInfo(email string, id string) *SessionInfo { - if userSession, ok := UserSessionInfoList[email]; ok { - if sessionInfo, ok := userSession[id]; ok { - return sessionInfo - } +func RemoveSessionInfo(email string, id string) error { + key := "UserSessionInfo:" + email + ":" + id + err := app.Server.Cache.DeleteCache(context.Background(), key) + if err != nil { + return err } return nil } -func (sessionInfo *SessionInfo) UpdateAccessTime() { - currentTime := time.Now() - formattedTime := currentTime.Format("01-02-2006") - sessionInfo.AccessAt = formattedTime +func RemoveAllSessions(email string) error { + pattern := "UserSessionInfo:" + email + ":*" + keys, err := app.Server.Cache.GetKeys(context.Background(), pattern) + if err != nil { + return err + } + + for _, key := range keys { + parts := strings.Split(key, ":") + sessionID := parts[2] + + err = app.Server.Cache.DeleteCache(context.Background(), "Session:"+sessionID) + if err != nil { + return err + } + err = app.Server.Cache.DeleteCache(context.Background(), key) + if err != nil { + return err + } + } + + return nil +} + +func GetSessionInfo(email string, id string) (*SessionInfo, error) { + key := "UserSessionInfo:" + email + ":" + id + + sessionInfoData, err := app.Server.Cache.GetCache(context.Background(), key) + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, nil + } + return nil, err + } + + var sessionInfo SessionInfo + err = json.Unmarshal([]byte(sessionInfoData), &sessionInfo) + if err != nil { + return nil, err + } + + return &sessionInfo, nil } func GetSession(r *http.Request) (UserStatus, types.User, string) { @@ -162,36 +173,52 @@ func GetSession(r *http.Request) (UserStatus, types.User, string) { return Unauthorized, types.User{}, "" } - storeSession, ok := GlobalSessionStore[cookie.Value] - if !ok { - return InvalidSession, types.User{}, "" - } - - val := storeSession.Values["user"] - var userSession = types.User{} - userSession, ok = val.(types.User) - if !ok { - return Unauthorized, types.User{}, "" - } - - if !userSession.Authenticated && userSession.Totp != "" { - return Unauthorized, userSession, cookie.Value - } - - if !userSession.Authenticated { - return Unauthorized, types.User{}, "" - } - - return Authorized, userSession, cookie.Value -} - -func GetSessions(email string) []*SessionInfo { - if sessions, ok := UserSessionInfoList[email]; ok { - result := make([]*SessionInfo, 0, len(sessions)) - for _, sessionInfo := range sessions { - result = append(result, sessionInfo) + sessionData, err := app.Server.Cache.GetCache(context.Background(), "Session:"+cookie.Value) + if err != nil { + if errors.Is(err, redis.Nil) { + return InvalidSession, types.User{}, "" } - return result + return Unauthorized, types.User{}, "" } - return nil + + var storeSession types.User + err = json.Unmarshal([]byte(sessionData), &storeSession) + + if err != nil { + return Unauthorized, types.User{}, "" + } + + if !storeSession.Authenticated && storeSession.Totp != "" { + return Unauthorized, storeSession, cookie.Value + } + + if !storeSession.Authenticated { + return Unauthorized, types.User{}, "" + } + return Authorized, storeSession, cookie.Value +} + +func GetSessions(email string) ([]*SessionInfo, error) { + pattern := "UserSessionInfo:" + email + ":*" + keys, err := app.Server.Cache.GetKeys(context.Background(), pattern) + if err != nil { + return nil, err + } + + var sessions []*SessionInfo + for _, key := range keys { + sessionData, err := app.Server.Cache.GetCache(context.Background(), key) + if err != nil { + return nil, err + } + + var sessionInfo SessionInfo + err = json.Unmarshal([]byte(sessionData), &sessionInfo) + if err != nil { + return nil, err + } + sessions = append(sessions, &sessionInfo) + } + + return sessions, nil } diff --git a/types/models/models.go b/types/models/models.go index 804a953..d1dd8aa 100644 --- a/types/models/models.go +++ b/types/models/models.go @@ -11,12 +11,10 @@ type User struct { } type File struct { - ID uuid.UUID `gorm:"primaryKey;not null;unique"` - OwnerID uuid.UUID `gorm:"not null"` - Name string `gorm:"not null"` - Size int64 `gorm:"not null"` - Downloaded int64 `gorm:"not null;default=0"` - UploadedByte int64 `gorm:"not null;default=0"` - UploadedChunk int64 `gorm:"not null;default=0"` - Done bool `gorm:"not null;default=false"` + ID uuid.UUID `gorm:"primaryKey;not null;unique"` + OwnerID uuid.UUID `gorm:"not null"` + Name string `gorm:"not null"` + Size int64 `gorm:"not null"` + TotalChunk int64 `gorm:"not null"` + Downloaded int64 `gorm:"not null;default=0"` } diff --git a/types/types.go b/types/types.go index 7eda14f..b16ca03 100644 --- a/types/types.go +++ b/types/types.go @@ -1,7 +1,10 @@ package types import ( + "context" + "github.com/fossyy/filekeeper/types/models" "github.com/google/uuid" + "time" ) type Message struct { @@ -29,3 +32,44 @@ type FileData struct { Size string Downloaded int64 } + +type FileWithDetail struct { + ID uuid.UUID + OwnerID uuid.UUID + Name string + Size int64 + Downloaded int64 + Chunk map[string]bool + Done bool +} + +type Database interface { + IsUserRegistered(email string, username string) bool + IsEmailRegistered(email string) bool + + CreateUser(user *models.User) error + GetUser(email string) (*models.User, error) + GetAllUsers() ([]models.User, error) + UpdateUserPassword(email string, password string) error + + CreateFile(file *models.File) error + GetFile(fileID string) (*models.File, error) + GetUserFile(name string, ownerID string) (*models.File, error) + GetFiles(ownerID string) ([]*models.File, error) + + InitializeTotp(email string, secret string) error +} + +type CachingServer interface { + GetCache(ctx context.Context, key string) (string, error) + SetCache(ctx context.Context, key string, value interface{}, expiration time.Duration) error + DeleteCache(ctx context.Context, key string) error + GetKeys(ctx context.Context, pattern string) ([]string, error) +} + +type Services interface { + GetUser(ctx context.Context, email string) (*models.User, error) + DeleteUser(email string) + GetFile(id string) (*models.File, error) + GetUserFile(name, ownerID string) (*FileWithDetail, error) +} diff --git a/view/client/user/user.templ b/view/client/user/user.templ index f9e1417..fd1b597 100644 --- a/view/client/user/user.templ +++ b/view/client/user/user.templ @@ -99,10 +99,6 @@ templ content(message types.Message, title string, user types.User, ListSession class="h-12 px-4 text-left align-middle font-medium text-muted-foreground [&:has([role=checkbox])]:pr-0"> Device -