From 9377ec42a10fb66bb1f143df6a22bdf5dea07a9a Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Thu, 7 Sep 2023 11:24:53 -0400 Subject: [PATCH 01/11] Transpose functions --- Cargo.lock | 9 ++ crates/transpose/Cargo.toml | 8 + crates/transpose/src/lib.rs | 286 ++++++++++++++++++++++++++++++++++++ crates/web/Cargo.toml | 1 + 4 files changed, 304 insertions(+) create mode 100644 crates/transpose/Cargo.toml create mode 100644 crates/transpose/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 59c6c6f..863541f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -187,6 +187,14 @@ dependencies = [ "ts-rs", ] +[[package]] +name = "rose-transpose" +version = "0.0.0" +dependencies = [ + "indexmap", + "rose", +] + [[package]] name = "rose-web" version = "0.0.0" @@ -198,6 +206,7 @@ dependencies = [ "rose", "rose-autodiff", "rose-interp", + "rose-transpose", "serde", "serde-wasm-bindgen", "wasm-bindgen", diff --git a/crates/transpose/Cargo.toml b/crates/transpose/Cargo.toml new file mode 100644 index 0000000..b6d2f45 --- /dev/null +++ b/crates/transpose/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "rose-transpose" +version = "0.0.0" +edition = "2021" + +[dependencies] +indexmap = "2" +rose = { path = "../core" } diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs new file mode 100644 index 0000000..0b064c8 --- /dev/null +++ b/crates/transpose/src/lib.rs @@ -0,0 +1,286 @@ +use rose::{id, Binop, Constraint, Expr, Func, Instr, Ty, Unop}; + +struct Transpose<'a> { + f: &'a Func, + linear: id::Ty, + fwd_block: Vec, + intermediate_tuple: id::Var, + intermediate_member: id::Member, + bwd_nonlinear: Vec, + bwd_linear: Vec, +} + +impl Transpose<'_> { + fn block(&mut self, block: &[Instr]) { + let mut members = vec![]; + for instr in block.iter() { + let var = instr.var; + self.instr(var, &instr.expr); + let t = self.f.vars[var.var()]; + if t != self.linear && self.f.types[t.ty()] == Ty::F64 { + members.push(var); + self.bwd_nonlinear.push(Instr { + var, + expr: Expr::Member { + tuple: self.intermediate_tuple, + member: self.intermediate_member, + }, + }); + self.intermediate_member = id::member(self.intermediate_member.member() + 1); + } + } + self.fwd_block.push(Instr { + var: self.intermediate_tuple, + expr: Expr::Tuple { + members: members.into(), + }, + }); + } + + fn instr(&mut self, var: id::Var, expr: &Expr) { + match expr { + Expr::Unit => self.fwd_block.push(Instr { + var, + expr: Expr::Unit, + }), + &Expr::Bool { val } => self.fwd_block.push(Instr { + var, + expr: Expr::Bool { val }, + }), + &Expr::F64 { val } => { + if self.f.vars[var.var()] == self.linear { + self.bwd_linear.push(Instr { + var, + expr: Expr::F64 { val }, + }) + } else { + self.fwd_block.push(Instr { + var, + expr: Expr::F64 { val }, + }) + } + } + &Expr::Fin { val } => self.fwd_block.push(Instr { + var, + expr: Expr::Fin { val }, + }), + + Expr::Array { elems } => todo!(), + Expr::Tuple { members } => todo!(), + + Expr::Index { array, index } => todo!(), + Expr::Member { tuple, member } => todo!(), + + Expr::Slice { array, index } => todo!(), + Expr::Field { tuple, member } => todo!(), + + Expr::Unary { op, arg } => todo!(), + Expr::Binary { op, left, right } => todo!(), + Expr::Select { cond, then, els } => todo!(), + + Expr::Call { id, generics, args } => todo!(), + Expr::For { arg, body, ret } => todo!(), + Expr::Read { + var, + arg, + body, + ret, + } => todo!(), + Expr::Accum { + shape, + arg, + body, + ret, + } => todo!(), + + Expr::Ask { var } => todo!(), + Expr::Add { accum, addend } => todo!(), + } + } +} + +/// Return two functions that make up the transpose of `f`. +/// +/// `linear` must be the type index of an `F64` type in `f`. +pub fn transpose(f: &Func, linear: id::Ty) -> (Func, Func) { + let mut tp = Transpose { + f, + linear, + fwd_block: vec![], + intermediate_tuple: id::var(0), // TODO + intermediate_member: id::member(0), // TODO + bwd_nonlinear: vec![], + bwd_linear: vec![], + }; + + tp.block(&f.body); + + let t_f64 = id::ty(f.types.len()); + + let mut fwd_types = f.types.to_vec(); + fwd_types[linear.ty()] = Ty::Unit; + fwd_types.push(Ty::F64); + + let mut bwd_types = f.types.to_vec(); + bwd_types.push(Ty::F64); + + let mut bwd_params: Vec<_> = f + .params + .iter() + .map(|param| match &f.types[f.vars[param.var()].ty()] { + Ty::Unit => todo!(), + Ty::Bool => todo!(), + Ty::F64 => todo!(), + Ty::Fin { size } => todo!(), + Ty::Generic { id } => todo!(), + Ty::Scope { kind, id } => todo!(), + Ty::Ref { scope, inner } => todo!(), + Ty::Array { index, elem } => todo!(), + Ty::Tuple { members } => todo!(), + }) + .collect(); + + let mut fwd_body = vec![]; + let mut bwd_body = vec![]; + + let mut intermediates = vec![]; + + for instr in f.body.iter() { + let var = instr.var; + let t = f.vars[var.var()]; + match &instr.expr { + Expr::Unit => fwd_body.push(Instr { + var, + expr: Expr::Unit, + }), + &Expr::Bool { val } => fwd_body.push(Instr { + var, + expr: Expr::Bool { val }, + }), + &Expr::F64 { val } => { + if t == linear { + bwd_body.push(Instr { + var, + expr: Expr::F64 { val }, + }) + } else { + fwd_body.push(Instr { + var, + expr: Expr::F64 { val }, + }) + } + } + &Expr::Fin { val } => fwd_body.push(Instr { + var, + expr: Expr::Fin { val }, + }), + + Expr::Array { .. } => todo!(), + Expr::Tuple { members } => todo!(), + + Expr::Index { .. } => todo!(), + Expr::Member { tuple, member } => todo!(), + + Expr::Slice { .. } => todo!(), + Expr::Field { .. } => todo!(), + + &Expr::Unary { op, arg } => match op { + Unop::Not => fwd_body.push(Instr { + var, + expr: Expr::Unary { op: Unop::Not, arg }, + }), + + Unop::Neg => { + if f.vars[instr.var.var()] == linear { + bwd_body.push(Instr { + var, + expr: Expr::Unary { op: Unop::Neg, arg }, + }) + } else { + fwd_body.push(Instr { + var, + expr: Expr::Unary { op: Unop::Neg, arg }, + }) + } + } + Unop::Abs => todo!(), + Unop::Sign => todo!(), + Unop::Sqrt => todo!(), + }, + &Expr::Binary { op, left, right } => match op { + Binop::And => todo!(), + Binop::Or => todo!(), + Binop::Iff => todo!(), + Binop::Xor => todo!(), + + Binop::Neq => todo!(), + Binop::Lt => todo!(), + Binop::Leq => todo!(), + Binop::Eq => todo!(), + Binop::Gt => todo!(), + Binop::Geq => todo!(), + + Binop::Add => todo!(), + Binop::Sub => todo!(), + Binop::Mul => todo!(), + Binop::Div => todo!(), + }, + Expr::Select { .. } => todo!(), + + Expr::Call { .. } => todo!(), + Expr::For { .. } => todo!(), + Expr::Read { .. } => todo!(), + Expr::Accum { .. } => todo!(), + + Expr::Ask { .. } => todo!(), + Expr::Add { .. } => todo!(), + } + if t != linear && f.types[t.ty()] == Ty::F64 { + fwd_body.push(Instr { + var: id::var(fwd_types.len() - 1), + expr: Expr::F64 { val: 1.0 }, + }); + } + } + + let t_tuple = fwd_types.len(); // or `bwd_types.len()`, shouldn't matter + let members = intermediates.into_boxed_slice(); + fwd_types.push(Ty::Tuple { + members: members.clone(), + }); + bwd_types.push(Ty::Tuple { members }); + + ( + Func { + generics: f.generics.clone(), + types: fwd_types.into(), + vars: [].into(), // TODO + params: f.params.clone(), + ret: id::var(0), // TODO + body: fwd_body.into(), + }, + Func { + generics: f + .generics + .iter() + .map(|&before| { + let mut after = before; + if before.contains(Constraint::Read) { + after.remove(Constraint::Read); + after.insert(Constraint::Accum); + } + if before.contains(Constraint::Accum) { + after.remove(Constraint::Accum); + after.insert(Constraint::Read); + } + after + }) + .collect(), + types: bwd_types.into(), + vars: [].into(), // TODO + params: [].into(), // TODO + ret: id::var(0), // TODO + body: bwd_body.into(), + }, + ) +} diff --git a/crates/web/Cargo.toml b/crates/web/Cargo.toml index 853cd68..a4ee556 100644 --- a/crates/web/Cargo.toml +++ b/crates/web/Cargo.toml @@ -15,6 +15,7 @@ js-sys = "0.3" rose = { path = "../core" } rose-autodiff = { path = "../autodiff" } rose-interp = { path = "../interp", features = ["serde"] } +rose-transpose = { path = "../transpose" } serde = { version = "1", features = ["derive"] } serde-wasm-bindgen = "0.4" wasm-bindgen = "=0.2.87" # Must be this version of wbg From 7a17a03da50a6b0298942b97c2de91a7ff9cda99 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Thu, 7 Sep 2023 11:51:25 -0400 Subject: [PATCH 02/11] Fix build --- crates/transpose/src/lib.rs | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 0b064c8..d544ecd 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -80,21 +80,14 @@ impl Transpose<'_> { Expr::Call { id, generics, args } => todo!(), Expr::For { arg, body, ret } => todo!(), - Expr::Read { - var, - arg, - body, - ret, - } => todo!(), - Expr::Accum { - shape, - arg, - body, - ret, - } => todo!(), + + Expr::Read { var } => todo!(), + Expr::Accum { shape } => todo!(), Expr::Ask { var } => todo!(), Expr::Add { accum, addend } => todo!(), + + Expr::Resolve { var } => todo!(), } } } @@ -229,11 +222,14 @@ pub fn transpose(f: &Func, linear: id::Ty) -> (Func, Func) { Expr::Call { .. } => todo!(), Expr::For { .. } => todo!(), + Expr::Read { .. } => todo!(), Expr::Accum { .. } => todo!(), Expr::Ask { .. } => todo!(), Expr::Add { .. } => todo!(), + + Expr::Resolve { .. } => todo!(), } if t != linear && f.types[t.ty()] == Ty::F64 { fwd_body.push(Instr { From 411e1c6d81ae963ee5181b89f954202effa74bf9 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Thu, 7 Sep 2023 21:53:36 -0400 Subject: [PATCH 03/11] Implement something of a skeleton --- Cargo.lock | 1 - crates/transpose/Cargo.toml | 1 - crates/transpose/src/lib.rs | 590 ++++++++++++++++++++++++------------ 3 files changed, 391 insertions(+), 201 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 863541f..d452993 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -191,7 +191,6 @@ dependencies = [ name = "rose-transpose" version = "0.0.0" dependencies = [ - "indexmap", "rose", ] diff --git a/crates/transpose/Cargo.toml b/crates/transpose/Cargo.toml index b6d2f45..e35fdcc 100644 --- a/crates/transpose/Cargo.toml +++ b/crates/transpose/Cargo.toml @@ -4,5 +4,4 @@ version = "0.0.0" edition = "2021" [dependencies] -indexmap = "2" rose = { path = "../core" } diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index d544ecd..5e86ee1 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -1,40 +1,111 @@ use rose::{id, Binop, Constraint, Expr, Func, Instr, Ty, Unop}; +const REAL: id::Ty = id::ty(0); +const DUAL: id::Ty = id::ty(1); + +fn is_primitive(t: id::Ty) -> bool { + t == REAL || t == DUAL +} + +enum BwdTy { + Known(id::Ty), + Unit, + Fin(usize), + Accum, +} + +struct Lin { + acc: id::Var, + cot: id::Var, +} + struct Transpose<'a> { f: &'a Func, - linear: id::Ty, + types: Vec, + fwd_vars: Vec, fwd_block: Vec, - intermediate_tuple: id::Var, - intermediate_member: id::Member, + intermediates_tuple: id::Var, + intermediate_members: Vec, + bwd_vars: Vec, + real_shape: id::Var, + reals: Box<[Option]>, + duals: Box<[Option]>, + accums: Box<[Option]>, + cotangents: Box<[Option]>, bwd_nonlinear: Vec, bwd_linear: Vec, } impl Transpose<'_> { - fn block(&mut self, block: &[Instr]) { - let mut members = vec![]; + fn ty(&mut self, ty: Ty) -> id::Ty { + let t = id::ty(self.types.len()); + self.types.push(ty); + t + } + + fn keep(&mut self, var: id::Var) { + self.bwd_nonlinear.push(Instr { + var, + expr: Expr::Member { + tuple: self.intermediates_tuple, + member: id::member(self.intermediate_members.len()), + }, + }); + self.intermediate_members.push(var); + } + + fn bwd_var(&mut self, t: BwdTy) -> id::Var { + let var = id::var(self.bwd_vars.len()); + self.bwd_vars.push(t); + var + } + + fn accum(&mut self, shape: id::Var) -> Lin { + let acc = self.bwd_var(BwdTy::Accum); + let cot = self.bwd_var(BwdTy::Known(self.f.vars[shape.var()])); + self.bwd_nonlinear.push(Instr { + var: acc, + expr: Expr::Accum { shape }, + }); + self.accums[shape.var()] = Some(acc); + self.cotangents[shape.var()] = Some(cot); + Lin { acc, cot } + } + + fn calc(&mut self, tan: id::Var) -> Lin { + let acc = self.bwd_var(BwdTy::Accum); + let cot = self.bwd_var(BwdTy::Known(DUAL)); + self.bwd_nonlinear.push(Instr { + var: acc, + expr: Expr::Accum { + shape: self.real_shape, + }, + }); + self.accums[tan.var()] = Some(acc); + self.cotangents[tan.var()] = Some(cot); + Lin { acc, cot } + } + + fn resolve(&mut self, lin: Lin) { + self.bwd_linear.push(Instr { + var: lin.cot, + expr: Expr::Resolve { var: lin.acc }, + }) + } + + fn block(&mut self, block: &[Instr]) -> id::Ty { for instr in block.iter() { - let var = instr.var; - self.instr(var, &instr.expr); - let t = self.f.vars[var.var()]; - if t != self.linear && self.f.types[t.ty()] == Ty::F64 { - members.push(var); - self.bwd_nonlinear.push(Instr { - var, - expr: Expr::Member { - tuple: self.intermediate_tuple, - member: self.intermediate_member, - }, - }); - self.intermediate_member = id::member(self.intermediate_member.member() + 1); - } + self.instr(instr.var, &instr.expr); } + let vars = std::mem::take(&mut self.intermediate_members); + let types = vars.iter().map(|&x| self.fwd_vars[x.var()]).collect(); self.fwd_block.push(Instr { - var: self.intermediate_tuple, + var: self.intermediates_tuple, expr: Expr::Tuple { - members: members.into(), + members: vars.into(), }, }); + self.ty(Ty::Tuple { members: types }) } fn instr(&mut self, var: id::Var, expr: &Expr) { @@ -47,212 +118,333 @@ impl Transpose<'_> { var, expr: Expr::Bool { val }, }), - &Expr::F64 { val } => { - if self.f.vars[var.var()] == self.linear { - self.bwd_linear.push(Instr { - var, - expr: Expr::F64 { val }, - }) - } else { - self.fwd_block.push(Instr { - var, - expr: Expr::F64 { val }, - }) - } + &Expr::F64 { val } => todo!(), + &Expr::Fin { val } => { + self.fwd_block.push(Instr { + var, + expr: Expr::Fin { val }, + }); + self.bwd_nonlinear.push(Instr { + var, + expr: Expr::Fin { val }, + }); } - &Expr::Fin { val } => self.fwd_block.push(Instr { - var, - expr: Expr::Fin { val }, - }), - - Expr::Array { elems } => todo!(), - Expr::Tuple { members } => todo!(), - Expr::Index { array, index } => todo!(), - Expr::Member { tuple, member } => todo!(), + Expr::Array { elems } => { + self.fwd_block.push(Instr { + var, + expr: Expr::Array { + elems: elems.clone(), + }, + }); + self.keep(var); + let lin = self.accum(var); + for (i, &elem) in elems.iter().enumerate() { + if let Some(accum) = self.accums[elem.var()] { + let index = self.bwd_var(BwdTy::Fin(elems.len())); + let addend = self.bwd_var(BwdTy::Known(self.f.vars[elem.var()])); + let unit = self.bwd_var(BwdTy::Unit); + self.bwd_linear.push(Instr { + var: unit, + expr: Expr::Add { accum, addend }, + }); + self.bwd_linear.push(Instr { + var: addend, + expr: Expr::Index { + array: lin.cot, + index, + }, + }); + self.bwd_linear.push(Instr { + var: index, + expr: Expr::Fin { val: i }, + }); + } + } + self.resolve(lin); + } + Expr::Tuple { members } => { + self.fwd_block.push(Instr { + var, + expr: Expr::Tuple { + members: members.clone(), + }, + }); + self.keep(var); + let lin = self.accum(var); + for (i, &member) in members.iter().enumerate() { + if let Some(accum) = self.accums[member.var()] { + let addend = self.bwd_var(BwdTy::Known(self.f.vars[member.var()])); + let unit = self.bwd_var(BwdTy::Unit); + self.bwd_linear.push(Instr { + var: unit, + expr: Expr::Add { accum, addend }, + }); + self.bwd_linear.push(Instr { + var: addend, + expr: Expr::Member { + tuple: lin.cot, + member: id::member(i), + }, + }); + } + } + self.resolve(lin); + } - Expr::Slice { array, index } => todo!(), - Expr::Field { tuple, member } => todo!(), + &Expr::Index { array, index } => { + self.fwd_block.push(Instr { + var, + expr: Expr::Index { array, index }, + }); + self.bwd_nonlinear.push(Instr { + var, + expr: Expr::Index { array, index }, + }); + let acc = self.bwd_var(BwdTy::Accum); + self.accums[var.var()] = Some(acc); + self.bwd_nonlinear.push(Instr { + var: acc, + expr: Expr::Slice { + array: self.accums[array.var()].unwrap(), + index, + }, + }); + } + &Expr::Member { tuple, member } => match self.f.vars[var.var()] { + REAL => self.reals[var.var()] = Some(tuple), + DUAL => self.duals[var.var()] = Some(tuple), + _ => { + self.fwd_block.push(Instr { + var, + expr: Expr::Member { tuple, member }, + }); + self.bwd_nonlinear.push(Instr { + var, + expr: Expr::Member { tuple, member }, + }); + let acc = self.bwd_var(BwdTy::Accum); + self.accums[var.var()] = Some(acc); + self.bwd_nonlinear.push(Instr { + var: acc, + expr: Expr::Field { + tuple: self.accums[tuple.var()].unwrap(), + member, + }, + }); + } + }, - Expr::Unary { op, arg } => todo!(), - Expr::Binary { op, left, right } => todo!(), - Expr::Select { cond, then, els } => todo!(), + &Expr::Slice { array, index } => todo!(), + &Expr::Field { tuple, member } => todo!(), + + &Expr::Unary { op, arg } => match self.f.vars[var.var()] { + DUAL => match op { + Unop::Not | Unop::Abs | Unop::Sign | Unop::Sqrt => panic!(), + Unop::Neg => { + let lin = self.calc(var); + let res = self.bwd_var(BwdTy::Known(DUAL)); + let unit = self.bwd_var(BwdTy::Unit); + self.bwd_linear.push(Instr { + var: unit, + expr: Expr::Add { + accum: self.accums[arg.var()].unwrap(), + addend: res, + }, + }); + self.bwd_linear.push(Instr { + var: res, + expr: Expr::Unary { + op: Unop::Neg, + arg: lin.cot, + }, + }); + self.resolve(lin); + } + }, + _ => { + self.fwd_block.push(Instr { + var, + expr: Expr::Unary { op, arg }, + }); + self.keep(var); + } + }, + &Expr::Binary { op, left, right } => match self.f.vars[var.var()] { + DUAL => { + let lin = self.calc(var); + match op { + Binop::And + | Binop::Or + | Binop::Iff + | Binop::Xor + | Binop::Neq + | Binop::Lt + | Binop::Leq + | Binop::Eq + | Binop::Gt + | Binop::Geq => panic!(), + Binop::Add => { + let a = self.bwd_var(BwdTy::Unit); + let b = self.bwd_var(BwdTy::Unit); + self.bwd_linear.push(Instr { + var: a, + expr: Expr::Add { + accum: self.accums[left.var()].unwrap(), + addend: lin.cot, + }, + }); + self.bwd_linear.push(Instr { + var: b, + expr: Expr::Add { + accum: self.accums[right.var()].unwrap(), + addend: lin.cot, + }, + }); + } + Binop::Sub => { + let res = self.bwd_var(BwdTy::Known(DUAL)); + let a = self.bwd_var(BwdTy::Unit); + let b = self.bwd_var(BwdTy::Unit); + self.bwd_linear.push(Instr { + var: a, + expr: Expr::Add { + accum: self.accums[left.var()].unwrap(), + addend: lin.cot, + }, + }); + self.bwd_linear.push(Instr { + var: b, + expr: Expr::Add { + accum: self.accums[right.var()].unwrap(), + addend: res, + }, + }); + self.bwd_linear.push(Instr { + var: res, + expr: Expr::Unary { + op: Unop::Neg, + arg: lin.cot, + }, + }); + } + Binop::Mul | Binop::Div => { + let res = self.bwd_var(BwdTy::Known(DUAL)); + let unit = self.bwd_var(BwdTy::Unit); + self.bwd_linear.push(Instr { + var: unit, + expr: Expr::Add { + accum: self.accums[left.var()].unwrap(), + addend: res, + }, + }); + self.bwd_linear.push(Instr { + var: res, + expr: Expr::Binary { + op, + left: lin.cot, + right, + }, + }); + } + } + self.resolve(lin); + } + _ => { + self.fwd_block.push(Instr { + var, + expr: Expr::Binary { op, left, right }, + }); + self.keep(var); + } + }, + &Expr::Select { cond, then, els } => todo!(), Expr::Call { id, generics, args } => todo!(), Expr::For { arg, body, ret } => todo!(), - Expr::Read { var } => todo!(), - Expr::Accum { shape } => todo!(), + &Expr::Read { var } => todo!(), + &Expr::Accum { shape } => todo!(), - Expr::Ask { var } => todo!(), - Expr::Add { accum, addend } => todo!(), + &Expr::Ask { var } => todo!(), + &Expr::Add { accum, addend } => todo!(), - Expr::Resolve { var } => todo!(), + &Expr::Resolve { var } => todo!(), } } } /// Return two functions that make up the transpose of `f`. -/// -/// `linear` must be the type index of an `F64` type in `f`. -pub fn transpose(f: &Func, linear: id::Ty) -> (Func, Func) { +pub fn transpose(f: &Func) -> (Func, Func) { let mut tp = Transpose { f, - linear, - fwd_block: vec![], - intermediate_tuple: id::var(0), // TODO - intermediate_member: id::member(0), // TODO - bwd_nonlinear: vec![], - bwd_linear: vec![], - }; - - tp.block(&f.body); - - let t_f64 = id::ty(f.types.len()); - - let mut fwd_types = f.types.to_vec(); - fwd_types[linear.ty()] = Ty::Unit; - fwd_types.push(Ty::F64); - - let mut bwd_types = f.types.to_vec(); - bwd_types.push(Ty::F64); - - let mut bwd_params: Vec<_> = f - .params - .iter() - .map(|param| match &f.types[f.vars[param.var()].ty()] { - Ty::Unit => todo!(), - Ty::Bool => todo!(), - Ty::F64 => todo!(), - Ty::Fin { size } => todo!(), - Ty::Generic { id } => todo!(), - Ty::Scope { kind, id } => todo!(), - Ty::Ref { scope, inner } => todo!(), - Ty::Array { index, elem } => todo!(), - Ty::Tuple { members } => todo!(), - }) - .collect(); - - let mut fwd_body = vec![]; - let mut bwd_body = vec![]; - - let mut intermediates = vec![]; - - for instr in f.body.iter() { - let var = instr.var; - let t = f.vars[var.var()]; - match &instr.expr { - Expr::Unit => fwd_body.push(Instr { - var, - expr: Expr::Unit, - }), - &Expr::Bool { val } => fwd_body.push(Instr { - var, - expr: Expr::Bool { val }, - }), - &Expr::F64 { val } => { - if t == linear { - bwd_body.push(Instr { - var, - expr: Expr::F64 { val }, - }) - } else { - fwd_body.push(Instr { - var, - expr: Expr::F64 { val }, - }) + types: f + .types + .iter() + .map(|ty| match ty { + Ty::Unit => Ty::Unit, + Ty::Bool => Ty::Unit, + Ty::F64 => Ty::F64, + &Ty::Fin { size } => Ty::Fin { size }, + &Ty::Generic { id } => Ty::Generic { id }, + &Ty::Scope { kind, id } => Ty::Scope { kind, id }, + &Ty::Ref { scope, inner } => { + if is_primitive(inner) { + panic!() + } + Ty::Ref { scope, inner } } - } - &Expr::Fin { val } => fwd_body.push(Instr { - var, - expr: Expr::Fin { val }, - }), - - Expr::Array { .. } => todo!(), - Expr::Tuple { members } => todo!(), - - Expr::Index { .. } => todo!(), - Expr::Member { tuple, member } => todo!(), - - Expr::Slice { .. } => todo!(), - Expr::Field { .. } => todo!(), - - &Expr::Unary { op, arg } => match op { - Unop::Not => fwd_body.push(Instr { - var, - expr: Expr::Unary { op: Unop::Not, arg }, - }), - - Unop::Neg => { - if f.vars[instr.var.var()] == linear { - bwd_body.push(Instr { - var, - expr: Expr::Unary { op: Unop::Neg, arg }, - }) + &Ty::Array { index, elem } => Ty::Array { index, elem }, + Ty::Tuple { members } => { + if members.iter().any(|&t| is_primitive(t)) { + Ty::F64 } else { - fwd_body.push(Instr { - var, - expr: Expr::Unary { op: Unop::Neg, arg }, - }) + Ty::Tuple { + members: members.clone(), + } } } - Unop::Abs => todo!(), - Unop::Sign => todo!(), - Unop::Sqrt => todo!(), - }, - &Expr::Binary { op, left, right } => match op { - Binop::And => todo!(), - Binop::Or => todo!(), - Binop::Iff => todo!(), - Binop::Xor => todo!(), - - Binop::Neq => todo!(), - Binop::Lt => todo!(), - Binop::Leq => todo!(), - Binop::Eq => todo!(), - Binop::Gt => todo!(), - Binop::Geq => todo!(), - - Binop::Add => todo!(), - Binop::Sub => todo!(), - Binop::Mul => todo!(), - Binop::Div => todo!(), - }, - Expr::Select { .. } => todo!(), - - Expr::Call { .. } => todo!(), - Expr::For { .. } => todo!(), - - Expr::Read { .. } => todo!(), - Expr::Accum { .. } => todo!(), - - Expr::Ask { .. } => todo!(), - Expr::Add { .. } => todo!(), - - Expr::Resolve { .. } => todo!(), - } - if t != linear && f.types[t.ty()] == Ty::F64 { - fwd_body.push(Instr { - var: id::var(fwd_types.len() - 1), - expr: Expr::F64 { val: 1.0 }, - }); - } - } + }) + .collect(), + fwd_vars: f.vars.to_vec(), + fwd_block: vec![], + intermediates_tuple: id::var(0), // TODO + intermediate_members: vec![], + bwd_vars: f.vars.iter().map(|&t| BwdTy::Known(t)).collect(), + real_shape: id::var(0), // TODO + reals: vec![None; f.vars.len()].into(), + duals: vec![None; f.vars.len()].into(), + accums: vec![None; f.vars.len()].into(), // TODO + cotangents: vec![None; f.vars.len()].into(), // TODO + bwd_nonlinear: vec![], + bwd_linear: vec![], + }; + let t_intermediates = tp.block(&f.body); + let bwd_types = tp.types.clone(); - let t_tuple = fwd_types.len(); // or `bwd_types.len()`, shouldn't matter - let members = intermediates.into_boxed_slice(); + let mut fwd_types = tp.types; + let t_bundle = id::ty(fwd_types.len()); fwd_types.push(Ty::Tuple { - members: members.clone(), + members: vec![f.vars[f.ret.var()], t_intermediates].into(), + }); + let mut fwd_vars = tp.fwd_vars; + let fwd_ret = id::var(fwd_vars.len()); + fwd_vars.push(t_bundle); + let mut fwd_body = tp.fwd_block; + fwd_body.push(Instr { + var: fwd_ret, + expr: Expr::Tuple { + members: vec![f.ret, tp.intermediates_tuple].into(), + }, }); - bwd_types.push(Ty::Tuple { members }); ( Func { generics: f.generics.clone(), types: fwd_types.into(), - vars: [].into(), // TODO + vars: fwd_vars.into(), params: f.params.clone(), - ret: id::var(0), // TODO + ret: fwd_ret, body: fwd_body.into(), }, Func { @@ -271,12 +463,12 @@ pub fn transpose(f: &Func, linear: id::Ty) -> (Func, Func) { } after }) - .collect(), + .collect(), // TODO types: bwd_types.into(), vars: [].into(), // TODO params: [].into(), // TODO ret: id::var(0), // TODO - body: bwd_body.into(), + body: [].into(), // TODO }, ) } From dc69d40401c2ef7fab1834cf0c3184743ac95c9b Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Fri, 8 Sep 2023 12:55:48 -0400 Subject: [PATCH 04/11] Separate function-level from block-level state --- crates/transpose/src/lib.rs | 238 ++++++++++++++++++++++++++---------- 1 file changed, 175 insertions(+), 63 deletions(-) diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 5e86ee1..464f638 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -1,4 +1,5 @@ use rose::{id, Binop, Constraint, Expr, Func, Instr, Ty, Unop}; +use std::mem::{swap, take}; const REAL: id::Ty = id::ty(0); const DUAL: id::Ty = id::ty(1); @@ -7,33 +8,106 @@ fn is_primitive(t: id::Ty) -> bool { t == REAL || t == DUAL } +/// Type ID of a variable while building up the backward pass. enum BwdTy { + /// We already know the type ID of this variable. + /// + /// Usually this means the variable is directly mapped from a variable in the original function. Known(id::Ty), + + /// This variable has the unit type. Unit, - Fin(usize), + + /// This variable is an accumulator. + /// + /// After we process the entire function, we need to add a scope type for every new accumulator + /// variable we've introduced; we don't add these as we go because we don't know ahead of time + /// how many types we'll need for intermediate values, and we want all those intermediate value + /// type IDs to be the same between the forward and backward passes. Accum, + + /// We don't know the type of this variable yet, but will soon. + /// + /// Usually this means the variable is a tuple of intermediate values, and we'll update its type + /// to something concrete when we reach the end of the block. + Unknown, } +/// Linear variables in the backward pass for a variable from the original function. struct Lin { + /// Accumulator variable. acc: id::Var, + + /// Resolved cotangent variable. cot: id::Var, } +/// A block under construction for both the forward pass and the backward pass. +struct Block { + /// Instructions in this forward pass block, in order. + fwd: Vec, + + /// Variable IDs for intermediate values to be saved at the end of this forward pass block. + intermediate_members: Vec, + + /// Variable ID for the intermediate values tuple in this backward pass block. + intermediates_tuple: id::Var, + + /// Instructions at the beginning of this backward pass block, in order. + bwd_nonlinear: Vec, + + /// Instructions at the end of this backward pass block, in reverse order. + bwd_linear: Vec, +} + +/// The forward pass and backward pass of a transposed function under construction. struct Transpose<'a> { + /// The function being transposed, which is usually a forward-mode derivative. f: &'a Func, + + /// Shared types between the forward and backward passes. + /// + /// This starts out the same length as `f.types`, with the only difference being that all dual + /// number types are replaced with `F64`. Then we add more types only for tuples and arrays of + /// intermediate values that are shared between the two passes. types: Vec, + + /// Types of variables in the forward pass. + /// + /// This starts out as a clone of `f.vars`, but more variables can be added for dealing with + /// intermediate values. fwd_vars: Vec, - fwd_block: Vec, - intermediates_tuple: id::Var, - intermediate_members: Vec, + + /// Types of variables in the backward pass. + /// + /// This starts out with the `Known` type of every variable from `f.vars`, but more variables + /// can be added for intermediate values, accumulators, and cotangents. bwd_vars: Vec, + + /// A variable of type `F64` defined at the beginning of the backward pass. + /// + /// Every accumulator must be initialized using a concrete variable to dictate its topology. For + /// most variables, we keep around the original nonlinear value and use that as the shape, but + /// this doesn't work for raw linear `F64` variables, which might be part of some lower-level + /// mathematical calculation that is not clearly attached to any value at the dual number level + /// or higher. All those values have the same shape, though, so in those cases we just use this + /// one dummy variable as the shape. real_shape: id::Var, + + /// Variables from the original function that are just the real part of a dual number variable. reals: Box<[Option]>, + + /// Variables from the original function that are just the dual part of a dual number variable. duals: Box<[Option]>, + + /// Accumulator variables for variables from the original function. accums: Box<[Option]>, + + /// Cotangent variables for variables from the original function. cotangents: Box<[Option]>, - bwd_nonlinear: Vec, - bwd_linear: Vec, + + /// The current block under construction. + block: Block, } impl Transpose<'_> { @@ -43,15 +117,10 @@ impl Transpose<'_> { t } - fn keep(&mut self, var: id::Var) { - self.bwd_nonlinear.push(Instr { - var, - expr: Expr::Member { - tuple: self.intermediates_tuple, - member: id::member(self.intermediate_members.len()), - }, - }); - self.intermediate_members.push(var); + fn fwd_var(&mut self, t: id::Ty) -> id::Var { + let var = id::var(self.fwd_vars.len()); + self.fwd_vars.push(t); + var } fn bwd_var(&mut self, t: BwdTy) -> id::Var { @@ -60,10 +129,21 @@ impl Transpose<'_> { var } + fn keep(&mut self, var: id::Var) { + self.block.bwd_nonlinear.push(Instr { + var, + expr: Expr::Member { + tuple: self.block.intermediates_tuple, + member: id::member(self.block.intermediate_members.len()), + }, + }); + self.block.intermediate_members.push(var); + } + fn accum(&mut self, shape: id::Var) -> Lin { let acc = self.bwd_var(BwdTy::Accum); let cot = self.bwd_var(BwdTy::Known(self.f.vars[shape.var()])); - self.bwd_nonlinear.push(Instr { + self.block.bwd_nonlinear.push(Instr { var: acc, expr: Expr::Accum { shape }, }); @@ -75,7 +155,7 @@ impl Transpose<'_> { fn calc(&mut self, tan: id::Var) -> Lin { let acc = self.bwd_var(BwdTy::Accum); let cot = self.bwd_var(BwdTy::Known(DUAL)); - self.bwd_nonlinear.push(Instr { + self.block.bwd_nonlinear.push(Instr { var: acc, expr: Expr::Accum { shape: self.real_shape, @@ -87,7 +167,7 @@ impl Transpose<'_> { } fn resolve(&mut self, lin: Lin) { - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: lin.cot, expr: Expr::Resolve { var: lin.acc }, }) @@ -97,41 +177,49 @@ impl Transpose<'_> { for instr in block.iter() { self.instr(instr.var, &instr.expr); } - let vars = std::mem::take(&mut self.intermediate_members); - let types = vars.iter().map(|&x| self.fwd_vars[x.var()]).collect(); - self.fwd_block.push(Instr { - var: self.intermediates_tuple, + let vars = take(&mut self.block.intermediate_members); + let t = self.ty(Ty::Tuple { + members: vars.iter().map(|&x| self.fwd_vars[x.var()]).collect(), + }); + let var = self.fwd_var(t); + self.block.fwd.push(Instr { + var, expr: Expr::Tuple { members: vars.into(), }, }); - self.ty(Ty::Tuple { members: types }) + self.bwd_vars[self.block.intermediates_tuple.var()] = BwdTy::Known(t); + t } fn instr(&mut self, var: id::Var, expr: &Expr) { match expr { - Expr::Unit => self.fwd_block.push(Instr { + Expr::Unit => self.block.fwd.push(Instr { var, expr: Expr::Unit, }), - &Expr::Bool { val } => self.fwd_block.push(Instr { + &Expr::Bool { val } => self.block.fwd.push(Instr { var, expr: Expr::Bool { val }, }), &Expr::F64 { val } => todo!(), &Expr::Fin { val } => { - self.fwd_block.push(Instr { + self.block.fwd.push(Instr { var, expr: Expr::Fin { val }, }); - self.bwd_nonlinear.push(Instr { + self.block.bwd_nonlinear.push(Instr { var, expr: Expr::Fin { val }, }); } Expr::Array { elems } => { - self.fwd_block.push(Instr { + let t = match self.f.types[self.f.vars[var.var()].ty()] { + Ty::Array { index, elem: _ } => index, + _ => panic!(), + }; + self.block.fwd.push(Instr { var, expr: Expr::Array { elems: elems.clone(), @@ -141,21 +229,21 @@ impl Transpose<'_> { let lin = self.accum(var); for (i, &elem) in elems.iter().enumerate() { if let Some(accum) = self.accums[elem.var()] { - let index = self.bwd_var(BwdTy::Fin(elems.len())); + let index = self.bwd_var(BwdTy::Known(t)); let addend = self.bwd_var(BwdTy::Known(self.f.vars[elem.var()])); let unit = self.bwd_var(BwdTy::Unit); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: unit, expr: Expr::Add { accum, addend }, }); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: addend, expr: Expr::Index { array: lin.cot, index, }, }); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: index, expr: Expr::Fin { val: i }, }); @@ -164,7 +252,7 @@ impl Transpose<'_> { self.resolve(lin); } Expr::Tuple { members } => { - self.fwd_block.push(Instr { + self.block.fwd.push(Instr { var, expr: Expr::Tuple { members: members.clone(), @@ -176,11 +264,11 @@ impl Transpose<'_> { if let Some(accum) = self.accums[member.var()] { let addend = self.bwd_var(BwdTy::Known(self.f.vars[member.var()])); let unit = self.bwd_var(BwdTy::Unit); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: unit, expr: Expr::Add { accum, addend }, }); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: addend, expr: Expr::Member { tuple: lin.cot, @@ -193,17 +281,17 @@ impl Transpose<'_> { } &Expr::Index { array, index } => { - self.fwd_block.push(Instr { + self.block.fwd.push(Instr { var, expr: Expr::Index { array, index }, }); - self.bwd_nonlinear.push(Instr { + self.block.bwd_nonlinear.push(Instr { var, expr: Expr::Index { array, index }, }); let acc = self.bwd_var(BwdTy::Accum); self.accums[var.var()] = Some(acc); - self.bwd_nonlinear.push(Instr { + self.block.bwd_nonlinear.push(Instr { var: acc, expr: Expr::Slice { array: self.accums[array.var()].unwrap(), @@ -215,17 +303,17 @@ impl Transpose<'_> { REAL => self.reals[var.var()] = Some(tuple), DUAL => self.duals[var.var()] = Some(tuple), _ => { - self.fwd_block.push(Instr { + self.block.fwd.push(Instr { var, expr: Expr::Member { tuple, member }, }); - self.bwd_nonlinear.push(Instr { + self.block.bwd_nonlinear.push(Instr { var, expr: Expr::Member { tuple, member }, }); let acc = self.bwd_var(BwdTy::Accum); self.accums[var.var()] = Some(acc); - self.bwd_nonlinear.push(Instr { + self.block.bwd_nonlinear.push(Instr { var: acc, expr: Expr::Field { tuple: self.accums[tuple.var()].unwrap(), @@ -245,14 +333,14 @@ impl Transpose<'_> { let lin = self.calc(var); let res = self.bwd_var(BwdTy::Known(DUAL)); let unit = self.bwd_var(BwdTy::Unit); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: unit, expr: Expr::Add { accum: self.accums[arg.var()].unwrap(), addend: res, }, }); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: res, expr: Expr::Unary { op: Unop::Neg, @@ -263,7 +351,7 @@ impl Transpose<'_> { } }, _ => { - self.fwd_block.push(Instr { + self.block.fwd.push(Instr { var, expr: Expr::Unary { op, arg }, }); @@ -287,14 +375,14 @@ impl Transpose<'_> { Binop::Add => { let a = self.bwd_var(BwdTy::Unit); let b = self.bwd_var(BwdTy::Unit); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: a, expr: Expr::Add { accum: self.accums[left.var()].unwrap(), addend: lin.cot, }, }); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: b, expr: Expr::Add { accum: self.accums[right.var()].unwrap(), @@ -306,21 +394,21 @@ impl Transpose<'_> { let res = self.bwd_var(BwdTy::Known(DUAL)); let a = self.bwd_var(BwdTy::Unit); let b = self.bwd_var(BwdTy::Unit); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: a, expr: Expr::Add { accum: self.accums[left.var()].unwrap(), addend: lin.cot, }, }); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: b, expr: Expr::Add { accum: self.accums[right.var()].unwrap(), addend: res, }, }); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: res, expr: Expr::Unary { op: Unop::Neg, @@ -331,14 +419,14 @@ impl Transpose<'_> { Binop::Mul | Binop::Div => { let res = self.bwd_var(BwdTy::Known(DUAL)); let unit = self.bwd_var(BwdTy::Unit); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: unit, expr: Expr::Add { accum: self.accums[left.var()].unwrap(), addend: res, }, }); - self.bwd_linear.push(Instr { + self.block.bwd_linear.push(Instr { var: res, expr: Expr::Binary { op, @@ -351,7 +439,7 @@ impl Transpose<'_> { self.resolve(lin); } _ => { - self.fwd_block.push(Instr { + self.block.fwd.push(Instr { var, expr: Expr::Binary { op, left, right }, }); @@ -361,7 +449,18 @@ impl Transpose<'_> { &Expr::Select { cond, then, els } => todo!(), Expr::Call { id, generics, args } => todo!(), - Expr::For { arg, body, ret } => todo!(), + Expr::For { arg, body, ret } => { + let mut block = Block { + fwd: vec![], + intermediate_members: vec![], + intermediates_tuple: self.bwd_var(BwdTy::Unknown), + bwd_nonlinear: vec![], + bwd_linear: vec![], + }; + swap(&mut self.block, &mut block); + self.block(body); + swap(&mut self.block, &mut block); + } &Expr::Read { var } => todo!(), &Expr::Accum { shape } => todo!(), @@ -381,10 +480,16 @@ pub fn transpose(f: &Func) -> (Func, Func) { types: f .types .iter() - .map(|ty| match ty { + .enumerate() + .map(|(i, ty)| match ty { Ty::Unit => Ty::Unit, Ty::Bool => Ty::Unit, - Ty::F64 => Ty::F64, + Ty::F64 => { + if is_primitive(id::ty(i)) { + panic!() + } + Ty::F64 + } &Ty::Fin { size } => Ty::Fin { size }, &Ty::Generic { id } => Ty::Generic { id }, &Ty::Scope { kind, id } => Ty::Scope { kind, id }, @@ -394,7 +499,12 @@ pub fn transpose(f: &Func) -> (Func, Func) { } Ty::Ref { scope, inner } } - &Ty::Array { index, elem } => Ty::Array { index, elem }, + &Ty::Array { index, elem } => { + if is_primitive(elem) { + panic!() + } + Ty::Array { index, elem } + } Ty::Tuple { members } => { if members.iter().any(|&t| is_primitive(t)) { Ty::F64 @@ -407,17 +517,19 @@ pub fn transpose(f: &Func) -> (Func, Func) { }) .collect(), fwd_vars: f.vars.to_vec(), - fwd_block: vec![], - intermediates_tuple: id::var(0), // TODO - intermediate_members: vec![], bwd_vars: f.vars.iter().map(|&t| BwdTy::Known(t)).collect(), real_shape: id::var(0), // TODO reals: vec![None; f.vars.len()].into(), duals: vec![None; f.vars.len()].into(), accums: vec![None; f.vars.len()].into(), // TODO cotangents: vec![None; f.vars.len()].into(), // TODO - bwd_nonlinear: vec![], - bwd_linear: vec![], + block: Block { + fwd: vec![], + intermediates_tuple: id::var(0), // TODO + intermediate_members: vec![], + bwd_nonlinear: vec![], + bwd_linear: vec![], + }, }; let t_intermediates = tp.block(&f.body); let bwd_types = tp.types.clone(); @@ -430,11 +542,11 @@ pub fn transpose(f: &Func) -> (Func, Func) { let mut fwd_vars = tp.fwd_vars; let fwd_ret = id::var(fwd_vars.len()); fwd_vars.push(t_bundle); - let mut fwd_body = tp.fwd_block; + let mut fwd_body = tp.block.fwd; fwd_body.push(Instr { var: fwd_ret, expr: Expr::Tuple { - members: vec![f.ret, tp.intermediates_tuple].into(), + members: vec![f.ret, tp.block.intermediates_tuple].into(), }, }); From fe6394acac1bd195d44ec1ddf4cc83daf99240ee Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Fri, 8 Sep 2023 17:06:05 -0400 Subject: [PATCH 05/11] Try pretty-printing transpose to debug --- crates/transpose/src/lib.rs | 108 +++++++++++++++++++++++++-------- crates/web/src/lib.rs | 33 ++++++++++ packages/core/src/impl.test.ts | 8 +++ 3 files changed, 125 insertions(+), 24 deletions(-) diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 464f638..60ec792 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -24,7 +24,7 @@ enum BwdTy { /// variable we've introduced; we don't add these as we go because we don't know ahead of time /// how many types we'll need for intermediate values, and we want all those intermediate value /// type IDs to be the same between the forward and backward passes. - Accum, + Accum(Option, id::Ty), /// We don't know the type of this variable yet, but will soon. /// @@ -110,7 +110,7 @@ struct Transpose<'a> { block: Block, } -impl Transpose<'_> { +impl<'a> Transpose<'a> { fn ty(&mut self, ty: Ty) -> id::Ty { let t = id::ty(self.types.len()); self.types.push(ty); @@ -140,9 +140,10 @@ impl Transpose<'_> { self.block.intermediate_members.push(var); } - fn accum(&mut self, shape: id::Var) -> Lin { - let acc = self.bwd_var(BwdTy::Accum); - let cot = self.bwd_var(BwdTy::Known(self.f.vars[shape.var()])); + fn accum(&mut self, shape: id::Var, scope: Option) -> Lin { + let t = self.f.vars[shape.var()]; + let acc = self.bwd_var(BwdTy::Accum(scope, t)); + let cot = self.bwd_var(BwdTy::Known(t)); self.block.bwd_nonlinear.push(Instr { var: acc, expr: Expr::Accum { shape }, @@ -153,8 +154,9 @@ impl Transpose<'_> { } fn calc(&mut self, tan: id::Var) -> Lin { - let acc = self.bwd_var(BwdTy::Accum); - let cot = self.bwd_var(BwdTy::Known(DUAL)); + let t = self.f.vars[tan.var()]; + let acc = self.bwd_var(BwdTy::Accum(None, t)); + let cot = self.bwd_var(BwdTy::Known(t)); self.block.bwd_nonlinear.push(Instr { var: acc, expr: Expr::Accum { @@ -202,7 +204,16 @@ impl Transpose<'_> { var, expr: Expr::Bool { val }, }), - &Expr::F64 { val } => todo!(), + &Expr::F64 { val } => match self.f.vars[var.var()] { + DUAL => { + let lin = self.calc(var); + self.resolve(lin); + } + _ => self.block.fwd.push(Instr { + var, + expr: Expr::F64 { val }, + }), + }, &Expr::Fin { val } => { self.block.fwd.push(Instr { var, @@ -226,7 +237,7 @@ impl Transpose<'_> { }, }); self.keep(var); - let lin = self.accum(var); + let lin = self.accum(var, None); for (i, &elem) in elems.iter().enumerate() { if let Some(accum) = self.accums[elem.var()] { let index = self.bwd_var(BwdTy::Known(t)); @@ -259,7 +270,7 @@ impl Transpose<'_> { }, }); self.keep(var); - let lin = self.accum(var); + let lin = self.accum(var, None); for (i, &member) in members.iter().enumerate() { if let Some(accum) = self.accums[member.var()] { let addend = self.bwd_var(BwdTy::Known(self.f.vars[member.var()])); @@ -289,19 +300,24 @@ impl Transpose<'_> { var, expr: Expr::Index { array, index }, }); - let acc = self.bwd_var(BwdTy::Accum); + let arr_acc = self.accums[array.var()].unwrap(); + let acc = self.bwd_var(BwdTy::Accum(Some(arr_acc), self.f.vars[array.var()])); self.accums[var.var()] = Some(acc); self.block.bwd_nonlinear.push(Instr { var: acc, expr: Expr::Slice { - array: self.accums[array.var()].unwrap(), + array: arr_acc, index, }, }); } &Expr::Member { tuple, member } => match self.f.vars[var.var()] { REAL => self.reals[var.var()] = Some(tuple), - DUAL => self.duals[var.var()] = Some(tuple), + DUAL => { + let lin = self.accum(var, None); // TODO + self.duals[var.var()] = Some(tuple); + self.resolve(lin); + } _ => { self.block.fwd.push(Instr { var, @@ -311,12 +327,13 @@ impl Transpose<'_> { var, expr: Expr::Member { tuple, member }, }); - let acc = self.bwd_var(BwdTy::Accum); + let tup_acc = self.accums[tuple.var()].unwrap(); + let acc = self.bwd_var(BwdTy::Accum(Some(tup_acc), self.f.vars[tuple.var()])); self.accums[var.var()] = Some(acc); self.block.bwd_nonlinear.push(Instr { var: acc, expr: Expr::Field { - tuple: self.accums[tuple.var()].unwrap(), + tuple: tup_acc, member, }, }); @@ -475,6 +492,11 @@ impl Transpose<'_> { /// Return two functions that make up the transpose of `f`. pub fn transpose(f: &Func) -> (Func, Func) { + let mut bwd_vars: Vec<_> = f.vars.iter().map(|&t| BwdTy::Known(t)).collect(); + let real_shape = id::var(bwd_vars.len()); + bwd_vars.push(BwdTy::Known(DUAL)); + let intermediates_tuple = id::var(bwd_vars.len()); + bwd_vars.push(BwdTy::Unknown); let mut tp = Transpose { f, types: f @@ -485,7 +507,7 @@ pub fn transpose(f: &Func) -> (Func, Func) { Ty::Unit => Ty::Unit, Ty::Bool => Ty::Unit, Ty::F64 => { - if is_primitive(id::ty(i)) { + if !is_primitive(id::ty(i)) { panic!() } Ty::F64 @@ -517,22 +539,23 @@ pub fn transpose(f: &Func) -> (Func, Func) { }) .collect(), fwd_vars: f.vars.to_vec(), - bwd_vars: f.vars.iter().map(|&t| BwdTy::Known(t)).collect(), - real_shape: id::var(0), // TODO + bwd_vars, + real_shape, reals: vec![None; f.vars.len()].into(), duals: vec![None; f.vars.len()].into(), accums: vec![None; f.vars.len()].into(), // TODO cotangents: vec![None; f.vars.len()].into(), // TODO block: Block { fwd: vec![], - intermediates_tuple: id::var(0), // TODO + intermediates_tuple, intermediate_members: vec![], bwd_nonlinear: vec![], bwd_linear: vec![], }, }; + let t_intermediates = tp.block(&f.body); - let bwd_types = tp.types.clone(); + let mut bwd_types = tp.types.clone(); let mut fwd_types = tp.types; let t_bundle = id::ty(fwd_types.len()); @@ -550,6 +573,43 @@ pub fn transpose(f: &Func) -> (Func, Func) { }, }); + let t_unit = id::ty(bwd_types.len()); + bwd_types.push(Ty::Unit); + let mut bwd_vars: Vec<_> = tp + .bwd_vars + .into_iter() + .enumerate() + .map(|(i, t)| match t { + BwdTy::Known(t) => t, + BwdTy::Unit => t_unit, + BwdTy::Accum(scope, inner) => { + let scope = id::ty(bwd_types.len()); + bwd_types.push(Ty::Scope { + kind: Constraint::Accum, + id: id::var(i), + }); + let t = id::ty(bwd_types.len()); + bwd_types.push(Ty::Ref { scope, inner }); + t + } + BwdTy::Unknown => panic!(), + }) + .collect(); + let bwd_ret = id::var(bwd_vars.len()); + bwd_vars.push(t_unit); + let mut bwd_body = vec![Instr { + var: tp.real_shape, + expr: Expr::F64 { val: 0. }, + }]; + bwd_body.append(&mut tp.block.bwd_nonlinear); + let mut bwd_linear = tp.block.bwd_linear; + bwd_linear.reverse(); + bwd_body.append(&mut bwd_linear); + bwd_body.push(Instr { + var: bwd_ret, + expr: Expr::Unit, + }); + ( Func { generics: f.generics.clone(), @@ -577,10 +637,10 @@ pub fn transpose(f: &Func) -> (Func, Func) { }) .collect(), // TODO types: bwd_types.into(), - vars: [].into(), // TODO - params: [].into(), // TODO - ret: id::var(0), // TODO - body: [].into(), // TODO + vars: bwd_vars.into(), + params: f.params.clone(), // TODO + ret: bwd_ret, + body: bwd_body.into(), }, ) } diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index ed4e185..5647b82 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -293,6 +293,39 @@ impl Func { *cache = Some(Rc::downgrade(&rc)); Self { rc } } + + #[cfg(feature = "debug")] + pub fn transpose(&self) -> Result { + match &self.rc.as_ref().inner { + Inner::Transparent { def, .. } => { + let (fwd, bwd) = rose_transpose::transpose(def); + let fwd = Func { + rc: Rc::new(Pointee { + inner: Inner::Transparent { + deps: [].into(), + def: fwd, + }, + structs: [].into(), + jvp: RefCell::new(None), + }), + }; + let bwd = Func { + rc: Rc::new(Pointee { + inner: Inner::Transparent { + deps: [].into(), + def: bwd, + }, + structs: [].into(), + jvp: RefCell::new(None), + }), + }; + let fwd_str = pprint(&fwd)?; + let bwd_str = pprint(&bwd)?; + Ok(format!("{fwd_str}\n\n{bwd_str}")) + } + Inner::Opaque { .. } => todo!(), + } + } } #[cfg(feature = "debug")] diff --git a/packages/core/src/impl.test.ts b/packages/core/src/impl.test.ts index 0aa4727..04b1be5 100644 --- a/packages/core/src/impl.test.ts +++ b/packages/core/src/impl.test.ts @@ -15,6 +15,7 @@ import { gt, iff, inner, + jvp, leq, lt, mul, @@ -188,4 +189,11 @@ T2 = [T1; T0] `.trimStart(), ); }); + + test("transpose", () => { + const f = fn([Real, Real], Real, (x, y) => mul(x, y)); + const g = jvp(f); + console.log(pprint(g)); + console.log(g[inner].transpose()); + }); }); From 2ecf91b226bd9e3d55aa828e103cd901ac9d78aa Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Fri, 8 Sep 2023 20:44:41 -0400 Subject: [PATCH 06/11] Handle VJP on the JS side --- crates/web/src/lib.rs | 139 +++++++++++++++++++++++++++----- packages/core/src/impl.test.ts | 8 -- packages/core/src/impl.ts | 45 +++++++++++ packages/core/src/index.test.ts | 6 ++ packages/core/src/index.ts | 1 + 5 files changed, 170 insertions(+), 29 deletions(-) diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index c9ca6d8..8cf7f57 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -124,6 +124,10 @@ struct Pointee { structs: Box<[Option>]>, jvp: RefCell>>, + + fwd: RefCell>>, + + bwd: RefCell>>, } /// A node in a reference-counted acyclic digraph of functions. @@ -149,6 +153,8 @@ impl Func { }, structs: [].into(), jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), }), } } @@ -261,6 +267,7 @@ impl Func { inner, structs, jvp, + .. } = self.rc.as_ref(); let mut cache = jvp.borrow_mut(); if let Some(rc) = cache.as_ref().and_then(|weak| weak.upgrade()) { @@ -286,6 +293,8 @@ impl Func { }, structs: structs_jvp.into(), jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), }) } Inner::Opaque { .. } => todo!(), @@ -294,40 +303,101 @@ impl Func { Self { rc } } - #[cfg(feature = "debug")] - pub fn transpose(&self) -> Result { - match &self.rc.as_ref().inner { - Inner::Transparent { def, .. } => { - let (fwd, bwd) = rose_transpose::transpose(def); - let fwd = Func { - rc: Rc::new(Pointee { + fn transpose_pair(&self) -> (Self, Self) { + let Pointee { + inner, + structs, + fwd, + bwd, + .. + } = self.rc.as_ref(); + let mut cache_fwd = fwd.borrow_mut(); + let mut cache_bwd = bwd.borrow_mut(); + if let (Some(rc_fwd), Some(rc_bwd)) = ( + cache_fwd.as_ref().and_then(|weak| weak.upgrade()), + cache_bwd.as_ref().and_then(|weak| weak.upgrade()), + ) { + return (Self { rc: rc_fwd }, Self { rc: rc_bwd }); + } + let (rc_fwd, rc_bwd) = match inner { + Inner::Transparent { deps, def } => { + let (deps_fwd, deps_bwd): (Vec<_>, Vec<_>) = + deps.iter().map(|f| f.transpose_pair()).unzip(); + let (def_fwd, def_bwd) = rose_transpose::transpose(def); + let structs_fwd = def_fwd + .types + .iter() + .enumerate() + .map(|(i, ty)| match ty { + rose::Ty::F64 => None, + _ => structs.get(i).cloned().flatten(), + }) + .collect(); + let structs_bwd = def_bwd + .types + .iter() + .enumerate() + .map(|(i, ty)| match ty { + rose::Ty::F64 => None, + _ => structs.get(i).cloned().flatten(), + }) + .collect(); + ( + Rc::new(Pointee { inner: Inner::Transparent { - deps: [].into(), - def: fwd, + deps: deps_fwd.into(), + def: def_fwd, }, - structs: [].into(), + structs: structs_fwd, jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), }), - }; - let bwd = Func { - rc: Rc::new(Pointee { + Rc::new(Pointee { inner: Inner::Transparent { - deps: [].into(), - def: bwd, + deps: deps_bwd.into(), + def: def_bwd, }, - structs: [].into(), + structs: structs_bwd, jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), }), - }; - let fwd_str = pprint(&fwd)?; - let bwd_str = pprint(&bwd)?; - Ok(format!("{fwd_str}\n\n{bwd_str}")) + ) } - Inner::Opaque { .. } => todo!(), + Inner::Opaque { .. } => panic!(), + }; + *cache_fwd = Some(Rc::downgrade(&rc_fwd)); + *cache_bwd = Some(Rc::downgrade(&rc_bwd)); + (Self { rc: rc_fwd }, Self { rc: rc_bwd }) + } + + pub fn transpose(&self) -> Transpose { + let (fwd, bwd) = self.transpose_pair(); + Transpose { + fwd: Some(fwd), + bwd: Some(bwd), } } } +#[wasm_bindgen] +pub struct Transpose { + fwd: Option, + bwd: Option, +} + +#[wasm_bindgen] +impl Transpose { + pub fn fwd(&mut self) -> Option { + self.fwd.take() + } + + pub fn bwd(&mut self) -> Option { + self.bwd.take() + } +} + #[cfg(feature = "debug")] #[wasm_bindgen] pub fn pprint(f: &Func) -> Result { @@ -531,6 +601,10 @@ enum Ty { Fin { size: usize, }, + Scope { + kind: rose::Constraint, + id: id::Var, + }, Ref { scope: id::Ty, inner: id::Ty, @@ -558,6 +632,7 @@ impl Ty { Ty::Bool => (rose::Ty::Bool, None), Ty::F64 => (rose::Ty::F64, None), Ty::Fin { size } => (rose::Ty::Fin { size }, None), + Ty::Scope { kind, id } => (rose::Ty::Scope { kind, id }, None), Ty::Ref { scope, inner } => (rose::Ty::Ref { scope, inner }, None), Ty::Array { index, elem } => (rose::Ty::Array { index, elem }, None), Ty::Struct { keys, members } => (rose::Ty::Tuple { members }, Some(keys)), @@ -650,6 +725,8 @@ impl FuncBuilder { }, structs: structs.into(), jvp: RefCell::new(None), + fwd: RefCell::new(None), + bwd: RefCell::new(None), }), } } @@ -831,6 +908,13 @@ impl FuncBuilder { ) } + #[wasm_bindgen(js_name = "tyAccum")] + pub fn ty_accum(&mut self, id: usize) -> usize { + let kind = rose::Constraint::Accum; + let id = id::var(id); + self.newtype(Ty::Scope { kind, id }, EnumSet::only(kind)) + } + /// Return the ID for the type of arrays with index type `index` and element type `elem`, /// /// Assumes `index` and `elem` are valid type IDs. @@ -1472,4 +1556,17 @@ impl Block { }; self.instr(f, id::ty(t), expr) } + + pub fn accum(&mut self, f: &mut FuncBuilder, inner: usize, shape: usize) -> usize { + let inner = id::ty(inner); + let shape = id::var(shape); + let scope = id::ty(f.ty_accum(f.vars.len())); + let t = id::ty(f.newtype(Ty::Ref { scope, inner }, EnumSet::empty())); + self.instr(f, t, rose::Expr::Accum { shape }) + } + + pub fn resolve(&mut self, f: &mut FuncBuilder, t: usize, var: usize) -> usize { + let expr = rose::Expr::Resolve { var: id::var(var) }; + self.instr(f, id::ty(t), expr) + } } diff --git a/packages/core/src/impl.test.ts b/packages/core/src/impl.test.ts index 04b1be5..0aa4727 100644 --- a/packages/core/src/impl.test.ts +++ b/packages/core/src/impl.test.ts @@ -15,7 +15,6 @@ import { gt, iff, inner, - jvp, leq, lt, mul, @@ -189,11 +188,4 @@ T2 = [T1; T0] `.trimStart(), ); }); - - test("transpose", () => { - const f = fn([Real, Real], Real, (x, y) => mul(x, y)); - const g = jvp(f); - console.log(pprint(g)); - console.log(g[inner].transpose()); - }); }); diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 2da93cf..10bd300 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -595,6 +595,51 @@ export const jvp = ( return g; }; +export const vjp = ( + f: Fn & ((arg: A) => R), +): ((arg: A) => { ret: R; grad: (cot: R) => A }) => { + const g = jvp(f); + const tp = g[inner].transpose(); + const fwdFunc = tp.fwd()!; + const bwdFunc = tp.bwd()!; + const fwd: Fn = { [inner]: fwdFunc, [strings]: [...f[strings]] }; + const bwd: Fn = { [inner]: bwdFunc, [strings]: [...f[strings]] }; + funcs.register(fwd, fwdFunc); + funcs.register(bwd, bwdFunc); + return (arg: A) => { + const ctx = getCtx(); + const strs = intern(ctx, fwd[strings]); + const generics = new Uint32Array(); // TODO: support generics + const [tArg, tBundle] = ctx.func.ingest(fwd[inner], strs, generics); + const [tRet, tInter] = ctx.func.members(tBundle); + const argId = valId(ctx, tArg, arg); + const bundleId = ctx.block.call( + ctx.func, + fwd[inner], + generics, + tBundle, + new Uint32Array([argId]), + ); + const primalId = ctx.block.member(ctx.func, tRet, bundleId, 0); + const interId = ctx.block.member(ctx.func, tInter, bundleId, 1); + const grad = (cot: R) => { + if (getCtx() !== ctx) throw Error("VJP closure escaped its context"); + const cotId = valId(ctx, tRet, cot); + const accId = ctx.block.accum(ctx.func, tArg, argId); + const tScope = ctx.func.tyAccum(accId); + ctx.block.call( + ctx.func, + bwd[inner], + new Uint32Array([tScope]), // TODO: support generics + ctx.func.tyUnit(), + new Uint32Array([accId, cotId, interId]), + ); + return idVal(ctx, tArg, ctx.block.resolve(ctx.func, tArg, accId)) as A; + }; + return { ret: idVal(ctx, tRet, primalId) as R, grad }; + }; +}; + /** Return the variable ID for the abstract boolean `x`. */ const boolId = (ctx: Context, x: Bool): number => valId(ctx, ctx.func.tyBool(), x); diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index c99fc20..ec4f809 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -17,6 +17,7 @@ import { sqrt, sub, vec, + vjp, } from "./index.js"; describe("invalid", () => { @@ -351,4 +352,9 @@ describe("valid", () => { const g = interp(jvp(f)); expect(g({ re: 2, du: 3 })).toEqual({ re: 2097152, du: 3145728 }); }); + + test("VJP", () => { + const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1])); + vjp(f); + }); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 602c51b..fad6ed2 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -27,5 +27,6 @@ export { sqrt, sub, vec, + vjp, xor, } from "./impl.js"; From e85470f0ce6342be67e86b5b64e75f267a7d13a8 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Fri, 8 Sep 2023 23:56:27 -0400 Subject: [PATCH 07/11] Track variable sources --- crates/transpose/src/lib.rs | 594 +++++++++++++++++++------------- crates/web/src/lib.rs | 16 +- packages/core/src/index.test.ts | 8 +- 3 files changed, 360 insertions(+), 258 deletions(-) diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 60ec792..7d0663a 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -33,6 +33,35 @@ enum BwdTy { Unknown, } +/// The source of a primitive variable. +#[derive(Clone, Copy)] +enum Src { + /// This variable is original. + Original, + + /// This variable is an alias of an original primitive variable of the same type. + Alias(id::Var), + + /// This variable is a component of an original dual number variable. + Projection(id::Var), +} + +impl Src { + fn prim(self, x: id::Var) -> Self { + match self { + Self::Original => Self::Alias(x), + _ => self, + } + } + + fn dual(self, x: id::Var) -> Self { + match self { + Self::Original => Self::Projection(x), + _ => self, + } + } +} + /// Linear variables in the backward pass for a variable from the original function. struct Lin { /// Accumulator variable. @@ -48,16 +77,16 @@ struct Block { fwd: Vec, /// Variable IDs for intermediate values to be saved at the end of this forward pass block. - intermediate_members: Vec, + inter_mem: Vec, /// Variable ID for the intermediate values tuple in this backward pass block. - intermediates_tuple: id::Var, + inter_tup: id::Var, /// Instructions at the beginning of this backward pass block, in order. - bwd_nonlinear: Vec, + bwd_nonlin: Vec, /// Instructions at the end of this backward pass block, in reverse order. - bwd_linear: Vec, + bwd_lin: Vec, } /// The forward pass and backward pass of a transposed function under construction. @@ -94,23 +123,49 @@ struct Transpose<'a> { /// one dummy variable as the shape. real_shape: id::Var, - /// Variables from the original function that are just the real part of a dual number variable. - reals: Box<[Option]>, + /// Sources of primitive variables from the original function. + prims: Box<[Option]>, - /// Variables from the original function that are just the dual part of a dual number variable. - duals: Box<[Option]>, + /// Sources for dual number variables from the original function. + duals: Box<[Option<(Src, Src)>]>, /// Accumulator variables for variables from the original function. accums: Box<[Option]>, /// Cotangent variables for variables from the original function. - cotangents: Box<[Option]>, + cotans: Box<[Option]>, /// The current block under construction. block: Block, } impl<'a> Transpose<'a> { + fn re(&self, x: id::Var) -> Src { + let (src, _) = self.duals[x.var()].unwrap(); + src.dual(x) + } + + fn du(&self, x: id::Var) -> Src { + let (_, src) = self.duals[x.var()].unwrap(); + src.dual(x) + } + + fn get_prim_accum(&self, x: id::Var) -> id::Var { + let z = match self.prims[x.var()].unwrap() { + Src::Original => x, + Src::Alias(y) | Src::Projection(y) => y, + }; + self.accums[z.var()].unwrap() + } + + fn get_dual_accum(&self, x: id::Var) -> Option { + let z = match self.duals[x.var()] { + None | Some((_, Src::Original)) => x, + Some((_, Src::Alias(y))) | Some((_, Src::Projection(y))) => y, + }; + self.accums[z.var()] + } + fn ty(&mut self, ty: Ty) -> id::Ty { let t = id::ty(self.types.len()); self.types.push(ty); @@ -130,26 +185,26 @@ impl<'a> Transpose<'a> { } fn keep(&mut self, var: id::Var) { - self.block.bwd_nonlinear.push(Instr { + self.block.bwd_nonlin.push(Instr { var, expr: Expr::Member { - tuple: self.block.intermediates_tuple, - member: id::member(self.block.intermediate_members.len()), + tuple: self.block.inter_tup, + member: id::member(self.block.inter_mem.len()), }, }); - self.block.intermediate_members.push(var); + self.block.inter_mem.push(var); } fn accum(&mut self, shape: id::Var, scope: Option) -> Lin { let t = self.f.vars[shape.var()]; let acc = self.bwd_var(BwdTy::Accum(scope, t)); let cot = self.bwd_var(BwdTy::Known(t)); - self.block.bwd_nonlinear.push(Instr { + self.block.bwd_nonlin.push(Instr { var: acc, expr: Expr::Accum { shape }, }); self.accums[shape.var()] = Some(acc); - self.cotangents[shape.var()] = Some(cot); + self.cotans[shape.var()] = Some(cot); Lin { acc, cot } } @@ -157,19 +212,19 @@ impl<'a> Transpose<'a> { let t = self.f.vars[tan.var()]; let acc = self.bwd_var(BwdTy::Accum(None, t)); let cot = self.bwd_var(BwdTy::Known(t)); - self.block.bwd_nonlinear.push(Instr { + self.block.bwd_nonlin.push(Instr { var: acc, expr: Expr::Accum { shape: self.real_shape, }, }); self.accums[tan.var()] = Some(acc); - self.cotangents[tan.var()] = Some(cot); + self.cotans[tan.var()] = Some(cot); Lin { acc, cot } } fn resolve(&mut self, lin: Lin) { - self.block.bwd_linear.push(Instr { + self.block.bwd_lin.push(Instr { var: lin.cot, expr: Expr::Resolve { var: lin.acc }, }) @@ -179,7 +234,7 @@ impl<'a> Transpose<'a> { for instr in block.iter() { self.instr(instr.var, &instr.expr); } - let vars = take(&mut self.block.intermediate_members); + let vars = take(&mut self.block.inter_mem); let t = self.ty(Ty::Tuple { members: vars.iter().map(|&x| self.fwd_vars[x.var()]).collect(), }); @@ -190,7 +245,7 @@ impl<'a> Transpose<'a> { members: vars.into(), }, }); - self.bwd_vars[self.block.intermediates_tuple.var()] = BwdTy::Known(t); + self.bwd_vars[self.block.inter_tup.var()] = BwdTy::Known(t); t } @@ -204,22 +259,25 @@ impl<'a> Transpose<'a> { var, expr: Expr::Bool { val }, }), - &Expr::F64 { val } => match self.f.vars[var.var()] { - DUAL => { - let lin = self.calc(var); - self.resolve(lin); + &Expr::F64 { val } => { + match self.f.vars[var.var()] { + DUAL => { + let lin = self.calc(var); + self.resolve(lin); + } + _ => self.block.fwd.push(Instr { + var, + expr: Expr::F64 { val }, + }), } - _ => self.block.fwd.push(Instr { - var, - expr: Expr::F64 { val }, - }), - }, + self.prims[var.var()] = Some(Src::Original); + } &Expr::Fin { val } => { self.block.fwd.push(Instr { var, expr: Expr::Fin { val }, }); - self.block.bwd_nonlinear.push(Instr { + self.block.bwd_nonlin.push(Instr { var, expr: Expr::Fin { val }, }); @@ -239,22 +297,22 @@ impl<'a> Transpose<'a> { self.keep(var); let lin = self.accum(var, None); for (i, &elem) in elems.iter().enumerate() { - if let Some(accum) = self.accums[elem.var()] { + if let Some(accum) = self.get_dual_accum(elem) { let index = self.bwd_var(BwdTy::Known(t)); let addend = self.bwd_var(BwdTy::Known(self.f.vars[elem.var()])); let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_linear.push(Instr { + self.block.bwd_lin.push(Instr { var: unit, expr: Expr::Add { accum, addend }, }); - self.block.bwd_linear.push(Instr { + self.block.bwd_lin.push(Instr { var: addend, expr: Expr::Index { array: lin.cot, index, }, }); - self.block.bwd_linear.push(Instr { + self.block.bwd_lin.push(Instr { var: index, expr: Expr::Fin { val: i }, }); @@ -262,220 +320,242 @@ impl<'a> Transpose<'a> { } self.resolve(lin); } - Expr::Tuple { members } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Tuple { - members: members.clone(), - }, - }); - self.keep(var); - let lin = self.accum(var, None); - for (i, &member) in members.iter().enumerate() { - if let Some(accum) = self.accums[member.var()] { - let addend = self.bwd_var(BwdTy::Known(self.f.vars[member.var()])); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_linear.push(Instr { - var: unit, - expr: Expr::Add { accum, addend }, - }); - self.block.bwd_linear.push(Instr { - var: addend, - expr: Expr::Member { - tuple: lin.cot, - member: id::member(i), - }, - }); + Expr::Tuple { members } => match self.types[self.f.vars[var.var()].ty()] { + Ty::F64 => { + let x = members[0]; + let dx = members[1]; + self.duals[var.var()] = Some(( + self.prims[x.var()].unwrap().prim(x), + self.prims[dx.var()].unwrap().prim(dx), + )); + } + _ => { + self.block.fwd.push(Instr { + var, + expr: Expr::Tuple { + members: members.clone(), + }, + }); + self.keep(var); + let lin = self.accum(var, None); + for (i, &member) in members.iter().enumerate() { + if let Some(accum) = self.get_dual_accum(member) { + let addend = self.bwd_var(BwdTy::Known(self.f.vars[member.var()])); + let unit = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { accum, addend }, + }); + self.block.bwd_lin.push(Instr { + var: addend, + expr: Expr::Member { + tuple: lin.cot, + member: id::member(i), + }, + }); + } } + self.resolve(lin); } - self.resolve(lin); - } + }, &Expr::Index { array, index } => { self.block.fwd.push(Instr { var, expr: Expr::Index { array, index }, }); - self.block.bwd_nonlinear.push(Instr { + self.block.bwd_nonlin.push(Instr { var, expr: Expr::Index { array, index }, }); let arr_acc = self.accums[array.var()].unwrap(); let acc = self.bwd_var(BwdTy::Accum(Some(arr_acc), self.f.vars[array.var()])); self.accums[var.var()] = Some(acc); - self.block.bwd_nonlinear.push(Instr { + self.block.bwd_nonlin.push(Instr { var: acc, expr: Expr::Slice { array: arr_acc, index, }, }); - } - &Expr::Member { tuple, member } => match self.f.vars[var.var()] { - REAL => self.reals[var.var()] = Some(tuple), - DUAL => { - let lin = self.accum(var, None); // TODO - self.duals[var.var()] = Some(tuple); - self.resolve(lin); - } - _ => { - self.block.fwd.push(Instr { - var, - expr: Expr::Member { tuple, member }, - }); - self.block.bwd_nonlinear.push(Instr { - var, - expr: Expr::Member { tuple, member }, - }); - let tup_acc = self.accums[tuple.var()].unwrap(); - let acc = self.bwd_var(BwdTy::Accum(Some(tup_acc), self.f.vars[tuple.var()])); - self.accums[var.var()] = Some(acc); - self.block.bwd_nonlinear.push(Instr { - var: acc, - expr: Expr::Field { - tuple: tup_acc, - member, - }, - }); + if let Ty::F64 = self.types[self.f.vars[var.var()].ty()] { + self.duals[var.var()] = Some((Src::Original, Src::Original)); } - }, - - &Expr::Slice { array, index } => todo!(), - &Expr::Field { tuple, member } => todo!(), - - &Expr::Unary { op, arg } => match self.f.vars[var.var()] { - DUAL => match op { - Unop::Not | Unop::Abs | Unop::Sign | Unop::Sqrt => panic!(), - Unop::Neg => { - let lin = self.calc(var); - let res = self.bwd_var(BwdTy::Known(DUAL)); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_linear.push(Instr { - var: unit, - expr: Expr::Add { - accum: self.accums[arg.var()].unwrap(), - addend: res, - }, + } + &Expr::Member { tuple, member } => { + let t = self.f.vars[var.var()]; + match t { + REAL => self.prims[var.var()] = Some(self.re(tuple)), + DUAL => self.prims[var.var()] = Some(self.du(tuple)), + _ => { + self.block.fwd.push(Instr { + var, + expr: Expr::Member { tuple, member }, + }); + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Member { tuple, member }, }); - self.block.bwd_linear.push(Instr { - var: res, - expr: Expr::Unary { - op: Unop::Neg, - arg: lin.cot, + let tup_acc = self.accums[tuple.var()].unwrap(); + let acc = + self.bwd_var(BwdTy::Accum(Some(tup_acc), self.f.vars[tuple.var()])); + self.accums[var.var()] = Some(acc); + self.block.bwd_nonlin.push(Instr { + var: acc, + expr: Expr::Field { + tuple: tup_acc, + member, }, }); - self.resolve(lin); + if let Ty::F64 = self.types[t.ty()] { + self.duals[var.var()] = Some((Src::Original, Src::Original)); + } } - }, - _ => { - self.block.fwd.push(Instr { - var, - expr: Expr::Unary { op, arg }, - }); - self.keep(var); } - }, - &Expr::Binary { op, left, right } => match self.f.vars[var.var()] { - DUAL => { - let lin = self.calc(var); - match op { - Binop::And - | Binop::Or - | Binop::Iff - | Binop::Xor - | Binop::Neq - | Binop::Lt - | Binop::Leq - | Binop::Eq - | Binop::Gt - | Binop::Geq => panic!(), - Binop::Add => { - let a = self.bwd_var(BwdTy::Unit); - let b = self.bwd_var(BwdTy::Unit); - self.block.bwd_linear.push(Instr { - var: a, - expr: Expr::Add { - accum: self.accums[left.var()].unwrap(), - addend: lin.cot, - }, - }); - self.block.bwd_linear.push(Instr { - var: b, - expr: Expr::Add { - accum: self.accums[right.var()].unwrap(), - addend: lin.cot, - }, - }); - } - Binop::Sub => { + } + + &Expr::Slice { array, index } => todo!(), + &Expr::Field { tuple, member } => todo!(), + + &Expr::Unary { op, arg } => { + match self.f.vars[var.var()] { + DUAL => match op { + Unop::Not | Unop::Abs | Unop::Sign | Unop::Sqrt => panic!(), + Unop::Neg => { + let lin = self.calc(var); let res = self.bwd_var(BwdTy::Known(DUAL)); - let a = self.bwd_var(BwdTy::Unit); - let b = self.bwd_var(BwdTy::Unit); - self.block.bwd_linear.push(Instr { - var: a, - expr: Expr::Add { - accum: self.accums[left.var()].unwrap(), - addend: lin.cot, - }, - }); - self.block.bwd_linear.push(Instr { - var: b, + let unit = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: unit, expr: Expr::Add { - accum: self.accums[right.var()].unwrap(), + accum: self.get_prim_accum(arg), addend: res, }, }); - self.block.bwd_linear.push(Instr { + self.block.bwd_lin.push(Instr { var: res, expr: Expr::Unary { op: Unop::Neg, arg: lin.cot, }, }); + self.resolve(lin); } - Binop::Mul | Binop::Div => { - let res = self.bwd_var(BwdTy::Known(DUAL)); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_linear.push(Instr { - var: unit, - expr: Expr::Add { - accum: self.accums[left.var()].unwrap(), - addend: res, - }, - }); - self.block.bwd_linear.push(Instr { - var: res, - expr: Expr::Binary { - op, - left: lin.cot, - right, - }, - }); - } + }, + _ => { + self.block.fwd.push(Instr { + var, + expr: Expr::Unary { op, arg }, + }); + self.keep(var); } - self.resolve(lin); } - _ => { - self.block.fwd.push(Instr { - var, - expr: Expr::Binary { op, left, right }, - }); - self.keep(var); + self.prims[var.var()] = Some(Src::Original); + } + &Expr::Binary { op, left, right } => { + match self.f.vars[var.var()] { + DUAL => { + let lin = self.calc(var); + match op { + Binop::And + | Binop::Or + | Binop::Iff + | Binop::Xor + | Binop::Neq + | Binop::Lt + | Binop::Leq + | Binop::Eq + | Binop::Gt + | Binop::Geq => panic!(), + Binop::Add => { + let a = self.bwd_var(BwdTy::Unit); + let b = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: a, + expr: Expr::Add { + accum: self.get_prim_accum(left), + addend: lin.cot, + }, + }); + self.block.bwd_lin.push(Instr { + var: b, + expr: Expr::Add { + accum: self.get_prim_accum(right), + addend: lin.cot, + }, + }); + } + Binop::Sub => { + let res = self.bwd_var(BwdTy::Known(DUAL)); + let a = self.bwd_var(BwdTy::Unit); + let b = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: a, + expr: Expr::Add { + accum: self.get_prim_accum(left), + addend: lin.cot, + }, + }); + self.block.bwd_lin.push(Instr { + var: b, + expr: Expr::Add { + accum: self.get_prim_accum(right), + addend: res, + }, + }); + self.block.bwd_lin.push(Instr { + var: res, + expr: Expr::Unary { + op: Unop::Neg, + arg: lin.cot, + }, + }); + } + Binop::Mul | Binop::Div => { + let res = self.bwd_var(BwdTy::Known(DUAL)); + let unit = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { + accum: self.get_prim_accum(left), + addend: res, + }, + }); + self.block.bwd_lin.push(Instr { + var: res, + expr: Expr::Binary { + op, + left: lin.cot, + right, + }, + }); + } + } + self.resolve(lin); + } + _ => { + self.block.fwd.push(Instr { + var, + expr: Expr::Binary { op, left, right }, + }); + self.keep(var); + } } - }, + self.prims[var.var()] = Some(Src::Original); + } &Expr::Select { cond, then, els } => todo!(), Expr::Call { id, generics, args } => todo!(), Expr::For { arg, body, ret } => { let mut block = Block { fwd: vec![], - intermediate_members: vec![], - intermediates_tuple: self.bwd_var(BwdTy::Unknown), - bwd_nonlinear: vec![], - bwd_linear: vec![], + inter_mem: vec![], + inter_tup: self.bwd_var(BwdTy::Unknown), + bwd_nonlin: vec![], + bwd_lin: vec![], }; swap(&mut self.block, &mut block); - self.block(body); + self.block(body); // TODO swap(&mut self.block, &mut block); } @@ -492,65 +572,81 @@ impl<'a> Transpose<'a> { /// Return two functions that make up the transpose of `f`. pub fn transpose(f: &Func) -> (Func, Func) { + let types: Vec<_> = f + .types + .iter() + .enumerate() + .map(|(i, ty)| match ty { + Ty::Unit => Ty::Unit, + Ty::Bool => Ty::Unit, + Ty::F64 => { + if !is_primitive(id::ty(i)) { + panic!() + } + Ty::F64 + } + &Ty::Fin { size } => Ty::Fin { size }, + &Ty::Generic { id } => Ty::Generic { id }, + &Ty::Scope { kind, id } => Ty::Scope { kind, id }, + &Ty::Ref { scope, inner } => { + if is_primitive(inner) { + panic!() + } + Ty::Ref { scope, inner } + } + &Ty::Array { index, elem } => { + if is_primitive(elem) { + panic!() + } + Ty::Array { index, elem } + } + Ty::Tuple { members } => { + if members.iter().any(|&t| is_primitive(t)) { + Ty::F64 + } else { + Ty::Tuple { + members: members.clone(), + } + } + } + }) + .collect(); + let mut bwd_vars: Vec<_> = f.vars.iter().map(|&t| BwdTy::Known(t)).collect(); let real_shape = id::var(bwd_vars.len()); bwd_vars.push(BwdTy::Known(DUAL)); let intermediates_tuple = id::var(bwd_vars.len()); bwd_vars.push(BwdTy::Unknown); + + let mut duals = vec![None; f.vars.len()].into_boxed_slice(); + let mut accums = vec![None; f.vars.len()].into_boxed_slice(); + + for param in f.params.iter() { + let t = f.vars[param.var()]; + if let Ty::F64 = types[t.ty()] { + duals[param.var()] = Some((Src::Original, Src::Original)); + } + let acc = id::var(bwd_vars.len()); + bwd_vars.push(BwdTy::Accum(None, t)); + accums[param.var()] = Some(acc); + } + let mut tp = Transpose { f, - types: f - .types - .iter() - .enumerate() - .map(|(i, ty)| match ty { - Ty::Unit => Ty::Unit, - Ty::Bool => Ty::Unit, - Ty::F64 => { - if !is_primitive(id::ty(i)) { - panic!() - } - Ty::F64 - } - &Ty::Fin { size } => Ty::Fin { size }, - &Ty::Generic { id } => Ty::Generic { id }, - &Ty::Scope { kind, id } => Ty::Scope { kind, id }, - &Ty::Ref { scope, inner } => { - if is_primitive(inner) { - panic!() - } - Ty::Ref { scope, inner } - } - &Ty::Array { index, elem } => { - if is_primitive(elem) { - panic!() - } - Ty::Array { index, elem } - } - Ty::Tuple { members } => { - if members.iter().any(|&t| is_primitive(t)) { - Ty::F64 - } else { - Ty::Tuple { - members: members.clone(), - } - } - } - }) - .collect(), + types, fwd_vars: f.vars.to_vec(), bwd_vars, real_shape, - reals: vec![None; f.vars.len()].into(), - duals: vec![None; f.vars.len()].into(), - accums: vec![None; f.vars.len()].into(), // TODO - cotangents: vec![None; f.vars.len()].into(), // TODO + prims: vec![None; f.vars.len()].into(), + duals, + accums, + cotans: vec![None; f.vars.len()].into(), block: Block { fwd: vec![], - intermediates_tuple, - intermediate_members: vec![], - bwd_nonlinear: vec![], - bwd_linear: vec![], + inter_tup: intermediates_tuple, + inter_mem: vec![], + bwd_nonlin: vec![], + bwd_lin: vec![], }, }; @@ -569,7 +665,7 @@ pub fn transpose(f: &Func) -> (Func, Func) { fwd_body.push(Instr { var: fwd_ret, expr: Expr::Tuple { - members: vec![f.ret, tp.block.intermediates_tuple].into(), + members: vec![f.ret, tp.block.inter_tup].into(), }, }); @@ -601,8 +697,8 @@ pub fn transpose(f: &Func) -> (Func, Func) { var: tp.real_shape, expr: Expr::F64 { val: 0. }, }]; - bwd_body.append(&mut tp.block.bwd_nonlinear); - let mut bwd_linear = tp.block.bwd_linear; + bwd_body.append(&mut tp.block.bwd_nonlin); + let mut bwd_linear = tp.block.bwd_lin; bwd_linear.reverse(); bwd_body.append(&mut bwd_linear); bwd_body.push(Instr { diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 8cf7f57..bacdb01 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -617,7 +617,7 @@ enum Ty { /// A tuple type, with additional information about key names that makes it into a struct. Struct { /// String IDs for key names, in order; the actual strings are stored in JavaScript. - keys: Box<[usize]>, + keys: Option>, /// Member types of the underlying tuple. Must be the same length as `keys`. members: Box<[id::Ty]>, @@ -635,7 +635,7 @@ impl Ty { Ty::Scope { kind, id } => (rose::Ty::Scope { kind, id }, None), Ty::Ref { scope, inner } => (rose::Ty::Ref { scope, inner }, None), Ty::Array { index, elem } => (rose::Ty::Array { index, elem }, None), - Ty::Struct { keys, members } => (rose::Ty::Tuple { members }, Some(keys)), + Ty::Struct { keys, members } => (rose::Ty::Tuple { members }, keys), } } } @@ -830,7 +830,10 @@ impl FuncBuilder { /// `Err` if `t` is out of range or does not represent a struct type. pub fn keys(&self, t: usize) -> Result, JsError> { match self.ty(t)? { - Ty::Struct { keys, members: _ } => Ok(keys.clone()), + Ty::Struct { + keys: Some(keys), + members: _, + } => Ok(keys.clone()), _ => Err(JsError::new("type is not a struct")), } } @@ -942,7 +945,7 @@ impl FuncBuilder { pub fn ty_struct(&mut self, keys: &[usize], mems: &[usize]) -> usize { self.newtype( Ty::Struct { - keys: keys.into(), + keys: Some(keys.into()), members: mems.iter().map(|&t| id::ty(t)).collect(), }, EnumSet::only(rose::Constraint::Value), @@ -1127,10 +1130,7 @@ impl FuncBuilder { Ty::Struct { keys: structs[t] .as_ref() - .unwrap() - .iter() - .map(|&s| strings[s]) - .collect(), + .map(|ss| ss.iter().map(|&s| strings[s]).collect()), members: members .iter() .map(|x| types[x.ty()]) diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index ec4f809..1516dad 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -355,6 +355,12 @@ describe("valid", () => { test("VJP", () => { const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1])); - vjp(f); + const g = vjp(f); + const h = fn([], Vec(3, Real), () => { + const { ret: x, grad } = g([2, 3]); + const v = grad(1); + return [x, v[0], v[1]]; + }); + expect(interp(h)()).toBe([6, 3, 2]); }); }); From 7af3ecd260f20fac6b58205b9ee5b735d119cbba Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Sat, 9 Sep 2023 01:23:27 -0400 Subject: [PATCH 08/11] Fix more bugs --- Cargo.lock | 1 + crates/interp/src/lib.rs | 4 +- crates/transpose/Cargo.toml | 1 + crates/transpose/src/lib.rs | 193 ++++++++++++++++++++++++------------ 4 files changed, 131 insertions(+), 68 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d452993..9f71a70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -191,6 +191,7 @@ dependencies = [ name = "rose-transpose" version = "0.0.0" dependencies = [ + "enumset", "rose", ] diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 8f0a7d1..4f3fd5b 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -193,11 +193,11 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { }, &Expr::Slice { array, index } => match (self.get(array).inner(), self.get(index)) { - (Val::Array(v), &Val::Fin(i)) => v[i].clone(), + (Val::Array(v), &Val::Fin(i)) => Val::Ref(Rc::new(v[i].clone())), _ => unreachable!(), }, &Expr::Field { tuple, member } => match self.get(tuple).inner() { - Val::Tuple(x) => x[member.member()].clone(), + Val::Tuple(x) => Val::Ref(Rc::new(x[member.member()].clone())), _ => unreachable!(), }, diff --git a/crates/transpose/Cargo.toml b/crates/transpose/Cargo.toml index e35fdcc..c122e43 100644 --- a/crates/transpose/Cargo.toml +++ b/crates/transpose/Cargo.toml @@ -4,4 +4,5 @@ version = "0.0.0" edition = "2021" [dependencies] +enumset = "1" rose = { path = "../core" } diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 7d0663a..6bc14cf 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -1,3 +1,4 @@ +use enumset::EnumSet; use rose::{id, Binop, Constraint, Expr, Func, Instr, Ty, Unop}; use std::mem::{swap, take}; @@ -8,6 +9,12 @@ fn is_primitive(t: id::Ty) -> bool { t == REAL || t == DUAL } +enum Scope { + Generic(id::Generic), + Derived(id::Var), + Original, +} + /// Type ID of a variable while building up the backward pass. enum BwdTy { /// We already know the type ID of this variable. @@ -24,7 +31,7 @@ enum BwdTy { /// variable we've introduced; we don't add these as we go because we don't know ahead of time /// how many types we'll need for intermediate values, and we want all those intermediate value /// type IDs to be the same between the forward and backward passes. - Accum(Option, id::Ty), + Accum(Scope, id::Ty), /// We don't know the type of this variable yet, but will soon. /// @@ -150,20 +157,33 @@ impl<'a> Transpose<'a> { src.dual(x) } - fn get_prim_accum(&self, x: id::Var) -> id::Var { - let z = match self.prims[x.var()].unwrap() { + fn get_prim(&self, x: id::Var) -> id::Var { + match self.prims[x.var()].unwrap() { Src::Original => x, Src::Alias(y) | Src::Projection(y) => y, - }; - self.accums[z.var()].unwrap() + } } - fn get_dual_accum(&self, x: id::Var) -> Option { - let z = match self.duals[x.var()] { + fn get_re(&self, x: id::Var) -> id::Var { + match self.duals[x.var()] { + None | Some((Src::Original, _)) => x, + Some((Src::Alias(y), _)) | Some((Src::Projection(y), _)) => y, + } + } + + fn get_du(&self, x: id::Var) -> id::Var { + match self.duals[x.var()] { None | Some((_, Src::Original)) => x, Some((_, Src::Alias(y))) | Some((_, Src::Projection(y))) => y, - }; - self.accums[z.var()] + } + } + + fn get_prim_accum(&self, x: id::Var) -> id::Var { + self.accums[self.get_prim(x).var()].unwrap() + } + + fn get_dual_accum(&self, x: id::Var) -> Option { + self.accums[self.get_du(x).var()] } fn ty(&mut self, ty: Ty) -> id::Ty { @@ -195,7 +215,7 @@ impl<'a> Transpose<'a> { self.block.inter_mem.push(var); } - fn accum(&mut self, shape: id::Var, scope: Option) -> Lin { + fn accum(&mut self, shape: id::Var, scope: Scope) -> Lin { let t = self.f.vars[shape.var()]; let acc = self.bwd_var(BwdTy::Accum(scope, t)); let cot = self.bwd_var(BwdTy::Known(t)); @@ -210,7 +230,7 @@ impl<'a> Transpose<'a> { fn calc(&mut self, tan: id::Var) -> Lin { let t = self.f.vars[tan.var()]; - let acc = self.bwd_var(BwdTy::Accum(None, t)); + let acc = self.bwd_var(BwdTy::Accum(Scope::Original, t)); let cot = self.bwd_var(BwdTy::Known(t)); self.block.bwd_nonlin.push(Instr { var: acc, @@ -230,7 +250,7 @@ impl<'a> Transpose<'a> { }) } - fn block(&mut self, block: &[Instr]) -> id::Ty { + fn block(&mut self, block: &[Instr]) -> (id::Ty, id::Var) { for instr in block.iter() { self.instr(instr.var, &instr.expr); } @@ -246,7 +266,7 @@ impl<'a> Transpose<'a> { }, }); self.bwd_vars[self.block.inter_tup.var()] = BwdTy::Known(t); - t + (t, var) } fn instr(&mut self, var: id::Var, expr: &Expr) { @@ -291,11 +311,11 @@ impl<'a> Transpose<'a> { self.block.fwd.push(Instr { var, expr: Expr::Array { - elems: elems.clone(), + elems: elems.iter().map(|&elem| self.get_re(elem)).collect(), }, }); self.keep(var); - let lin = self.accum(var, None); + let lin = self.accum(var, Scope::Original); for (i, &elem) in elems.iter().enumerate() { if let Some(accum) = self.get_dual_accum(elem) { let index = self.bwd_var(BwdTy::Known(t)); @@ -322,8 +342,8 @@ impl<'a> Transpose<'a> { } Expr::Tuple { members } => match self.types[self.f.vars[var.var()].ty()] { Ty::F64 => { - let x = members[0]; - let dx = members[1]; + let x = members[1]; + let dx = members[0]; self.duals[var.var()] = Some(( self.prims[x.var()].unwrap().prim(x), self.prims[dx.var()].unwrap().prim(dx), @@ -333,11 +353,11 @@ impl<'a> Transpose<'a> { self.block.fwd.push(Instr { var, expr: Expr::Tuple { - members: members.clone(), + members: members.iter().map(|&member| self.get_re(member)).collect(), }, }); self.keep(var); - let lin = self.accum(var, None); + let lin = self.accum(var, Scope::Original); for (i, &member) in members.iter().enumerate() { if let Some(accum) = self.get_dual_accum(member) { let addend = self.bwd_var(BwdTy::Known(self.f.vars[member.var()])); @@ -369,7 +389,10 @@ impl<'a> Transpose<'a> { expr: Expr::Index { array, index }, }); let arr_acc = self.accums[array.var()].unwrap(); - let acc = self.bwd_var(BwdTy::Accum(Some(arr_acc), self.f.vars[array.var()])); + let acc = self.bwd_var(BwdTy::Accum( + Scope::Derived(arr_acc), + self.f.vars[array.var()], + )); self.accums[var.var()] = Some(acc); self.block.bwd_nonlin.push(Instr { var: acc, @@ -397,8 +420,10 @@ impl<'a> Transpose<'a> { expr: Expr::Member { tuple, member }, }); let tup_acc = self.accums[tuple.var()].unwrap(); - let acc = - self.bwd_var(BwdTy::Accum(Some(tup_acc), self.f.vars[tuple.var()])); + let acc = self.bwd_var(BwdTy::Accum( + Scope::Derived(tup_acc), + self.f.vars[tuple.var()], + )); self.accums[var.var()] = Some(acc); self.block.bwd_nonlin.push(Instr { var: acc, @@ -445,7 +470,10 @@ impl<'a> Transpose<'a> { _ => { self.block.fwd.push(Instr { var, - expr: Expr::Unary { op, arg }, + expr: Expr::Unary { + op, + arg: self.get_prim(arg), + }, }); self.keep(var); } @@ -526,7 +554,7 @@ impl<'a> Transpose<'a> { expr: Expr::Binary { op, left: lin.cot, - right, + right: self.get_prim(right), }, }); } @@ -536,7 +564,11 @@ impl<'a> Transpose<'a> { _ => { self.block.fwd.push(Instr { var, - expr: Expr::Binary { op, left, right }, + expr: Expr::Binary { + op, + left: self.get_prim(left), + right: self.get_prim(right), + }, }); self.keep(var); } @@ -612,24 +644,59 @@ pub fn transpose(f: &Func) -> (Func, Func) { }) .collect(); + let mut bwd_generics: Vec<_> = f + .generics + .iter() + .map(|&before| { + let mut after = before; + if before.contains(Constraint::Read) { + after.remove(Constraint::Read); + after.insert(Constraint::Accum); + } + if before.contains(Constraint::Accum) { + after.remove(Constraint::Accum); + after.insert(Constraint::Read); + } + after + }) + .collect(); + let mut bwd_vars: Vec<_> = f.vars.iter().map(|&t| BwdTy::Known(t)).collect(); let real_shape = id::var(bwd_vars.len()); bwd_vars.push(BwdTy::Known(DUAL)); - let intermediates_tuple = id::var(bwd_vars.len()); + let inter_tup = id::var(bwd_vars.len()); bwd_vars.push(BwdTy::Unknown); let mut duals = vec![None; f.vars.len()].into_boxed_slice(); let mut accums = vec![None; f.vars.len()].into_boxed_slice(); - for param in f.params.iter() { - let t = f.vars[param.var()]; - if let Ty::F64 = types[t.ty()] { - duals[param.var()] = Some((Src::Original, Src::Original)); - } - let acc = id::var(bwd_vars.len()); - bwd_vars.push(BwdTy::Accum(None, t)); - accums[param.var()] = Some(acc); - } + let mut inter_mem = vec![]; + let mut bwd_nonlin = vec![]; + + let mut bwd_params: Vec<_> = f + .params + .iter() + .map(|¶m| { + let g = id::generic(bwd_generics.len()); + bwd_generics.push(EnumSet::only(Constraint::Accum)); + let t = f.vars[param.var()]; + bwd_nonlin.push(Instr { + var: param, + expr: Expr::Member { + tuple: inter_tup, + member: id::member(inter_mem.len()), + }, + }); + inter_mem.push(param); + let acc = id::var(bwd_vars.len()); + bwd_vars.push(BwdTy::Accum(Scope::Generic(g), t)); + if let Ty::F64 = types[t.ty()] { + duals[param.var()] = Some((Src::Original, Src::Original)); + } + accums[param.var()] = Some(acc); + acc + }) + .collect(); let mut tp = Transpose { f, @@ -643,14 +710,16 @@ pub fn transpose(f: &Func) -> (Func, Func) { cotans: vec![None; f.vars.len()].into(), block: Block { fwd: vec![], - inter_tup: intermediates_tuple, - inter_mem: vec![], - bwd_nonlin: vec![], + inter_tup, + inter_mem, + bwd_nonlin, bwd_lin: vec![], }, }; - let t_intermediates = tp.block(&f.body); + let (t_intermediates, fwd_inter) = tp.block(&f.body); + let fwd_ret = tp.get_re(f.ret); + let bwd_acc = tp.get_dual_accum(f.ret).unwrap(); let mut bwd_types = tp.types.clone(); let mut fwd_types = tp.types; @@ -659,13 +728,13 @@ pub fn transpose(f: &Func) -> (Func, Func) { members: vec![f.vars[f.ret.var()], t_intermediates].into(), }); let mut fwd_vars = tp.fwd_vars; - let fwd_ret = id::var(fwd_vars.len()); + let fwd_bundle = id::var(fwd_vars.len()); fwd_vars.push(t_bundle); let mut fwd_body = tp.block.fwd; fwd_body.push(Instr { - var: fwd_ret, + var: fwd_bundle, expr: Expr::Tuple { - members: vec![f.ret, tp.block.inter_tup].into(), + members: vec![fwd_ret, fwd_inter].into(), }, }); @@ -691,20 +760,27 @@ pub fn transpose(f: &Func) -> (Func, Func) { BwdTy::Unknown => panic!(), }) .collect(); - let bwd_ret = id::var(bwd_vars.len()); + let bwd_cot = id::var(bwd_vars.len()); + bwd_vars.push(f.vars[f.ret.var()]); + let bwd_unit = id::var(bwd_vars.len()); bwd_vars.push(t_unit); + bwd_params.push(bwd_cot); + bwd_params.push(tp.block.inter_tup); let mut bwd_body = vec![Instr { var: tp.real_shape, expr: Expr::F64 { val: 0. }, }]; bwd_body.append(&mut tp.block.bwd_nonlin); let mut bwd_linear = tp.block.bwd_lin; + bwd_linear.push(Instr { + var: bwd_unit, + expr: Expr::Add { + accum: bwd_acc, + addend: bwd_cot, + }, + }); bwd_linear.reverse(); bwd_body.append(&mut bwd_linear); - bwd_body.push(Instr { - var: bwd_ret, - expr: Expr::Unit, - }); ( Func { @@ -712,30 +788,15 @@ pub fn transpose(f: &Func) -> (Func, Func) { types: fwd_types.into(), vars: fwd_vars.into(), params: f.params.clone(), - ret: fwd_ret, + ret: fwd_bundle, body: fwd_body.into(), }, Func { - generics: f - .generics - .iter() - .map(|&before| { - let mut after = before; - if before.contains(Constraint::Read) { - after.remove(Constraint::Read); - after.insert(Constraint::Accum); - } - if before.contains(Constraint::Accum) { - after.remove(Constraint::Accum); - after.insert(Constraint::Read); - } - after - }) - .collect(), // TODO + generics: bwd_generics.into(), types: bwd_types.into(), vars: bwd_vars.into(), - params: f.params.clone(), // TODO - ret: bwd_ret, + params: bwd_params.into(), + ret: bwd_unit, body: bwd_body.into(), }, ) From c91d7e57142372b9844ce7045cce27b182022023 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Sat, 9 Sep 2023 01:50:10 -0400 Subject: [PATCH 09/11] Fix the interpreter --- crates/interp/src/lib.rs | 46 +++++++++++++++++++++++---------- crates/web/src/lib.rs | 1 + packages/core/src/impl.test.ts | 1 + packages/core/src/index.test.ts | 2 +- 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 4f3fd5b..4132195 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -16,7 +16,7 @@ pub enum Val { Bool(bool), F64(Cell), Fin(usize), - Ref(Rc), + Ref(Rc, Option), Array(Vals), // assume all indices are `Fin` Tuple(Vals), } @@ -50,9 +50,35 @@ impl Val { } } + fn fin(&self) -> usize { + match self { + &Val::Fin(i) => i, + _ => unreachable!(), + } + } + + fn get(&self, i: usize) -> &Self { + match self { + Val::Array(x) => &x[i], + Val::Tuple(x) => &x[i], + _ => unreachable!(), + } + } + + fn slice(&self, i: usize) -> Self { + match self { + Val::Ref(x, None) => Val::Ref(Rc::clone(x), Some(i)), + Val::Ref(x, Some(j)) => Val::Ref(Rc::new(x.get(*j).clone()), Some(i)), + _ => unreachable!(), + } + } + fn inner(&self) -> &Self { match self { - Val::Ref(x) => x.as_ref(), + Val::Ref(x, i) => match i { + None => x.as_ref(), + &Some(j) => x.get(j), + }, _ => unreachable!(), } } @@ -64,7 +90,7 @@ impl Val { &Self::Bool(x) => Self::Bool(x), Self::F64(_) => Self::F64(Cell::new(0.)), &Self::Fin(x) => Self::Fin(x), - Self::Ref(_) => unreachable!(), + Self::Ref(..) => unreachable!(), Self::Array(x) => Self::Array(collect_vals(x.iter().map(|x| x.zero()))), Self::Tuple(x) => Self::Tuple(collect_vals(x.iter().map(|x| x.zero()))), } @@ -192,14 +218,8 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { _ => unreachable!(), }, - &Expr::Slice { array, index } => match (self.get(array).inner(), self.get(index)) { - (Val::Array(v), &Val::Fin(i)) => Val::Ref(Rc::new(v[i].clone())), - _ => unreachable!(), - }, - &Expr::Field { tuple, member } => match self.get(tuple).inner() { - Val::Tuple(x) => Val::Ref(Rc::new(x[member.member()].clone())), - _ => unreachable!(), - }, + &Expr::Slice { array, index } => self.get(array).slice(self.get(index).fin()), + &Expr::Field { tuple, member } => self.get(tuple).slice(member.member()), &Expr::Unary { op, arg } => { let x = self.get(arg); @@ -257,8 +277,8 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { )) } - &Expr::Read { var } => Val::Ref(Rc::new(self.get(var).clone())), - &Expr::Accum { shape } => Val::Ref(Rc::new(self.get(shape).zero())), + &Expr::Read { var } => Val::Ref(Rc::new(self.get(var).clone()), None), + &Expr::Accum { shape } => Val::Ref(Rc::new(self.get(shape).zero()), None), &Expr::Ask { var } => self.get(var).inner().clone(), &Expr::Add { accum, addend } => { diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index bacdb01..3511125 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -40,6 +40,7 @@ pub fn layouts() -> Result { ("Func", layout::()), ("Instr", layout::()), ("Ty", layout::()), + ("Val", layout::()), ]) } diff --git a/packages/core/src/impl.test.ts b/packages/core/src/impl.test.ts index 0aa4727..b8686cf 100644 --- a/packages/core/src/impl.test.ts +++ b/packages/core/src/impl.test.ts @@ -39,6 +39,7 @@ test("core IR type layouts", () => { Func: { size: 44, align: 4 }, Instr: { size: 32, align: 8 }, Ty: { size: 12, align: 4 }, + Val: { size: 16, align: 8 }, }); }); diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 1516dad..7432bc1 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -361,6 +361,6 @@ describe("valid", () => { const v = grad(1); return [x, v[0], v[1]]; }); - expect(interp(h)()).toBe([6, 3, 2]); + expect(interp(h)()).toEqual([6, 3, 2]); }); }); From 8278f8253296f68b67c24d9ed5f175e40c97943c Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Sat, 9 Sep 2023 16:01:17 -0400 Subject: [PATCH 10/11] Test ternary --- crates/transpose/src/lib.rs | 70 +++++++++++++++++++++++++++++---- packages/core/src/index.test.ts | 22 ++++++++--- 2 files changed, 79 insertions(+), 13 deletions(-) diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 6bc14cf..79b0a20 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -275,10 +275,16 @@ impl<'a> Transpose<'a> { var, expr: Expr::Unit, }), - &Expr::Bool { val } => self.block.fwd.push(Instr { - var, - expr: Expr::Bool { val }, - }), + &Expr::Bool { val } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Bool { val }, + }); + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Bool { val }, + }); + } &Expr::F64 { val } => { match self.f.vars[var.var()] { DUAL => { @@ -562,12 +568,25 @@ impl<'a> Transpose<'a> { self.resolve(lin); } _ => { + let (a, b) = match op { + Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right), + Binop::Neq + | Binop::Lt + | Binop::Leq + | Binop::Eq + | Binop::Gt + | Binop::Geq + | Binop::Add + | Binop::Sub + | Binop::Mul + | Binop::Div => (self.get_prim(left), self.get_prim(right)), + }; self.block.fwd.push(Instr { var, expr: Expr::Binary { op, - left: self.get_prim(left), - right: self.get_prim(right), + left: a, + right: b, }, }); self.keep(var); @@ -575,7 +594,42 @@ impl<'a> Transpose<'a> { } self.prims[var.var()] = Some(Src::Original); } - &Expr::Select { cond, then, els } => todo!(), + &Expr::Select { cond, then, els } => { + self.block.fwd.push(Instr { + var, + expr: Expr::Select { + cond, + then: self.get_re(then), + els: self.get_re(els), + }, + }); + self.keep(var); + let lin = self.accum(var, Scope::Original); + let acc_then = self.get_dual_accum(then).unwrap(); + let acc_els = self.get_dual_accum(els).unwrap(); + // TODO: this scope is wrong; it actually needs to match both `then` and `els` + let acc = self.bwd_var(BwdTy::Accum(Scope::Original, self.f.vars[then.var()])); + let unit = self.bwd_var(BwdTy::Unit); + self.block.bwd_lin.push(Instr { + var: unit, + expr: Expr::Add { + accum: acc, + addend: lin.cot, + }, + }); + self.block.bwd_lin.push(Instr { + var: acc, + expr: Expr::Select { + cond, + then: acc_then, + els: acc_els, + }, + }); + self.resolve(lin); + if let Ty::F64 = self.types[self.f.vars[var.var()].ty()] { + self.duals[var.var()] = Some((Src::Original, Src::Original)); + } + } Expr::Call { id, generics, args } => todo!(), Expr::For { arg, body, ret } => { @@ -610,7 +664,7 @@ pub fn transpose(f: &Func) -> (Func, Func) { .enumerate() .map(|(i, ty)| match ty { Ty::Unit => Ty::Unit, - Ty::Bool => Ty::Unit, + Ty::Bool => Ty::Bool, Ty::F64 => { if !is_primitive(id::ty(i)) { panic!() diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 7432bc1..fee8102 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -12,6 +12,7 @@ import { interp, jvp, mul, + or, select, sign, sqrt, @@ -346,7 +347,7 @@ describe("valid", () => { test("JVP with sharing in call graph", () => { let f = fn([Real], Real, (x) => x); - for (let i = 0; i < 20; i++) { + for (let i = 0; i < 20; ++i) { f = fn([Real], Real, (x) => add(f(x), f(x))); } const g = interp(jvp(f)); @@ -355,12 +356,23 @@ describe("valid", () => { test("VJP", () => { const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1])); - const g = vjp(f); - const h = fn([], Vec(3, Real), () => { - const { ret: x, grad } = g([2, 3]); + const g = fn([], Vec(3, Real), () => { + const { ret: x, grad } = vjp(f)([2, 3]); const v = grad(1); return [x, v[0], v[1]]; }); - expect(interp(h)()).toEqual([6, 3, 2]); + expect(interp(g)()).toEqual([6, 3, 2]); + }); + + test("VJP with struct and select", () => { + const Stuff = { a: Null, b: Bool, c: Real } as const; + const f = fn([Stuff], Real, ({ b, c }) => select(or(false, b), Real, c, 2)); + const g = fn([Bool, Real], { x: Real, stuff: Stuff }, (b, c) => { + const { ret: x, grad } = vjp(f)({ a: null, b, c }); + return { x, stuff: grad(3) }; + }); + const h = interp(g); + expect(h(true, 5)).toEqual({ x: 5, stuff: { a: null, b: true, c: 3 } }); + expect(h(false, 7)).toEqual({ x: 2, stuff: { a: null, b: false, c: 0 } }); }); }); From 7f0161e544ad7a03e32249ec8b79183d2e736e30 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Sat, 9 Sep 2023 16:11:28 -0400 Subject: [PATCH 11/11] Always make accumulators --- crates/transpose/src/lib.rs | 20 ++++++++++++++++---- packages/core/src/index.test.ts | 24 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 79b0a20..853a11c 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -271,10 +271,18 @@ impl<'a> Transpose<'a> { fn instr(&mut self, var: id::Var, expr: &Expr) { match expr { - Expr::Unit => self.block.fwd.push(Instr { - var, - expr: Expr::Unit, - }), + Expr::Unit => { + self.block.fwd.push(Instr { + var, + expr: Expr::Unit, + }); + self.block.bwd_nonlin.push(Instr { + var, + expr: Expr::Unit, + }); + let lin = self.accum(var, Scope::Original); + self.resolve(lin); + } &Expr::Bool { val } => { self.block.fwd.push(Instr { var, @@ -284,6 +292,8 @@ impl<'a> Transpose<'a> { var, expr: Expr::Bool { val }, }); + let lin = self.accum(var, Scope::Original); + self.resolve(lin); } &Expr::F64 { val } => { match self.f.vars[var.var()] { @@ -307,6 +317,8 @@ impl<'a> Transpose<'a> { var, expr: Expr::Fin { val }, }); + let lin = self.accum(var, Scope::Original); + self.resolve(lin); } Expr::Array { elems } => { diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index fee8102..6e105ba 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -375,4 +375,28 @@ describe("valid", () => { expect(h(true, 5)).toEqual({ x: 5, stuff: { a: null, b: true, c: 3 } }); expect(h(false, 7)).toEqual({ x: 2, stuff: { a: null, b: false, c: 0 } }); }); + + test("VJP with select on null", () => { + const f = fn([Null], Null, () => select(true, Null, null, null)); + const g = fn([], Null, () => vjp(f)(null).ret); + const h = interp(g); + expect(h()).toBe(null); + }); + + test("VJP with select on booleans", () => { + const f = fn([Bool], Bool, (p) => select(p, Bool, false, true)); + const g = fn([Bool], Bool, (p) => vjp(f)(p).ret); + const h = interp(g); + expect(h(true)).toBe(false); + expect(h(false)).toBe(true); + }); + + test("VJP with select on indices", () => { + const n = 2; + const f = fn([Bool], n, (p) => select(p, n, 0, 1)); + const g = fn([Bool], n, (p) => vjp(f)(p).ret); + const h = interp(g); + expect(h(true)).toBe(0); + expect(h(false)).toBe(1); + }); });