From b4f303463dccd8d493020751a776aae209843721 Mon Sep 17 00:00:00 2001 From: Bagas Aulia Rezki Date: Sun, 28 Apr 2024 20:49:41 +0700 Subject: [PATCH 1/3] Separate database initialization into NewMYSQLdb function --- db/database.go | 165 ++++++++++++++++-- db/model/user/user.go | 10 +- handler/download/download.go | 14 +- handler/download/file/file.go | 14 +- handler/forgotPassword/forgotPassword.go | 10 +- handler/forgotPassword/verify/verify.go | 8 +- handler/signup/signup.go | 44 +++-- handler/signup/verify/verify.go | 9 +- .../upload/initialisation/initialisation.go | 33 ++-- handler/upload/upload.go | 27 ++- 10 files changed, 242 insertions(+), 92 deletions(-) diff --git a/db/database.go b/db/database.go index 86510e1..5de4fea 100644 --- a/db/database.go +++ b/db/database.go @@ -1,39 +1,178 @@ package db import ( + "errors" "fmt" - "os" - "strings" - "github.com/fossyy/filekeeper/logger" - "github.com/fossyy/filekeeper/utils" + "github.com/fossyy/filekeeper/types/models" "gorm.io/driver/mysql" "gorm.io/gorm" gormLogger "gorm.io/gorm/logger" + "os" + "strings" ) +var log *logger.AggregatedLogger var DB *gorm.DB -var log *logger.AggregatedLogger +type mySQLdb struct { + *gorm.DB +} -func init() { +type Database interface { + IsUserRegistered(email string, username string) bool + + CreateUser(user *models.User) error + GetUser(email string) (*models.User, error) + UpdateUserPassword(email string, password string) error + + CreateFile(file *models.File) error + GetFile(fileID string) (*models.File, error) + GetUserFile(name string, ownerID string) (*models.File, error) + GetFiles(ownerID string) ([]*models.File, error) + + CreateUploadInfo(info models.FilesUploaded) error + GetUploadInfo(uploadID string) (*models.FilesUploaded, error) + UpdateUpdateIndex(index int, fileID string) + FinalizeFileUpload(fileID string) +} + +func NewMYSQLdb(username, password, host, port, dbName string) Database { var err error - connection := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) - DB, err = gorm.Open(mysql.Open(connection), &gorm.Config{}, &gorm.Config{ + connection := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", username, password, host, port, dbName) + DB, err = gorm.Open(mysql.New(mysql.Config{ + DSN: connection, + DefaultStringSize: 256, + DisableDatetimePrecision: true, + DontSupportRenameIndex: true, + DontSupportRenameColumn: true, + SkipInitializeWithVersion: false, + }), &gorm.Config{ Logger: gormLogger.Default.LogMode(gormLogger.Silent), }) + if err != nil { - panic("failed to connect database" + err.Error()) + panic("failed to connect database: " + err.Error()) } + file, err := os.ReadFile("schema.sql") if err != nil { - log.Error("Error opening file: %s", err.Error()) + panic("Error opening file: " + err.Error()) } - querys := strings.Split(string(file), "\n") - for _, query := range querys { + + queries := strings.Split(string(file), ";") + for _, query := range queries { + query = strings.TrimSpace(query) + if query == "" { + continue + } err := DB.Exec(query).Error if err != nil { - panic(err.Error()) + panic("Error executing query: " + err.Error()) } } + + return &mySQLdb{DB} +} + +func (db *mySQLdb) IsUserRegistered(email string, username string) bool { + var data models.User + err := db.DB.Table("users").Where("email = ? OR username = ?", email, username).First(&data).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false + } + return true + } + return true +} + +func (db *mySQLdb) CreateUser(user *models.User) error { + err := db.DB.Create(user).Error + if err != nil { + return err + } + return nil +} + +func (db *mySQLdb) GetUser(email string) (*models.User, error) { + var user models.User + err := db.DB.Table("users").Where("email = ?", email).First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +func (db *mySQLdb) UpdateUserPassword(email string, password string) error { + err := db.DB.Table("users").Where("email = ?", email).Update("password", password).Error + if err != nil { + return err + } + return nil +} + +func (db *mySQLdb) CreateFile(file *models.File) error { + err := db.DB.Create(file).Error + if err != nil { + return err + } + return nil +} + +func (db *mySQLdb) GetFile(fileID string) (*models.File, error) { + var file models.File + err := db.DB.Table("files").Where("id = ?", fileID).First(&file).Error + if err != nil { + return nil, err + } + return &file, nil +} + +func (db *mySQLdb) GetUserFile(name string, ownerID string) (*models.File, error) { + var file models.File + err := db.DB.Table("files").Where("name = ? AND owner_id = ?", name, ownerID).First(&file).Error + if err != nil { + return nil, err + } + return &file, nil +} + +func (db *mySQLdb) GetFiles(ownerID string) ([]*models.File, error) { + var files []*models.File + err := db.DB.Table("files").Where("owner_id = ?", ownerID).Find(&files).Error + if err != nil { + return nil, err + } + return files, err +} + +// CreateUploadInfo It's not optimal, but it's okay for now. Consider implementing caching instead of pushing all updates to the database for better performance in the future. +func (db *mySQLdb) CreateUploadInfo(info models.FilesUploaded) error { + err := db.DB.Create(info).Error + if err != nil { + return err + } + return nil +} + +func (db *mySQLdb) GetUploadInfo(fileID string) (*models.FilesUploaded, error) { + var info models.FilesUploaded + err := db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).First(&info).Error + if err != nil { + return nil, err + } + return &info, nil +} + +func (db *mySQLdb) UpdateUpdateIndex(index int, fileID string) { + db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ + "Uploaded": index, + }) +} + +func (db *mySQLdb) FinalizeFileUpload(fileID string) { + db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ + "Done": true, + }) } diff --git a/db/model/user/user.go b/db/model/user/user.go index 6d72a1f..367613a 100644 --- a/db/model/user/user.go +++ b/db/model/user/user.go @@ -2,6 +2,7 @@ package user import ( "fmt" + "github.com/fossyy/filekeeper/utils" "sync" "time" @@ -26,8 +27,12 @@ type UserWithExpired struct { var log *logger.AggregatedLogger var UserCache *Cache +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) UserCache = &Cache{users: make(map[string]*UserWithExpired)} ticker := time.NewTicker(time.Hour * 8) @@ -61,8 +66,7 @@ func Get(email string) (*UserWithExpired, error) { return user, nil } - var userData UserWithExpired - err := db.DB.Table("users").Where("email = ?", email).First(&userData).Error + userData, err := database.GetUser(email) if err != nil { return nil, err } @@ -75,7 +79,7 @@ func Get(email string) (*UserWithExpired, error) { AccessAt: time.Now(), } - return &userData, nil + return UserCache.users[email], nil } func DeleteCache(email string) { diff --git a/handler/download/download.go b/handler/download/download.go index f53ec36..f60c2a9 100644 --- a/handler/download/download.go +++ b/handler/download/download.go @@ -9,15 +9,19 @@ import ( "github.com/fossyy/filekeeper/middleware" "github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/types" - "github.com/fossyy/filekeeper/types/models" "github.com/fossyy/filekeeper/utils" downloadView "github.com/fossyy/filekeeper/view/download" ) var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { @@ -41,8 +45,12 @@ func GET(w http.ResponseWriter, r *http.Request) { } userSession := middleware.GetUser(storeSession) - var files []models.File - db.DB.Table("files").Where("owner_id = ?", userSession.UserID).Find(&files) + files, err := database.GetFiles(userSession.UserID.String()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + var filesData []types.FileData for i := 0; i < len(files); i++ { filesData = append(filesData, types.FileData{ diff --git a/handler/download/file/file.go b/handler/download/file/file.go index f258c44..8fb5021 100644 --- a/handler/download/file/file.go +++ b/handler/download/file/file.go @@ -1,29 +1,33 @@ package downloadFileHandler import ( + "github.com/fossyy/filekeeper/utils" "net/http" "os" "path/filepath" "github.com/fossyy/filekeeper/db" "github.com/fossyy/filekeeper/logger" - "github.com/fossyy/filekeeper/types/models" ) var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { fileID := r.PathValue("id") - - var file models.File - err := db.DB.Table("files").Where("id = ?", fileID).First(&file).Error + file, err := database.GetFile(fileID) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) + return } uploadDir := "uploads" @@ -42,6 +46,7 @@ func GET(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) + return } defer openFile.Close() @@ -49,6 +54,7 @@ func GET(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) + return } w.Header().Set("Content-Disposition", "attachment; filename="+stat.Name()) diff --git a/handler/forgotPassword/forgotPassword.go b/handler/forgotPassword/forgotPassword.go index f11c233..d6368af 100644 --- a/handler/forgotPassword/forgotPassword.go +++ b/handler/forgotPassword/forgotPassword.go @@ -33,12 +33,17 @@ var mailServer *email.SmtpServer var ListForgotPassword map[string]*ForgotPassword var UserForgotPassword = make(map[string]string) +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() ListForgotPassword = make(map[string]*ForgotPassword) smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT")) mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) ticker := time.NewTicker(time.Minute) + //TESTING + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) go func() { for { <-ticker.C @@ -84,8 +89,7 @@ func POST(w http.ResponseWriter, r *http.Request) { emailForm := r.Form.Get("email") - var user models.User - err = db.DB.Table("users").Where("email = ?", emailForm).First(&user).Error + user, err := database.GetUser(emailForm) if errors.Is(err, gorm.ErrRecordNotFound) { component := forgotPasswordView.Main(fmt.Sprintf("Account with this email address %s is not found", emailForm), types.Message{ Code: 0, @@ -100,7 +104,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - err = verifyForgot(&user) + err = verifyForgot(user) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) diff --git a/handler/forgotPassword/verify/verify.go b/handler/forgotPassword/verify/verify.go index e5b3a9c..0700564 100644 --- a/handler/forgotPassword/verify/verify.go +++ b/handler/forgotPassword/verify/verify.go @@ -16,8 +16,14 @@ import ( var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + //TESTING + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { @@ -84,7 +90,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - err = db.DB.Table("users").Where("email = ?", data.User.Email).Update("password", hashedPassword).Error + err = database.UpdateUserPassword(data.User.Email, hashedPassword) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) diff --git a/handler/signup/signup.go b/handler/signup/signup.go index ade664f..e489612 100644 --- a/handler/signup/signup.go +++ b/handler/signup/signup.go @@ -3,7 +3,6 @@ package signupHandler import ( "bytes" "context" - "errors" "fmt" "net/http" "strconv" @@ -19,7 +18,6 @@ import ( emailView "github.com/fossyy/filekeeper/view/email" signupView "github.com/fossyy/filekeeper/view/signup" "github.com/google/uuid" - "gorm.io/gorm" ) type UnverifiedUser struct { @@ -34,12 +32,16 @@ var mailServer *email.SmtpServer var VerifyUser map[string]*UnverifiedUser var VerifyEmail map[string]string +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT")) mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) VerifyUser = make(map[string]*UnverifiedUser) VerifyEmail = make(map[string]string) + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) ticker := time.NewTicker(time.Minute) go func() { @@ -110,34 +112,28 @@ func POST(w http.ResponseWriter, r *http.Request) { Password: hashedPassword, } - var data models.User - err = db.DB.Table("users").Where("email = ? OR username = ?", userEmail, username).First(&data).Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - err = verifyEmail(&newUser) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - log.Error(err.Error()) - return - } - - component := signupView.EmailSend("Sign up Page") - err = component.Render(r.Context(), w) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - log.Error(err.Error()) - return - } + if registered := database.IsUserRegistered(userEmail, username); registered { + component := signupView.Main("Sign up Page", types.Message{ + Code: 0, + Message: "Email or Username has been registered", + }) + err = component.Render(r.Context(), w) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + log.Error(err.Error()) return } + return + } + + err = verifyEmail(&newUser) + if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) return } - component := signupView.Main("Sign up Page", types.Message{ - Code: 0, - Message: "Email or Username has been registered", - }) + + component := signupView.EmailSend("Sign up Page") err = component.Render(r.Context(), w) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) diff --git a/handler/signup/verify/verify.go b/handler/signup/verify/verify.go index 01d0af2..007825f 100644 --- a/handler/signup/verify/verify.go +++ b/handler/signup/verify/verify.go @@ -1,6 +1,7 @@ package signupVerifyHandler import ( + "github.com/fossyy/filekeeper/utils" "net/http" "github.com/fossyy/filekeeper/db" @@ -12,8 +13,13 @@ import ( var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { @@ -25,8 +31,7 @@ func GET(w http.ResponseWriter, r *http.Request) { return } - err := db.DB.Create(&data.User).Error - + err := database.CreateUser(data.User) if err != nil { component := signupView.Main("Sign up Page", types.Message{ Code: 0, diff --git a/handler/upload/initialisation/initialisation.go b/handler/upload/initialisation/initialisation.go index 91d6ac6..d0f9ce8 100644 --- a/handler/upload/initialisation/initialisation.go +++ b/handler/upload/initialisation/initialisation.go @@ -3,6 +3,7 @@ package initialisation import ( "encoding/json" "errors" + "github.com/fossyy/filekeeper/utils" "io" "net/http" "os" @@ -20,8 +21,13 @@ import ( var log *logger.AggregatedLogger +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func POST(w http.ResponseWriter, r *http.Request) { @@ -53,7 +59,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - fileData, err := getFile(fileInfo.Name, userSession.UserID) + fileData, err := database.GetUserFile(fileInfo.Name, userSession.UserID.String()) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { upload, err := handleNewUpload(userSession, fileInfo) @@ -68,7 +74,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - info, err := GetUploadInfo(fileData.ID.String()) + info, err := database.GetUploadInfo(fileData.ID.String()) if err != nil { log.Error(err.Error()) return @@ -81,15 +87,6 @@ func POST(w http.ResponseWriter, r *http.Request) { respondJSON(w, info) } -func getFile(name string, ownerID uuid.UUID) (models.File, error) { - var data models.File - err := db.DB.Table("files").Where("name = ? AND owner_id = ?", name, ownerID).First(&data).Error - if err != nil { - return data, err - } - return data, nil -} - func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded, error) { uploadDir := "uploads" if _, err := os.Stat(uploadDir); os.IsNotExist(err) { @@ -124,7 +121,8 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded Size: file.Size, Downloaded: 0, } - err = db.DB.Create(&newFile).Error + + err = database.CreateFile(&newFile) if err != nil { log.Error(err.Error()) return models.FilesUploaded{}, err @@ -140,7 +138,7 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded Done: false, } - err = db.DB.Create(&filesUploaded).Error + err = database.CreateUploadInfo(filesUploaded) if err != nil { log.Error(err.Error()) return models.FilesUploaded{}, err @@ -148,15 +146,6 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded return filesUploaded, nil } -func GetUploadInfo(fileID string) (*models.FilesUploaded, error) { - var data *models.FilesUploaded - err := db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).First(&data).Error - if err != nil { - return data, err - } - return data, nil -} - func respondJSON(w http.ResponseWriter, data interface{}) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(data); err != nil { diff --git a/handler/upload/upload.go b/handler/upload/upload.go index 644adea..e1df7b2 100644 --- a/handler/upload/upload.go +++ b/handler/upload/upload.go @@ -2,6 +2,8 @@ package uploadHandler import ( "errors" + "github.com/fossyy/filekeeper/db" + "github.com/fossyy/filekeeper/utils" "io" "net/http" "os" @@ -9,8 +11,6 @@ import ( "strconv" "sync" - "github.com/fossyy/filekeeper/db" - "github.com/fossyy/filekeeper/handler/upload/initialisation" "github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/middleware" "github.com/fossyy/filekeeper/session" @@ -20,8 +20,13 @@ import ( var log *logger.AggregatedLogger var mu sync.Mutex +// TESTTING VAR +var database db.Database + func init() { log = logger.Logger() + database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + } func GET(w http.ResponseWriter, r *http.Request) { @@ -57,7 +62,7 @@ func POST(w http.ResponseWriter, r *http.Request) { userSession := middleware.GetUser(storeSession) if r.FormValue("done") == "true" { - finalizeFileUpload(fileID) + database.FinalizeFileUpload(fileID) return } @@ -67,7 +72,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - file, err := initialisation.GetUploadInfo(fileID) + file, err := database.GetUploadInfo(fileID) if err != nil { log.Error("error getting upload info: " + err.Error()) return @@ -105,13 +110,7 @@ func POST(w http.ResponseWriter, r *http.Request) { if err != nil { return } - updateIndex(index, fileID) -} - -func finalizeFileUpload(fileID string) { - db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ - "Done": true, - }) + database.UpdateUpdateIndex(index, fileID) } func createUploadDirectory(uploadDir string) error { @@ -123,12 +122,6 @@ func createUploadDirectory(uploadDir string) error { return nil } -func updateIndex(index int, fileID string) { - db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ - "Uploaded": index, - }) -} - func handleCookieError(w http.ResponseWriter, r *http.Request, err error) { if errors.Is(err, http.ErrNoCookie) { http.Redirect(w, r, "/signin", http.StatusSeeOther) From 5ca6cdd1293008184898bba280aa235cb7cb00b9 Mon Sep 17 00:00:00 2001 From: Bagas Aulia Rezki Date: Sun, 28 Apr 2024 21:52:33 +0700 Subject: [PATCH 2/3] Add PostgreSQL as a supported database --- db/database.go | 148 +++++++++++++++++- db/model/user/user.go | 2 +- go.mod | 5 + go.sum | 24 +++ handler/download/download.go | 4 +- handler/download/file/file.go | 2 +- handler/forgotPassword/forgotPassword.go | 2 +- handler/forgotPassword/verify/verify.go | 2 +- handler/signup/signup.go | 2 +- handler/signup/verify/verify.go | 2 +- .../upload/initialisation/initialisation.go | 2 +- handler/upload/upload.go | 2 +- 12 files changed, 183 insertions(+), 14 deletions(-) diff --git a/db/database.go b/db/database.go index 5de4fea..7eb91ce 100644 --- a/db/database.go +++ b/db/database.go @@ -6,6 +6,7 @@ import ( "github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/types/models" "gorm.io/driver/mysql" + "gorm.io/driver/postgres" "gorm.io/gorm" gormLogger "gorm.io/gorm/logger" "os" @@ -13,12 +14,15 @@ import ( ) var log *logger.AggregatedLogger -var DB *gorm.DB type mySQLdb struct { *gorm.DB } +type postgresDB struct { + *gorm.DB +} + type Database interface { IsUserRegistered(email string, username string) bool @@ -40,7 +44,7 @@ type Database interface { func NewMYSQLdb(username, password, host, port, dbName string) Database { var err error connection := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", username, password, host, port, dbName) - DB, err = gorm.Open(mysql.New(mysql.Config{ + DB, err := gorm.Open(mysql.New(mysql.Config{ DSN: connection, DefaultStringSize: 256, DisableDatetimePrecision: true, @@ -75,6 +79,39 @@ func NewMYSQLdb(username, password, host, port, dbName string) Database { return &mySQLdb{DB} } +func NewPostgresDB(username, password, host, port, dbName string) Database { + var err error + connection := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=Asia/Jakarta", host, username, password, dbName, port, "disable") + DB, err := gorm.Open(postgres.New(postgres.Config{ + DSN: connection, + }), &gorm.Config{ + Logger: gormLogger.Default.LogMode(gormLogger.Silent), + }) + + if err != nil { + panic("failed to connect database: " + err.Error()) + } + + file, err := os.ReadFile("schema.sql") + if err != nil { + panic("Error opening file: " + err.Error()) + } + + queries := strings.Split(string(file), ";") + for _, query := range queries { + query = strings.TrimSpace(query) + if query == "" { + continue + } + err := DB.Exec(query).Error + if err != nil { + panic("Error executing query: " + err.Error()) + } + } + + return &postgresDB{DB} +} + func (db *mySQLdb) IsUserRegistered(email string, username string) bool { var data models.User err := db.DB.Table("users").Where("email = ? OR username = ?", email, username).First(&data).Error @@ -166,13 +203,116 @@ func (db *mySQLdb) GetUploadInfo(fileID string) (*models.FilesUploaded, error) { } func (db *mySQLdb) UpdateUpdateIndex(index int, fileID string) { - db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ + db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).Updates(map[string]interface{}{ "Uploaded": index, }) } func (db *mySQLdb) FinalizeFileUpload(fileID string) { - db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ + db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).Updates(map[string]interface{}{ + "Done": true, + }) +} + +// POSTGRES FUNCTION +func (db *postgresDB) IsUserRegistered(email string, username string) bool { + var data models.User + err := db.DB.Table("users").Where("email = $1 OR username = $2", email, username).First(&data).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false + } + return true + } + return true +} + +func (db *postgresDB) CreateUser(user *models.User) error { + err := db.DB.Create(user).Error + if err != nil { + return err + } + return nil +} + +func (db *postgresDB) GetUser(email string) (*models.User, error) { + var user models.User + err := db.DB.Table("users").Where("email = $1", email).First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +func (db *postgresDB) UpdateUserPassword(email string, password string) error { + err := db.DB.Table("users").Where("email = $1", email).Update("password", password).Error + if err != nil { + return err + } + return nil +} + +func (db *postgresDB) CreateFile(file *models.File) error { + err := db.DB.Create(file).Error + if err != nil { + return err + } + return nil +} + +func (db *postgresDB) GetFile(fileID string) (*models.File, error) { + var file models.File + err := db.DB.Table("files").Where("id = $1", fileID).First(&file).Error + if err != nil { + return nil, err + } + return &file, nil +} + +func (db *postgresDB) GetUserFile(name string, ownerID string) (*models.File, error) { + var file models.File + err := db.DB.Table("files").Where("name = $1 AND owner_id = $2", name, ownerID).First(&file).Error + if err != nil { + return nil, err + } + return &file, nil +} + +func (db *postgresDB) GetFiles(ownerID string) ([]*models.File, error) { + var files []*models.File + err := db.DB.Table("files").Where("owner_id = $1", ownerID).Find(&files).Error + if err != nil { + return nil, err + } + return files, err +} + +// CreateUploadInfo It's not optimal, but it's okay for now. Consider implementing caching instead of pushing all updates to the database for better performance in the future. +func (db *postgresDB) CreateUploadInfo(info models.FilesUploaded) error { + err := db.DB.Create(info).Error + if err != nil { + return err + } + return nil +} + +func (db *postgresDB) GetUploadInfo(fileID string) (*models.FilesUploaded, error) { + var info models.FilesUploaded + err := db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).First(&info).Error + if err != nil { + return nil, err + } + return &info, nil +} + +func (db *postgresDB) UpdateUpdateIndex(index int, fileID string) { + db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).Updates(map[string]interface{}{ + "Uploaded": index, + }) +} + +func (db *postgresDB) FinalizeFileUpload(fileID string) { + db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).Updates(map[string]interface{}{ "Done": true, }) } diff --git a/db/model/user/user.go b/db/model/user/user.go index 367613a..571d003 100644 --- a/db/model/user/user.go +++ b/db/model/user/user.go @@ -32,7 +32,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) UserCache = &Cache{users: make(map[string]*UserWithExpired)} ticker := time.NewTicker(time.Hour * 8) diff --git a/go.mod b/go.mod index 9564b2c..683d919 100644 --- a/go.mod +++ b/go.mod @@ -9,12 +9,17 @@ require ( golang.org/x/crypto v0.21.0 gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df gorm.io/driver/mysql v1.5.6 + gorm.io/driver/postgres v1.5.7 gorm.io/gorm v1.25.8 ) require ( github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + golang.org/x/text v0.14.0 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect ) diff --git a/go.sum b/go.sum index be21266..3f53935 100644 --- a/go.sum +++ b/go.sum @@ -1,25 +1,49 @@ github.com/a-h/templ v0.2.648 h1:A1ggHGIE7AONOHrFaDTM8SrqgqHL6fWgWCijQ21Zy9I= github.com/a-h/templ v0.2.648/go.mod h1:SA7mtYwVEajbIXFRh3vKdYm/4FYyLQAtPH1+KxzGPA8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= +gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.8 h1:WAGEZ/aEcznN4D03laj8DKnehe1e9gYQAjW8xyPRdeo= gorm.io/gorm v1.25.8/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/handler/download/download.go b/handler/download/download.go index f60c2a9..77ca38f 100644 --- a/handler/download/download.go +++ b/handler/download/download.go @@ -20,7 +20,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } @@ -50,7 +50,7 @@ func GET(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - + var filesData []types.FileData for i := 0; i < len(files); i++ { filesData = append(filesData, types.FileData{ diff --git a/handler/download/file/file.go b/handler/download/file/file.go index 8fb5021..c532d89 100644 --- a/handler/download/file/file.go +++ b/handler/download/file/file.go @@ -17,7 +17,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } diff --git a/handler/forgotPassword/forgotPassword.go b/handler/forgotPassword/forgotPassword.go index d6368af..4b16ee7 100644 --- a/handler/forgotPassword/forgotPassword.go +++ b/handler/forgotPassword/forgotPassword.go @@ -43,7 +43,7 @@ func init() { mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) ticker := time.NewTicker(time.Minute) //TESTING - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) go func() { for { <-ticker.C diff --git a/handler/forgotPassword/verify/verify.go b/handler/forgotPassword/verify/verify.go index 0700564..4265d05 100644 --- a/handler/forgotPassword/verify/verify.go +++ b/handler/forgotPassword/verify/verify.go @@ -22,7 +22,7 @@ var database db.Database func init() { log = logger.Logger() //TESTING - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } diff --git a/handler/signup/signup.go b/handler/signup/signup.go index e489612..58f4bca 100644 --- a/handler/signup/signup.go +++ b/handler/signup/signup.go @@ -41,7 +41,7 @@ func init() { mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) VerifyUser = make(map[string]*UnverifiedUser) VerifyEmail = make(map[string]string) - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) ticker := time.NewTicker(time.Minute) go func() { diff --git a/handler/signup/verify/verify.go b/handler/signup/verify/verify.go index 007825f..afcf477 100644 --- a/handler/signup/verify/verify.go +++ b/handler/signup/verify/verify.go @@ -18,7 +18,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } diff --git a/handler/upload/initialisation/initialisation.go b/handler/upload/initialisation/initialisation.go index d0f9ce8..a4e46a5 100644 --- a/handler/upload/initialisation/initialisation.go +++ b/handler/upload/initialisation/initialisation.go @@ -26,7 +26,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } diff --git a/handler/upload/upload.go b/handler/upload/upload.go index e1df7b2..3a58894 100644 --- a/handler/upload/upload.go +++ b/handler/upload/upload.go @@ -25,7 +25,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } From e6344743d55ac0140a4a718cdcb7f98bc7cfa14f Mon Sep 17 00:00:00 2001 From: Bagas Aulia Rezki Date: Sun, 28 Apr 2024 22:26:53 +0700 Subject: [PATCH 3/3] Separate run command from main function --- app/app.go | 48 +++++++++++++++++++ db/database.go | 12 ++++- db/model/user/user.go | 7 +-- handler/download/download.go | 7 +-- handler/download/file/file.go | 8 +--- handler/forgotPassword/forgotPassword.go | 6 +-- handler/forgotPassword/verify/verify.go | 6 +-- handler/signup/signup.go | 6 +-- handler/signup/verify/verify.go | 8 +--- .../upload/initialisation/initialisation.go | 14 ++---- handler/upload/upload.go | 12 ++--- main.go | 19 +------- 12 files changed, 74 insertions(+), 79 deletions(-) create mode 100644 app/app.go diff --git a/app/app.go b/app/app.go new file mode 100644 index 0000000..a38c406 --- /dev/null +++ b/app/app.go @@ -0,0 +1,48 @@ +package app + +import ( + "fmt" + "github.com/fossyy/filekeeper/db" + "github.com/fossyy/filekeeper/middleware" + "github.com/fossyy/filekeeper/routes" + "github.com/fossyy/filekeeper/utils" + "net/http" +) + +type App struct { + http.Server + DB db.Database +} + +var Server App + +func NewServer(addr string, handler http.Handler, database db.Database) App { + return App{ + Server: http.Server{ + Addr: addr, + Handler: handler, + }, + DB: database, + } +} + +func Start() { + serverAddr := fmt.Sprintf("%s:%s", utils.Getenv("SERVER_HOST"), utils.Getenv("SERVER_PORT")) + + dbUser := utils.Getenv("DB_USERNAME") + dbPass := utils.Getenv("DB_PASSWORD") + dbHost := utils.Getenv("DB_HOST") + dbPort := utils.Getenv("DB_PORT") + dbName := utils.Getenv("DB_NAME") + + database := db.NewPostgresDB(dbUser, dbPass, dbHost, dbPort, dbName, db.DisableSSL) + db.DB = database + + Server = NewServer(serverAddr, middleware.Handler(routes.SetupRoutes()), database) + fmt.Printf("Listening on http://%s\n", Server.Addr) + err := Server.ListenAndServe() + if err != nil { + panic(err) + return + } +} diff --git a/db/database.go b/db/database.go index 7eb91ce..8c04e42 100644 --- a/db/database.go +++ b/db/database.go @@ -14,6 +14,7 @@ import ( ) var log *logger.AggregatedLogger +var DB Database type mySQLdb struct { *gorm.DB @@ -23,6 +24,13 @@ type postgresDB struct { *gorm.DB } +type SSLMode string + +const ( + DisableSSL SSLMode = "disable" + EnableSSL SSLMode = "enable" +) + type Database interface { IsUserRegistered(email string, username string) bool @@ -79,9 +87,9 @@ func NewMYSQLdb(username, password, host, port, dbName string) Database { return &mySQLdb{DB} } -func NewPostgresDB(username, password, host, port, dbName string) Database { +func NewPostgresDB(username, password, host, port, dbName string, mode SSLMode) Database { var err error - connection := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=Asia/Jakarta", host, username, password, dbName, port, "disable") + connection := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=Asia/Jakarta", host, username, password, dbName, port, mode) DB, err := gorm.Open(postgres.New(postgres.Config{ DSN: connection, }), &gorm.Config{ diff --git a/db/model/user/user.go b/db/model/user/user.go index 571d003..16de14c 100644 --- a/db/model/user/user.go +++ b/db/model/user/user.go @@ -2,7 +2,6 @@ package user import ( "fmt" - "github.com/fossyy/filekeeper/utils" "sync" "time" @@ -27,12 +26,8 @@ type UserWithExpired struct { var log *logger.AggregatedLogger var UserCache *Cache -// TESTTING VAR -var database db.Database - func init() { log = logger.Logger() - database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) UserCache = &Cache{users: make(map[string]*UserWithExpired)} ticker := time.NewTicker(time.Hour * 8) @@ -66,7 +61,7 @@ func Get(email string) (*UserWithExpired, error) { return user, nil } - userData, err := database.GetUser(email) + userData, err := db.DB.GetUser(email) if err != nil { return nil, err } diff --git a/handler/download/download.go b/handler/download/download.go index 77ca38f..e9d8b89 100644 --- a/handler/download/download.go +++ b/handler/download/download.go @@ -15,13 +15,8 @@ import ( var log *logger.AggregatedLogger -// TESTTING VAR -var database db.Database - func init() { log = logger.Logger() - database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) - } func GET(w http.ResponseWriter, r *http.Request) { @@ -45,7 +40,7 @@ func GET(w http.ResponseWriter, r *http.Request) { } userSession := middleware.GetUser(storeSession) - files, err := database.GetFiles(userSession.UserID.String()) + files, err := db.DB.GetFiles(userSession.UserID.String()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/handler/download/file/file.go b/handler/download/file/file.go index c532d89..df61902 100644 --- a/handler/download/file/file.go +++ b/handler/download/file/file.go @@ -1,7 +1,6 @@ package downloadFileHandler import ( - "github.com/fossyy/filekeeper/utils" "net/http" "os" "path/filepath" @@ -12,18 +11,13 @@ import ( var log *logger.AggregatedLogger -// TESTTING VAR -var database db.Database - func init() { log = logger.Logger() - database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) - } func GET(w http.ResponseWriter, r *http.Request) { fileID := r.PathValue("id") - file, err := database.GetFile(fileID) + file, err := db.DB.GetFile(fileID) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) diff --git a/handler/forgotPassword/forgotPassword.go b/handler/forgotPassword/forgotPassword.go index 4b16ee7..d93aadf 100644 --- a/handler/forgotPassword/forgotPassword.go +++ b/handler/forgotPassword/forgotPassword.go @@ -33,9 +33,6 @@ var mailServer *email.SmtpServer var ListForgotPassword map[string]*ForgotPassword var UserForgotPassword = make(map[string]string) -// TESTTING VAR -var database db.Database - func init() { log = logger.Logger() ListForgotPassword = make(map[string]*ForgotPassword) @@ -43,7 +40,6 @@ func init() { mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) ticker := time.NewTicker(time.Minute) //TESTING - database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) go func() { for { <-ticker.C @@ -89,7 +85,7 @@ func POST(w http.ResponseWriter, r *http.Request) { emailForm := r.Form.Get("email") - user, err := database.GetUser(emailForm) + user, err := db.DB.GetUser(emailForm) if errors.Is(err, gorm.ErrRecordNotFound) { component := forgotPasswordView.Main(fmt.Sprintf("Account with this email address %s is not found", emailForm), types.Message{ Code: 0, diff --git a/handler/forgotPassword/verify/verify.go b/handler/forgotPassword/verify/verify.go index 4265d05..22a682d 100644 --- a/handler/forgotPassword/verify/verify.go +++ b/handler/forgotPassword/verify/verify.go @@ -16,13 +16,9 @@ import ( var log *logger.AggregatedLogger -// TESTTING VAR -var database db.Database - func init() { log = logger.Logger() //TESTING - database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } @@ -90,7 +86,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - err = database.UpdateUserPassword(data.User.Email, hashedPassword) + err = db.DB.UpdateUserPassword(data.User.Email, hashedPassword) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) log.Error(err.Error()) diff --git a/handler/signup/signup.go b/handler/signup/signup.go index 58f4bca..31375fb 100644 --- a/handler/signup/signup.go +++ b/handler/signup/signup.go @@ -32,16 +32,12 @@ var mailServer *email.SmtpServer var VerifyUser map[string]*UnverifiedUser var VerifyEmail map[string]string -// TESTTING VAR -var database db.Database - func init() { log = logger.Logger() smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT")) mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) VerifyUser = make(map[string]*UnverifiedUser) VerifyEmail = make(map[string]string) - database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) ticker := time.NewTicker(time.Minute) go func() { @@ -112,7 +108,7 @@ func POST(w http.ResponseWriter, r *http.Request) { Password: hashedPassword, } - if registered := database.IsUserRegistered(userEmail, username); registered { + if registered := db.DB.IsUserRegistered(userEmail, username); registered { component := signupView.Main("Sign up Page", types.Message{ Code: 0, Message: "Email or Username has been registered", diff --git a/handler/signup/verify/verify.go b/handler/signup/verify/verify.go index afcf477..f26bea5 100644 --- a/handler/signup/verify/verify.go +++ b/handler/signup/verify/verify.go @@ -1,7 +1,6 @@ package signupVerifyHandler import ( - "github.com/fossyy/filekeeper/utils" "net/http" "github.com/fossyy/filekeeper/db" @@ -13,13 +12,8 @@ import ( var log *logger.AggregatedLogger -// TESTTING VAR -var database db.Database - func init() { log = logger.Logger() - database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) - } func GET(w http.ResponseWriter, r *http.Request) { @@ -31,7 +25,7 @@ func GET(w http.ResponseWriter, r *http.Request) { return } - err := database.CreateUser(data.User) + err := db.DB.CreateUser(data.User) if err != nil { component := signupView.Main("Sign up Page", types.Message{ Code: 0, diff --git a/handler/upload/initialisation/initialisation.go b/handler/upload/initialisation/initialisation.go index a4e46a5..634101f 100644 --- a/handler/upload/initialisation/initialisation.go +++ b/handler/upload/initialisation/initialisation.go @@ -3,7 +3,6 @@ package initialisation import ( "encoding/json" "errors" - "github.com/fossyy/filekeeper/utils" "io" "net/http" "os" @@ -21,13 +20,8 @@ import ( var log *logger.AggregatedLogger -// TESTTING VAR -var database db.Database - func init() { log = logger.Logger() - database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) - } func POST(w http.ResponseWriter, r *http.Request) { @@ -59,7 +53,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - fileData, err := database.GetUserFile(fileInfo.Name, userSession.UserID.String()) + fileData, err := db.DB.GetUserFile(fileInfo.Name, userSession.UserID.String()) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { upload, err := handleNewUpload(userSession, fileInfo) @@ -74,7 +68,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - info, err := database.GetUploadInfo(fileData.ID.String()) + info, err := db.DB.GetUploadInfo(fileData.ID.String()) if err != nil { log.Error(err.Error()) return @@ -122,7 +116,7 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded Downloaded: 0, } - err = database.CreateFile(&newFile) + err = db.DB.CreateFile(&newFile) if err != nil { log.Error(err.Error()) return models.FilesUploaded{}, err @@ -138,7 +132,7 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded Done: false, } - err = database.CreateUploadInfo(filesUploaded) + err = db.DB.CreateUploadInfo(filesUploaded) if err != nil { log.Error(err.Error()) return models.FilesUploaded{}, err diff --git a/handler/upload/upload.go b/handler/upload/upload.go index 3a58894..b17b436 100644 --- a/handler/upload/upload.go +++ b/handler/upload/upload.go @@ -3,7 +3,6 @@ package uploadHandler import ( "errors" "github.com/fossyy/filekeeper/db" - "github.com/fossyy/filekeeper/utils" "io" "net/http" "os" @@ -20,13 +19,8 @@ import ( var log *logger.AggregatedLogger var mu sync.Mutex -// TESTTING VAR -var database db.Database - func init() { log = logger.Logger() - database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) - } func GET(w http.ResponseWriter, r *http.Request) { @@ -62,7 +56,7 @@ func POST(w http.ResponseWriter, r *http.Request) { userSession := middleware.GetUser(storeSession) if r.FormValue("done") == "true" { - database.FinalizeFileUpload(fileID) + db.DB.FinalizeFileUpload(fileID) return } @@ -72,7 +66,7 @@ func POST(w http.ResponseWriter, r *http.Request) { return } - file, err := database.GetUploadInfo(fileID) + file, err := db.DB.GetUploadInfo(fileID) if err != nil { log.Error("error getting upload info: " + err.Error()) return @@ -110,7 +104,7 @@ func POST(w http.ResponseWriter, r *http.Request) { if err != nil { return } - database.UpdateUpdateIndex(index, fileID) + db.DB.UpdateUpdateIndex(index, fileID) } func createUploadDirectory(uploadDir string) error { diff --git a/main.go b/main.go index 4d81e3d..41e794e 100644 --- a/main.go +++ b/main.go @@ -1,24 +1,9 @@ package main import ( - "fmt" - "net/http" - - "github.com/fossyy/filekeeper/middleware" - "github.com/fossyy/filekeeper/routes" - "github.com/fossyy/filekeeper/utils" + "github.com/fossyy/filekeeper/app" ) func main() { - serverAddr := fmt.Sprintf("%s:%s", utils.Getenv("SERVER_HOST"), utils.Getenv("SERVER_PORT")) - server := http.Server{ - Addr: serverAddr, - Handler: middleware.Handler(routes.SetupRoutes()), - } - - fmt.Printf("Listening on http://%s\n", serverAddr) - err := server.ListenAndServe() - if err != nil { - return - } + app.Start() }