diff --git a/db/database.go b/db/database.go index 5de4fea..7eb91ce 100644 --- a/db/database.go +++ b/db/database.go @@ -6,6 +6,7 @@ import ( "github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/types/models" "gorm.io/driver/mysql" + "gorm.io/driver/postgres" "gorm.io/gorm" gormLogger "gorm.io/gorm/logger" "os" @@ -13,12 +14,15 @@ import ( ) var log *logger.AggregatedLogger -var DB *gorm.DB type mySQLdb struct { *gorm.DB } +type postgresDB struct { + *gorm.DB +} + type Database interface { IsUserRegistered(email string, username string) bool @@ -40,7 +44,7 @@ type Database interface { func NewMYSQLdb(username, password, host, port, dbName string) Database { var err error connection := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", username, password, host, port, dbName) - DB, err = gorm.Open(mysql.New(mysql.Config{ + DB, err := gorm.Open(mysql.New(mysql.Config{ DSN: connection, DefaultStringSize: 256, DisableDatetimePrecision: true, @@ -75,6 +79,39 @@ func NewMYSQLdb(username, password, host, port, dbName string) Database { return &mySQLdb{DB} } +func NewPostgresDB(username, password, host, port, dbName string) Database { + var err error + connection := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=Asia/Jakarta", host, username, password, dbName, port, "disable") + DB, err := gorm.Open(postgres.New(postgres.Config{ + DSN: connection, + }), &gorm.Config{ + Logger: gormLogger.Default.LogMode(gormLogger.Silent), + }) + + if err != nil { + panic("failed to connect database: " + err.Error()) + } + + file, err := os.ReadFile("schema.sql") + if err != nil { + panic("Error opening file: " + err.Error()) + } + + queries := strings.Split(string(file), ";") + for _, query := range queries { + query = strings.TrimSpace(query) + if query == "" { + continue + } + err := DB.Exec(query).Error + if err != nil { + panic("Error executing query: " + err.Error()) + } + } + + return &postgresDB{DB} +} + func (db *mySQLdb) IsUserRegistered(email string, username string) bool { var data models.User err := db.DB.Table("users").Where("email = ? OR username = ?", email, username).First(&data).Error @@ -166,13 +203,116 @@ func (db *mySQLdb) GetUploadInfo(fileID string) (*models.FilesUploaded, error) { } func (db *mySQLdb) UpdateUpdateIndex(index int, fileID string) { - db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ + db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).Updates(map[string]interface{}{ "Uploaded": index, }) } func (db *mySQLdb) FinalizeFileUpload(fileID string) { - db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{ + db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).Updates(map[string]interface{}{ + "Done": true, + }) +} + +// POSTGRES FUNCTION +func (db *postgresDB) IsUserRegistered(email string, username string) bool { + var data models.User + err := db.DB.Table("users").Where("email = $1 OR username = $2", email, username).First(&data).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false + } + return true + } + return true +} + +func (db *postgresDB) CreateUser(user *models.User) error { + err := db.DB.Create(user).Error + if err != nil { + return err + } + return nil +} + +func (db *postgresDB) GetUser(email string) (*models.User, error) { + var user models.User + err := db.DB.Table("users").Where("email = $1", email).First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +func (db *postgresDB) UpdateUserPassword(email string, password string) error { + err := db.DB.Table("users").Where("email = $1", email).Update("password", password).Error + if err != nil { + return err + } + return nil +} + +func (db *postgresDB) CreateFile(file *models.File) error { + err := db.DB.Create(file).Error + if err != nil { + return err + } + return nil +} + +func (db *postgresDB) GetFile(fileID string) (*models.File, error) { + var file models.File + err := db.DB.Table("files").Where("id = $1", fileID).First(&file).Error + if err != nil { + return nil, err + } + return &file, nil +} + +func (db *postgresDB) GetUserFile(name string, ownerID string) (*models.File, error) { + var file models.File + err := db.DB.Table("files").Where("name = $1 AND owner_id = $2", name, ownerID).First(&file).Error + if err != nil { + return nil, err + } + return &file, nil +} + +func (db *postgresDB) GetFiles(ownerID string) ([]*models.File, error) { + var files []*models.File + err := db.DB.Table("files").Where("owner_id = $1", ownerID).Find(&files).Error + if err != nil { + return nil, err + } + return files, err +} + +// CreateUploadInfo It's not optimal, but it's okay for now. Consider implementing caching instead of pushing all updates to the database for better performance in the future. +func (db *postgresDB) CreateUploadInfo(info models.FilesUploaded) error { + err := db.DB.Create(info).Error + if err != nil { + return err + } + return nil +} + +func (db *postgresDB) GetUploadInfo(fileID string) (*models.FilesUploaded, error) { + var info models.FilesUploaded + err := db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).First(&info).Error + if err != nil { + return nil, err + } + return &info, nil +} + +func (db *postgresDB) UpdateUpdateIndex(index int, fileID string) { + db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).Updates(map[string]interface{}{ + "Uploaded": index, + }) +} + +func (db *postgresDB) FinalizeFileUpload(fileID string) { + db.DB.Table("files_uploadeds").Where("file_id = $1", fileID).Updates(map[string]interface{}{ "Done": true, }) } diff --git a/db/model/user/user.go b/db/model/user/user.go index 367613a..571d003 100644 --- a/db/model/user/user.go +++ b/db/model/user/user.go @@ -32,7 +32,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) UserCache = &Cache{users: make(map[string]*UserWithExpired)} ticker := time.NewTicker(time.Hour * 8) diff --git a/go.mod b/go.mod index 9564b2c..683d919 100644 --- a/go.mod +++ b/go.mod @@ -9,12 +9,17 @@ require ( golang.org/x/crypto v0.21.0 gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df gorm.io/driver/mysql v1.5.6 + gorm.io/driver/postgres v1.5.7 gorm.io/gorm v1.25.8 ) require ( github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + golang.org/x/text v0.14.0 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect ) diff --git a/go.sum b/go.sum index be21266..3f53935 100644 --- a/go.sum +++ b/go.sum @@ -1,25 +1,49 @@ github.com/a-h/templ v0.2.648 h1:A1ggHGIE7AONOHrFaDTM8SrqgqHL6fWgWCijQ21Zy9I= github.com/a-h/templ v0.2.648/go.mod h1:SA7mtYwVEajbIXFRh3vKdYm/4FYyLQAtPH1+KxzGPA8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/mysql v1.5.6 h1:Ld4mkIickM+EliaQZQx3uOJDJHtrd70MxAUqWqlx3Y8= gorm.io/driver/mysql v1.5.6/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= +gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= +gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gorm.io/gorm v1.25.8 h1:WAGEZ/aEcznN4D03laj8DKnehe1e9gYQAjW8xyPRdeo= gorm.io/gorm v1.25.8/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/handler/download/download.go b/handler/download/download.go index f60c2a9..77ca38f 100644 --- a/handler/download/download.go +++ b/handler/download/download.go @@ -20,7 +20,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } @@ -50,7 +50,7 @@ func GET(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - + var filesData []types.FileData for i := 0; i < len(files); i++ { filesData = append(filesData, types.FileData{ diff --git a/handler/download/file/file.go b/handler/download/file/file.go index 8fb5021..c532d89 100644 --- a/handler/download/file/file.go +++ b/handler/download/file/file.go @@ -17,7 +17,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } diff --git a/handler/forgotPassword/forgotPassword.go b/handler/forgotPassword/forgotPassword.go index d6368af..4b16ee7 100644 --- a/handler/forgotPassword/forgotPassword.go +++ b/handler/forgotPassword/forgotPassword.go @@ -43,7 +43,7 @@ func init() { mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) ticker := time.NewTicker(time.Minute) //TESTING - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) go func() { for { <-ticker.C diff --git a/handler/forgotPassword/verify/verify.go b/handler/forgotPassword/verify/verify.go index 0700564..4265d05 100644 --- a/handler/forgotPassword/verify/verify.go +++ b/handler/forgotPassword/verify/verify.go @@ -22,7 +22,7 @@ var database db.Database func init() { log = logger.Logger() //TESTING - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } diff --git a/handler/signup/signup.go b/handler/signup/signup.go index e489612..58f4bca 100644 --- a/handler/signup/signup.go +++ b/handler/signup/signup.go @@ -41,7 +41,7 @@ func init() { mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) VerifyUser = make(map[string]*UnverifiedUser) VerifyEmail = make(map[string]string) - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) ticker := time.NewTicker(time.Minute) go func() { diff --git a/handler/signup/verify/verify.go b/handler/signup/verify/verify.go index 007825f..afcf477 100644 --- a/handler/signup/verify/verify.go +++ b/handler/signup/verify/verify.go @@ -18,7 +18,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } diff --git a/handler/upload/initialisation/initialisation.go b/handler/upload/initialisation/initialisation.go index d0f9ce8..a4e46a5 100644 --- a/handler/upload/initialisation/initialisation.go +++ b/handler/upload/initialisation/initialisation.go @@ -26,7 +26,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) } diff --git a/handler/upload/upload.go b/handler/upload/upload.go index e1df7b2..3a58894 100644 --- a/handler/upload/upload.go +++ b/handler/upload/upload.go @@ -25,7 +25,7 @@ var database db.Database func init() { log = logger.Logger() - database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) + database = db.NewPostgresDB(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) }