Skip to content

Commit

Permalink
Merge pull request #21 from oflatt/oflatt-faster-dag
Browse files Browse the repository at this point in the history
Global DAG extraction
  • Loading branch information
oflatt authored Dec 13, 2023
2 parents d953273 + c722bd3 commit 934bab7
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 9 deletions.
32 changes: 32 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ version = "0.1.0"
ilp-cbc = ["coin_cbc"]

[dependencies]
env_logger = {version = "0.10.0", default-features = false}
env_logger = { version = "0.10.0", default-features = false }
indexmap = "2.0.0"
log = "0.4.19"
ordered-float = "3"
pico-args = {version = "0.5.0", features = ["eq-separator"]}
pico-args = { version = "0.5.0", features = ["eq-separator"] }

anyhow = "1.0.71"
coin_cbc = {version = "0.1.6", optional = true}
coin_cbc = { version = "0.1.6", optional = true }
im-rc = "15.1.0"

rpds = "1.1.0"
[dependencies.egraph-serialize]
git = "https://github.com/egraphs-good/egraph-serialize"
rev = "951b829a434f4008c7b45ba4ac0da1037d2da90"
27 changes: 21 additions & 6 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def load_jsons(files):
return js


def process(js, extractors=[]):
extractors = extractors or sorted(set(j["extractor"] for j in js))

def process(js, extractors):
by_name = {}
for j in js:
n, e = j["name"], j["extractor"]
Expand All @@ -39,7 +37,9 @@ def process(js, extractors=[]):
for name, d in by_name.items():
try:
if d[e1]["tree"] != d[e2]["tree"]:
print(name, d[e1]["tree"], d[e2]["tree"]);
print(name, " differs in tree cost: ", d[e1]["tree"], d[e2]["tree"]);
if d[e1]["dag"] != d[e2]["dag"]:
print(name, " differs in dag cost: ", d[e1]["dag"], d[e2]["dag"]);

tree_ratio = d[e1]["tree"] / d[e2]["tree"]
dag_ratio = d[e1]["dag"] / d[e2]["dag"]
Expand All @@ -56,10 +56,15 @@ def process(js, extractors=[]):
except Exception as e:
print(f"Error processing {name}")
raise e

print(f"cumulative time for {e1}: {e1_cumulative/1000:.0f}ms")
print(f"cumulative time for {e2}: {e2_cumulative/1000:.0f}ms")

print(f"cumulative tree cost for {e1}: {sum(d[e1]['tree'] for d in by_name.values()):.0f}")
print(f"cumulative tree cost for {e2}: {sum(d[e2]['tree'] for d in by_name.values()):.0f}")
print(f"cumulative dag cost for {e1}: {sum(d[e1]['dag'] for d in by_name.values()):.0f}")
print(f"cumulative dag cost for {e2}: {sum(d[e2]['dag'] for d in by_name.values()):.0f}")

print(f"{e1} / {e2}")

print("geo mean")
Expand Down Expand Up @@ -93,4 +98,14 @@ def quantiles(key):
files = sys.argv[1:] or glob.glob("output/**/*.json", recursive=True)
js = load_jsons(files)
print(f"Loaded {len(js)} jsons.")
process(js)

extractors = sorted(set(j["extractor"] for j in js))

for i in range(len(extractors)):
for j in range(i + 1, len(extractors)):
ex1, ex2 = extractors[i], extractors[j]
if ex1 == ex2:
continue
print(f"###################################################\n{ex1} vs {ex2}\n\n")
process(js, [ex1, ex2])
print("\n\n")
203 changes: 203 additions & 0 deletions src/extract/global_greedy_dag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
use std::iter;

use rpds::{HashTrieMap, HashTrieSet};

use super::*;

type TermId = usize;

#[derive(Clone, PartialEq, Eq, Hash)]
struct Term {
op: String,
children: Vec<TermId>,
}

type Reachable = HashTrieSet<ClassId>;

struct TermInfo {
node: NodeId,
eclass: ClassId,
node_cost: Cost,
total_cost: Cost,
// store the set of reachable terms from this term
reachable: Reachable,
size: usize,
}

/// A TermDag needs to store terms that share common
/// subterms using a hashmap.
/// However, it also critically needs to be able to answer
/// reachability queries in this dag `reachable`.
/// This prevents double-counting costs when
/// computing the cost of a term.
#[derive(Default)]
pub struct TermDag {
nodes: Vec<Term>,
info: Vec<TermInfo>,
hash_cons: HashMap<Term, TermId>,
}

impl TermDag {
/// Makes a new term using a node and children terms
/// Correctly computes total_cost with sharing
/// If this term contains itself, returns None
/// If this term costs more than target, returns None
pub fn make(
&mut self,
node_id: NodeId,
node: &Node,
children: Vec<TermId>,
target: Cost,
) -> Option<TermId> {
let term = Term {
op: node.op.clone(),
children: children.clone(),
};

if let Some(id) = self.hash_cons.get(&term) {
return Some(*id);
}

let node_cost = node.cost;

if children.is_empty() {
let next_id = self.nodes.len();
self.nodes.push(term.clone());
self.info.push(TermInfo {
node: node_id,
eclass: node.eclass.clone(),
node_cost,
total_cost: node_cost,
reachable: iter::once(node.eclass.clone()).collect(),
size: 1,
});
self.hash_cons.insert(term, next_id);
Some(next_id)
} else {
// check if children contains this node, preventing cycles
// This is sound because `reachable` is the set of reachable eclasses
// from this term.
for child in &children {
if self.info[*child].reachable.contains(&node.eclass) {
return None;
}
}

let biggest_child = (0..children.len())
.max_by_key(|i| self.info[children[*i]].size)
.unwrap();

let mut cost = node_cost + self.total_cost(children[biggest_child]);
let mut reachable = self.info[children[biggest_child]].reachable.clone();
let next_id = self.nodes.len();

for child in children.iter() {
if cost > target {
return None;
}
let child_cost = self.get_cost(&mut reachable, *child);
cost += child_cost;
}

if cost > target {
return None;
}

reachable = reachable.insert(node.eclass.clone());

self.info.push(TermInfo {
node: node_id,
node_cost,
eclass: node.eclass.clone(),
total_cost: cost,
reachable,
size: 1 + children.iter().map(|c| self.info[*c].size).sum::<usize>(),
});
self.nodes.push(term.clone());
self.hash_cons.insert(term, next_id);
Some(next_id)
}
}

/// Return a new term, like this one but making use of shared terms.
/// Also return the cost of the new nodes.
fn get_cost(&self, shared: &mut Reachable, id: TermId) -> Cost {
let eclass = self.info[id].eclass.clone();

// This is the key to why this algorithm is faster than greedy_dag.
// While doing the set union between reachable sets, we can stop early
// if we find a shared term.
// Since the term with `id` is shared, the reachable set of `id` will already
// be in `shared`.
if shared.contains(&eclass) {
NotNan::<f64>::new(0.0).unwrap()
} else {
let mut cost = self.node_cost(id);
for child in &self.nodes[id].children {
let child_cost = self.get_cost(shared, *child);
cost += child_cost;
}
*shared = shared.insert(eclass);
cost
}
}

pub fn node_cost(&self, id: TermId) -> Cost {
self.info[id].node_cost
}

pub fn total_cost(&self, id: TermId) -> Cost {
self.info[id].total_cost
}
}

pub struct GlobalGreedyDagExtractor;
impl Extractor for GlobalGreedyDagExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
let mut keep_going = true;

let nodes = egraph.nodes.clone();
let mut termdag = TermDag::default();
let mut best_in_class: HashMap<ClassId, TermId> = HashMap::default();

let mut i = 0;
while keep_going {
i += 1;
println!("iteration {}", i);
keep_going = false;

'node_loop: for (node_id, node) in &nodes {
let mut children: Vec<TermId> = vec![];
// compute the cost set from the children
for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
if let Some(best) = best_in_class.get(child_cid) {
children.push(*best);
} else {
continue 'node_loop;
}
}

let old_cost = best_in_class
.get(&node.eclass)
.map(|id| termdag.total_cost(*id))
.unwrap_or(INFINITY);

if let Some(candidate) = termdag.make(node_id.clone(), node, children, old_cost) {
let cadidate_cost = termdag.total_cost(candidate);

if cadidate_cost < old_cost {
best_in_class.insert(node.eclass.clone(), candidate);
keep_going = true;
}
}
}
}

let mut result = ExtractionResult::default();
for (class, term) in best_in_class {
result.choose(class, termdag.info[term].node.clone());
}
result
}
}
1 change: 1 addition & 0 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
pub use crate::*;

pub mod bottom_up;
pub mod global_greedy_dag;
pub mod greedy_dag;
pub mod greedy_dag_1;

Expand Down
5 changes: 5 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn main() {
env_logger::init();

let extractors: IndexMap<&str, Box<dyn Extractor>> = [
("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()),
(
"greedy-dag",
extract::greedy_dag::GreedyDagExtractor.boxed(),
Expand All @@ -27,6 +28,10 @@ fn main() {
"faster-greedy-dag",
extract::greedy_dag_1::FasterGreedyDagExtractor.boxed(),
),
(
"global-greedy-dag",
extract::global_greedy_dag::GlobalGreedyDagExtractor.boxed(),
),
]
.into_iter()
.collect();
Expand Down

0 comments on commit 934bab7

Please sign in to comment.