From 6a3f6e3b614006e68eb8c41c0b6047d6b4df1d32 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Sun, 17 Sep 2023 14:59:22 -0400 Subject: [PATCH] Write a few more comments --- crates/transpose/src/lib.rs | 5 +++++ crates/web/src/lib.rs | 18 ++++++++++++++++++ packages/core/src/impl.ts | 1 + 3 files changed, 24 insertions(+) diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index fc374e0..17d6219 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -266,6 +266,7 @@ impl<'a> Transpose<'a> { self.block.inter_mem.push(var); } + /// Create a non-primitive accumulator for `shape`; return it along with its eventual cotangent. fn accum(&mut self, shape: id::Var) -> Lin { let t_cot = self.f.vars[shape.var()]; let t_acc = self.ty(Ty::Ref { inner: t_cot }); @@ -280,6 +281,7 @@ impl<'a> Transpose<'a> { Lin { acc, cot } } + /// Create a primitive accumulator for the given `tangent`, using `self.real_shape`. fn calc(&mut self, tangent: id::Var) -> Lin { let t_cot = self.f.vars[tangent.var()]; let t_acc = self.ty(Ty::Ref { inner: t_cot }); @@ -296,6 +298,7 @@ impl<'a> Transpose<'a> { Lin { acc, cot } } + /// Resolve the given accumulator. fn resolve(&mut self, lin: Lin) { self.block.bwd_lin.push(Instr { var: lin.cot, @@ -303,6 +306,7 @@ impl<'a> Transpose<'a> { }) } + /// Process `block` and return the type and forward variable for the intermediate values tuple. fn block(&mut self, block: &[Instr]) -> (id::Ty, id::Var) { for instr in block.iter() { self.instr(instr.var, &instr.expr); @@ -322,6 +326,7 @@ impl<'a> Transpose<'a> { (t, var) } + /// Process the instruction with the given `var` and `expr`. fn instr(&mut self, var: id::Var, expr: &Expr) { match expr { Expr::Unit => { diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 0f31044..d406f1e 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -124,10 +124,13 @@ struct Pointee { /// The actual strings are stored in JavaScript. structs: Box<[Option>]>, + /// Jacobian-vector product. jvp: RefCell>>, + /// Forward pass of the vector-Jacobian product. fwd: RefCell>>, + /// Backward pass of the vector-Jacobian product. bwd: RefCell>>, } @@ -304,6 +307,7 @@ impl Func { Self { rc } } + /// Return the forward and backward pass of the transpose of this function. fn transpose_pair(&self) -> (Self, Self) { let Pointee { inner, @@ -382,6 +386,9 @@ impl Func { (Self { rc: rc_fwd }, Self { rc: rc_bwd }) } + /// Return the transpose of this function. + /// + /// Assumes that this function has already been computed as the `jvp` of another function. pub fn transpose(&self) -> Transpose { let (fwd, bwd) = self.transpose_pair(); Transpose { @@ -391,6 +398,7 @@ impl Func { } } +/// A temporary object to hold the two passes of a transposed function before they are destructured. #[wasm_bindgen] pub struct Transpose { fwd: Option, @@ -399,10 +407,12 @@ pub struct Transpose { #[wasm_bindgen] impl Transpose { + /// Return the forward pass. pub fn fwd(&mut self) -> Option { self.fwd.take() } + /// Return the backward pass. pub fn bwd(&mut self) -> Option { self.bwd.take() } @@ -1545,6 +1555,10 @@ impl Block { self.instr(f, id::ty(t), expr) } + /// Return the variable ID for a new instruction defining an accumulator with the given `shape`. + /// + /// Assumes `shape` is defined and in scope, and that `t` is the ID of a reference type whose + /// inner type is the same as the type of `shape`. pub fn accum(&mut self, f: &mut FuncBuilder, t: usize, shape: usize) -> usize { let expr = rose::Expr::Accum { shape: id::var(shape), @@ -1552,6 +1566,10 @@ impl Block { self.instr(f, id::ty(t), expr) } + /// Return the variable ID for a new instruction resolving the given accumulator `var`. + /// + /// Assumes `var` is defined and in scope, and that `t` is the inner type of the reference type + /// for `var`. pub fn resolve(&mut self, f: &mut FuncBuilder, t: usize, var: usize) -> usize { let expr = rose::Expr::Resolve { var: id::var(var) }; self.instr(f, id::ty(t), expr) diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 16f5c83..2198309 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -595,6 +595,7 @@ export const jvp = ( return g; }; +/** Construct a closure that computes the Jacobian-vector product of `f`. */ export const vjp = ( f: Fn & ((arg: A) => R), ): ((arg: A) => { ret: R; grad: (cot: R) => A }) => {