From 9fca425f8c5e9428c440af99faeb67736f384543 Mon Sep 17 00:00:00 2001 From: Sam Lucidi Date: Thu, 18 Jul 2024 12:39:20 -0400 Subject: [PATCH] Omit associations during Create operations When inserting a new record, GORM will also attempt to insert records into tables refered to by many-to-many relationships on the inserted record. This commit attempts to ensure that associations are omitted when inserting records, and then the associations are added to the join tables separately. Signed-off-by: Sam Lucidi --- api/application.go | 38 ++++++++++++++++++++++++++++++-------- api/archetype.go | 35 +++++++++++++++++++++++++++++++++-- api/group.go | 12 +++++++++++- api/identity.go | 2 +- api/migrationwave.go | 17 ++++++++++++++++- api/stakeholder.go | 22 +++++++++++++++++++++- 6 files changed, 112 insertions(+), 14 deletions(-) diff --git a/api/application.go b/api/application.go index 7169d1c2e..f0dea0a07 100644 --- a/api/application.go +++ b/api/application.go @@ -58,22 +58,22 @@ func (h ApplicationHandler) AddRoutes(e *gin.Engine) { routeGroup.DELETE(ApplicationRoot, h.Delete) // Tags routeGroup = e.Group("/") - routeGroup.Use(Required("applications")) + routeGroup.Use(Required("applications"), Transaction) routeGroup.GET(ApplicationTagsRoot, h.TagList) routeGroup.GET(ApplicationTagsRoot+"/", h.TagList) routeGroup.POST(ApplicationTagsRoot, h.TagAdd) routeGroup.DELETE(ApplicationTagRoot, h.TagDelete) - routeGroup.PUT(ApplicationTagsRoot, h.TagReplace, Transaction) + routeGroup.PUT(ApplicationTagsRoot, h.TagReplace) // Facts routeGroup = e.Group("/") - routeGroup.Use(Required("applications.facts")) + routeGroup.Use(Required("applications.facts"), Transaction) routeGroup.GET(ApplicationFactsRoot, h.FactGet) routeGroup.GET(ApplicationFactsRoot+"/", h.FactGet) routeGroup.POST(ApplicationFactsRoot, h.FactCreate) routeGroup.GET(ApplicationFactRoot, h.FactGet) routeGroup.PUT(ApplicationFactRoot, h.FactPut) routeGroup.DELETE(ApplicationFactRoot, h.FactDelete) - routeGroup.PUT(ApplicationFactsRoot, h.FactPut, Transaction) + routeGroup.PUT(ApplicationFactsRoot, h.FactPut) // Bucket routeGroup = e.Group("/") routeGroup.Use(Required("applications.bucket")) @@ -84,11 +84,11 @@ func (h ApplicationHandler) AddRoutes(e *gin.Engine) { routeGroup.DELETE(AppBucketContentRoot, h.BucketDelete) // Stakeholders routeGroup = e.Group("/") - routeGroup.Use(Required("applications.stakeholders")) + routeGroup.Use(Required("applications.stakeholders"), Transaction) routeGroup.PUT(AppStakeholdersRoot, h.StakeholdersUpdate) // Assessments routeGroup = e.Group("/") - routeGroup.Use(Required("applications.assessments")) + routeGroup.Use(Required("applications.assessments"), Transaction) routeGroup.GET(AppAssessmentsRoot, h.AssessmentList) routeGroup.POST(AppAssessmentsRoot, h.AssessmentCreate) } @@ -210,11 +210,23 @@ func (h ApplicationHandler) Create(ctx *gin.Context) { } m := r.Model() m.CreateUser = h.BaseHandler.CurrentUser(ctx) - result := h.DB(ctx).Omit("Tags").Create(m) + result := h.DB(ctx).Omit(clause.Associations).Create(m) if result.Error != nil { _ = ctx.Error(result.Error) return } + db := h.DB(ctx).Model(m) + err = db.Association("Identities").Replace(m.Identities) + if err != nil { + _ = ctx.Error(err) + return + } + db = h.DB(ctx).Model(m) + err = db.Association("Contributors").Replace(m.Contributors) + if err != nil { + _ = ctx.Error(err) + return + } tags := []model.ApplicationTag{} if len(r.Tags) > 0 { @@ -1078,11 +1090,21 @@ func (h ApplicationHandler) AssessmentCreate(ctx *gin.Context) { assessment.PrepareForApplication(resolver, application, m) newAssessment = true } - result = h.DB(ctx).Create(m) + result = h.DB(ctx).Omit(clause.Associations).Create(m) if result.Error != nil { _ = ctx.Error(result.Error) return } + err = h.DB(ctx).Model(m).Association("Stakeholders").Replace("Stakeholders", m.Stakeholders) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("StakeholderGroups").Replace("StakeholderGroups", m.StakeholderGroups) + if err != nil { + _ = ctx.Error(err) + return + } if newAssessment { metrics.AssessmentsInitiated.Inc() } diff --git a/api/archetype.go b/api/archetype.go index d62d88a13..bcadf05cd 100644 --- a/api/archetype.go +++ b/api/archetype.go @@ -136,12 +136,33 @@ func (h ArchetypeHandler) Create(ctx *gin.Context) { } m := r.Model() m.CreateUser = h.CurrentUser(ctx) - result := h.DB(ctx).Create(m) + result := h.DB(ctx).Omit(clause.Associations).Create(m) if result.Error != nil { _ = ctx.Error(result.Error) return } + err = h.DB(ctx).Model(m).Association("Stakeholders").Replace("Stakeholders", m.Stakeholders) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("StakeholderGroups").Replace("StakeholderGroups", m.StakeholderGroups) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("CriteriaTags").Replace("CriteriaTags", m.CriteriaTags) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("Tags").Replace("Tags", m.Tags) + if err != nil { + _ = ctx.Error(err) + return + } + archetypes := []model.Archetype{} db := h.preLoad(h.DB(ctx), "Tags", "CriteriaTags") result = db.Find(&archetypes) @@ -319,11 +340,21 @@ func (h ArchetypeHandler) AssessmentCreate(ctx *gin.Context) { assessment.PrepareForArchetype(resolver, archetype, m) newAssessment = true } - result = h.DB(ctx).Create(m) + result = h.DB(ctx).Omit(clause.Associations).Create(m) if result.Error != nil { _ = ctx.Error(result.Error) return } + err = h.DB(ctx).Model(m).Association("Stakeholders").Replace("Stakeholders", m.Stakeholders) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("StakeholderGroups").Replace("StakeholderGroups", m.StakeholderGroups) + if err != nil { + _ = ctx.Error(err) + return + } if newAssessment { metrics.AssessmentsInitiated.Inc() } diff --git a/api/group.go b/api/group.go index 4e92337b3..dee59ff97 100644 --- a/api/group.go +++ b/api/group.go @@ -97,11 +97,21 @@ func (h StakeholderGroupHandler) Create(ctx *gin.Context) { } m := r.Model() m.CreateUser = h.BaseHandler.CurrentUser(ctx) - result := h.DB(ctx).Create(m) + result := h.DB(ctx).Omit(clause.Associations).Create(m) if result.Error != nil { _ = ctx.Error(result.Error) return } + err = h.DB(ctx).Model(m).Association("Stakeholders").Replace(m.Stakeholders) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("MigrationWaves").Replace(m.MigrationWaves) + if err != nil { + _ = ctx.Error(err) + return + } r.With(m) h.Respond(ctx, http.StatusCreated, r) diff --git a/api/identity.go b/api/identity.go index dbaf606dd..8c631c6f7 100644 --- a/api/identity.go +++ b/api/identity.go @@ -34,7 +34,7 @@ func (h IdentityHandler) AddRoutes(e *gin.Engine) { routeGroup.GET(IdentitiesRoot+"/", h.setDecrypted, h.List) routeGroup.POST(IdentitiesRoot, h.Create) routeGroup.GET(IdentityRoot, h.setDecrypted, h.Get) - routeGroup.PUT(IdentityRoot, h.Update, Transaction) + routeGroup.PUT(IdentityRoot, Transaction, h.Update) routeGroup.DELETE(IdentityRoot, h.Delete) } diff --git a/api/migrationwave.go b/api/migrationwave.go index 8651a369e..128bb485f 100644 --- a/api/migrationwave.go +++ b/api/migrationwave.go @@ -98,11 +98,26 @@ func (h MigrationWaveHandler) Create(ctx *gin.Context) { } m := r.Model() m.CreateUser = h.CurrentUser(ctx) - result := h.DB(ctx).Create(m) + result := h.DB(ctx).Omit(clause.Associations).Create(m) if result.Error != nil { _ = ctx.Error(result.Error) return } + err = h.DB(ctx).Model(m).Association("Applications").Replace("Applications", m.Applications) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("Stakeholders").Replace("Stakeholders", m.Stakeholders) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("StakeholderGroups").Replace("StakeholderGroups", m.StakeholderGroups) + if err != nil { + _ = ctx.Error(err) + return + } r.With(m) h.Respond(ctx, http.StatusCreated, r) diff --git a/api/stakeholder.go b/api/stakeholder.go index 4cb659310..1a551cf4d 100644 --- a/api/stakeholder.go +++ b/api/stakeholder.go @@ -97,11 +97,31 @@ func (h StakeholderHandler) Create(ctx *gin.Context) { } m := r.Model() m.CreateUser = h.BaseHandler.CurrentUser(ctx) - result := h.DB(ctx).Create(m) + result := h.DB(ctx).Omit(clause.Associations).Create(m) if result.Error != nil { _ = ctx.Error(result.Error) return } + err = h.DB(ctx).Model(m).Association("Groups").Replace(m.Groups) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("Owns").Replace(m.Owns) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("Contributes").Replace(m.Contributes) + if err != nil { + _ = ctx.Error(err) + return + } + err = h.DB(ctx).Model(m).Association("MigrationWaves").Replace(m.MigrationWaves) + if err != nil { + _ = ctx.Error(err) + return + } r.With(m) h.Respond(ctx, http.StatusCreated, r)