diff --git a/atrium-repo/src/mst.rs b/atrium-repo/src/mst.rs index 72f9fa66..8951b459 100644 --- a/atrium-repo/src/mst.rs +++ b/atrium-repo/src/mst.rs @@ -133,6 +133,37 @@ mod algos { } } + /// Traverse through the tree, finding the node that contains a key. This will record + /// the CIDs of all nodes traversed. + pub fn traverse_find_path<'a>( + key: &'a str, + ) -> impl FnMut(Node, Cid) -> Result, Error> + 'a { + move |node, cid| -> Result<_, Error> { + if let Some(index) = node.find_ge(key) { + if let Some(NodeEntry::Leaf(e)) = node.entries.get(index) { + if e.key == key { + return Ok(TraverseAction::Stop(cid)); + } + } + + // Check if the left neighbor is a tree, and if so, recurse into it. + if let Some(index) = index.checked_sub(1) { + if let Some(subtree) = node.entries.get(index).unwrap().tree() { + return Ok(TraverseAction::Continue((subtree.clone(), subtree.clone()))); + } else { + return Err(Error::KeyNotFound); + } + } else { + // There is no left neighbor. The key is not present. + return Err(Error::KeyNotFound); + } + } else { + // We've recursed into an empty node, so the key is not present in the tree. + return Err(Error::KeyNotFound); + } + } + } + /// Traverse through the tree, finding the first node that consists of more than just a single /// nested tree entry. pub fn traverse_prune() -> impl FnMut(Node, Cid) -> Result, Error> { @@ -688,6 +719,16 @@ impl Tree { Err(e) => Err(e), } } + + /// Returns the path to a node that contains the specified key (including the containing node). + /// + /// This is useful for exporting portions of the MST for e.g. generating firehose records. + pub async fn get_path(&mut self, key: &str) -> Result, Error> { + match algos::traverse(&mut self.storage, self.root, algos::traverse_find_path(key)).await { + Ok((node_path, cid)) => Ok(node_path.into_iter().map(|(_, cid)| cid).chain([cid])), + Err(e) => Err(e), + } + } } /// The location of an entry in a Merkle Search Tree.