From 6b4537e234fe6f7b91215e8779b211f36f8c6d09 Mon Sep 17 00:00:00 2001 From: Geoffrey Ragot Date: Wed, 28 Feb 2024 17:48:35 +0100 Subject: [PATCH] fix(ledger): balance filter on accounts --- .../internal/storage/ledgerstore/accounts.go | 24 +++++++++++++++---- .../storage/ledgerstore/accounts_test.go | 8 +++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/components/ledger/internal/storage/ledgerstore/accounts.go b/components/ledger/internal/storage/ledgerstore/accounts.go index 579d713ddb..72c79b6b34 100644 --- a/components/ledger/internal/storage/ledgerstore/accounts.go +++ b/components/ledger/internal/storage/ledgerstore/accounts.go @@ -3,6 +3,7 @@ package ledgerstore import ( "context" "errors" + "fmt" "regexp" "github.com/formancehq/stack/libs/go-libs/bun/bunpaginate" @@ -54,6 +55,21 @@ func (store *Store) accountQueryContext(qb query.Builder, q GetAccountsQuery) (s balanceRegex := regexp.MustCompile("balance\\[(.*)\\]") return qb.Build(query.ContextFn(func(key, operator string, value any) (string, []any, error) { + convertOperatorToSQL := func() string { + switch operator { + case "$match": + return "=" + case "$lt": + return "<" + case "$gt": + return ">" + case "$lte": + return "<=" + case "$gte": + return ">=" + } + panic("unreachable") + } switch { case key == "address": // TODO: Should allow comparison operator only if segments not used @@ -83,21 +99,21 @@ func (store *Store) accountQueryContext(qb query.Builder, q GetAccountsQuery) (s case balanceRegex.Match([]byte(key)): match := balanceRegex.FindAllStringSubmatch(key, 2) - return `( + return fmt.Sprintf(`( select balance_from_volumes(post_commit_volumes) from moves where asset = ? and account_address = accounts.address and ledger = ? order by seq desc limit 1 - ) < ?`, []any{match[0][1], store.name, value}, nil + ) %s ?`, convertOperatorToSQL()), []any{match[0][1], store.name, value}, nil case key == "balance": - return `( + return fmt.Sprintf(`( select balance_from_volumes(post_commit_volumes) from moves where account_address = accounts.address and ledger = ? order by seq desc limit 1 - ) < ?`, []any{store.name, value}, nil + ) %s ?`, convertOperatorToSQL()), []any{store.name, value}, nil default: return "", nil, newErrInvalidQuery("unknown key '%s' when building query", key) } diff --git a/components/ledger/internal/storage/ledgerstore/accounts_test.go b/components/ledger/internal/storage/ledgerstore/accounts_test.go index 5c3cac9183..0db1be9296 100644 --- a/components/ledger/internal/storage/ledgerstore/accounts_test.go +++ b/components/ledger/internal/storage/ledgerstore/accounts_test.go @@ -163,6 +163,14 @@ func TestGetAccounts(t *testing.T) { )) require.NoError(t, err) require.Len(t, accounts.Data, 1) // world + + accounts, err = store.GetAccountsWithVolumes(context.Background(), NewGetAccountsQuery(NewPaginatedQueryOptions(PITFilterWithVolumes{}). + WithQueryBuilder(query.Gt("balance[USD]", 0)), + )) + require.NoError(t, err) + require.Len(t, accounts.Data, 2) // world + require.Equal(t, "account:1", accounts.Data[0].Account.Address) + require.Equal(t, "bank", accounts.Data[1].Account.Address) }) t.Run("list using filter invalid field", func(t *testing.T) { t.Parallel()