Skip to content

Commit

Permalink
Merge pull request #137 from kaleido-io/after-load
Browse files Browse the repository at this point in the history
Add custom ID field and AfterLoad hook
  • Loading branch information
awrichar authored May 9, 2024
2 parents d2ba7f7 + d9ac6df commit dab0027
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
24 changes: 21 additions & 3 deletions pkg/dbsql/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ type CrudBase[T Resource] struct {
TimesDisabled bool // no management of the time columns
PatchDisabled bool // allows non-pointer fields, but prevents UpdateSparse function
ImmutableColumns []string
IDField string // override default ID field
NameField string // If supporting name semantics
QueryFactory ffapi.QueryFactory // Must be set when name is set
DefaultSort func() []interface{} // optionally override the default sort - array of *ffapi.SortField or string
Expand All @@ -144,6 +145,7 @@ type CrudBase[T Resource] struct {
ReadTableAlias string
ReadOnlyColumns []string
ReadQueryModifier QueryModifier
AfterLoad func(ctx context.Context, inst T) error // perform final validation/formatting after an instance is loaded from db
}

func (c *CrudBase[T]) Scoped(scope sq.Eq) CRUD[T] {
Expand All @@ -159,6 +161,13 @@ func (c *CrudBase[T]) TableAlias() string {
return c.Table
}

func (c *CrudBase[T]) GetIDField() string {
if c.IDField != "" {
return c.IDField
}
return ColumnID
}

func (c *CrudBase[T]) GetQueryFactory() ffapi.QueryFactory {
return c.QueryFactory
}
Expand Down Expand Up @@ -212,7 +221,7 @@ func (c *CrudBase[T]) Validate() {
ptrs := map[string]interface{}{}
fieldMap := map[string]bool{
// Mandatory column checks
ColumnID: false,
c.GetIDField(): false,
}
if !c.TimesDisabled {
fieldMap[ColumnCreated] = false
Expand Down Expand Up @@ -260,7 +269,7 @@ func (c *CrudBase[T]) idFilter(id string) sq.Eq {
filter = sq.Eq{}
}
if c.ReadTableAlias != "" {
filter[fmt.Sprintf("%s.id", c.ReadTableAlias)] = id
filter[fmt.Sprintf("%s.%s", c.ReadTableAlias, c.GetIDField())] = id
} else {
filter["id"] = id
}
Expand All @@ -270,7 +279,7 @@ func (c *CrudBase[T]) idFilter(id string) sq.Eq {
func (c *CrudBase[T]) buildUpdateList(_ context.Context, update sq.UpdateBuilder, inst T, includeNil bool) sq.UpdateBuilder {
colLoop:
for _, col := range c.Columns {
for _, immutable := range append(c.ImmutableColumns, ColumnID, ColumnCreated, ColumnUpdated, c.DB.sequenceColumn) {
for _, immutable := range append(c.ImmutableColumns, c.GetIDField(), ColumnCreated, ColumnUpdated, c.DB.sequenceColumn) {
if col == immutable {
continue colLoop
}
Expand Down Expand Up @@ -626,6 +635,9 @@ func (c *CrudBase[T]) GetByID(ctx context.Context, id string, getOpts ...GetOpti
if err != nil {
return c.NilValue(), err
}
if c.AfterLoad != nil {
return inst, c.AfterLoad(ctx, inst)
}
return inst, nil
}

Expand Down Expand Up @@ -728,6 +740,12 @@ func (c *CrudBase[T]) getManyScoped(ctx context.Context, tableFrom string, fi *f
if err != nil {
return nil, nil, err
}
if c.AfterLoad != nil {
err = c.AfterLoad(ctx, inst)
if err != nil {
return nil, nil, err
}
}
instances = append(instances, inst)
}
log.L(ctx).Debugf("SQL<- GetMany(%s): %d", c.Table, len(instances))
Expand Down
38 changes: 38 additions & 0 deletions pkg/dbsql/crud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,19 @@ func TestCRUDWithDBEnd2End(t *testing.T) {
assert.Equal(t, Created, collection.events[0])
collection.events = nil

// Install an AfterLoad handler
afterLoadCalled := false
collection.AfterLoad = func(ctx context.Context, inst *TestCRUDable) error {
afterLoadCalled = true
return nil
}

// Check we get it back
c1copy, err := iCrud.GetByID(ctx, c1.ID.String())
assert.NoError(t, err)
checkEqualExceptTimes(t, *c1, *c1copy)
assert.True(t, afterLoadCalled)
collection.AfterLoad = nil

// Check we get it back by name
c1copy, err = iCrud.GetByName(ctx, *c1.Name)
Expand Down Expand Up @@ -412,6 +421,14 @@ func TestCRUDWithDBEnd2End(t *testing.T) {
assert.NoError(t, err)
checkEqualExceptTimes(t, *c1, *c1copy)

// Check AfterLoad error behavior
collection.AfterLoad = func(ctx context.Context, inst *TestCRUDable) error {
return fmt.Errorf("pop")
}
_, _, err = iCrud.GetMany(ctx, CRUDableQueryFactory.NewFilter(ctx).And())
assert.EqualError(t, err, "pop")
collection.AfterLoad = nil

// Upsert the existing row optimized
c1copy.Field1 = ptrTo("hello again - 1")
created, err := iCrud.Upsert(ctx, c1copy, UpsertOptimizationExisting)
Expand Down Expand Up @@ -1294,3 +1311,24 @@ func TestValidateNameSemanticsWithoutQueryFactory(t *testing.T) {
tc.Validate()
})
}

func TestCustomIDColumn(t *testing.T) {
db, _ := NewMockProvider().UTInit()
tc := &CrudBase[*TestCRUDable]{
DB: &db.Database,
NewInstance: func() *TestCRUDable { return &TestCRUDable{} },
NilValue: func() *TestCRUDable { return nil },
IDField: "f1",
Columns: []string{"f1"},
TimesDisabled: true,
PatchDisabled: true,
GetFieldPtr: func(inst *TestCRUDable, col string) interface{} {
if col == "id" {
var t *string
return &t
}
return nil
},
}
tc.Validate()
}

0 comments on commit dab0027

Please sign in to comment.