Skip to content

Commit

Permalink
Support custom JVPs
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Sep 19, 2023
1 parent 2424d58 commit 31e8d0f
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 113 deletions.
146 changes: 80 additions & 66 deletions crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -773,79 +773,93 @@ impl<'a> Transpose<'a> {
}
}

Expr::Call { id, generics, args } => {
let (dep_types, t) = self.deps[id.func()];
let mut types = vec![];
for ty in dep_types {
types.push(self.translate(generics, &types, ty));
Expr::Call { id, generics, args } => match self.f.vars[var.var()] {
REAL => {
self.block.fwd.push(Instr {
var,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.iter().map(|&arg| self.get_prim(arg)).collect(),
},
});
self.keep(var);
self.prims[var.var()] = Some(Src(None));
}
let t_tup = types[t.ty()];
_ => {
let (dep_types, t) = self.deps[id.func()];
let mut types = vec![];
for ty in dep_types {
types.push(self.translate(generics, &types, ty));
}
let t_tup = types[t.ty()];

let t_bundle = self.ty(Ty::Tuple {
members: [self.f.vars[var.var()], t_tup].into(),
});
let bundle = self.fwd_var(t_bundle);
self.block.fwd.push(Instr {
var: bundle,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.iter().map(|&arg| self.get_re(arg)).collect(),
},
});
let t_bundle = self.ty(Ty::Tuple {
members: [self.f.vars[var.var()], t_tup].into(),
});
let bundle = self.fwd_var(t_bundle);
self.block.fwd.push(Instr {
var: bundle,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.iter().map(|&arg| self.get_re(arg)).collect(),
},
});

self.block.fwd.push(Instr {
var,
expr: Expr::Member {
tuple: bundle,
member: id::member(0),
},
});
self.keep(var);
self.block.fwd.push(Instr {
var,
expr: Expr::Member {
tuple: bundle,
member: id::member(0),
},
});
self.keep(var);

let inter_fwd = self.fwd_var(t_tup);
let inter_bwd = self.bwd_var(Some(t_tup));
self.block.fwd.push(Instr {
var: inter_fwd,
expr: Expr::Member {
tuple: bundle,
member: id::member(1),
},
});
self.block.bwd_nonlin.push(Instr {
var: inter_bwd,
expr: Expr::Member {
tuple: self.block.inter_tup,
member: id::member(self.block.inter_mem.len()),
},
});
self.block.inter_mem.push(inter_fwd);
let inter_fwd = self.fwd_var(t_tup);
let inter_bwd = self.bwd_var(Some(t_tup));
self.block.fwd.push(Instr {
var: inter_fwd,
expr: Expr::Member {
tuple: bundle,
member: id::member(1),
},
});
self.block.bwd_nonlin.push(Instr {
var: inter_bwd,
expr: Expr::Member {
tuple: self.block.inter_tup,
member: id::member(self.block.inter_mem.len()),
},
});
self.block.inter_mem.push(inter_fwd);

let lin = self.accum(var);
let unit = self.bwd_var(Some(self.unit));
let mut args: Vec<_> = args
.iter()
.map(|&arg| match self.f.types[self.f.vars[arg.var()].ty()] {
Ty::Ref { .. } => self.get_cotan(arg),
_ => self.get_accum(arg),
})
.collect();
args.push(lin.cot);
args.push(inter_bwd);
self.block.bwd_lin.push(Instr {
var: unit,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.into(),
},
});
self.resolve(lin);
let lin = self.accum(var);
let unit = self.bwd_var(Some(self.unit));
let mut args: Vec<_> = args
.iter()
.map(|&arg| match self.f.types[self.f.vars[arg.var()].ty()] {
Ty::Ref { .. } => self.get_cotan(arg),
_ => self.get_accum(arg),
})
.collect();
args.push(lin.cot);
args.push(inter_bwd);
self.block.bwd_lin.push(Instr {
var: unit,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.into(),
},
});
self.resolve(lin);

if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] {
self.duals[var.var()] = Some((Src(None), Src(None)));
if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] {
self.duals[var.var()] = Some((Src(None), Src(None)));
}
}
}
},
Expr::For { arg, body, ret } => {
let t_index = self.f.vars[arg.var()];
let t_elem = self.f.vars[ret.var()];
Expand Down
40 changes: 28 additions & 12 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ struct Pointee {
structs: Box<[Option<Box<[usize]>>]>,

/// Jacobian-vector product.
jvp: RefCell<Option<Weak<Pointee>>>,
jvp: RefCell<Option<Rc<Pointee>>>,

/// Forward pass of the vector-Jacobian product.
fwd: RefCell<Option<Weak<Pointee>>>,
Expand Down Expand Up @@ -262,6 +262,11 @@ impl Func {
Ok(to_js_value(&ret)?)
}

#[wasm_bindgen(js_name = "setJvp")]
pub fn set_jvp(&self, f: &Func) {
self.rc.as_ref().jvp.replace(Some(Rc::clone(&f.rc)));
}

/// Return a function that computes the Jacobian-vector product of this function.
///
/// `re` must be the string ID for the string `"re"` not just in this function, but in every
Expand All @@ -274,7 +279,7 @@ impl Func {
..
} = self.rc.as_ref();
let mut cache = jvp.borrow_mut();
if let Some(rc) = cache.as_ref().and_then(|weak| weak.upgrade()) {
if let Some(rc) = cache.as_ref().map(Rc::clone) {
return Self { rc };
}
let rc =
Expand All @@ -301,9 +306,9 @@ impl Func {
bwd: RefCell::new(None),
})
}
Inner::Opaque { .. } => todo!(),
Inner::Opaque { .. } => panic!("no JVP provided for opaque function"),
};
*cache = Some(Rc::downgrade(&rc));
*cache = Some(Rc::clone(&rc));
Self { rc }
}

Expand Down Expand Up @@ -379,7 +384,7 @@ impl Func {
}),
)
}
Inner::Opaque { .. } => panic!(),
Inner::Opaque { .. } => (Rc::clone(&self.rc), (Rc::clone(&self.rc))),
};
*cache_fwd = Some(Rc::downgrade(&rc_fwd));
*cache_bwd = Some(Rc::downgrade(&rc_bwd));
Expand Down Expand Up @@ -611,6 +616,7 @@ enum Ty {
Unit,
Bool,
F64,
T64,
Fin {
size: usize,
},
Expand Down Expand Up @@ -639,6 +645,7 @@ impl Ty {
Ty::Unit => (rose::Ty::Unit, None),
Ty::Bool => (rose::Ty::Bool, None),
Ty::F64 => (rose::Ty::F64, None),
Ty::T64 => (rose::Ty::F64, None),
Ty::Fin { size } => (rose::Ty::Fin { size }, None),
Ty::Ref { inner } => (rose::Ty::Ref { inner }, None),
Ty::Array { index, elem } => (rose::Ty::Array { index, elem }, None),
Expand Down Expand Up @@ -694,10 +701,13 @@ impl FuncBuilder {
/// Start building a function with the given number of `generics`, all constrained as `Index`.
#[wasm_bindgen(constructor)]
pub fn new(generics: usize) -> Self {
let mut types = IndexMap::new();
types.insert(Ty::F64, EnumSet::only(rose::Constraint::Value));
types.insert(Ty::T64, EnumSet::only(rose::Constraint::Value));
Self {
functions: vec![],
generics: vec![EnumSet::only(rose::Constraint::Index); generics].into(),
types: IndexMap::new(),
types,
vars: vec![],
params: vec![],
constants: vec![],
Expand Down Expand Up @@ -906,7 +916,13 @@ impl FuncBuilder {
/// Return the ID for the 64-bit floating-point type, creating if needed.
#[wasm_bindgen(js_name = "tyF64")]
pub fn ty_f64(&mut self) -> usize {
self.newtype(Ty::F64, EnumSet::only(rose::Constraint::Value))
0
}

/// Return the ID for the 64-bit floating-point tangent type, creating if needed.
#[wasm_bindgen(js_name = "tyT64")]
pub fn ty_t64(&mut self) -> usize {
1
}

/// Return the ID for the type of nonnegative integers less than `size`, creating if needed.
Expand Down Expand Up @@ -1251,7 +1267,7 @@ impl Block {
///
/// Assumes `arg` is defined, in scope, and has 64-bit floating point type.
pub fn neg(&mut self, f: &mut FuncBuilder, arg: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[arg].t;
let expr = rose::Expr::Unary {
op: rose::Unop::Neg,
arg: id::var(arg),
Expand Down Expand Up @@ -1433,7 +1449,7 @@ impl Block {
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn add(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Add,
left: id::var(left),
Expand All @@ -1446,7 +1462,7 @@ impl Block {
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn sub(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Sub,
left: id::var(left),
Expand All @@ -1459,7 +1475,7 @@ impl Block {
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn mul(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Mul,
left: id::var(left),
Expand All @@ -1472,7 +1488,7 @@ impl Block {
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn div(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Div,
left: id::var(left),
Expand Down
Loading

0 comments on commit 31e8d0f

Please sign in to comment.