Skip to content

Commit

Permalink
Define richer "ColumnsExt" syntax for CRUD base
Browse files Browse the repository at this point in the history
Rather than continuing to add per-column flags in a piecemeal fashion,
define Columns as a map. Each column may override its select query,
add query modifiers (such as joins), and be declared as immutable or
read-only.

Signed-off-by: Andrew Richardson <[email protected]>
  • Loading branch information
awrichar committed Sep 20, 2023
1 parent 27e89cb commit 6f71cb1
Show file tree
Hide file tree
Showing 2 changed files with 266 additions and 76 deletions.
248 changes: 174 additions & 74 deletions pkg/dbsql/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,28 +112,57 @@ type CRUD[T Resource] interface {
Scoped(scope sq.Eq) *CrudBase[T] // allows dynamic scoping to a collection
}

type QueryModifier func(sq.SelectBuilder) sq.SelectBuilder
type IDValidator func(ctx context.Context, idStr string) error

type Column[T Resource] struct {
Select string
Immutable bool // disallow update
ReadOnly bool // disallow insert and update
QueryModifier QueryModifier
GetFieldPtr func(inst T) interface{}
}

type ColumnMap[T Resource] map[string]*Column[T]

type namedColumn[T Resource] struct {
Column[T]
Name string
}

type namedColumnList[T Resource] []*namedColumn[T]

func (c *Column[T]) getFieldValue(inst T) interface{} {
// Validate() will have checked this is safe for microservices (as long as they use that at build time in their UTs)
return reflect.ValueOf(c.GetFieldPtr(inst)).Elem().Interface()
}

type CrudBase[T Resource] struct {
DB *Database
Table string
Columns []string
FilterFieldMap map[string]string
TimesDisabled bool // no management of the time columns
PatchDisabled bool // allows non-pointer fields, but prevents UpdateSparse function
ImmutableColumns []string
NameField string // If supporting name semantics
QueryFactory ffapi.QueryFactory // Must be set when name is set
IDValidator func(ctx context.Context, idStr string) error // if IDs must conform to a pattern, such as a UUID (prebuilt UUIDValidator provided for that)
DB *Database
Table string
ReadTableAlias string

// Old-style column definitions
Columns []string
ReadOnlyColumns []string // appended to Columns (for query only)
ImmutableColumns []string // must be a subset of Columns (disallows update)
ReadQueryModifier QueryModifier
GetFieldPtr func(inst T, col string) interface{}

// New-style column definitions
ColumnsExt ColumnMap[T]

FilterFieldMap map[string]string
TimesDisabled bool // no management of the time columns
PatchDisabled bool // allows non-pointer fields, but prevents UpdateSparse function
NameField string // If supporting name semantics
QueryFactory ffapi.QueryFactory // Must be set when name is set
IDValidator IDValidator // if IDs must conform to a pattern, such as a UUID (prebuilt UUIDValidator provided for that)

NilValue func() T // nil value typed to T
NewInstance func() T
ScopedFilter func() sq.Eq
EventHandler func(id string, eventType ChangeEventType)
GetFieldPtr func(inst T, col string) interface{}

// Optional extensions
ReadTableAlias string
ReadOnlyColumns []string
ReadQueryModifier func(sq.SelectBuilder) sq.SelectBuilder
}

func (c *CrudBase[T]) Scoped(scope sq.Eq) *CrudBase[T] {
Expand Down Expand Up @@ -168,23 +197,30 @@ func (c *CrudBase[T]) Validate() {
fieldMap[ColumnCreated] = false
fieldMap[ColumnUpdated] = false
}
for _, col := range c.Columns {
if ok, set := fieldMap[col]; ok && set {
panic(fmt.Sprintf("%s is a duplicated column", col))
for _, col := range c.getCols() {
if ok, set := fieldMap[col.Name]; ok && set {
panic(fmt.Sprintf("%s is a duplicated column", col.Name))
}
fieldMap[col.Name] = true

if col.ReadOnly {
continue
}
if col.QueryModifier != nil {
panic(fmt.Sprintf("%s: query modifiers are only supported on read-only columns", col.Name))
}

fieldMap[col] = true
if col == c.DB.sequenceColumn {
panic(fmt.Sprintf("cannot have column named '%s'", c.DB.sequenceColumn))
if col.Immutable {
continue
}
fieldPtr := c.GetFieldPtr(inst, col)
fieldPtr := col.GetFieldPtr(inst)
ptrVal := reflect.ValueOf(fieldPtr)
if ptrVal.Kind() != reflect.Ptr || !isNil(ptrVal.Elem().Interface()) {
if !c.PatchDisabled {
panic(fmt.Sprintf("field %s does not seem to be a pointer type - prevents null-check for PATCH semantics functioning", col))
panic(fmt.Sprintf("field %s does not seem to be a pointer type - prevents null-check for PATCH semantics functioning", col.Name))
}
}
ptrs[col] = fieldPtr
ptrs[col.Name] = fieldPtr
}
for col, set := range fieldMap {
if !set {
Expand All @@ -194,7 +230,7 @@ func (c *CrudBase[T]) Validate() {
if !isNil(c.NilValue()) {
panic("NilValue() value must be nil")
}
if !isNil(c.GetFieldPtr(inst, fftypes.NewUUID().String())) {
if c.GetFieldPtr != nil && !isNil(c.GetFieldPtr(inst, fftypes.NewUUID().String())) {
panic("GetFieldPtr() must return nil for unknown column")
}
if c.NameField != "" && c.QueryFactory == nil {
Expand All @@ -218,16 +254,13 @@ func (c *CrudBase[T]) idFilter(id string) sq.Eq {
}

func (c *CrudBase[T]) buildUpdateList(_ context.Context, update sq.UpdateBuilder, inst T, includeNil bool) sq.UpdateBuilder {
colLoop:
for _, col := range c.Columns {
for _, immutable := range append(c.ImmutableColumns, ColumnID, ColumnCreated, ColumnUpdated, c.DB.sequenceColumn) {
if col == immutable {
continue colLoop
}
for _, col := range c.getCols() {
if col.Immutable || col.ReadOnly {
continue
}
value := c.getFieldValue(inst, col)
value := col.getFieldValue(inst)
if includeNil || !isNil(value) {
update = update.Set(col, value)
update = update.Set(col.Name, value)
}
}
if !c.TimesDisabled {
Expand All @@ -252,11 +285,6 @@ func (c *CrudBase[T]) updateFromInstance(ctx context.Context, tx *TXWrapper, ins
})
}

func (c *CrudBase[T]) getFieldValue(inst T, col string) interface{} {
// Validate() will have checked this is safe for microservices (as long as they use that at build time in their UTs)
return reflect.ValueOf(c.GetFieldPtr(inst, col)).Elem().Interface()
}

func (c *CrudBase[T]) setInsertTimestamps(inst T) {
if !c.TimesDisabled {
now := fftypes.Now()
Expand All @@ -279,12 +307,19 @@ func (c *CrudBase[T]) attemptInsert(ctx context.Context, tx *TXWrapper, inst T,
}

c.setInsertTimestamps(inst)
insert := sq.Insert(c.Table).Columns(c.Columns...)
values := make([]interface{}, len(c.Columns))
for i, col := range c.Columns {
values[i] = c.getFieldValue(inst, col)

cols := c.getCols()
colNames := make([]string, 0, len(cols))
values := make([]interface{}, 0, len(cols))
for _, col := range cols {
if col.ReadOnly {
continue
}
colNames = append(colNames, col.Name)
values = append(values, col.getFieldValue(inst))
}
insert = insert.Values(values...)
insert := sq.Insert(c.Table).Columns(colNames...).Values(values...)

seq, err := c.DB.InsertTxExt(ctx, c.Table, tx, insert,
func() {
if c.EventHandler != nil {
Expand Down Expand Up @@ -359,12 +394,22 @@ func (c *CrudBase[T]) InsertMany(ctx context.Context, instances []T, allowPartia
}
defer c.DB.RollbackTx(ctx, tx, autoCommit)
if c.DB.Features().MultiRowInsert {
insert := sq.Insert(c.Table).Columns(c.Columns...)
cols := make(namedColumnList[T], 0)
colNames := make([]string, 0)
for _, col := range c.getCols() {
if col.ReadOnly {
continue
}
cols = append(cols, col)
colNames = append(colNames, col.Name)
}

insert := sq.Insert(c.Table).Columns(colNames...)
for _, inst := range instances {
c.setInsertTimestamps(inst)
values := make([]interface{}, len(c.Columns))
for i, col := range c.Columns {
values[i] = c.getFieldValue(inst, col)
values := make([]interface{}, 0, len(cols))
for _, col := range cols {
values = append(values, col.getFieldValue(inst))
}
insert = insert.Values(values...)
}
Expand Down Expand Up @@ -443,14 +488,14 @@ func (c *CrudBase[T]) Replace(ctx context.Context, inst T, hooks ...PostCompleti
return c.DB.CommitTx(ctx, tx, autoCommit)
}

func (c *CrudBase[T]) scanRow(ctx context.Context, cols []string, row *sql.Rows) (T, error) {
func (c *CrudBase[T]) scanRow(ctx context.Context, cols namedColumnList[T], row *sql.Rows) (T, error) {
inst := c.NewInstance()
var seq int64
fieldPointers := make([]interface{}, len(cols))
fieldPointers[0] = &seq // The first column is alway the sequence
for i, col := range cols {
if i != 0 {
fieldPointers[i] = c.GetFieldPtr(inst, col)
fieldPointers[i] = col.GetFieldPtr(inst)
}
}
err := row.Scan(fieldPointers...)
Expand All @@ -461,40 +506,95 @@ func (c *CrudBase[T]) scanRow(ctx context.Context, cols []string, row *sql.Rows)
return inst, nil
}

func (c *CrudBase[T]) getReadCols(f *ffapi.FilterInfo) (tableFrom string, cols, readCols []string) {
cols = append([]string{c.DB.SequenceColumn()}, c.Columns...)
if c.ReadOnlyColumns != nil {
cols = append(cols, c.ReadOnlyColumns...)
func (c *CrudBase[T]) isImmutable(col string) bool {
for _, name := range c.ImmutableColumns {
if name == col {
return true
}
}
return false
}

func (c *CrudBase[T]) getCols() (cols namedColumnList[T]) {
cols = make(namedColumnList[T], 0)
cols = append(cols, &namedColumn[T]{
Name: c.DB.SequenceColumn(),
Column: Column[T]{
ReadOnly: true,
},
})
for _, name := range c.Columns {
nameCopy := name
cols = append(cols, &namedColumn[T]{
Name: name,
Column: Column[T]{
Immutable: c.isImmutable(name),
GetFieldPtr: func(inst T) interface{} {
return c.GetFieldPtr(inst, nameCopy)
},
},
})
}
for _, name := range c.ReadOnlyColumns {
nameCopy := name
cols = append(cols, &namedColumn[T]{
Name: name,
Column: Column[T]{
ReadOnly: true,
GetFieldPtr: func(inst T) interface{} {
return c.GetFieldPtr(inst, nameCopy)
},
},
})
}
for name, col := range c.ColumnsExt {
cols = append(cols, &namedColumn[T]{
Name: name,
Column: *col,
})
}
return cols
}

func (c *CrudBase[T]) getReadCols(f *ffapi.FilterInfo) (tableFrom string, cols namedColumnList[T], readCols []string, modifiers []QueryModifier) {
cols = c.getCols()
newCols := namedColumnList[T]{cols[0] /* first column is always the sequence, and must be */}
if f != nil && len(f.RequiredFields) > 0 {
newCols := []string{cols[0] /* first column is always the sequence, and must be */}
for _, requiredFieldName := range f.RequiredFields {
requiredColName := c.FilterFieldMap[requiredFieldName]
if requiredColName == "" {
requiredColName = requiredFieldName
}
for i, col := range cols {
if i > 0 /* idx==0 handled above */ && col == requiredColName {
if i > 0 /* idx==0 handled above */ && col.Name == requiredColName {
newCols = append(newCols, col)
}
}
}
cols = newCols
}
tableFrom = c.Table
readCols = cols
if c.ReadTableAlias != "" {
tableFrom = fmt.Sprintf("%s AS %s", c.Table, c.ReadTableAlias)
readCols = make([]string, len(cols))
for i, col := range cols {
if strings.Contains(col, ".") {
readCols[i] = cols[i]
} else {
readCols[i] = fmt.Sprintf("%s.%s", c.ReadTableAlias, col)
}
}
readCols = make([]string, len(cols))
for i, col := range cols {
switch {
case col.Select != "":
readCols[i] = col.Select
case c.ReadTableAlias == "" || strings.Contains(col.Name, "."):
readCols[i] = col.Name
default:
readCols[i] = fmt.Sprintf("%s.%s", c.ReadTableAlias, col.Name)
}
if col.QueryModifier != nil {
modifiers = append(modifiers, col.QueryModifier)
}
}
return tableFrom, cols, readCols
if c.ReadQueryModifier != nil {
modifiers = append(modifiers, c.ReadQueryModifier)
}
return tableFrom, cols, readCols, modifiers
}

func (c *CrudBase[T]) GetSequenceForID(ctx context.Context, id string) (seq int64, err error) {
Expand Down Expand Up @@ -526,18 +626,17 @@ func processGetOpts(ctx context.Context, getOpts []GetOption) (failNotFound bool
}

func (c *CrudBase[T]) GetByID(ctx context.Context, id string, getOpts ...GetOption) (inst T, err error) {

failNotFound, err := processGetOpts(ctx, getOpts)
if err != nil {
return c.NilValue(), err
}

tableFrom, cols, readCols := c.getReadCols(nil)
tableFrom, cols, readCols, modifiers := c.getReadCols(nil)
query := sq.Select(readCols...).
From(tableFrom).
Where(c.idFilter(id))
if c.ReadQueryModifier != nil {
query = c.ReadQueryModifier(query)
for _, mod := range modifiers {
query = mod(query)
}

rows, _, err := c.DB.Query(ctx, c.Table, query)
Expand Down Expand Up @@ -566,12 +665,13 @@ func (c *CrudBase[T]) GetMany(ctx context.Context, filter ffapi.Filter) (instanc
if err != nil {
return nil, nil, err
}
tableFrom, cols, readCols := c.getReadCols(fi)

tableFrom, cols, readCols, modifiers := c.getReadCols(fi)
var preconditions []sq.Sqlizer
if c.ScopedFilter != nil {
preconditions = []sq.Sqlizer{c.ScopedFilter()}
}
return c.getManyScoped(ctx, tableFrom, fi, cols, readCols, preconditions)
return c.getManyScoped(ctx, tableFrom, fi, cols, readCols, preconditions, modifiers)
}

// GetFirst returns a single match (like GetByID), but using a generic filter
Expand Down Expand Up @@ -630,16 +730,16 @@ func (c *CrudBase[T]) GetByUUIDOrName(ctx context.Context, uuidOrName string, ge
return c.GetByName(ctx, uuidOrName, getOpts...)
}

func (c *CrudBase[T]) getManyScoped(ctx context.Context, tableFrom string, fi *ffapi.FilterInfo, cols, readCols []string, preconditions []sq.Sqlizer) (instances []T, fr *ffapi.FilterResult, err error) {
func (c *CrudBase[T]) getManyScoped(ctx context.Context, tableFrom string, fi *ffapi.FilterInfo, cols namedColumnList[T], readCols []string, preconditions []sq.Sqlizer, modifiers []QueryModifier) (instances []T, fr *ffapi.FilterResult, err error) {
query, fop, fi, err := c.DB.filterSelectFinalized(ctx, c.ReadTableAlias, sq.Select(readCols...).From(tableFrom), fi, c.FilterFieldMap,
[]interface{}{
&ffapi.SortField{Field: c.DB.sequenceColumn, Descending: true},
}, preconditions...)
if err != nil {
return nil, nil, err
}
if c.ReadQueryModifier != nil {
query = c.ReadQueryModifier(query)
for _, mod := range modifiers {
query = mod(query)
}

rows, tx, err := c.DB.Query(ctx, c.Table, query)
Expand Down
Loading

0 comments on commit 6f71cb1

Please sign in to comment.