Skip to content

Commit

Permalink
Allow literal tangents
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Sep 27, 2023
1 parent 6abe24f commit f3d7057
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ impl FuncBuilder {
/// integer type; or if `t` is an integer type which cannot represent the given value of `x`.
pub fn num(&mut self, t: usize, x: f64) -> Result<usize, JsError> {
match self.ty(t)? {
Ty::F64 => Ok(self.constant(t, rose::Expr::F64 { val: x })),
Ty::F64 | Ty::T64 => Ok(self.constant(t, rose::Expr::F64 { val: x })),
&Ty::Fin { size } => {
let y = x as usize;
if y as f64 != x {
Expand Down
7 changes: 7 additions & 0 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,13 @@ describe("valid", () => {
expect(interp(g)(0, 1)).toBeCloseTo(50000);
});

test("custom JVP with zero tangent", () => {
const signum = opaque([Real], Real, Math.sign);
signum.jvp = fn([Dual], Dual, ({ re: x }) => ({ re: sign(x), du: 0 }));
const f = interp(jvp(fn([Real], Real, (x) => signum(x))));
expect(f({ re: 2, du: 1 })).toEqual({ re: 1, du: 0 });
});

test("VJP", () => {
const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1]));
const g = fn([], Vec(3, Real), () => {
Expand Down

0 comments on commit f3d7057

Please sign in to comment.