Skip to content

Commit

Permalink
Merge pull request #116 from hyperledger/rq-err
Browse files Browse the repository at this point in the history
Add error return to query modifier
  • Loading branch information
peterbroadhurst authored Jan 2, 2024
2 parents fa739e9 + 8809383 commit a81b68a
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 19 deletions.
4 changes: 2 additions & 2 deletions mocks/crudmocks/crud.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 11 additions & 7 deletions pkg/dbsql/crud.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,14 @@ func (c *CrudBase[T]) NewUpdateBuilder(ctx context.Context) ffapi.UpdateBuilder
func (c *CrudBase[T]) ModifyQuery(newModifier QueryModifier) CRUDQuery[T] {
cModified := *c
originalModifier := cModified.ReadQueryModifier
cModified.ReadQueryModifier = func(sb sq.SelectBuilder) sq.SelectBuilder {
cModified.ReadQueryModifier = func(sb sq.SelectBuilder) (_ sq.SelectBuilder, err error) {
if originalModifier != nil {
sb = originalModifier(sb)
sb, err = originalModifier(sb)
}
if newModifier != nil {
sb = newModifier(sb)
if err == nil && newModifier != nil {
sb, err = newModifier(sb)
}
return sb
return sb, err
}
return &cModified
}
Expand Down Expand Up @@ -595,7 +595,9 @@ func (c *CrudBase[T]) GetByID(ctx context.Context, id string, getOpts ...GetOpti
From(tableFrom).
Where(c.idFilter(id))
if c.ReadQueryModifier != nil {
query = c.ReadQueryModifier(query)
if query, err = c.ReadQueryModifier(query); err != nil {
return c.NilValue(), err
}
}

rows, _, err := c.DB.Query(ctx, c.Table, query)
Expand Down Expand Up @@ -701,7 +703,9 @@ func (c *CrudBase[T]) getManyScoped(ctx context.Context, tableFrom string, fi *f
return nil, nil, err
}
if c.ReadQueryModifier != nil {
query = c.ReadQueryModifier(query)
if query, err = c.ReadQueryModifier(query); err != nil {
return nil, nil, err
}
}

rows, tx, err := c.DB.Query(ctx, c.Table, query)
Expand Down
34 changes: 28 additions & 6 deletions pkg/dbsql/crud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ func newLinkableCollection(db *Database, ns string) *CrudBase[*TestLinkable] {
"description": "desc",
"crud": "crud_id",
},
ReadQueryModifier: func(query sq.SelectBuilder) sq.SelectBuilder {
return query.LeftJoin("crudables AS c ON c.id = l.crud_id")
ReadQueryModifier: func(query sq.SelectBuilder) (sq.SelectBuilder, error) {
return query.LeftJoin("crudables AS c ON c.id = l.crud_id"), nil
},
DefaultSort: func() []interface{} {
// Return an empty list
Expand Down Expand Up @@ -399,11 +399,11 @@ func TestCRUDWithDBEnd2End(t *testing.T) {
checkEqualExceptTimes(t, *c1, *c1copy)

// Check we get it back with custom modifiers
collection.ReadQueryModifier = func(sb sq.SelectBuilder) sq.SelectBuilder {
return sb.Where(sq.Eq{"ns": "ns1"})
collection.ReadQueryModifier = func(sb sq.SelectBuilder) (sq.SelectBuilder, error) {
return sb.Where(sq.Eq{"ns": "ns1"}), nil
}
c1copy, err = iCrud.ModifyQuery(func(sb sq.SelectBuilder) sq.SelectBuilder {
return sb.Where(sq.Eq{"field1": "hello1"})
c1copy, err = iCrud.ModifyQuery(func(sb sq.SelectBuilder) (sq.SelectBuilder, error) {
return sb.Where(sq.Eq{"field1": "hello1"}), nil
}).GetByName(ctx, *c1.Name)
assert.NoError(t, err)
checkEqualExceptTimes(t, *c1, *c1copy)
Expand Down Expand Up @@ -895,6 +895,17 @@ func TestGetByIDScanFail(t *testing.T) {
assert.NoError(t, mock.ExpectationsWereMet())
}

func TestGetByIDReqQueryModifierFail(t *testing.T) {
db, mock := NewMockProvider().UTInit()
tc := newCRUDCollection(&db.Database, "ns1")
tc.ReadQueryModifier = func(sb sq.SelectBuilder) (sq.SelectBuilder, error) {
return sb, fmt.Errorf("pop")
}
_, err := tc.GetByID(context.Background(), fftypes.NewUUID().String())
assert.Regexp(t, "pop", err)
assert.NoError(t, mock.ExpectationsWereMet())
}

func TestGetByNameNoNameSemantics(t *testing.T) {
db, _ := NewMockProvider().UTInit()
tc := newLinkableCollection(&db.Database, "ns1")
Expand Down Expand Up @@ -958,6 +969,17 @@ func TestGetManySelectFail(t *testing.T) {
assert.NoError(t, mock.ExpectationsWereMet())
}

func TestGetManyReadModifierFail(t *testing.T) {
db, mock := NewMockProvider().UTInit()
tc := newCRUDCollection(&db.Database, "ns1")
tc.ReadQueryModifier = func(sb sq.SelectBuilder) (sq.SelectBuilder, error) {
return sb, fmt.Errorf("pop")
}
_, _, err := tc.GetMany(context.Background(), CRUDableQueryFactory.NewFilter(context.Background()).And())
assert.Regexp(t, "pop", err)
assert.NoError(t, mock.ExpectationsWereMet())
}

func TestGetByManyScanFail(t *testing.T) {
db, mock := NewMockProvider().UTInit()
tc := newCRUDCollection(&db.Database, "ns1")
Expand Down
6 changes: 4 additions & 2 deletions pkg/dbsql/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type Database struct {
sequenceColumn string
}

type QueryModifier = func(sq.SelectBuilder) sq.SelectBuilder
type QueryModifier = func(sq.SelectBuilder) (sq.SelectBuilder, error)

// PreCommitAccumulator is a structure that can accumulate state during
// the transaction, then has a function that is called just before commit.
Expand Down Expand Up @@ -241,7 +241,9 @@ func (s *Database) CountQuery(ctx context.Context, table string, tx *TXWrapper,
}
q := sq.Select(fmt.Sprintf("COUNT(%s)", countExpr)).From(table).Where(fop)
if qm != nil {
q = qm(q)
if q, err = qm(q); err != nil {
return -1, err
}
}
sqlQuery, args, err := q.PlaceholderFormat(s.features.PlaceholderFormat).ToSql()
if err != nil {
Expand Down
14 changes: 12 additions & 2 deletions pkg/dbsql/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,14 +545,24 @@ func TestCountQueryWithExpr(t *testing.T) {
func TestCountQueryWithModifier(t *testing.T) {
s, mdb := NewMockProvider().UTInit()
mdb.ExpectQuery("^SELECT COUNT\\(\\*\\)").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow(10))
qm := func(sb sq.SelectBuilder) sq.SelectBuilder {
return sb.Where(sq.Eq{"col1": "val1"})
qm := func(sb sq.SelectBuilder) (sq.SelectBuilder, error) {
return sb.Where(sq.Eq{"col1": "val1"}), nil
}
_, err := s.CountQuery(context.Background(), "table1", nil, sq.Eq{"col1": "val1"}, qm, "")
assert.NoError(t, err)
assert.NoError(t, mdb.ExpectationsWereMet())
}

func TestCountQueryWithModifierErr(t *testing.T) {
s, mdb := NewMockProvider().UTInit()
mdb.ExpectQuery("^SELECT COUNT\\(\\*\\)").WillReturnRows(sqlmock.NewRows([]string{"col1"}).AddRow(10))
qm := func(sb sq.SelectBuilder) (sq.SelectBuilder, error) {
return sb, fmt.Errorf("pop")
}
_, err := s.CountQuery(context.Background(), "table1", nil, sq.Eq{"col1": "val1"}, qm, "")
assert.Regexp(t, "pop", err)
}

func TestQueryResSwallowError(t *testing.T) {
s, _ := NewMockProvider().UTInit()
res := s.QueryRes(context.Background(), "table1", nil, sq.Insert("wrong"), nil, &ffapi.FilterInfo{
Expand Down

0 comments on commit a81b68a

Please sign in to comment.