From db1ec0a771e2fcb94608e280dad8ae005553c20e Mon Sep 17 00:00:00 2001 From: deo002 <oberoidearsh@gmail.com> Date: Tue, 21 May 2024 16:17:47 +0530 Subject: [PATCH] refactor(obligation_maps): Add transactions and refactor code --- pkg/api/obligationmap.go | 267 ++++++++++++++++++++++++++------------- 1 file changed, 181 insertions(+), 86 deletions(-) diff --git a/pkg/api/obligationmap.go b/pkg/api/obligationmap.go index b982f72..a9982a6 100644 --- a/pkg/api/obligationmap.go +++ b/pkg/api/obligationmap.go @@ -6,6 +6,7 @@ package api import ( + "errors" "fmt" "net/http" "strconv" @@ -40,9 +41,8 @@ func GetObligationMapByTopic(c *gin.Context) { var shortnameList []string topic := c.Param("topic") - query := db.DB.Model(&obligation) - if err := query.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil { + if err := db.DB.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil { er := models.LicenseError{ Status: http.StatusNotFound, Message: fmt.Sprintf("obligation with topic '%s' not found", topic), @@ -54,7 +54,7 @@ func GetObligationMapByTopic(c *gin.Context) { return } - if err := getObligationMapsForObligation(obligation.Id, &obMap).Error; err != nil { + if err := db.DB.Where(models.ObligationMap{ObligationPk: obligation.Id}).Find(&obMap).Error; err != nil { er := models.LicenseError{ Status: http.StatusNotFound, Message: fmt.Sprintf("Obligation map not found for topic '%s'", topic), @@ -68,8 +68,17 @@ func GetObligationMapByTopic(c *gin.Context) { for i := 0; i < len(obMap); i++ { var license models.LicenseDB - licenseQuery := db.DB.Model(&license) - licenseQuery.Where(models.LicenseDB{Id: obMap[i].RfPk}).First(&license) + if err := db.DB.Where(models.LicenseDB{Id: obMap[i].RfPk}).First(&license).Error; err != nil { + er := models.LicenseError{ + Status: http.StatusNotFound, + Message: "Unable to fetch license shortnames", + Error: err.Error(), + Path: c.Request.URL.Path, + Timestamp: time.Now().Format(time.RFC3339), + } + c.JSON(http.StatusNotFound, er) + return + } shortnameList = append(shortnameList, license.Shortname) } @@ -88,10 +97,6 @@ func GetObligationMapByTopic(c *gin.Context) { c.JSON(http.StatusOK, res) } -func getObligationMapsForObligation(obligationId int64, obMap *[]models.ObligationMap) *gorm.DB { - return db.DB.Model(&obMap).Where(models.ObligationMap{ObligationPk: obligationId}).Find(&obMap) -} - // GetObligationMapByLicense retrieves obligation maps for given license shortname // // @Summary Get maps for a license @@ -111,9 +116,8 @@ func GetObligationMapByLicense(c *gin.Context) { var resObMapList []models.ObligationMapUser licenseShortName := c.Param("license") - query := db.DB.Model(&license) - if err := query.Where(models.LicenseDB{Shortname: licenseShortName}).First(&license).Error; err != nil { + if err := db.DB.Where(models.LicenseDB{Shortname: licenseShortName}).First(&license).Error; err != nil { er := models.LicenseError{ Status: http.StatusNotFound, Message: fmt.Sprintf("license with shortname '%s' not found", licenseShortName), @@ -125,9 +129,7 @@ func GetObligationMapByLicense(c *gin.Context) { return } - query = db.DB.Model(&obMap) - - if err := query.Where(models.ObligationMap{RfPk: license.Id}).Find(&obMap).Error; err != nil { + if err := db.DB.Where(models.ObligationMap{RfPk: license.Id}).Find(&obMap).Error; err != nil { er := models.LicenseError{ Status: http.StatusNotFound, Message: fmt.Sprintf("Obligation map not found for license '%s'", licenseShortName), @@ -141,8 +143,17 @@ func GetObligationMapByLicense(c *gin.Context) { for i := 0; i < len(obMap); i++ { var obligation models.Obligation - obligationQuery := db.DB.Model(&obligation) - obligationQuery.Where(models.Obligation{Id: obMap[i].ObligationPk}).First(&obligation) + if err := db.DB.Where(models.Obligation{Id: obMap[i].ObligationPk}).First(&obligation).Error; err != nil { + er := models.LicenseError{ + Status: http.StatusNotFound, + Message: fmt.Sprintf("Unable to fetch obligations linked with license '%s'", licenseShortName), + Error: err.Error(), + Path: c.Request.URL.Path, + Timestamp: time.Now().Format(time.RFC3339), + } + c.JSON(http.StatusNotFound, er) + return + } resObMapList = append(resObMapList, models.ObligationMapUser{ Topic: obligation.Topic, Shortnames: []string{licenseShortName}, @@ -195,9 +206,7 @@ func PatchObligationMap(c *gin.Context) { return } - query := db.DB.Model(&obligation) - - if err := query.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil { + if err := db.DB.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil { er := models.LicenseError{ Status: http.StatusNotFound, Message: fmt.Sprintf("obligation with topic '%s' not found", topic), @@ -212,9 +221,7 @@ func PatchObligationMap(c *gin.Context) { for i := 0; i < len(obMapInput.MapInput); i++ { var license models.LicenseDB var obligationMap models.ObligationMap - err := db.DB.Model(&license).Where(&models.LicenseDB{Shortname: obMapInput.MapInput[i].Shortname}).First( - &license).Error - if err != nil { + if err := db.DB.Where(&models.LicenseDB{Shortname: obMapInput.MapInput[i].Shortname}).First(&license).Error; err != nil { er := models.LicenseError{ Status: http.StatusNotFound, Message: fmt.Sprintf("license with shortname '%s' not found", obMapInput.MapInput[i].Shortname), @@ -225,13 +232,25 @@ func PatchObligationMap(c *gin.Context) { c.JSON(http.StatusNotFound, er) return } - if err := db.DB.Model(&obligationMap).Where(&models.ObligationMap{ObligationPk: obligation.Id, - RfPk: license.Id}).First(&obligationMap).Error; err != nil { + if err := db.DB.Where(&models.ObligationMap{ObligationPk: obligation.Id, RfPk: license.Id}).First(&obligationMap).Error; err != nil { // License not in map - if obMapInput.MapInput[i].Add { - // Add to insert slice - insertLicenseIds = append(insertLicenseIds, license.Id) + if errors.Is(err, gorm.ErrRecordNotFound) { + if obMapInput.MapInput[i].Add { + // Add to insert slice + insertLicenseIds = append(insertLicenseIds, license.Id) + } + } else { + er := models.LicenseError{ + Status: http.StatusInternalServerError, + Message: fmt.Sprintf("unable to fetch obligation maps for obligation with topic '%s'", obligation.Topic), + Error: err.Error(), + Path: c.Request.URL.Path, + Timestamp: time.Now().Format(time.RFC3339), + } + c.JSON(http.StatusInternalServerError, er) + return } + } else { // License in map if !obMapInput.MapInput[i].Add { @@ -243,7 +262,18 @@ func PatchObligationMap(c *gin.Context) { username := c.GetString("username") - res := performObligationMapActions(username, obligation, removeLicenseIds, insertLicenseIds) + res, err := PerformObligationMapActions(username, obligation, removeLicenseIds, insertLicenseIds) + if err != nil { + er := models.LicenseError{ + Status: http.StatusInternalServerError, + Message: "Something went wrong", + Error: err.Error(), + Path: c.Request.URL.Path, + Timestamp: time.Now().Format(time.RFC3339), + } + c.JSON(http.StatusInternalServerError, er) + return + } c.JSON(http.StatusOK, res) } @@ -266,12 +296,23 @@ func PatchObligationMap(c *gin.Context) { func UpdateLicenseInObligationMap(c *gin.Context) { var obligation models.Obligation var obMapInput models.LicenseShortnamesInput - var oldObMaps []models.ObligationMap var removeLicenseIds []int64 var insertLicenseIds []int64 topic := c.Param("topic") + if err := db.DB.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil { + er := models.LicenseError{ + Status: http.StatusNotFound, + Message: fmt.Sprintf("obligation with topic '%s' not found", topic), + Error: err.Error(), + Path: c.Request.URL.Path, + Timestamp: time.Now().Format(time.RFC3339), + } + c.JSON(http.StatusNotFound, er) + return + } + if err := c.ShouldBindJSON(&obMapInput); err != nil { er := models.LicenseError{ Status: http.StatusBadRequest, @@ -286,12 +327,17 @@ func UpdateLicenseInObligationMap(c *gin.Context) { obMapInput.Shortnames = slices.Compact(obMapInput.Shortnames) - query := db.DB.Model(&obligation) + username := c.GetString("username") - if err := query.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil { + if err := GenerateDiffOfLicenses(c, &obligation, obMapInput.Shortnames, &removeLicenseIds, &insertLicenseIds); err != nil { + return + } + + res, err := PerformObligationMapActions(username, obligation, removeLicenseIds, insertLicenseIds) + if err != nil { er := models.LicenseError{ - Status: http.StatusNotFound, - Message: fmt.Sprintf("obligation with topic '%s' not found", topic), + Status: http.StatusInternalServerError, + Message: "Something went wrong", Error: err.Error(), Path: c.Request.URL.Path, Timestamp: time.Now().Format(time.RFC3339), @@ -300,54 +346,79 @@ func UpdateLicenseInObligationMap(c *gin.Context) { return } - getObligationMapsForObligation(obligation.Id, &oldObMaps) + c.JSON(http.StatusOK, res) +} + +// GenerateDiffOfLicenses calculates diff from the obligation maps list in database and the list provided by the user to determine the licenses to be +// inserted and the licenses to be removed. Basically, it replaces the list present in database by the list given by the user. +func GenerateDiffOfLicenses(c *gin.Context, obligation *models.Obligation, inputShortnames []string, removeLicenseIds, insertLicenseIds *[]int64) error { + var oldObMaps []models.ObligationMap + if err := db.DB.Where(models.ObligationMap{ObligationPk: obligation.Id}).Find(&oldObMaps).Error; err != nil { + er := models.LicenseError{ + Status: http.StatusNotFound, + Message: fmt.Sprintf("obligation maps for obligation with topic '%s' not found", obligation.Topic), + Error: err.Error(), + Path: c.Request.URL.Path, + Timestamp: time.Now().Format(time.RFC3339), + } + c.JSON(http.StatusNotFound, er) + return err + } for i := 0; i < len(oldObMaps); i++ { - removeLicenseIds = append(removeLicenseIds, oldObMaps[i].RfPk) + *removeLicenseIds = append(*removeLicenseIds, oldObMaps[i].RfPk) } - for i := 0; i < len(obMapInput.Shortnames); i++ { + for i := 0; i < len(inputShortnames); i++ { var license models.LicenseDB var obligationMap models.ObligationMap - err := db.DB.Model(&license).Where(&models.LicenseDB{Shortname: obMapInput.Shortnames[i]}).First(&license).Error - if err != nil { + if err := db.DB.Where(&models.LicenseDB{Shortname: inputShortnames[i]}).First(&license).Error; err != nil { er := models.LicenseError{ Status: http.StatusNotFound, - Message: fmt.Sprintf("license with shortname '%s' not found", obMapInput.Shortnames[i]), + Message: fmt.Sprintf("license with shortname '%s' not found", inputShortnames[i]), Error: err.Error(), Path: c.Request.URL.Path, Timestamp: time.Now().Format(time.RFC3339), } c.JSON(http.StatusNotFound, er) - return + return err } - if err := db.DB.Model(&obligationMap).Where(&models.ObligationMap{ObligationPk: obligation.Id, - RfPk: license.Id}).First(&obligationMap).Error; err != nil { + if err := db.DB.Where(&models.ObligationMap{ObligationPk: obligation.Id, RfPk: license.Id}).First(&obligationMap).Error; err != nil { // License not in map, add to insert slice - insertLicenseIds = append(insertLicenseIds, license.Id) + if errors.Is(err, gorm.ErrRecordNotFound) { + *insertLicenseIds = append(*insertLicenseIds, license.Id) + } else { + er := models.LicenseError{ + Status: http.StatusInternalServerError, + Message: fmt.Sprintf("unable to fetch obligation maps for obligation with topic '%s'", obligation.Topic), + Error: err.Error(), + Path: c.Request.URL.Path, + Timestamp: time.Now().Format(time.RFC3339), + } + c.JSON(http.StatusInternalServerError, er) + return err + } } // License in request, remove from removal slice - removeLicenseIds = removeFromSlice(removeLicenseIds, license.Id) + *removeLicenseIds = removeFromSlice(*removeLicenseIds, license.Id) } - username := c.GetString("username") - - res := performObligationMapActions(username, obligation, removeLicenseIds, insertLicenseIds) - - c.JSON(http.StatusOK, res) + return nil } -// performObligationMapActions performs the actions for ObligationMap endpoint PATCH and PUT calls. +// PerformObligationMapActions performs the actions for ObligationMap endpoint PATCH and PUT calls. // It takes the input of obligation which is being modified, list of licenses to be removed and added, // and the user making the changes. The function computes the changelog and returns the response. -func performObligationMapActions(username string, obligation models.Obligation, removeLicenseIds []int64, - insertLicenseIds []int64) models.ObligationMapResponse { +func PerformObligationMapActions(username string, obligation models.Obligation, removeLicenseIds []int64, + insertLicenseIds []int64) (*models.ObligationMapResponse, error) { var oldObMaps []models.ObligationMap var newObMaps []models.ObligationMap var removeObMaps []models.ObligationMap var insertObMaps []models.ObligationMap - getObligationMapsForObligation(obligation.Id, &oldObMaps) + if err := db.DB.Where(models.ObligationMap{ObligationPk: obligation.Id}).Find(&oldObMaps).Error; err != nil { + return nil, err + } for i := 0; i < len(removeLicenseIds); i++ { deleteItem := models.ObligationMap{ @@ -355,7 +426,9 @@ func performObligationMapActions(username string, obligation models.Obligation, RfPk: removeLicenseIds[i], } // Find the primary key to make delete faster - db.DB.Where(&deleteItem).First(&deleteItem) + if err := db.DB.Where(&deleteItem).First(&deleteItem).Error; err != nil { + return nil, err + } removeObMaps = append(removeObMaps, deleteItem) } for i := 0; i < len(insertLicenseIds); i++ { @@ -365,44 +438,48 @@ func performObligationMapActions(username string, obligation models.Obligation, }) } - if len(removeObMaps) > 0 { - // Bulk delete removeObMaps from DB - db.DB.Delete(&removeObMaps) - } - if len(insertObMaps) > 0 { - // Bulk create insertObMaps in DB - db.DB.Create(&insertObMaps) + if err := db.DB.Transaction(func(tx *gorm.DB) error { + if len(removeObMaps) > 0 { + // Bulk delete removeObMaps from DB + if err := tx.Delete(&removeObMaps).Error; err != nil { + return err + } + } + if len(insertObMaps) > 0 { + // Bulk create insertObMaps in DB + if err := tx.Create(&insertObMaps).Error; err != nil { + return err + } + } + + if err := tx.Where(models.ObligationMap{ObligationPk: obligation.Id}).Find(&newObMaps).Error; err != nil { + return err + } + + return createObligationMapChangelog(tx, username, oldObMaps, newObMaps, &obligation) + + }); err != nil { + return nil, err } - getObligationMapsForObligation(obligation.Id, &newObMaps) + obMap, err := createObligationMapUser(obligation, newObMaps) + if err != nil { + return nil, err + } res := models.ObligationMapResponse{ - Data: []models.ObligationMapUser{createObligationMapUser(obligation, newObMaps)}, + Data: []models.ObligationMapUser{*obMap}, Status: http.StatusOK, Meta: models.PaginationMeta{ ResourceCount: 1, }, } - var user models.User - db.DB.Where(models.User{Username: username}).First(&user) - audit := models.Audit{ - UserId: user.Id, - TypeId: obligation.Id, - Timestamp: time.Now(), - Type: "obligation_map", - } - - db.DB.Create(&audit) - - change := createObligationMapChangelog(oldObMaps, newObMaps, audit) - db.DB.Create(&change) - return res + return &res, nil } // createObligationMapChangelog creates the changelog for the obligation map changes. -func createObligationMapChangelog(oldObMaps []models.ObligationMap, newObMaps []models.ObligationMap, - audit models.Audit) models.ChangeLog { +func createObligationMapChangelog(tx *gorm.DB, username string, oldObMaps, newObMaps []models.ObligationMap, obligation *models.Obligation) error { var oldLicenses []string var newLicenses []string @@ -416,12 +493,29 @@ func createObligationMapChangelog(oldObMaps []models.ObligationMap, newObMaps [] oldVal := strings.Join(oldLicenses, ",") newVal := strings.Join(newLicenses, ",") change := models.ChangeLog{ - AuditId: audit.Id, Field: "RfPk", OldValue: &oldVal, UpdatedValue: &newVal, } - return change + + var user models.User + if err := tx.Where(models.User{Username: username}).First(&user).Error; err != nil { + return err + } + + audit := models.Audit{ + UserId: user.Id, + TypeId: obligation.Id, + Timestamp: time.Now(), + Type: "license", + ChangeLogs: []models.ChangeLog{change}, + } + + if err := tx.Create(&audit).Error; err != nil { + return err + } + + return nil } // removeFromSlice removes the item from the slice if it exists. @@ -433,16 +527,17 @@ func removeFromSlice[E string | int64](slice []E, item E) []E { } // createObligationMapUser creates the response data for the obligation map endpoint. -func createObligationMapUser(obligation models.Obligation, obMaps []models.ObligationMap) models.ObligationMapUser { +func createObligationMapUser(obligation models.Obligation, obMaps []models.ObligationMap) (*models.ObligationMapUser, error) { var shortnameList []string for i := 0; i < len(obMaps); i++ { var license models.LicenseDB - licenseQuery := db.DB.Model(&license) - licenseQuery.Where(models.LicenseDB{Id: obMaps[i].RfPk}).First(&license) + if err := db.DB.Where(models.LicenseDB{Id: obMaps[i].RfPk}).First(&license).Error; err != nil { + return nil, err + } shortnameList = append(shortnameList, license.Shortname) } - return models.ObligationMapUser{ + return &models.ObligationMapUser{ Topic: obligation.Topic, Shortnames: shortnameList, - } + }, nil }