641 lines
16 KiB
Go
641 lines
16 KiB
Go
package db
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/fossyy/filekeeper/types"
|
|
"github.com/fossyy/filekeeper/types/models"
|
|
"github.com/google/uuid"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
gormLogger "gorm.io/gorm/logger"
|
|
)
|
|
|
|
type mySQLdb struct {
|
|
*gorm.DB
|
|
}
|
|
|
|
type postgresDB struct {
|
|
*gorm.DB
|
|
}
|
|
|
|
type SSLMode string
|
|
|
|
const (
|
|
DisableSSL SSLMode = "disable"
|
|
EnableSSL SSLMode = "enable"
|
|
)
|
|
|
|
func NewMYSQLdb(username, password, host, port, dbName string) types.Database {
|
|
var err error
|
|
var count int64
|
|
|
|
connection := fmt.Sprintf("%s:%s@tcp(%s:%s)/", username, password, host, port)
|
|
initDB, err := gorm.Open(mysql.New(mysql.Config{
|
|
DSN: connection,
|
|
DefaultStringSize: 256,
|
|
DisableDatetimePrecision: true,
|
|
DontSupportRenameIndex: true,
|
|
DontSupportRenameColumn: true,
|
|
SkipInitializeWithVersion: false,
|
|
}), &gorm.Config{
|
|
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
|
})
|
|
|
|
if err != nil {
|
|
panic("failed to connect database: " + err.Error())
|
|
}
|
|
|
|
initDB.Raw("SELECT count(*) FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = ?", dbName).Scan(&count)
|
|
if count <= 0 {
|
|
if err := initDB.Exec("CREATE DATABASE IF NOT EXISTS " + dbName).Error; err != nil {
|
|
panic("Error creating database: " + err.Error())
|
|
}
|
|
}
|
|
|
|
connection = fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", username, password, host, port, dbName)
|
|
DB, err := gorm.Open(mysql.New(mysql.Config{
|
|
DSN: connection,
|
|
DefaultStringSize: 256,
|
|
DisableDatetimePrecision: true,
|
|
DontSupportRenameIndex: true,
|
|
DontSupportRenameColumn: true,
|
|
SkipInitializeWithVersion: false,
|
|
}), &gorm.Config{
|
|
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
|
})
|
|
|
|
if err != nil {
|
|
panic("failed to connect database: " + err.Error())
|
|
}
|
|
|
|
err = DB.AutoMigrate(&models.MysqlUser{})
|
|
if err != nil {
|
|
panic(err.Error())
|
|
return nil
|
|
}
|
|
err = DB.AutoMigrate(&models.MysqlFile{})
|
|
if err != nil {
|
|
panic(err.Error())
|
|
return nil
|
|
}
|
|
err = DB.AutoMigrate(&models.MysqlAllowance{})
|
|
if err != nil {
|
|
panic(err.Error())
|
|
return nil
|
|
}
|
|
return &mySQLdb{DB}
|
|
}
|
|
|
|
func NewPostgresDB(username, password, host, port, dbName string, mode SSLMode) types.Database {
|
|
var err error
|
|
var count int64
|
|
|
|
connection := fmt.Sprintf("host=%s user=%s password=%s port=%s sslmode=%s TimeZone=Asia/Jakarta", host, username, password, port, mode)
|
|
initDB, err := gorm.Open(postgres.New(postgres.Config{
|
|
DSN: connection,
|
|
}), &gorm.Config{
|
|
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
|
})
|
|
|
|
if err != nil {
|
|
panic("failed to connect database: " + err.Error())
|
|
}
|
|
|
|
initDB.Raw("SELECT count(*) FROM pg_database WHERE datname = ?", dbName).Scan(&count)
|
|
if count <= 0 {
|
|
if err := initDB.Exec("CREATE DATABASE " + dbName).Error; err != nil {
|
|
panic("Error creating database: " + err.Error())
|
|
}
|
|
}
|
|
|
|
connection = fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=Asia/Jakarta", host, username, password, dbName, port, mode)
|
|
DB, err := gorm.Open(postgres.New(postgres.Config{
|
|
DSN: connection,
|
|
}), &gorm.Config{
|
|
Logger: gormLogger.Default.LogMode(gormLogger.Silent),
|
|
})
|
|
|
|
if err != nil {
|
|
panic("failed to connect database: " + err.Error())
|
|
}
|
|
|
|
err = DB.AutoMigrate(&models.User{})
|
|
if err != nil {
|
|
panic(err.Error())
|
|
return nil
|
|
}
|
|
err = DB.AutoMigrate(&models.File{})
|
|
if err != nil {
|
|
panic(err.Error())
|
|
return nil
|
|
}
|
|
err = DB.AutoMigrate(&models.Allowance{})
|
|
if err != nil {
|
|
panic(err.Error())
|
|
return nil
|
|
}
|
|
return &postgresDB{DB}
|
|
}
|
|
|
|
func UUIDToString(u uuid.UUID) string {
|
|
return u.String()
|
|
}
|
|
|
|
func StringToUUID(s string) (uuid.UUID, error) {
|
|
return uuid.Parse(s)
|
|
}
|
|
|
|
func (db *mySQLdb) IsUserRegistered(email string, username string) bool {
|
|
var data models.MysqlUser
|
|
err := db.DB.Table("users").Where("email = ? OR username = ?", email, username).First(&data).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (db *mySQLdb) IsEmailRegistered(email string) bool {
|
|
var data models.MysqlUser
|
|
err := db.DB.Table("users").Where("email = ?", email).First(&data).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (db *mySQLdb) CreateUser(user *models.User) error {
|
|
mysqlUser := models.MysqlUser{
|
|
UserID: UUIDToString(user.UserID),
|
|
Username: user.Username,
|
|
Email: user.Email,
|
|
Password: user.Password,
|
|
Totp: user.Totp,
|
|
}
|
|
err := db.DB.Create(&mysqlUser).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return db.CreateAllowance(user.UserID)
|
|
}
|
|
|
|
func (db *mySQLdb) GetUser(email string) (*models.User, error) {
|
|
var mysqlUser models.MysqlUser
|
|
err := db.DB.Table("users").Where("email = ?", email).First(&mysqlUser).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
userID, _ := StringToUUID(mysqlUser.UserID)
|
|
return &models.User{
|
|
UserID: userID,
|
|
Username: mysqlUser.Username,
|
|
Email: mysqlUser.Email,
|
|
Password: mysqlUser.Password,
|
|
Totp: mysqlUser.Totp,
|
|
}, nil
|
|
}
|
|
|
|
func (db *mySQLdb) GetAllUsers() ([]models.User, error) {
|
|
var mysqlUsers []models.MysqlUser
|
|
err := db.DB.Table("users").Select("user_id, username, email").Find(&mysqlUsers).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
users := make([]models.User, len(mysqlUsers))
|
|
for i, u := range mysqlUsers {
|
|
userID, _ := StringToUUID(u.UserID)
|
|
users[i] = models.User{
|
|
UserID: userID,
|
|
Username: u.Username,
|
|
Email: u.Email,
|
|
Password: u.Password,
|
|
Totp: u.Totp,
|
|
}
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (db *mySQLdb) UpdateUserPassword(email string, password string) error {
|
|
var mysqlUser models.MysqlUser
|
|
err := db.DB.Table("users").Where("email = ?", email).First(&mysqlUser).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mysqlUser.Password = password
|
|
return db.DB.Save(&mysqlUser).Error
|
|
}
|
|
|
|
func (db *mySQLdb) CreateAllowance(userID uuid.UUID) error {
|
|
userAllowance := &models.Allowance{
|
|
UserID: userID,
|
|
AllowanceByte: 1024 * 1024 * 1024 * 10,
|
|
AllowanceFile: 10,
|
|
}
|
|
return db.DB.Create(userAllowance).Error
|
|
}
|
|
|
|
func (db *mySQLdb) GetAllowance(userID uuid.UUID) (*models.Allowance, error) {
|
|
var allowance models.Allowance
|
|
err := db.DB.Table("allowances").Where("user_id = ?", userID).First(&allowance).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &allowance, nil
|
|
}
|
|
|
|
func (db *mySQLdb) CreateFile(file *models.File) error {
|
|
mysqlFile := models.MysqlFile{
|
|
ID: UUIDToString(file.ID),
|
|
OwnerID: UUIDToString(file.OwnerID),
|
|
Name: file.Name,
|
|
Size: file.Size,
|
|
TotalChunk: file.TotalChunk,
|
|
StartHash: file.StartHash,
|
|
EndHash: file.EndHash,
|
|
IsPrivate: file.IsPrivate,
|
|
Type: file.Type,
|
|
Downloaded: file.Downloaded,
|
|
}
|
|
return db.DB.Create(&mysqlFile).Error
|
|
}
|
|
|
|
func (db *mySQLdb) GetFile(fileID string) (*models.File, error) {
|
|
var mysqlFile models.MysqlFile
|
|
err := db.DB.Table("files").Where("id = ?", fileID).First(&mysqlFile).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fileIDUUID, err := StringToUUID(mysqlFile.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ownerIDUUID, err := StringToUUID(mysqlFile.OwnerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &models.File{
|
|
ID: fileIDUUID,
|
|
OwnerID: ownerIDUUID,
|
|
Name: mysqlFile.Name,
|
|
Size: mysqlFile.Size,
|
|
TotalChunk: mysqlFile.TotalChunk,
|
|
StartHash: mysqlFile.StartHash,
|
|
EndHash: mysqlFile.EndHash,
|
|
IsPrivate: mysqlFile.IsPrivate,
|
|
Type: mysqlFile.Type,
|
|
Downloaded: mysqlFile.Downloaded,
|
|
}, nil
|
|
}
|
|
|
|
func (db *mySQLdb) RenameFile(fileID string, name string) (*models.File, error) {
|
|
var mysqlFile models.MysqlFile
|
|
err := db.DB.Table("files").Where("id = ?", fileID).First(&mysqlFile).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
mysqlFile.Name = name
|
|
err = db.DB.Save(&mysqlFile).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fileIDUUID, err := StringToUUID(mysqlFile.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ownerIDUUID, err := StringToUUID(mysqlFile.OwnerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &models.File{
|
|
ID: fileIDUUID,
|
|
OwnerID: ownerIDUUID,
|
|
Name: mysqlFile.Name,
|
|
Size: mysqlFile.Size,
|
|
TotalChunk: mysqlFile.TotalChunk,
|
|
StartHash: mysqlFile.StartHash,
|
|
EndHash: mysqlFile.EndHash,
|
|
IsPrivate: mysqlFile.IsPrivate,
|
|
Type: mysqlFile.Type,
|
|
Downloaded: mysqlFile.Downloaded,
|
|
}, nil
|
|
}
|
|
|
|
func (db *mySQLdb) DeleteFile(fileID string) error {
|
|
return db.DB.Table("files").Where("id = ?", fileID).Delete(&models.MysqlFile{}).Error
|
|
}
|
|
|
|
func (db *mySQLdb) GetUserFile(name string, ownerID string) (*models.File, error) {
|
|
var mysqlFile models.MysqlFile
|
|
err := db.DB.Table("files").Where("name = ? AND owner_id = ?", name, ownerID).First(&mysqlFile).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fileIDUUID, err := StringToUUID(mysqlFile.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ownerIDUUID, err := StringToUUID(mysqlFile.OwnerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &models.File{
|
|
ID: fileIDUUID,
|
|
OwnerID: ownerIDUUID,
|
|
Name: mysqlFile.Name,
|
|
Size: mysqlFile.Size,
|
|
TotalChunk: mysqlFile.TotalChunk,
|
|
StartHash: mysqlFile.StartHash,
|
|
EndHash: mysqlFile.EndHash,
|
|
IsPrivate: mysqlFile.IsPrivate,
|
|
Type: mysqlFile.Type,
|
|
Downloaded: mysqlFile.Downloaded,
|
|
}, nil
|
|
}
|
|
|
|
func (db *mySQLdb) GetFiles(ownerID string, query string, status types.FileStatus) ([]*models.File, error) {
|
|
var mysqlFiles []*models.MysqlFile
|
|
tx := db.DB.Table("files").Where("owner_id = ?", ownerID)
|
|
|
|
if query != "" {
|
|
tx = tx.Where("name LIKE ?", "%"+query+"%")
|
|
}
|
|
|
|
switch status {
|
|
case types.Private:
|
|
tx = tx.Where("is_private = ?", true)
|
|
case types.Public:
|
|
tx = tx.Where("is_private = ?", false)
|
|
}
|
|
|
|
err := tx.Find(&mysqlFiles).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
files := make([]*models.File, len(mysqlFiles))
|
|
for i, f := range mysqlFiles {
|
|
fileIDUUID, err := StringToUUID(f.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
ownerIDUUID, err := StringToUUID(f.OwnerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
files[i] = &models.File{
|
|
ID: fileIDUUID,
|
|
OwnerID: ownerIDUUID,
|
|
Name: f.Name,
|
|
Size: f.Size,
|
|
TotalChunk: f.TotalChunk,
|
|
StartHash: f.StartHash,
|
|
EndHash: f.EndHash,
|
|
IsPrivate: f.IsPrivate,
|
|
Type: f.Type,
|
|
Downloaded: f.Downloaded,
|
|
}
|
|
}
|
|
return files, nil
|
|
}
|
|
|
|
func (db *mySQLdb) IncrementDownloadCount(fileID string) error {
|
|
var mysqlFile models.MysqlFile
|
|
err := db.DB.Table("files").Where("id = ?", fileID).First(&mysqlFile).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mysqlFile.Downloaded++
|
|
return db.DB.Save(&mysqlFile).Error
|
|
}
|
|
|
|
func (db *mySQLdb) ChangeFileVisibility(fileID string) error {
|
|
err := db.DB.Model(&models.MysqlFile{}).Where("id = ?", fileID).Update("is_private", gorm.Expr("NOT is_private")).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *mySQLdb) InitializeTotp(email string, secret string) error {
|
|
var mysqlUser models.MysqlUser
|
|
err := db.DB.Table("users").Where("email = ?", email).First(&mysqlUser).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
mysqlUser.Totp = secret
|
|
return db.DB.Save(&mysqlUser).Error
|
|
}
|
|
|
|
// POSTGRES FUNCTION
|
|
func (db *postgresDB) IsUserRegistered(email string, username string) bool {
|
|
var data models.User
|
|
err := db.DB.Table("users").Where("email = $1 OR username = $2", email, username).First(&data).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (db *postgresDB) IsEmailRegistered(email string) bool {
|
|
var data models.User
|
|
err := db.DB.Table("users").Where("email = $1 ", email).First(&data).Error
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (db *postgresDB) CreateUser(user *models.User) error {
|
|
err := db.DB.Create(user).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = db.CreateAllowance(user.UserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *postgresDB) GetUser(email string) (*models.User, error) {
|
|
var user models.User
|
|
err := db.DB.Table("users").Where("email = $1", email).First(&user).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (db *postgresDB) GetAllUsers() ([]models.User, error) {
|
|
var users []models.User
|
|
err := db.DB.Table("users").Select("user_id, username, email").Find(&users).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (db *postgresDB) UpdateUserPassword(email string, password string) error {
|
|
var user models.User
|
|
err := db.DB.Table("users").Where("email = $1", email).First(&user).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.Password = password
|
|
db.Save(&user)
|
|
return nil
|
|
}
|
|
|
|
func (db *postgresDB) CreateAllowance(userID uuid.UUID) error {
|
|
userAllowance := &models.Allowance{
|
|
UserID: userID,
|
|
AllowanceByte: 1024 * 1024 * 1024 * 10,
|
|
AllowanceFile: 10,
|
|
}
|
|
err := db.DB.Create(userAllowance).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *postgresDB) GetAllowance(userID uuid.UUID) (*models.Allowance, error) {
|
|
var allowance models.Allowance
|
|
err := db.DB.Table("allowances").Where("user_id = $1", userID).First(&allowance).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &allowance, nil
|
|
}
|
|
|
|
func (db *postgresDB) CreateFile(file *models.File) error {
|
|
err := db.DB.Create(file).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *postgresDB) GetFile(fileID string) (*models.File, error) {
|
|
var file models.File
|
|
err := db.DB.Table("files").Where("id = $1", fileID).First(&file).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &file, nil
|
|
}
|
|
|
|
func (db *postgresDB) RenameFile(fileID string, name string) (*models.File, error) {
|
|
var file models.File
|
|
err := db.DB.Table("files").Where("id = $1", fileID).First(&file).Error
|
|
file.Name = name
|
|
err = db.DB.Save(&file).Error
|
|
if err != nil {
|
|
return &file, err
|
|
}
|
|
return &file, nil
|
|
}
|
|
|
|
func (db *postgresDB) DeleteFile(fileID string) error {
|
|
err := db.DB.Table("files").Where("id = $1", fileID).Delete(&models.File{}).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *postgresDB) GetUserFile(name string, ownerID string) (*models.File, error) {
|
|
var file models.File
|
|
err := db.DB.Table("files").Where("name = $1 AND owner_id = $2", name, ownerID).First(&file).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &file, nil
|
|
}
|
|
|
|
func (db *postgresDB) GetFiles(ownerID string, query string, status types.FileStatus) ([]*models.File, error) {
|
|
var files []*models.File
|
|
tx := db.DB.Table("files").Where("owner_id = $1", ownerID)
|
|
|
|
if query != "" {
|
|
tx = tx.Where("name LIKE $2", "%"+query+"%")
|
|
}
|
|
|
|
if query == "" {
|
|
switch status {
|
|
case types.Private:
|
|
tx = tx.Where("is_private = $2::boolean", true)
|
|
case types.Public:
|
|
tx = tx.Where("is_private = $2::boolean", false)
|
|
default:
|
|
}
|
|
} else {
|
|
switch status {
|
|
case types.Private:
|
|
tx = tx.Where("is_private = $3::boolean", true)
|
|
case types.Public:
|
|
tx = tx.Where("is_private = $3::boolean", false)
|
|
default:
|
|
}
|
|
}
|
|
|
|
err := tx.Find(&files).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return files, nil
|
|
}
|
|
|
|
func (db *postgresDB) IncrementDownloadCount(fileID string) error {
|
|
var file models.File
|
|
err := db.DB.Table("files").Where("id = $1", fileID).First(&file).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
file.Downloaded = file.Downloaded + 1
|
|
err = db.DB.Updates(file).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *postgresDB) ChangeFileVisibility(fileID string) error {
|
|
err := db.DB.Model(&models.File{}).Where("id = $1", fileID).Select("is_private").
|
|
Updates(map[string]interface{}{"is_private": gorm.Expr("NOT is_private")}).Error
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (db *postgresDB) InitializeTotp(email string, secret string) error {
|
|
var user models.User
|
|
err := db.DB.Table("users").Where("email = $1", email).First(&user).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
user.Totp = secret
|
|
err = db.Save(&user).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|