Skip to content

Commit

Permalink
faster bottom up algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
oflatt committed Dec 5, 2023
1 parent d953273 commit 45caef8
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 0 deletions.
117 changes: 117 additions & 0 deletions src/extract/faster_bottom_up.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use super::*;

/// A faster bottom up extractor inspired by the faster-greedy-dag extractor.
pub struct BottomUpExtractor;

impl Extractor for BottomUpExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
// 1. build map from class to parent nodes
let mut parents = IndexMap::<ClassId, Vec<NodeId>>::default();
let n2c = |nid: &NodeId| egraph.nid_to_cid(nid);

for class in egraph.classes().values() {
parents.insert(class.id.clone(), Vec::new());
}
for class in egraph.classes().values() {
for node in &class.nodes {
for c in &egraph[node].children {
parents[n2c(c)].push(node.clone());
}
}
}

// 2. start analysis from leaves
let mut analysis_pending = UniqueQueue::default();

for class in egraph.classes().values() {
for node in &class.nodes {
if egraph[node].is_leaf() {
analysis_pending.insert(node.clone());
}
}
}

let mut result = ExtractionResult::default();
let mut costs = IndexMap::<ClassId, Cost>::default();

while let Some(node_id) = analysis_pending.pop() {
let class_id = n2c(&node_id);
let node = &egraph[&node_id];
if node.children.iter().all(|c| costs.contains_key(n2c(c))) {
let prev_cost = costs.get(class_id).unwrap_or(&INFINITY);

let cost = result.node_sum_cost(egraph, node, &costs);
if cost < *prev_cost {
result.choose(class_id.clone(), node_id.clone());
costs.insert(class_id.clone(), cost);
analysis_pending.extend(parents[class_id].iter().cloned());
}
} else {
analysis_pending.insert(node_id.clone());
}
}

result
}
}

/** A data structure to maintain a queue of unique elements.
Notably, insert/pop operations have O(1) expected amortized runtime complexity.
Thanks Trevor for the implementation!
*/
#[derive(Clone)]
#[cfg_attr(feature = "serde-1", derive(Serialize, Deserialize))]
pub(crate) struct UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
set: std::collections::HashSet<T>, // hashbrown::
queue: std::collections::VecDeque<T>,
}

impl<T> Default for UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
fn default() -> Self {
UniqueQueue {
set: std::collections::HashSet::default(),
queue: std::collections::VecDeque::new(),
}
}
}

impl<T> UniqueQueue<T>
where
T: Eq + std::hash::Hash + Clone,
{
pub fn insert(&mut self, t: T) {
if self.set.insert(t.clone()) {
self.queue.push_back(t);
}
}

pub fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = T>,
{
for t in iter.into_iter() {
self.insert(t);
}
}

pub fn pop(&mut self) -> Option<T> {
let res = self.queue.pop_front();
res.as_ref().map(|t| self.set.remove(t));
res
}

#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
let r = self.queue.is_empty();
debug_assert_eq!(r, self.set.is_empty());
r
}
}
1 change: 1 addition & 0 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
pub use crate::*;

pub mod bottom_up;
pub mod faster_bottom_up;
pub mod greedy_dag;
pub mod greedy_dag_1;

Expand Down
4 changes: 4 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ fn main() {
env_logger::init();

let extractors: IndexMap<&str, Box<dyn Extractor>> = [
(
"faster-bottom-up",
extract::faster_bottom_up::BottomUpExtractor.boxed(),
),
(
"greedy-dag",
extract::greedy_dag::GreedyDagExtractor.boxed(),
Expand Down

0 comments on commit 45caef8

Please sign in to comment.