From 1ee11a97f3f014846449f34e437769f8926ee99f Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Mon, 11 Sep 2023 09:51:23 -0400 Subject: [PATCH] Transpose functions (#86) --- Cargo.lock | 9 + crates/interp/src/lib.rs | 46 +- crates/transpose/Cargo.toml | 8 + crates/transpose/src/lib.rs | 869 ++++++++++++++++++++++++++++++++ crates/web/Cargo.toml | 1 + crates/web/src/lib.rs | 147 +++++- packages/core/src/impl.test.ts | 1 + packages/core/src/impl.ts | 45 ++ packages/core/src/index.test.ts | 50 +- packages/core/src/index.ts | 1 + 10 files changed, 1155 insertions(+), 22 deletions(-) create mode 100644 crates/transpose/Cargo.toml create mode 100644 crates/transpose/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 59c6c6f..9f71a70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -187,6 +187,14 @@ dependencies = [ "ts-rs", ] +[[package]] +name = "rose-transpose" +version = "0.0.0" +dependencies = [ + "enumset", + "rose", +] + [[package]] name = "rose-web" version = "0.0.0" @@ -198,6 +206,7 @@ dependencies = [ "rose", "rose-autodiff", "rose-interp", + "rose-transpose", "serde", "serde-wasm-bindgen", "wasm-bindgen", diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 8f0a7d1..4132195 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -16,7 +16,7 @@ pub enum Val { Bool(bool), F64(Cell), Fin(usize), - Ref(Rc), + Ref(Rc, Option), Array(Vals), // assume all indices are `Fin` Tuple(Vals), } @@ -50,9 +50,35 @@ impl Val { } } + fn fin(&self) -> usize { + match self { + &Val::Fin(i) => i, + _ => unreachable!(), + } + } + + fn get(&self, i: usize) -> &Self { + match self { + Val::Array(x) => &x[i], + Val::Tuple(x) => &x[i], + _ => unreachable!(), + } + } + + fn slice(&self, i: usize) -> Self { + match self { + Val::Ref(x, None) => Val::Ref(Rc::clone(x), Some(i)), + Val::Ref(x, Some(j)) => Val::Ref(Rc::new(x.get(*j).clone()), Some(i)), + _ => unreachable!(), + } + } + fn inner(&self) -> &Self { match self { - Val::Ref(x) => x.as_ref(), + Val::Ref(x, i) => match i { + None => x.as_ref(), + &Some(j) => x.get(j), + }, _ => unreachable!(), } } @@ -64,7 +90,7 @@ impl Val { &Self::Bool(x) => Self::Bool(x), Self::F64(_) => Self::F64(Cell::new(0.)), &Self::Fin(x) => Self::Fin(x), - Self::Ref(_) => unreachable!(), + Self::Ref(..) => unreachable!(), Self::Array(x) => Self::Array(collect_vals(x.iter().map(|x| x.zero()))), Self::Tuple(x) => Self::Tuple(collect_vals(x.iter().map(|x| x.zero()))), } @@ -192,14 +218,8 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { _ => unreachable!(), }, - &Expr::Slice { array, index } => match (self.get(array).inner(), self.get(index)) { - (Val::Array(v), &Val::Fin(i)) => v[i].clone(), - _ => unreachable!(), - }, - &Expr::Field { tuple, member } => match self.get(tuple).inner() { - Val::Tuple(x) => x[member.member()].clone(), - _ => unreachable!(), - }, + &Expr::Slice { array, index } => self.get(array).slice(self.get(index).fin()), + &Expr::Field { tuple, member } => self.get(tuple).slice(member.member()), &Expr::Unary { op, arg } => { let x = self.get(arg); @@ -257,8 +277,8 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { )) } - &Expr::Read { var } => Val::Ref(Rc::new(self.get(var).clone())), - &Expr::Accum { shape } => Val::Ref(Rc::new(self.get(shape).zero())), + &Expr::Read { var } => Val::Ref(Rc::new(self.get(var).clone()), None), + &Expr::Accum { shape } => Val::Ref(Rc::new(self.get(shape).zero()), None), &Expr::Ask { var } => self.get(var).inner().clone(), &Expr::Add { accum, addend } => { diff --git a/crates/transpose/Cargo.toml b/crates/transpose/Cargo.toml new file mode 100644 index 0000000..c122e43 --- /dev/null +++ b/crates/transpose/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "rose-transpose" +version = "0.0.0" +edition = "2021" + +[dependencies] +enumset = "1" +rose = { path = "../core" } diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs new file mode 100644 index 0000000..853a11c --- /dev/null +++ b/crates/transpose/src/lib.rs @@ -0,0 +1,869 @@ +use enumset::EnumSet; +use rose::{id, Binop, Constraint, Expr, Func, Instr, Ty, Unop}; +use std::mem::{swap, take}; + +const REAL: id::Ty = id::ty(0); +const DUAL: id::Ty = id::ty(1); + +fn is_primitive(t: id::Ty) -> bool { + t == REAL || t == DUAL +} + +enum Scope { + Generic(id::Generic), + Derived(id::Var), + Original, +} + +/// Type ID of a variable while building up the backward pass. +enum BwdTy { + /// We already know the type ID of this variable. + /// + /// Usually this means the variable is directly mapped from a variable in the original function. + Known(id::Ty), + + /// This variable has the unit type. + Unit, + + /// This variable is an accumulator. + /// + /// After we process the entire function, we need to add a scope type for every new accumulator + /// variable we've introduced; we don't add these as we go because we don't know ahead of time + /// how many types we'll need for intermediate values, and we want all those intermediate value + /// type IDs to be the same between the forward and backward passes. + Accum(Scope, id::Ty), + + /// We don't know the type of this variable yet, but will soon. + /// + /// Usually this means the variable is a tuple of intermediate values, and we'll update its type + /// to something concrete when we reach the end of the block. + Unknown, +} + +/// The source of a primitive variable. +#[derive(Clone, Copy)] +enum Src { + /// This variable is original. + Original, + + /// This variable is an alias of an original primitive variable of the same type. + Alias(id::Var), + + /// This variable is a component of an original dual number variable. + Projection(id::Var), +} + +impl Src { + fn prim(self, x: id::Var) -> Self { + match self { + Self::Original => Self::Alias(x), + _ => self, + } + } + + fn dual(self, x: id::Var) -> Self { + match self { + Self::Original => Self::Projection(x), + _ => self, + } + } +} + +/// Linear variables in the backward pass for a variable from the original function. +struct Lin { + /// Accumulator variable. + acc: id::Var, + + /// Resolved cotangent variable. + cot: id::Var, +} + +/// A block under construction for both the forward pass and the backward pass. +struct Block { + /// Instructions in this forward pass block, in order. + fwd: Vec, + + /// Variable IDs for intermediate values to be saved at the end of this forward pass block. + inter_mem: Vec, + + /// Variable ID for the intermediate values tuple in this backward pass block. + inter_tup: id::Var, + + /// Instructions at the beginning of this backward pass block, in order. + bwd_nonlin: Vec, + + /// Instructions at the end of this backward pass block, in reverse order. + bwd_lin: Vec, +} + +/// The forward pass and backward pass of a transposed function under construction. +struct Transpose<'a> { + /// The function being transposed, which is usually a forward-mode derivative. + f: &'a Func, + + /// Shared types between the forward and backward passes. + /// + /// This starts out the same length as `f.types`, with the only difference being that all dual + /// number types are replaced with `F64`. Then we add more types only for tuples and arrays of + /// intermediate values that are shared between the two passes. + types: Vec, + + /// Types of variables in the forward pass. + /// + /// This starts out as a clone of `f.vars`, but more variables can be added for dealing with + /// intermediate values. + fwd_vars: Vec, + + /// Types of variables in the backward pass. + /// + /// This starts out with the `Known` type of every variable from `f.vars`, but more variables + /// can be added for intermediate values, accumulators, and cotangents. + bwd_vars: Vec, + + /// A variable of type `F64` defined at the beginning of the backward pass. + /// + /// Every accumulator must be initialized using a concrete variable to dictate its topology. For + /// most variables, we keep around the original nonlinear value and use that as the shape, but + /// this doesn't work for raw linear `F64` variables, which might be part of some lower-level + /// mathematical calculation that is not clearly attached to any value at the dual number level + /// or higher. All those values have the same shape, though, so in those cases we just use this + /// one dummy variable as the shape. + real_shape: id::Var, + + /// Sources of primitive variables from the original function. + prims: Box<[Option]>, + + /// Sources for dual number variables from the original function. + duals: Box<[Option<(Src, Src)>]>, + + /// Accumulator variables for variables from the original function. + accums: Box<[Option]>, + + /// Cotangent variables for variables from the original function. + cotans: Box<[Option]>, + + /// The current block under construction. + block: Block, +} + +impl<'a> Transpose<'a> { + fn re(&self, x: id::Var) -> Src { + let (src, _) = self.duals[x.var()].unwrap(); + src.dual(x) + } + + fn du(&self, x: id::Var) -> Src { + let (_, src) = self.duals[x.var()].unwrap(); + src.dual(x) + } + + fn get_prim(&self, x: id::Var) -> id::Var { + match self.prims[x.var()].unwrap() { + Src::Original => x, + Src::Alias(y) | Src::Projection(y) => y, + } + } + + fn get_re(&self, x: id::Var) -> id::Var { + match self.duals[x.var()] { + None | Some((Src::Original, _)) => x, + Some((Src::Alias(y), _)) | Some((Src::Projection(y), _)) => y, + } + } + + fn get_du(&self, x: id::Var) -> id::Var { + match self.duals[x.var()] { + None | Some((_, Src::Original)) => x, + Some((_, Src::Alias(y))) | Some((_, Src::Projection(y))) => y, + } + } + + fn get_prim_accum(&self, x: id::Var) -> id::Var { + self.accums[self.get_prim(x).var()].unwrap() + } + + fn get_dual_accum(&self, x: id::Var) -> Option { + self.accums[self.get_du(x).var()] + } + + fn ty(&mut self, ty: Ty) -> id::Ty { + let t = id::ty(self.types.len()); + self.types.push(ty); + t + } + + fn fwd_var(&mut self, t: id::Ty) -> id::Var { + let var = id::var(self.fwd_vars.len()); + self.fwd_vars.push(t); + var + } + + fn bwd_var(&mut self, t: BwdTy) -> id::Var { + let var = id::var(self.bwd_vars.len()); + self.bwd_vars.push(t); + var + } + + fn keep(&mut self, var: id::Var) { + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Member { + tuple: self.block.inter_tup, + member: id::member(self.block.inter_mem.len()), + }, + }); + self.block.inter_mem.push(var); + } + + fn accum(&mut self, shape: id::Var, scope: Scope) -> Lin { + let t = self.f.vars[shape.var()]; + let acc = self.bwd_var(BwdTy::Accum(scope, t)); + let cot = self.bwd_var(BwdTy::Known(t)); + self.block.bwd_nonlin.push(Instr { + var: acc, + expr: Expr::Accum { shape }, + }); + self.accums[shape.var()] = Some(acc); + self.cotans[shape.var()] = Some(cot); + Lin { acc, cot } + } + + fn calc(&mut self, tan: id::Var) -> Lin { + let t = self.f.vars[tan.var()]; + let acc = self.bwd_var(BwdTy::Accum(Scope::Original, t)); + let cot = self.bwd_var(BwdTy::Known(t)); + self.block.bwd_nonlin.push(Instr { + var: acc, + expr: Expr::Accum { + shape: self.real_shape, + }, + }); + self.accums[tan.var()] = Some(acc); + self.cotans[tan.var()] = Some(cot); + Lin { acc, cot } + } + + fn resolve(&mut self, lin: Lin) { + self.block.bwd_lin.push(Instr { + var: lin.cot, + expr: Expr::Resolve { var: lin.acc }, + }) + } + + fn block(&mut self, block: &[Instr]) -> (id::Ty, id::Var) { + for instr in block.iter() { + self.instr(instr.var, &instr.expr); + } + let vars = take(&mut self.block.inter_mem); + let t = self.ty(Ty::Tuple { + members: vars.iter().map(|&x| self.fwd_vars[x.var()]).collect(), + }); + let var = self.fwd_var(t); + self.block.fwd.push(Instr { + var, + expr: Expr::Tuple { + members: vars.into(), + }, + }); + self.bwd_vars[self.block.inter_tup.var()] = BwdTy::Known(t); + (t, var) + } + + fn instr(&mut self, var: id::Var, expr: &Expr) { + match expr { + Expr::Unit => { + self.block.fwd.push(Instr { + var, + expr: Expr::Unit, + }); + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Unit, + }); + let lin = self.accum(var, Scope::Original); + self.resolve(lin); + } + &Expr::Bool { val } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Bool { val }, + }); + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Bool { val }, + }); + let lin = self.accum(var, Scope::Original); + self.resolve(lin); + } + &Expr::F64 { val } => { + match self.f.vars[var.var()] { + DUAL => { + let lin = self.calc(var); + self.resolve(lin); + } + _ => self.block.fwd.push(Instr { + var, + expr: Expr::F64 { val }, + }), + } + self.prims[var.var()] = Some(Src::Original); + } + &Expr::Fin { val } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Fin { val }, + }); + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Fin { val }, + }); + let lin = self.accum(var, Scope::Original); + self.resolve(lin); + } + + Expr::Array { elems } => { + let t = match self.f.types[self.f.vars[var.var()].ty()] { + Ty::Array { index, elem: _ } => index, + _ => panic!(), + }; + self.block.fwd.push(Instr { + var, + expr: Expr::Array { + elems: elems.iter().map(|&elem| self.get_re(elem)).collect(), + }, + }); + self.keep(var); + let lin = self.accum(var, Scope::Original); + for (i, &elem) in elems.iter().enumerate() { + if let Some(accum) = self.get_dual_accum(elem) { + let index = self.bwd_var(BwdTy::Known(t)); + let addend = self.bwd_var(BwdTy::Known(self.f.vars[elem.var()])); + let unit = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { accum, addend }, + }); + self.block.bwd_lin.push(Instr { + var: addend, + expr: Expr::Index { + array: lin.cot, + index, + }, + }); + self.block.bwd_lin.push(Instr { + var: index, + expr: Expr::Fin { val: i }, + }); + } + } + self.resolve(lin); + } + Expr::Tuple { members } => match self.types[self.f.vars[var.var()].ty()] { + Ty::F64 => { + let x = members[1]; + let dx = members[0]; + self.duals[var.var()] = Some(( + self.prims[x.var()].unwrap().prim(x), + self.prims[dx.var()].unwrap().prim(dx), + )); + } + _ => { + self.block.fwd.push(Instr { + var, + expr: Expr::Tuple { + members: members.iter().map(|&member| self.get_re(member)).collect(), + }, + }); + self.keep(var); + let lin = self.accum(var, Scope::Original); + for (i, &member) in members.iter().enumerate() { + if let Some(accum) = self.get_dual_accum(member) { + let addend = self.bwd_var(BwdTy::Known(self.f.vars[member.var()])); + let unit = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { accum, addend }, + }); + self.block.bwd_lin.push(Instr { + var: addend, + expr: Expr::Member { + tuple: lin.cot, + member: id::member(i), + }, + }); + } + } + self.resolve(lin); + } + }, + + &Expr::Index { array, index } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Index { array, index }, + }); + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Index { array, index }, + }); + let arr_acc = self.accums[array.var()].unwrap(); + let acc = self.bwd_var(BwdTy::Accum( + Scope::Derived(arr_acc), + self.f.vars[array.var()], + )); + self.accums[var.var()] = Some(acc); + self.block.bwd_nonlin.push(Instr { + var: acc, + expr: Expr::Slice { + array: arr_acc, + index, + }, + }); + if let Ty::F64 = self.types[self.f.vars[var.var()].ty()] { + self.duals[var.var()] = Some((Src::Original, Src::Original)); + } + } + &Expr::Member { tuple, member } => { + let t = self.f.vars[var.var()]; + match t { + REAL => self.prims[var.var()] = Some(self.re(tuple)), + DUAL => self.prims[var.var()] = Some(self.du(tuple)), + _ => { + self.block.fwd.push(Instr { + var, + expr: Expr::Member { tuple, member }, + }); + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Member { tuple, member }, + }); + let tup_acc = self.accums[tuple.var()].unwrap(); + let acc = self.bwd_var(BwdTy::Accum( + Scope::Derived(tup_acc), + self.f.vars[tuple.var()], + )); + self.accums[var.var()] = Some(acc); + self.block.bwd_nonlin.push(Instr { + var: acc, + expr: Expr::Field { + tuple: tup_acc, + member, + }, + }); + if let Ty::F64 = self.types[t.ty()] { + self.duals[var.var()] = Some((Src::Original, Src::Original)); + } + } + } + } + + &Expr::Slice { array, index } => todo!(), + &Expr::Field { tuple, member } => todo!(), + + &Expr::Unary { op, arg } => { + match self.f.vars[var.var()] { + DUAL => match op { + Unop::Not | Unop::Abs | Unop::Sign | Unop::Sqrt => panic!(), + Unop::Neg => { + let lin = self.calc(var); + let res = self.bwd_var(BwdTy::Known(DUAL)); + let unit = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { + accum: self.get_prim_accum(arg), + addend: res, + }, + }); + self.block.bwd_lin.push(Instr { + var: res, + expr: Expr::Unary { + op: Unop::Neg, + arg: lin.cot, + }, + }); + self.resolve(lin); + } + }, + _ => { + self.block.fwd.push(Instr { + var, + expr: Expr::Unary { + op, + arg: self.get_prim(arg), + }, + }); + self.keep(var); + } + } + self.prims[var.var()] = Some(Src::Original); + } + &Expr::Binary { op, left, right } => { + match self.f.vars[var.var()] { + DUAL => { + let lin = self.calc(var); + match op { + Binop::And + | Binop::Or + | Binop::Iff + | Binop::Xor + | Binop::Neq + | Binop::Lt + | Binop::Leq + | Binop::Eq + | Binop::Gt + | Binop::Geq => panic!(), + Binop::Add => { + let a = self.bwd_var(BwdTy::Unit); + let b = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: a, + expr: Expr::Add { + accum: self.get_prim_accum(left), + addend: lin.cot, + }, + }); + self.block.bwd_lin.push(Instr { + var: b, + expr: Expr::Add { + accum: self.get_prim_accum(right), + addend: lin.cot, + }, + }); + } + Binop::Sub => { + let res = self.bwd_var(BwdTy::Known(DUAL)); + let a = self.bwd_var(BwdTy::Unit); + let b = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: a, + expr: Expr::Add { + accum: self.get_prim_accum(left), + addend: lin.cot, + }, + }); + self.block.bwd_lin.push(Instr { + var: b, + expr: Expr::Add { + accum: self.get_prim_accum(right), + addend: res, + }, + }); + self.block.bwd_lin.push(Instr { + var: res, + expr: Expr::Unary { + op: Unop::Neg, + arg: lin.cot, + }, + }); + } + Binop::Mul | Binop::Div => { + let res = self.bwd_var(BwdTy::Known(DUAL)); + let unit = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { + accum: self.get_prim_accum(left), + addend: res, + }, + }); + self.block.bwd_lin.push(Instr { + var: res, + expr: Expr::Binary { + op, + left: lin.cot, + right: self.get_prim(right), + }, + }); + } + } + self.resolve(lin); + } + _ => { + let (a, b) = match op { + Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right), + Binop::Neq + | Binop::Lt + | Binop::Leq + | Binop::Eq + | Binop::Gt + | Binop::Geq + | Binop::Add + | Binop::Sub + | Binop::Mul + | Binop::Div => (self.get_prim(left), self.get_prim(right)), + }; + self.block.fwd.push(Instr { + var, + expr: Expr::Binary { + op, + left: a, + right: b, + }, + }); + self.keep(var); + } + } + self.prims[var.var()] = Some(Src::Original); + } + &Expr::Select { cond, then, els } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Select { + cond, + then: self.get_re(then), + els: self.get_re(els), + }, + }); + self.keep(var); + let lin = self.accum(var, Scope::Original); + let acc_then = self.get_dual_accum(then).unwrap(); + let acc_els = self.get_dual_accum(els).unwrap(); + // TODO: this scope is wrong; it actually needs to match both `then` and `els` + let acc = self.bwd_var(BwdTy::Accum(Scope::Original, self.f.vars[then.var()])); + let unit = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { + accum: acc, + addend: lin.cot, + }, + }); + self.block.bwd_lin.push(Instr { + var: acc, + expr: Expr::Select { + cond, + then: acc_then, + els: acc_els, + }, + }); + self.resolve(lin); + if let Ty::F64 = self.types[self.f.vars[var.var()].ty()] { + self.duals[var.var()] = Some((Src::Original, Src::Original)); + } + } + + Expr::Call { id, generics, args } => todo!(), + Expr::For { arg, body, ret } => { + let mut block = Block { + fwd: vec![], + inter_mem: vec![], + inter_tup: self.bwd_var(BwdTy::Unknown), + bwd_nonlin: vec![], + bwd_lin: vec![], + }; + swap(&mut self.block, &mut block); + self.block(body); // TODO + swap(&mut self.block, &mut block); + } + + &Expr::Read { var } => todo!(), + &Expr::Accum { shape } => todo!(), + + &Expr::Ask { var } => todo!(), + &Expr::Add { accum, addend } => todo!(), + + &Expr::Resolve { var } => todo!(), + } + } +} + +/// Return two functions that make up the transpose of `f`. +pub fn transpose(f: &Func) -> (Func, Func) { + let types: Vec<_> = f + .types + .iter() + .enumerate() + .map(|(i, ty)| match ty { + Ty::Unit => Ty::Unit, + Ty::Bool => Ty::Bool, + Ty::F64 => { + if !is_primitive(id::ty(i)) { + panic!() + } + Ty::F64 + } + &Ty::Fin { size } => Ty::Fin { size }, + &Ty::Generic { id } => Ty::Generic { id }, + &Ty::Scope { kind, id } => Ty::Scope { kind, id }, + &Ty::Ref { scope, inner } => { + if is_primitive(inner) { + panic!() + } + Ty::Ref { scope, inner } + } + &Ty::Array { index, elem } => { + if is_primitive(elem) { + panic!() + } + Ty::Array { index, elem } + } + Ty::Tuple { members } => { + if members.iter().any(|&t| is_primitive(t)) { + Ty::F64 + } else { + Ty::Tuple { + members: members.clone(), + } + } + } + }) + .collect(); + + let mut bwd_generics: Vec<_> = f + .generics + .iter() + .map(|&before| { + let mut after = before; + if before.contains(Constraint::Read) { + after.remove(Constraint::Read); + after.insert(Constraint::Accum); + } + if before.contains(Constraint::Accum) { + after.remove(Constraint::Accum); + after.insert(Constraint::Read); + } + after + }) + .collect(); + + let mut bwd_vars: Vec<_> = f.vars.iter().map(|&t| BwdTy::Known(t)).collect(); + let real_shape = id::var(bwd_vars.len()); + bwd_vars.push(BwdTy::Known(DUAL)); + let inter_tup = id::var(bwd_vars.len()); + bwd_vars.push(BwdTy::Unknown); + + let mut duals = vec![None; f.vars.len()].into_boxed_slice(); + let mut accums = vec![None; f.vars.len()].into_boxed_slice(); + + let mut inter_mem = vec![]; + let mut bwd_nonlin = vec![]; + + let mut bwd_params: Vec<_> = f + .params + .iter() + .map(|¶m| { + let g = id::generic(bwd_generics.len()); + bwd_generics.push(EnumSet::only(Constraint::Accum)); + let t = f.vars[param.var()]; + bwd_nonlin.push(Instr { + var: param, + expr: Expr::Member { + tuple: inter_tup, + member: id::member(inter_mem.len()), + }, + }); + inter_mem.push(param); + let acc = id::var(bwd_vars.len()); + bwd_vars.push(BwdTy::Accum(Scope::Generic(g), t)); + if let Ty::F64 = types[t.ty()] { + duals[param.var()] = Some((Src::Original, Src::Original)); + } + accums[param.var()] = Some(acc); + acc + }) + .collect(); + + let mut tp = Transpose { + f, + types, + fwd_vars: f.vars.to_vec(), + bwd_vars, + real_shape, + prims: vec![None; f.vars.len()].into(), + duals, + accums, + cotans: vec![None; f.vars.len()].into(), + block: Block { + fwd: vec![], + inter_tup, + inter_mem, + bwd_nonlin, + bwd_lin: vec![], + }, + }; + + let (t_intermediates, fwd_inter) = tp.block(&f.body); + let fwd_ret = tp.get_re(f.ret); + let bwd_acc = tp.get_dual_accum(f.ret).unwrap(); + let mut bwd_types = tp.types.clone(); + + let mut fwd_types = tp.types; + let t_bundle = id::ty(fwd_types.len()); + fwd_types.push(Ty::Tuple { + members: vec![f.vars[f.ret.var()], t_intermediates].into(), + }); + let mut fwd_vars = tp.fwd_vars; + let fwd_bundle = id::var(fwd_vars.len()); + fwd_vars.push(t_bundle); + let mut fwd_body = tp.block.fwd; + fwd_body.push(Instr { + var: fwd_bundle, + expr: Expr::Tuple { + members: vec![fwd_ret, fwd_inter].into(), + }, + }); + + let t_unit = id::ty(bwd_types.len()); + bwd_types.push(Ty::Unit); + let mut bwd_vars: Vec<_> = tp + .bwd_vars + .into_iter() + .enumerate() + .map(|(i, t)| match t { + BwdTy::Known(t) => t, + BwdTy::Unit => t_unit, + BwdTy::Accum(scope, inner) => { + let scope = id::ty(bwd_types.len()); + bwd_types.push(Ty::Scope { + kind: Constraint::Accum, + id: id::var(i), + }); + let t = id::ty(bwd_types.len()); + bwd_types.push(Ty::Ref { scope, inner }); + t + } + BwdTy::Unknown => panic!(), + }) + .collect(); + let bwd_cot = id::var(bwd_vars.len()); + bwd_vars.push(f.vars[f.ret.var()]); + let bwd_unit = id::var(bwd_vars.len()); + bwd_vars.push(t_unit); + bwd_params.push(bwd_cot); + bwd_params.push(tp.block.inter_tup); + let mut bwd_body = vec![Instr { + var: tp.real_shape, + expr: Expr::F64 { val: 0. }, + }]; + bwd_body.append(&mut tp.block.bwd_nonlin); + let mut bwd_linear = tp.block.bwd_lin; + bwd_linear.push(Instr { + var: bwd_unit, + expr: Expr::Add { + accum: bwd_acc, + addend: bwd_cot, + }, + }); + bwd_linear.reverse(); + bwd_body.append(&mut bwd_linear); + + ( + Func { + generics: f.generics.clone(), + types: fwd_types.into(), + vars: fwd_vars.into(), + params: f.params.clone(), + ret: fwd_bundle, + body: fwd_body.into(), + }, + Func { + generics: bwd_generics.into(), + types: bwd_types.into(), + vars: bwd_vars.into(), + params: bwd_params.into(), + ret: bwd_unit, + body: bwd_body.into(), + }, + ) +} diff --git a/crates/web/Cargo.toml b/crates/web/Cargo.toml index 853cd68..a4ee556 100644 --- a/crates/web/Cargo.toml +++ b/crates/web/Cargo.toml @@ -15,6 +15,7 @@ js-sys = "0.3" rose = { path = "../core" } rose-autodiff = { path = "../autodiff" } rose-interp = { path = "../interp", features = ["serde"] } +rose-transpose = { path = "../transpose" } serde = { version = "1", features = ["derive"] } serde-wasm-bindgen = "0.4" wasm-bindgen = "=0.2.87" # Must be this version of wbg diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index ec26335..3511125 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -40,6 +40,7 @@ pub fn layouts() -> Result { ("Func", layout::()), ("Instr", layout::()), ("Ty", layout::()), + ("Val", layout::()), ]) } @@ -124,6 +125,10 @@ struct Pointee { structs: Box<[Option>]>, jvp: RefCell>>, + + fwd: RefCell>>, + + bwd: RefCell>>, } /// A node in a reference-counted acyclic digraph of functions. @@ -149,6 +154,8 @@ impl Func { }, structs: [].into(), jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), }), } } @@ -261,6 +268,7 @@ impl Func { inner, structs, jvp, + .. } = self.rc.as_ref(); let mut cache = jvp.borrow_mut(); if let Some(rc) = cache.as_ref().and_then(|weak| weak.upgrade()) { @@ -286,6 +294,8 @@ impl Func { }, structs: structs_jvp.into(), jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), }) } Inner::Opaque { .. } => todo!(), @@ -293,6 +303,100 @@ impl Func { *cache = Some(Rc::downgrade(&rc)); Self { rc } } + + fn transpose_pair(&self) -> (Self, Self) { + let Pointee { + inner, + structs, + fwd, + bwd, + .. + } = self.rc.as_ref(); + let mut cache_fwd = fwd.borrow_mut(); + let mut cache_bwd = bwd.borrow_mut(); + if let (Some(rc_fwd), Some(rc_bwd)) = ( + cache_fwd.as_ref().and_then(|weak| weak.upgrade()), + cache_bwd.as_ref().and_then(|weak| weak.upgrade()), + ) { + return (Self { rc: rc_fwd }, Self { rc: rc_bwd }); + } + let (rc_fwd, rc_bwd) = match inner { + Inner::Transparent { deps, def } => { + let (deps_fwd, deps_bwd): (Vec<_>, Vec<_>) = + deps.iter().map(|f| f.transpose_pair()).unzip(); + let (def_fwd, def_bwd) = rose_transpose::transpose(def); + let structs_fwd = def_fwd + .types + .iter() + .enumerate() + .map(|(i, ty)| match ty { + rose::Ty::F64 => None, + _ => structs.get(i).cloned().flatten(), + }) + .collect(); + let structs_bwd = def_bwd + .types + .iter() + .enumerate() + .map(|(i, ty)| match ty { + rose::Ty::F64 => None, + _ => structs.get(i).cloned().flatten(), + }) + .collect(); + ( + Rc::new(Pointee { + inner: Inner::Transparent { + deps: deps_fwd.into(), + def: def_fwd, + }, + structs: structs_fwd, + jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), + }), + Rc::new(Pointee { + inner: Inner::Transparent { + deps: deps_bwd.into(), + def: def_bwd, + }, + structs: structs_bwd, + jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), + }), + ) + } + Inner::Opaque { .. } => panic!(), + }; + *cache_fwd = Some(Rc::downgrade(&rc_fwd)); + *cache_bwd = Some(Rc::downgrade(&rc_bwd)); + (Self { rc: rc_fwd }, Self { rc: rc_bwd }) + } + + pub fn transpose(&self) -> Transpose { + let (fwd, bwd) = self.transpose_pair(); + Transpose { + fwd: Some(fwd), + bwd: Some(bwd), + } + } +} + +#[wasm_bindgen] +pub struct Transpose { + fwd: Option, + bwd: Option, +} + +#[wasm_bindgen] +impl Transpose { + pub fn fwd(&mut self) -> Option { + self.fwd.take() + } + + pub fn bwd(&mut self) -> Option { + self.bwd.take() + } } #[cfg(feature = "debug")] @@ -498,6 +602,10 @@ enum Ty { Fin { size: usize, }, + Scope { + kind: rose::Constraint, + id: id::Var, + }, Ref { scope: id::Ty, inner: id::Ty, @@ -510,7 +618,7 @@ enum Ty { /// A tuple type, with additional information about key names that makes it into a struct. Struct { /// String IDs for key names, in order; the actual strings are stored in JavaScript. - keys: Box<[usize]>, + keys: Option>, /// Member types of the underlying tuple. Must be the same length as `keys`. members: Box<[id::Ty]>, @@ -525,9 +633,10 @@ impl Ty { Ty::Bool => (rose::Ty::Bool, None), Ty::F64 => (rose::Ty::F64, None), Ty::Fin { size } => (rose::Ty::Fin { size }, None), + Ty::Scope { kind, id } => (rose::Ty::Scope { kind, id }, None), Ty::Ref { scope, inner } => (rose::Ty::Ref { scope, inner }, None), Ty::Array { index, elem } => (rose::Ty::Array { index, elem }, None), - Ty::Struct { keys, members } => (rose::Ty::Tuple { members }, Some(keys)), + Ty::Struct { keys, members } => (rose::Ty::Tuple { members }, keys), } } } @@ -617,6 +726,8 @@ impl FuncBuilder { }, structs: structs.into(), jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), }), } } @@ -720,7 +831,10 @@ impl FuncBuilder { /// `Err` if `t` is out of range or does not represent a struct type. pub fn keys(&self, t: usize) -> Result, JsError> { match self.ty(t)? { - Ty::Struct { keys, members: _ } => Ok(keys.clone()), + Ty::Struct { + keys: Some(keys), + members: _, + } => Ok(keys.clone()), _ => Err(JsError::new("type is not a struct")), } } @@ -798,6 +912,13 @@ impl FuncBuilder { ) } + #[wasm_bindgen(js_name = "tyAccum")] + pub fn ty_accum(&mut self, id: usize) -> usize { + let kind = rose::Constraint::Accum; + let id = id::var(id); + self.newtype(Ty::Scope { kind, id }, EnumSet::only(kind)) + } + /// Return the ID for the type of arrays with index type `index` and element type `elem`, /// /// Assumes `index` and `elem` are valid type IDs. @@ -825,7 +946,7 @@ impl FuncBuilder { pub fn ty_struct(&mut self, keys: &[usize], mems: &[usize]) -> usize { self.newtype( Ty::Struct { - keys: keys.into(), + keys: Some(keys.into()), members: mems.iter().map(|&t| id::ty(t)).collect(), }, EnumSet::only(rose::Constraint::Value), @@ -1010,10 +1131,7 @@ impl FuncBuilder { Ty::Struct { keys: structs[t] .as_ref() - .unwrap() - .iter() - .map(|&s| strings[s]) - .collect(), + .map(|ss| ss.iter().map(|&s| strings[s]).collect()), members: members .iter() .map(|x| types[x.ty()]) @@ -1439,4 +1557,17 @@ impl Block { }; self.instr(f, id::ty(t), expr) } + + pub fn accum(&mut self, f: &mut FuncBuilder, inner: usize, shape: usize) -> usize { + let inner = id::ty(inner); + let shape = id::var(shape); + let scope = id::ty(f.ty_accum(f.vars.len())); + let t = id::ty(f.newtype(Ty::Ref { scope, inner }, EnumSet::empty())); + self.instr(f, t, rose::Expr::Accum { shape }) + } + + pub fn resolve(&mut self, f: &mut FuncBuilder, t: usize, var: usize) -> usize { + let expr = rose::Expr::Resolve { var: id::var(var) }; + self.instr(f, id::ty(t), expr) + } } diff --git a/packages/core/src/impl.test.ts b/packages/core/src/impl.test.ts index 0aa4727..b8686cf 100644 --- a/packages/core/src/impl.test.ts +++ b/packages/core/src/impl.test.ts @@ -39,6 +39,7 @@ test("core IR type layouts", () => { Func: { size: 44, align: 4 }, Instr: { size: 32, align: 8 }, Ty: { size: 12, align: 4 }, + Val: { size: 16, align: 8 }, }); }); diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 2da93cf..10bd300 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -595,6 +595,51 @@ export const jvp = ( return g; }; +export const vjp = ( + f: Fn & ((arg: A) => R), +): ((arg: A) => { ret: R; grad: (cot: R) => A }) => { + const g = jvp(f); + const tp = g[inner].transpose(); + const fwdFunc = tp.fwd()!; + const bwdFunc = tp.bwd()!; + const fwd: Fn = { [inner]: fwdFunc, [strings]: [...f[strings]] }; + const bwd: Fn = { [inner]: bwdFunc, [strings]: [...f[strings]] }; + funcs.register(fwd, fwdFunc); + funcs.register(bwd, bwdFunc); + return (arg: A) => { + const ctx = getCtx(); + const strs = intern(ctx, fwd[strings]); + const generics = new Uint32Array(); // TODO: support generics + const [tArg, tBundle] = ctx.func.ingest(fwd[inner], strs, generics); + const [tRet, tInter] = ctx.func.members(tBundle); + const argId = valId(ctx, tArg, arg); + const bundleId = ctx.block.call( + ctx.func, + fwd[inner], + generics, + tBundle, + new Uint32Array([argId]), + ); + const primalId = ctx.block.member(ctx.func, tRet, bundleId, 0); + const interId = ctx.block.member(ctx.func, tInter, bundleId, 1); + const grad = (cot: R) => { + if (getCtx() !== ctx) throw Error("VJP closure escaped its context"); + const cotId = valId(ctx, tRet, cot); + const accId = ctx.block.accum(ctx.func, tArg, argId); + const tScope = ctx.func.tyAccum(accId); + ctx.block.call( + ctx.func, + bwd[inner], + new Uint32Array([tScope]), // TODO: support generics + ctx.func.tyUnit(), + new Uint32Array([accId, cotId, interId]), + ); + return idVal(ctx, tArg, ctx.block.resolve(ctx.func, tArg, accId)) as A; + }; + return { ret: idVal(ctx, tRet, primalId) as R, grad }; + }; +}; + /** Return the variable ID for the abstract boolean `x`. */ const boolId = (ctx: Context, x: Bool): number => valId(ctx, ctx.func.tyBool(), x); diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index c99fc20..6e105ba 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -12,11 +12,13 @@ import { interp, jvp, mul, + or, select, sign, sqrt, sub, vec, + vjp, } from "./index.js"; describe("invalid", () => { @@ -345,10 +347,56 @@ describe("valid", () => { test("JVP with sharing in call graph", () => { let f = fn([Real], Real, (x) => x); - for (let i = 0; i < 20; i++) { + for (let i = 0; i < 20; ++i) { f = fn([Real], Real, (x) => add(f(x), f(x))); } const g = interp(jvp(f)); expect(g({ re: 2, du: 3 })).toEqual({ re: 2097152, du: 3145728 }); }); + + test("VJP", () => { + const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1])); + const g = fn([], Vec(3, Real), () => { + const { ret: x, grad } = vjp(f)([2, 3]); + const v = grad(1); + return [x, v[0], v[1]]; + }); + expect(interp(g)()).toEqual([6, 3, 2]); + }); + + test("VJP with struct and select", () => { + const Stuff = { a: Null, b: Bool, c: Real } as const; + const f = fn([Stuff], Real, ({ b, c }) => select(or(false, b), Real, c, 2)); + const g = fn([Bool, Real], { x: Real, stuff: Stuff }, (b, c) => { + const { ret: x, grad } = vjp(f)({ a: null, b, c }); + return { x, stuff: grad(3) }; + }); + const h = interp(g); + expect(h(true, 5)).toEqual({ x: 5, stuff: { a: null, b: true, c: 3 } }); + expect(h(false, 7)).toEqual({ x: 2, stuff: { a: null, b: false, c: 0 } }); + }); + + test("VJP with select on null", () => { + const f = fn([Null], Null, () => select(true, Null, null, null)); + const g = fn([], Null, () => vjp(f)(null).ret); + const h = interp(g); + expect(h()).toBe(null); + }); + + test("VJP with select on booleans", () => { + const f = fn([Bool], Bool, (p) => select(p, Bool, false, true)); + const g = fn([Bool], Bool, (p) => vjp(f)(p).ret); + const h = interp(g); + expect(h(true)).toBe(false); + expect(h(false)).toBe(true); + }); + + test("VJP with select on indices", () => { + const n = 2; + const f = fn([Bool], n, (p) => select(p, n, 0, 1)); + const g = fn([Bool], n, (p) => vjp(f)(p).ret); + const h = interp(g); + expect(h(true)).toBe(0); + expect(h(false)).toBe(1); + }); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 602c51b..fad6ed2 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -27,5 +27,6 @@ export { sqrt, sub, vec, + vjp, xor, } from "./impl.js";