Separate database initialization into NewMYSQLdb function

This commit is contained in:
2024-04-28 20:49:41 +07:00
parent 1a7ac48330
commit b4f303463d
10 changed files with 242 additions and 92 deletions

View File

@ -1,39 +1,178 @@
package db package db
import ( import (
"errors"
"fmt" "fmt"
"os"
"strings"
"github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/logger"
"github.com/fossyy/filekeeper/utils" "github.com/fossyy/filekeeper/types/models"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/gorm" "gorm.io/gorm"
gormLogger "gorm.io/gorm/logger" gormLogger "gorm.io/gorm/logger"
"os"
"strings"
) )
var log *logger.AggregatedLogger
var DB *gorm.DB var DB *gorm.DB
var log *logger.AggregatedLogger type mySQLdb struct {
*gorm.DB
}
func init() { type Database interface {
IsUserRegistered(email string, username string) bool
CreateUser(user *models.User) error
GetUser(email string) (*models.User, error)
UpdateUserPassword(email string, password string) error
CreateFile(file *models.File) error
GetFile(fileID string) (*models.File, error)
GetUserFile(name string, ownerID string) (*models.File, error)
GetFiles(ownerID string) ([]*models.File, error)
CreateUploadInfo(info models.FilesUploaded) error
GetUploadInfo(uploadID string) (*models.FilesUploaded, error)
UpdateUpdateIndex(index int, fileID string)
FinalizeFileUpload(fileID string)
}
func NewMYSQLdb(username, password, host, port, dbName string) Database {
var err error var err error
connection := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME")) connection := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", username, password, host, port, dbName)
DB, err = gorm.Open(mysql.Open(connection), &gorm.Config{}, &gorm.Config{ DB, err = gorm.Open(mysql.New(mysql.Config{
DSN: connection,
DefaultStringSize: 256,
DisableDatetimePrecision: true,
DontSupportRenameIndex: true,
DontSupportRenameColumn: true,
SkipInitializeWithVersion: false,
}), &gorm.Config{
Logger: gormLogger.Default.LogMode(gormLogger.Silent), Logger: gormLogger.Default.LogMode(gormLogger.Silent),
}) })
if err != nil { if err != nil {
panic("failed to connect database" + err.Error()) panic("failed to connect database: " + err.Error())
} }
file, err := os.ReadFile("schema.sql") file, err := os.ReadFile("schema.sql")
if err != nil { if err != nil {
log.Error("Error opening file: %s", err.Error()) panic("Error opening file: " + err.Error())
} }
querys := strings.Split(string(file), "\n")
for _, query := range querys { queries := strings.Split(string(file), ";")
for _, query := range queries {
query = strings.TrimSpace(query)
if query == "" {
continue
}
err := DB.Exec(query).Error err := DB.Exec(query).Error
if err != nil { if err != nil {
panic(err.Error()) panic("Error executing query: " + err.Error())
} }
} }
return &mySQLdb{DB}
}
func (db *mySQLdb) IsUserRegistered(email string, username string) bool {
var data models.User
err := db.DB.Table("users").Where("email = ? OR username = ?", email, username).First(&data).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false
}
return true
}
return true
}
func (db *mySQLdb) CreateUser(user *models.User) error {
err := db.DB.Create(user).Error
if err != nil {
return err
}
return nil
}
func (db *mySQLdb) GetUser(email string) (*models.User, error) {
var user models.User
err := db.DB.Table("users").Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (db *mySQLdb) UpdateUserPassword(email string, password string) error {
err := db.DB.Table("users").Where("email = ?", email).Update("password", password).Error
if err != nil {
return err
}
return nil
}
func (db *mySQLdb) CreateFile(file *models.File) error {
err := db.DB.Create(file).Error
if err != nil {
return err
}
return nil
}
func (db *mySQLdb) GetFile(fileID string) (*models.File, error) {
var file models.File
err := db.DB.Table("files").Where("id = ?", fileID).First(&file).Error
if err != nil {
return nil, err
}
return &file, nil
}
func (db *mySQLdb) GetUserFile(name string, ownerID string) (*models.File, error) {
var file models.File
err := db.DB.Table("files").Where("name = ? AND owner_id = ?", name, ownerID).First(&file).Error
if err != nil {
return nil, err
}
return &file, nil
}
func (db *mySQLdb) GetFiles(ownerID string) ([]*models.File, error) {
var files []*models.File
err := db.DB.Table("files").Where("owner_id = ?", ownerID).Find(&files).Error
if err != nil {
return nil, err
}
return files, err
}
// CreateUploadInfo It's not optimal, but it's okay for now. Consider implementing caching instead of pushing all updates to the database for better performance in the future.
func (db *mySQLdb) CreateUploadInfo(info models.FilesUploaded) error {
err := db.DB.Create(info).Error
if err != nil {
return err
}
return nil
}
func (db *mySQLdb) GetUploadInfo(fileID string) (*models.FilesUploaded, error) {
var info models.FilesUploaded
err := db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).First(&info).Error
if err != nil {
return nil, err
}
return &info, nil
}
func (db *mySQLdb) UpdateUpdateIndex(index int, fileID string) {
db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{
"Uploaded": index,
})
}
func (db *mySQLdb) FinalizeFileUpload(fileID string) {
db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{
"Done": true,
})
} }

View File

@ -2,6 +2,7 @@ package user
import ( import (
"fmt" "fmt"
"github.com/fossyy/filekeeper/utils"
"sync" "sync"
"time" "time"
@ -26,8 +27,12 @@ type UserWithExpired struct {
var log *logger.AggregatedLogger var log *logger.AggregatedLogger
var UserCache *Cache var UserCache *Cache
// TESTTING VAR
var database db.Database
func init() { func init() {
log = logger.Logger() log = logger.Logger()
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
UserCache = &Cache{users: make(map[string]*UserWithExpired)} UserCache = &Cache{users: make(map[string]*UserWithExpired)}
ticker := time.NewTicker(time.Hour * 8) ticker := time.NewTicker(time.Hour * 8)
@ -61,8 +66,7 @@ func Get(email string) (*UserWithExpired, error) {
return user, nil return user, nil
} }
var userData UserWithExpired userData, err := database.GetUser(email)
err := db.DB.Table("users").Where("email = ?", email).First(&userData).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -75,7 +79,7 @@ func Get(email string) (*UserWithExpired, error) {
AccessAt: time.Now(), AccessAt: time.Now(),
} }
return &userData, nil return UserCache.users[email], nil
} }
func DeleteCache(email string) { func DeleteCache(email string) {

View File

@ -9,15 +9,19 @@ import (
"github.com/fossyy/filekeeper/middleware" "github.com/fossyy/filekeeper/middleware"
"github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/session"
"github.com/fossyy/filekeeper/types" "github.com/fossyy/filekeeper/types"
"github.com/fossyy/filekeeper/types/models"
"github.com/fossyy/filekeeper/utils" "github.com/fossyy/filekeeper/utils"
downloadView "github.com/fossyy/filekeeper/view/download" downloadView "github.com/fossyy/filekeeper/view/download"
) )
var log *logger.AggregatedLogger var log *logger.AggregatedLogger
// TESTTING VAR
var database db.Database
func init() { func init() {
log = logger.Logger() log = logger.Logger()
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
} }
func GET(w http.ResponseWriter, r *http.Request) { func GET(w http.ResponseWriter, r *http.Request) {
@ -41,8 +45,12 @@ func GET(w http.ResponseWriter, r *http.Request) {
} }
userSession := middleware.GetUser(storeSession) userSession := middleware.GetUser(storeSession)
var files []models.File files, err := database.GetFiles(userSession.UserID.String())
db.DB.Table("files").Where("owner_id = ?", userSession.UserID).Find(&files) if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var filesData []types.FileData var filesData []types.FileData
for i := 0; i < len(files); i++ { for i := 0; i < len(files); i++ {
filesData = append(filesData, types.FileData{ filesData = append(filesData, types.FileData{

View File

@ -1,29 +1,33 @@
package downloadFileHandler package downloadFileHandler
import ( import (
"github.com/fossyy/filekeeper/utils"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"github.com/fossyy/filekeeper/db" "github.com/fossyy/filekeeper/db"
"github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/logger"
"github.com/fossyy/filekeeper/types/models"
) )
var log *logger.AggregatedLogger var log *logger.AggregatedLogger
// TESTTING VAR
var database db.Database
func init() { func init() {
log = logger.Logger() log = logger.Logger()
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
} }
func GET(w http.ResponseWriter, r *http.Request) { func GET(w http.ResponseWriter, r *http.Request) {
fileID := r.PathValue("id") fileID := r.PathValue("id")
file, err := database.GetFile(fileID)
var file models.File
err := db.DB.Table("files").Where("id = ?", fileID).First(&file).Error
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
log.Error(err.Error()) log.Error(err.Error())
return
} }
uploadDir := "uploads" uploadDir := "uploads"
@ -42,6 +46,7 @@ func GET(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
log.Error(err.Error()) log.Error(err.Error())
return
} }
defer openFile.Close() defer openFile.Close()
@ -49,6 +54,7 @@ func GET(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
log.Error(err.Error()) log.Error(err.Error())
return
} }
w.Header().Set("Content-Disposition", "attachment; filename="+stat.Name()) w.Header().Set("Content-Disposition", "attachment; filename="+stat.Name())

View File

@ -33,12 +33,17 @@ var mailServer *email.SmtpServer
var ListForgotPassword map[string]*ForgotPassword var ListForgotPassword map[string]*ForgotPassword
var UserForgotPassword = make(map[string]string) var UserForgotPassword = make(map[string]string)
// TESTTING VAR
var database db.Database
func init() { func init() {
log = logger.Logger() log = logger.Logger()
ListForgotPassword = make(map[string]*ForgotPassword) ListForgotPassword = make(map[string]*ForgotPassword)
smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT")) smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT"))
mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD"))
ticker := time.NewTicker(time.Minute) ticker := time.NewTicker(time.Minute)
//TESTING
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
go func() { go func() {
for { for {
<-ticker.C <-ticker.C
@ -84,8 +89,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
emailForm := r.Form.Get("email") emailForm := r.Form.Get("email")
var user models.User user, err := database.GetUser(emailForm)
err = db.DB.Table("users").Where("email = ?", emailForm).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
component := forgotPasswordView.Main(fmt.Sprintf("Account with this email address %s is not found", emailForm), types.Message{ component := forgotPasswordView.Main(fmt.Sprintf("Account with this email address %s is not found", emailForm), types.Message{
Code: 0, Code: 0,
@ -100,7 +104,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
return return
} }
err = verifyForgot(&user) err = verifyForgot(user)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
log.Error(err.Error()) log.Error(err.Error())

View File

@ -16,8 +16,14 @@ import (
var log *logger.AggregatedLogger var log *logger.AggregatedLogger
// TESTTING VAR
var database db.Database
func init() { func init() {
log = logger.Logger() log = logger.Logger()
//TESTING
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
} }
func GET(w http.ResponseWriter, r *http.Request) { func GET(w http.ResponseWriter, r *http.Request) {
@ -84,7 +90,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
return return
} }
err = db.DB.Table("users").Where("email = ?", data.User.Email).Update("password", hashedPassword).Error err = database.UpdateUserPassword(data.User.Email, hashedPassword)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
log.Error(err.Error()) log.Error(err.Error())

View File

@ -3,7 +3,6 @@ package signupHandler
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"strconv" "strconv"
@ -19,7 +18,6 @@ import (
emailView "github.com/fossyy/filekeeper/view/email" emailView "github.com/fossyy/filekeeper/view/email"
signupView "github.com/fossyy/filekeeper/view/signup" signupView "github.com/fossyy/filekeeper/view/signup"
"github.com/google/uuid" "github.com/google/uuid"
"gorm.io/gorm"
) )
type UnverifiedUser struct { type UnverifiedUser struct {
@ -34,12 +32,16 @@ var mailServer *email.SmtpServer
var VerifyUser map[string]*UnverifiedUser var VerifyUser map[string]*UnverifiedUser
var VerifyEmail map[string]string var VerifyEmail map[string]string
// TESTTING VAR
var database db.Database
func init() { func init() {
log = logger.Logger() log = logger.Logger()
smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT")) smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT"))
mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD")) mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD"))
VerifyUser = make(map[string]*UnverifiedUser) VerifyUser = make(map[string]*UnverifiedUser)
VerifyEmail = make(map[string]string) VerifyEmail = make(map[string]string)
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
ticker := time.NewTicker(time.Minute) ticker := time.NewTicker(time.Minute)
go func() { go func() {
@ -110,34 +112,28 @@ func POST(w http.ResponseWriter, r *http.Request) {
Password: hashedPassword, Password: hashedPassword,
} }
var data models.User if registered := database.IsUserRegistered(userEmail, username); registered {
err = db.DB.Table("users").Where("email = ? OR username = ?", userEmail, username).First(&data).Error component := signupView.Main("Sign up Page", types.Message{
if err != nil { Code: 0,
if errors.Is(err, gorm.ErrRecordNotFound) { Message: "Email or Username has been registered",
err = verifyEmail(&newUser) })
if err != nil { err = component.Render(r.Context(), w)
http.Error(w, err.Error(), http.StatusInternalServerError) if err != nil {
log.Error(err.Error()) http.Error(w, err.Error(), http.StatusInternalServerError)
return log.Error(err.Error())
}
component := signupView.EmailSend("Sign up Page")
err = component.Render(r.Context(), w)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
log.Error(err.Error())
return
}
return return
} }
return
}
err = verifyEmail(&newUser)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
log.Error(err.Error()) log.Error(err.Error())
return return
} }
component := signupView.Main("Sign up Page", types.Message{
Code: 0, component := signupView.EmailSend("Sign up Page")
Message: "Email or Username has been registered",
})
err = component.Render(r.Context(), w) err = component.Render(r.Context(), w)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)

View File

@ -1,6 +1,7 @@
package signupVerifyHandler package signupVerifyHandler
import ( import (
"github.com/fossyy/filekeeper/utils"
"net/http" "net/http"
"github.com/fossyy/filekeeper/db" "github.com/fossyy/filekeeper/db"
@ -12,8 +13,13 @@ import (
var log *logger.AggregatedLogger var log *logger.AggregatedLogger
// TESTTING VAR
var database db.Database
func init() { func init() {
log = logger.Logger() log = logger.Logger()
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
} }
func GET(w http.ResponseWriter, r *http.Request) { func GET(w http.ResponseWriter, r *http.Request) {
@ -25,8 +31,7 @@ func GET(w http.ResponseWriter, r *http.Request) {
return return
} }
err := db.DB.Create(&data.User).Error err := database.CreateUser(data.User)
if err != nil { if err != nil {
component := signupView.Main("Sign up Page", types.Message{ component := signupView.Main("Sign up Page", types.Message{
Code: 0, Code: 0,

View File

@ -3,6 +3,7 @@ package initialisation
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/fossyy/filekeeper/utils"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -20,8 +21,13 @@ import (
var log *logger.AggregatedLogger var log *logger.AggregatedLogger
// TESTTING VAR
var database db.Database
func init() { func init() {
log = logger.Logger() log = logger.Logger()
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
} }
func POST(w http.ResponseWriter, r *http.Request) { func POST(w http.ResponseWriter, r *http.Request) {
@ -53,7 +59,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
return return
} }
fileData, err := getFile(fileInfo.Name, userSession.UserID) fileData, err := database.GetUserFile(fileInfo.Name, userSession.UserID.String())
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
upload, err := handleNewUpload(userSession, fileInfo) upload, err := handleNewUpload(userSession, fileInfo)
@ -68,7 +74,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
return return
} }
info, err := GetUploadInfo(fileData.ID.String()) info, err := database.GetUploadInfo(fileData.ID.String())
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
return return
@ -81,15 +87,6 @@ func POST(w http.ResponseWriter, r *http.Request) {
respondJSON(w, info) respondJSON(w, info)
} }
func getFile(name string, ownerID uuid.UUID) (models.File, error) {
var data models.File
err := db.DB.Table("files").Where("name = ? AND owner_id = ?", name, ownerID).First(&data).Error
if err != nil {
return data, err
}
return data, nil
}
func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded, error) { func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded, error) {
uploadDir := "uploads" uploadDir := "uploads"
if _, err := os.Stat(uploadDir); os.IsNotExist(err) { if _, err := os.Stat(uploadDir); os.IsNotExist(err) {
@ -124,7 +121,8 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded
Size: file.Size, Size: file.Size,
Downloaded: 0, Downloaded: 0,
} }
err = db.DB.Create(&newFile).Error
err = database.CreateFile(&newFile)
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
return models.FilesUploaded{}, err return models.FilesUploaded{}, err
@ -140,7 +138,7 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded
Done: false, Done: false,
} }
err = db.DB.Create(&filesUploaded).Error err = database.CreateUploadInfo(filesUploaded)
if err != nil { if err != nil {
log.Error(err.Error()) log.Error(err.Error())
return models.FilesUploaded{}, err return models.FilesUploaded{}, err
@ -148,15 +146,6 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded
return filesUploaded, nil return filesUploaded, nil
} }
func GetUploadInfo(fileID string) (*models.FilesUploaded, error) {
var data *models.FilesUploaded
err := db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).First(&data).Error
if err != nil {
return data, err
}
return data, nil
}
func respondJSON(w http.ResponseWriter, data interface{}) { func respondJSON(w http.ResponseWriter, data interface{}) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(data); err != nil { if err := json.NewEncoder(w).Encode(data); err != nil {

View File

@ -2,6 +2,8 @@ package uploadHandler
import ( import (
"errors" "errors"
"github.com/fossyy/filekeeper/db"
"github.com/fossyy/filekeeper/utils"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -9,8 +11,6 @@ import (
"strconv" "strconv"
"sync" "sync"
"github.com/fossyy/filekeeper/db"
"github.com/fossyy/filekeeper/handler/upload/initialisation"
"github.com/fossyy/filekeeper/logger" "github.com/fossyy/filekeeper/logger"
"github.com/fossyy/filekeeper/middleware" "github.com/fossyy/filekeeper/middleware"
"github.com/fossyy/filekeeper/session" "github.com/fossyy/filekeeper/session"
@ -20,8 +20,13 @@ import (
var log *logger.AggregatedLogger var log *logger.AggregatedLogger
var mu sync.Mutex var mu sync.Mutex
// TESTTING VAR
var database db.Database
func init() { func init() {
log = logger.Logger() log = logger.Logger()
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
} }
func GET(w http.ResponseWriter, r *http.Request) { func GET(w http.ResponseWriter, r *http.Request) {
@ -57,7 +62,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
userSession := middleware.GetUser(storeSession) userSession := middleware.GetUser(storeSession)
if r.FormValue("done") == "true" { if r.FormValue("done") == "true" {
finalizeFileUpload(fileID) database.FinalizeFileUpload(fileID)
return return
} }
@ -67,7 +72,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
return return
} }
file, err := initialisation.GetUploadInfo(fileID) file, err := database.GetUploadInfo(fileID)
if err != nil { if err != nil {
log.Error("error getting upload info: " + err.Error()) log.Error("error getting upload info: " + err.Error())
return return
@ -105,13 +110,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
return return
} }
updateIndex(index, fileID) database.UpdateUpdateIndex(index, fileID)
}
func finalizeFileUpload(fileID string) {
db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{
"Done": true,
})
} }
func createUploadDirectory(uploadDir string) error { func createUploadDirectory(uploadDir string) error {
@ -123,12 +122,6 @@ func createUploadDirectory(uploadDir string) error {
return nil return nil
} }
func updateIndex(index int, fileID string) {
db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{
"Uploaded": index,
})
}
func handleCookieError(w http.ResponseWriter, r *http.Request, err error) { func handleCookieError(w http.ResponseWriter, r *http.Request, err error) {
if errors.Is(err, http.ErrNoCookie) { if errors.Is(err, http.ErrNoCookie) {
http.Redirect(w, r, "/signin", http.StatusSeeOther) http.Redirect(w, r, "/signin", http.StatusSeeOther)