From db1ec0a771e2fcb94608e280dad8ae005553c20e Mon Sep 17 00:00:00 2001
From: deo002 <oberoidearsh@gmail.com>
Date: Tue, 21 May 2024 16:17:47 +0530
Subject: [PATCH] refactor(obligation_maps): Add transactions and refactor code

---
 pkg/api/obligationmap.go | 267 ++++++++++++++++++++++++++-------------
 1 file changed, 181 insertions(+), 86 deletions(-)

diff --git a/pkg/api/obligationmap.go b/pkg/api/obligationmap.go
index b982f72..a9982a6 100644
--- a/pkg/api/obligationmap.go
+++ b/pkg/api/obligationmap.go
@@ -6,6 +6,7 @@
 package api
 
 import (
+	"errors"
 	"fmt"
 	"net/http"
 	"strconv"
@@ -40,9 +41,8 @@ func GetObligationMapByTopic(c *gin.Context) {
 	var shortnameList []string
 
 	topic := c.Param("topic")
-	query := db.DB.Model(&obligation)
 
-	if err := query.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil {
+	if err := db.DB.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil {
 		er := models.LicenseError{
 			Status:    http.StatusNotFound,
 			Message:   fmt.Sprintf("obligation with topic '%s' not found", topic),
@@ -54,7 +54,7 @@ func GetObligationMapByTopic(c *gin.Context) {
 		return
 	}
 
-	if err := getObligationMapsForObligation(obligation.Id, &obMap).Error; err != nil {
+	if err := db.DB.Where(models.ObligationMap{ObligationPk: obligation.Id}).Find(&obMap).Error; err != nil {
 		er := models.LicenseError{
 			Status:    http.StatusNotFound,
 			Message:   fmt.Sprintf("Obligation map not found for topic '%s'", topic),
@@ -68,8 +68,17 @@ func GetObligationMapByTopic(c *gin.Context) {
 
 	for i := 0; i < len(obMap); i++ {
 		var license models.LicenseDB
-		licenseQuery := db.DB.Model(&license)
-		licenseQuery.Where(models.LicenseDB{Id: obMap[i].RfPk}).First(&license)
+		if err := db.DB.Where(models.LicenseDB{Id: obMap[i].RfPk}).First(&license).Error; err != nil {
+			er := models.LicenseError{
+				Status:    http.StatusNotFound,
+				Message:   "Unable to fetch license shortnames",
+				Error:     err.Error(),
+				Path:      c.Request.URL.Path,
+				Timestamp: time.Now().Format(time.RFC3339),
+			}
+			c.JSON(http.StatusNotFound, er)
+			return
+		}
 		shortnameList = append(shortnameList, license.Shortname)
 	}
 
@@ -88,10 +97,6 @@ func GetObligationMapByTopic(c *gin.Context) {
 	c.JSON(http.StatusOK, res)
 }
 
-func getObligationMapsForObligation(obligationId int64, obMap *[]models.ObligationMap) *gorm.DB {
-	return db.DB.Model(&obMap).Where(models.ObligationMap{ObligationPk: obligationId}).Find(&obMap)
-}
-
 // GetObligationMapByLicense retrieves obligation maps for given license shortname
 //
 //	@Summary		Get maps for a license
@@ -111,9 +116,8 @@ func GetObligationMapByLicense(c *gin.Context) {
 	var resObMapList []models.ObligationMapUser
 
 	licenseShortName := c.Param("license")
-	query := db.DB.Model(&license)
 
-	if err := query.Where(models.LicenseDB{Shortname: licenseShortName}).First(&license).Error; err != nil {
+	if err := db.DB.Where(models.LicenseDB{Shortname: licenseShortName}).First(&license).Error; err != nil {
 		er := models.LicenseError{
 			Status:    http.StatusNotFound,
 			Message:   fmt.Sprintf("license with shortname '%s' not found", licenseShortName),
@@ -125,9 +129,7 @@ func GetObligationMapByLicense(c *gin.Context) {
 		return
 	}
 
-	query = db.DB.Model(&obMap)
-
-	if err := query.Where(models.ObligationMap{RfPk: license.Id}).Find(&obMap).Error; err != nil {
+	if err := db.DB.Where(models.ObligationMap{RfPk: license.Id}).Find(&obMap).Error; err != nil {
 		er := models.LicenseError{
 			Status:    http.StatusNotFound,
 			Message:   fmt.Sprintf("Obligation map not found for license '%s'", licenseShortName),
@@ -141,8 +143,17 @@ func GetObligationMapByLicense(c *gin.Context) {
 
 	for i := 0; i < len(obMap); i++ {
 		var obligation models.Obligation
-		obligationQuery := db.DB.Model(&obligation)
-		obligationQuery.Where(models.Obligation{Id: obMap[i].ObligationPk}).First(&obligation)
+		if err := db.DB.Where(models.Obligation{Id: obMap[i].ObligationPk}).First(&obligation).Error; err != nil {
+			er := models.LicenseError{
+				Status:    http.StatusNotFound,
+				Message:   fmt.Sprintf("Unable to fetch obligations linked with license '%s'", licenseShortName),
+				Error:     err.Error(),
+				Path:      c.Request.URL.Path,
+				Timestamp: time.Now().Format(time.RFC3339),
+			}
+			c.JSON(http.StatusNotFound, er)
+			return
+		}
 		resObMapList = append(resObMapList, models.ObligationMapUser{
 			Topic:      obligation.Topic,
 			Shortnames: []string{licenseShortName},
@@ -195,9 +206,7 @@ func PatchObligationMap(c *gin.Context) {
 		return
 	}
 
-	query := db.DB.Model(&obligation)
-
-	if err := query.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil {
+	if err := db.DB.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil {
 		er := models.LicenseError{
 			Status:    http.StatusNotFound,
 			Message:   fmt.Sprintf("obligation with topic '%s' not found", topic),
@@ -212,9 +221,7 @@ func PatchObligationMap(c *gin.Context) {
 	for i := 0; i < len(obMapInput.MapInput); i++ {
 		var license models.LicenseDB
 		var obligationMap models.ObligationMap
-		err := db.DB.Model(&license).Where(&models.LicenseDB{Shortname: obMapInput.MapInput[i].Shortname}).First(
-			&license).Error
-		if err != nil {
+		if err := db.DB.Where(&models.LicenseDB{Shortname: obMapInput.MapInput[i].Shortname}).First(&license).Error; err != nil {
 			er := models.LicenseError{
 				Status:    http.StatusNotFound,
 				Message:   fmt.Sprintf("license with shortname '%s' not found", obMapInput.MapInput[i].Shortname),
@@ -225,13 +232,25 @@ func PatchObligationMap(c *gin.Context) {
 			c.JSON(http.StatusNotFound, er)
 			return
 		}
-		if err := db.DB.Model(&obligationMap).Where(&models.ObligationMap{ObligationPk: obligation.Id,
-			RfPk: license.Id}).First(&obligationMap).Error; err != nil {
+		if err := db.DB.Where(&models.ObligationMap{ObligationPk: obligation.Id, RfPk: license.Id}).First(&obligationMap).Error; err != nil {
 			// License not in map
-			if obMapInput.MapInput[i].Add {
-				// Add to insert slice
-				insertLicenseIds = append(insertLicenseIds, license.Id)
+			if errors.Is(err, gorm.ErrRecordNotFound) {
+				if obMapInput.MapInput[i].Add {
+					// Add to insert slice
+					insertLicenseIds = append(insertLicenseIds, license.Id)
+				}
+			} else {
+				er := models.LicenseError{
+					Status:    http.StatusInternalServerError,
+					Message:   fmt.Sprintf("unable to fetch obligation maps for obligation with topic '%s'", obligation.Topic),
+					Error:     err.Error(),
+					Path:      c.Request.URL.Path,
+					Timestamp: time.Now().Format(time.RFC3339),
+				}
+				c.JSON(http.StatusInternalServerError, er)
+				return
 			}
+
 		} else {
 			// License in map
 			if !obMapInput.MapInput[i].Add {
@@ -243,7 +262,18 @@ func PatchObligationMap(c *gin.Context) {
 
 	username := c.GetString("username")
 
-	res := performObligationMapActions(username, obligation, removeLicenseIds, insertLicenseIds)
+	res, err := PerformObligationMapActions(username, obligation, removeLicenseIds, insertLicenseIds)
+	if err != nil {
+		er := models.LicenseError{
+			Status:    http.StatusInternalServerError,
+			Message:   "Something went wrong",
+			Error:     err.Error(),
+			Path:      c.Request.URL.Path,
+			Timestamp: time.Now().Format(time.RFC3339),
+		}
+		c.JSON(http.StatusInternalServerError, er)
+		return
+	}
 
 	c.JSON(http.StatusOK, res)
 }
@@ -266,12 +296,23 @@ func PatchObligationMap(c *gin.Context) {
 func UpdateLicenseInObligationMap(c *gin.Context) {
 	var obligation models.Obligation
 	var obMapInput models.LicenseShortnamesInput
-	var oldObMaps []models.ObligationMap
 	var removeLicenseIds []int64
 	var insertLicenseIds []int64
 
 	topic := c.Param("topic")
 
+	if err := db.DB.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil {
+		er := models.LicenseError{
+			Status:    http.StatusNotFound,
+			Message:   fmt.Sprintf("obligation with topic '%s' not found", topic),
+			Error:     err.Error(),
+			Path:      c.Request.URL.Path,
+			Timestamp: time.Now().Format(time.RFC3339),
+		}
+		c.JSON(http.StatusNotFound, er)
+		return
+	}
+
 	if err := c.ShouldBindJSON(&obMapInput); err != nil {
 		er := models.LicenseError{
 			Status:    http.StatusBadRequest,
@@ -286,12 +327,17 @@ func UpdateLicenseInObligationMap(c *gin.Context) {
 
 	obMapInput.Shortnames = slices.Compact(obMapInput.Shortnames)
 
-	query := db.DB.Model(&obligation)
+	username := c.GetString("username")
 
-	if err := query.Where(models.Obligation{Topic: topic}).First(&obligation).Error; err != nil {
+	if err := GenerateDiffOfLicenses(c, &obligation, obMapInput.Shortnames, &removeLicenseIds, &insertLicenseIds); err != nil {
+		return
+	}
+
+	res, err := PerformObligationMapActions(username, obligation, removeLicenseIds, insertLicenseIds)
+	if err != nil {
 		er := models.LicenseError{
-			Status:    http.StatusNotFound,
-			Message:   fmt.Sprintf("obligation with topic '%s' not found", topic),
+			Status:    http.StatusInternalServerError,
+			Message:   "Something went wrong",
 			Error:     err.Error(),
 			Path:      c.Request.URL.Path,
 			Timestamp: time.Now().Format(time.RFC3339),
@@ -300,54 +346,79 @@ func UpdateLicenseInObligationMap(c *gin.Context) {
 		return
 	}
 
-	getObligationMapsForObligation(obligation.Id, &oldObMaps)
+	c.JSON(http.StatusOK, res)
+}
+
+// GenerateDiffOfLicenses calculates diff from the obligation maps list in database and the list provided by the user to determine the licenses to be
+// inserted and the licenses to be removed. Basically, it replaces the list present in database by the list given by the user.
+func GenerateDiffOfLicenses(c *gin.Context, obligation *models.Obligation, inputShortnames []string, removeLicenseIds, insertLicenseIds *[]int64) error {
+	var oldObMaps []models.ObligationMap
+	if err := db.DB.Where(models.ObligationMap{ObligationPk: obligation.Id}).Find(&oldObMaps).Error; err != nil {
+		er := models.LicenseError{
+			Status:    http.StatusNotFound,
+			Message:   fmt.Sprintf("obligation maps for obligation with topic '%s' not found", obligation.Topic),
+			Error:     err.Error(),
+			Path:      c.Request.URL.Path,
+			Timestamp: time.Now().Format(time.RFC3339),
+		}
+		c.JSON(http.StatusNotFound, er)
+		return err
+	}
 
 	for i := 0; i < len(oldObMaps); i++ {
-		removeLicenseIds = append(removeLicenseIds, oldObMaps[i].RfPk)
+		*removeLicenseIds = append(*removeLicenseIds, oldObMaps[i].RfPk)
 	}
 
-	for i := 0; i < len(obMapInput.Shortnames); i++ {
+	for i := 0; i < len(inputShortnames); i++ {
 		var license models.LicenseDB
 		var obligationMap models.ObligationMap
-		err := db.DB.Model(&license).Where(&models.LicenseDB{Shortname: obMapInput.Shortnames[i]}).First(&license).Error
-		if err != nil {
+		if err := db.DB.Where(&models.LicenseDB{Shortname: inputShortnames[i]}).First(&license).Error; err != nil {
 			er := models.LicenseError{
 				Status:    http.StatusNotFound,
-				Message:   fmt.Sprintf("license with shortname '%s' not found", obMapInput.Shortnames[i]),
+				Message:   fmt.Sprintf("license with shortname '%s' not found", inputShortnames[i]),
 				Error:     err.Error(),
 				Path:      c.Request.URL.Path,
 				Timestamp: time.Now().Format(time.RFC3339),
 			}
 			c.JSON(http.StatusNotFound, er)
-			return
+			return err
 		}
-		if err := db.DB.Model(&obligationMap).Where(&models.ObligationMap{ObligationPk: obligation.Id,
-			RfPk: license.Id}).First(&obligationMap).Error; err != nil {
+		if err := db.DB.Where(&models.ObligationMap{ObligationPk: obligation.Id, RfPk: license.Id}).First(&obligationMap).Error; err != nil {
 			// License not in map, add to insert slice
-			insertLicenseIds = append(insertLicenseIds, license.Id)
+			if errors.Is(err, gorm.ErrRecordNotFound) {
+				*insertLicenseIds = append(*insertLicenseIds, license.Id)
+			} else {
+				er := models.LicenseError{
+					Status:    http.StatusInternalServerError,
+					Message:   fmt.Sprintf("unable to fetch obligation maps for obligation with topic '%s'", obligation.Topic),
+					Error:     err.Error(),
+					Path:      c.Request.URL.Path,
+					Timestamp: time.Now().Format(time.RFC3339),
+				}
+				c.JSON(http.StatusInternalServerError, er)
+				return err
+			}
 		}
 		// License in request, remove from removal slice
-		removeLicenseIds = removeFromSlice(removeLicenseIds, license.Id)
+		*removeLicenseIds = removeFromSlice(*removeLicenseIds, license.Id)
 	}
 
-	username := c.GetString("username")
-
-	res := performObligationMapActions(username, obligation, removeLicenseIds, insertLicenseIds)
-
-	c.JSON(http.StatusOK, res)
+	return nil
 }
 
-// performObligationMapActions performs the actions for ObligationMap endpoint PATCH and PUT calls.
+// PerformObligationMapActions performs the actions for ObligationMap endpoint PATCH and PUT calls.
 // It takes the input of obligation which is being modified, list of licenses to be removed and added,
 // and the user making the changes. The function computes the changelog and returns the response.
-func performObligationMapActions(username string, obligation models.Obligation, removeLicenseIds []int64,
-	insertLicenseIds []int64) models.ObligationMapResponse {
+func PerformObligationMapActions(username string, obligation models.Obligation, removeLicenseIds []int64,
+	insertLicenseIds []int64) (*models.ObligationMapResponse, error) {
 	var oldObMaps []models.ObligationMap
 	var newObMaps []models.ObligationMap
 	var removeObMaps []models.ObligationMap
 	var insertObMaps []models.ObligationMap
 
-	getObligationMapsForObligation(obligation.Id, &oldObMaps)
+	if err := db.DB.Where(models.ObligationMap{ObligationPk: obligation.Id}).Find(&oldObMaps).Error; err != nil {
+		return nil, err
+	}
 
 	for i := 0; i < len(removeLicenseIds); i++ {
 		deleteItem := models.ObligationMap{
@@ -355,7 +426,9 @@ func performObligationMapActions(username string, obligation models.Obligation,
 			RfPk:         removeLicenseIds[i],
 		}
 		// Find the primary key to make delete faster
-		db.DB.Where(&deleteItem).First(&deleteItem)
+		if err := db.DB.Where(&deleteItem).First(&deleteItem).Error; err != nil {
+			return nil, err
+		}
 		removeObMaps = append(removeObMaps, deleteItem)
 	}
 	for i := 0; i < len(insertLicenseIds); i++ {
@@ -365,44 +438,48 @@ func performObligationMapActions(username string, obligation models.Obligation,
 		})
 	}
 
-	if len(removeObMaps) > 0 {
-		// Bulk delete removeObMaps from DB
-		db.DB.Delete(&removeObMaps)
-	}
-	if len(insertObMaps) > 0 {
-		// Bulk create insertObMaps in DB
-		db.DB.Create(&insertObMaps)
+	if err := db.DB.Transaction(func(tx *gorm.DB) error {
+		if len(removeObMaps) > 0 {
+			// Bulk delete removeObMaps from DB
+			if err := tx.Delete(&removeObMaps).Error; err != nil {
+				return err
+			}
+		}
+		if len(insertObMaps) > 0 {
+			// Bulk create insertObMaps in DB
+			if err := tx.Create(&insertObMaps).Error; err != nil {
+				return err
+			}
+		}
+
+		if err := tx.Where(models.ObligationMap{ObligationPk: obligation.Id}).Find(&newObMaps).Error; err != nil {
+			return err
+		}
+
+		return createObligationMapChangelog(tx, username, oldObMaps, newObMaps, &obligation)
+
+	}); err != nil {
+		return nil, err
 	}
 
-	getObligationMapsForObligation(obligation.Id, &newObMaps)
+	obMap, err := createObligationMapUser(obligation, newObMaps)
+	if err != nil {
+		return nil, err
+	}
 
 	res := models.ObligationMapResponse{
-		Data:   []models.ObligationMapUser{createObligationMapUser(obligation, newObMaps)},
+		Data:   []models.ObligationMapUser{*obMap},
 		Status: http.StatusOK,
 		Meta: models.PaginationMeta{
 			ResourceCount: 1,
 		},
 	}
 
-	var user models.User
-	db.DB.Where(models.User{Username: username}).First(&user)
-	audit := models.Audit{
-		UserId:    user.Id,
-		TypeId:    obligation.Id,
-		Timestamp: time.Now(),
-		Type:      "obligation_map",
-	}
-
-	db.DB.Create(&audit)
-
-	change := createObligationMapChangelog(oldObMaps, newObMaps, audit)
-	db.DB.Create(&change)
-	return res
+	return &res, nil
 }
 
 // createObligationMapChangelog creates the changelog for the obligation map changes.
-func createObligationMapChangelog(oldObMaps []models.ObligationMap, newObMaps []models.ObligationMap,
-	audit models.Audit) models.ChangeLog {
+func createObligationMapChangelog(tx *gorm.DB, username string, oldObMaps, newObMaps []models.ObligationMap, obligation *models.Obligation) error {
 	var oldLicenses []string
 	var newLicenses []string
 
@@ -416,12 +493,29 @@ func createObligationMapChangelog(oldObMaps []models.ObligationMap, newObMaps []
 	oldVal := strings.Join(oldLicenses, ",")
 	newVal := strings.Join(newLicenses, ",")
 	change := models.ChangeLog{
-		AuditId:      audit.Id,
 		Field:        "RfPk",
 		OldValue:     &oldVal,
 		UpdatedValue: &newVal,
 	}
-	return change
+
+	var user models.User
+	if err := tx.Where(models.User{Username: username}).First(&user).Error; err != nil {
+		return err
+	}
+
+	audit := models.Audit{
+		UserId:     user.Id,
+		TypeId:     obligation.Id,
+		Timestamp:  time.Now(),
+		Type:       "license",
+		ChangeLogs: []models.ChangeLog{change},
+	}
+
+	if err := tx.Create(&audit).Error; err != nil {
+		return err
+	}
+
+	return nil
 }
 
 // removeFromSlice removes the item from the slice if it exists.
@@ -433,16 +527,17 @@ func removeFromSlice[E string | int64](slice []E, item E) []E {
 }
 
 // createObligationMapUser creates the response data for the obligation map endpoint.
-func createObligationMapUser(obligation models.Obligation, obMaps []models.ObligationMap) models.ObligationMapUser {
+func createObligationMapUser(obligation models.Obligation, obMaps []models.ObligationMap) (*models.ObligationMapUser, error) {
 	var shortnameList []string
 	for i := 0; i < len(obMaps); i++ {
 		var license models.LicenseDB
-		licenseQuery := db.DB.Model(&license)
-		licenseQuery.Where(models.LicenseDB{Id: obMaps[i].RfPk}).First(&license)
+		if err := db.DB.Where(models.LicenseDB{Id: obMaps[i].RfPk}).First(&license).Error; err != nil {
+			return nil, err
+		}
 		shortnameList = append(shortnameList, license.Shortname)
 	}
-	return models.ObligationMapUser{
+	return &models.ObligationMapUser{
 		Topic:      obligation.Topic,
 		Shortnames: shortnameList,
-	}
+	}, nil
 }