Skip to content

Commit

Permalink
Skip opaque function indices in accumulation calls
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Sep 25, 2023
1 parent 7940319 commit 7de7510
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
8 changes: 3 additions & 5 deletions crates/wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {
self.pointer();
let j = self.funcs.get_index_of(&(ByAddress(def), gens)).unwrap();
self.bump(self.costs[j]);
self.imports.len() + self.extras + j
self.extras + j
}
Node::Opaque { def, .. } => {
self.imports.get_index_of(&(def, gens)).unwrap()
Expand Down Expand Up @@ -662,7 +662,7 @@ pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
let mut code_section = CodeSection::new();

let mut metas: Vec<Meta> = vec![];
let mut extras: usize = 0;
let mut extras: usize = imports.len();
for ty in types.into_iter() {
let (layout, cost, members) = match &ty {
Ty::Unit => (Layout::Unit, None, None),
Expand Down Expand Up @@ -976,9 +976,7 @@ pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
export_section.export(
"f",
wasm_encoder::ExportKind::Func,
(imports.len() + extras + funcs.len() - 1)
.try_into()
.unwrap(),
(extras + funcs.len() - 1).try_into().unwrap(),
);
export_section.export("m", wasm_encoder::ExportKind::Memory, 0);

Expand Down
11 changes: 11 additions & 0 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,17 @@ describe("valid", () => {
expect((await compile(h))(1)).toBe(1);
});

test.only("compile VJP with opaque call", async () => {
const exp = opaque([Real], Real, Math.exp);
exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => {
const y = exp(x);
return { re: y, du: mulLin(dx, y) };
});
const g = fn([Real], Real, (x) => exp(x));
const h = fn([Real], Real, (x) => vjp(g)(x).ret);
expect((await compile(h))(1)).toBeCloseTo(Math.E);
});

test("compile nulls in signature", async () => {
const f = fn([Null], Null, (x) => x);
const g = await compile(f);
Expand Down

0 comments on commit 7de7510

Please sign in to comment.