Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom JVPs #105

Merged
merged 3 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 80 additions & 66 deletions crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -773,79 +773,93 @@ impl<'a> Transpose<'a> {
}
}

Expr::Call { id, generics, args } => {
let (dep_types, t) = self.deps[id.func()];
let mut types = vec![];
for ty in dep_types {
types.push(self.translate(generics, &types, ty));
Expr::Call { id, generics, args } => match self.f.vars[var.var()] {
REAL => {
self.block.fwd.push(Instr {
var,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.iter().map(|&arg| self.get_prim(arg)).collect(),
},
});
self.keep(var);
self.prims[var.var()] = Some(Src(None));
}
let t_tup = types[t.ty()];
_ => {
let (dep_types, t) = self.deps[id.func()];
let mut types = vec![];
for ty in dep_types {
types.push(self.translate(generics, &types, ty));
}
let t_tup = types[t.ty()];

let t_bundle = self.ty(Ty::Tuple {
members: [self.f.vars[var.var()], t_tup].into(),
});
let bundle = self.fwd_var(t_bundle);
self.block.fwd.push(Instr {
var: bundle,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.iter().map(|&arg| self.get_re(arg)).collect(),
},
});
let t_bundle = self.ty(Ty::Tuple {
members: [self.f.vars[var.var()], t_tup].into(),
});
let bundle = self.fwd_var(t_bundle);
self.block.fwd.push(Instr {
var: bundle,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.iter().map(|&arg| self.get_re(arg)).collect(),
},
});

self.block.fwd.push(Instr {
var,
expr: Expr::Member {
tuple: bundle,
member: id::member(0),
},
});
self.keep(var);
self.block.fwd.push(Instr {
var,
expr: Expr::Member {
tuple: bundle,
member: id::member(0),
},
});
self.keep(var);

let inter_fwd = self.fwd_var(t_tup);
let inter_bwd = self.bwd_var(Some(t_tup));
self.block.fwd.push(Instr {
var: inter_fwd,
expr: Expr::Member {
tuple: bundle,
member: id::member(1),
},
});
self.block.bwd_nonlin.push(Instr {
var: inter_bwd,
expr: Expr::Member {
tuple: self.block.inter_tup,
member: id::member(self.block.inter_mem.len()),
},
});
self.block.inter_mem.push(inter_fwd);
let inter_fwd = self.fwd_var(t_tup);
let inter_bwd = self.bwd_var(Some(t_tup));
self.block.fwd.push(Instr {
var: inter_fwd,
expr: Expr::Member {
tuple: bundle,
member: id::member(1),
},
});
self.block.bwd_nonlin.push(Instr {
var: inter_bwd,
expr: Expr::Member {
tuple: self.block.inter_tup,
member: id::member(self.block.inter_mem.len()),
},
});
self.block.inter_mem.push(inter_fwd);

let lin = self.accum(var);
let unit = self.bwd_var(Some(self.unit));
let mut args: Vec<_> = args
.iter()
.map(|&arg| match self.f.types[self.f.vars[arg.var()].ty()] {
Ty::Ref { .. } => self.get_cotan(arg),
_ => self.get_accum(arg),
})
.collect();
args.push(lin.cot);
args.push(inter_bwd);
self.block.bwd_lin.push(Instr {
var: unit,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.into(),
},
});
self.resolve(lin);
let lin = self.accum(var);
let unit = self.bwd_var(Some(self.unit));
let mut args: Vec<_> = args
.iter()
.map(|&arg| match self.f.types[self.f.vars[arg.var()].ty()] {
Ty::Ref { .. } => self.get_cotan(arg),
_ => self.get_accum(arg),
})
.collect();
args.push(lin.cot);
args.push(inter_bwd);
self.block.bwd_lin.push(Instr {
var: unit,
expr: Expr::Call {
id: *id,
generics: generics.clone(),
args: args.into(),
},
});
self.resolve(lin);

if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] {
self.duals[var.var()] = Some((Src(None), Src(None)));
if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] {
self.duals[var.var()] = Some((Src(None), Src(None)));
}
}
}
},
Expr::For { arg, body, ret } => {
let t_index = self.f.vars[arg.var()];
let t_elem = self.f.vars[ret.var()];
Expand Down
40 changes: 28 additions & 12 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ struct Pointee {
structs: Box<[Option<Box<[usize]>>]>,

/// Jacobian-vector product.
jvp: RefCell<Option<Weak<Pointee>>>,
jvp: RefCell<Option<Rc<Pointee>>>,

/// Forward pass of the vector-Jacobian product.
fwd: RefCell<Option<Weak<Pointee>>>,
Expand Down Expand Up @@ -262,6 +262,11 @@ impl Func {
Ok(to_js_value(&ret)?)
}

#[wasm_bindgen(js_name = "setJvp")]
pub fn set_jvp(&self, f: &Func) {
self.rc.as_ref().jvp.replace(Some(Rc::clone(&f.rc)));
}

/// Return a function that computes the Jacobian-vector product of this function.
///
/// `re` must be the string ID for the string `"re"` not just in this function, but in every
Expand All @@ -274,7 +279,7 @@ impl Func {
..
} = self.rc.as_ref();
let mut cache = jvp.borrow_mut();
if let Some(rc) = cache.as_ref().and_then(|weak| weak.upgrade()) {
if let Some(rc) = cache.as_ref().map(Rc::clone) {
return Self { rc };
}
let rc =
Expand All @@ -301,9 +306,9 @@ impl Func {
bwd: RefCell::new(None),
})
}
Inner::Opaque { .. } => todo!(),
Inner::Opaque { .. } => panic!("no JVP provided for opaque function"),
};
*cache = Some(Rc::downgrade(&rc));
*cache = Some(Rc::clone(&rc));
Self { rc }
}

Expand Down Expand Up @@ -379,7 +384,7 @@ impl Func {
}),
)
}
Inner::Opaque { .. } => panic!(),
Inner::Opaque { .. } => (Rc::clone(&self.rc), (Rc::clone(&self.rc))),
};
*cache_fwd = Some(Rc::downgrade(&rc_fwd));
*cache_bwd = Some(Rc::downgrade(&rc_bwd));
Expand Down Expand Up @@ -611,6 +616,7 @@ enum Ty {
Unit,
Bool,
F64,
T64,
Fin {
size: usize,
},
Expand Down Expand Up @@ -639,6 +645,7 @@ impl Ty {
Ty::Unit => (rose::Ty::Unit, None),
Ty::Bool => (rose::Ty::Bool, None),
Ty::F64 => (rose::Ty::F64, None),
Ty::T64 => (rose::Ty::F64, None),
Ty::Fin { size } => (rose::Ty::Fin { size }, None),
Ty::Ref { inner } => (rose::Ty::Ref { inner }, None),
Ty::Array { index, elem } => (rose::Ty::Array { index, elem }, None),
Expand Down Expand Up @@ -694,10 +701,13 @@ impl FuncBuilder {
/// Start building a function with the given number of `generics`, all constrained as `Index`.
#[wasm_bindgen(constructor)]
pub fn new(generics: usize) -> Self {
let mut types = IndexMap::new();
types.insert(Ty::F64, EnumSet::only(rose::Constraint::Value));
types.insert(Ty::T64, EnumSet::only(rose::Constraint::Value));
Self {
functions: vec![],
generics: vec![EnumSet::only(rose::Constraint::Index); generics].into(),
types: IndexMap::new(),
types,
vars: vec![],
params: vec![],
constants: vec![],
Expand Down Expand Up @@ -906,7 +916,13 @@ impl FuncBuilder {
/// Return the ID for the 64-bit floating-point type, creating if needed.
#[wasm_bindgen(js_name = "tyF64")]
pub fn ty_f64(&mut self) -> usize {
self.newtype(Ty::F64, EnumSet::only(rose::Constraint::Value))
0
}

/// Return the ID for the 64-bit floating-point tangent type, creating if needed.
#[wasm_bindgen(js_name = "tyT64")]
pub fn ty_t64(&mut self) -> usize {
1
}

/// Return the ID for the type of nonnegative integers less than `size`, creating if needed.
Expand Down Expand Up @@ -1251,7 +1267,7 @@ impl Block {
///
/// Assumes `arg` is defined, in scope, and has 64-bit floating point type.
pub fn neg(&mut self, f: &mut FuncBuilder, arg: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[arg].t;
let expr = rose::Expr::Unary {
op: rose::Unop::Neg,
arg: id::var(arg),
Expand Down Expand Up @@ -1433,7 +1449,7 @@ impl Block {
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn add(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Add,
left: id::var(left),
Expand All @@ -1446,7 +1462,7 @@ impl Block {
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn sub(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Sub,
left: id::var(left),
Expand All @@ -1459,7 +1475,7 @@ impl Block {
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn mul(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Mul,
left: id::var(left),
Expand All @@ -1472,7 +1488,7 @@ impl Block {
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
pub fn div(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_f64());
let t = f.vars[left].t;
let expr = rose::Expr::Binary {
op: rose::Binop::Div,
left: id::var(left),
Expand Down
Loading