From a1d27cf5ce7b8f1a16414d110fd904ba8dfe9375 Mon Sep 17 00:00:00 2001 From: Ryan Stout Date: Mon, 17 Jun 2024 13:27:49 -0700 Subject: [PATCH] (WIP) Add a draft spec for RMN OffChain Blessing --- .../ccip/spec/commit_plugin_rmn_ocb_draft.py | 311 ++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py diff --git a/core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py b/core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py new file mode 100644 index 00000000000..a3d76b24b4d --- /dev/null +++ b/core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py @@ -0,0 +1,311 @@ +# TODO: doc + +from typing import List, Dict, Optional +from dataclasses import dataclass +from collections import defaultdict + +RmnNodeId = str +ChainSelector = int + +MAX_INTERVAL_LENGTH = 256 + + +@dataclass +class Interval: + min: int + max: int + + def is_empty(self) -> bool: + return self.min == self.max + + +@dataclass +class RmnSig: + rmn_node_id: RmnNodeId + sig: bytes + + +@dataclass +class SignedInterval: + interval: Interval + root: bytes + sigs: List[RmnSig] + + +@dataclass +class CcipMessage: + seq_num: int + + +@dataclass +class CommitQuery: + rmn_max_seq_nums: Dict[ChainSelector, int] + signed_intervals: Dict[ChainSelector, SignedInterval] + + +@dataclass +class CommitObservation: + max_seq_nums_on_dest_chain: Dict[ChainSelector, int] + messages: Dict[ChainSelector, List[CcipMessage]] + + +@dataclass +class CommitOutcome: + next_intervals: Dict[ChainSelector, Interval] + signed_intervals: Dict[ChainSelector, SignedInterval] + + +@dataclass +class RmnNode: + node_id: RmnNodeId + ip_address: bytes + pub_key: bytes + supported_chains: List[ChainSelector] + + +@dataclass +class RmnClientConfig: + rmn_nodes: List[RmnNode] + + +@dataclass +class RmnClient: + rmn_client_config: RmnClientConfig + + # TODO: doc + def request_max_seq_nums_from_single_node( + self, + rmn_node_id: RmnNodeId, + chains: List[ChainSelector] + ) -> Dict[ChainSelector, int]: + pass + + # TODO: doc + def request_signed_intervals_from_single_node( + self, + rmn_node_id: RmnNodeId, + intervals: Dict[ChainSelector, Interval] + ) -> Dict[ChainSelector, SignedInterval]: + pass + + # TODO: doc + def request_max_seq_nums( + self, + chains: List[ChainSelector] + ) -> Dict[ChainSelector, int]: + pass + + # TODO: doc + def request_signed_intervals( + self, + intervals: Dict[ChainSelector, Interval] + ) -> Dict[ChainSelector, SignedInterval]: + pass + + +class ChainReader: + def __init__(self): + pass + + +class OffRamp: + def __init__(self): + pass + + # TODO: doc + def get_max_seq_nums_on_dest_chain(self) -> Dict[ChainSelector, int]: + pass + + +@dataclass +class CommitPlugin: + rmn_client: RmnClient + all_source_chains: List[ChainSelector] + dest_chain: ChainSelector + chain_readers: Dict[ChainSelector, ChainReader] + off_ramp: OffRamp + f: int + + # TODO: doc + def can_read_from_dest_chain(self) -> bool: + return self.dest_chain in self.chain_readers + + # TODO: doc + def get_ccip_messages_from_source_chains( + self, + intervals: Dict[ChainSelector, Interval] + ) -> Dict[ChainSelector, List[CcipMessage]]: + pass + + # TODO: doc + def query(self, previous_outcome: CommitOutcome) -> CommitQuery: + max_seq_nums = self.rmn_client.request_max_seq_nums(self.all_source_chains) + signed_intervals = self.rmn_client.request_signed_intervals(previous_outcome.next_intervals) + return CommitQuery(max_seq_nums, signed_intervals) + + # TODO: doc + def observation(self, previous_outcome: CommitOutcome) -> CommitObservation: + # Get persisted min seq nums from dest chain + max_seq_nums_on_dest_chain = {} + if self.can_read_from_dest_chain(): + max_seq_nums_on_dest_chain = self.off_ramp.get_max_seq_nums_on_dest_chain() + + # Get messages in previous_outcome.next_interval + messages = self.get_ccip_messages_from_source_chains(previous_outcome.next_intervals) + + return CommitObservation(max_seq_nums_on_dest_chain, messages) + + # TODO: doc + def get_consensus_max_seq_nums_on_dest_chain( + self, + observations: List[CommitObservation] + ) -> Dict[ChainSelector, int]: + counts = defaultdict(int) + for observation in observations: + # Convert the dictionary to a frozenset of items so it can be used as a key + if len(observation.max_seq_nums_on_dest_chain) > 0: + frozen_dict = frozenset(observation.max_seq_nums_on_dest_chain.items()) + counts[frozen_dict] += 1 + + # Consensus on the onchain state is reached if there is only one max_seq_nums_on_dest_chain dict that is + # observed by more than f nodes + # TODO: more doc + candidates = [] + for (candidate, count) in counts.items(): + if count > self.f: + candidates.append(candidate) + + if len(candidates) == 1: + return dict(candidates[0]) + else: + return {} + + # Compute the merkle root for a sequence of messages + def compute_merkle_root( + self, + messages: List[CcipMessage] + ) -> bytes: + pass + + # Return true if the given sigs are valid and sufficient signatures of merkle_root + def verify_rmn_sigs(self, merkle_root: bytes, rmn_threshold: int, sigs: List[RmnSig]) -> bool: + pass + + # Return the number of RMN signatures required for the given chain + def get_rmn_threshold_for_chain(self, chain_selector: ChainSelector) -> int: + pass + + # TODO: doc + # TODO: maybe rename to get_signed_interval_for_report + def get_verified_signed_interval( + self, + chain_selector: ChainSelector, + current_interval: Interval, + max_seq_num_on_dest_chain: int, + signed_intervals: Dict[ChainSelector, SignedInterval], + messages: List[CcipMessage] + ) -> Optional[SignedInterval]: + if current_interval.min != max_seq_num_on_dest_chain + 1: + return None + + merkle_root = self.compute_merkle_root(messages) + + if len(merkle_root) != 32: + return None + + rmn_threshold = self.get_rmn_threshold_for_chain(chain_selector) + if rmn_threshold > 0: + if chain_selector not in signed_intervals: + return None + else: + signed_interval = signed_intervals[chain_selector] + if (current_interval != signed_interval.interval or + merkle_root != signed_interval.root or + not self.verify_rmn_sigs(merkle_root, rmn_threshold, signed_interval.sigs)): + return None + else: + return signed_interval + else: + if messages[0].seq_num != max_seq_num_on_dest_chain + 1: + return None + interval = Interval(messages[0].seq_num, messages[-1].seq_num + 1) + return SignedInterval(interval, merkle_root, []) + + # Given a list of observations, return a mapping from chains to the message sequence that a sufficient number of + # observations have reached consensus on + def get_messages_consensus( + self, + previous_outcome: CommitOutcome, + observations: List[CommitObservation] + ) -> Dict[ChainSelector, List[CcipMessage]]: + # build Dict[ChainSelector, Dict[int, Dict[CcipMessage, int]]] + # map of chains to maps of seq nums to maps of CcipMessage to occurrence + pass + + # Given a SignedInterval that will be included in this round's report, return the interval to be used in the next + # round + def build_next_interval(self, signed_interval: SignedInterval, rmn_max_seq_num: Optional[int]) -> Interval: + interval_min = signed_interval.interval.max + if rmn_max_seq_num is None: + return Interval(interval_min, interval_min + MAX_INTERVAL_LENGTH) + else: + interval_max = max(signed_interval.interval.max, rmn_max_seq_num) + if interval_max - interval_min > MAX_INTERVAL_LENGTH: + interval_max = interval_min + MAX_INTERVAL_LENGTH + return Interval(interval_min, interval_max) + + # TODO: doc, impl + def rebuild_current_interval( + self, + current_interval: Interval, + max_seq_num_on_dest_chain: Optional[int], + rmn_max_seq_num: Optional[int] + ) -> Interval: + interval_min = current_interval.min + if max_seq_num_on_dest_chain is not None: + interval_min = max_seq_num_on_dest_chain + 1 + + pass + + # TODO: doc + def outcome( + self, + previous_outcome: CommitOutcome, + query: CommitQuery, + observations: List[CommitObservation] + ) -> CommitOutcome: + max_seq_nums_on_dest_chain = self.get_consensus_max_seq_nums_on_dest_chain(observations) + if len(max_seq_nums_on_dest_chain) == 0: + # TODO: doc + return CommitOutcome(previous_outcome.next_intervals, {}) + + messages = self.get_messages_consensus(previous_outcome, observations) + + next_intervals = {} + signed_intervals = {} + for chain_selector in self.all_source_chains: + # TODO: handle key-missing errors + current_interval = previous_outcome.next_intervals[chain_selector] + max_seq_num_on_dest_chain = max_seq_nums_on_dest_chain[chain_selector] + + signed_interval = self.get_verified_signed_interval( + chain_selector, + current_interval, + max_seq_num_on_dest_chain, + query.signed_intervals, + messages[chain_selector] + ) + + rmn_max_seq_num = query.rmn_max_seq_nums.get(chain_selector) + + # TODO: doc + if signed_interval is None: + next_interval = self.rebuild_current_interval(current_interval, max_seq_num_on_dest_chain, + rmn_max_seq_num) + next_intervals[chain_selector] = next_interval + else: + next_interval = self.build_next_interval(signed_interval, rmn_max_seq_num) + next_intervals[chain_selector] = next_interval + signed_intervals[chain_selector] = signed_interval + + return CommitOutcome(next_intervals, signed_intervals)