From 31e8d0f11881cac65600a069d6422df01b047ca5 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Tue, 19 Sep 2023 17:33:27 -0400 Subject: [PATCH 1/3] Support custom JVPs --- crates/transpose/src/lib.rs | 146 +++++++++++++++++--------------- crates/web/src/lib.rs | 40 ++++++--- packages/core/src/impl.test.ts | 57 +++++++------ packages/core/src/impl.ts | 62 +++++++++++++- packages/core/src/index.test.ts | 38 +++++++-- packages/core/src/index.ts | 9 +- 6 files changed, 239 insertions(+), 113 deletions(-) diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index c4c2547..5f76b6c 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -773,79 +773,93 @@ impl<'a> Transpose<'a> { } } - Expr::Call { id, generics, args } => { - let (dep_types, t) = self.deps[id.func()]; - let mut types = vec![]; - for ty in dep_types { - types.push(self.translate(generics, &types, ty)); + Expr::Call { id, generics, args } => match self.f.vars[var.var()] { + REAL => { + self.block.fwd.push(Instr { + var, + expr: Expr::Call { + id: *id, + generics: generics.clone(), + args: args.iter().map(|&arg| self.get_prim(arg)).collect(), + }, + }); + self.keep(var); + self.prims[var.var()] = Some(Src(None)); } - let t_tup = types[t.ty()]; + _ => { + let (dep_types, t) = self.deps[id.func()]; + let mut types = vec![]; + for ty in dep_types { + types.push(self.translate(generics, &types, ty)); + } + let t_tup = types[t.ty()]; - let t_bundle = self.ty(Ty::Tuple { - members: [self.f.vars[var.var()], t_tup].into(), - }); - let bundle = self.fwd_var(t_bundle); - self.block.fwd.push(Instr { - var: bundle, - expr: Expr::Call { - id: *id, - generics: generics.clone(), - args: args.iter().map(|&arg| self.get_re(arg)).collect(), - }, - }); + let t_bundle = self.ty(Ty::Tuple { + members: [self.f.vars[var.var()], t_tup].into(), + }); + let bundle = self.fwd_var(t_bundle); + self.block.fwd.push(Instr { + var: bundle, + expr: Expr::Call { + id: *id, + generics: generics.clone(), + args: args.iter().map(|&arg| self.get_re(arg)).collect(), + }, + }); - self.block.fwd.push(Instr { - var, - expr: Expr::Member { - tuple: bundle, - member: id::member(0), - }, - }); - self.keep(var); + self.block.fwd.push(Instr { + var, + expr: Expr::Member { + tuple: bundle, + member: id::member(0), + }, + }); + self.keep(var); - let inter_fwd = self.fwd_var(t_tup); - let inter_bwd = self.bwd_var(Some(t_tup)); - self.block.fwd.push(Instr { - var: inter_fwd, - expr: Expr::Member { - tuple: bundle, - member: id::member(1), - }, - }); - self.block.bwd_nonlin.push(Instr { - var: inter_bwd, - expr: Expr::Member { - tuple: self.block.inter_tup, - member: id::member(self.block.inter_mem.len()), - }, - }); - self.block.inter_mem.push(inter_fwd); + let inter_fwd = self.fwd_var(t_tup); + let inter_bwd = self.bwd_var(Some(t_tup)); + self.block.fwd.push(Instr { + var: inter_fwd, + expr: Expr::Member { + tuple: bundle, + member: id::member(1), + }, + }); + self.block.bwd_nonlin.push(Instr { + var: inter_bwd, + expr: Expr::Member { + tuple: self.block.inter_tup, + member: id::member(self.block.inter_mem.len()), + }, + }); + self.block.inter_mem.push(inter_fwd); - let lin = self.accum(var); - let unit = self.bwd_var(Some(self.unit)); - let mut args: Vec<_> = args - .iter() - .map(|&arg| match self.f.types[self.f.vars[arg.var()].ty()] { - Ty::Ref { .. } => self.get_cotan(arg), - _ => self.get_accum(arg), - }) - .collect(); - args.push(lin.cot); - args.push(inter_bwd); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Call { - id: *id, - generics: generics.clone(), - args: args.into(), - }, - }); - self.resolve(lin); + let lin = self.accum(var); + let unit = self.bwd_var(Some(self.unit)); + let mut args: Vec<_> = args + .iter() + .map(|&arg| match self.f.types[self.f.vars[arg.var()].ty()] { + Ty::Ref { .. } => self.get_cotan(arg), + _ => self.get_accum(arg), + }) + .collect(); + args.push(lin.cot); + args.push(inter_bwd); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Call { + id: *id, + generics: generics.clone(), + args: args.into(), + }, + }); + self.resolve(lin); - if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src(None), Src(None))); + if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { + self.duals[var.var()] = Some((Src(None), Src(None))); + } } - } + }, Expr::For { arg, body, ret } => { let t_index = self.f.vars[arg.var()]; let t_elem = self.f.vars[ret.var()]; diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index d406f1e..91f9d20 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -125,7 +125,7 @@ struct Pointee { structs: Box<[Option>]>, /// Jacobian-vector product. - jvp: RefCell>>, + jvp: RefCell>>, /// Forward pass of the vector-Jacobian product. fwd: RefCell>>, @@ -262,6 +262,11 @@ impl Func { Ok(to_js_value(&ret)?) } + #[wasm_bindgen(js_name = "setJvp")] + pub fn set_jvp(&self, f: &Func) { + self.rc.as_ref().jvp.replace(Some(Rc::clone(&f.rc))); + } + /// Return a function that computes the Jacobian-vector product of this function. /// /// `re` must be the string ID for the string `"re"` not just in this function, but in every @@ -274,7 +279,7 @@ impl Func { .. } = self.rc.as_ref(); let mut cache = jvp.borrow_mut(); - if let Some(rc) = cache.as_ref().and_then(|weak| weak.upgrade()) { + if let Some(rc) = cache.as_ref().map(Rc::clone) { return Self { rc }; } let rc = @@ -301,9 +306,9 @@ impl Func { bwd: RefCell::new(None), }) } - Inner::Opaque { .. } => todo!(), + Inner::Opaque { .. } => panic!("no JVP provided for opaque function"), }; - *cache = Some(Rc::downgrade(&rc)); + *cache = Some(Rc::clone(&rc)); Self { rc } } @@ -379,7 +384,7 @@ impl Func { }), ) } - Inner::Opaque { .. } => panic!(), + Inner::Opaque { .. } => (Rc::clone(&self.rc), (Rc::clone(&self.rc))), }; *cache_fwd = Some(Rc::downgrade(&rc_fwd)); *cache_bwd = Some(Rc::downgrade(&rc_bwd)); @@ -611,6 +616,7 @@ enum Ty { Unit, Bool, F64, + T64, Fin { size: usize, }, @@ -639,6 +645,7 @@ impl Ty { Ty::Unit => (rose::Ty::Unit, None), Ty::Bool => (rose::Ty::Bool, None), Ty::F64 => (rose::Ty::F64, None), + Ty::T64 => (rose::Ty::F64, None), Ty::Fin { size } => (rose::Ty::Fin { size }, None), Ty::Ref { inner } => (rose::Ty::Ref { inner }, None), Ty::Array { index, elem } => (rose::Ty::Array { index, elem }, None), @@ -694,10 +701,13 @@ impl FuncBuilder { /// Start building a function with the given number of `generics`, all constrained as `Index`. #[wasm_bindgen(constructor)] pub fn new(generics: usize) -> Self { + let mut types = IndexMap::new(); + types.insert(Ty::F64, EnumSet::only(rose::Constraint::Value)); + types.insert(Ty::T64, EnumSet::only(rose::Constraint::Value)); Self { functions: vec![], generics: vec![EnumSet::only(rose::Constraint::Index); generics].into(), - types: IndexMap::new(), + types, vars: vec![], params: vec![], constants: vec![], @@ -906,7 +916,13 @@ impl FuncBuilder { /// Return the ID for the 64-bit floating-point type, creating if needed. #[wasm_bindgen(js_name = "tyF64")] pub fn ty_f64(&mut self) -> usize { - self.newtype(Ty::F64, EnumSet::only(rose::Constraint::Value)) + 0 + } + + /// Return the ID for the 64-bit floating-point tangent type, creating if needed. + #[wasm_bindgen(js_name = "tyT64")] + pub fn ty_t64(&mut self) -> usize { + 1 } /// Return the ID for the type of nonnegative integers less than `size`, creating if needed. @@ -1251,7 +1267,7 @@ impl Block { /// /// Assumes `arg` is defined, in scope, and has 64-bit floating point type. pub fn neg(&mut self, f: &mut FuncBuilder, arg: usize) -> usize { - let t = id::ty(f.ty_f64()); + let t = f.vars[arg].t; let expr = rose::Expr::Unary { op: rose::Unop::Neg, arg: id::var(arg), @@ -1433,7 +1449,7 @@ impl Block { /// /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. pub fn add(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_f64()); + let t = f.vars[left].t; let expr = rose::Expr::Binary { op: rose::Binop::Add, left: id::var(left), @@ -1446,7 +1462,7 @@ impl Block { /// /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. pub fn sub(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_f64()); + let t = f.vars[left].t; let expr = rose::Expr::Binary { op: rose::Binop::Sub, left: id::var(left), @@ -1459,7 +1475,7 @@ impl Block { /// /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. pub fn mul(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_f64()); + let t = f.vars[left].t; let expr = rose::Expr::Binary { op: rose::Binop::Mul, left: id::var(left), @@ -1472,7 +1488,7 @@ impl Block { /// /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. pub fn div(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { - let t = id::ty(f.ty_f64()); + let t = f.vars[left].t; let expr = rose::Expr::Binary { op: rose::Binop::Div, left: id::var(left), diff --git a/packages/core/src/impl.test.ts b/packages/core/src/impl.test.ts index b7f3a2a..cc93cdd 100644 --- a/packages/core/src/impl.test.ts +++ b/packages/core/src/impl.test.ts @@ -57,9 +57,10 @@ describe("pprint", () => { expect(s).toBe( ` T0 = F64 -T1 = Bool +T1 = F64 +T2 = Bool (x0: T0, x1: T0) -> T0 { - x2: T1 = x0 < x1 + x2: T2 = x0 < x1 x3: T0 = x0 * x1 x4: T0 = x1 - x0 x5: T0 = x3 + x0 @@ -85,6 +86,7 @@ T1 = Bool expect(s).toBe( ` T0 = F64 +T1 = F64 (x0: T0) -> T0 { x1: T0 = f0<>(x0) x2: T0 = f1<>(x0) @@ -108,10 +110,11 @@ T0 = F64 expect(s).toBe( ` T0 = F64 -T1 = Bool +T1 = F64 +T2 = Bool (x0: T0) -> T0 { - x1: T1 = true - x2: T1 = not x1 + x1: T2 = true + x2: T2 = not x1 x3: T0 = -x0 x4: T0 = |x3| x5: T0 = sign(x0) @@ -143,24 +146,25 @@ T1 = Bool expect(s).toBe( ` T0 = F64 -T1 = Bool -(x0: T0, x1: T0) -> T1 { - x6: T1 = true - x7: T1 = false +T1 = F64 +T2 = Bool +(x0: T0, x1: T0) -> T2 { + x6: T2 = true + x7: T2 = false x2: T0 = x0 + x1 x3: T0 = x0 - x1 x4: T0 = x0 * x1 x5: T0 = x0 / x1 - x8: T1 = x6 and x7 - x9: T1 = x6 or x7 - x10: T1 = x6 iff x7 - x11: T1 = x6 xor x7 - x12: T1 = x0 != x1 - x13: T1 = x0 < x1 - x14: T1 = x0 <= x1 - x15: T1 = x0 == x1 - x16: T1 = x0 > x1 - x17: T1 = x4 >= x5 + x8: T2 = x6 and x7 + x9: T2 = x6 or x7 + x10: T2 = x6 iff x7 + x11: T2 = x6 xor x7 + x12: T2 = x0 != x1 + x13: T2 = x0 < x1 + x14: T2 = x0 <= x1 + x15: T2 = x0 == x1 + x16: T2 = x0 > x1 + x17: T2 = x4 >= x5 x17 } `.trimStart(), @@ -174,14 +178,15 @@ T1 = Bool const s = pprint(f); expect(s).toBe( ` -T0 = 3 +T0 = F64 T1 = F64 -T2 = [T0]T1 -(x0: T2, x1: T2) -> T2 { - x6: T2 = for x2: T0 { - x3: T1 = x0[x2] - x4: T1 = x1[x2] - x5: T1 = x3 + x4 +T2 = 3 +T3 = [T2]T0 +(x0: T3, x1: T3) -> T3 { + x6: T3 = for x2: T2 { + x3: T0 = x0[x2] + x4: T0 = x1[x2] + x5: T0 = x3 + x4 x5 } x6 diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 2198309..8fdc186 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -30,6 +30,7 @@ const strings = Symbol("strings"); export interface Fn { [inner]: wasm.Func; [strings]: string[]; + jvp: Fn; } /** Property key for a variable ID. */ @@ -65,6 +66,9 @@ export type Bool = boolean | Var; /** An abstract 64-bit floating point number. */ export type Real = number | Var; +/** An abstract 64-bit floating point tangent number. */ +export type Tan = number | Var; + /** An abstract natural number, which can be used to index into a vector. */ type Nat = number | symbol; @@ -215,6 +219,9 @@ export const Bool = Symbol("Bool"); /** The 64-bit floating-point type. */ export const Real = Symbol("Real"); +/** The 64-bit floating-point tangent type. */ +export const Tan = Symbol("Tan"); + /** Representation of the null type. */ type Nulls = typeof Null; @@ -224,6 +231,9 @@ type Bools = typeof Bool; /** Representation of the 64-bit floating point type. */ type Reals = typeof Real; +/** Representation of the 64-bit floating point tangent type. */ +type Tans = typeof Tan; + /** Representation of a bounded index type (it's just the upper bound). */ type Nats = number; @@ -244,6 +254,9 @@ export const Vec = (index: K, elem: V): Vecs => { return { [ind]: index, [elm]: elem }; }; +/** The 128-bit floating-point dual number type. */ +export const Dual = { re: Real, du: Tan } as const; + // TODO: make this locale-independent const compare = (a: string, b: string): number => a.localeCompare(b); @@ -266,6 +279,7 @@ const tyId = (ctx: Context, ty: unknown): number => { if (ty === Null) return ctx.func.tyUnit(); else if (ty === Bool) return ctx.func.tyBool(); else if (ty === Real) return ctx.func.tyF64(); + else if (ty === Tan) return ctx.func.tyT64(); else if (typeof ty === "number") return ctx.func.tyFin(ty); else if (typeof ty === "object" && ty !== null) { if (ind in ty && elm in ty) @@ -385,6 +399,8 @@ type ToSymbolic = T extends Nulls ? Bool : T extends Reals ? Real + : T extends Tans + ? Tan : T extends Nats ? Nat : T extends Vecs @@ -404,6 +420,8 @@ type ToValue = T extends Nulls ? Bool : T extends Reals ? Real + : T extends Tans + ? Tan : T extends Nats ? Nat : T extends Vecs @@ -472,7 +490,7 @@ export const fn = ( }; /** Construct an opaque function whose implementation runs `f`. */ -export const custom = ( +export const opaque = ( params: P, ret: R, f: (...args: JsArgs>) => ToJs>, @@ -485,6 +503,11 @@ export const custom = ( funcs.register(g, func); g[inner] = func; g[strings] = []; // TODO: support tuples in opaque functions + Object.defineProperty(g, "jvp", { + set(h: Fn) { + func.setJvp(h[inner]); + }, + }); return g; }; @@ -603,8 +626,8 @@ export const vjp = ( const tp = g[inner].transpose(); const fwdFunc = tp.fwd()!; const bwdFunc = tp.bwd()!; - const fwd: Fn = { [inner]: fwdFunc, [strings]: [...f[strings]] }; - const bwd: Fn = { [inner]: bwdFunc, [strings]: [...f[strings]] }; + const fwd: Fn = { [inner]: fwdFunc, [strings]: [...f[strings]] } as Fn; + const bwd: Fn = { [inner]: bwdFunc, [strings]: [...f[strings]] } as Fn; funcs.register(fwd, fwdFunc); funcs.register(bwd, bwdFunc); return (arg: A) => { @@ -802,3 +825,36 @@ export const vec = ( const id = block.vec(ctx.func, t, arg, body, out); return idVal(ctx, t, id) as Vec>; }; + +/** Return the variable ID for the abstract floating point tangent `x`. */ +const tanId = (ctx: Context, x: Tan): number => valId(ctx, ctx.func.tyT64(), x); + +/** Return the negative of the abstract tangent `x`. */ +export const negLin = (x: Tan): Tan => { + const ctx = getCtx(); + return newVar(ctx.block.neg(ctx.func, tanId(ctx, x))); +}; + +/** Return the abstract tangent `x` plus the abstract tangent `y`. */ +export const addLin = (x: Tan, y: Tan): Tan => { + const ctx = getCtx(); + return newVar(ctx.block.add(ctx.func, tanId(ctx, x), tanId(ctx, y))); +}; + +/** Return the abstract tangent `x` minus the abstract tangent `y`. */ +export const subLin = (x: Tan, y: Tan): Tan => { + const ctx = getCtx(); + return newVar(ctx.block.sub(ctx.func, tanId(ctx, x), tanId(ctx, y))); +}; + +/** Return the abstract tangent `x` times the abstract number `y`. */ +export const mulLin = (x: Tan, y: Real): Tan => { + const ctx = getCtx(); + return newVar(ctx.block.mul(ctx.func, tanId(ctx, x), realId(ctx, y))); +}; + +/** Return the abstract tangent `x` divided by the abstract number `y`. */ +export const divLin = (x: Tan, y: Real): Tan => { + const ctx = getCtx(); + return newVar(ctx.block.div(ctx.func, tanId(ctx, x), realId(ctx, y))); +}; diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 7bc2e0f..2128db6 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -1,17 +1,20 @@ import { describe, expect, test } from "vitest"; import { Bool, + Dual, Null, Real, Vec, add, - custom, div, fn, gt, interp, jvp, mul, + mulLin, + neg, + opaque, or, select, sign, @@ -326,14 +329,14 @@ describe("valid", () => { expect(g({ x: 42 })).toBe(42); }); - test("custom unary function", () => { - const log = custom([Real], Real, Math.log); + test("opaque unary function", () => { + const log = opaque([Real], Real, Math.log); const f = interp(log); expect(f(Math.PI)).toBe(1.1447298858494002); }); - test("custom binary function", () => { - const pow = custom([Real, Real], Real, Math.pow); + test("opaque binary function", () => { + const pow = opaque([Real, Real], Real, Math.pow); const f = interp(pow); expect(f(Math.E, Math.PI)).toBe(23.140692632779263); }); @@ -501,4 +504,29 @@ describe("valid", () => { ), ).toEqual({ p: true, x: 13, y: 0, z: 7 }); }); + + test("opaque functions with derivatives", () => { + const grad = (f: any) => fn([Real], Real, (x) => vjp(f)(x).grad(1) as Real); + + const sin = opaque([Real], Real, Math.sin); + const cos = opaque([Real], Real, Math.cos); + + sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + return { re: sin(x), du: mulLin(dx, cos(x)) }; + }); + cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + return { re: cos(x), du: mulLin(dx, neg(sin(x))) }; + }); + + let f = sin; + expect(interp(f)(1)).toBeCloseTo(Math.sin(1)); + f = grad(f); + expect(interp(f)(1)).toBeCloseTo(Math.cos(1)); + f = grad(f); + expect(interp(f)(1)).toBeCloseTo(-Math.sin(1)); + f = grad(f); + expect(interp(f)(1)).toBeCloseTo(-Math.cos(1)); + f = grad(f); + expect(interp(f)(1)).toBeCloseTo(Math.sin(1)); + }); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index fad6ed2..1888868 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,13 +1,16 @@ export { Bool, + Dual, Null, Real, + Tan, Vec, abs, add, + addLin, and, - custom, div, + divLin, eq, fn, geq, @@ -18,14 +21,18 @@ export { leq, lt, mul, + mulLin, neg, + negLin, neq, not, + opaque, or, select, sign, sqrt, sub, + subLin, vec, vjp, xor, From f0c083a998a6f73782aa58744cc3ea20240cc3ae Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Tue, 19 Sep 2023 20:36:16 -0400 Subject: [PATCH 2/3] Let the test pass --- packages/core/src/index.test.ts | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 2128db6..37cc031 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -524,9 +524,12 @@ describe("valid", () => { expect(interp(f)(1)).toBeCloseTo(Math.cos(1)); f = grad(f); expect(interp(f)(1)).toBeCloseTo(-Math.sin(1)); + + f = cos; + expect(interp(f)(1)).toBeCloseTo(Math.cos(1)); f = grad(f); - expect(interp(f)(1)).toBeCloseTo(-Math.cos(1)); + expect(interp(f)(1)).toBeCloseTo(-Math.sin(1)); f = grad(f); - expect(interp(f)(1)).toBeCloseTo(Math.sin(1)); + expect(interp(f)(1)).toBeCloseTo(-Math.cos(1)); }); }); From 1e616d36ecfc7986d22a380cc8763bc3fb8e4043 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Tue, 19 Sep 2023 21:04:11 -0400 Subject: [PATCH 3/3] Allow custom JVP on a transparent function --- packages/core/src/impl.ts | 40 ++++++++++++++++++--------------- packages/core/src/index.test.ts | 11 +++++++++ 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 8fdc186..687481c 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -26,13 +26,27 @@ export const inner = Symbol("inner"); /** Property key for a `Fn`'s string array for resolving struct key names. */ const strings = Symbol("strings"); -/** An abstract function. */ -export interface Fn { +interface FnBase { [inner]: wasm.Func; [strings]: string[]; - jvp: Fn; } +/** An abstract function. */ +export interface Fn extends FnBase { + jvp?: Fn; +} + +/** Adds `f` to the registry, mutates it into a full `Fn`, then returns it. */ +const makeFn = (f: FnBase): Fn => { + funcs.register(f, f[inner]); + Object.defineProperty(f, "jvp", { + set(g: Fn) { + f[inner].setJvp(g[inner]); + }, + }); + return f as Fn; +}; + /** Property key for a variable ID. */ const variable = Symbol("variable"); @@ -483,10 +497,9 @@ export const fn = ( const g: any = (...args: any): any => // TODO: support generics call(g, new Uint32Array(), args); - funcs.register(g, func); g[inner] = func; g[strings] = strs; - return g; + return makeFn(g) as any; }; /** Construct an opaque function whose implementation runs `f`. */ @@ -500,15 +513,9 @@ export const opaque = ( const g: any = (...args: any): any => // TODO: support generics call(g, new Uint32Array(), args); - funcs.register(g, func); g[inner] = func; g[strings] = []; // TODO: support tuples in opaque functions - Object.defineProperty(g, "jvp", { - set(h: Fn) { - func.setJvp(h[inner]); - }, - }); - return g; + return makeFn(g) as any; }; /** A concrete value. */ @@ -612,10 +619,9 @@ export const jvp = ( const g: any = (...args: any): any => // TODO: support generics call(g, new Uint32Array(), args); - funcs.register(g, func); g[inner] = func; g[strings] = strs; - return g; + return makeFn(g) as any; }; /** Construct a closure that computes the Jacobian-vector product of `f`. */ @@ -626,10 +632,8 @@ export const vjp = ( const tp = g[inner].transpose(); const fwdFunc = tp.fwd()!; const bwdFunc = tp.bwd()!; - const fwd: Fn = { [inner]: fwdFunc, [strings]: [...f[strings]] } as Fn; - const bwd: Fn = { [inner]: bwdFunc, [strings]: [...f[strings]] } as Fn; - funcs.register(fwd, fwdFunc); - funcs.register(bwd, bwdFunc); + const fwd = makeFn({ [inner]: fwdFunc, [strings]: [...f[strings]] }); + const bwd = makeFn({ [inner]: bwdFunc, [strings]: [...f[strings]] }); return (arg: A) => { const ctx = getCtx(); const strs = intern(ctx, fwd[strings]); diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 37cc031..aef3457 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -357,6 +357,17 @@ describe("valid", () => { expect(g({ re: 2, du: 3 })).toEqual({ re: 2097152, du: 3145728 }); }); + test("custom JVP", () => { + const max = fn([Real, Real], Real, (x, y) => select(gt(x, y), Real, x, y)); + const f = fn([Real], Real, (x) => sqrt(x)); + const epsilon = 1e-5; + f.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = sqrt(x); + return { re: y, du: mulLin(dx, div(1 / 2, max(epsilon, y))) }; + }); + expect(interp(jvp(f))({ re: 0, du: 1 }).du).toBeCloseTo(50000); + }); + test("VJP", () => { const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1])); const g = fn([], Vec(3, Real), () => {