Skip to content

Commit

Permalink
Write a few more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Sep 17, 2023
1 parent 0cb1347 commit 6a3f6e3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
5 changes: 5 additions & 0 deletions crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ impl<'a> Transpose<'a> {
self.block.inter_mem.push(var);
}

/// Create a non-primitive accumulator for `shape`; return it along with its eventual cotangent.
fn accum(&mut self, shape: id::Var) -> Lin {
let t_cot = self.f.vars[shape.var()];
let t_acc = self.ty(Ty::Ref { inner: t_cot });
Expand All @@ -280,6 +281,7 @@ impl<'a> Transpose<'a> {
Lin { acc, cot }
}

/// Create a primitive accumulator for the given `tangent`, using `self.real_shape`.
fn calc(&mut self, tangent: id::Var) -> Lin {
let t_cot = self.f.vars[tangent.var()];
let t_acc = self.ty(Ty::Ref { inner: t_cot });
Expand All @@ -296,13 +298,15 @@ impl<'a> Transpose<'a> {
Lin { acc, cot }
}

/// Resolve the given accumulator.
fn resolve(&mut self, lin: Lin) {
self.block.bwd_lin.push(Instr {
var: lin.cot,
expr: Expr::Resolve { var: lin.acc },
})
}

/// Process `block` and return the type and forward variable for the intermediate values tuple.
fn block(&mut self, block: &[Instr]) -> (id::Ty, id::Var) {
for instr in block.iter() {
self.instr(instr.var, &instr.expr);
Expand All @@ -322,6 +326,7 @@ impl<'a> Transpose<'a> {
(t, var)
}

/// Process the instruction with the given `var` and `expr`.
fn instr(&mut self, var: id::Var, expr: &Expr) {
match expr {
Expr::Unit => {
Expand Down
18 changes: 18 additions & 0 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ struct Pointee {
/// The actual strings are stored in JavaScript.
structs: Box<[Option<Box<[usize]>>]>,

/// Jacobian-vector product.
jvp: RefCell<Option<Weak<Pointee>>>,

/// Forward pass of the vector-Jacobian product.
fwd: RefCell<Option<Weak<Pointee>>>,

/// Backward pass of the vector-Jacobian product.
bwd: RefCell<Option<Weak<Pointee>>>,
}

Expand Down Expand Up @@ -304,6 +307,7 @@ impl Func {
Self { rc }
}

/// Return the forward and backward pass of the transpose of this function.
fn transpose_pair(&self) -> (Self, Self) {
let Pointee {
inner,
Expand Down Expand Up @@ -382,6 +386,9 @@ impl Func {
(Self { rc: rc_fwd }, Self { rc: rc_bwd })
}

/// Return the transpose of this function.
///
/// Assumes that this function has already been computed as the `jvp` of another function.
pub fn transpose(&self) -> Transpose {
let (fwd, bwd) = self.transpose_pair();
Transpose {
Expand All @@ -391,6 +398,7 @@ impl Func {
}
}

/// A temporary object to hold the two passes of a transposed function before they are destructured.
#[wasm_bindgen]
pub struct Transpose {
fwd: Option<Func>,
Expand All @@ -399,10 +407,12 @@ pub struct Transpose {

#[wasm_bindgen]
impl Transpose {
/// Return the forward pass.
pub fn fwd(&mut self) -> Option<Func> {
self.fwd.take()
}

/// Return the backward pass.
pub fn bwd(&mut self) -> Option<Func> {
self.bwd.take()
}
Expand Down Expand Up @@ -1545,13 +1555,21 @@ impl Block {
self.instr(f, id::ty(t), expr)
}

/// Return the variable ID for a new instruction defining an accumulator with the given `shape`.
///
/// Assumes `shape` is defined and in scope, and that `t` is the ID of a reference type whose
/// inner type is the same as the type of `shape`.
pub fn accum(&mut self, f: &mut FuncBuilder, t: usize, shape: usize) -> usize {
let expr = rose::Expr::Accum {
shape: id::var(shape),
};
self.instr(f, id::ty(t), expr)
}

/// Return the variable ID for a new instruction resolving the given accumulator `var`.
///
/// Assumes `var` is defined and in scope, and that `t` is the inner type of the reference type
/// for `var`.
pub fn resolve(&mut self, f: &mut FuncBuilder, t: usize, var: usize) -> usize {
let expr = rose::Expr::Resolve { var: id::var(var) };
self.instr(f, id::ty(t), expr)
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ export const jvp = <const A extends readonly any[], const R>(
return g;
};

/** Construct a closure that computes the Jacobian-vector product of `f`. */
export const vjp = <const A, const R>(
f: Fn & ((arg: A) => R),
): ((arg: A) => { ret: R; grad: (cot: R) => A }) => {
Expand Down

0 comments on commit 6a3f6e3

Please sign in to comment.