diff --git a/db/database.go b/db/database.go index c5870c0..4eeabe2 100644 --- a/db/database.go +++ b/db/database.go @@ -261,6 +261,20 @@ func (db *mySQLdb) GetFiles(ownerID string) ([]*models.File, error) { return files, err } +func (db *mySQLdb) IncrementDownloadCount(fileID string) error { + var file models.File + err := db.DB.Table("files").Where("id = ?", fileID).First(&file).Error + if err != nil { + return err + } + file.Downloaded = file.Downloaded + 1 + err = db.DB.Updates(file).Error + if err != nil { + return err + } + return nil +} + func (db *mySQLdb) InitializeTotp(email string, secret string) error { var user models.User err := db.DB.Table("users").Where("email = ?", email).First(&user).Error @@ -398,6 +412,20 @@ func (db *postgresDB) GetFiles(ownerID string) ([]*models.File, error) { return files, err } +func (db *postgresDB) IncrementDownloadCount(fileID string) error { + var file models.File + err := db.DB.Table("files").Where("id = $1", fileID).First(&file).Error + if err != nil { + return err + } + file.Downloaded = file.Downloaded + 1 + err = db.DB.Updates(file).Error + if err != nil { + return err + } + return nil +} + func (db *postgresDB) InitializeTotp(email string, secret string) error { var user models.User err := db.DB.Table("users").Where("email = $1", email).First(&user).Error diff --git a/handler/file/download/download.go b/handler/file/download/download.go index 22176cc..5d7a9cb 100644 --- a/handler/file/download/download.go +++ b/handler/file/download/download.go @@ -68,6 +68,7 @@ func GET(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", fmt.Sprintf("%d", file.Size)) sendFileChunk(w, saveFolder, file, 0, int64(file.Size-1)) + return } func sendFileChunk(w http.ResponseWriter, saveFolder string, file *models.File, start, end int64) { @@ -130,6 +131,14 @@ func sendFileChunk(w http.ResponseWriter, saveFolder string, file *models.File, return } toSend -= int64(n) + if i == int64(file.TotalChunk)-1 && toSend == 0 { + err := app.Server.Database.IncrementDownloadCount(file.ID.String()) + if err != nil { + http.Error(w, fmt.Sprintf("Error writing chunk: %v", err), http.StatusInternalServerError) + app.Server.Logger.Error(err.Error()) + return + } + } } } } diff --git a/types/types.go b/types/types.go index 2a15117..636e7c5 100644 --- a/types/types.go +++ b/types/types.go @@ -59,6 +59,7 @@ type Database interface { GetFile(fileID string) (*models.File, error) GetUserFile(name string, ownerID string) (*models.File, error) GetFiles(ownerID string) ([]*models.File, error) + IncrementDownloadCount(fileID string) error InitializeTotp(email string, secret string) error }