-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
(WIP) Add a draft spec for RMN OffChain Blessing
- Loading branch information
Showing
1 changed file
with
311 additions
and
0 deletions.
There are no files selected for viewing
311 changes: 311 additions & 0 deletions
311
core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,311 @@ | ||
# TODO: doc | ||
|
||
from typing import List, Dict, Optional | ||
from dataclasses import dataclass | ||
from collections import defaultdict | ||
|
||
RmnNodeId = str | ||
ChainSelector = int | ||
|
||
MAX_INTERVAL_LENGTH = 256 | ||
|
||
|
||
@dataclass | ||
class Interval: | ||
min: int | ||
max: int | ||
|
||
def is_empty(self) -> bool: | ||
return self.min == self.max | ||
|
||
|
||
@dataclass | ||
class RmnSig: | ||
rmn_node_id: RmnNodeId | ||
sig: bytes | ||
|
||
|
||
@dataclass | ||
class SignedInterval: | ||
interval: Interval | ||
root: bytes | ||
sigs: List[RmnSig] | ||
|
||
|
||
@dataclass | ||
class CcipMessage: | ||
seq_num: int | ||
|
||
|
||
@dataclass | ||
class CommitQuery: | ||
rmn_max_seq_nums: Dict[ChainSelector, int] | ||
signed_intervals: Dict[ChainSelector, SignedInterval] | ||
|
||
|
||
@dataclass | ||
class CommitObservation: | ||
max_seq_nums_on_dest_chain: Dict[ChainSelector, int] | ||
messages: Dict[ChainSelector, List[CcipMessage]] | ||
|
||
|
||
@dataclass | ||
class CommitOutcome: | ||
next_intervals: Dict[ChainSelector, Interval] | ||
signed_intervals: Dict[ChainSelector, SignedInterval] | ||
|
||
|
||
@dataclass | ||
class RmnNode: | ||
node_id: RmnNodeId | ||
ip_address: bytes | ||
pub_key: bytes | ||
supported_chains: List[ChainSelector] | ||
|
||
|
||
@dataclass | ||
class RmnClientConfig: | ||
rmn_nodes: List[RmnNode] | ||
|
||
|
||
@dataclass | ||
class RmnClient: | ||
rmn_client_config: RmnClientConfig | ||
|
||
# TODO: doc | ||
def request_max_seq_nums_from_single_node( | ||
self, | ||
rmn_node_id: RmnNodeId, | ||
chains: List[ChainSelector] | ||
) -> Dict[ChainSelector, int]: | ||
pass | ||
|
||
# TODO: doc | ||
def request_signed_intervals_from_single_node( | ||
self, | ||
rmn_node_id: RmnNodeId, | ||
intervals: Dict[ChainSelector, Interval] | ||
) -> Dict[ChainSelector, SignedInterval]: | ||
pass | ||
|
||
# TODO: doc | ||
def request_max_seq_nums( | ||
self, | ||
chains: List[ChainSelector] | ||
) -> Dict[ChainSelector, int]: | ||
pass | ||
|
||
# TODO: doc | ||
def request_signed_intervals( | ||
self, | ||
intervals: Dict[ChainSelector, Interval] | ||
) -> Dict[ChainSelector, SignedInterval]: | ||
pass | ||
|
||
|
||
class ChainReader: | ||
def __init__(self): | ||
pass | ||
|
||
|
||
class OffRamp: | ||
def __init__(self): | ||
pass | ||
|
||
# TODO: doc | ||
def get_max_seq_nums_on_dest_chain(self) -> Dict[ChainSelector, int]: | ||
pass | ||
|
||
|
||
@dataclass | ||
class CommitPlugin: | ||
rmn_client: RmnClient | ||
all_source_chains: List[ChainSelector] | ||
dest_chain: ChainSelector | ||
chain_readers: Dict[ChainSelector, ChainReader] | ||
off_ramp: OffRamp | ||
f: int | ||
|
||
# TODO: doc | ||
def can_read_from_dest_chain(self) -> bool: | ||
return self.dest_chain in self.chain_readers | ||
|
||
# TODO: doc | ||
def get_ccip_messages_from_source_chains( | ||
self, | ||
intervals: Dict[ChainSelector, Interval] | ||
) -> Dict[ChainSelector, List[CcipMessage]]: | ||
pass | ||
|
||
# TODO: doc | ||
def query(self, previous_outcome: CommitOutcome) -> CommitQuery: | ||
max_seq_nums = self.rmn_client.request_max_seq_nums(self.all_source_chains) | ||
signed_intervals = self.rmn_client.request_signed_intervals(previous_outcome.next_intervals) | ||
return CommitQuery(max_seq_nums, signed_intervals) | ||
|
||
# TODO: doc | ||
def observation(self, previous_outcome: CommitOutcome) -> CommitObservation: | ||
# Get persisted min seq nums from dest chain | ||
max_seq_nums_on_dest_chain = {} | ||
if self.can_read_from_dest_chain(): | ||
max_seq_nums_on_dest_chain = self.off_ramp.get_max_seq_nums_on_dest_chain() | ||
|
||
# Get messages in previous_outcome.next_interval | ||
messages = self.get_ccip_messages_from_source_chains(previous_outcome.next_intervals) | ||
|
||
return CommitObservation(max_seq_nums_on_dest_chain, messages) | ||
|
||
# TODO: doc | ||
def get_consensus_max_seq_nums_on_dest_chain( | ||
self, | ||
observations: List[CommitObservation] | ||
) -> Dict[ChainSelector, int]: | ||
counts = defaultdict(int) | ||
for observation in observations: | ||
# Convert the dictionary to a frozenset of items so it can be used as a key | ||
if len(observation.max_seq_nums_on_dest_chain) > 0: | ||
frozen_dict = frozenset(observation.max_seq_nums_on_dest_chain.items()) | ||
counts[frozen_dict] += 1 | ||
|
||
# Consensus on the onchain state is reached if there is only one max_seq_nums_on_dest_chain dict that is | ||
# observed by more than f nodes | ||
# TODO: more doc | ||
candidates = [] | ||
for (candidate, count) in counts.items(): | ||
if count > self.f: | ||
candidates.append(candidate) | ||
|
||
if len(candidates) == 1: | ||
return dict(candidates[0]) | ||
else: | ||
return {} | ||
|
||
# Compute the merkle root for a sequence of messages | ||
def compute_merkle_root( | ||
self, | ||
messages: List[CcipMessage] | ||
) -> bytes: | ||
pass | ||
|
||
# Return true if the given sigs are valid and sufficient signatures of merkle_root | ||
def verify_rmn_sigs(self, merkle_root: bytes, rmn_threshold: int, sigs: List[RmnSig]) -> bool: | ||
pass | ||
|
||
# Return the number of RMN signatures required for the given chain | ||
def get_rmn_threshold_for_chain(self, chain_selector: ChainSelector) -> int: | ||
pass | ||
|
||
# TODO: doc | ||
# TODO: maybe rename to get_signed_interval_for_report | ||
def get_verified_signed_interval( | ||
self, | ||
chain_selector: ChainSelector, | ||
current_interval: Interval, | ||
max_seq_num_on_dest_chain: int, | ||
signed_intervals: Dict[ChainSelector, SignedInterval], | ||
messages: List[CcipMessage] | ||
) -> Optional[SignedInterval]: | ||
if current_interval.min != max_seq_num_on_dest_chain + 1: | ||
return None | ||
|
||
merkle_root = self.compute_merkle_root(messages) | ||
|
||
if len(merkle_root) != 32: | ||
return None | ||
|
||
rmn_threshold = self.get_rmn_threshold_for_chain(chain_selector) | ||
if rmn_threshold > 0: | ||
if chain_selector not in signed_intervals: | ||
return None | ||
else: | ||
signed_interval = signed_intervals[chain_selector] | ||
if (current_interval != signed_interval.interval or | ||
merkle_root != signed_interval.root or | ||
not self.verify_rmn_sigs(merkle_root, rmn_threshold, signed_interval.sigs)): | ||
return None | ||
else: | ||
return signed_interval | ||
else: | ||
if messages[0].seq_num != max_seq_num_on_dest_chain + 1: | ||
return None | ||
interval = Interval(messages[0].seq_num, messages[-1].seq_num + 1) | ||
return SignedInterval(interval, merkle_root, []) | ||
|
||
# Given a list of observations, return a mapping from chains to the message sequence that a sufficient number of | ||
# observations have reached consensus on | ||
def get_messages_consensus( | ||
self, | ||
previous_outcome: CommitOutcome, | ||
observations: List[CommitObservation] | ||
) -> Dict[ChainSelector, List[CcipMessage]]: | ||
# build Dict[ChainSelector, Dict[int, Dict[CcipMessage, int]]] | ||
# map of chains to maps of seq nums to maps of CcipMessage to occurrence | ||
pass | ||
|
||
# Given a SignedInterval that will be included in this round's report, return the interval to be used in the next | ||
# round | ||
def build_next_interval(self, signed_interval: SignedInterval, rmn_max_seq_num: Optional[int]) -> Interval: | ||
interval_min = signed_interval.interval.max | ||
if rmn_max_seq_num is None: | ||
return Interval(interval_min, interval_min + MAX_INTERVAL_LENGTH) | ||
else: | ||
interval_max = max(signed_interval.interval.max, rmn_max_seq_num) | ||
if interval_max - interval_min > MAX_INTERVAL_LENGTH: | ||
interval_max = interval_min + MAX_INTERVAL_LENGTH | ||
return Interval(interval_min, interval_max) | ||
|
||
# TODO: doc, impl | ||
def rebuild_current_interval( | ||
self, | ||
current_interval: Interval, | ||
max_seq_num_on_dest_chain: Optional[int], | ||
rmn_max_seq_num: Optional[int] | ||
) -> Interval: | ||
interval_min = current_interval.min | ||
if max_seq_num_on_dest_chain is not None: | ||
interval_min = max_seq_num_on_dest_chain + 1 | ||
|
||
pass | ||
|
||
# TODO: doc | ||
def outcome( | ||
self, | ||
previous_outcome: CommitOutcome, | ||
query: CommitQuery, | ||
observations: List[CommitObservation] | ||
) -> CommitOutcome: | ||
max_seq_nums_on_dest_chain = self.get_consensus_max_seq_nums_on_dest_chain(observations) | ||
if len(max_seq_nums_on_dest_chain) == 0: | ||
# TODO: doc | ||
return CommitOutcome(previous_outcome.next_intervals, {}) | ||
|
||
messages = self.get_messages_consensus(previous_outcome, observations) | ||
|
||
next_intervals = {} | ||
signed_intervals = {} | ||
for chain_selector in self.all_source_chains: | ||
# TODO: handle key-missing errors | ||
current_interval = previous_outcome.next_intervals[chain_selector] | ||
max_seq_num_on_dest_chain = max_seq_nums_on_dest_chain[chain_selector] | ||
|
||
signed_interval = self.get_verified_signed_interval( | ||
chain_selector, | ||
current_interval, | ||
max_seq_num_on_dest_chain, | ||
query.signed_intervals, | ||
messages[chain_selector] | ||
) | ||
|
||
rmn_max_seq_num = query.rmn_max_seq_nums.get(chain_selector) | ||
|
||
# TODO: doc | ||
if signed_interval is None: | ||
next_interval = self.rebuild_current_interval(current_interval, max_seq_num_on_dest_chain, | ||
rmn_max_seq_num) | ||
next_intervals[chain_selector] = next_interval | ||
else: | ||
next_interval = self.build_next_interval(signed_interval, rmn_max_seq_num) | ||
next_intervals[chain_selector] = next_interval | ||
signed_intervals[chain_selector] = signed_interval | ||
|
||
return CommitOutcome(next_intervals, signed_intervals) |