Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global DAG extraction #21

Merged
merged 8 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ version = "0.1.0"
ilp-cbc = ["coin_cbc"]

[dependencies]
env_logger = {version = "0.10.0", default-features = false}
env_logger = { version = "0.10.0", default-features = false }
indexmap = "2.0.0"
log = "0.4.19"
ordered-float = "3"
pico-args = {version = "0.5.0", features = ["eq-separator"]}
pico-args = { version = "0.5.0", features = ["eq-separator"] }

anyhow = "1.0.71"
coin_cbc = {version = "0.1.6", optional = true}
coin_cbc = { version = "0.1.6", optional = true }
im-rc = "15.1.0"

rpds = "1.1.0"
[dependencies.egraph-serialize]
git = "https://github.com/egraphs-good/egraph-serialize"
rev = "951b829a434f4008c7b45ba4ac0da1037d2da90"
27 changes: 21 additions & 6 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def load_jsons(files):
return js


def process(js, extractors=[]):
extractors = extractors or sorted(set(j["extractor"] for j in js))

def process(js, extractors):
by_name = {}
for j in js:
n, e = j["name"], j["extractor"]
Expand All @@ -39,7 +37,9 @@ def process(js, extractors=[]):
for name, d in by_name.items():
try:
if d[e1]["tree"] != d[e2]["tree"]:
print(name, d[e1]["tree"], d[e2]["tree"]);
print(name, " differs in tree cost: ", d[e1]["tree"], d[e2]["tree"]);
if d[e1]["dag"] != d[e2]["dag"]:
print(name, " differs in dag cost: ", d[e1]["dag"], d[e2]["dag"]);

tree_ratio = d[e1]["tree"] / d[e2]["tree"]
dag_ratio = d[e1]["dag"] / d[e2]["dag"]
Expand All @@ -56,10 +56,15 @@ def process(js, extractors=[]):
except Exception as e:
print(f"Error processing {name}")
raise e

print(f"cumulative time for {e1}: {e1_cumulative/1000:.0f}ms")
print(f"cumulative time for {e2}: {e2_cumulative/1000:.0f}ms")

print(f"cumulative tree cost for {e1}: {sum(d[e1]['tree'] for d in by_name.values()):.0f}")
print(f"cumulative tree cost for {e2}: {sum(d[e2]['tree'] for d in by_name.values()):.0f}")
print(f"cumulative dag cost for {e1}: {sum(d[e1]['dag'] for d in by_name.values()):.0f}")
print(f"cumulative dag cost for {e2}: {sum(d[e2]['dag'] for d in by_name.values()):.0f}")

print(f"{e1} / {e2}")

print("geo mean")
Expand Down Expand Up @@ -93,4 +98,14 @@ def quantiles(key):
files = sys.argv[1:] or glob.glob("output/**/*.json", recursive=True)
js = load_jsons(files)
print(f"Loaded {len(js)} jsons.")
process(js)

extractors = sorted(set(j["extractor"] for j in js))

for i in range(len(extractors)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR also makes the plot output data for each combination of extractors.
This creates a lot of output- maybe we don't want this change, but I find it helpful.

for j in range(i + 1, len(extractors)):
ex1, ex2 = extractors[i], extractors[j]
if ex1 == ex2:
continue
print(f"###################################################\n{ex1} vs {ex2}\n\n")
process(js, [ex1, ex2])
print("\n\n")
203 changes: 203 additions & 0 deletions src/extract/global_greedy_dag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
use std::iter;

use rpds::{HashTrieMap, HashTrieSet};

use super::*;

type TermId = usize;

#[derive(Clone, PartialEq, Eq, Hash)]
struct Term {
op: String,
children: Vec<TermId>,
}

type Reachable = HashTrieSet<ClassId>;

struct TermInfo {
node: NodeId,
eclass: ClassId,
node_cost: Cost,
total_cost: Cost,
// store the set of reachable terms from this term
reachable: Reachable,
size: usize,
}

/// A TermDag needs to store terms that share common
/// subterms using a hashmap.
/// However, it also critically needs to be able to answer
/// reachability queries in this dag `reachable`.
/// This prevents double-counting costs when
/// computing the cost of a term.
#[derive(Default)]
pub struct TermDag {
nodes: Vec<Term>,
info: Vec<TermInfo>,
hash_cons: HashMap<Term, TermId>,
}

impl TermDag {
/// Makes a new term using a node and children terms
/// Correctly computes total_cost with sharing
/// If this term contains itself, returns None
/// If this term costs more than target, returns None
pub fn make(
&mut self,
node_id: NodeId,
node: &Node,
children: Vec<TermId>,
target: Cost,
) -> Option<TermId> {
let term = Term {
op: node.op.clone(),
children: children.clone(),
};

if let Some(id) = self.hash_cons.get(&term) {
return Some(*id);
}

let node_cost = node.cost;

if children.is_empty() {
let next_id = self.nodes.len();
self.nodes.push(term.clone());
self.info.push(TermInfo {
node: node_id,
eclass: node.eclass.clone(),
node_cost,
total_cost: node_cost,
reachable: iter::once(node.eclass.clone()).collect(),
size: 1,
});
self.hash_cons.insert(term, next_id);
Some(next_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional:

I think when the if/else are this long it can help to unnest after the "easy case", namely:

if children.is_empty() {
...
  return Some(next_id);
}
// check if children contains this node
for child in &children {
...

Copy link
Member Author

@oflatt oflatt Dec 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eh, I like nesting
return considered harmful (only half joking)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you! I suggested this because there's already a good amount of early-return in this function (78, 92, 99, for example).

I personally think it is good sometimes, and not good other times. For code like this that feels pretty imperative either way I think it can make the code clearer. Feel free to keep as is though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see
re-examining my head it seems that early return from a big code block makes me sad

} else {
// check if children contains this node, preventing cycles
// This is sound because `reachable` is the set of reachable eclasses
// from this term.
for child in &children {
if self.info[*child].reachable.contains(&node.eclass) {
return None;
}
}

let biggest_child = (0..children.len())
.max_by_key(|i| self.info[children[*i]].size)
.unwrap();

let mut cost = node_cost + self.total_cost(children[biggest_child]);
let mut reachable = self.info[children[biggest_child]].reachable.clone();
let next_id = self.nodes.len();

for child in children.iter() {
if cost > target {
return None;
}
let child_cost = self.get_cost(&mut reachable, *child);
cost += child_cost;
}

if cost > target {
return None;
}

reachable = reachable.insert(node.eclass.clone());

self.info.push(TermInfo {
node: node_id,
node_cost,
eclass: node.eclass.clone(),
total_cost: cost,
reachable,
size: 1 + children.iter().map(|c| self.info[*c].size).sum::<usize>(),
});
self.nodes.push(term.clone());
self.hash_cons.insert(term, next_id);
Some(next_id)
}
}

/// Return a new term, like this one but making use of shared terms.
/// Also return the cost of the new nodes.
fn get_cost(&self, shared: &mut Reachable, id: TermId) -> Cost {
let eclass = self.info[id].eclass.clone();

// This is the key to why this algorithm is faster than greedy_dag.
// While doing the set union between reachable sets, we can stop early
// if we find a shared term.
// Since the term with `id` is shared, the reachable set of `id` will already
// be in `shared`.
if shared.contains(&eclass) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the key "why" for how this is better than greedy-dag: instead of doing an $O(n)$ union operation when you merge reachable sets, we "stop early" based on the shared egraph-level structure: if two nodes share a child somewhere down we don't continue unioning because we know the sets are the same.

Is that right? If so, I think it's worth a comment :) (If not, I'd definitely appreciate a comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly right- I left a comment
// This is the key to why this algorithm is faster than greedy_dag.
// While doing the set union between reachable sets, we can stop early
// if we find a shared term.
// Since the term with id is shared, the reachable set of id will already
// be in shared.

NotNan::<f64>::new(0.0).unwrap()
} else {
let mut cost = self.node_cost(id);
for child in &self.nodes[id].children {
let child_cost = self.get_cost(shared, *child);
cost += child_cost;
}
*shared = shared.insert(eclass);
cost
}
}

pub fn node_cost(&self, id: TermId) -> Cost {
self.info[id].node_cost
}

pub fn total_cost(&self, id: TermId) -> Cost {
self.info[id].total_cost
}
}

pub struct GlobalGreedyDagExtractor;
impl Extractor for GlobalGreedyDagExtractor {
fn extract(&self, egraph: &EGraph, _roots: &[ClassId]) -> ExtractionResult {
let mut keep_going = true;

let nodes = egraph.nodes.clone();
let mut termdag = TermDag::default();
let mut best_in_class: HashMap<ClassId, TermId> = HashMap::default();

let mut i = 0;
while keep_going {
i += 1;
println!("iteration {}", i);
keep_going = false;

'node_loop: for (node_id, node) in &nodes {
let mut children: Vec<TermId> = vec![];
// compute the cost set from the children
for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
if let Some(best) = best_in_class.get(child_cid) {
children.push(*best);
} else {
continue 'node_loop;
}
}

let old_cost = best_in_class
.get(&node.eclass)
.map(|id| termdag.total_cost(*id))
.unwrap_or(INFINITY);

if let Some(candidate) = termdag.make(node_id.clone(), node, children, old_cost) {
let cadidate_cost = termdag.total_cost(candidate);

if cadidate_cost < old_cost {
best_in_class.insert(node.eclass.clone(), candidate);
keep_going = true;
}
}
}
}

let mut result = ExtractionResult::default();
for (class, term) in best_in_class {
result.choose(class, termdag.info[term].node.clone());
}
result
}
}
1 change: 1 addition & 0 deletions src/extract/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
pub use crate::*;

pub mod bottom_up;
pub mod global_greedy_dag;
pub mod greedy_dag;
pub mod greedy_dag_1;

Expand Down
5 changes: 5 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn main() {
env_logger::init();

let extractors: IndexMap<&str, Box<dyn Extractor>> = [
("bottom-up", extract::bottom_up::BottomUpExtractor.boxed()),
(
"greedy-dag",
extract::greedy_dag::GreedyDagExtractor.boxed(),
Expand All @@ -27,6 +28,10 @@ fn main() {
"faster-greedy-dag",
extract::greedy_dag_1::FasterGreedyDagExtractor.boxed(),
),
(
"global-greedy-dag",
extract::global_greedy_dag::GlobalGreedyDagExtractor.boxed(),
),
]
.into_iter()
.collect();
Expand Down