Skip to content

Commit

Permalink
Merge pull request #611 from lidofinance/check-consensus-for-atts
Browse files Browse the repository at this point in the history
chore: check consensus version to enable EIP7549 support
  • Loading branch information
F4ever authored Jan 31, 2025
2 parents 1b24302 + cf0fc75 commit b76ce13
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 20 deletions.
22 changes: 19 additions & 3 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,21 @@ class FrameCheckpointProcessor:
state: State
finalized_blockstamp: BlockStamp

def __init__(self, cc: ConsensusClient, state: State, converter: Web3Converter, finalized_blockstamp: BlockStamp):
eip7549_supported: bool

def __init__(
self,
cc: ConsensusClient,
state: State,
converter: Web3Converter,
finalized_blockstamp: BlockStamp,
eip7549_supported: bool = True,
):
self.cc = cc
self.converter = converter
self.state = state
self.finalized_blockstamp = finalized_blockstamp
self.eip7549_supported = eip7549_supported

def exec(self, checkpoint: FrameCheckpoint) -> int:
logger.info(
Expand Down Expand Up @@ -193,7 +203,7 @@ def _check_duty(
committees = self._prepare_committees(duty_epoch)
for root in block_roots:
attestations = self.cc.get_block_attestations(root)
process_attestations(attestations, committees)
process_attestations(attestations, committees, self.eip7549_supported)

with lock:
for committee in committees.values():
Expand Down Expand Up @@ -227,8 +237,14 @@ def _prepare_committees(self, epoch: EpochNumber) -> Committees:
return committees


def process_attestations(attestations: Iterable[BlockAttestation], committees: Committees) -> None:
def process_attestations(
attestations: Iterable[BlockAttestation],
committees: Committees,
eip7549_supported: bool = True,
) -> None:
for attestation in attestations:
if is_eip7549_attestation(attestation) and not eip7549_supported:
raise ValueError("EIP-7549 support is not enabled")
committee_offset = 0
for committee_idx in get_committee_indices(attestation):
committee = committees.get((attestation.data.slot, committee_idx), [])
Expand Down
7 changes: 5 additions & 2 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def validate_state(self, blockstamp: ReferenceBlockStamp) -> None:
def collect_data(self, blockstamp: BlockStamp) -> bool:
"""Ongoing report data collection for the estimated reference slot"""

consensus_version = self.get_consensus_version(blockstamp)
eip7549_supported = consensus_version != 1

logger.info({"msg": "Collecting data for the report"})

converter = self.converter(blockstamp)
Expand Down Expand Up @@ -198,7 +201,7 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
logger.info({"msg": "The starting epoch of the frame is not finalized yet"})
return False

self.state.migrate(l_epoch, r_epoch)
self.state.migrate(l_epoch, r_epoch, consensus_version)
self.state.log_progress()

if self.state.is_fulfilled:
Expand All @@ -212,7 +215,7 @@ def collect_data(self, blockstamp: BlockStamp) -> bool:
except MinStepIsNotReached:
return False

processor = FrameCheckpointProcessor(self.w3.cc, self.state, converter, blockstamp)
processor = FrameCheckpointProcessor(self.w3.cc, self.state, converter, blockstamp, eip7549_supported)

for checkpoint in checkpoints:
if self.current_frame_range(self._receive_last_finalized_slot()) != (l_epoch, r_epoch):
Expand Down
16 changes: 14 additions & 2 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from pathlib import Path
from typing import Self

from src import variables
from src.types import EpochNumber, ValidatorIndex
from src.utils.range import sequence
from src import variables

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,6 +49,8 @@ class State:
_epochs_to_process: tuple[EpochNumber, ...]
_processed_epochs: set[EpochNumber]

_consensus_version: int = 1

def __init__(self, data: dict[ValidatorIndex, AttestationsAccumulator] | None = None) -> None:
self.data = defaultdict(AttestationsAccumulator, data or {})
self._epochs_to_process = tuple()
Expand Down Expand Up @@ -102,7 +104,16 @@ def add_processed_epoch(self, epoch: EpochNumber) -> None:
def log_progress(self) -> None:
logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"})

def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber):
def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version: int):
if consensus_version != self._consensus_version:
logger.warning(
{
"msg": f"Cache was built for consensus version {self._consensus_version}. "
f"Discarding data to migrate to consensus version {consensus_version}"
}
)
self.clear()

for state_epochs in (self._epochs_to_process, self._processed_epochs):
for epoch in state_epochs:
if epoch < l_epoch or epoch > r_epoch:
Expand All @@ -111,6 +122,7 @@ def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber):
break

self._epochs_to_process = tuple(sequence(l_epoch, r_epoch))
self._consensus_version = consensus_version
self.commit()

def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
Expand Down
42 changes: 37 additions & 5 deletions tests/modules/csm/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import cast
from unittest.mock import Mock

from faker import Faker
import pytest
from faker import Faker

import src.modules.csm.checkpoint as checkpoint_module
from src.modules.csm.checkpoint import (
FrameCheckpoint,
FrameCheckpointProcessor,
Expand All @@ -16,7 +17,7 @@
from src.modules.submodules.types import ChainConfig, FrameConfig
from src.providers.consensus.client import ConsensusClient
from src.providers.consensus.types import BeaconSpecResponse, BlockAttestation, SlotAttestationCommittee
from src.types import SlotNumber, ValidatorIndex
from src.types import EpochNumber, SlotNumber, ValidatorIndex
from src.utils.web3converter import Web3Converter
from tests.factory.bitarrays import BitListFactory
from tests.factory.configs import (
Expand Down Expand Up @@ -313,6 +314,37 @@ def _get_block_attestations(root):
consensus_client.get_block_attestations = Mock(side_effect=_get_block_attestations)


@pytest.mark.usefixtures(
"mock_get_state_block_roots",
"mock_get_attestation_committees",
"mock_get_block_attestations",
"mock_get_config_spec",
)
def test_checkpoints_processor_no_eip7549_support(
consensus_client,
converter,
monkeypatch: pytest.MonkeyPatch,
):
state = State()
state.migrate(EpochNumber(0), EpochNumber(255))
processor = FrameCheckpointProcessor(
consensus_client,
state,
converter,
Mock(),
eip7549_supported=False,
)
roots = processor._get_block_roots(SlotNumber(0))
with monkeypatch.context():
monkeypatch.setattr(
checkpoint_module,
"is_eip7549_attestation",
Mock(return_value=True),
)
with pytest.raises(ValueError, match="support is not enabled"):
processor._check_duty(0, roots[:64])


def test_checkpoints_processor_check_duty(
mock_get_state_block_roots,
mock_get_attestation_committees,
Expand All @@ -322,7 +354,7 @@ def test_checkpoints_processor_check_duty(
converter,
):
state = State()
state.migrate(0, 255)
state.migrate(0, 255, 1)
finalized_blockstamp = ...
processor = FrameCheckpointProcessor(
consensus_client,
Expand All @@ -347,7 +379,7 @@ def test_checkpoints_processor_process(
converter,
):
state = State()
state.migrate(0, 255)
state.migrate(0, 255, 1)
finalized_blockstamp = ...
processor = FrameCheckpointProcessor(
consensus_client,
Expand All @@ -372,7 +404,7 @@ def test_checkpoints_processor_exec(
converter,
):
state = State()
state.migrate(0, 255)
state.migrate(0, 255, 1)
finalized_blockstamp = ...
processor = FrameCheckpointProcessor(
consensus_client,
Expand Down
2 changes: 1 addition & 1 deletion tests/modules/csm/test_csm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_calculate_distribution(module: CSOracle, csm: CSM):
ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000),
}
)
module.state.migrate(EpochNumber(100), EpochNumber(500))
module.state.migrate(EpochNumber(100), EpochNumber(500), 1)

_, shares, log = module.calculate_distribution(blockstamp=Mock())

Expand Down
34 changes: 27 additions & 7 deletions tests/modules/csm/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ def test_state_avg_perf():
def test_state_frame():
state = State()

state.migrate(EpochNumber(100), EpochNumber(500))
state.migrate(EpochNumber(100), EpochNumber(500), 1)
assert state.frame == (100, 500)

state.migrate(EpochNumber(300), EpochNumber(301))
state.migrate(EpochNumber(300), EpochNumber(301), 1)
assert state.frame == (300, 301)

state.clear()
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_empty_to_new_frame(self):
l_epoch = EpochNumber(1)
r_epoch = EpochNumber(255)

state.migrate(l_epoch, r_epoch)
state.migrate(l_epoch, r_epoch, 1)

assert not state.is_empty
assert state.unprocessed_epochs == set(sequence(l_epoch, r_epoch))
Expand All @@ -171,10 +171,10 @@ def test_empty_to_new_frame(self):
def test_new_frame_requires_discarding_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new):
state = State()
state.clear = Mock(side_effect=state.clear)
state.migrate(l_epoch_old, r_epoch_old)
state.migrate(l_epoch_old, r_epoch_old, 1)
state.clear.assert_not_called()

state.migrate(l_epoch_new, r_epoch_new)
state.migrate(l_epoch_new, r_epoch_new, 1)
state.clear.assert_called_once()

assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new))
Expand All @@ -190,10 +190,30 @@ def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new
state = State()
state.clear = Mock(side_effect=state.clear)

state.migrate(l_epoch_old, r_epoch_old)
state.migrate(l_epoch_old, r_epoch_old, 1)
state.clear.assert_not_called()

state.migrate(l_epoch_new, r_epoch_new)
state.migrate(l_epoch_new, r_epoch_new, 1)
state.clear.assert_not_called()

assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new))

@pytest.mark.parametrize(
("old_version", "new_version"),
[
pytest.param(2, 3, id="Increase consensus version"),
pytest.param(3, 2, id="Decrease consensus version"),
],
)
def test_consensus_version_change(self, old_version, new_version):
state = State()
state.clear = Mock(side_effect=state.clear)
state._consensus_version = old_version

l_epoch = r_epoch = EpochNumber(255)

state.migrate(l_epoch, r_epoch, old_version)
state.clear.assert_not_called()

state.migrate(l_epoch, r_epoch, new_version)
state.clear.assert_called_once()

0 comments on commit b76ce13

Please sign in to comment.