diff --git a/pkg/dbsql/crud.go b/pkg/dbsql/crud.go index 2c951d3..f47a5d1 100644 --- a/pkg/dbsql/crud.go +++ b/pkg/dbsql/crud.go @@ -129,6 +129,7 @@ type CrudBase[T Resource] struct { TimesDisabled bool // no management of the time columns PatchDisabled bool // allows non-pointer fields, but prevents UpdateSparse function ImmutableColumns []string + IDField string // override default ID field NameField string // If supporting name semantics QueryFactory ffapi.QueryFactory // Must be set when name is set DefaultSort func() []interface{} // optionally override the default sort - array of *ffapi.SortField or string @@ -144,6 +145,7 @@ type CrudBase[T Resource] struct { ReadTableAlias string ReadOnlyColumns []string ReadQueryModifier QueryModifier + AfterLoad func(ctx context.Context, inst T) error // perform final validation/formatting after an instance is loaded from db } func (c *CrudBase[T]) Scoped(scope sq.Eq) CRUD[T] { @@ -159,6 +161,13 @@ func (c *CrudBase[T]) TableAlias() string { return c.Table } +func (c *CrudBase[T]) GetIDField() string { + if c.IDField != "" { + return c.IDField + } + return ColumnID +} + func (c *CrudBase[T]) GetQueryFactory() ffapi.QueryFactory { return c.QueryFactory } @@ -212,7 +221,7 @@ func (c *CrudBase[T]) Validate() { ptrs := map[string]interface{}{} fieldMap := map[string]bool{ // Mandatory column checks - ColumnID: false, + c.GetIDField(): false, } if !c.TimesDisabled { fieldMap[ColumnCreated] = false @@ -260,7 +269,7 @@ func (c *CrudBase[T]) idFilter(id string) sq.Eq { filter = sq.Eq{} } if c.ReadTableAlias != "" { - filter[fmt.Sprintf("%s.id", c.ReadTableAlias)] = id + filter[fmt.Sprintf("%s.%s", c.ReadTableAlias, c.GetIDField())] = id } else { filter["id"] = id } @@ -270,7 +279,7 @@ func (c *CrudBase[T]) idFilter(id string) sq.Eq { func (c *CrudBase[T]) buildUpdateList(_ context.Context, update sq.UpdateBuilder, inst T, includeNil bool) sq.UpdateBuilder { colLoop: for _, col := range c.Columns { - for _, immutable := range append(c.ImmutableColumns, ColumnID, ColumnCreated, ColumnUpdated, c.DB.sequenceColumn) { + for _, immutable := range append(c.ImmutableColumns, c.GetIDField(), ColumnCreated, ColumnUpdated, c.DB.sequenceColumn) { if col == immutable { continue colLoop } @@ -626,6 +635,9 @@ func (c *CrudBase[T]) GetByID(ctx context.Context, id string, getOpts ...GetOpti if err != nil { return c.NilValue(), err } + if c.AfterLoad != nil { + return inst, c.AfterLoad(ctx, inst) + } return inst, nil } @@ -728,6 +740,12 @@ func (c *CrudBase[T]) getManyScoped(ctx context.Context, tableFrom string, fi *f if err != nil { return nil, nil, err } + if c.AfterLoad != nil { + err = c.AfterLoad(ctx, inst) + if err != nil { + return nil, nil, err + } + } instances = append(instances, inst) } log.L(ctx).Debugf("SQL<- GetMany(%s): %d", c.Table, len(instances)) diff --git a/pkg/dbsql/crud_test.go b/pkg/dbsql/crud_test.go index a07998f..cc76393 100644 --- a/pkg/dbsql/crud_test.go +++ b/pkg/dbsql/crud_test.go @@ -377,10 +377,19 @@ func TestCRUDWithDBEnd2End(t *testing.T) { assert.Equal(t, Created, collection.events[0]) collection.events = nil + // Install an AfterLoad handler + afterLoadCalled := false + collection.AfterLoad = func(ctx context.Context, inst *TestCRUDable) error { + afterLoadCalled = true + return nil + } + // Check we get it back c1copy, err := iCrud.GetByID(ctx, c1.ID.String()) assert.NoError(t, err) checkEqualExceptTimes(t, *c1, *c1copy) + assert.True(t, afterLoadCalled) + collection.AfterLoad = nil // Check we get it back by name c1copy, err = iCrud.GetByName(ctx, *c1.Name) @@ -412,6 +421,14 @@ func TestCRUDWithDBEnd2End(t *testing.T) { assert.NoError(t, err) checkEqualExceptTimes(t, *c1, *c1copy) + // Check AfterLoad error behavior + collection.AfterLoad = func(ctx context.Context, inst *TestCRUDable) error { + return fmt.Errorf("pop") + } + _, _, err = iCrud.GetMany(ctx, CRUDableQueryFactory.NewFilter(ctx).And()) + assert.EqualError(t, err, "pop") + collection.AfterLoad = nil + // Upsert the existing row optimized c1copy.Field1 = ptrTo("hello again - 1") created, err := iCrud.Upsert(ctx, c1copy, UpsertOptimizationExisting) @@ -1294,3 +1311,24 @@ func TestValidateNameSemanticsWithoutQueryFactory(t *testing.T) { tc.Validate() }) } + +func TestCustomIDColumn(t *testing.T) { + db, _ := NewMockProvider().UTInit() + tc := &CrudBase[*TestCRUDable]{ + DB: &db.Database, + NewInstance: func() *TestCRUDable { return &TestCRUDable{} }, + NilValue: func() *TestCRUDable { return nil }, + IDField: "f1", + Columns: []string{"f1"}, + TimesDisabled: true, + PatchDisabled: true, + GetFieldPtr: func(inst *TestCRUDable, col string) interface{} { + if col == "id" { + var t *string + return &t + } + return nil + }, + } + tc.Validate() +}