diff --git a/core/chains/evm/logpoller/disabled.go b/core/chains/evm/logpoller/disabled.go index 3a1e4ba4fe6..3d97cd5d55a 100644 --- a/core/chains/evm/logpoller/disabled.go +++ b/core/chains/evm/logpoller/disabled.go @@ -118,8 +118,8 @@ func (d disabled) LogsDataWordBetween(ctx context.Context, eventSig common.Hash, return nil, ErrDisabled } -func (d disabled) FilteredLogs(_ query.KeyFilter, _ query.LimitAndSort) ([]Log, error) { - return nil, nil +func (d disabled) FilteredLogs(_ context.Context, _ query.KeyFilter, _ query.LimitAndSort) ([]Log, error) { + return nil, ErrDisabled } func (d disabled) FindLCA(ctx context.Context) (*LogPollerBlock, error) { diff --git a/core/chains/evm/logpoller/log_poller.go b/core/chains/evm/logpoller/log_poller.go index b1d7d1da623..768b7b57b7b 100644 --- a/core/chains/evm/logpoller/log_poller.go +++ b/core/chains/evm/logpoller/log_poller.go @@ -67,7 +67,8 @@ type LogPoller interface { LogsDataWordGreaterThan(ctx context.Context, eventSig common.Hash, address common.Address, wordIndex int, wordValueMin common.Hash, confs evmtypes.Confirmations) ([]Log, error) LogsDataWordBetween(ctx context.Context, eventSig common.Hash, address common.Address, wordIndexMin, wordIndexMax int, wordValue common.Hash, confs evmtypes.Confirmations) ([]Log, error) - FilteredLogs(filter query.KeyFilter, limitAndSrt query.LimitAndSort) ([]Log, error) + // chainlink-common query filtering + FilteredLogs(ctx context.Context, filter query.KeyFilter, limitAndSort query.LimitAndSort) ([]Log, error) } type LogPollerTest interface { @@ -1522,6 +1523,6 @@ func EvmWord(i uint64) common.Hash { return common.BytesToHash(b) } -func (lp *logPoller) FilteredLogs(queryFilter query.KeyFilter, sortAndLimit query.LimitAndSort) ([]Log, error) { - return lp.orm.FilteredLogs(queryFilter, sortAndLimit) +func (lp *logPoller) FilteredLogs(ctx context.Context, queryFilter query.KeyFilter, limitAndSort query.LimitAndSort) ([]Log, error) { + return lp.orm.FilteredLogs(ctx, queryFilter, limitAndSort) } diff --git a/core/chains/evm/logpoller/mocks/log_poller.go b/core/chains/evm/logpoller/mocks/log_poller.go index e30aac56c8b..ed0ce092295 100644 --- a/core/chains/evm/logpoller/mocks/log_poller.go +++ b/core/chains/evm/logpoller/mocks/log_poller.go @@ -59,9 +59,9 @@ func (_m *LogPoller) DeleteLogsAndBlocksAfter(ctx context.Context, start int64) return r0 } -// FilteredLogs provides a mock function with given fields: filter, limitAndSrt -func (_m *LogPoller) FilteredLogs(filter query.KeyFilter, limitAndSrt query.LimitAndSort) ([]logpoller.Log, error) { - ret := _m.Called(filter, limitAndSrt) +// FilteredLogs provides a mock function with given fields: ctx, filter, limitAndSort +func (_m *LogPoller) FilteredLogs(ctx context.Context, filter query.KeyFilter, limitAndSort query.LimitAndSort) ([]logpoller.Log, error) { + ret := _m.Called(ctx, filter, limitAndSort) if len(ret) == 0 { panic("no return value specified for FilteredLogs") @@ -69,19 +69,19 @@ func (_m *LogPoller) FilteredLogs(filter query.KeyFilter, limitAndSrt query.Limi var r0 []logpoller.Log var r1 error - if rf, ok := ret.Get(0).(func(query.KeyFilter, query.LimitAndSort) ([]logpoller.Log, error)); ok { - return rf(filter, limitAndSrt) + if rf, ok := ret.Get(0).(func(context.Context, query.KeyFilter, query.LimitAndSort) ([]logpoller.Log, error)); ok { + return rf(ctx, filter, limitAndSort) } - if rf, ok := ret.Get(0).(func(query.KeyFilter, query.LimitAndSort) []logpoller.Log); ok { - r0 = rf(filter, limitAndSrt) + if rf, ok := ret.Get(0).(func(context.Context, query.KeyFilter, query.LimitAndSort) []logpoller.Log); ok { + r0 = rf(ctx, filter, limitAndSort) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]logpoller.Log) } } - if rf, ok := ret.Get(1).(func(query.KeyFilter, query.LimitAndSort) error); ok { - r1 = rf(filter, limitAndSrt) + if rf, ok := ret.Get(1).(func(context.Context, query.KeyFilter, query.LimitAndSort) error); ok { + r1 = rf(ctx, filter, limitAndSort) } else { r1 = ret.Error(1) } diff --git a/core/chains/evm/logpoller/orm.go b/core/chains/evm/logpoller/orm.go index 3cd8db74849..3f9608891b1 100644 --- a/core/chains/evm/logpoller/orm.go +++ b/core/chains/evm/logpoller/orm.go @@ -61,8 +61,9 @@ type ORM interface { SelectLogsDataWordRange(ctx context.Context, address common.Address, eventSig common.Hash, wordIndex int, wordValueMin, wordValueMax common.Hash, confs evmtypes.Confirmations) ([]Log, error) SelectLogsDataWordGreaterThan(ctx context.Context, address common.Address, eventSig common.Hash, wordIndex int, wordValueMin common.Hash, confs evmtypes.Confirmations) ([]Log, error) SelectLogsDataWordBetween(ctx context.Context, address common.Address, eventSig common.Hash, wordIndexMin int, wordIndexMax int, wordValue common.Hash, confs evmtypes.Confirmations) ([]Log, error) + // FilteredLogs accepts chainlink-common filtering DSL. - FilteredLogs(filter query.KeyFilter, sortAndLimit query.LimitAndSort) ([]Log, error) + FilteredLogs(ctx context.Context, filter query.KeyFilter, limitAndSort query.LimitAndSort) ([]Log, error) } type DSORM struct { @@ -92,10 +93,10 @@ func (o *DSORM) new(ds sqlutil.DataSource) *DSORM { return NewORM(o.chainID, ds, // InsertBlock is idempotent to support replays. func (o *DSORM) InsertBlock(ctx context.Context, blockHash common.Hash, blockNumber int64, blockTimestamp time.Time, finalizedBlock int64) error { args, err := newQueryArgs(o.chainID). - withCustomHashArg("block_hash", blockHash). - withCustomArg("block_number", blockNumber). - withCustomArg("block_timestamp", blockTimestamp). - withCustomArg("finalized_block_number", finalizedBlock). + withField("block_hash", blockHash). + withField("block_number", blockNumber). + withField("block_timestamp", blockTimestamp). + withField("finalized_block_number", finalizedBlock). toArgs() if err != nil { return err @@ -115,7 +116,7 @@ func (o *DSORM) InsertBlock(ctx context.Context, blockHash common.Hash, blockNum func (o *DSORM) InsertFilter(ctx context.Context, filter Filter) (err error) { topicArrays := []types.HashArray{filter.Topic2, filter.Topic3, filter.Topic4} args, err := newQueryArgs(o.chainID). - withCustomArg("name", filter.Name). + withField("name", filter.Name). withRetention(filter.Retention). withMaxLogsKept(filter.MaxLogsKept). withLogsPerBlock(filter.LogsPerBlock). @@ -930,8 +931,8 @@ func (o *DSORM) SelectIndexedLogsWithSigsExcluding(ctx context.Context, sigA, si withTopicIndex(topicIndex). withStartBlock(startBlock). withEndBlock(endBlock). - withCustomHashArg("sigA", sigA). - withCustomHashArg("sigB", sigB). + withField("sigA", sigA). + withField("sigB", sigB). withConfs(confs). toArgs() if err != nil { @@ -970,9 +971,28 @@ func (o *DSORM) SelectIndexedLogsWithSigsExcluding(ctx context.Context, sigA, si return logs, nil } -func (o *DSORM) FilteredLogs(_ query.KeyFilter, _ query.LimitAndSort) ([]Log, error) { - //TODO implement me - panic("implement me") +func (o *DSORM) FilteredLogs(ctx context.Context, filter query.KeyFilter, limitAndSort query.LimitAndSort) ([]Log, error) { + qs, args, err := (&pgDSLParser{}).buildQuery(o.chainID, filter.Expressions, limitAndSort) + if err != nil { + return nil, err + } + + values, err := args.toArgs() + if err != nil { + return nil, err + } + + query, sqlArgs, err := o.ds.BindNamed(qs, values) + if err != nil { + return nil, err + } + + var logs []Log + if err = o.ds.SelectContext(ctx, &logs, query, sqlArgs...); err != nil { + return nil, err + } + + return logs, nil } func nestedBlockNumberQuery(confs evmtypes.Confirmations) string { diff --git a/core/chains/evm/logpoller/orm_test.go b/core/chains/evm/logpoller/orm_test.go index 7e6ce9aada2..f6edf48df1a 100644 --- a/core/chains/evm/logpoller/orm_test.go +++ b/core/chains/evm/logpoller/orm_test.go @@ -19,6 +19,8 @@ import ( "github.com/stretchr/testify/require" "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types/query" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -859,18 +861,65 @@ func TestORM_SelectLogsWithSigsByBlockRangeFilter(t *testing.T) { } require.NoError(t, o1.InsertLogs(ctx, inputLogs)) + filter := func(sigs []common.Hash, startBlock, endBlock int64) query.KeyFilter { + filters := []query.Expression{ + logpoller.NewAddressFilter(sourceAddr), + } + + if len(sigs) > 0 { + exp := make([]query.Expression, len(sigs)) + for idx, val := range sigs { + exp[idx] = logpoller.NewEventSigFilter(val) + } + + filters = append(filters, query.Expression{ + BoolExpression: query.BoolExpression{ + Expressions: exp, + BoolOperator: query.OR, + }, + }) + } + + filters = append(filters, query.Expression{ + BoolExpression: query.BoolExpression{ + Expressions: []query.Expression{ + query.Block(uint64(startBlock), primitives.Gte), + query.Block(uint64(endBlock), primitives.Lte), + }, + BoolOperator: query.AND, + }, + }) + + return query.KeyFilter{ + Expressions: filters, + } + } + + limiter := query.LimitAndSort{ + SortBy: []query.SortBy{query.NewSortBySequence(query.Asc)}, + } + + assertion := func(t *testing.T, logs []logpoller.Log, err error, startBlock, endBlock int64) { + require.NoError(t, err) + assert.Len(t, logs, 4) + for _, l := range logs { + assert.Equal(t, sourceAddr, l.Address, "wrong log address") + assert.True(t, bytes.Equal(topic.Bytes(), l.EventSig.Bytes()) || bytes.Equal(topic2.Bytes(), l.EventSig.Bytes()), "wrong log topic") + assert.True(t, l.BlockNumber >= startBlock && l.BlockNumber <= endBlock) + } + } + startBlock, endBlock := int64(10), int64(15) logs, err := o1.SelectLogsWithSigs(ctx, startBlock, endBlock, sourceAddr, []common.Hash{ topic, topic2, }) - require.NoError(t, err) - assert.Len(t, logs, 4) - for _, l := range logs { - assert.Equal(t, sourceAddr, l.Address, "wrong log address") - assert.True(t, bytes.Equal(topic.Bytes(), l.EventSig.Bytes()) || bytes.Equal(topic2.Bytes(), l.EventSig.Bytes()), "wrong log topic") - assert.True(t, l.BlockNumber >= startBlock && l.BlockNumber <= endBlock) - } + + assertion(t, logs, err, startBlock, endBlock) + + logs, err = th.ORM.FilteredLogs(ctx, filter([]common.Hash{topic, topic2}, startBlock, endBlock), limiter) + + assertion(t, logs, err, startBlock, endBlock) } func TestORM_DeleteBlocksBefore(t *testing.T) { @@ -1404,29 +1453,92 @@ func TestSelectLogsCreatedAfter(t *testing.T) { }, }, } + + filter := func(timestamp time.Time, confs evmtypes.Confirmations, topicIdx int, topicVals []common.Hash) query.KeyFilter { + var queryConfs primitives.ConfirmationLevel + + switch confs { + case evmtypes.Finalized: + queryConfs = primitives.Finalized + case evmtypes.Unconfirmed: + queryConfs = primitives.Unconfirmed + default: + fmt.Println("default") + queryConfs = primitives.ConfirmationLevel(confs) + } + + filters := []query.Expression{ + logpoller.NewAddressFilter(address), + logpoller.NewEventSigFilter(event), + } + + if len(topicVals) > 0 { + exp := make([]query.Expression, len(topicVals)) + for idx, val := range topicVals { + exp[idx] = logpoller.NewEventByTopicFilter(common.Hash{}, uint64(topicIdx), []primitives.ValueComparator{ + {Value: val.String(), Operator: primitives.Eq}, + }) + } + + filters = append(filters, query.Expression{ + BoolExpression: query.BoolExpression{ + Expressions: exp, + BoolOperator: query.OR, + }, + }) + } + + filters = append(filters, []query.Expression{ + query.Timestamp(uint64(timestamp.Unix()), primitives.Gt), + query.Confirmation(queryConfs), + }...) + + return query.KeyFilter{ + Expressions: filters, + } + } + + limiter := query.LimitAndSort{ + SortBy: []query.SortBy{ + query.NewSortBySequence(query.Asc), + }, + } + + assertion := func(t *testing.T, logs []logpoller.Log, err error, exp []expectedLog) { + require.NoError(t, err) + require.Len(t, logs, len(exp)) + + for i, log := range logs { + assert.Equal(t, exp[i].block, log.BlockNumber) + assert.Equal(t, exp[i].log, log.LogIndex) + } + } + for _, tt := range tests { - t.Run("SelectLogsCreatedAfter"+tt.name, func(t *testing.T) { + t.Run(tt.name, func(t *testing.T) { logs, err := th.ORM.SelectLogsCreatedAfter(ctx, address, event, tt.after, tt.confs) - require.NoError(t, err) - require.Len(t, logs, len(tt.expectedLogs)) - for i, log := range logs { - require.Equal(t, tt.expectedLogs[i].block, log.BlockNumber) - require.Equal(t, tt.expectedLogs[i].log, log.LogIndex) - } - }) + assertion(t, logs, err, tt.expectedLogs) - t.Run("SelectIndexedLogsCreatedAfter"+tt.name, func(t *testing.T) { - logs, err := th.ORM.SelectIndexedLogsCreatedAfter(ctx, address, event, 1, []common.Hash{event}, tt.after, tt.confs) - require.NoError(t, err) - require.Len(t, logs, len(tt.expectedLogs)) + logs, err = th.ORM.FilteredLogs(ctx, filter(tt.after, tt.confs, 0, nil), limiter) - for i, log := range logs { - require.Equal(t, tt.expectedLogs[i].block, log.BlockNumber) - require.Equal(t, tt.expectedLogs[i].log, log.LogIndex) - } + assertion(t, logs, err, tt.expectedLogs) }) } + + t.Run("SelectIndexedLogsCreatedAfter", func(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logs, err := th.ORM.SelectIndexedLogsCreatedAfter(ctx, address, event, 1, []common.Hash{event}, tt.after, tt.confs) + + assertion(t, logs, err, tt.expectedLogs) + + logs, err = th.ORM.FilteredLogs(ctx, filter(tt.after, tt.confs, 1, []common.Hash{event}), limiter) + + assertion(t, logs, err, tt.expectedLogs) + }) + } + }) } func TestNestedLogPollerBlocksQuery(t *testing.T) { @@ -1612,6 +1724,12 @@ func TestSelectLogsDataWordBetween(t *testing.T) { logpoller.NewLogPollerBlock(utils.RandomBytes32(), 10, time.Now(), 1), ) require.NoError(t, err) + limiter := query.LimitAndSort{ + SortBy: []query.SortBy{ + query.NewSortByBlock(query.Asc), + query.NewSortBySequence(query.Asc), + }, + } tests := []struct { name string @@ -1640,15 +1758,40 @@ func TestSelectLogsDataWordBetween(t *testing.T) { }, } + wordFilter := func(word uint64) query.KeyFilter { + return query.KeyFilter{ + Expressions: []query.Expression{ + logpoller.NewAddressFilter(address), + logpoller.NewEventSigFilter(eventSig), + logpoller.NewEventByWordFilter(eventSig, 0, []primitives.ValueComparator{ + {Value: logpoller.EvmWord(word).Hex(), Operator: primitives.Lte}, + }), + logpoller.NewEventByWordFilter(eventSig, 1, []primitives.ValueComparator{ + {Value: logpoller.EvmWord(word).Hex(), Operator: primitives.Gte}, + }), + query.Confirmation(primitives.Unconfirmed), + }, + } + } + + assertion := func(t *testing.T, logs []logpoller.Log, err error, expected []int64) { + require.NoError(t, err) + assert.Len(t, logs, len(expected)) + + for index := range logs { + assert.Equal(t, expected[index], logs[index].BlockNumber) + } + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - logs, err1 := th.ORM.SelectLogsDataWordBetween(ctx, address, eventSig, 0, 1, logpoller.EvmWord(tt.wordValue), evmtypes.Unconfirmed) - assert.NoError(t, err1) - assert.Len(t, logs, len(tt.expectedLogs)) + logs, err := th.ORM.SelectLogsDataWordBetween(ctx, address, eventSig, 0, 1, logpoller.EvmWord(tt.wordValue), evmtypes.Unconfirmed) - for index := range logs { - assert.Equal(t, tt.expectedLogs[index], logs[index].BlockNumber) - } + assertion(t, logs, err, tt.expectedLogs) + + logs, err = th.ORM.FilteredLogs(ctx, wordFilter(tt.wordValue), limiter) + + assertion(t, logs, err, tt.expectedLogs) }) } } diff --git a/core/chains/evm/logpoller/parser.go b/core/chains/evm/logpoller/parser.go new file mode 100644 index 00000000000..01ec737658a --- /dev/null +++ b/core/chains/evm/logpoller/parser.go @@ -0,0 +1,499 @@ +package logpoller + +import ( + "errors" + "fmt" + "math/big" + "strconv" + "strings" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/hexutil" + + "github.com/smartcontractkit/chainlink-common/pkg/types/query" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" + + evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" +) + +const ( + sequenceFieldAlias = "sequence" + sequenceField = "CONCAT(block_number, '-', tx_hash, '-', log_index)" + blockFieldName = "block_number" + timestampFieldName = "block_timestamp" + txHashFieldName = "tx_hash" + eventSigFieldName = "event_sig" + chainIDFieldName = "evm_chain_id" +) + +var ( + ErrUnexpectedCursorFormat = errors.New("unexpected cursor format") +) + +// pgDSLParser is a visitor that builds a postgres query and arguments from a commontypes.QueryFilter +type pgDSLParser struct { + args *queryArgs + + // transient properties expected to be set and reset with every expression + expression string + err error +} + +var _ primitives.Visitor = (*pgDSLParser)(nil) + +func (v *pgDSLParser) Comparator(p primitives.Comparator) {} + +func (v *pgDSLParser) Block(p primitives.Block) { + cmp, err := cmpOpToString(p.Operator) + if err != nil { + v.err = err + + return + } + + v.expression = fmt.Sprintf( + "%s %s :%s", + blockFieldName, + cmp, + v.args.withIndexedField(blockFieldName, p.Block), + ) +} + +func (v *pgDSLParser) Confirmations(p primitives.Confirmations) { + switch p.ConfirmationLevel { + case primitives.Finalized: + v.expression = v.nestedConfQuery(true, 0) + case primitives.Unconfirmed: + // Unconfirmed in the evm relayer is an alias to the case of 0 confirmations + // set the level to the number 0 and fallthrough to the default case + p.ConfirmationLevel = primitives.ConfirmationLevel(0) + + fallthrough + default: + // the default case passes the confirmation level as a number directly to a subquery + v.expression = v.nestedConfQuery(false, uint64(evmtypes.Confirmations(p.ConfirmationLevel))) + } +} + +func (v *pgDSLParser) Timestamp(p primitives.Timestamp) { + cmp, err := cmpOpToString(p.Operator) + if err != nil { + v.err = err + + return + } + + v.expression = fmt.Sprintf( + "%s %s :%s", + timestampFieldName, + cmp, + v.args.withIndexedField(timestampFieldName, time.Unix(int64(p.Timestamp), 0)), + ) +} + +func (v *pgDSLParser) TxHash(p primitives.TxHash) { + bts, err := hexutil.Decode(p.TxHash) + if errors.Is(err, hexutil.ErrMissingPrefix) { + bts, err = hexutil.Decode("0x" + p.TxHash) + } + + if err != nil { + v.err = err + + return + } + + txHash := common.BytesToHash(bts) + + v.expression = fmt.Sprintf( + "%s = :%s", + txHashFieldName, + v.args.withIndexedField(txHashFieldName, txHash), + ) +} + +func (v *pgDSLParser) VisitAddressFilter(p *addressFilter) { + v.expression = fmt.Sprintf( + "address = :%s", + v.args.withIndexedField("address", p.address), + ) +} + +func (v *pgDSLParser) VisitEventSigFilter(p *eventSigFilter) { + v.expression = fmt.Sprintf( + "%s = :%s", + eventSigFieldName, + v.args.withIndexedField(eventSigFieldName, p.eventSig), + ) +} + +func (v *pgDSLParser) nestedConfQuery(finalized bool, confs uint64) string { + var ( + from = "FROM evm.log_poller_blocks " + where = "WHERE evm_chain_id = :evm_chain_id " + order = "ORDER BY block_number DESC LIMIT 1" + selector string + ) + + if finalized { + selector = "SELECT finalized_block_number " + } else { + selector = fmt.Sprintf("SELECT greatest(block_number - :%s, 0) ", + v.args.withIndexedField("confs", confs), + ) + } + + var builder strings.Builder + + builder.WriteString(selector) + builder.WriteString(from) + builder.WriteString(where) + builder.WriteString(order) + + return fmt.Sprintf("%s <= (%s)", blockFieldName, builder.String()) +} + +func (v *pgDSLParser) VisitEventByWordFilter(p *eventByWordFilter) { + if len(p.ValueComparers) > 0 { + wordIdx := v.args.withIndexedField("word_index", p.WordIndex) + + comps := make([]string, len(p.ValueComparers)) + for idx, comp := range p.ValueComparers { + comps[idx], v.err = makeComp(comp, v.args, "word_value", wordIdx, "substring(data from 32*:%s+1 for 32) %s :%s") + if v.err != nil { + return + } + } + + v.expression = strings.Join(comps, " AND ") + } +} + +func (v *pgDSLParser) VisitEventTopicsByValueFilter(p *eventByTopicFilter) { + if len(p.ValueComparers) > 0 { + topicIdx := v.args.withIndexedField("topic_index", p.Topic) + + comps := make([]string, len(p.ValueComparers)) + for idx, comp := range p.ValueComparers { + comps[idx], v.err = makeComp(comp, v.args, "topic_value", topicIdx, "topics[:%s] %s :%s") + if v.err != nil { + return + } + } + + v.expression = strings.Join(comps, " AND ") + } +} + +func makeComp(comp primitives.ValueComparator, args *queryArgs, field, subfield, pattern string) (string, error) { + cmp, err := cmpOpToString(comp.Operator) + if err != nil { + return "", err + } + + return fmt.Sprintf( + pattern, + subfield, + cmp, + args.withIndexedField(field, common.HexToHash(comp.Value)), + ), nil +} + +func (v *pgDSLParser) buildQuery(chainID *big.Int, expressions []query.Expression, limiter query.LimitAndSort) (string, *queryArgs, error) { + // reset transient properties + v.args = newQueryArgs(chainID) + v.expression = "" + v.err = nil + + // build the query string + clauses := []string{"SELECT evm.logs.* FROM evm.logs"} + + where, err := v.whereClause(expressions, limiter) + if err != nil { + return "", nil, err + } + + clauses = append(clauses, where) + + order, err := v.orderClause(limiter) + if err != nil { + return "", nil, err + } + + if len(order) > 0 { + clauses = append(clauses, order) + } + + limit := v.limitClause(limiter) + if len(limit) > 0 { + clauses = append(clauses, limit) + } + + return strings.Join(clauses, " "), v.args, nil +} + +func (v *pgDSLParser) whereClause(expressions []query.Expression, limiter query.LimitAndSort) (string, error) { + segment := "WHERE evm_chain_id = :evm_chain_id" + + if len(expressions) > 0 { + exp, err := v.combineExpressions(expressions, query.AND) + if err != nil { + return "", err + } + + segment = fmt.Sprintf("%s AND %s", segment, exp) + } + + if limiter.HasCursorLimit() { + var op string + switch limiter.Limit.CursorDirection { + case query.CursorFollowing: + op = ">" + case query.CursorPrevious: + op = "<" + default: + return "", errors.New("invalid cursor direction") + } + + block, txHash, logIdx, err := valuesFromCursor(limiter.Limit.Cursor) + if err != nil { + return "", err + } + + segment = fmt.Sprintf("%s AND block_number %s= :cursor_block AND tx_hash %s= :cursor_txhash AND log_index %s :cursor_log_index", segment, op, op, op) + + v.args.withField("cursor_block_number", block). + withField("cursor_txhash", common.HexToHash(txHash)). + withField("cursor_log_index", logIdx) + } + + return segment, nil +} + +func (v *pgDSLParser) orderClause(limiter query.LimitAndSort) (string, error) { + sorting := limiter.SortBy + + if limiter.HasCursorLimit() && !limiter.HasSequenceSort() { + var dir query.SortDirection + + switch limiter.Limit.CursorDirection { + case query.CursorFollowing: + dir = query.Asc + case query.CursorPrevious: + dir = query.Desc + default: + return "", errors.New("unexpected cursor direction") + } + + sorting = append(sorting, query.NewSortBySequence(dir)) + } + + if len(sorting) == 0 { + return "", nil + } + + sort := make([]string, len(sorting)) + + for idx, sorted := range sorting { + var name string + + order, err := orderToString(sorted.GetDirection()) + if err != nil { + return "", err + } + + switch sorted.(type) { + case query.SortByBlock: + name = blockFieldName + case query.SortBySequence: + sort[idx] = fmt.Sprintf("block_number %s, tx_hash %s, log_index %s", order, order, order) + + continue + case query.SortByTimestamp: + name = timestampFieldName + default: + return "", errors.New("unexpected sort by") + } + + sort[idx] = fmt.Sprintf("%s %s", name, order) + } + + return fmt.Sprintf("ORDER BY %s", strings.Join(sort, ", ")), nil +} + +func (v *pgDSLParser) limitClause(limiter query.LimitAndSort) string { + if !limiter.HasCursorLimit() && limiter.Limit.Count == 0 { + return "" + } + + return fmt.Sprintf("LIMIT %d", limiter.Limit.Count) +} + +func (v *pgDSLParser) getLastExpression() (string, error) { + exp := v.expression + err := v.err + + v.expression = "" + v.err = nil + + return exp, err +} + +func (v *pgDSLParser) combineExpressions(expressions []query.Expression, op query.BoolOperator) (string, error) { + grouped := len(expressions) > 1 + clauses := make([]string, len(expressions)) + + for idx, exp := range expressions { + if exp.IsPrimitive() { + exp.Primitive.Accept(v) + + clause, err := v.getLastExpression() + if err != nil { + return "", err + } + + clauses[idx] = clause + } else { + clause, err := v.combineExpressions(exp.BoolExpression.Expressions, exp.BoolExpression.BoolOperator) + if err != nil { + return "", err + } + + clauses[idx] = clause + } + } + + output := strings.Join(clauses, fmt.Sprintf(" %s ", op.String())) + + if grouped { + output = fmt.Sprintf("(%s)", output) + } + + return output, nil +} + +func cmpOpToString(op primitives.ComparisonOperator) (string, error) { + switch op { + case primitives.Eq: + return "=", nil + case primitives.Neq: + return "!=", nil + case primitives.Gt: + return ">", nil + case primitives.Gte: + return ">=", nil + case primitives.Lt: + return "<", nil + case primitives.Lte: + return "<=", nil + default: + return "", errors.New("invalid comparison operator") + } +} + +func orderToString(dir query.SortDirection) (string, error) { + switch dir { + case query.Asc: + return "ASC", nil + case query.Desc: + return "DESC", nil + default: + return "", errors.New("invalid sort direction") + } +} + +func valuesFromCursor(cursor string) (int64, string, int, error) { + parts := strings.Split(cursor, "-") + if len(parts) != 3 { + return 0, "", 0, fmt.Errorf("%w: must be composed as block-txhash-logindex", ErrUnexpectedCursorFormat) + } + + block, err := strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return 0, "", 0, fmt.Errorf("%w: block number not parsable as int64", ErrUnexpectedCursorFormat) + } + + logIdx, err := strconv.ParseInt(parts[2], 10, 64) + if err != nil { + return 0, "", 0, fmt.Errorf("%w: log index not parsable as int", ErrUnexpectedCursorFormat) + } + + return block, parts[1], int(logIdx), nil +} + +type addressFilter struct { + address common.Address +} + +func NewAddressFilter(address common.Address) query.Expression { + return query.Expression{ + Primitive: &addressFilter{address: address}, + } +} + +func (f *addressFilter) Accept(visitor primitives.Visitor) { + switch v := visitor.(type) { + case *pgDSLParser: + v.VisitAddressFilter(f) + } +} + +type eventSigFilter struct { + eventSig common.Hash +} + +func NewEventSigFilter(hash common.Hash) query.Expression { + return query.Expression{ + Primitive: &eventSigFilter{eventSig: hash}, + } +} + +func (f *eventSigFilter) Accept(visitor primitives.Visitor) { + switch v := visitor.(type) { + case *pgDSLParser: + v.VisitEventSigFilter(f) + } +} + +type eventByWordFilter struct { + EventSig common.Hash + WordIndex uint8 + ValueComparers []primitives.ValueComparator +} + +func NewEventByWordFilter(eventSig common.Hash, wordIndex uint8, valueComparers []primitives.ValueComparator) query.Expression { + return query.Expression{Primitive: &eventByWordFilter{ + EventSig: eventSig, + WordIndex: wordIndex, + ValueComparers: valueComparers, + }} +} + +func (f *eventByWordFilter) Accept(visitor primitives.Visitor) { + switch v := visitor.(type) { + case *pgDSLParser: + v.VisitEventByWordFilter(f) + } +} + +type eventByTopicFilter struct { + EventSig common.Hash + Topic uint64 + ValueComparers []primitives.ValueComparator +} + +func NewEventByTopicFilter(eventSig common.Hash, topicIndex uint64, valueComparers []primitives.ValueComparator) query.Expression { + return query.Expression{Primitive: &eventByTopicFilter{ + EventSig: eventSig, + Topic: topicIndex, + ValueComparers: valueComparers, + }} +} + +func (f *eventByTopicFilter) Accept(visitor primitives.Visitor) { + switch v := visitor.(type) { + case *pgDSLParser: + v.VisitEventTopicsByValueFilter(f) + } +} diff --git a/core/chains/evm/logpoller/parser_test.go b/core/chains/evm/logpoller/parser_test.go new file mode 100644 index 00000000000..f4edcf5191f --- /dev/null +++ b/core/chains/evm/logpoller/parser_test.go @@ -0,0 +1,339 @@ +package logpoller + +import ( + "math/big" + "testing" + + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smartcontractkit/chainlink-common/pkg/types/query" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" +) + +func assertArgs(t *testing.T, args *queryArgs, numVals int) { + values, err := args.toArgs() + + assert.Len(t, values, numVals) + assert.NoError(t, err) +} + +func TestDSLParser(t *testing.T) { + t.Parallel() + + t.Run("query with no filters no order and no limit", func(t *testing.T) { + t.Parallel() + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{} + limiter := query.LimitAndSort{} + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + + require.NoError(t, err) + assert.Equal(t, "SELECT evm.logs.* FROM evm.logs WHERE evm_chain_id = :evm_chain_id", result) + + assertArgs(t, args, 1) + }) + + t.Run("query with cursor and no order by", func(t *testing.T) { + t.Parallel() + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{ + NewAddressFilter(common.HexToAddress("0x42")), + NewEventSigFilter(common.HexToHash("0x21")), + } + limiter := query.NewLimitAndSort(query.CursorLimit("10-0x42-5", query.CursorFollowing, 20)) + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "AND (address = :address_0 AND event_sig = :event_sig_0) " + + "AND block_number >= :cursor_block AND tx_hash >= :cursor_txhash AND log_index > :cursor_log_index " + + "ORDER BY block_number ASC, tx_hash ASC, log_index ASC " + + "LIMIT 20" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 6) + }) + + t.Run("query with limit and no order by", func(t *testing.T) { + t.Parallel() + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{ + NewAddressFilter(common.HexToAddress("0x42")), + NewEventSigFilter(common.HexToHash("0x21")), + } + limiter := query.NewLimitAndSort(query.CountLimit(20)) + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "AND (address = :address_0 AND event_sig = :event_sig_0) " + + "LIMIT 20" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 3) + }) + + t.Run("query with order by sequence no cursor no limit", func(t *testing.T) { + t.Parallel() + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{} + limiter := query.NewLimitAndSort(query.Limit{}, query.NewSortBySequence(query.Desc)) + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "ORDER BY block_number DESC, tx_hash DESC, log_index DESC" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 1) + }) + + t.Run("query with multiple order by no limit", func(t *testing.T) { + t.Parallel() + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{} + limiter := query.NewLimitAndSort(query.Limit{}, query.NewSortByBlock(query.Asc), query.NewSortByTimestamp(query.Desc)) + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "ORDER BY block_number ASC, block_timestamp DESC" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 1) + }) + + t.Run("basic query with default primitives no order by and cursor", func(t *testing.T) { + t.Parallel() + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{ + query.Timestamp(10, primitives.Eq), + query.TxHash(common.HexToHash("0x84").String()), + query.Block(99, primitives.Neq), + query.Confirmation(primitives.Finalized), + query.Confirmation(primitives.Unconfirmed), + } + limiter := query.NewLimitAndSort(query.CursorLimit("10-0x42-20", query.CursorPrevious, 20)) + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "AND (block_timestamp = :block_timestamp_0 " + + "AND tx_hash = :tx_hash_0 " + + "AND block_number != :block_number_0 " + + "AND block_number <= " + + "(SELECT finalized_block_number FROM evm.log_poller_blocks WHERE evm_chain_id = :evm_chain_id ORDER BY block_number DESC LIMIT 1) " + + "AND block_number <= (SELECT greatest(block_number - :confs_0, 0) FROM evm.log_poller_blocks WHERE evm_chain_id = :evm_chain_id ORDER BY block_number DESC LIMIT 1)) " + + "AND block_number <= :cursor_block AND tx_hash <= :cursor_txhash AND log_index < :cursor_log_index " + + "ORDER BY block_number DESC, tx_hash DESC, log_index DESC LIMIT 20" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 8) + }) + + t.Run("query for finality", func(t *testing.T) { + t.Parallel() + + t.Run("finalized", func(t *testing.T) { + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{query.Confirmation(primitives.Finalized)} + limiter := query.LimitAndSort{} + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "AND block_number <= (SELECT finalized_block_number FROM evm.log_poller_blocks WHERE evm_chain_id = :evm_chain_id ORDER BY block_number DESC LIMIT 1)" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 1) + }) + + t.Run("unconfirmed", func(t *testing.T) { + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{query.Confirmation(primitives.Unconfirmed)} + limiter := query.LimitAndSort{} + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "AND block_number <= (SELECT greatest(block_number - :confs_0, 0) FROM evm.log_poller_blocks WHERE evm_chain_id = :evm_chain_id ORDER BY block_number DESC LIMIT 1)" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 2) + }) + }) + + t.Run("query for event by word", func(t *testing.T) { + t.Parallel() + + wordFilter := NewEventByWordFilter(common.HexToHash("0x42"), 8, []primitives.ValueComparator{ + {Value: "", Operator: primitives.Gt}, + }) + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{wordFilter} + limiter := query.LimitAndSort{} + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "AND substring(data from 32*:word_index_0+1 for 32) > :word_value_0" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 3) + }) + + t.Run("query for event topic", func(t *testing.T) { + t.Parallel() + + topicFilter := NewEventByTopicFilter(common.HexToHash("0x42"), 8, []primitives.ValueComparator{ + {Value: "a", Operator: primitives.Gt}, + {Value: "b", Operator: primitives.Lt}, + }) + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{topicFilter} + limiter := query.LimitAndSort{} + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "AND topics[:topic_index_0] > :topic_value_0 AND topics[:topic_index_0] < :topic_value_1" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 4) + }) + + // nested query -> a & (b || c) + t.Run("nested query", func(t *testing.T) { + t.Parallel() + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{ + {BoolExpression: query.BoolExpression{ + Expressions: []query.Expression{ + query.Timestamp(10, primitives.Gte), + {BoolExpression: query.BoolExpression{ + Expressions: []query.Expression{ + query.TxHash(common.HexToHash("0x84").Hex()), + query.Confirmation(primitives.Unconfirmed), + }, + BoolOperator: query.OR, + }}, + }, + BoolOperator: query.AND, + }}, + } + limiter := query.LimitAndSort{} + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "AND (block_timestamp >= :block_timestamp_0 " + + "AND (tx_hash = :tx_hash_0 " + + "OR block_number <= (SELECT greatest(block_number - :confs_0, 0) FROM evm.log_poller_blocks WHERE evm_chain_id = :evm_chain_id ORDER BY block_number DESC LIMIT 1)))" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 4) + }) + + // deep nested query -> a & (b || (c & d)) + t.Run("nested query deep", func(t *testing.T) { + t.Parallel() + + wordFilter := NewEventByWordFilter(common.HexToHash("0x42"), 8, []primitives.ValueComparator{ + {Value: "a", Operator: primitives.Gt}, + {Value: "b", Operator: primitives.Lte}, + }) + + parser := &pgDSLParser{} + chainID := big.NewInt(1) + expressions := []query.Expression{ + {BoolExpression: query.BoolExpression{ + Expressions: []query.Expression{ + query.Timestamp(10, primitives.Eq), + {BoolExpression: query.BoolExpression{ + Expressions: []query.Expression{ + query.TxHash(common.HexToHash("0x84").Hex()), + {BoolExpression: query.BoolExpression{ + Expressions: []query.Expression{ + query.Confirmation(primitives.Unconfirmed), + wordFilter, + }, + BoolOperator: query.AND, + }}, + }, + BoolOperator: query.OR, + }}, + }, + BoolOperator: query.AND, + }}, + } + limiter := query.LimitAndSort{} + + result, args, err := parser.buildQuery(chainID, expressions, limiter) + expected := "SELECT evm.logs.* " + + "FROM evm.logs " + + "WHERE evm_chain_id = :evm_chain_id " + + "AND (block_timestamp = :block_timestamp_0 " + + "AND (tx_hash = :tx_hash_0 " + + "OR (block_number <= (SELECT greatest(block_number - :confs_0, 0) FROM evm.log_poller_blocks WHERE evm_chain_id = :evm_chain_id ORDER BY block_number DESC LIMIT 1) " + + "AND substring(data from 32*:word_index_0+1 for 32) > :word_value_0 " + + "AND substring(data from 32*:word_index_0+1 for 32) <= :word_value_1)))" + + require.NoError(t, err) + assert.Equal(t, expected, result) + + assertArgs(t, args, 7) + }) +} diff --git a/core/chains/evm/logpoller/query.go b/core/chains/evm/logpoller/query.go index 244552dbec8..b39a260f034 100644 --- a/core/chains/evm/logpoller/query.go +++ b/core/chains/evm/logpoller/query.go @@ -28,16 +28,18 @@ func concatBytes[T bytesProducer](byteSlice []T) [][]byte { // queryArgs is a helper for building the arguments to a postgres query created by DSORM // Besides the convenience methods, it also keeps track of arguments validation and sanitization. type queryArgs struct { - args map[string]interface{} - err []error + args map[string]any + idxLookup map[string]uint8 + err []error } func newQueryArgs(chainId *big.Int) *queryArgs { return &queryArgs{ - args: map[string]interface{}{ + args: map[string]any{ "evm_chain_id": ubig.New(chainId), }, - err: []error{}, + idxLookup: make(map[string]uint8), + err: []error{}, } } @@ -47,16 +49,62 @@ func newQueryArgsForEvent(chainId *big.Int, address common.Address, eventSig com withEventSig(eventSig) } +func (q *queryArgs) withField(fieldName string, value any) *queryArgs { + _, args := q.withIndexableField(fieldName, value, false) + + return args +} + +func (q *queryArgs) withIndexedField(fieldName string, value any) string { + field, _ := q.withIndexableField(fieldName, value, true) + + return field +} + +func (q *queryArgs) withIndexableField(fieldName string, value any, addIndex bool) (string, *queryArgs) { + if addIndex { + idx := q.nextIdx(fieldName) + idxName := fmt.Sprintf("%s_%d", fieldName, idx) + + q.idxLookup[fieldName] = uint8(idx) + fieldName = idxName + } + + switch typed := value.(type) { + case common.Hash: + q.args[fieldName] = typed.Bytes() + case []common.Hash: + q.args[fieldName] = concatBytes(typed) + case types.HashArray: + q.args[fieldName] = concatBytes(typed) + case []common.Address: + q.args[fieldName] = concatBytes(typed) + default: + q.args[fieldName] = typed + } + + return fieldName, q +} + +func (q *queryArgs) nextIdx(baseFieldName string) int { + idx, ok := q.idxLookup[baseFieldName] + if !ok { + return 0 + } + + return int(idx) + 1 +} + func (q *queryArgs) withEventSig(eventSig common.Hash) *queryArgs { - return q.withCustomHashArg("event_sig", eventSig) + return q.withField("event_sig", eventSig) } func (q *queryArgs) withEventSigArray(eventSigs []common.Hash) *queryArgs { - return q.withCustomArg("event_sig_array", concatBytes(eventSigs)) + return q.withField("event_sig_array", eventSigs) } func (q *queryArgs) withTopicArray(topicValues types.HashArray, topicNum uint64) *queryArgs { - return q.withCustomArg(fmt.Sprintf("topic%d", topicNum), concatBytes(topicValues)) + return q.withField(fmt.Sprintf("topic%d", topicNum), topicValues) } func (q *queryArgs) withTopicArrays(topic2Vals types.HashArray, topic3Vals types.HashArray, topic4Vals types.HashArray) *queryArgs { @@ -66,47 +114,47 @@ func (q *queryArgs) withTopicArrays(topic2Vals types.HashArray, topic3Vals types } func (q *queryArgs) withAddress(address common.Address) *queryArgs { - return q.withCustomArg("address", address) + return q.withField("address", address) } func (q *queryArgs) withAddressArray(addresses []common.Address) *queryArgs { - return q.withCustomArg("address_array", concatBytes(addresses)) + return q.withField("address_array", addresses) } func (q *queryArgs) withStartBlock(startBlock int64) *queryArgs { - return q.withCustomArg("start_block", startBlock) + return q.withField("start_block", startBlock) } func (q *queryArgs) withEndBlock(endBlock int64) *queryArgs { - return q.withCustomArg("end_block", endBlock) + return q.withField("end_block", endBlock) } func (q *queryArgs) withWordIndex(wordIndex int) *queryArgs { - return q.withCustomArg("word_index", wordIndex) + return q.withField("word_index", wordIndex) } func (q *queryArgs) withWordValueMin(wordValueMin common.Hash) *queryArgs { - return q.withCustomHashArg("word_value_min", wordValueMin) + return q.withField("word_value_min", wordValueMin) } func (q *queryArgs) withWordValueMax(wordValueMax common.Hash) *queryArgs { - return q.withCustomHashArg("word_value_max", wordValueMax) + return q.withField("word_value_max", wordValueMax) } func (q *queryArgs) withWordIndexMin(wordIndex int) *queryArgs { - return q.withCustomArg("word_index_min", wordIndex) + return q.withField("word_index_min", wordIndex) } func (q *queryArgs) withWordIndexMax(wordIndex int) *queryArgs { - return q.withCustomArg("word_index_max", wordIndex) + return q.withField("word_index_max", wordIndex) } func (q *queryArgs) withWordValue(wordValue common.Hash) *queryArgs { - return q.withCustomHashArg("word_value", wordValue) + return q.withField("word_value", wordValue) } func (q *queryArgs) withConfs(confs evmtypes.Confirmations) *queryArgs { - return q.withCustomArg("confs", confs) + return q.withField("confs", confs) } func (q *queryArgs) withTopicIndex(index int) *queryArgs { @@ -115,53 +163,45 @@ func (q *queryArgs) withTopicIndex(index int) *queryArgs { q.err = append(q.err, fmt.Errorf("invalid index for topic: %d", index)) } // Add 1 since postgresql arrays are 1-indexed. - return q.withCustomArg("topic_index", index+1) + return q.withField("topic_index", index+1) } func (q *queryArgs) withTopicValueMin(valueMin common.Hash) *queryArgs { - return q.withCustomHashArg("topic_value_min", valueMin) + return q.withField("topic_value_min", valueMin) } func (q *queryArgs) withTopicValueMax(valueMax common.Hash) *queryArgs { - return q.withCustomHashArg("topic_value_max", valueMax) + return q.withField("topic_value_max", valueMax) } func (q *queryArgs) withTopicValues(values []common.Hash) *queryArgs { - return q.withCustomArg("topic_values", concatBytes(values)) + return q.withField("topic_values", concatBytes(values)) } func (q *queryArgs) withBlockTimestampAfter(after time.Time) *queryArgs { - return q.withCustomArg("block_timestamp_after", after) + return q.withField("block_timestamp_after", after) } func (q *queryArgs) withTxHash(hash common.Hash) *queryArgs { - return q.withCustomHashArg("tx_hash", hash) + return q.withField("tx_hash", hash) } func (q *queryArgs) withRetention(retention time.Duration) *queryArgs { - return q.withCustomArg("retention", retention) + return q.withField("retention", retention) } func (q *queryArgs) withLogsPerBlock(logsPerBlock uint64) *queryArgs { - return q.withCustomArg("logs_per_block", logsPerBlock) + return q.withField("logs_per_block", logsPerBlock) } func (q *queryArgs) withMaxLogsKept(maxLogsKept uint64) *queryArgs { - return q.withCustomArg("max_logs_kept", maxLogsKept) + return q.withField("max_logs_kept", maxLogsKept) } -func (q *queryArgs) withCustomHashArg(name string, arg common.Hash) *queryArgs { - return q.withCustomArg(name, arg.Bytes()) -} - -func (q *queryArgs) withCustomArg(name string, arg any) *queryArgs { - q.args[name] = arg - return q -} - -func (q *queryArgs) toArgs() (map[string]interface{}, error) { +func (q *queryArgs) toArgs() (map[string]any, error) { if len(q.err) > 0 { return nil, errors.Join(q.err...) } + return q.args, nil } diff --git a/core/chains/evm/logpoller/query_test.go b/core/chains/evm/logpoller/query_test.go index 832cbbfcb00..67472ecead4 100644 --- a/core/chains/evm/logpoller/query_test.go +++ b/core/chains/evm/logpoller/query_test.go @@ -33,14 +33,14 @@ func Test_QueryArgs(t *testing.T) { }, { name: "custom argument", - queryArgs: newEmptyArgs().withCustomArg("arg", "value"), + queryArgs: newEmptyArgs().withField("arg", "value"), want: map[string]interface{}{ "arg": "value", }, }, { name: "hash converted to bytes", - queryArgs: newEmptyArgs().withCustomHashArg("hash", common.Hash{}), + queryArgs: newEmptyArgs().withField("hash", common.Hash{}), want: map[string]interface{}{ "hash": make([]byte, 32), }, diff --git a/core/services/relay/evm/chain_reader.go b/core/services/relay/evm/chain_reader.go index 85c837e55bc..4a8c3691d1a 100644 --- a/core/services/relay/evm/chain_reader.go +++ b/core/services/relay/evm/chain_reader.go @@ -12,7 +12,6 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/types/query" - "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" commonservices "github.com/smartcontractkit/chainlink-common/pkg/services" commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" @@ -216,7 +215,7 @@ func (cr *chainReader) addEvent(contractName, eventName string, a abi.ABI, chain inputInfo: inputInfo, inputModifier: inputModifier, codecTopicInfo: codecTopicInfo, - topicsInfo: make(map[string]topicInfo), + topics: make(map[string]topicDetail), eventDataWords: chainReaderDefinition.GenericDataWordNames, id: wrapItemType(contractName, eventName, false) + uuid.NewString(), } @@ -227,11 +226,12 @@ func (cr *chainReader) addEvent(contractName, eventName string, a abi.ABI, chain for topicIndex, topic := range event.Inputs { genericTopicName, ok := chainReaderDefinition.GenericTopicNames[topic.Name] if ok { - eb.topicsInfo[genericTopicName] = topicInfo{ - Argument: topic, - topicIndex: uint64(topicIndex), + eb.topics[genericTopicName] = topicDetail{ + Argument: topic, + Index: uint64(topicIndex), } } + // this way querying by key/s values comparison can find its bindings cr.contractBindings.AddReadBinding(contractName, genericTopicName, eb) } @@ -296,56 +296,6 @@ func (cr *chainReader) addDecoderDef(contractName, itemType string, outputs abi. return output.Init() } -// remapFilter, changes chain agnostic filters to match evm specific filters. -func (e *eventBinding) remapFilter(filter query.KeyFilter) (remappedFilter query.KeyFilter, err error) { - addEventSigFilter := false - for _, expression := range filter.Expressions { - remappedExpression, hasComparatorPrimitive, err := e.remapExpression(filter.Key, expression) - if err != nil { - return query.KeyFilter{}, err - } - remappedFilter.Expressions = append(remappedFilter.Expressions, remappedExpression) - // comparator primitive maps to event by topic or event by evm data word filters, which means that event sig filter is not needed - addEventSigFilter = addEventSigFilter != hasComparatorPrimitive - } - - if addEventSigFilter { - remappedFilter.Expressions = append(remappedFilter.Expressions, NewEventBySigFilter(e.address, e.hash)) - } - return remappedFilter, nil -} - -func (e *eventBinding) remapExpression(key string, expression query.Expression) (remappedExpression query.Expression, hasComparerPrimitive bool, err error) { - if !expression.IsPrimitive() { - for i := range expression.BoolExpression.Expressions { - remappedExpression, hasComparerPrimitive, err = e.remapExpression(key, expression.BoolExpression.Expressions[i]) - if err != nil { - return query.Expression{}, false, err - } - remappedExpression.BoolExpression.Expressions = append(remappedExpression.BoolExpression.Expressions, remappedExpression) - } - - if expression.BoolExpression.BoolOperator == query.AND { - return query.And(remappedExpression.BoolExpression.Expressions...), hasComparerPrimitive, nil - } - return query.Or(remappedExpression.BoolExpression.Expressions...), hasComparerPrimitive, nil - } - - // remap chain agnostic primitives to chain specific - switch primitive := expression.Primitive.(type) { - case *primitives.Confirmations: - remappedExpression, err = NewFinalityFilter(primitive) - return remappedExpression, hasComparerPrimitive, err - case *primitives.Comparator: - if val, ok := e.eventDataWords[primitive.Name]; ok { - return NewEventByWordFilter(e.hash, val, primitive.ValueComparators), true, nil - } - return NewEventByTopicFilter(e.hash, e.topicsInfo[key].topicIndex, primitive.ValueComparators), true, nil - default: - return expression, hasComparerPrimitive, nil - } -} - func setupEventInput(event abi.Event, def types.ChainReaderDefinition) ([]abi.Argument, types.CodecEntry, map[string]bool) { topicFieldDefs := map[string]bool{} for _, value := range def.EventInputFields { diff --git a/core/services/relay/evm/chain_reader_test.go b/core/services/relay/evm/chain_reader_test.go index c3cb36b93e3..72332f38a06 100644 --- a/core/services/relay/evm/chain_reader_test.go +++ b/core/services/relay/evm/chain_reader_test.go @@ -49,6 +49,7 @@ const ( func TestChainReaderGetLatestValue(t *testing.T) { t.Parallel() it := &chainReaderInterfaceTester{} + RunChainReaderGetLatestValueInterfaceTests(t, it) RunChainReaderGetLatestValueInterfaceTests(t, commontestutils.WrapChainReaderTesterForLoop(it)) @@ -109,6 +110,14 @@ func TestChainReaderGetLatestValue(t *testing.T) { }) } +func TestChainReaderQueryKey(t *testing.T) { + t.Parallel() + it := &chainReaderInterfaceTester{} + + RunQueryKeyInterfaceTests(t, it) + RunQueryKeyInterfaceTests(t, commontestutils.WrapChainReaderTesterForLoop(it)) +} + func triggerFourTopics(t *testing.T, it *chainReaderInterfaceTester, i1, i2, i3 int32) { tx, err := it.evmTest.ChainReaderTesterTransactor.TriggerWithFourTopics(it.auth, i1, i2, i3) require.NoError(t, err) diff --git a/core/services/relay/evm/dsl.go b/core/services/relay/evm/dsl.go deleted file mode 100644 index 05592feb996..00000000000 --- a/core/services/relay/evm/dsl.go +++ /dev/null @@ -1,96 +0,0 @@ -package evm - -import ( - "fmt" - - "github.com/ethereum/go-ethereum/common" - - "github.com/smartcontractkit/chainlink-common/pkg/types/query" - "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" - evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" -) - -type EventBySigFilter struct { - Address common.Address - EventSig common.Hash -} - -func NewEventBySigFilter(address common.Address, eventSig common.Hash) query.Expression { - var searchEventFilter *EventBySigFilter - searchEventFilter.Address = address - searchEventFilter.EventSig = eventSig - return query.Expression{Primitive: searchEventFilter} -} - -func (f *EventBySigFilter) Accept(visitor primitives.Visitor) { - switch v := visitor.(type) { - case *PgDSLParser: - v.VisitEventBySigFilter(f) - } -} - -type EventByTopicFilter struct { - EventSig common.Hash - Topic uint64 - ValueComparators []primitives.ValueComparator -} - -func NewEventByTopicFilter(eventSig common.Hash, topicIndex uint64, valueComparators []primitives.ValueComparator) query.Expression { - var eventByIndexFilter *EventByTopicFilter - eventByIndexFilter.EventSig = eventSig - eventByIndexFilter.Topic = topicIndex - eventByIndexFilter.ValueComparators = valueComparators - - return query.Expression{Primitive: eventByIndexFilter} -} - -func (f *EventByTopicFilter) Accept(visitor primitives.Visitor) { - switch v := visitor.(type) { - case *PgDSLParser: - v.VisitEventTopicsByValueFilter(f) - } -} - -type EventByWordFilter struct { - EventSig common.Hash - WordIndex uint8 - ValueComparators []primitives.ValueComparator -} - -func NewEventByWordFilter(eventSig common.Hash, wordIndex uint8, valueComparators []primitives.ValueComparator) query.Expression { - var eventByIndexFilter *EventByWordFilter - eventByIndexFilter.EventSig = eventSig - eventByIndexFilter.WordIndex = wordIndex - eventByIndexFilter.ValueComparators = valueComparators - return query.Expression{Primitive: eventByIndexFilter} -} - -func (f *EventByWordFilter) Accept(visitor primitives.Visitor) { - switch v := visitor.(type) { - case *PgDSLParser: - v.VisitEventByWordFilter(f) - } -} - -type FinalityFilter struct { - Confs evmtypes.Confirmations -} - -func NewFinalityFilter(filter *primitives.Confirmations) (query.Expression, error) { - // TODO chain agnostic confidence levels that map to evm finality - switch filter.ConfirmationLevel { - case primitives.Finalized: - return query.Expression{Primitive: &FinalityFilter{evmtypes.Finalized}}, nil - case primitives.Unconfirmed: - return query.Expression{Primitive: &FinalityFilter{evmtypes.Unconfirmed}}, nil - default: - return query.Expression{}, fmt.Errorf("invalid finality confirmations filter value %v", filter.ConfirmationLevel) - } -} - -func (f *FinalityFilter) Accept(visitor primitives.Visitor) { - switch v := visitor.(type) { - case *PgDSLParser: - v.VisitFinalityFilter(f) - } -} diff --git a/core/services/relay/evm/event_binding.go b/core/services/relay/evm/event_binding.go index f43576de4ff..aad991660bb 100644 --- a/core/services/relay/evm/event_binding.go +++ b/core/services/relay/evm/event_binding.go @@ -13,6 +13,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/codec" commontypes "github.com/smartcontractkit/chainlink-common/pkg/types" "github.com/smartcontractkit/chainlink-common/pkg/types/query" + "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" "github.com/smartcontractkit/chainlink/v2/core/chains/evm/logpoller" evmtypes "github.com/smartcontractkit/chainlink/v2/core/chains/evm/types" @@ -33,48 +34,22 @@ type eventBinding struct { inputInfo types.CodecEntry inputModifier codec.Modifier codecTopicInfo types.CodecEntry - // key is generic topic name - topicsInfo map[string]topicInfo + // topics maps a generic topic name (key) to topic data + topics map[string]topicDetail + // eventDataWords maps a generic name to a word index // key is a predefined generic name for evm log event data word // for eg. first evm data word(32bytes) of USDC log event is value so the key can be called value eventDataWords map[string]uint8 - // used to allow Register and Unregister to be unique in case two bindings have the same event. - // otherwise, if one unregisters, it'll unregister both with the LogPoller. - id string + id string } -type topicInfo struct { +type topicDetail struct { abi.Argument - topicIndex uint64 + Index uint64 } var _ readBinding = &eventBinding{} -func (e *eventBinding) decodeLogsIntoSequences(ctx context.Context, logs []logpoller.Log, into any) ([]commontypes.Sequence, error) { - var sequences []commontypes.Sequence - for i := range logs { - sequence := commontypes.Sequence{ - // TODO SequenceCursor, should be combination of block, eventsig, topic and also match a proper db cursor?... - Cursor: "TODO", - Head: commontypes.Head{ - Identifier: fmt.Sprint(logs[i].BlockNumber), - Hash: logs[i].BlockHash.Bytes(), - Timestamp: uint64(logs[i].BlockTimestamp.Unix()), - }, - // TODO test this - Data: reflect.New(reflect.TypeOf(into).Elem()), - } - - if err := e.decodeLog(ctx, &logs[i], sequence.Data); err != nil { - return nil, err - } - - sequences = append(sequences, sequence) - } - - return sequences, nil -} - func (e *eventBinding) SetCodec(codec commontypes.RemoteCodec) { e.codec = codec } @@ -130,16 +105,27 @@ func (e *eventBinding) GetLatestValue(ctx context.Context, params, into any) err } func (e *eventBinding) QueryKey(ctx context.Context, filter query.KeyFilter, limitAndSort query.LimitAndSort, sequenceDataType any) ([]commontypes.Sequence, error) { - remappedFilter, err := e.remapFilter(filter) + remapped, err := e.remap(filter) if err != nil { return nil, err } - logs, err := e.lp.FilteredLogs(remappedFilter, limitAndSort) + // filter should always use the address and event sig + defaultExpressions := []query.Expression{ + logpoller.NewAddressFilter(e.address), + logpoller.NewEventSigFilter(e.hash), + } + remapped.Expressions = append(defaultExpressions, remapped.Expressions...) + + logs, err := e.lp.FilteredLogs(ctx, remapped, limitAndSort) if err != nil { return nil, err } + if len(logs) == 0 { + return nil, commontypes.ErrNotFound + } + return e.decodeLogsIntoSequences(ctx, logs, sequenceDataType) } @@ -326,6 +312,91 @@ func (e *eventBinding) decodeLog(ctx context.Context, log *logpoller.Log, into a return mapstructureDecode(topicsInto, into) } +func (e *eventBinding) decodeLogsIntoSequences(ctx context.Context, logs []logpoller.Log, into any) ([]commontypes.Sequence, error) { + sequences := make([]commontypes.Sequence, len(logs)) + + for idx := range logs { + sequences[idx] = commontypes.Sequence{ + Cursor: fmt.Sprintf("%s-%s-%d", logs[idx].BlockHash, logs[idx].TxHash, logs[idx].LogIndex), + Head: commontypes.Head{ + Identifier: fmt.Sprint(logs[idx].BlockNumber), + Hash: logs[idx].BlockHash.Bytes(), + Timestamp: uint64(logs[idx].BlockTimestamp.Unix()), + }, + } + + var tpVal reflect.Value + + tpInto := reflect.TypeOf(into) + if tpInto.Kind() == reflect.Pointer { + tpVal = reflect.New(tpInto.Elem()) + } else { + tpVal = reflect.Indirect(reflect.New(tpInto)) + } + + // create a new value of the same type as 'into' for the data to be extracted to + sequences[idx].Data = tpVal.Interface() + + if err := e.decodeLog(ctx, &logs[idx], sequences[idx].Data); err != nil { + return nil, err + } + } + + return sequences, nil +} + +func (e *eventBinding) remap(filter query.KeyFilter) (query.KeyFilter, error) { + remapped := query.KeyFilter{} + + for _, expression := range filter.Expressions { + remappedExpression, err := e.remapExpression(filter.Key, expression) + if err != nil { + return query.KeyFilter{}, err + } + + remapped.Expressions = append(remapped.Expressions, remappedExpression) + } + + return remapped, nil +} + +func (e *eventBinding) remapExpression(key string, expression query.Expression) (query.Expression, error) { + if !expression.IsPrimitive() { + remappedBoolExpressions := make([]query.Expression, len(expression.BoolExpression.Expressions)) + + for i := range expression.BoolExpression.Expressions { + remapped, err := e.remapExpression(key, expression.BoolExpression.Expressions[i]) + if err != nil { + return query.Expression{}, err + } + + remappedBoolExpressions[i] = remapped + } + + if expression.BoolExpression.BoolOperator == query.AND { + return query.And(remappedBoolExpressions...), nil + } + + return query.Or(remappedBoolExpressions...), nil + } + + return e.remapPrimitive(key, expression) +} + +func (e *eventBinding) remapPrimitive(key string, expression query.Expression) (query.Expression, error) { + // remap chain agnostic primitives to chain specific + switch primitive := expression.Primitive.(type) { + case *primitives.Comparator: + if val, ok := e.eventDataWords[primitive.Name]; ok { + return logpoller.NewEventByWordFilter(e.hash, val, primitive.ValueComparators), nil + } + + return logpoller.NewEventByTopicFilter(e.hash, e.topics[key].Index, primitive.ValueComparators), nil + default: + return expression, nil + } +} + func wrapInternalErr(err error) error { if err == nil { return nil diff --git a/core/services/relay/evm/pgparser.go b/core/services/relay/evm/pgparser.go deleted file mode 100644 index 62e95c10121..00000000000 --- a/core/services/relay/evm/pgparser.go +++ /dev/null @@ -1,36 +0,0 @@ -package evm - -import ( - "math/big" - - "github.com/smartcontractkit/chainlink-common/pkg/types/query/primitives" -) - -// PgDSLParser is a visitor that builds a postgres query and arguments from a query.KeyFilter -type PgDSLParser struct { - //TODO implement psql parser -} - -var _ primitives.Visitor = (*PgDSLParser)(nil) - -func NewPgParser(evmChainID *big.Int) *PgDSLParser { - return &PgDSLParser{} -} - -func (v *PgDSLParser) Comparator(_ primitives.Comparator) {} - -func (v *PgDSLParser) Block(_ primitives.Block) {} - -func (v *PgDSLParser) Confirmations(_ primitives.Confirmations) {} - -func (v *PgDSLParser) Timestamp(_ primitives.Timestamp) {} - -func (v *PgDSLParser) TxHash(_ primitives.TxHash) {} - -func (v *PgDSLParser) VisitEventTopicsByValueFilter(_ *EventByTopicFilter) {} - -func (v *PgDSLParser) VisitEventByWordFilter(_ *EventByWordFilter) {} - -func (v *PgDSLParser) VisitEventBySigFilter(_ *EventBySigFilter) {} - -func (v *PgDSLParser) VisitFinalityFilter(_ *FinalityFilter) {}