Skip to content

Commit

Permalink
refactor: calculate_distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
vgorkavenko committed Feb 4, 2025
1 parent f48ce02 commit 76fdc39
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,26 +231,35 @@ def calculate_distribution(
"""Computes distribution of fee shares at the given timestamp"""
operators_to_validators = self.module_validators_by_node_operators(blockstamp)

distributed = 0
# Calculate share of each CSM node operator.
shares = defaultdict[NodeOperatorId, int](int)
total_distributed = 0
total_shares = defaultdict[NodeOperatorId, int](int)
logs: list[FramePerfLog] = []

for frame in self.state.frames:
from_epoch, to_epoch = frame
logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"})

frame_blockstamp = blockstamp
if to_epoch != blockstamp.ref_epoch:
frame_blockstamp = self._get_ref_blockstamp_for_frame(blockstamp, to_epoch)

total_to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash)
to_distribute_in_frame = total_to_distribute - total_distributed

distributed_in_frame, shares_in_frame, log = self._calculate_distribution_in_frame(
frame_blockstamp, operators_to_validators, frame, distributed
frame_blockstamp, operators_to_validators, frame, to_distribute_in_frame
)
distributed += distributed_in_frame

total_distributed += distributed_in_frame
if total_distributed > total_to_distribute:
raise CSMError(f"Invalid distribution: {total_distributed=} > {total_to_distribute=}")

for no_id, share in shares_in_frame.items():
shares[no_id] += share
total_shares[no_id] += share

logs.append(log)

return distributed, shares, logs
return total_distributed, total_shares, logs

def _get_ref_blockstamp_for_frame(
self, blockstamp: ReferenceBlockStamp, frame_ref_epoch: EpochNumber
Expand All @@ -268,7 +277,7 @@ def _calculate_distribution_in_frame(
blockstamp: ReferenceBlockStamp,
operators_to_validators: ValidatorsByNodeOperator,
frame: Frame,
distributed: int,
to_distribute: int,
):
network_perf = self._calculate_network_performance(frame)
threshold = self._calculate_threshold(network_perf, blockstamp)
Expand All @@ -281,7 +290,7 @@ def _calculate_distribution_in_frame(
for (_, no_id), validators in operators_to_validators.items():
self._process_operator(validators, no_id, stuck_operators, frame, log, distribution, threshold)

return self._finalize_distribution(distribution, blockstamp, distributed, log)
return self._finalize_distribution(distribution, to_distribute, log)

def _calculate_network_performance(self, frame: Frame) -> float:
att_perf = self.state.get_att_network_aggr(frame).perf
Expand Down Expand Up @@ -368,16 +377,14 @@ def _calculate_validator_performance(att_aggr, prop_aggr, sync_aggr) -> float:
raise ValueError(f"Invalid performance: {performance=}")
return performance

@staticmethod
def _finalize_distribution(
self,
distribution: dict[NodeOperatorId, int],
blockstamp: ReferenceBlockStamp,
distributed: int,
to_distribute: int,
log: FramePerfLog
) -> tuple[int, dict[NodeOperatorId, int], FramePerfLog]:
shares: dict[NodeOperatorId, int] = defaultdict(int)
total = sum(distribution.values())
to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - distributed
log.distributable = to_distribute

if not total:
Expand Down

0 comments on commit 76fdc39

Please sign in to comment.