Skip to content

Commit

Permalink
cache inverse caculation
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikalii committed Nov 25, 2024
1 parent 799265e commit f1bf60e
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 7 deletions.
14 changes: 11 additions & 3 deletions src/compile/invert/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
mod un;
mod under;

use std::{boxed, cell::RefCell, collections::HashMap, error::Error, fmt};
use std::{
boxed,
cell::RefCell,
collections::HashMap,
error::Error,
fmt,
hash::{DefaultHasher, Hasher},
};

use ecow::eco_vec;
use serde::*;

use crate::{
assembly::{Assembly, Function},
check::{nodes_clean_sig, SigCheckError},
ArrayLen, FunctionId,
check::{nodes_clean_sig, nodes_sig, SigCheckError},
compile::algebra::algebraic_inverse,
ArrayLen, CustomInverse, FunctionId,
ImplPrimitive::{self, *},
Node::{self, *},
Primitive::{self, *},
Expand Down
43 changes: 39 additions & 4 deletions src/compile/invert/un.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use crate::{check::nodes_sig, compile::algebra::algebraic_inverse, CustomInverse};

use super::*;

impl Node {
Expand Down Expand Up @@ -33,6 +31,24 @@ pub fn un_inverse(input: &[Node], asm: &Assembly) -> InversionResult<Node> {
if input.is_empty() {
return Ok(Node::empty());
}

thread_local! {
static CACHE: RefCell<HashMap<u64, InversionResult<Node>>> = Default::default();
}
let mut hasher = DefaultHasher::new();
for node in input {
node.hash_with_span(&mut hasher);
}
let hash = hasher.finish();
if let Some(cached) = CACHE.with(|cache| cache.borrow_mut().get(&hash).cloned()) {
return cached;
}
let res = un_inverse_impl(input, asm);
CACHE.with(|cache| cache.borrow_mut().insert(hash, res.clone()));
res
}

fn un_inverse_impl(input: &[Node], asm: &Assembly) -> InversionResult<Node> {
let mut node = Node::empty();
let mut curr = input;
let mut error = Generic;
Expand All @@ -57,9 +73,28 @@ pub fn un_inverse(input: &[Node], asm: &Assembly) -> InversionResult<Node> {
Err(error)
}

pub fn anti_inverse(mut input: &[Node], asm: &Assembly) -> InversionResult<Node> {
// An anti inverse can be optionaly sandwiched by an un inverse on either side
pub fn anti_inverse(input: &[Node], asm: &Assembly) -> InversionResult<Node> {
if input.is_empty() {
return generic();
}
thread_local! {
static CACHE: RefCell<HashMap<u64, InversionResult<Node>>> = Default::default();
}
let mut hasher = DefaultHasher::new();
for node in input {
node.hash_with_span(&mut hasher);
}
let hash = hasher.finish();
if let Some(cached) = CACHE.with(|cache| cache.borrow_mut().get(&hash).cloned()) {
return cached;
}
let res = anti_inverse_impl(input, asm);
CACHE.with(|cache| cache.borrow_mut().insert(hash, res.clone()));
res
}

fn anti_inverse_impl(mut input: &[Node], asm: &Assembly) -> InversionResult<Node> {
// An anti inverse can be optionaly sandwiched by an un inverse on either side
let orig_input = input;
let mut error = Generic;

Expand Down
33 changes: 33 additions & 0 deletions src/compile/invert/under.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,39 @@ fn under_inverse(
if input.is_empty() {
return Ok((Node::empty(), Node::empty()));
}

type Key = (u64, Signature, bool);
thread_local! {
static CACHE: RefCell<HashMap<Key, InversionResult<(Node, Node)>>> = Default::default();
}
let mut hasher = DefaultHasher::new();
for node in input {
node.hash_with_span(&mut hasher);
}
let hash = hasher.finish();
if let Some(cached) =
CACHE.with(|cache| cache.borrow_mut().get(&(hash, g_sig, inverse)).cloned())
{
return cached;
}
let res = under_inverse_impl(input, g_sig, inverse, asm);
CACHE.with(|cache| {
cache
.borrow_mut()
.insert((hash, g_sig, inverse), res.clone())
});
res
}

fn under_inverse_impl(
input: &[Node],
g_sig: Signature,
inverse: bool,
asm: &Assembly,
) -> InversionResult<(Node, Node)> {
if input.is_empty() {
return Ok((Node::empty(), Node::empty()));
}
let mut before = Node::empty();
let mut after = Node::empty();
let mut curr = input;
Expand Down
6 changes: 6 additions & 0 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,12 @@ impl Node {
}
recurse(None, None, self, asm, f)
}
pub(crate) fn hash_with_span(&self, hasher: &mut impl Hasher) {
self.hash(hasher);
if let Some(span) = self.span() {
span.hash(hasher);
}
}
}

impl From<&[Node]> for Node {
Expand Down

0 comments on commit f1bf60e

Please sign in to comment.