Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
vgorkavenko committed Jan 29, 2025
1 parent c13d47a commit 684afa5
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 104 deletions.
148 changes: 83 additions & 65 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@
from typing import Iterable, Sequence, TypeGuard

from src import variables
from src.constants import SLOTS_PER_HISTORICAL_ROOT
from src.metrics.prometheus.csm import CSM_MIN_UNPROCESSED_EPOCH, CSM_UNPROCESSED_EPOCHS_COUNT
from src.constants import SLOTS_PER_HISTORICAL_ROOT, EPOCHS_PER_SYNC_COMMITTEE_PERIOD
from src.metrics.prometheus.csm import CSM_UNPROCESSED_EPOCHS_COUNT, CSM_MIN_UNPROCESSED_EPOCH
from src.modules.csm.state import State
from src.providers.consensus.client import ConsensusClient
from src.providers.consensus.types import BlockAttestation, BlockAttestationEIP7549
from src.providers.consensus.types import BlockAttestation, BlockDetailsResponse, SyncCommittee, ProposerDuties
from src.providers.consensus.types import BlockAttestation, BlockAttestationEIP7549, SyncCommittee, SyncAggregate
from src.types import BlockRoot, BlockStamp, EpochNumber, SlotNumber, ValidatorIndex
from src.utils.blockstamp import build_blockstamp
from src.utils.range import sequence
from src.utils.slot import get_next_non_missed_slot, get_blockstamp
from src.utils.slot import get_next_non_missed_slot, get_prev_non_missed_slot
from src.utils.timeit import timeit
from src.utils.types import hex_str_to_bytes
from src.utils.web3converter import Web3Converter
Expand Down Expand Up @@ -110,6 +107,8 @@ def _is_min_step_reached(self):

type Slot = str
type CommitteeIndex = str
type SlotBlockRoot = tuple[SlotNumber, BlockRoot | None]
type SyncCommittees = dict[SlotNumber, list[ValidatorDuty]]
type AttestationCommittees = dict[tuple[Slot, CommitteeIndex], list[ValidatorDuty]]


Expand Down Expand Up @@ -157,7 +156,7 @@ def _get_block_roots(self, checkpoint_slot: SlotNumber):

def _select_block_roots(
self, duty_epoch: EpochNumber, block_roots: list[BlockRoot | None], checkpoint_slot: SlotNumber
) -> list[tuple[SlotNumber, BlockRoot | None]]:
) -> tuple[list[SlotBlockRoot], list[SlotBlockRoot]]:
roots_to_check = []
# To check duties in the current epoch you need to
# have 32 slots of the current epoch and 32 slots of the next epoch
Expand All @@ -172,13 +171,19 @@ def _select_block_roots(
raise ValueError("Slot is out of the state block roots range")
roots_to_check.append((slot_to_check, block_roots[slot_to_check % SLOTS_PER_HISTORICAL_ROOT]))

return roots_to_check
duty_epoch_roots, next_epoch_roots = roots_to_check[:32], roots_to_check[32:]

def _process(self, unprocessed_epochs: list[EpochNumber], duty_epochs_roots: dict[EpochNumber, list[tuple[SlotNumber, BlockRoot | None]]]):
return duty_epoch_roots, next_epoch_roots

def _process(
self,
unprocessed_epochs: list[EpochNumber],
duty_epochs_roots: dict[EpochNumber, tuple[list[SlotBlockRoot], list[SlotBlockRoot]]]
):
executor = ThreadPoolExecutor(max_workers=variables.CSM_ORACLE_MAX_CONCURRENCY)
try:
futures = {
executor.submit(self._check_duty, duty_epoch, duty_epochs_roots[duty_epoch])
executor.submit(self._check_duties, duty_epoch, *duty_epochs_roots[duty_epoch])
for duty_epoch in unprocessed_epochs
}
for future in as_completed(futures):
Expand All @@ -192,45 +197,51 @@ def _process(self, unprocessed_epochs: list[EpochNumber], duty_epochs_roots: dic
logger.info({"msg": "The executor was shut down"})

@timeit(lambda args, duration: logger.info({"msg": f"Epoch {args.duty_epoch} processed in {duration:.2f} seconds"}))
def _check_duty(
def _check_duties(
self,
duty_epoch: EpochNumber,
two_epochs_block_roots: list[tuple[SlotNumber, BlockRoot | None]],
duty_epoch_roots: list[SlotBlockRoot],
next_epoch_roots: list[SlotBlockRoot],
):

logger.info({"msg": f"Processing epoch {duty_epoch}"})

att_committees = self._prepare_att_committees(EpochNumber(duty_epoch))
for slot, root in two_epochs_block_roots:
missed = root is None
if not missed:
# TODO: should we use get_block_details here?
attestations = self.cc.get_block_attestations(BlockRoot(root))
process_attestations(attestations, att_committees)

duty_epoch_block_roots = two_epochs_block_roots[:32]
propose_duties = self._prepare_propose_duties(EpochNumber(duty_epoch))
sync_committee = self._prepare_sync_committee(EpochNumber(duty_epoch), duty_epoch_block_roots)
process_sync(sync_committee, duty_epoch_block_roots)
process_proposals(propose_duties, duty_epoch_block_roots)
propose_duties = self._prepare_propose_duties(EpochNumber(duty_epoch), self.finalized_blockstamp)
sync_committees = self._prepare_sync_committee(EpochNumber(duty_epoch), duty_epoch_roots)
for slot, root in [*duty_epoch_roots, *next_epoch_roots]:
missed_slot = root is None
if missed_slot:
continue
attestations, sync_aggregate = self.cc.get_block_attestations_and_sync(BlockRoot(root))
process_attestations(attestations, att_committees)
if root in duty_epoch_roots:
propose_duties[slot].included = True
process_sync(slot, sync_aggregate, sync_committees)

with lock:
if duty_epoch not in self.state.unprocessed_epochs:
raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed")
frame = self.state.find_frame(duty_epoch)
for att_committee in att_committees.values():
for att_duty in att_committee:
self.state.increment_att_duty(
duty_epoch,
frame,
att_duty.index,
included=att_duty.included,
)
for sync_duty in sync_committee:
self.state.increment_sync_duty(
duty_epoch,
sync_duty.index,
included=sync_duty.included,
)
for sync_committee in sync_committees.values():
for sync_duty in sync_committee:
self.state.increment_sync_duty(
frame,
sync_duty.index,
included=sync_duty.included,
)
for proposer_duty in propose_duties.values():
self.state.increment_prop_duty(duty_epoch, proposer_duty.index, proposer_duty.included)
self.state.increment_prop_duty(
frame,
proposer_duty.index,
included=proposer_duty.included
)
self.state.add_processed_epoch(duty_epoch)
self.state.log_progress()
unprocessed_epochs = self.state.unprocessed_epochs
Expand Down Expand Up @@ -258,12 +269,15 @@ def _prepare_att_committees(self, epoch: EpochNumber) -> AttestationCommittees:
)
)
def _prepare_sync_committee(
self, epoch: EpochNumber, block_roots: list[tuple[SlotNumber, BlockRoot | None]]
) -> list[str]: # TODO: list ValidatorIndex
self, epoch: EpochNumber, block_roots: list[SlotBlockRoot]
) -> dict[SlotNumber, list[ValidatorDuty]]:
# TODO: should be under lock?
sync_committee_epochs = epoch % EPOCHS_PER_SYNC_COMMITTEE_PERIOD
# TODO: check real committee and from this func on border cases
if not self.current_sync_committee or sync_committee_epochs == 0:
for_epochs = EPOCHS_PER_SYNC_COMMITTEE_PERIOD - sync_committee_epochs
logger.info({"msg": f"Prepare sync committee {for_epochs=}"})
# TODO: do better
epochs_range = EPOCHS_PER_SYNC_COMMITTEE_PERIOD - sync_committee_epochs
logger.info({"msg": f"Preparing Sync Committee for {epochs_range} epochs"})
first_slot_root, *_ = block_roots
slot, _ = first_slot_root
blockstamp = build_blockstamp(
Expand All @@ -275,20 +289,49 @@ def _prepare_sync_committee(
)
# TODO: can we use lru cache here?
self.current_sync_committee = self.cc.get_sync_committee(blockstamp, epoch)
return self.current_sync_committee.validators
duties = {}
for slot, root in block_roots:
missed_slot = root is None
if missed_slot:
continue
duties[slot] = [
ValidatorDuty(index=ValidatorIndex(int(validator)), included=False)
for validator in self.current_sync_committee.validators
]

return duties

@timeit(
lambda args, duration: logger.info(
{"msg": f"Propose Duties for epoch {args.epoch} prepared in {duration:.2f} seconds"}
)
)
def _prepare_propose_duties(self, epoch: EpochNumber) -> dict[SlotNumber, ValidatorDuty]:
def _prepare_propose_duties(self, epoch: EpochNumber, blockstamp: BlockStamp) -> dict[SlotNumber, ValidatorDuty]:
duties = {}
for duty in self.cc.get_proposer_duties(epoch):
duties[SlotNumber(int(duty.slot))] = ValidatorDuty(index=ValidatorIndex(int(duty.validator_index)), included=False)
dependent_slot = self.converter.get_epoch_last_slot(EpochNumber(epoch - 1))
# TODO: can we just take root from the state?
dependent_non_missed_slot = SlotNumber(int(
get_prev_non_missed_slot(
self.cc,
dependent_slot,
blockstamp.slot_number
).message.slot)
)
for duty in self.cc.get_proposer_duties(epoch, dependent_non_missed_slot):
duties[SlotNumber(int(duty.slot))] = ValidatorDuty(
index=ValidatorIndex(int(duty.validator_index)), included=False
)
return duties


def process_sync(slot: SlotNumber, sync_aggregate: SyncAggregate, committees: SyncCommittees) -> None:
committee = committees[slot]
# Spec: https://github.com/ethereum/consensus-specs/blob/dev/specs/altair/beacon-chain.md#syncaggregate
sync_bits = hex_bitvector_to_list(sync_aggregate.sync_committee_bits)
for index_in_committee in get_set_indices(sync_bits):
committee[index_in_committee].included = True


def process_attestations(attestations: Iterable[BlockAttestation], committees: AttestationCommittees) -> None:
for attestation in attestations:
committee_offset = 0
Expand Down Expand Up @@ -325,31 +368,6 @@ def hex_bitvector_to_list(bitvector: str) -> list[bool]:
return _bytes_to_bool_list(bytes_)


def process_sync(committee: list[str], block_roots: list[tuple[SlotNumber, BlockRoot | None]]):
duties = {}
for slot, root in block_roots:
if root is None:
continue
sync_bits = self.cc.get_block_details(BlockRoot(root)).message.body.sync_aggregate.sync_committee_bits
sync_bits = _to_bits(sync_bits)
for index_in_committee, validator_index in enumerate(committee):
duty = ValidatorDuty(index=ValidatorIndex(int(validator_index)), included=False)
duty.included = duty.included or _is_attested(sync_bits, index_in_committee)
duties[slot] = duty
return duties


def process_proposals(
duties: dict[SlotNumber, ValidatorDuty], block_roots: list[tuple[SlotNumber, BlockRoot | None]]
) -> None:
for slot, root in block_roots:
duties[slot].included = root is not None


def _is_attested(bits: Sequence[bool], index: int) -> bool:
return bits[index]


def hex_bitlist_to_list(bitlist: str) -> list[bool]:
bytes_ = hex_str_to_bytes(bitlist)
if not bytes_ or bytes_[-1] == 0:
Expand Down
46 changes: 37 additions & 9 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def calculate_distribution(
shares = defaultdict[NodeOperatorId, int](int)
logs: list[FramePerfLog] = []

for frame in self.state.att_data:
frames = self.state.calculate_frames(self.state._epochs_to_process, self.state._epochs_per_frame)
for frame in frames:
from_epoch, to_epoch = frame
logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"})
frame_blockstamp = blockstamp
Expand Down Expand Up @@ -268,7 +269,12 @@ def _calculate_distribution_in_frame(
frame: Frame,
distributed: int,
):
network_perf = self.state.get_network_aggr(frame).perf
att_network_perf = self.state.get_att_network_aggr(frame).perf
prop_network_perf = self.state.get_prop_network_aggr(frame).perf
sync_network_perf = self.state.get_sync_network_aggr(frame).perf

network_perf = 56/64 * att_network_perf + 8/64 * prop_network_perf + 2/64 * sync_network_perf

threshold = network_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS

# Build the map of the current distribution operators.
Expand All @@ -277,30 +283,52 @@ def _calculate_distribution_in_frame(
log = FramePerfLog(blockstamp, frame, threshold)

for (_, no_id), validators in operators_to_validators.items():
# TODO: Do we need to check it later to have other data in logs?
if no_id in stuck_operators:
log.operators[no_id].stuck = True
continue

for v in validators:
aggr = self.state.att_data[frame].get(ValidatorIndex(int(v.index)))
att_aggr = self.state.att_data[frame].get(ValidatorIndex(int(v.index)))
prop_aggr = self.state.prop_data[frame].get(ValidatorIndex(int(v.index)))
sync_aggr = self.state.sync_data[frame].get(ValidatorIndex(int(v.index)))

if aggr is None:
if att_aggr is None:
# It's possible that the validator is not assigned to any duty, hence it's performance
# is not presented in the aggregates (e.g. exited, pending for activation etc).
# TODO: do we need to check sync_aggr to strike the validator?
continue

log_data = log.operators[no_id].validators[v.index]

if v.validator.slashed is True:
# It means that validator was active during the frame and got slashed and didn't meet the exit
# epoch, so we should not count such validator for operator's share.
log.operators[no_id].validators[v.index].slashed = True
log_data.slashed = True
continue

if aggr.perf > threshold:
performance = att_aggr.perf

if prop_aggr is not None and sync_aggr is not None:
performance = 56/64 * att_aggr.perf + 8/64 * prop_aggr.perf + 2/64 * sync_aggr.perf

if prop_aggr is not None and sync_aggr is None:
performance = 56/62 * att_aggr.perf + 8/62 * prop_aggr.perf

if prop_aggr is None and sync_aggr is not None:
performance = 54 / 56 * att_aggr.perf + 2 / 56 * sync_aggr.perf

if performance > threshold:
# Count of assigned attestations used as a metrics of time
# the validator was active in the current frame.
distribution[no_id] += aggr.assigned

log.operators[no_id].validators[v.index].perf = aggr
distribution[no_id] += att_aggr.assigned

log_data.performance = performance
log_data.attestations = att_aggr
if prop_aggr is not None:
log_data.proposals = prop_aggr
if sync_aggr is not None:
log_data.sync_committee = sync_aggr

# Calculate share of each CSM node operator.
shares = defaultdict[NodeOperatorId, int](int)
Expand Down
6 changes: 4 additions & 2 deletions src/modules/csm/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ class LogJSONEncoder(json.JSONEncoder): ...

@dataclass
class ValidatorFrameSummary:
# TODO: Should be renamed. Perf means different things in different contexts
perf: DutyAccumulator = field(default_factory=DutyAccumulator)
attestations: DutyAccumulator = field(default_factory=DutyAccumulator)
proposals: DutyAccumulator = field(default_factory=DutyAccumulator)
sync_committee: DutyAccumulator = field(default_factory=DutyAccumulator)
performance: float = 0.0
slashed: bool = False


Expand Down
Loading

0 comments on commit 684afa5

Please sign in to comment.