Separate database initialization into NewMYSQLdb function
This commit is contained in:
@ -9,15 +9,19 @@ import (
|
||||
"github.com/fossyy/filekeeper/middleware"
|
||||
"github.com/fossyy/filekeeper/session"
|
||||
"github.com/fossyy/filekeeper/types"
|
||||
"github.com/fossyy/filekeeper/types/models"
|
||||
"github.com/fossyy/filekeeper/utils"
|
||||
downloadView "github.com/fossyy/filekeeper/view/download"
|
||||
)
|
||||
|
||||
var log *logger.AggregatedLogger
|
||||
|
||||
// TESTTING VAR
|
||||
var database db.Database
|
||||
|
||||
func init() {
|
||||
log = logger.Logger()
|
||||
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
|
||||
|
||||
}
|
||||
|
||||
func GET(w http.ResponseWriter, r *http.Request) {
|
||||
@ -41,8 +45,12 @@ func GET(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
userSession := middleware.GetUser(storeSession)
|
||||
|
||||
var files []models.File
|
||||
db.DB.Table("files").Where("owner_id = ?", userSession.UserID).Find(&files)
|
||||
files, err := database.GetFiles(userSession.UserID.String())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var filesData []types.FileData
|
||||
for i := 0; i < len(files); i++ {
|
||||
filesData = append(filesData, types.FileData{
|
||||
|
@ -1,29 +1,33 @@
|
||||
package downloadFileHandler
|
||||
|
||||
import (
|
||||
"github.com/fossyy/filekeeper/utils"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/fossyy/filekeeper/db"
|
||||
"github.com/fossyy/filekeeper/logger"
|
||||
"github.com/fossyy/filekeeper/types/models"
|
||||
)
|
||||
|
||||
var log *logger.AggregatedLogger
|
||||
|
||||
// TESTTING VAR
|
||||
var database db.Database
|
||||
|
||||
func init() {
|
||||
log = logger.Logger()
|
||||
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
|
||||
|
||||
}
|
||||
|
||||
func GET(w http.ResponseWriter, r *http.Request) {
|
||||
fileID := r.PathValue("id")
|
||||
|
||||
var file models.File
|
||||
err := db.DB.Table("files").Where("id = ?", fileID).First(&file).Error
|
||||
file, err := database.GetFile(fileID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
uploadDir := "uploads"
|
||||
@ -42,6 +46,7 @@ func GET(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
defer openFile.Close()
|
||||
|
||||
@ -49,6 +54,7 @@ func GET(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Disposition", "attachment; filename="+stat.Name())
|
||||
|
@ -33,12 +33,17 @@ var mailServer *email.SmtpServer
|
||||
var ListForgotPassword map[string]*ForgotPassword
|
||||
var UserForgotPassword = make(map[string]string)
|
||||
|
||||
// TESTTING VAR
|
||||
var database db.Database
|
||||
|
||||
func init() {
|
||||
log = logger.Logger()
|
||||
ListForgotPassword = make(map[string]*ForgotPassword)
|
||||
smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT"))
|
||||
mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD"))
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
//TESTING
|
||||
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
|
||||
go func() {
|
||||
for {
|
||||
<-ticker.C
|
||||
@ -84,8 +89,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
emailForm := r.Form.Get("email")
|
||||
|
||||
var user models.User
|
||||
err = db.DB.Table("users").Where("email = ?", emailForm).First(&user).Error
|
||||
user, err := database.GetUser(emailForm)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
component := forgotPasswordView.Main(fmt.Sprintf("Account with this email address %s is not found", emailForm), types.Message{
|
||||
Code: 0,
|
||||
@ -100,7 +104,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = verifyForgot(&user)
|
||||
err = verifyForgot(user)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Error(err.Error())
|
||||
|
@ -16,8 +16,14 @@ import (
|
||||
|
||||
var log *logger.AggregatedLogger
|
||||
|
||||
// TESTTING VAR
|
||||
var database db.Database
|
||||
|
||||
func init() {
|
||||
log = logger.Logger()
|
||||
//TESTING
|
||||
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
|
||||
|
||||
}
|
||||
|
||||
func GET(w http.ResponseWriter, r *http.Request) {
|
||||
@ -84,7 +90,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = db.DB.Table("users").Where("email = ?", data.User.Email).Update("password", hashedPassword).Error
|
||||
err = database.UpdateUserPassword(data.User.Email, hashedPassword)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Error(err.Error())
|
||||
|
@ -3,7 +3,6 @@ package signupHandler
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@ -19,7 +18,6 @@ import (
|
||||
emailView "github.com/fossyy/filekeeper/view/email"
|
||||
signupView "github.com/fossyy/filekeeper/view/signup"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UnverifiedUser struct {
|
||||
@ -34,12 +32,16 @@ var mailServer *email.SmtpServer
|
||||
var VerifyUser map[string]*UnverifiedUser
|
||||
var VerifyEmail map[string]string
|
||||
|
||||
// TESTTING VAR
|
||||
var database db.Database
|
||||
|
||||
func init() {
|
||||
log = logger.Logger()
|
||||
smtpPort, _ := strconv.Atoi(utils.Getenv("SMTP_PORT"))
|
||||
mailServer = email.NewSmtpServer(utils.Getenv("SMTP_HOST"), smtpPort, utils.Getenv("SMTP_USER"), utils.Getenv("SMTP_PASSWORD"))
|
||||
VerifyUser = make(map[string]*UnverifiedUser)
|
||||
VerifyEmail = make(map[string]string)
|
||||
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
go func() {
|
||||
@ -110,34 +112,28 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
Password: hashedPassword,
|
||||
}
|
||||
|
||||
var data models.User
|
||||
err = db.DB.Table("users").Where("email = ? OR username = ?", userEmail, username).First(&data).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
err = verifyEmail(&newUser)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
component := signupView.EmailSend("Sign up Page")
|
||||
err = component.Render(r.Context(), w)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
if registered := database.IsUserRegistered(userEmail, username); registered {
|
||||
component := signupView.Main("Sign up Page", types.Message{
|
||||
Code: 0,
|
||||
Message: "Email or Username has been registered",
|
||||
})
|
||||
err = component.Render(r.Context(), w)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err = verifyEmail(&newUser)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Error(err.Error())
|
||||
return
|
||||
}
|
||||
component := signupView.Main("Sign up Page", types.Message{
|
||||
Code: 0,
|
||||
Message: "Email or Username has been registered",
|
||||
})
|
||||
|
||||
component := signupView.EmailSend("Sign up Page")
|
||||
err = component.Render(r.Context(), w)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package signupVerifyHandler
|
||||
|
||||
import (
|
||||
"github.com/fossyy/filekeeper/utils"
|
||||
"net/http"
|
||||
|
||||
"github.com/fossyy/filekeeper/db"
|
||||
@ -12,8 +13,13 @@ import (
|
||||
|
||||
var log *logger.AggregatedLogger
|
||||
|
||||
// TESTTING VAR
|
||||
var database db.Database
|
||||
|
||||
func init() {
|
||||
log = logger.Logger()
|
||||
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
|
||||
|
||||
}
|
||||
|
||||
func GET(w http.ResponseWriter, r *http.Request) {
|
||||
@ -25,8 +31,7 @@ func GET(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err := db.DB.Create(&data.User).Error
|
||||
|
||||
err := database.CreateUser(data.User)
|
||||
if err != nil {
|
||||
component := signupView.Main("Sign up Page", types.Message{
|
||||
Code: 0,
|
||||
|
@ -3,6 +3,7 @@ package initialisation
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/fossyy/filekeeper/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -20,8 +21,13 @@ import (
|
||||
|
||||
var log *logger.AggregatedLogger
|
||||
|
||||
// TESTTING VAR
|
||||
var database db.Database
|
||||
|
||||
func init() {
|
||||
log = logger.Logger()
|
||||
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
|
||||
|
||||
}
|
||||
|
||||
func POST(w http.ResponseWriter, r *http.Request) {
|
||||
@ -53,7 +59,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
fileData, err := getFile(fileInfo.Name, userSession.UserID)
|
||||
fileData, err := database.GetUserFile(fileInfo.Name, userSession.UserID.String())
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
upload, err := handleNewUpload(userSession, fileInfo)
|
||||
@ -68,7 +74,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
info, err := GetUploadInfo(fileData.ID.String())
|
||||
info, err := database.GetUploadInfo(fileData.ID.String())
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return
|
||||
@ -81,15 +87,6 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
respondJSON(w, info)
|
||||
}
|
||||
|
||||
func getFile(name string, ownerID uuid.UUID) (models.File, error) {
|
||||
var data models.File
|
||||
err := db.DB.Table("files").Where("name = ? AND owner_id = ?", name, ownerID).First(&data).Error
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded, error) {
|
||||
uploadDir := "uploads"
|
||||
if _, err := os.Stat(uploadDir); os.IsNotExist(err) {
|
||||
@ -124,7 +121,8 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded
|
||||
Size: file.Size,
|
||||
Downloaded: 0,
|
||||
}
|
||||
err = db.DB.Create(&newFile).Error
|
||||
|
||||
err = database.CreateFile(&newFile)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return models.FilesUploaded{}, err
|
||||
@ -140,7 +138,7 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded
|
||||
Done: false,
|
||||
}
|
||||
|
||||
err = db.DB.Create(&filesUploaded).Error
|
||||
err = database.CreateUploadInfo(filesUploaded)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return models.FilesUploaded{}, err
|
||||
@ -148,15 +146,6 @@ func handleNewUpload(user types.User, file types.FileInfo) (models.FilesUploaded
|
||||
return filesUploaded, nil
|
||||
}
|
||||
|
||||
func GetUploadInfo(fileID string) (*models.FilesUploaded, error) {
|
||||
var data *models.FilesUploaded
|
||||
err := db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).First(&data).Error
|
||||
if err != nil {
|
||||
return data, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func respondJSON(w http.ResponseWriter, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
|
@ -2,6 +2,8 @@ package uploadHandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/fossyy/filekeeper/db"
|
||||
"github.com/fossyy/filekeeper/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -9,8 +11,6 @@ import (
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/fossyy/filekeeper/db"
|
||||
"github.com/fossyy/filekeeper/handler/upload/initialisation"
|
||||
"github.com/fossyy/filekeeper/logger"
|
||||
"github.com/fossyy/filekeeper/middleware"
|
||||
"github.com/fossyy/filekeeper/session"
|
||||
@ -20,8 +20,13 @@ import (
|
||||
var log *logger.AggregatedLogger
|
||||
var mu sync.Mutex
|
||||
|
||||
// TESTTING VAR
|
||||
var database db.Database
|
||||
|
||||
func init() {
|
||||
log = logger.Logger()
|
||||
database = db.NewMYSQLdb(utils.Getenv("DB_USERNAME"), utils.Getenv("DB_PASSWORD"), utils.Getenv("DB_HOST"), utils.Getenv("DB_PORT"), utils.Getenv("DB_NAME"))
|
||||
|
||||
}
|
||||
|
||||
func GET(w http.ResponseWriter, r *http.Request) {
|
||||
@ -57,7 +62,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
userSession := middleware.GetUser(storeSession)
|
||||
|
||||
if r.FormValue("done") == "true" {
|
||||
finalizeFileUpload(fileID)
|
||||
database.FinalizeFileUpload(fileID)
|
||||
return
|
||||
}
|
||||
|
||||
@ -67,7 +72,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
file, err := initialisation.GetUploadInfo(fileID)
|
||||
file, err := database.GetUploadInfo(fileID)
|
||||
if err != nil {
|
||||
log.Error("error getting upload info: " + err.Error())
|
||||
return
|
||||
@ -105,13 +110,7 @@ func POST(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
updateIndex(index, fileID)
|
||||
}
|
||||
|
||||
func finalizeFileUpload(fileID string) {
|
||||
db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{
|
||||
"Done": true,
|
||||
})
|
||||
database.UpdateUpdateIndex(index, fileID)
|
||||
}
|
||||
|
||||
func createUploadDirectory(uploadDir string) error {
|
||||
@ -123,12 +122,6 @@ func createUploadDirectory(uploadDir string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateIndex(index int, fileID string) {
|
||||
db.DB.Table("files_uploadeds").Where("file_id = ?", fileID).Updates(map[string]interface{}{
|
||||
"Uploaded": index,
|
||||
})
|
||||
}
|
||||
|
||||
func handleCookieError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if errors.Is(err, http.ErrNoCookie) {
|
||||
http.Redirect(w, r, "/signin", http.StatusSeeOther)
|
||||
|
Reference in New Issue
Block a user