-
Notifications
You must be signed in to change notification settings - Fork 57
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
(WIP) Add a draft spec for RMN OffChain Blessing
- Loading branch information
Showing
1 changed file
with
353 additions
and
0 deletions.
There are no files selected for viewing
353 changes: 353 additions & 0 deletions
353
core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,353 @@ | ||
# TODO: doc | ||
|
||
from typing import List, Dict | ||
from collections import defaultdict | ||
|
||
RmnNodeId = str | ||
ChainSelector = int | ||
|
||
MAX_INTERVAL_LENGTH = 256 | ||
|
||
|
||
class Interval: | ||
# TODO: doc, inclusive, exclusive ranges | ||
def __init__(self, min: int, max: int): | ||
# TODO: invariant check min <= max | ||
self.min = min | ||
self.max = max | ||
|
||
def is_empty(self) -> bool: | ||
return self.min == self.max | ||
|
||
|
||
class RmnSig: | ||
def __init__(self, rmn_node_id: RmnNodeId, sig: bytes): | ||
self.rmn_node_id = rmn_node_id | ||
self.sig = sig | ||
|
||
|
||
class SignedInterval: | ||
def __init__(self, interval: Interval, root: bytes, sigs: List[RmnSig]): | ||
self.interval = interval | ||
self.root = root | ||
self.sigs = sigs | ||
|
||
|
||
class CcipMessage: | ||
def __init__(self, seq_num: int): | ||
self.seq_num = seq_num | ||
pass | ||
|
||
|
||
class CommitQuery: | ||
def __init__( | ||
self, | ||
rmn_max_seq_nums: Dict[ChainSelector, int], | ||
signed_intervals: Dict[ChainSelector, SignedInterval] | ||
): | ||
self.rmn_max_seq_nums = rmn_max_seq_nums | ||
self.signed_intervals = signed_intervals | ||
|
||
|
||
class CommitObservation: | ||
def __init__( | ||
self, | ||
max_seq_nums_on_dest_chain: Dict[ChainSelector, int], | ||
max_seq_nums_from_source_chains: Dict[ChainSelector, int], # TODO: try without | ||
messages: Dict[ChainSelector, List[CcipMessage]] | ||
): | ||
self.max_seq_nums_on_dest_chain = max_seq_nums_on_dest_chain | ||
self.max_seq_nums_from_source_chains = max_seq_nums_from_source_chains | ||
self.messages = messages | ||
|
||
|
||
class CommitOutcome: | ||
def __init__( | ||
self, | ||
next_intervals: Dict[ChainSelector, Interval], | ||
signed_intervals: Dict[ChainSelector, SignedInterval] | ||
): | ||
self.next_intervals = next_intervals | ||
self.signed_intervals = signed_intervals | ||
return | ||
|
||
|
||
class RmnNode: | ||
def __init__(self, node_id: RmnNodeId, ip_address: bytes, pub_key: bytes, supported_chains: List[ChainSelector]): | ||
self.node_id = node_id | ||
self.ip_address = ip_address | ||
self.pub_key = pub_key | ||
self.supported_chains = supported_chains | ||
|
||
|
||
class RmnClientConfig: | ||
def __init__(self, rmn_nodes: List[RmnNode]): | ||
self.rmn_nodes = rmn_nodes | ||
|
||
|
||
class RmnClient: | ||
def __init__(self, rmn_client_config: RmnClientConfig): | ||
self.rmn_client_config = rmn_client_config | ||
|
||
# TODO: doc | ||
def request_max_seq_nums_from_single_node( | ||
self, | ||
rmn_node_id: RmnNodeId, | ||
chains: List[ChainSelector] | ||
) -> Dict[ChainSelector, int]: | ||
pass | ||
|
||
# TODO: doc | ||
def request_signed_intervals_from_single_node( | ||
self, | ||
rmn_node_id: RmnNodeId, | ||
intervals: Dict[ChainSelector, Interval] | ||
) -> Dict[ChainSelector, SignedInterval]: | ||
pass | ||
|
||
# TODO: doc | ||
def request_max_seq_nums( | ||
self, | ||
chains: List[ChainSelector] | ||
) -> Dict[ChainSelector, int]: | ||
pass | ||
|
||
# TODO: doc | ||
def request_signed_intervals( | ||
self, | ||
intervals: Dict[ChainSelector, Interval] | ||
) -> Dict[ChainSelector, SignedInterval]: | ||
pass | ||
|
||
|
||
class ChainReader: | ||
def __init__(self): | ||
pass | ||
|
||
|
||
class OffRamp: | ||
def __init__(self): | ||
pass | ||
|
||
# TODO: doc | ||
def get_max_seq_nums_on_dest_chain(self) -> Dict[ChainSelector, int]: | ||
pass | ||
|
||
|
||
class CommitPlugin: | ||
def __init__( | ||
self, | ||
rmn_client: RmnClient, | ||
all_source_chains: List[ChainSelector], | ||
dest_chain: ChainSelector, | ||
chain_readers: Dict[ChainSelector, ChainReader], | ||
off_ramp: OffRamp, | ||
f: int | ||
): | ||
self.rmn_client = rmn_client | ||
self.all_source_chains = all_source_chains | ||
self.dest_chain = dest_chain | ||
self.chain_readers = chain_readers | ||
self.off_ramp = off_ramp | ||
self.f = f | ||
return | ||
|
||
# TODO: doc | ||
def can_read_from_dest_chain(self) -> bool: | ||
return self.dest_chain in self.chain_readers | ||
|
||
# TODO: doc | ||
def get_ccip_messages_from_source_chains( | ||
self, | ||
intervals: Dict[ChainSelector, Interval] | ||
) -> Dict[ChainSelector, List[CcipMessage]]: | ||
pass | ||
|
||
# TODO: doc | ||
def query(self, previous_outcome: CommitOutcome) -> CommitQuery: | ||
max_seq_nums = self.rmn_client.request_max_seq_nums(self.all_source_chains) | ||
signed_intervals = self.rmn_client.request_signed_intervals(previous_outcome.next_intervals) | ||
return CommitQuery(max_seq_nums, signed_intervals) | ||
|
||
# TODO: doc | ||
def observation(self, previous_outcome: CommitOutcome) -> CommitObservation: | ||
# Get persisted min seq nums from dest chain | ||
max_seq_nums_on_dest_chain = {} | ||
if self.can_read_from_dest_chain(): | ||
max_seq_nums_on_dest_chain = self.off_ramp.get_max_seq_nums_on_dest_chain() | ||
|
||
# Get max seq nums from all chains that can be read from | ||
max_seq_nums_from_source_chains = {} | ||
for (chain_selector, chain_reader) in self.chain_readers.items(): | ||
max_seq_nums_from_source_chains[chain_selector] = chain_reader.get_max_seq_num(self.dest_chain) | ||
|
||
# Get messages in previous_outcome.next_interval | ||
messages = self.get_ccip_messages_from_source_chains(previous_outcome.next_intervals) | ||
|
||
return CommitObservation(max_seq_nums_on_dest_chain, max_seq_nums_from_source_chains, messages) | ||
|
||
# TODO: doc | ||
def aggregate_observations( | ||
self, | ||
observations: List[CommitObservation] | ||
) -> CommitObservation: | ||
pass | ||
|
||
# TODO: doc | ||
def get_consensus_max_seq_nums_on_dest_chain( | ||
self, | ||
observations: List[CommitObservation] | ||
) -> Dict[ChainSelector, int]: | ||
counts = defaultdict(int) | ||
for observation in observations: | ||
# Convert the dictionary to a frozenset of items so it can be used as a key | ||
if len(observation.max_seq_nums_on_dest_chain) > 0: | ||
frozen_dict = frozenset(observation.max_seq_nums_on_dest_chain.items()) | ||
counts[frozen_dict] += 1 | ||
|
||
# Consensus on the onchain state is reached if there is only one max_seq_nums_on_dest_chain dict that is | ||
# observed by more than f nodes | ||
# TODO: more doc | ||
candidates = [] | ||
for (candidate, count) in counts.items(): | ||
if count > self.f: | ||
candidates.append(candidate) | ||
|
||
if len(candidates) == 1: | ||
return dict(candidates[0]) | ||
else: | ||
return {} | ||
|
||
# TODO: doc | ||
# the interval mins in previous outcome should be one more than what's onchain | ||
def chains_with_unexpected_onchain_state( | ||
self, | ||
max_seq_nums_on_dest_chain: Dict[ChainSelector, int], | ||
previous_outcome: CommitOutcome, | ||
) -> List[ChainSelector]: | ||
pass | ||
|
||
# TODO: doc | ||
def compute_merkle_root( | ||
self, | ||
chain_selector: ChainSelector, | ||
current_interval: Interval, | ||
observations: List[CommitObservation] | ||
) -> bytes: | ||
pass | ||
|
||
# TODO: doc | ||
def verify_rmn_sigs(self, sigs: List[RmnSig]) -> bool: | ||
pass | ||
|
||
# TODO: doc | ||
def get_rmn_threshold_for_chain(self, chain_selector: ChainSelector) -> int: | ||
pass | ||
|
||
# TODO: doc | ||
def outcome( | ||
self, | ||
previous_outcome: CommitOutcome, | ||
query: CommitQuery, | ||
observations: List[CommitObservation] | ||
) -> CommitOutcome: | ||
# filter out entries from signed_intervals | ||
# add entries that don't require RMN blessing | ||
# Determine next intervals | ||
|
||
# reconcile_observations()? | ||
# what do you do with observations first? | ||
# need to ensure you get enough observations that have the same max_seq_nums_on_dest_chain | ||
# - from this you have a single max_seq_nums_on_dest_chain | ||
# then you can see if there are any that don't match previous_outcome.next_intervals | ||
# what do you do if there are conflicts? | ||
# - don't compute merkle roots for these entries, remove from signed intervals? | ||
# - maybe just add to a dict/list (of chains to exclude from report) | ||
|
||
# check against previous_outcome | ||
# if any conflicts | ||
|
||
max_seq_nums_on_dest_chain = self.get_consensus_max_seq_nums_on_dest_chain(observations) | ||
if len(max_seq_nums_on_dest_chain) == 0: | ||
# TODO: return something | ||
pass | ||
|
||
next_intervals = {} | ||
signed_intervals = {} | ||
for chain_selector in self.all_source_chains: | ||
# handle key-missing errors | ||
current_interval = previous_outcome.next_intervals[chain_selector] | ||
max_seq_num_on_dest_chain = max_seq_nums_on_dest_chain[chain_selector] | ||
|
||
# what are the conditions where all you have to do is update the next interval? | ||
# - when current_interval is empty | ||
# - when current_interval is already persisted | ||
# - i.e. max_seq_num_on_dest_chain >= current_interval.max | ||
# - when the previous interval hasn't been persisted | ||
# - when max_seq_num_on_dest_chain < current_interval.min - 1 | ||
# Happy path is when max_seq_num_on_dest_chain == current_interval.min - 1 | ||
|
||
# TODO: doc | ||
if current_interval.min != max_seq_num_on_dest_chain + 1: | ||
# Something unexpected happened, don't include this chain in the outcome, update next interval | ||
pass | ||
|
||
# TODO: doc | ||
elif current_interval.is_empty(): | ||
next_max = current_interval.max | ||
if chain_selector in query.rmn_max_seq_nums: | ||
rmn_max_seq_num = query.rmn_max_seq_nums[chain_selector] | ||
next_max = max(next_max, rmn_max_seq_num) | ||
if next_max - current_interval.min > MAX_INTERVAL_LENGTH: | ||
next_max = MAX_INTERVAL_LENGTH | ||
|
||
next_intervals[chain_selector] = Interval(current_interval.min, next_max) | ||
|
||
# TODO: doc | ||
else: | ||
# compute merkle root | ||
merkle_root = self.compute_merkle_root(chain_selector, current_interval, observations) | ||
|
||
if chain_selector in query.signed_intervals: | ||
signed_interval = query.signed_intervals[chain_selector] | ||
# TODO: explain this case | ||
if (current_interval != signed_interval.interval or | ||
merkle_root != signed_interval.root or | ||
not self.verify_rmn_sigs(signed_interval.sigs)): | ||
# TODO: update next_intervals, continue | ||
pass | ||
else: | ||
signed_intervals[chain_selector] = signed_interval | ||
# TODO: update next_intervals | ||
else: | ||
# TODO: doc | ||
if self.get_rmn_threshold_for_chain(chain_selector) == 0: | ||
signed_intervals[chain_selector] = SignedInterval(current_interval, merkle_root, []) | ||
# TODO: update next_intervals | ||
else: | ||
# TODO: update next_intervals | ||
pass | ||
|
||
# Check RMN signed interval | ||
# - if exists, construct message sequence for that interval | ||
# - if that fails, update next_intervals, ??? | ||
# - compute merkle root, compare to RMN merkle root | ||
# - if different: | ||
# - update next_intervals to the same interval | ||
# - don't add to signed_intervals | ||
# - else: | ||
# - check RMN sigs? if valid, add to signed_intervals, else don't | ||
|
||
# maybe compute merkle root first, or check RMN range first | ||
# we need to compute the merkle root in the case where RMN sigs not required | ||
# if RMN interval doesn't match previous_outcome.next_intervals, something went wrong | ||
# - update next_intervals, continue (the for loop, i.e. don't add entry to signed_intervals) | ||
pass | ||
|
||
# Handle case if interval is already persisted | ||
|
||
# Handle case when the previous interval hasn't been persisted | ||
|
||
# Handle happy path when max_seq_num_on_dest_chain == current_interval.min - 1 | ||
|
||
return CommitOutcome(next_intervals, signed_intervals) |