From b4f303463dccd8d493020751a776aae209843721 Mon Sep 17 00:00:00 2001 From: Bagas Aulia Rezki Date: Sun, 28 Apr 2024 20:49:41 +0700 Subject: [PATCH] Separate database initialization into NewMYSQLdb function --- db/database.go | 165 ++++++++++++++++-- db/model/user/user.go | 10 +- handler/download/download.go | 14 +- handler/download/file/file.go | 14 +- handler/forgotPassword/forgotPassword.go | 10 +- handler/forgotPassword/verify/verify.go | 8 +- handler/signup/signup.go | 44 +++-- handler/signup/verify/verify.go | 9 +- .../upload/initialisation/initialisation.go | 33 ++-- handler/upload/upload.go | 27 ++- 10 files changed, 242 insertions(+), 92 deletions(-) diff --git a/db/database.go b/db/database.go index 86510e1..5de4fea 100644 --- a/db/database.go +++ b/db/database.go @@ -1,39 +1,178 @@ package db import ( + "errors" "fmt" - "os" - "strings" - "github.com/fossyy/filekeeper/logger" - "github.com/fossyy/filekeeper/utils" + "github.com/fossyy/filekeeper/types/models" "gorm.io/driver/mysql" "gorm.io/gorm" gormLogger "gorm.io/gorm/logger" + "os" + "strings" ) +var log *logger.AggregatedLogger var DB *gorm.DB -var log *logger.AggregatedLogger +type mySQLdb struct { + *gorm.DB +} -func init() { +type Database interface { + IsUserRegistered(email string, username string) bool + + CreateUser(user *models.User) error + GetUser(email string) (*models.User, error) + UpdateUserPassword(email string, password string) error + + CreateFile(file *models.File) error + GetFile(fileID string) (*models.File, error) + GetUserFile(name string, ownerID string) (*models.File, error) + GetFiles(ownerID string) ([]*models.File, error) + + CreateUploadInfo(info models.FilesUploaded) error + GetUploadInfo(uploadID string) (*models.FilesUploaded, error) + UpdateUpdateIndex(index int, fileID string) + FinalizeFileUpload(fileID string) +} + +func NewMYSQLdb(username, password, host, port, dbName string) Database { var err error - connection := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) - DB, err = gorm.Open(mysql.Open(connection), &gorm.Config{}, &gorm.Config{ + connection := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", username, password, host, port, dbName) + DB, err = gorm.Open(mysql.New(mysql.Config{ + DSN: connection, + DefaultStringSize: 256, + DisableDatetimePrecision: true, + DontSupportRenameIndex: true, + DontSupportRenameColumn: true, + SkipInitializeWithVersion: false, + }), &gorm.Config{ Logger: gormLogger.Default.LogMode(gormLogger.Silent), }) + if err != nil { - panic("failed to connect database" + err.Error()) + panic("failed to connect database: " + err.Error()) } + file, err := os.ReadFile("schema.sql") if err != nil { - log.Error("Error opening file: %s", err.Error()) + panic("Error opening file: " + err.Error()) } - querys := strings.Split(string(file), "\n") - for _, query := range querys { + + queries := strings.Split(string(file), ";") + for _, query := range queries { + query = strings.TrimSpace(query) + if query == "" { + continue + } err := DB.Exec(query).Error if err != nil { - panic(err.Error()) + panic("Error executing query: " + err.Error()) } } + + return &mySQLdb{DB} +} + +func (db *mySQLdb) IsUserRegistered(email string, username string) bool { + var data models.User + err := db.DB.Table("users").Where("email = ? OR username = ?", email, username).First(&data).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false + } + return true + } + return true +} + +func (db *mySQLdb) CreateUser(user *models.User) error { + err := db.DB.Create(user).Error + if err != nil { + return err + } + return nil +} + +func (db *mySQLdb) GetUser(email string) (*models.User, error) { + var user models.User + err := db.DB.Table("users").Where("email = ?", email).First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +func (db *mySQLdb) UpdateUserPassword(email string, password string) error { + err := db.DB.Table("users").Where("email = ?", email).Update("password", password).Error + if err != nil { + return err + } + return nil +} + +func (db *mySQLdb) CreateFile(file *models.File) error { + err := db.DB.Create(file).Error + if err != nil { + return err + } + return nil +} + +func (db *mySQLdb) GetFile(fileID string) (*models.File, error) { + var file models.File + err := db.DB.Table("files").Where("id = ?", fileID).First(&file).Error + if err != nil { + return nil, err + } + return &file, nil +} + +func (db *mySQLdb) GetUserFile(name string, ownerID string) (*models.File, error) { + var file models.File + err := db.DB.Table("files").Where("name = ? AND owner_id = ?", name, ownerID).First(&file).Error + if err != nil { + return nil, err + } + return &file, nil +} + +func (db *mySQLdb) GetFiles(ownerID string) ([]*models.File, error) { + var files []*models.File + err := db.DB.Table("files").Where("owner_id = ?", ownerID).Find(&files).Error + if err != nil { + return nil, err + } + return files, err +} + +// CreateUploadInfo It's not optimal, but it's okay for now. Consider implementing caching instead of pushing all updates to the database for better performance in the future. +func (db *mySQLdb) CreateUploadInfo(info models.FilesUploaded) error { + err := db.DB.Create(info).Error + if err != nil { + return err + } + return nil +} + +func (db *mySQLdb) GetUploadInfo(fileID string) (*models.FilesUploaded, error) { + var info models.FilesUploaded + err := db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).First(&info).Error + if err != nil { + return nil, err + } + return &info, nil +} + +func (db *mySQLdb) UpdateUpdateIndex(index int, fileID string) { + db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ + "Uploaded": index, + }) +} + +func (db *mySQLdb) FinalizeFileUpload(fileID string) { + db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ + "Done": true, + }) } diff --git a/db/model/user/user.go b/db/model/user/user.go index 6d72a1f..367613a 100644 --- a/db/model/user/user.go +++ b/db/model/user/user.go @@ -2,6 +2,7 @@ package user import ( "fmt" + "github.com/fossyy/filekeeper/utils" "sync" "time" @@ -26,8 +27,12 @@ type UserWithExpired struct { var log *logger.AggregatedLogger var UserCache *Cache +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) UserCache = &Cache{users: make(map[string]*UserWithExpired)} ticker := time.NewTicker(time.Hour * 8) @@ -61,8 +66,7 @@ func Get(email string) (*UserWithExpired, error) { return user, nil } - var userData UserWithExpired - err := db.DB.Table("users").Where("email = ?", email).First(&userData).Error + userData, err := database.GetUser(email) if err != nil { return nil, err } @@ -75,7 +79,7 @@ func Get(email string) (*UserWithExpired, error) { AccessAt: time.Now(), } - return &userData, nil + return UserCache.users[email], nil } func DeleteCache(email string) { diff --git a/handler/download/download.go b/handler/download/download.go index f53ec36..f60c2a9 100644 --- a/handler/download/download.go +++ b/handler/download/download.go @@ -9,15 +9,19 @@ import ( "github.com/fossyy/filekeeper/middleware" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" - "github.com/fossyy/filekeeper/types/models" "github.com/fossyy/filekeeper/utils" downloadView "github.com/fossyy/filekeeper/view/download" ) var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { @@ -41,8 +45,12 @@ func GET(w http.ResponseWriter, r *http.Request) { } userSession := middleware.GetUser(storeSession) - var files []models.File - db.DB.Table("files").Where("owner_id = ?", userSession.UserID).Find(&files) + files, err := database.GetFiles(userSession.UserID.String()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + var filesData []types.FileData for i := 0; i < len(files); i++ { filesData = append(filesData, types.FileData{ diff --git a/handler/download/file/file.go b/handler/download/file/file.go index f258c44..8fb5021 100644 --- a/handler/download/file/file.go +++ b/handler/download/file/file.go @@ -1,29 +1,33 @@ package downloadFileHandler import ( + "github.com/fossyy/filekeeper/utils" "net/http" "os" "path/filepath" "github.com/fossyy/filekeeper/db" "github.com/fossyy/filekeeper/logger" - "github.com/fossyy/filekeeper/types/models" ) var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { fileID := r.PathValue("id") - - var file models.File - err := db.DB.Table("files").Where("id = ?", fileID).First(&file).Error + file, err := database.GetFile(fileID) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) + return } uploadDir := "uploads" @@ -42,6 +46,7 @@ func GET(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) + return } defer openFile.Close() @@ -49,6 +54,7 @@ func GET(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) + return } w.Header().Set("Content-Disposition", "attachment; filename="+stat.Name()) diff --git a/handler/forgotPassword/forgotPassword.go b/handler/forgotPassword/forgotPassword.go index f11c233..d6368af 100644 --- a/handler/forgotPassword/forgotPassword.go +++ b/handler/forgotPassword/forgotPassword.go @@ -33,12 +33,17 @@ var mailServer *email.SmtpServer var ListForgotPassword map[string]*ForgotPassword var UserForgotPassword = make(map[string]string) +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() ListForgotPassword = make(map[string]*ForgotPassword) smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT")) mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) ticker := time.NewTicker(time.Minute) + //TESTING + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) go func() { for { <-ticker.C @@ -84,8 +89,7 @@ func POST(w http.ResponseWriter, r *http.Request) { emailForm := r.Form.Get("email") - var user models.User - err = db.DB.Table("users").Where("email = ?", emailForm).First(&user).Error + user, err := database.GetUser(emailForm) if errors.Is(err, gorm.ErrRecordNotFound) { component := forgotPasswordView.Main(fmt.Sprintf("Account with this email address %s is not found", emailForm), types.Message{ Code: 0, @@ -100,7 +104,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - err = verifyForgot(&user) + err = verifyForgot(user) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) diff --git a/handler/forgotPassword/verify/verify.go b/handler/forgotPassword/verify/verify.go index e5b3a9c..0700564 100644 --- a/handler/forgotPassword/verify/verify.go +++ b/handler/forgotPassword/verify/verify.go @@ -16,8 +16,14 @@ import ( var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + //TESTING + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { @@ -84,7 +90,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - err = db.DB.Table("users").Where("email = ?", data.User.Email).Update("password", hashedPassword).Error + err = database.UpdateUserPassword(data.User.Email, hashedPassword) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) diff --git a/handler/signup/signup.go b/handler/signup/signup.go index ade664f..e489612 100644 --- a/handler/signup/signup.go +++ b/handler/signup/signup.go @@ -3,7 +3,6 @@ package signupHandler import ( "bytes" "context" - "errors" "fmt" "net/http" "strconv" @@ -19,7 +18,6 @@ import ( emailView "github.com/fossyy/filekeeper/view/email" signupView "github.com/fossyy/filekeeper/view/signup" "github.com/google/uuid" - "gorm.io/gorm" ) type UnverifiedUser struct { @@ -34,12 +32,16 @@ var mailServer *email.SmtpServer var VerifyUser map[string]*UnverifiedUser var VerifyEmail map[string]string +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT")) mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) VerifyUser = make(map[string]*UnverifiedUser) VerifyEmail = make(map[string]string) + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) ticker := time.NewTicker(time.Minute) go func() { @@ -110,34 +112,28 @@ func POST(w http.ResponseWriter, r *http.Request) { Password: hashedPassword, } - var data models.User - err = db.DB.Table("users").Where("email = ? OR username = ?", userEmail, username).First(&data).Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - err = verifyEmail(&newUser) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - log.Error(err.Error()) - return - } - - component := signupView.EmailSend("Sign up Page") - err = component.Render(r.Context(), w) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - log.Error(err.Error()) - return - } + if registered := database.IsUserRegistered(userEmail, username); registered { + component := signupView.Main("Sign up Page", types.Message{ + Code: 0, + Message: "Email or Username has been registered", + }) + err = component.Render(r.Context(), w) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + log.Error(err.Error()) return } + return + } + + err = verifyEmail(&newUser) + if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) return } - component := signupView.Main("Sign up Page", types.Message{ - Code: 0, - Message: "Email or Username has been registered", - }) + + component := signupView.EmailSend("Sign up Page") err = component.Render(r.Context(), w) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/handler/signup/verify/verify.go b/handler/signup/verify/verify.go index 01d0af2..007825f 100644 --- a/handler/signup/verify/verify.go +++ b/handler/signup/verify/verify.go @@ -1,6 +1,7 @@ package signupVerifyHandler import ( + "github.com/fossyy/filekeeper/utils" "net/http" "github.com/fossyy/filekeeper/db" @@ -12,8 +13,13 @@ import ( var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { @@ -25,8 +31,7 @@ func GET(w http.ResponseWriter, r *http.Request) { return } - err := db.DB.Create(&data.User).Error - + err := database.CreateUser(data.User) if err != nil { component := signupView.Main("Sign up Page", types.Message{ Code: 0, diff --git a/handler/upload/initialisation/initialisation.go b/handler/upload/initialisation/initialisation.go index 91d6ac6..d0f9ce8 100644 --- a/handler/upload/initialisation/initialisation.go +++ b/handler/upload/initialisation/initialisation.go @@ -3,6 +3,7 @@ package initialisation import ( "encoding/json" "errors" + "github.com/fossyy/filekeeper/utils" "io" "net/http" "os" @@ -20,8 +21,13 @@ import ( var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func POST(w http.ResponseWriter, r *http.Request) { @@ -53,7 +59,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - fileData, err := getFile(fileInfo.Name, userSession.UserID) + fileData, err := database.GetUserFile(fileInfo.Name, userSession.UserID.String()) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { upload, err := handleNewUpload(userSession, fileInfo) @@ -68,7 +74,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - info, err := GetUploadInfo(fileData.ID.String()) + info, err := database.GetUploadInfo(fileData.ID.String()) if err != nil { log.Error(err.Error()) return @@ -81,15 +87,6 @@ func POST(w http.ResponseWriter, r *http.Request) { respondJSON(w, info) } -func getFile(name string, ownerID uuid.UUID) (models.File, error) { - var data models.File - err := db.DB.Table("files").Where("name = ? AND owner_id = ?", name, ownerID).First(&data).Error - if err != nil { - return data, err - } - return data, nil -} - func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded, error) { uploadDir := "uploads" if _, err := os.Stat(uploadDir); os.IsNotExist(err) { @@ -124,7 +121,8 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded Size: file.Size, Downloaded: 0, } - err = db.DB.Create(&newFile).Error + + err = database.CreateFile(&newFile) if err != nil { log.Error(err.Error()) return models.FilesUploaded{}, err @@ -140,7 +138,7 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded Done: false, } - err = db.DB.Create(&filesUploaded).Error + err = database.CreateUploadInfo(filesUploaded) if err != nil { log.Error(err.Error()) return models.FilesUploaded{}, err @@ -148,15 +146,6 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded return filesUploaded, nil } -func GetUploadInfo(fileID string) (*models.FilesUploaded, error) { - var data *models.FilesUploaded - err := db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).First(&data).Error - if err != nil { - return data, err - } - return data, nil -} - func respondJSON(w http.ResponseWriter, data interface{}) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(data); err != nil { diff --git a/handler/upload/upload.go b/handler/upload/upload.go index 644adea..e1df7b2 100644 --- a/handler/upload/upload.go +++ b/handler/upload/upload.go @@ -2,6 +2,8 @@ package uploadHandler import ( "errors" + "github.com/fossyy/filekeeper/db" + "github.com/fossyy/filekeeper/utils" "io" "net/http" "os" @@ -9,8 +11,6 @@ import ( "strconv" "sync" - "github.com/fossyy/filekeeper/db" - "github.com/fossyy/filekeeper/handler/upload/initialisation" "github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/middleware" "github.com/fossyy/filekeeper/session" @@ -20,8 +20,13 @@ import ( var log *logger.AggregatedLogger var mu sync.Mutex +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { @@ -57,7 +62,7 @@ func POST(w http.ResponseWriter, r *http.Request) { userSession := middleware.GetUser(storeSession) if r.FormValue("done") == "true" { - finalizeFileUpload(fileID) + database.FinalizeFileUpload(fileID) return } @@ -67,7 +72,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - file, err := initialisation.GetUploadInfo(fileID) + file, err := database.GetUploadInfo(fileID) if err != nil { log.Error("error getting upload info: " + err.Error()) return @@ -105,13 +110,7 @@ func POST(w http.ResponseWriter, r *http.Request) { if err != nil { return } - updateIndex(index, fileID) -} - -func finalizeFileUpload(fileID string) { - db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ - "Done": true, - }) + database.UpdateUpdateIndex(index, fileID) } func createUploadDirectory(uploadDir string) error { @@ -123,12 +122,6 @@ func createUploadDirectory(uploadDir string) error { return nil } -func updateIndex(index int, fileID string) { - db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ - "Uploaded": index, - }) -} - func handleCookieError(w http.ResponseWriter, r *http.Request, err error) { if errors.Is(err, http.ErrNoCookie) { http.Redirect(w, r, "/signin", http.StatusSeeOther)