diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 4f09527..68ac572 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -945,7 +945,7 @@ impl FuncBuilder { /// integer type; or if `t` is an integer type which cannot represent the given value of `x`. pub fn num(&mut self, t: usize, x: f64) -> Result { match self.ty(t)? { - Ty::F64 => Ok(self.constant(t, rose::Expr::F64 { val: x })), + Ty::F64 | Ty::T64 => Ok(self.constant(t, rose::Expr::F64 { val: x })), &Ty::Fin { size } => { let y = x as usize; if y as f64 != x { diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 23c2df5..f44d665 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -433,6 +433,13 @@ describe("valid", () => { expect(interp(g)(0, 1)).toBeCloseTo(50000); }); + test("custom JVP with zero tangent", () => { + const signum = opaque([Real], Real, Math.sign); + signum.jvp = fn([Dual], Dual, ({ re: x }) => ({ re: sign(x), du: 0 })); + const f = interp(jvp(fn([Real], Real, (x) => signum(x)))); + expect(f({ re: 2, du: 1 })).toEqual({ re: 1, du: 0 }); + }); + test("VJP", () => { const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1])); const g = fn([], Vec(3, Real), () => {