Skip to content

Commit

Permalink
Transpose select on references
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Sep 17, 2023
1 parent d1dc574 commit 0cb1347
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 10 deletions.
37 changes: 27 additions & 10 deletions crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,22 +711,38 @@ impl<'a> Transpose<'a> {
}
&Expr::Select { cond, then, els } => {
let t = self.f.vars[var.var()];

self.block.fwd.push(Instr {
var,
expr: Expr::Select {
cond,
then: self.get_re(then),
els: self.get_re(els),
},
});

match &self.f.types[t.ty()] {
&Ty::Ref { inner } => todo!(),
_ => {
self.block.fwd.push(Instr {
var,
&Ty::Ref { inner } => {
let cot = self.bwd_var(Some(inner));
self.block.bwd_nonlin.push(Instr {
var: cot,
expr: Expr::Select {
cond,
then: self.get_re(then),
els: self.get_re(els),
then: self.get_cotan(then),
els: self.get_cotan(els),
},
});
self.cotans[var.var()] = Some(cot);
}
_ => {
self.keep(var);
let lin = self.accum(var);
let acc_then = self.get_accum(then);
let acc_els = self.get_accum(els);
let acc = self.bwd_var(Some(self.f.vars[then.var()])); // `els` is fine too
let t_acc = self.ty(Ty::Ref {
inner: self.f.vars[var.var()],
});
let acc = self.bwd_var(Some(t_acc));
let unit = self.bwd_var(Some(self.unit));
self.block.bwd_lin.push(Instr {
var: unit,
Expand All @@ -744,11 +760,12 @@ impl<'a> Transpose<'a> {
},
});
self.resolve(lin);
if let Ty::F64 = self.mapped_types[t.ty()] {
self.duals[var.var()] = Some((Src(None), Src(None)));
}
}
}

if let Ty::F64 = self.mapped_types[t.ty()] {
self.duals[var.var()] = Some((Src(None), Src(None)));
}
}

Expr::Call { id, generics, args } => {
Expand Down
15 changes: 15 additions & 0 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -455,4 +455,19 @@ describe("valid", () => {
const h = fn([Pair, Pair], Pair, (p, q) => vjp(g)(p).grad(q));
expect(interp(h)({ x: 2, y: 3 }, { x: 5, y: 7 })).toEqual({ x: 7, y: 5 });
});

test("VJP twice with select", () => {
const Stuff = { p: Bool, x: Real, y: Real, z: Real } as const;
const f = fn([Stuff], Real, ({ p, x, y, z }) =>
mul(z, select(p, Real, x, y)),
);
const g = fn([Stuff], Stuff, (p) => vjp(f)(p).grad(1));
const h = fn([Stuff, Stuff], Stuff, (p, q) => vjp(g)(p).grad(q));
expect(
interp(h)(
{ p: true, x: 2, y: 3, z: 5 },
{ p: false, x: 7, y: 11, z: 13 },
),
).toEqual({ p: true, x: 13, y: 0, z: 7 });
});
});

0 comments on commit 0cb1347

Please sign in to comment.