diff --git a/Cargo.lock b/Cargo.lock index 9f71a70..0079f56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -192,6 +192,7 @@ name = "rose-transpose" version = "0.0.0" dependencies = [ "enumset", + "indexmap", "rose", ] diff --git a/crates/transpose/Cargo.toml b/crates/transpose/Cargo.toml index c122e43..1030629 100644 --- a/crates/transpose/Cargo.toml +++ b/crates/transpose/Cargo.toml @@ -5,4 +5,5 @@ edition = "2021" [dependencies] enumset = "1" +indexmap = "2" rose = { path = "../core" } diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index a3473f1..17d6219 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -1,57 +1,41 @@ +use indexmap::{indexset, IndexSet}; use rose::{id, Binop, Expr, Func, Instr, Ty, Unop}; -use std::mem::{swap, take}; +use std::mem::{replace, swap, take}; +/// By convention, the first type in a function to be transposed must be the nonlinear `F64`. const REAL: id::Ty = id::ty(0); + +/// By convention, the first type in a function to be transposed must be the linear `F64`. const DUAL: id::Ty = id::ty(1); +/// Return true iff `t` is the type ID of a linear type in a function to be transposed. +/// +/// In this module, "primitive" specifically means a linear or nonlinear `F64` type, and +/// specifically excludes other types that might be considered primitive, such as `Unit` or `Bool`. fn is_primitive(t: id::Ty) -> bool { t == REAL || t == DUAL } -/// Type ID of a variable while building up the backward pass. -enum BwdTy { - /// We already know the type ID of this variable. - /// - /// Usually this means the variable is directly mapped from a variable in the original function. - Known(id::Ty), - - /// This variable has the unit type. - Unit, +/// By convention, the first member in a type for the dual numbers must be the linear part. +const DU: id::Member = id::member(0); - /// This variable is an accumulator. - Accum(id::Ty), - - /// We don't know the type of this variable yet, but will soon. - /// - /// Usually this means the variable is a tuple of intermediate values, and we'll update its type - /// to something concrete when we reach the end of the block. - Unknown, -} +/// By convention, the second member in a type for the dual numbers must be the nonlinear part. +const RE: id::Member = id::member(1); -/// The source of a primitive variable. +/// The source of a primitive variable or a component of a dual number variable. +/// +/// The value `None` means that this is the original source, whereas `Some` means that it is an +/// alias of a the given primitive variable or a component of the given dual number variable. #[derive(Clone, Copy)] -enum Src { - /// This variable is original. - Original, - - /// This variable is an alias of an original primitive variable of the same type. - Alias(id::Var), - - /// This variable is a component of an original dual number variable. - Projection(id::Var), -} +struct Src(Option); impl Src { - fn prim(self, x: id::Var) -> Self { - match self { - Self::Original => Self::Alias(x), - _ => self, - } - } - - fn dual(self, x: id::Var) -> Self { - match self { - Self::Original => Self::Projection(x), + /// Return the source for a variable derived from `self`. + /// + /// The variable ID `x` represents `self`, not the new source being returned. + fn derive(self, x: id::Var) -> Self { + match self.0 { + None => Self(Some(x)), _ => self, } } @@ -89,12 +73,28 @@ struct Transpose<'a> { /// The function being transposed, which is usually a forward-mode derivative. f: &'a Func, - /// Shared types between the forward and backward passes. + deps: &'a [(&'a [Ty], id::Ty)], + + /// Mapped versions of `f.types`. + /// + /// The only reason this is useful is to easily check whether a type is the dual number type by + /// looking it up here to see if it's equal to `F64`. + mapped_types: Vec, + + /// Additional types, shared between the forward and backward passes. + /// + /// This starts out empty: at first we only have `mapped_types`, but later we'll add more types + /// for tuples and arrays of intermediate values that are shared between the two passes, and + /// also for new reference types that are only used in the backward pass. These `types` will + /// later all be appended onto `mapped_types`, so any type indices referencing them should be + /// offset by `mapped_types.len()`. + types: IndexSet, + + /// Type ID for `Unit`. /// - /// This starts out the same length as `f.types`, with the only difference being that all dual - /// number types are replaced with `F64`. Then we add more types only for tuples and arrays of - /// intermediate values that are shared between the two passes. - types: Vec, + /// We could get this every time by looking up in `types`, but it's easier to just always put it + /// in at the beginning to save ourselves the repeated hash lookups. + unit: id::Ty, /// Types of variables in the forward pass. /// @@ -104,9 +104,10 @@ struct Transpose<'a> { /// Types of variables in the backward pass. /// - /// This starts out with the `Known` type of every variable from `f.vars`, but more variables - /// can be added for intermediate values, accumulators, and cotangents. - bwd_vars: Vec, + /// This starts out with `Some` type for every variable from `f.vars`, but more variables can be + /// added for intermediate values, accumulators, and cotangents. The only time a variable's type + /// here is `None` is for tuples of intermediate values; see the `inter_tup` field in `Block`. + bwd_vars: Vec>, /// A variable of type `F64` defined at the beginning of the backward pass. /// @@ -122,6 +123,9 @@ struct Transpose<'a> { prims: Box<[Option]>, /// Sources for dual number variables from the original function. + /// + /// The sources are for the nonlinear part and the linear part, respectively; note that this + /// disagrees with the order for tuple members dictated by `DU` and `RE`. duals: Box<[Option<(Src, Src)>]>, /// Accumulator variables for variables from the original function. @@ -130,68 +134,127 @@ struct Transpose<'a> { /// Cotangent variables for variables from the original function. cotans: Box<[Option]>, + /// Stack of pending unreversed instructions for the backward pass. + /// + /// In general, we keep track of reversed instructions in the `bwd_lin` field of `block`; those + /// will go after the unreversed `bwd_nonlin` instructions. When we enter a new scope via + /// `Expr::For`, we start an entirely new `block`, so even though those inner `bwd_nonlin` + /// instructions may end up interleaved with our current `bwd_lin` instructions, that's fine + /// because they're going in a separate instruction list anyway. + /// + /// But for `Expr::Accum` and `Expr::Resolve`, we're introducing a new scope without actually + /// starting a new `block`. In that case, we still need for all the instructions we put in + /// `bwd_nonlin` during this scope to go before all our `bwd_lin` instructions from the scope, + /// but we also need them to go after any `bwd_lin` instructions we add after the scope ends. + /// So, what we do is push `bwd_nonlin` onto this `stack` when we enter the scope via + /// `Expr::Accum`, and then when we exit the scope via `Expr::Resolve`, we pop it off, reverse + /// it, and append it to `bwd_lin`. Then when we finally finish the actual block, the stack + /// should be empty, so we just reverse `bwd_lin` and append it to `bwd_nonlin` as normal. + stack: Vec>, + /// The current block under construction. block: Block, } impl<'a> Transpose<'a> { + /// Return the ID for `ty`, adding it to `types` if it isn't already there. + fn ty(&mut self, ty: Ty) -> id::Ty { + let (i, _) = self.types.insert_full(ty); + id::ty(self.f.types.len() + i) + } + + fn translate(&mut self, generics: &[id::Ty], types: &[id::Ty], ty: &rose::Ty) -> id::Ty { + self.ty(match ty { + Ty::Unit => Ty::Unit, + Ty::Bool => Ty::Bool, + Ty::F64 => Ty::F64, + &Ty::Fin { size } => Ty::Fin { size }, + Ty::Generic { id } => return generics[id.generic()], + Ty::Ref { inner } => Ty::Ref { + inner: types[inner.ty()], + }, + Ty::Array { index, elem } => Ty::Array { + index: types[index.ty()], + elem: types[elem.ty()], + }, + Ty::Tuple { members } => Ty::Tuple { + members: members.iter().map(|&member| types[member.ty()]).collect(), + }, + }) + } + + /// Return the source of a variable that is the nonlinear part of `x`. fn re(&self, x: id::Var) -> Src { let (src, _) = self.duals[x.var()].unwrap(); - src.dual(x) + src.derive(x) } + /// Return the source of a variable that is the linear part of `x`. fn du(&self, x: id::Var) -> Src { let (_, src) = self.duals[x.var()].unwrap(); - src.dual(x) + src.derive(x) } + /// Return the source variable for `x`, which has a primitive type. fn get_prim(&self, x: id::Var) -> id::Var { match self.prims[x.var()].unwrap() { - Src::Original => x, - Src::Alias(y) | Src::Projection(y) => y, + Src(None) => x, + Src(Some(y)) => y, } } + /// Return the source variable for the nonlinear part of `x`, which has a non-primitive type. + /// + /// Every non-primitive variable whose type is not the dual numbers is considered original. fn get_re(&self, x: id::Var) -> id::Var { match self.duals[x.var()] { - None | Some((Src::Original, _)) => x, - Some((Src::Alias(y), _)) | Some((Src::Projection(y), _)) => y, + None | Some((Src(None), _)) => x, + Some((Src(Some(y)), _)) => y, } } + /// Return the source variable for the linear part of `x`, which has a non-primitive type. + /// + /// Every non-primitive variable whose type is not the dual numbers is considered original. fn get_du(&self, x: id::Var) -> id::Var { match self.duals[x.var()] { - None | Some((_, Src::Original)) => x, - Some((_, Src::Alias(y))) | Some((_, Src::Projection(y))) => y, + None | Some((_, Src(None))) => x, + Some((_, Src(Some(y)))) => y, } } + /// Return the accumulator variable for `x`, which has a primitive type. fn get_prim_accum(&self, x: id::Var) -> id::Var { self.accums[self.get_prim(x).var()].unwrap() } - fn get_dual_accum(&self, x: id::Var) -> Option { - self.accums[self.get_du(x).var()] + /// Return the accumulator variable for the linear part of `x`, which has a non-primitive type. + fn get_accum(&self, x: id::Var) -> id::Var { + self.accums[self.get_du(x).var()].unwrap() } - fn ty(&mut self, ty: Ty) -> id::Ty { - let t = id::ty(self.types.len()); - self.types.push(ty); - t + /// Return the cotangent variable for the linear part of `x`, which has a non-primitive type. + fn get_cotan(&self, x: id::Var) -> id::Var { + self.cotans[self.get_du(x).var()].unwrap() } + /// Return the ID for a new variable with type ID `t` in the forward pass. fn fwd_var(&mut self, t: id::Ty) -> id::Var { let var = id::var(self.fwd_vars.len()); self.fwd_vars.push(t); var } - fn bwd_var(&mut self, t: BwdTy) -> id::Var { + /// Return the ID for a new variable with type ID `t` in the backward pass. + /// + /// `t` should be `None` iff it is a tuple of intermediate values. + fn bwd_var(&mut self, t: Option) -> id::Var { let var = id::var(self.bwd_vars.len()); self.bwd_vars.push(t); var } + /// Include `var` in the intermediate values tuple for the current block. fn keep(&mut self, var: id::Var) { self.block.bwd_nonlin.push(Instr { var, @@ -203,10 +266,12 @@ impl<'a> Transpose<'a> { self.block.inter_mem.push(var); } + /// Create a non-primitive accumulator for `shape`; return it along with its eventual cotangent. fn accum(&mut self, shape: id::Var) -> Lin { - let t = self.f.vars[shape.var()]; - let acc = self.bwd_var(BwdTy::Accum(t)); - let cot = self.bwd_var(BwdTy::Known(t)); + let t_cot = self.f.vars[shape.var()]; + let t_acc = self.ty(Ty::Ref { inner: t_cot }); + let acc = self.bwd_var(Some(t_acc)); + let cot = self.bwd_var(Some(t_cot)); self.block.bwd_nonlin.push(Instr { var: acc, expr: Expr::Accum { shape }, @@ -216,21 +281,24 @@ impl<'a> Transpose<'a> { Lin { acc, cot } } - fn calc(&mut self, tan: id::Var) -> Lin { - let t = self.f.vars[tan.var()]; - let acc = self.bwd_var(BwdTy::Accum(t)); - let cot = self.bwd_var(BwdTy::Known(t)); + /// Create a primitive accumulator for the given `tangent`, using `self.real_shape`. + fn calc(&mut self, tangent: id::Var) -> Lin { + let t_cot = self.f.vars[tangent.var()]; + let t_acc = self.ty(Ty::Ref { inner: t_cot }); + let acc = self.bwd_var(Some(t_acc)); + let cot = self.bwd_var(Some(t_cot)); self.block.bwd_nonlin.push(Instr { var: acc, expr: Expr::Accum { shape: self.real_shape, }, }); - self.accums[tan.var()] = Some(acc); - self.cotans[tan.var()] = Some(cot); + self.accums[tangent.var()] = Some(acc); + self.cotans[tangent.var()] = Some(cot); Lin { acc, cot } } + /// Resolve the given accumulator. fn resolve(&mut self, lin: Lin) { self.block.bwd_lin.push(Instr { var: lin.cot, @@ -238,6 +306,7 @@ impl<'a> Transpose<'a> { }) } + /// Process `block` and return the type and forward variable for the intermediate values tuple. fn block(&mut self, block: &[Instr]) -> (id::Ty, id::Var) { for instr in block.iter() { self.instr(instr.var, &instr.expr); @@ -253,10 +322,11 @@ impl<'a> Transpose<'a> { members: vars.into(), }, }); - self.bwd_vars[self.block.inter_tup.var()] = BwdTy::Known(t); + self.bwd_vars[self.block.inter_tup.var()] = Some(t); (t, var) } + /// Process the instruction with the given `var` and `expr`. fn instr(&mut self, var: id::Var, expr: &Expr) { match expr { Expr::Unit => { @@ -289,12 +359,20 @@ impl<'a> Transpose<'a> { let lin = self.calc(var); self.resolve(lin); } - _ => self.block.fwd.push(Instr { - var, - expr: Expr::F64 { val }, - }), + _ => { + self.block.fwd.push(Instr { + var, + expr: Expr::F64 { val }, + }); + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::F64 { val }, + }); + let lin = self.accum(var); + self.resolve(lin); + } } - self.prims[var.var()] = Some(Src::Original); + self.prims[var.var()] = Some(Src(None)); } &Expr::Fin { val } => { self.block.fwd.push(Instr { @@ -323,36 +401,35 @@ impl<'a> Transpose<'a> { self.keep(var); let lin = self.accum(var); for (i, &elem) in elems.iter().enumerate() { - if let Some(accum) = self.get_dual_accum(elem) { - let index = self.bwd_var(BwdTy::Known(t)); - let addend = self.bwd_var(BwdTy::Known(self.f.vars[elem.var()])); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { accum, addend }, - }); - self.block.bwd_lin.push(Instr { - var: addend, - expr: Expr::Index { - array: lin.cot, - index, - }, - }); - self.block.bwd_lin.push(Instr { - var: index, - expr: Expr::Fin { val: i }, - }); - } + let accum = self.get_accum(elem); + let index = self.bwd_var(Some(t)); + let addend = self.bwd_var(Some(self.f.vars[elem.var()])); + let unit = self.bwd_var(Some(self.unit)); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { accum, addend }, + }); + self.block.bwd_lin.push(Instr { + var: addend, + expr: Expr::Index { + array: lin.cot, + index, + }, + }); + self.block.bwd_lin.push(Instr { + var: index, + expr: Expr::Fin { val: i }, + }); } self.resolve(lin); } - Expr::Tuple { members } => match self.types[self.f.vars[var.var()].ty()] { + Expr::Tuple { members } => match self.mapped_types[self.f.vars[var.var()].ty()] { Ty::F64 => { - let x = members[1]; - let dx = members[0]; + let x = members[RE.member()]; + let dx = members[DU.member()]; self.duals[var.var()] = Some(( - self.prims[x.var()].unwrap().prim(x), - self.prims[dx.var()].unwrap().prim(dx), + self.prims[x.var()].unwrap().derive(x), + self.prims[dx.var()].unwrap().derive(dx), )); } _ => { @@ -365,21 +442,20 @@ impl<'a> Transpose<'a> { self.keep(var); let lin = self.accum(var); for (i, &member) in members.iter().enumerate() { - if let Some(accum) = self.get_dual_accum(member) { - let addend = self.bwd_var(BwdTy::Known(self.f.vars[member.var()])); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { accum, addend }, - }); - self.block.bwd_lin.push(Instr { - var: addend, - expr: Expr::Member { - tuple: lin.cot, - member: id::member(i), - }, - }); - } + let accum = self.get_accum(member); + let addend = self.bwd_var(Some(self.f.vars[member.var()])); + let unit = self.bwd_var(Some(self.unit)); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { accum, addend }, + }); + self.block.bwd_lin.push(Instr { + var: addend, + expr: Expr::Member { + tuple: lin.cot, + member: id::member(i), + }, + }); } self.resolve(lin); } @@ -394,8 +470,11 @@ impl<'a> Transpose<'a> { var, expr: Expr::Index { array, index }, }); - let arr_acc = self.accums[array.var()].unwrap(); - let acc = self.bwd_var(BwdTy::Accum(self.f.vars[array.var()])); + let arr_acc = self.get_accum(array); + let t_acc = self.ty(Ty::Ref { + inner: self.f.vars[var.var()], + }); + let acc = self.bwd_var(Some(t_acc)); self.accums[var.var()] = Some(acc); self.block.bwd_nonlin.push(Instr { var: acc, @@ -404,8 +483,8 @@ impl<'a> Transpose<'a> { index, }, }); - if let Ty::F64 = self.types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src::Original, Src::Original)); + if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { + self.duals[var.var()] = Some((Src(None), Src(None))); } } &Expr::Member { tuple, member } => { @@ -422,8 +501,11 @@ impl<'a> Transpose<'a> { var, expr: Expr::Member { tuple, member }, }); - let tup_acc = self.accums[tuple.var()].unwrap(); - let acc = self.bwd_var(BwdTy::Accum(self.f.vars[tuple.var()])); + let tup_acc = self.get_accum(tuple); + let t_acc = self.ty(Ty::Ref { + inner: self.f.vars[var.var()], + }); + let acc = self.bwd_var(Some(t_acc)); self.accums[var.var()] = Some(acc); self.block.bwd_nonlin.push(Instr { var: acc, @@ -432,15 +514,59 @@ impl<'a> Transpose<'a> { member, }, }); - if let Ty::F64 = self.types[t.ty()] { - self.duals[var.var()] = Some((Src::Original, Src::Original)); + if let Ty::F64 = self.mapped_types[t.ty()] { + self.duals[var.var()] = Some((Src(None), Src(None))); } } } } - &Expr::Slice { array, index } => todo!(), - &Expr::Field { tuple, member } => todo!(), + &Expr::Slice { array, index } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Slice { array, index }, + }); + + let t_cot = match &self.f.types[self.f.vars[var.var()].ty()] { + &Ty::Ref { inner } => inner, + _ => panic!(), + }; + let cot = self.bwd_var(Some(t_cot)); + self.block.bwd_nonlin.push(Instr { + var: cot, + expr: Expr::Index { + array: self.get_cotan(array), + index, + }, + }); + self.cotans[var.var()] = Some(cot); + if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { + self.duals[var.var()] = Some((Src(None), Src(None))); + } + } + &Expr::Field { tuple, member } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Field { tuple, member }, + }); + + let t_cot = match &self.f.types[self.f.vars[var.var()].ty()] { + &Ty::Ref { inner } => inner, + _ => panic!(), + }; + let cot = self.bwd_var(Some(t_cot)); + self.block.bwd_nonlin.push(Instr { + var: cot, + expr: Expr::Member { + tuple: self.get_cotan(tuple), + member, + }, + }); + self.cotans[var.var()] = Some(cot); + if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { + self.duals[var.var()] = Some((Src(None), Src(None))); + } + } &Expr::Unary { op, arg } => { match self.f.vars[var.var()] { @@ -448,8 +574,8 @@ impl<'a> Transpose<'a> { Unop::Not | Unop::Abs | Unop::Sign | Unop::Sqrt => panic!(), Unop::Neg => { let lin = self.calc(var); - let res = self.bwd_var(BwdTy::Known(DUAL)); - let unit = self.bwd_var(BwdTy::Unit); + let res = self.bwd_var(Some(DUAL)); + let unit = self.bwd_var(Some(self.unit)); self.block.bwd_lin.push(Instr { var: unit, expr: Expr::Add { @@ -478,7 +604,7 @@ impl<'a> Transpose<'a> { self.keep(var); } } - self.prims[var.var()] = Some(Src::Original); + self.prims[var.var()] = Some(Src(None)); } &Expr::Binary { op, left, right } => { match self.f.vars[var.var()] { @@ -496,8 +622,8 @@ impl<'a> Transpose<'a> { | Binop::Gt | Binop::Geq => panic!(), Binop::Add => { - let a = self.bwd_var(BwdTy::Unit); - let b = self.bwd_var(BwdTy::Unit); + let a = self.bwd_var(Some(self.unit)); + let b = self.bwd_var(Some(self.unit)); self.block.bwd_lin.push(Instr { var: a, expr: Expr::Add { @@ -514,9 +640,9 @@ impl<'a> Transpose<'a> { }); } Binop::Sub => { - let res = self.bwd_var(BwdTy::Known(DUAL)); - let a = self.bwd_var(BwdTy::Unit); - let b = self.bwd_var(BwdTy::Unit); + let res = self.bwd_var(Some(DUAL)); + let a = self.bwd_var(Some(self.unit)); + let b = self.bwd_var(Some(self.unit)); self.block.bwd_lin.push(Instr { var: a, expr: Expr::Add { @@ -540,8 +666,8 @@ impl<'a> Transpose<'a> { }); } Binop::Mul | Binop::Div => { - let res = self.bwd_var(BwdTy::Known(DUAL)); - let unit = self.bwd_var(BwdTy::Unit); + let res = self.bwd_var(Some(DUAL)); + let unit = self.bwd_var(Some(self.unit)); self.block.bwd_lin.push(Instr { var: unit, expr: Expr::Add { @@ -586,9 +712,11 @@ impl<'a> Transpose<'a> { self.keep(var); } } - self.prims[var.var()] = Some(Src::Original); + self.prims[var.var()] = Some(Src(None)); } &Expr::Select { cond, then, els } => { + let t = self.f.vars[var.var()]; + self.block.fwd.push(Instr { var, expr: Expr::Select { @@ -597,161 +725,453 @@ impl<'a> Transpose<'a> { els: self.get_re(els), }, }); + + match &self.f.types[t.ty()] { + &Ty::Ref { inner } => { + let cot = self.bwd_var(Some(inner)); + self.block.bwd_nonlin.push(Instr { + var: cot, + expr: Expr::Select { + cond, + then: self.get_cotan(then), + els: self.get_cotan(els), + }, + }); + self.cotans[var.var()] = Some(cot); + } + _ => { + self.keep(var); + let lin = self.accum(var); + let acc_then = self.get_accum(then); + let acc_els = self.get_accum(els); + let t_acc = self.ty(Ty::Ref { + inner: self.f.vars[var.var()], + }); + let acc = self.bwd_var(Some(t_acc)); + let unit = self.bwd_var(Some(self.unit)); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { + accum: acc, + addend: lin.cot, + }, + }); + self.block.bwd_lin.push(Instr { + var: acc, + expr: Expr::Select { + cond, + then: acc_then, + els: acc_els, + }, + }); + self.resolve(lin); + } + } + + if let Ty::F64 = self.mapped_types[t.ty()] { + self.duals[var.var()] = Some((Src(None), Src(None))); + } + } + + Expr::Call { id, generics, args } => { + let (dep_types, t) = self.deps[id.func()]; + let mut types = vec![]; + for ty in dep_types { + types.push(self.translate(generics, &types, ty)); + } + let t_tup = types[t.ty()]; + + let t_bundle = self.ty(Ty::Tuple { + members: [self.f.vars[var.var()], t_tup].into(), + }); + let bundle = self.fwd_var(t_bundle); + self.block.fwd.push(Instr { + var: bundle, + expr: Expr::Call { + id: *id, + generics: generics.clone(), + args: args.iter().map(|&arg| self.get_re(arg)).collect(), + }, + }); + + self.block.fwd.push(Instr { + var, + expr: Expr::Member { + tuple: bundle, + member: id::member(0), + }, + }); self.keep(var); - let lin = self.accum(var); - let acc_then = self.get_dual_accum(then).unwrap(); - let acc_els = self.get_dual_accum(els).unwrap(); - // TODO: this scope is wrong; it actually needs to match both `then` and `els` - let acc = self.bwd_var(BwdTy::Accum(self.f.vars[then.var()])); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { - accum: acc, - addend: lin.cot, + + let inter_fwd = self.fwd_var(t_tup); + let inter_bwd = self.bwd_var(Some(t_tup)); + self.block.fwd.push(Instr { + var: inter_fwd, + expr: Expr::Member { + tuple: bundle, + member: id::member(1), + }, + }); + self.block.bwd_nonlin.push(Instr { + var: inter_bwd, + expr: Expr::Member { + tuple: self.block.inter_tup, + member: id::member(self.block.inter_mem.len()), }, }); + self.block.inter_mem.push(inter_fwd); + + let lin = self.accum(var); + let unit = self.bwd_var(Some(self.unit)); + let mut args: Vec<_> = args + .iter() + .map(|&arg| match self.f.types[self.f.vars[arg.var()].ty()] { + Ty::Ref { .. } => self.get_cotan(arg), + _ => self.get_accum(arg), + }) + .collect(); + args.push(lin.cot); + args.push(inter_bwd); self.block.bwd_lin.push(Instr { - var: acc, - expr: Expr::Select { - cond, - then: acc_then, - els: acc_els, + var: unit, + expr: Expr::Call { + id: *id, + generics: generics.clone(), + args: args.into(), }, }); self.resolve(lin); - if let Ty::F64 = self.types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src::Original, Src::Original)); - } } - - Expr::Call { id, generics, args } => todo!(), Expr::For { arg, body, ret } => { + let t_index = self.f.vars[arg.var()]; + let t_elem = self.f.vars[ret.var()]; + let mut block = Block { fwd: vec![], inter_mem: vec![], - inter_tup: self.bwd_var(BwdTy::Unknown), + inter_tup: self.bwd_var(None), bwd_nonlin: vec![], bwd_lin: vec![], }; swap(&mut self.block, &mut block); - self.block(body); // TODO + let (t_inter, fwd_inter) = self.block(body); swap(&mut self.block, &mut block); - } - &Expr::Accum { shape } => todo!(), - - &Expr::Add { accum, addend } => todo!(), + let t_bundle = self.ty(Ty::Tuple { + members: [t_elem, t_inter].into(), + }); + let bundle = self.fwd_var(t_bundle); + block.fwd.push(Instr { + var: bundle, + expr: Expr::Tuple { + members: [self.get_re(*ret), fwd_inter].into(), + }, + }); + let t_arr_bundle = self.ty(Ty::Array { + index: t_index, + elem: t_bundle, + }); + let arr_bundle = self.fwd_var(t_arr_bundle); + self.block.fwd.push(Instr { + var: arr_bundle, + expr: Expr::For { + arg: *arg, + body: block.fwd.into(), + ret: bundle, + }, + }); + let fst_index = self.fwd_var(t_index); + let fst_bundle = self.fwd_var(t_bundle); + let elem = self.fwd_var(t_elem); + self.block.fwd.push(Instr { + var, + expr: Expr::For { + arg: fst_index, + body: [ + Instr { + var: fst_bundle, + expr: Expr::Index { + array: arr_bundle, + index: fst_index, + }, + }, + Instr { + var: elem, + expr: Expr::Member { + tuple: fst_bundle, + member: id::member(0), + }, + }, + ] + .into(), + ret: elem, + }, + }); + self.keep(var); + let t_arr_inter = self.ty(Ty::Array { + index: t_index, + elem: t_inter, + }); + let arr_inter = self.fwd_var(t_arr_inter); + let snd_index = self.fwd_var(t_index); + let snd_bundle = self.fwd_var(t_bundle); + let inter = self.fwd_var(t_inter); + self.block.fwd.push(Instr { + var: arr_inter, + expr: Expr::For { + arg: snd_index, + body: [ + Instr { + var: snd_bundle, + expr: Expr::Index { + array: arr_bundle, + index: snd_index, + }, + }, + Instr { + var: inter, + expr: Expr::Member { + tuple: snd_bundle, + member: id::member(1), + }, + }, + ] + .into(), + ret: inter, + }, + }); - &Expr::Resolve { var } => todo!(), - } - } -} + let arr_inter_bwd = self.bwd_var(Some(t_arr_inter)); + self.block.bwd_nonlin.push(Instr { + var: arr_inter_bwd, + expr: Expr::Member { + tuple: self.block.inter_tup, + member: id::member(self.block.inter_mem.len()), + }, + }); + self.block.inter_mem.push(arr_inter); -/// Return two functions that make up the transpose of `f`. -pub fn transpose(f: &Func) -> (Func, Func) { - let types: Vec<_> = f - .types - .iter() - .enumerate() - .map(|(i, ty)| match ty { - Ty::Unit => Ty::Unit, - Ty::Bool => Ty::Bool, - Ty::F64 => { - if !is_primitive(id::ty(i)) { - panic!() - } - Ty::F64 + let lin = self.accum(var); + let bwd_acc = self.get_accum(*ret); + let bwd_cot = self.bwd_var(Some(t_elem)); + let mut bwd_body = vec![ + Instr { + var: bwd_cot, + expr: Expr::Index { + array: lin.cot, + index: *arg, + }, + }, + Instr { + var: block.inter_tup, + expr: Expr::Index { + array: arr_inter_bwd, + index: *arg, + }, + }, + ]; + bwd_body.append(&mut block.bwd_nonlin); + let unit = self.bwd_var(Some(self.unit)); + bwd_body.push(Instr { + var: unit, + expr: Expr::Add { + accum: bwd_acc, + addend: bwd_cot, + }, + }); + block.bwd_lin.reverse(); + bwd_body.append(&mut block.bwd_lin); + let bwd_ret = self.bwd_var(Some(self.unit)); + bwd_body.push(Instr { + var: bwd_ret, + expr: Expr::Unit, + }); + let t_arr_unit = self.ty(Ty::Array { + index: t_index, + elem: self.unit, + }); + let arr_unit = self.bwd_var(Some(t_arr_unit)); + self.block.bwd_lin.push(Instr { + var: arr_unit, + expr: Expr::For { + arg: *arg, + body: bwd_body.into(), + ret: bwd_ret, + }, + }); + self.resolve(lin); } - &Ty::Fin { size } => Ty::Fin { size }, - &Ty::Generic { id } => Ty::Generic { id }, - &Ty::Ref { inner } => { - if is_primitive(inner) { - panic!() - } - Ty::Ref { inner } + + &Expr::Accum { shape } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Accum { + shape: self.get_re(shape), + }, + }); + + let cot = self.bwd_var(Some(self.f.vars[shape.var()])); + self.cotans[var.var()] = Some(cot); + self.stack.push(take(&mut self.block.bwd_nonlin)); } - &Ty::Array { index, elem } => { - if is_primitive(elem) { - panic!() - } - Ty::Array { index, elem } + + &Expr::Add { accum, addend } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Add { + accum, + addend: self.get_re(addend), + }, + }); + + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Unit, + }); + let lin = self.accum(var); + let unit = self.bwd_var(Some(self.unit)); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { + accum: self.get_accum(addend), + addend: self.get_cotan(accum), + }, + }); + self.resolve(lin); } - Ty::Tuple { members } => { - if members.iter().any(|&t| is_primitive(t)) { - Ty::F64 - } else { - Ty::Tuple { - members: members.clone(), - } + + &Expr::Resolve { var: accum } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Resolve { var: accum }, + }); + + let mut bwd_nonlin = replace(&mut self.block.bwd_nonlin, self.stack.pop().unwrap()); + bwd_nonlin.reverse(); + self.block.bwd_lin.append(&mut bwd_nonlin); + let acc = self.bwd_var(Some(self.f.vars[accum.var()])); + self.block.bwd_lin.push(Instr { + var: self.get_cotan(accum), + expr: Expr::Resolve { var: acc }, + }); + self.keep(var); + self.block.bwd_nonlin.push(Instr { + var: acc, + expr: Expr::Accum { shape: var }, + }); + self.accums[var.var()] = Some(acc); + if let Ty::F64 = self.mapped_types[self.f.vars[var.var()].ty()] { + self.duals[var.var()] = Some((Src(None), Src(None))); } } - }) - .collect(); + } + } +} - let mut bwd_vars: Vec<_> = f.vars.iter().map(|&t| BwdTy::Known(t)).collect(); +/// Return the forward and backward pass for the transpose of `f`. +pub fn transpose(f: &Func, deps: &[(&[Ty], id::Ty)]) -> (Func, Func) { + let mut bwd_vars: Vec<_> = f.vars.iter().map(|&t| Some(t)).collect(); let real_shape = id::var(bwd_vars.len()); - bwd_vars.push(BwdTy::Known(DUAL)); + bwd_vars.push(Some(DUAL)); let inter_tup = id::var(bwd_vars.len()); - bwd_vars.push(BwdTy::Unknown); - - let mut duals = vec![None; f.vars.len()].into_boxed_slice(); - let mut accums = vec![None; f.vars.len()].into_boxed_slice(); - - let mut inter_mem = vec![]; - let mut bwd_nonlin = vec![]; - - let mut bwd_params: Vec<_> = f - .params - .iter() - .map(|¶m| { - let t = f.vars[param.var()]; - bwd_nonlin.push(Instr { - var: param, - expr: Expr::Member { - tuple: inter_tup, - member: id::member(inter_mem.len()), - }, - }); - inter_mem.push(param); - let acc = id::var(bwd_vars.len()); - bwd_vars.push(BwdTy::Accum(t)); - if let Ty::F64 = types[t.ty()] { - duals[param.var()] = Some((Src::Original, Src::Original)); - } - accums[param.var()] = Some(acc); - acc - }) - .collect(); + bwd_vars.push(None); let mut tp = Transpose { f, - types, + deps, + mapped_types: f + .types + .iter() + .enumerate() + .map(|(i, ty)| match ty { + Ty::Unit => Ty::Unit, + Ty::Bool => Ty::Bool, + Ty::F64 => { + if !is_primitive(id::ty(i)) { + panic!() + } + Ty::F64 + } + &Ty::Fin { size } => Ty::Fin { size }, + &Ty::Generic { id } => Ty::Generic { id }, + &Ty::Ref { inner } => { + if is_primitive(inner) { + panic!() + } + Ty::Ref { inner } + } + &Ty::Array { index, elem } => { + if is_primitive(elem) { + panic!() + } + Ty::Array { index, elem } + } + Ty::Tuple { members } => { + if members.iter().any(|&t| is_primitive(t)) { + Ty::F64 + } else { + Ty::Tuple { + members: members.clone(), + } + } + } + }) + .collect(), + types: indexset! { Ty::Unit }, + unit: id::ty(f.types.len()), fwd_vars: f.vars.to_vec(), bwd_vars, real_shape, prims: vec![None; f.vars.len()].into(), - duals, - accums, + duals: vec![None; f.vars.len()].into(), + accums: vec![None; f.vars.len()].into(), cotans: vec![None; f.vars.len()].into(), + stack: vec![], block: Block { fwd: vec![], inter_tup, - inter_mem, - bwd_nonlin, + inter_mem: vec![], + bwd_nonlin: vec![], bwd_lin: vec![], }, }; + let mut bwd_params: Vec<_> = f + .params + .iter() + .map(|¶m| { + let t = f.vars[param.var()]; + match &f.types[t.ty()] { + &Ty::Ref { inner } => { + let cot = tp.bwd_var(Some(inner)); + tp.cotans[param.var()] = Some(cot); + cot + } + _ => { + let t_acc = tp.ty(Ty::Ref { inner: t }); + tp.keep(param); + let acc = tp.bwd_var(Some(t_acc)); + if let Ty::F64 = tp.mapped_types[t.ty()] { + tp.duals[param.var()] = Some((Src(None), Src(None))); + } + tp.accums[param.var()] = Some(acc); + acc + } + } + }) + .collect(); + let (t_intermediates, fwd_inter) = tp.block(&f.body); let fwd_ret = tp.get_re(f.ret); - let bwd_acc = tp.get_dual_accum(f.ret).unwrap(); - let mut bwd_types = tp.types.clone(); + let bwd_acc = tp.get_accum(f.ret); + + let mut bwd_types = tp.mapped_types; + bwd_types.extend(tp.types.into_iter()); - let mut fwd_types = tp.types; + let mut fwd_types = bwd_types.clone(); let t_bundle = id::ty(fwd_types.len()); fwd_types.push(Ty::Tuple { - members: vec![f.vars[f.ret.var()], t_intermediates].into(), + members: [f.vars[f.ret.var()], t_intermediates].into(), }); let mut fwd_vars = tp.fwd_vars; let fwd_bundle = id::var(fwd_vars.len()); @@ -760,31 +1180,15 @@ pub fn transpose(f: &Func) -> (Func, Func) { fwd_body.push(Instr { var: fwd_bundle, expr: Expr::Tuple { - members: vec![fwd_ret, fwd_inter].into(), + members: [fwd_ret, fwd_inter].into(), }, }); - let t_unit = id::ty(bwd_types.len()); - bwd_types.push(Ty::Unit); - let mut bwd_vars: Vec<_> = tp - .bwd_vars - .into_iter() - .enumerate() - .map(|(i, t)| match t { - BwdTy::Known(t) => t, - BwdTy::Unit => t_unit, - BwdTy::Accum(inner) => { - let t = id::ty(bwd_types.len()); - bwd_types.push(Ty::Ref { inner }); - t - } - BwdTy::Unknown => panic!(), - }) - .collect(); + let mut bwd_vars: Vec<_> = tp.bwd_vars.into_iter().map(|t| t.unwrap()).collect(); let bwd_cot = id::var(bwd_vars.len()); bwd_vars.push(f.vars[f.ret.var()]); let bwd_unit = id::var(bwd_vars.len()); - bwd_vars.push(t_unit); + bwd_vars.push(tp.unit); bwd_params.push(bwd_cot); bwd_params.push(tp.block.inter_tup); let mut bwd_body = vec![Instr { @@ -792,16 +1196,22 @@ pub fn transpose(f: &Func) -> (Func, Func) { expr: Expr::F64 { val: 0. }, }]; bwd_body.append(&mut tp.block.bwd_nonlin); - let mut bwd_linear = tp.block.bwd_lin; - bwd_linear.push(Instr { + bwd_body.push(Instr { var: bwd_unit, expr: Expr::Add { accum: bwd_acc, addend: bwd_cot, }, }); - bwd_linear.reverse(); - bwd_body.append(&mut bwd_linear); + let mut bwd_lin = tp.block.bwd_lin; + bwd_lin.reverse(); + bwd_body.append(&mut bwd_lin); + let bwd_ret = id::var(bwd_vars.len()); // separate var, because `bwd_unit` might not be in scope + bwd_vars.push(tp.unit); + bwd_body.push(Instr { + var: bwd_ret, + expr: Expr::Unit, + }); ( Func { @@ -817,7 +1227,7 @@ pub fn transpose(f: &Func) -> (Func, Func) { types: bwd_types.into(), vars: bwd_vars.into(), params: bwd_params.into(), - ret: bwd_unit, + ret: bwd_ret, body: bwd_body.into(), }, ) diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index c5563f1..d406f1e 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -124,10 +124,13 @@ struct Pointee { /// The actual strings are stored in JavaScript. structs: Box<[Option>]>, + /// Jacobian-vector product. jvp: RefCell>>, + /// Forward pass of the vector-Jacobian product. fwd: RefCell>>, + /// Backward pass of the vector-Jacobian product. bwd: RefCell>>, } @@ -304,6 +307,7 @@ impl Func { Self { rc } } + /// Return the forward and backward pass of the transpose of this function. fn transpose_pair(&self) -> (Self, Self) { let Pointee { inner, @@ -324,7 +328,16 @@ impl Func { Inner::Transparent { deps, def } => { let (deps_fwd, deps_bwd): (Vec<_>, Vec<_>) = deps.iter().map(|f| f.transpose_pair()).unzip(); - let (def_fwd, def_bwd) = rose_transpose::transpose(def); + let dep_types: Box<_> = deps_bwd + .iter() + .map(|f| match &f.rc.as_ref().inner { + Inner::Transparent { def, .. } => { + (def.types.as_ref(), def.vars[def.ret.var()]) + } + Inner::Opaque { types, ret, .. } => (types.as_ref(), *ret), + }) + .collect(); + let (def_fwd, def_bwd) = rose_transpose::transpose(def, &dep_types); let structs_fwd = def_fwd .types .iter() @@ -373,6 +386,9 @@ impl Func { (Self { rc: rc_fwd }, Self { rc: rc_bwd }) } + /// Return the transpose of this function. + /// + /// Assumes that this function has already been computed as the `jvp` of another function. pub fn transpose(&self) -> Transpose { let (fwd, bwd) = self.transpose_pair(); Transpose { @@ -382,6 +398,7 @@ impl Func { } } +/// A temporary object to hold the two passes of a transposed function before they are destructured. #[wasm_bindgen] pub struct Transpose { fwd: Option, @@ -390,10 +407,12 @@ pub struct Transpose { #[wasm_bindgen] impl Transpose { + /// Return the forward pass. pub fn fwd(&mut self) -> Option { self.fwd.take() } + /// Return the backward pass. pub fn bwd(&mut self) -> Option { self.bwd.take() } @@ -1536,6 +1555,10 @@ impl Block { self.instr(f, id::ty(t), expr) } + /// Return the variable ID for a new instruction defining an accumulator with the given `shape`. + /// + /// Assumes `shape` is defined and in scope, and that `t` is the ID of a reference type whose + /// inner type is the same as the type of `shape`. pub fn accum(&mut self, f: &mut FuncBuilder, t: usize, shape: usize) -> usize { let expr = rose::Expr::Accum { shape: id::var(shape), @@ -1543,6 +1566,10 @@ impl Block { self.instr(f, id::ty(t), expr) } + /// Return the variable ID for a new instruction resolving the given accumulator `var`. + /// + /// Assumes `var` is defined and in scope, and that `t` is the inner type of the reference type + /// for `var`. pub fn resolve(&mut self, f: &mut FuncBuilder, t: usize, var: usize) -> usize { let expr = rose::Expr::Resolve { var: id::var(var) }; self.instr(f, id::ty(t), expr) diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 16f5c83..2198309 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -595,6 +595,7 @@ export const jvp = ( return g; }; +/** Construct a closure that computes the Jacobian-vector product of `f`. */ export const vjp = ( f: Fn & ((arg: A) => R), ): ((arg: A) => { ret: R; grad: (cot: R) => A }) => { diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 6e105ba..3fe1846 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -399,4 +399,75 @@ describe("valid", () => { expect(h(true)).toBe(0); expect(h(false)).toBe(1); }); + + test("VJP with vector comprehension", () => { + const n = 2; + const f = fn([Vec(n, Real)], Vec(n, Real), (v) => + vec(n, Real, (i) => mul(v[i], v[i])), + ); + const g = fn([Vec(n, Real), Vec(n, Real)], Vec(n, Real), (u, v) => + vjp(f)(u).grad(v), + ); + expect(interp(g)([2, 3], [5, 7])).toEqual([20, 42]); + }); + + test("VJP twice", () => { + const f = fn([Real], Real, (x) => { + const y = mul(x, x); + return mul(x, y); + }); + const g = fn([Real], Real, (x) => vjp(f)(x).grad(1)); + const h = fn([Real], Real, (x) => vjp(g)(x).grad(1)); + expect(interp(h)(10)).toBe(60); + }); + + test("Hessian", () => { + const powi = (x: Real, n: number): Real => { + if (!Number.isInteger(n)) + throw new Error(`exponent is not an integer: ${n}`); + // https://en.wikipedia.org/wiki/Exponentiation_by_squaring + if (n < 0) return powi(div(1, x), -n); + else if (n == 0) return 1; + else if (n == 1) return x; + else if (n % 2 == 0) return powi(mul(x, x), n / 2); + else return mul(x, powi(mul(x, x), (n - 1) / 2)); + }; + const f = fn([Vec(2, Real)], Real, (v) => { + const x = v[0]; + const y = v[1]; + return sub(sub(powi(x, 3), mul(2, mul(x, y))), powi(y, 6)); + }); + const g = fn([Vec(2, Real)], Vec(2, Real), (v) => vjp(f)(v).grad(1)); + const h = fn([Vec(2, Real)], Vec(2, Vec(2, Real)), (v) => { + const { grad } = vjp(g)(v); + return [grad([1, 0] as any), grad([0, 1] as any)]; + }); + expect(interp(h)([1, 2])).toEqual([ + [6, -2], + [-2, -480], + ]); + }); + + test("VJP twice with struct", () => { + const Pair = { x: Real, y: Real } as const; + const f = fn([Pair], Real, ({ x, y }) => mul(x, y)); + const g = fn([Pair], Pair, (p) => vjp(f)(p).grad(1)); + const h = fn([Pair, Pair], Pair, (p, q) => vjp(g)(p).grad(q)); + expect(interp(h)({ x: 2, y: 3 }, { x: 5, y: 7 })).toEqual({ x: 7, y: 5 }); + }); + + test("VJP twice with select", () => { + const Stuff = { p: Bool, x: Real, y: Real, z: Real } as const; + const f = fn([Stuff], Real, ({ p, x, y, z }) => + mul(z, select(p, Real, x, y)), + ); + const g = fn([Stuff], Stuff, (p) => vjp(f)(p).grad(1)); + const h = fn([Stuff, Stuff], Stuff, (p, q) => vjp(g)(p).grad(q)); + expect( + interp(h)( + { p: true, x: 2, y: 3, z: 5 }, + { p: false, x: 7, y: 11, z: 13 }, + ), + ).toEqual({ p: true, x: 13, y: 0, z: 7 }); + }); });