Skip to content

Commit

Permalink
Fix event definitions and filters handling in CR
Browse files Browse the repository at this point in the history
  • Loading branch information
ilija42 committed Jun 4, 2024
1 parent a1cdcf9 commit 474097a
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 84 deletions.
18 changes: 9 additions & 9 deletions core/chains/evm/logpoller/log_poller.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,15 @@ func NewLogPoller(orm ORM, ec Client, lggr logger.Logger, opts Opts) *logPoller
}

type Filter struct {
Name string `json:"name"` // see FilterName(id, args) below
Addresses evmtypes.AddressArray `json:"addresses"`
EventSigs evmtypes.HashArray `json:"eventSigs"` // list of possible values for eventsig (aka topic1)
Topic2 evmtypes.HashArray `json:"topic2"` // list of possible values for topic2
Topic3 evmtypes.HashArray `json:"topic3"` // list of possible values for topic3
Topic4 evmtypes.HashArray `json:"topic4"` // list of possible values for topic4
Retention time.Duration `json:"retention"` // maximum amount of time to retain logs
MaxLogsKept uint64 `json:"maxLogsKept"` // maximum number of logs to retain ( 0 = unlimited )
LogsPerBlock uint64 `json:"logsPerBlock"` // rate limit ( maximum # of logs per block, 0 = unlimited )
Name string // see FilterName(id, args) below
Addresses evmtypes.AddressArray
EventSigs evmtypes.HashArray // list of possible values for eventsig (aka topic1)
Topic2 evmtypes.HashArray // list of possible values for topic2
Topic3 evmtypes.HashArray // list of possible values for topic3
Topic4 evmtypes.HashArray // list of possible values for topic4
Retention time.Duration // maximum amount of time to retain logs
MaxLogsKept uint64 // maximum number of logs to retain ( 0 = unlimited )
LogsPerBlock uint64 // rate limit ( maximum # of logs per block, 0 = unlimited )
}

// FilterName is a suggested convenience function for clients to construct unique filter names
Expand Down
2 changes: 2 additions & 0 deletions core/internal/features/ocr2/features_ocr2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ fromBlock = %d
if test.chainReaderAndCodec {
chainReaderSpec = `
[relayConfig.chainReader.contracts.median]
contractPollingFilter.genericEventNames = ["LatestRoundRequested"]
contractABI = '''
[
{
Expand Down
18 changes: 9 additions & 9 deletions core/services/relay/evm/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ type FilterRegisterer struct {
}

type contractBindings struct {
// FilterRegisterer is used to manage polling filter registration.
// FilterRegisterer is used to manage polling filter registration for the contact wide event filter.
FilterRegisterer
// key is read name
bindings map[string]readBinding
readBindings map[string]readBinding
}

func (b bindings) GetReadBinding(contractName, readName string) (readBinding, error) {
Expand All @@ -35,7 +35,7 @@ func (b bindings) GetReadBinding(contractName, readName string) (readBinding, er
return nil, fmt.Errorf("%w: no contract named %s", commontypes.ErrInvalidType, contractName)
}

reader, readerExists := rb.bindings[readName]
reader, readerExists := rb.readBindings[readName]
if !readerExists {
return nil, fmt.Errorf("%w: no readName named %s in contract %s", commontypes.ErrInvalidType, readName, contractName)
}
Expand All @@ -45,10 +45,10 @@ func (b bindings) GetReadBinding(contractName, readName string) (readBinding, er
func (b bindings) AddReadBinding(contractName, readName string, rb readBinding) {
rbs, rbsExists := b[contractName]
if !rbsExists {
rbs = &contractBindings{}
rbs = &contractBindings{readBindings: make(map[string]readBinding)}
b[contractName] = rbs
}
rbs.bindings[readName] = rb
rbs.readBindings[readName] = rb
}

func (b bindings) Bind(ctx context.Context, logPoller logpoller.LogPoller, boundContracts []commontypes.BoundContract) error {
Expand All @@ -61,12 +61,12 @@ func (b bindings) Bind(ctx context.Context, logPoller logpoller.LogPoller, bound
rbs.pollingFilter.Addresses = evmtypes.AddressArray{common.HexToAddress(bc.Address)}
rbs.pollingFilter.Name = logpoller.FilterName(bc.Name+"."+uuid.NewString(), bc.Address)

// we are changing contract address reference, so we need to unregister old filters
// we are changing contract address reference, so we need to unregister old filters if they exist
if err := rbs.Unregister(ctx, logPoller); err != nil {
return err
}

for _, r := range rbs.bindings {
for _, r := range rbs.readBindings {
r.Bind(bc)
}

Expand Down Expand Up @@ -105,7 +105,7 @@ func (rb *contractBindings) Register(ctx context.Context, logPoller logpoller.Lo
return fmt.Errorf("%w: %w", commontypes.ErrInternal, err)
}

for _, binding := range rb.bindings {
for _, binding := range rb.readBindings {
if err := binding.Register(ctx); err != nil {
return err
}
Expand All @@ -126,7 +126,7 @@ func (rb *contractBindings) Unregister(ctx context.Context, logPoller logpoller.
return fmt.Errorf("%w: %w", commontypes.ErrInternal, err)
}

for _, binding := range rb.bindings {
for _, binding := range rb.readBindings {
if err := binding.Unregister(ctx); err != nil {
return err
}
Expand Down
86 changes: 48 additions & 38 deletions core/services/relay/evm/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"reflect"
"slices"
"strings"
"sync"
"time"

"github.com/ethereum/go-ethereum/accounts/abi"
Expand Down Expand Up @@ -64,7 +65,7 @@ func NewChainReaderService(ctx context.Context, lggr logger.Logger, lp logpoller
}

err = cr.contractBindings.ForEach(ctx, func(c context.Context, rbs *contractBindings) error {
for _, rb := range rbs.bindings {
for _, rb := range rbs.readBindings {
rb.SetCodec(cr.codec)
}
return nil
Expand Down Expand Up @@ -106,27 +107,29 @@ func (cr *chainReader) init(chainContractReaders map[string]types.ChainContractR
return err
}

var eventSigsForPollingFilter evmtypes.HashArray
contractFilterEvents := chainContractReader.ContractPollingFilter.GenericEventNames
var eventSigsForContractFilter evmtypes.HashArray
for typeName, chainReaderDefinition := range chainContractReader.Configs {
switch chainReaderDefinition.ReadType {
case types.Method:
err = cr.addMethod(contractName, typeName, contractAbi, *chainReaderDefinition)
case types.Event:
partOfContractPollingFilter := slices.Contains(chainContractReader.GenericEventNames, typeName)
hasContractFilterOverride := chainReaderDefinition.HasPollingFilter()
if !partOfContractPollingFilter && !chainReaderDefinition.HasPollingFilter() {
partOfContractFilter := slices.Contains(contractFilterEvents, typeName)
if !partOfContractFilter && !chainReaderDefinition.HasPollingFilter() {
return fmt.Errorf(
"%w: chain reader has no polling filter defined for contract: %s event: %s",
"%w: chain reader has no polling filter defined for contract: %s, event: %s",
commontypes.ErrInvalidConfig, contractName, typeName)
}
if hasContractFilterOverride && partOfContractPollingFilter {

eventOverridesContractFilter := chainReaderDefinition.HasPollingFilter()
if eventOverridesContractFilter && partOfContractFilter {
return fmt.Errorf(
"%w: conflicting chain reader polling filter definitions for contract: %s event: %s, can't have polling filter defined both on contract and event level",
commontypes.ErrInvalidConfig, contractName, typeName)
}

if !hasContractFilterOverride {
eventSigsForPollingFilter = append(eventSigsForPollingFilter, contractAbi.Events[chainReaderDefinition.ChainSpecificName].ID)
if !eventOverridesContractFilter {
eventSigsForContractFilter = append(eventSigsForContractFilter, contractAbi.Events[chainReaderDefinition.ChainSpecificName].ID)
}

err = cr.addEvent(contractName, typeName, contractAbi, *chainReaderDefinition)
Expand All @@ -140,7 +143,7 @@ func (cr *chainReader) init(chainContractReaders map[string]types.ChainContractR
return err
}
}
cr.contractBindings[contractName].pollingFilter = chainContractReader.PollingFilter.ToLPFilter(eventSigsForPollingFilter)
cr.contractBindings[contractName].pollingFilter = chainContractReader.PollingFilter.ToLPFilter(eventSigsForContractFilter)
}
return nil
}
Expand Down Expand Up @@ -189,13 +192,6 @@ func (cr *chainReader) addMethod(
return fmt.Errorf("%w: method %s doesn't exist", commontypes.ErrInvalidConfig, chainReaderDefinition.ChainSpecificName)
}

if chainReaderDefinition.EventDefinitions != nil {
return fmt.Errorf(
"%w: method %s has event definition, but is not an event",
commontypes.ErrInvalidConfig,
chainReaderDefinition.ChainSpecificName)
}

cr.contractBindings.AddReadBinding(contractName, methodName, &methodBinding{
contractName: contractName,
method: methodName,
Expand All @@ -215,12 +211,13 @@ func (cr *chainReader) addEvent(contractName, eventName string, a abi.ABI, chain
return fmt.Errorf("%w: event %s doesn't exist", commontypes.ErrInvalidConfig, chainReaderDefinition.ChainSpecificName)
}

if chainReaderDefinition.EventDefinitions == nil {
return fmt.Errorf("%w: event %s doesn't have event definitions set", commontypes.ErrInvalidConfig, chainReaderDefinition.ChainSpecificName)
var inputFields []string
if chainReaderDefinition.EventDefinitions != nil {
inputFields = chainReaderDefinition.EventDefinitions.InputFields
}

filterArgs, codecTopicInfo, indexArgNames := setupEventInput(event, chainReaderDefinition)
if err := verifyEventInputsUsed(chainReaderDefinition, indexArgNames); err != nil {
filterArgs, codecTopicInfo, indexArgNames := setupEventInput(event, inputFields)
if err := verifyEventInputsUsed(eventName, inputFields, indexArgNames); err != nil {
return err
}

Expand All @@ -243,44 +240,57 @@ func (cr *chainReader) addEvent(contractName, eventName string, a abi.ABI, chain
return err
}

eventDefinitions := chainReaderDefinition.EventDefinitions
eb := &eventBinding{
contractName: contractName,
eventName: eventName,
logPollerFilter: eventDefinitions.PollingFilter.ToLPFilter(evmtypes.HashArray{a.Events[event.Name].ID}),
lp: cr.lp,
hash: event.ID,
inputInfo: inputInfo,
inputModifier: inputModifier,
codecTopicInfo: codecTopicInfo,
topics: make(map[string]topicDetail),
eventDataWords: eventDefinitions.GenericDataWordNames,
eventDataWords: make(map[string]uint8),
confirmationsMapping: confirmations,
}

if eventDefinitions := chainReaderDefinition.EventDefinitions; eventDefinitions != nil {
if eventDefinitions.PollingFilter != nil {
eb.FilterRegisterer = &FilterRegisterer{
pollingFilter: eventDefinitions.PollingFilter.ToLPFilter(evmtypes.HashArray{a.Events[event.Name].ID}),
filterLock: sync.Mutex{},
}
}

if eventDefinitions.GenericDataWordNames != nil {
eb.eventDataWords = eventDefinitions.GenericDataWordNames
}

cr.addQueryingReadBindings(contractName, eventDefinitions.GenericTopicNames, event.Inputs, eb)
}

cr.contractBindings.AddReadBinding(contractName, eventName, eb)

// set topic mappings for QueryKeys
for topicIndex, topic := range event.Inputs {
genericTopicName, ok := eventDefinitions.GenericTopicNames[topic.Name]
return cr.addDecoderDef(contractName, eventName, event.Inputs, chainReaderDefinition)
}

// addQueryingReadBindings reuses the eventBinding and maps it to topic and dataWord keys used for QueryKey.
func (cr *chainReader) addQueryingReadBindings(contractName string, genericTopicNames map[string]string, eventInputs abi.Arguments, eb *eventBinding) {
// add topic read readBindings for QueryKey
for topicIndex, topic := range eventInputs {
genericTopicName, ok := genericTopicNames[topic.Name]
if ok {
eb.topics[genericTopicName] = topicDetail{
Argument: topic,
Index: uint64(topicIndex),
}
}

// this way querying by key/s values comparison can find its bindings
cr.contractBindings.AddReadBinding(contractName, genericTopicName, eb)
}

// set data word mappings for QueryKeys
// add data word read readBindings for QueryKey
for genericDataWordName := range eb.eventDataWords {
// this way querying by key/s values comparison can find its bindings
cr.contractBindings.AddReadBinding(contractName, genericDataWordName, eb)
}

return cr.addDecoderDef(contractName, eventName, event.Inputs, chainReaderDefinition)
}

func (cr *chainReader) getEventInput(def types.ChainReaderDefinition, contractName, eventName string) (
Expand All @@ -299,10 +309,10 @@ func (cr *chainReader) getEventInput(def types.ChainReaderDefinition, contractNa
return inputInfo, inMod, nil
}

func verifyEventInputsUsed(chainReaderDefinition types.ChainReaderDefinition, indexArgNames map[string]bool) error {
for _, value := range chainReaderDefinition.EventDefinitions.InputFields {
func verifyEventInputsUsed(eventName string, inputFields []string, indexArgNames map[string]bool) error {
for _, value := range inputFields {
if !indexArgNames[abi.ToCamelCase(value)] {
return fmt.Errorf("%w: %s is not an indexed argument of event %s", commontypes.ErrInvalidConfig, value, chainReaderDefinition.ChainSpecificName)
return fmt.Errorf("%w: %s is not an indexed argument of event %s", commontypes.ErrInvalidConfig, value, eventName)
}
}
return nil
Expand Down Expand Up @@ -334,9 +344,9 @@ func (cr *chainReader) addDecoderDef(contractName, itemType string, outputs abi.
return output.Init()
}

func setupEventInput(event abi.Event, def types.ChainReaderDefinition) ([]abi.Argument, types.CodecEntry, map[string]bool) {
func setupEventInput(event abi.Event, inputFields []string) ([]abi.Argument, types.CodecEntry, map[string]bool) {
topicFieldDefs := map[string]bool{}
for _, value := range def.EventDefinitions.InputFields {
for _, value := range inputFields {
capFirstValue := abi.ToCamelCase(value)
topicFieldDefs[capFirstValue] = true
}
Expand Down
2 changes: 0 additions & 2 deletions core/services/relay/evm/chain_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ const (
func TestChainReaderInterfaceTests(t *testing.T) {
t.Parallel()
it := &chainReaderInterfaceTester{}

RunChainReaderInterfaceTests(t, it)
RunChainReaderInterfaceTests(t, commontestutils.WrapChainReaderTesterForLoop(it))

t.Run("Dynamically typed topics can be used to filter and have type correct in return", func(t *testing.T) {
it.Setup(t)

Expand Down
45 changes: 28 additions & 17 deletions core/services/relay/evm/event_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@ import (
)

type eventBinding struct {
FilterRegisterer
address common.Address
contractName string
eventName string
lp logpoller.LogPoller
logPollerFilter logpoller.Filter
hash common.Hash
codec commontypes.RemoteCodec
pending bool
bound bool
inputInfo types.CodecEntry
inputModifier codec.Modifier
codecTopicInfo types.CodecEntry
address common.Address
contractName string
eventName string
lp logpoller.LogPoller
// FilterRegisterer in eventBinding is to be used as an override for lp filter defined in the contract binding.
// If FilterRegisterer is nil, this event should be registered with the lp filter defined in the contract binding.
*FilterRegisterer
hash common.Hash
codec commontypes.RemoteCodec
pending bool
bound bool
inputInfo types.CodecEntry
inputModifier codec.Modifier
codecTopicInfo types.CodecEntry
// topics maps a generic topic name (key) to topic data
topics map[string]topicDetail
// eventDataWords maps a generic name to a word index
Expand All @@ -55,6 +56,10 @@ func (e *eventBinding) SetCodec(codec commontypes.RemoteCodec) {
}

func (e *eventBinding) Register(ctx context.Context) error {
if e.FilterRegisterer == nil {
return nil
}

e.filterLock.Lock()
defer e.filterLock.Unlock()

Expand All @@ -63,13 +68,17 @@ func (e *eventBinding) Register(ctx context.Context) error {
return nil
}

if err := e.lp.RegisterFilter(ctx, e.logPollerFilter); err != nil {
if err := e.lp.RegisterFilter(ctx, e.pollingFilter); err != nil {
return fmt.Errorf("%w: %w", commontypes.ErrInternal, err)
}
return nil
}

func (e *eventBinding) Unregister(ctx context.Context) error {
if e.FilterRegisterer == nil {
return nil
}

e.filterLock.Lock()
defer e.filterLock.Unlock()

Expand Down Expand Up @@ -135,9 +144,11 @@ func (e *eventBinding) Bind(binding commontypes.BoundContract) {
e.address = common.HexToAddress(binding.Address)
e.pending = binding.Pending

id := fmt.Sprintf("%s,%s,%s", e.contractName, e.eventName, uuid.NewString())
e.logPollerFilter.Name = logpoller.FilterName(id, e.address)
e.logPollerFilter.Addresses = evmtypes.AddressArray{e.address}
if e.FilterRegisterer != nil {
id := fmt.Sprintf("%s,%s,%s", e.contractName, e.eventName, uuid.NewString())
e.pollingFilter.Name = logpoller.FilterName(id, e.address)
e.pollingFilter.Addresses = evmtypes.AddressArray{e.address}
}

e.bound = true
}
Expand Down
Loading

0 comments on commit 474097a

Please sign in to comment.