Skip to content

Commit

Permalink
(WIP) Add a draft spec for RMN OffChain Blessing
Browse files Browse the repository at this point in the history
  • Loading branch information
rstout committed Jun 19, 2024
1 parent c651218 commit a1d27cf
Showing 1 changed file with 311 additions and 0 deletions.
311 changes: 311 additions & 0 deletions core/services/ocr3/plugins/ccip/spec/commit_plugin_rmn_ocb_draft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
# TODO: doc

from typing import List, Dict, Optional
from dataclasses import dataclass
from collections import defaultdict

RmnNodeId = str
ChainSelector = int

MAX_INTERVAL_LENGTH = 256


@dataclass
class Interval:
min: int
max: int

def is_empty(self) -> bool:
return self.min == self.max


@dataclass
class RmnSig:
rmn_node_id: RmnNodeId
sig: bytes


@dataclass
class SignedInterval:
interval: Interval
root: bytes
sigs: List[RmnSig]


@dataclass
class CcipMessage:
seq_num: int


@dataclass
class CommitQuery:
rmn_max_seq_nums: Dict[ChainSelector, int]
signed_intervals: Dict[ChainSelector, SignedInterval]


@dataclass
class CommitObservation:
max_seq_nums_on_dest_chain: Dict[ChainSelector, int]
messages: Dict[ChainSelector, List[CcipMessage]]


@dataclass
class CommitOutcome:
next_intervals: Dict[ChainSelector, Interval]
signed_intervals: Dict[ChainSelector, SignedInterval]


@dataclass
class RmnNode:
node_id: RmnNodeId
ip_address: bytes
pub_key: bytes
supported_chains: List[ChainSelector]


@dataclass
class RmnClientConfig:
rmn_nodes: List[RmnNode]


@dataclass
class RmnClient:
rmn_client_config: RmnClientConfig

# TODO: doc
def request_max_seq_nums_from_single_node(
self,
rmn_node_id: RmnNodeId,
chains: List[ChainSelector]
) -> Dict[ChainSelector, int]:
pass

# TODO: doc
def request_signed_intervals_from_single_node(
self,
rmn_node_id: RmnNodeId,
intervals: Dict[ChainSelector, Interval]
) -> Dict[ChainSelector, SignedInterval]:
pass

# TODO: doc
def request_max_seq_nums(
self,
chains: List[ChainSelector]
) -> Dict[ChainSelector, int]:
pass

# TODO: doc
def request_signed_intervals(
self,
intervals: Dict[ChainSelector, Interval]
) -> Dict[ChainSelector, SignedInterval]:
pass


class ChainReader:
def __init__(self):
pass


class OffRamp:
def __init__(self):
pass

# TODO: doc
def get_max_seq_nums_on_dest_chain(self) -> Dict[ChainSelector, int]:
pass


@dataclass
class CommitPlugin:
rmn_client: RmnClient
all_source_chains: List[ChainSelector]
dest_chain: ChainSelector
chain_readers: Dict[ChainSelector, ChainReader]
off_ramp: OffRamp
f: int

# TODO: doc
def can_read_from_dest_chain(self) -> bool:
return self.dest_chain in self.chain_readers

# TODO: doc
def get_ccip_messages_from_source_chains(
self,
intervals: Dict[ChainSelector, Interval]
) -> Dict[ChainSelector, List[CcipMessage]]:
pass

# TODO: doc
def query(self, previous_outcome: CommitOutcome) -> CommitQuery:
max_seq_nums = self.rmn_client.request_max_seq_nums(self.all_source_chains)
signed_intervals = self.rmn_client.request_signed_intervals(previous_outcome.next_intervals)
return CommitQuery(max_seq_nums, signed_intervals)

# TODO: doc
def observation(self, previous_outcome: CommitOutcome) -> CommitObservation:
# Get persisted min seq nums from dest chain
max_seq_nums_on_dest_chain = {}
if self.can_read_from_dest_chain():
max_seq_nums_on_dest_chain = self.off_ramp.get_max_seq_nums_on_dest_chain()

# Get messages in previous_outcome.next_interval
messages = self.get_ccip_messages_from_source_chains(previous_outcome.next_intervals)

return CommitObservation(max_seq_nums_on_dest_chain, messages)

# TODO: doc
def get_consensus_max_seq_nums_on_dest_chain(
self,
observations: List[CommitObservation]
) -> Dict[ChainSelector, int]:
counts = defaultdict(int)
for observation in observations:
# Convert the dictionary to a frozenset of items so it can be used as a key
if len(observation.max_seq_nums_on_dest_chain) > 0:
frozen_dict = frozenset(observation.max_seq_nums_on_dest_chain.items())
counts[frozen_dict] += 1

# Consensus on the onchain state is reached if there is only one max_seq_nums_on_dest_chain dict that is
# observed by more than f nodes
# TODO: more doc
candidates = []
for (candidate, count) in counts.items():
if count > self.f:
candidates.append(candidate)

if len(candidates) == 1:
return dict(candidates[0])
else:
return {}

# Compute the merkle root for a sequence of messages
def compute_merkle_root(
self,
messages: List[CcipMessage]
) -> bytes:
pass

# Return true if the given sigs are valid and sufficient signatures of merkle_root
def verify_rmn_sigs(self, merkle_root: bytes, rmn_threshold: int, sigs: List[RmnSig]) -> bool:
pass

# Return the number of RMN signatures required for the given chain
def get_rmn_threshold_for_chain(self, chain_selector: ChainSelector) -> int:
pass

# TODO: doc
# TODO: maybe rename to get_signed_interval_for_report
def get_verified_signed_interval(
self,
chain_selector: ChainSelector,
current_interval: Interval,
max_seq_num_on_dest_chain: int,
signed_intervals: Dict[ChainSelector, SignedInterval],
messages: List[CcipMessage]
) -> Optional[SignedInterval]:
if current_interval.min != max_seq_num_on_dest_chain + 1:
return None

merkle_root = self.compute_merkle_root(messages)

if len(merkle_root) != 32:
return None

rmn_threshold = self.get_rmn_threshold_for_chain(chain_selector)
if rmn_threshold > 0:
if chain_selector not in signed_intervals:
return None
else:
signed_interval = signed_intervals[chain_selector]
if (current_interval != signed_interval.interval or
merkle_root != signed_interval.root or
not self.verify_rmn_sigs(merkle_root, rmn_threshold, signed_interval.sigs)):
return None
else:
return signed_interval
else:
if messages[0].seq_num != max_seq_num_on_dest_chain + 1:
return None
interval = Interval(messages[0].seq_num, messages[-1].seq_num + 1)
return SignedInterval(interval, merkle_root, [])

# Given a list of observations, return a mapping from chains to the message sequence that a sufficient number of
# observations have reached consensus on
def get_messages_consensus(
self,
previous_outcome: CommitOutcome,
observations: List[CommitObservation]
) -> Dict[ChainSelector, List[CcipMessage]]:
# build Dict[ChainSelector, Dict[int, Dict[CcipMessage, int]]]
# map of chains to maps of seq nums to maps of CcipMessage to occurrence
pass

# Given a SignedInterval that will be included in this round's report, return the interval to be used in the next
# round
def build_next_interval(self, signed_interval: SignedInterval, rmn_max_seq_num: Optional[int]) -> Interval:
interval_min = signed_interval.interval.max
if rmn_max_seq_num is None:
return Interval(interval_min, interval_min + MAX_INTERVAL_LENGTH)
else:
interval_max = max(signed_interval.interval.max, rmn_max_seq_num)
if interval_max - interval_min > MAX_INTERVAL_LENGTH:
interval_max = interval_min + MAX_INTERVAL_LENGTH
return Interval(interval_min, interval_max)

# TODO: doc, impl
def rebuild_current_interval(
self,
current_interval: Interval,
max_seq_num_on_dest_chain: Optional[int],
rmn_max_seq_num: Optional[int]
) -> Interval:
interval_min = current_interval.min
if max_seq_num_on_dest_chain is not None:
interval_min = max_seq_num_on_dest_chain + 1

pass

# TODO: doc
def outcome(
self,
previous_outcome: CommitOutcome,
query: CommitQuery,
observations: List[CommitObservation]
) -> CommitOutcome:
max_seq_nums_on_dest_chain = self.get_consensus_max_seq_nums_on_dest_chain(observations)
if len(max_seq_nums_on_dest_chain) == 0:
# TODO: doc
return CommitOutcome(previous_outcome.next_intervals, {})

messages = self.get_messages_consensus(previous_outcome, observations)

next_intervals = {}
signed_intervals = {}
for chain_selector in self.all_source_chains:
# TODO: handle key-missing errors
current_interval = previous_outcome.next_intervals[chain_selector]
max_seq_num_on_dest_chain = max_seq_nums_on_dest_chain[chain_selector]

signed_interval = self.get_verified_signed_interval(
chain_selector,
current_interval,
max_seq_num_on_dest_chain,
query.signed_intervals,
messages[chain_selector]
)

rmn_max_seq_num = query.rmn_max_seq_nums.get(chain_selector)

# TODO: doc
if signed_interval is None:
next_interval = self.rebuild_current_interval(current_interval, max_seq_num_on_dest_chain,
rmn_max_seq_num)
next_intervals[chain_selector] = next_interval
else:
next_interval = self.build_next_interval(signed_interval, rmn_max_seq_num)
next_intervals[chain_selector] = next_interval
signed_intervals[chain_selector] = signed_interval

return CommitOutcome(next_intervals, signed_intervals)

0 comments on commit a1d27cf

Please sign in to comment.