Skip to content

Commit

Permalink
(WIP) Add a draft spec for RMN OffChain Blessing
Browse files Browse the repository at this point in the history
  • Loading branch information
rstout committed Jun 17, 2024
1 parent c651218 commit 2223947
Showing 1 changed file with 353 additions and 0 deletions.
353 changes: 353 additions & 0 deletions core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
# TODO: doc

from typing import List, Dict
from collections import defaultdict

RmnNodeId = str
ChainSelector = int

MAX_INTERVAL_LENGTH = 256


class Interval:
# TODO: doc, inclusive, exclusive ranges
def __init__(self, min: int, max: int):
# TODO: invariant check min <= max
self.min = min
self.max = max

def is_empty(self) -> bool:
return self.min == self.max


class RmnSig:
def __init__(self, rmn_node_id: RmnNodeId, sig: bytes):
self.rmn_node_id = rmn_node_id
self.sig = sig


class SignedInterval:
def __init__(self, interval: Interval, root: bytes, sigs: List[RmnSig]):
self.interval = interval
self.root = root
self.sigs = sigs


class CcipMessage:
def __init__(self, seq_num: int):
self.seq_num = seq_num
pass


class CommitQuery:
def __init__(
self,
rmn_max_seq_nums: Dict[ChainSelector, int],
signed_intervals: Dict[ChainSelector, SignedInterval]
):
self.rmn_max_seq_nums = rmn_max_seq_nums
self.signed_intervals = signed_intervals


class CommitObservation:
def __init__(
self,
max_seq_nums_on_dest_chain: Dict[ChainSelector, int],
max_seq_nums_from_source_chains: Dict[ChainSelector, int], # TODO: try without
messages: Dict[ChainSelector, List[CcipMessage]]
):
self.max_seq_nums_on_dest_chain = max_seq_nums_on_dest_chain
self.max_seq_nums_from_source_chains = max_seq_nums_from_source_chains
self.messages = messages


class CommitOutcome:
def __init__(
self,
next_intervals: Dict[ChainSelector, Interval],
signed_intervals: Dict[ChainSelector, SignedInterval]
):
self.next_intervals = next_intervals
self.signed_intervals = signed_intervals
return


class RmnNode:
def __init__(self, node_id: RmnNodeId, ip_address: bytes, pub_key: bytes, supported_chains: List[ChainSelector]):
self.node_id = node_id
self.ip_address = ip_address
self.pub_key = pub_key
self.supported_chains = supported_chains


class RmnClientConfig:
def __init__(self, rmn_nodes: List[RmnNode]):
self.rmn_nodes = rmn_nodes


class RmnClient:
def __init__(self, rmn_client_config: RmnClientConfig):
self.rmn_client_config = rmn_client_config

# TODO: doc
def request_max_seq_nums_from_single_node(
self,
rmn_node_id: RmnNodeId,
chains: List[ChainSelector]
) -> Dict[ChainSelector, int]:
pass

# TODO: doc
def request_signed_intervals_from_single_node(
self,
rmn_node_id: RmnNodeId,
intervals: Dict[ChainSelector, Interval]
) -> Dict[ChainSelector, SignedInterval]:
pass

# TODO: doc
def request_max_seq_nums(
self,
chains: List[ChainSelector]
) -> Dict[ChainSelector, int]:
pass

# TODO: doc
def request_signed_intervals(
self,
intervals: Dict[ChainSelector, Interval]
) -> Dict[ChainSelector, SignedInterval]:
pass


class ChainReader:
def __init__(self):
pass


class OffRamp:
def __init__(self):
pass

# TODO: doc
def get_max_seq_nums_on_dest_chain(self) -> Dict[ChainSelector, int]:
pass


class CommitPlugin:
def __init__(
self,
rmn_client: RmnClient,
all_source_chains: List[ChainSelector],
dest_chain: ChainSelector,
chain_readers: Dict[ChainSelector, ChainReader],
off_ramp: OffRamp,
f: int
):
self.rmn_client = rmn_client
self.all_source_chains = all_source_chains
self.dest_chain = dest_chain
self.chain_readers = chain_readers
self.off_ramp = off_ramp
self.f = f
return

# TODO: doc
def can_read_from_dest_chain(self) -> bool:
return self.dest_chain in self.chain_readers

# TODO: doc
def get_ccip_messages_from_source_chains(
self,
intervals: Dict[ChainSelector, Interval]
) -> Dict[ChainSelector, List[CcipMessage]]:
pass

# TODO: doc
def query(self, previous_outcome: CommitOutcome) -> CommitQuery:
max_seq_nums = self.rmn_client.request_max_seq_nums(self.all_source_chains)
signed_intervals = self.rmn_client.request_signed_intervals(previous_outcome.next_intervals)
return CommitQuery(max_seq_nums, signed_intervals)

# TODO: doc
def observation(self, previous_outcome: CommitOutcome) -> CommitObservation:
# Get persisted min seq nums from dest chain
max_seq_nums_on_dest_chain = {}
if self.can_read_from_dest_chain():
max_seq_nums_on_dest_chain = self.off_ramp.get_max_seq_nums_on_dest_chain()

# Get max seq nums from all chains that can be read from
max_seq_nums_from_source_chains = {}
for (chain_selector, chain_reader) in self.chain_readers.items():
max_seq_nums_from_source_chains[chain_selector] = chain_reader.get_max_seq_num(self.dest_chain)

# Get messages in previous_outcome.next_interval
messages = self.get_ccip_messages_from_source_chains(previous_outcome.next_intervals)

return CommitObservation(max_seq_nums_on_dest_chain, max_seq_nums_from_source_chains, messages)

# TODO: doc
def aggregate_observations(
self,
observations: List[CommitObservation]
) -> CommitObservation:
pass

# TODO: doc
def get_consensus_max_seq_nums_on_dest_chain(
self,
observations: List[CommitObservation]
) -> Dict[ChainSelector, int]:
counts = defaultdict(int)
for observation in observations:
# Convert the dictionary to a frozenset of items so it can be used as a key
if len(observation.max_seq_nums_on_dest_chain) > 0:
frozen_dict = frozenset(observation.max_seq_nums_on_dest_chain.items())
counts[frozen_dict] += 1

# Consensus on the onchain state is reached if there is only one max_seq_nums_on_dest_chain dict that is
# observed by more than f nodes
# TODO: more doc
candidates = []
for (candidate, count) in counts.items():
if count > self.f:
candidates.append(candidate)

if len(candidates) == 1:
return dict(candidates[0])
else:
return {}

# TODO: doc
# the interval mins in previous outcome should be one more than what's onchain
def chains_with_unexpected_onchain_state(
self,
max_seq_nums_on_dest_chain: Dict[ChainSelector, int],
previous_outcome: CommitOutcome,
) -> List[ChainSelector]:
pass

# TODO: doc
def compute_merkle_root(
self,
chain_selector: ChainSelector,
current_interval: Interval,
observations: List[CommitObservation]
) -> bytes:
pass

# TODO: doc
def verify_rmn_sigs(self, sigs: List[RmnSig]) -> bool:
pass

# TODO: doc
def get_rmn_threshold_for_chain(self, chain_selector: ChainSelector) -> int:
pass

# TODO: doc
def outcome(
self,
previous_outcome: CommitOutcome,
query: CommitQuery,
observations: List[CommitObservation]
) -> CommitOutcome:
# filter out entries from signed_intervals
# add entries that don't require RMN blessing
# Determine next intervals

# reconcile_observations()?
# what do you do with observations first?
# need to ensure you get enough observations that have the same max_seq_nums_on_dest_chain
# - from this you have a single max_seq_nums_on_dest_chain
# then you can see if there are any that don't match previous_outcome.next_intervals
# what do you do if there are conflicts?
# - don't compute merkle roots for these entries, remove from signed intervals?
# - maybe just add to a dict/list (of chains to exclude from report)

# check against previous_outcome
# if any conflicts

max_seq_nums_on_dest_chain = self.get_consensus_max_seq_nums_on_dest_chain(observations)
if len(max_seq_nums_on_dest_chain) == 0:
# TODO: return something
pass

next_intervals = {}
signed_intervals = {}
for chain_selector in self.all_source_chains:
# handle key-missing errors
current_interval = previous_outcome.next_intervals[chain_selector]
max_seq_num_on_dest_chain = max_seq_nums_on_dest_chain[chain_selector]

# what are the conditions where all you have to do is update the next interval?
# - when current_interval is empty
# - when current_interval is already persisted
# - i.e. max_seq_num_on_dest_chain >= current_interval.max
# - when the previous interval hasn't been persisted
# - when max_seq_num_on_dest_chain < current_interval.min - 1
# Happy path is when max_seq_num_on_dest_chain == current_interval.min - 1

# TODO: doc
if current_interval.min != max_seq_num_on_dest_chain + 1:
# Something unexpected happened, don't include this chain in the outcome, update next interval
pass

# TODO: doc
elif current_interval.is_empty():
next_max = current_interval.max
if chain_selector in query.rmn_max_seq_nums:
rmn_max_seq_num = query.rmn_max_seq_nums[chain_selector]
next_max = max(next_max, rmn_max_seq_num)
if next_max - current_interval.min > MAX_INTERVAL_LENGTH:
next_max = MAX_INTERVAL_LENGTH

next_intervals[chain_selector] = Interval(current_interval.min, next_max)

# TODO: doc
else:
# compute merkle root
merkle_root = self.compute_merkle_root(chain_selector, current_interval, observations)

if chain_selector in query.signed_intervals:
signed_interval = query.signed_intervals[chain_selector]
# TODO: explain this case
if (current_interval != signed_interval.interval or
merkle_root != signed_interval.root or
not self.verify_rmn_sigs(signed_interval.sigs)):
# TODO: update next_intervals, continue
pass
else:
signed_intervals[chain_selector] = signed_interval
# TODO: update next_intervals
else:
# TODO: doc
if self.get_rmn_threshold_for_chain(chain_selector) == 0:
signed_intervals[chain_selector] = SignedInterval(current_interval, merkle_root, [])
# TODO: update next_intervals
else:
# TODO: update next_intervals
pass

# Check RMN signed interval
# - if exists, construct message sequence for that interval
# - if that fails, update next_intervals, ???
# - compute merkle root, compare to RMN merkle root
# - if different:
# - update next_intervals to the same interval
# - don't add to signed_intervals
# - else:
# - check RMN sigs? if valid, add to signed_intervals, else don't

# maybe compute merkle root first, or check RMN range first
# we need to compute the merkle root in the case where RMN sigs not required
# if RMN interval doesn't match previous_outcome.next_intervals, something went wrong
# - update next_intervals, continue (the for loop, i.e. don't add entry to signed_intervals)
pass

# Handle case if interval is already persisted

# Handle case when the previous interval hasn't been persisted

# Handle happy path when max_seq_num_on_dest_chain == current_interval.min - 1

return CommitOutcome(next_intervals, signed_intervals)

0 comments on commit 2223947

Please sign in to comment.