From d4a7216cdc01868a831891e2739b4d76651f2524 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Mon, 11 Sep 2023 09:51:44 -0400 Subject: [PATCH] Revert "Transpose functions (#86)" This reverts commit 1ee11a97f3f014846449f34e437769f8926ee99f. --- Cargo.lock | 9 - crates/interp/src/lib.rs | 46 +- crates/transpose/Cargo.toml | 8 - crates/transpose/src/lib.rs | 869 -------------------------------- crates/web/Cargo.toml | 1 - crates/web/src/lib.rs | 147 +----- packages/core/src/impl.test.ts | 1 - packages/core/src/impl.ts | 45 -- packages/core/src/index.test.ts | 50 +- packages/core/src/index.ts | 1 - 10 files changed, 22 insertions(+), 1155 deletions(-) delete mode 100644 crates/transpose/Cargo.toml delete mode 100644 crates/transpose/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 9f71a70..59c6c6f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -187,14 +187,6 @@ dependencies = [ "ts-rs", ] -[[package]] -name = "rose-transpose" -version = "0.0.0" -dependencies = [ - "enumset", - "rose", -] - [[package]] name = "rose-web" version = "0.0.0" @@ -206,7 +198,6 @@ dependencies = [ "rose", "rose-autodiff", "rose-interp", - "rose-transpose", "serde", "serde-wasm-bindgen", "wasm-bindgen", diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 4132195..8f0a7d1 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -16,7 +16,7 @@ pub enum Val { Bool(bool), F64(Cell), Fin(usize), - Ref(Rc, Option), + Ref(Rc), Array(Vals), // assume all indices are `Fin` Tuple(Vals), } @@ -50,35 +50,9 @@ impl Val { } } - fn fin(&self) -> usize { - match self { - &Val::Fin(i) => i, - _ => unreachable!(), - } - } - - fn get(&self, i: usize) -> &Self { - match self { - Val::Array(x) => &x[i], - Val::Tuple(x) => &x[i], - _ => unreachable!(), - } - } - - fn slice(&self, i: usize) -> Self { - match self { - Val::Ref(x, None) => Val::Ref(Rc::clone(x), Some(i)), - Val::Ref(x, Some(j)) => Val::Ref(Rc::new(x.get(*j).clone()), Some(i)), - _ => unreachable!(), - } - } - fn inner(&self) -> &Self { match self { - Val::Ref(x, i) => match i { - None => x.as_ref(), - &Some(j) => x.get(j), - }, + Val::Ref(x) => x.as_ref(), _ => unreachable!(), } } @@ -90,7 +64,7 @@ impl Val { &Self::Bool(x) => Self::Bool(x), Self::F64(_) => Self::F64(Cell::new(0.)), &Self::Fin(x) => Self::Fin(x), - Self::Ref(..) => unreachable!(), + Self::Ref(_) => unreachable!(), Self::Array(x) => Self::Array(collect_vals(x.iter().map(|x| x.zero()))), Self::Tuple(x) => Self::Tuple(collect_vals(x.iter().map(|x| x.zero()))), } @@ -218,8 +192,14 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { _ => unreachable!(), }, - &Expr::Slice { array, index } => self.get(array).slice(self.get(index).fin()), - &Expr::Field { tuple, member } => self.get(tuple).slice(member.member()), + &Expr::Slice { array, index } => match (self.get(array).inner(), self.get(index)) { + (Val::Array(v), &Val::Fin(i)) => v[i].clone(), + _ => unreachable!(), + }, + &Expr::Field { tuple, member } => match self.get(tuple).inner() { + Val::Tuple(x) => x[member.member()].clone(), + _ => unreachable!(), + }, &Expr::Unary { op, arg } => { let x = self.get(arg); @@ -277,8 +257,8 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { )) } - &Expr::Read { var } => Val::Ref(Rc::new(self.get(var).clone()), None), - &Expr::Accum { shape } => Val::Ref(Rc::new(self.get(shape).zero()), None), + &Expr::Read { var } => Val::Ref(Rc::new(self.get(var).clone())), + &Expr::Accum { shape } => Val::Ref(Rc::new(self.get(shape).zero())), &Expr::Ask { var } => self.get(var).inner().clone(), &Expr::Add { accum, addend } => { diff --git a/crates/transpose/Cargo.toml b/crates/transpose/Cargo.toml deleted file mode 100644 index c122e43..0000000 --- a/crates/transpose/Cargo.toml +++ /dev/null @@ -1,8 +0,0 @@ -[package] -name = "rose-transpose" -version = "0.0.0" -edition = "2021" - -[dependencies] -enumset = "1" -rose = { path = "../core" } diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs deleted file mode 100644 index 853a11c..0000000 --- a/crates/transpose/src/lib.rs +++ /dev/null @@ -1,869 +0,0 @@ -use enumset::EnumSet; -use rose::{id, Binop, Constraint, Expr, Func, Instr, Ty, Unop}; -use std::mem::{swap, take}; - -const REAL: id::Ty = id::ty(0); -const DUAL: id::Ty = id::ty(1); - -fn is_primitive(t: id::Ty) -> bool { - t == REAL || t == DUAL -} - -enum Scope { - Generic(id::Generic), - Derived(id::Var), - Original, -} - -/// Type ID of a variable while building up the backward pass. -enum BwdTy { - /// We already know the type ID of this variable. - /// - /// Usually this means the variable is directly mapped from a variable in the original function. - Known(id::Ty), - - /// This variable has the unit type. - Unit, - - /// This variable is an accumulator. - /// - /// After we process the entire function, we need to add a scope type for every new accumulator - /// variable we've introduced; we don't add these as we go because we don't know ahead of time - /// how many types we'll need for intermediate values, and we want all those intermediate value - /// type IDs to be the same between the forward and backward passes. - Accum(Scope, id::Ty), - - /// We don't know the type of this variable yet, but will soon. - /// - /// Usually this means the variable is a tuple of intermediate values, and we'll update its type - /// to something concrete when we reach the end of the block. - Unknown, -} - -/// The source of a primitive variable. -#[derive(Clone, Copy)] -enum Src { - /// This variable is original. - Original, - - /// This variable is an alias of an original primitive variable of the same type. - Alias(id::Var), - - /// This variable is a component of an original dual number variable. - Projection(id::Var), -} - -impl Src { - fn prim(self, x: id::Var) -> Self { - match self { - Self::Original => Self::Alias(x), - _ => self, - } - } - - fn dual(self, x: id::Var) -> Self { - match self { - Self::Original => Self::Projection(x), - _ => self, - } - } -} - -/// Linear variables in the backward pass for a variable from the original function. -struct Lin { - /// Accumulator variable. - acc: id::Var, - - /// Resolved cotangent variable. - cot: id::Var, -} - -/// A block under construction for both the forward pass and the backward pass. -struct Block { - /// Instructions in this forward pass block, in order. - fwd: Vec, - - /// Variable IDs for intermediate values to be saved at the end of this forward pass block. - inter_mem: Vec, - - /// Variable ID for the intermediate values tuple in this backward pass block. - inter_tup: id::Var, - - /// Instructions at the beginning of this backward pass block, in order. - bwd_nonlin: Vec, - - /// Instructions at the end of this backward pass block, in reverse order. - bwd_lin: Vec, -} - -/// The forward pass and backward pass of a transposed function under construction. -struct Transpose<'a> { - /// The function being transposed, which is usually a forward-mode derivative. - f: &'a Func, - - /// Shared types between the forward and backward passes. - /// - /// This starts out the same length as `f.types`, with the only difference being that all dual - /// number types are replaced with `F64`. Then we add more types only for tuples and arrays of - /// intermediate values that are shared between the two passes. - types: Vec, - - /// Types of variables in the forward pass. - /// - /// This starts out as a clone of `f.vars`, but more variables can be added for dealing with - /// intermediate values. - fwd_vars: Vec, - - /// Types of variables in the backward pass. - /// - /// This starts out with the `Known` type of every variable from `f.vars`, but more variables - /// can be added for intermediate values, accumulators, and cotangents. - bwd_vars: Vec, - - /// A variable of type `F64` defined at the beginning of the backward pass. - /// - /// Every accumulator must be initialized using a concrete variable to dictate its topology. For - /// most variables, we keep around the original nonlinear value and use that as the shape, but - /// this doesn't work for raw linear `F64` variables, which might be part of some lower-level - /// mathematical calculation that is not clearly attached to any value at the dual number level - /// or higher. All those values have the same shape, though, so in those cases we just use this - /// one dummy variable as the shape. - real_shape: id::Var, - - /// Sources of primitive variables from the original function. - prims: Box<[Option]>, - - /// Sources for dual number variables from the original function. - duals: Box<[Option<(Src, Src)>]>, - - /// Accumulator variables for variables from the original function. - accums: Box<[Option]>, - - /// Cotangent variables for variables from the original function. - cotans: Box<[Option]>, - - /// The current block under construction. - block: Block, -} - -impl<'a> Transpose<'a> { - fn re(&self, x: id::Var) -> Src { - let (src, _) = self.duals[x.var()].unwrap(); - src.dual(x) - } - - fn du(&self, x: id::Var) -> Src { - let (_, src) = self.duals[x.var()].unwrap(); - src.dual(x) - } - - fn get_prim(&self, x: id::Var) -> id::Var { - match self.prims[x.var()].unwrap() { - Src::Original => x, - Src::Alias(y) | Src::Projection(y) => y, - } - } - - fn get_re(&self, x: id::Var) -> id::Var { - match self.duals[x.var()] { - None | Some((Src::Original, _)) => x, - Some((Src::Alias(y), _)) | Some((Src::Projection(y), _)) => y, - } - } - - fn get_du(&self, x: id::Var) -> id::Var { - match self.duals[x.var()] { - None | Some((_, Src::Original)) => x, - Some((_, Src::Alias(y))) | Some((_, Src::Projection(y))) => y, - } - } - - fn get_prim_accum(&self, x: id::Var) -> id::Var { - self.accums[self.get_prim(x).var()].unwrap() - } - - fn get_dual_accum(&self, x: id::Var) -> Option { - self.accums[self.get_du(x).var()] - } - - fn ty(&mut self, ty: Ty) -> id::Ty { - let t = id::ty(self.types.len()); - self.types.push(ty); - t - } - - fn fwd_var(&mut self, t: id::Ty) -> id::Var { - let var = id::var(self.fwd_vars.len()); - self.fwd_vars.push(t); - var - } - - fn bwd_var(&mut self, t: BwdTy) -> id::Var { - let var = id::var(self.bwd_vars.len()); - self.bwd_vars.push(t); - var - } - - fn keep(&mut self, var: id::Var) { - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Member { - tuple: self.block.inter_tup, - member: id::member(self.block.inter_mem.len()), - }, - }); - self.block.inter_mem.push(var); - } - - fn accum(&mut self, shape: id::Var, scope: Scope) -> Lin { - let t = self.f.vars[shape.var()]; - let acc = self.bwd_var(BwdTy::Accum(scope, t)); - let cot = self.bwd_var(BwdTy::Known(t)); - self.block.bwd_nonlin.push(Instr { - var: acc, - expr: Expr::Accum { shape }, - }); - self.accums[shape.var()] = Some(acc); - self.cotans[shape.var()] = Some(cot); - Lin { acc, cot } - } - - fn calc(&mut self, tan: id::Var) -> Lin { - let t = self.f.vars[tan.var()]; - let acc = self.bwd_var(BwdTy::Accum(Scope::Original, t)); - let cot = self.bwd_var(BwdTy::Known(t)); - self.block.bwd_nonlin.push(Instr { - var: acc, - expr: Expr::Accum { - shape: self.real_shape, - }, - }); - self.accums[tan.var()] = Some(acc); - self.cotans[tan.var()] = Some(cot); - Lin { acc, cot } - } - - fn resolve(&mut self, lin: Lin) { - self.block.bwd_lin.push(Instr { - var: lin.cot, - expr: Expr::Resolve { var: lin.acc }, - }) - } - - fn block(&mut self, block: &[Instr]) -> (id::Ty, id::Var) { - for instr in block.iter() { - self.instr(instr.var, &instr.expr); - } - let vars = take(&mut self.block.inter_mem); - let t = self.ty(Ty::Tuple { - members: vars.iter().map(|&x| self.fwd_vars[x.var()]).collect(), - }); - let var = self.fwd_var(t); - self.block.fwd.push(Instr { - var, - expr: Expr::Tuple { - members: vars.into(), - }, - }); - self.bwd_vars[self.block.inter_tup.var()] = BwdTy::Known(t); - (t, var) - } - - fn instr(&mut self, var: id::Var, expr: &Expr) { - match expr { - Expr::Unit => { - self.block.fwd.push(Instr { - var, - expr: Expr::Unit, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Unit, - }); - let lin = self.accum(var, Scope::Original); - self.resolve(lin); - } - &Expr::Bool { val } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Bool { val }, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Bool { val }, - }); - let lin = self.accum(var, Scope::Original); - self.resolve(lin); - } - &Expr::F64 { val } => { - match self.f.vars[var.var()] { - DUAL => { - let lin = self.calc(var); - self.resolve(lin); - } - _ => self.block.fwd.push(Instr { - var, - expr: Expr::F64 { val }, - }), - } - self.prims[var.var()] = Some(Src::Original); - } - &Expr::Fin { val } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Fin { val }, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Fin { val }, - }); - let lin = self.accum(var, Scope::Original); - self.resolve(lin); - } - - Expr::Array { elems } => { - let t = match self.f.types[self.f.vars[var.var()].ty()] { - Ty::Array { index, elem: _ } => index, - _ => panic!(), - }; - self.block.fwd.push(Instr { - var, - expr: Expr::Array { - elems: elems.iter().map(|&elem| self.get_re(elem)).collect(), - }, - }); - self.keep(var); - let lin = self.accum(var, Scope::Original); - for (i, &elem) in elems.iter().enumerate() { - if let Some(accum) = self.get_dual_accum(elem) { - let index = self.bwd_var(BwdTy::Known(t)); - let addend = self.bwd_var(BwdTy::Known(self.f.vars[elem.var()])); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { accum, addend }, - }); - self.block.bwd_lin.push(Instr { - var: addend, - expr: Expr::Index { - array: lin.cot, - index, - }, - }); - self.block.bwd_lin.push(Instr { - var: index, - expr: Expr::Fin { val: i }, - }); - } - } - self.resolve(lin); - } - Expr::Tuple { members } => match self.types[self.f.vars[var.var()].ty()] { - Ty::F64 => { - let x = members[1]; - let dx = members[0]; - self.duals[var.var()] = Some(( - self.prims[x.var()].unwrap().prim(x), - self.prims[dx.var()].unwrap().prim(dx), - )); - } - _ => { - self.block.fwd.push(Instr { - var, - expr: Expr::Tuple { - members: members.iter().map(|&member| self.get_re(member)).collect(), - }, - }); - self.keep(var); - let lin = self.accum(var, Scope::Original); - for (i, &member) in members.iter().enumerate() { - if let Some(accum) = self.get_dual_accum(member) { - let addend = self.bwd_var(BwdTy::Known(self.f.vars[member.var()])); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { accum, addend }, - }); - self.block.bwd_lin.push(Instr { - var: addend, - expr: Expr::Member { - tuple: lin.cot, - member: id::member(i), - }, - }); - } - } - self.resolve(lin); - } - }, - - &Expr::Index { array, index } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Index { array, index }, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Index { array, index }, - }); - let arr_acc = self.accums[array.var()].unwrap(); - let acc = self.bwd_var(BwdTy::Accum( - Scope::Derived(arr_acc), - self.f.vars[array.var()], - )); - self.accums[var.var()] = Some(acc); - self.block.bwd_nonlin.push(Instr { - var: acc, - expr: Expr::Slice { - array: arr_acc, - index, - }, - }); - if let Ty::F64 = self.types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src::Original, Src::Original)); - } - } - &Expr::Member { tuple, member } => { - let t = self.f.vars[var.var()]; - match t { - REAL => self.prims[var.var()] = Some(self.re(tuple)), - DUAL => self.prims[var.var()] = Some(self.du(tuple)), - _ => { - self.block.fwd.push(Instr { - var, - expr: Expr::Member { tuple, member }, - }); - self.block.bwd_nonlin.push(Instr { - var, - expr: Expr::Member { tuple, member }, - }); - let tup_acc = self.accums[tuple.var()].unwrap(); - let acc = self.bwd_var(BwdTy::Accum( - Scope::Derived(tup_acc), - self.f.vars[tuple.var()], - )); - self.accums[var.var()] = Some(acc); - self.block.bwd_nonlin.push(Instr { - var: acc, - expr: Expr::Field { - tuple: tup_acc, - member, - }, - }); - if let Ty::F64 = self.types[t.ty()] { - self.duals[var.var()] = Some((Src::Original, Src::Original)); - } - } - } - } - - &Expr::Slice { array, index } => todo!(), - &Expr::Field { tuple, member } => todo!(), - - &Expr::Unary { op, arg } => { - match self.f.vars[var.var()] { - DUAL => match op { - Unop::Not | Unop::Abs | Unop::Sign | Unop::Sqrt => panic!(), - Unop::Neg => { - let lin = self.calc(var); - let res = self.bwd_var(BwdTy::Known(DUAL)); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { - accum: self.get_prim_accum(arg), - addend: res, - }, - }); - self.block.bwd_lin.push(Instr { - var: res, - expr: Expr::Unary { - op: Unop::Neg, - arg: lin.cot, - }, - }); - self.resolve(lin); - } - }, - _ => { - self.block.fwd.push(Instr { - var, - expr: Expr::Unary { - op, - arg: self.get_prim(arg), - }, - }); - self.keep(var); - } - } - self.prims[var.var()] = Some(Src::Original); - } - &Expr::Binary { op, left, right } => { - match self.f.vars[var.var()] { - DUAL => { - let lin = self.calc(var); - match op { - Binop::And - | Binop::Or - | Binop::Iff - | Binop::Xor - | Binop::Neq - | Binop::Lt - | Binop::Leq - | Binop::Eq - | Binop::Gt - | Binop::Geq => panic!(), - Binop::Add => { - let a = self.bwd_var(BwdTy::Unit); - let b = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: a, - expr: Expr::Add { - accum: self.get_prim_accum(left), - addend: lin.cot, - }, - }); - self.block.bwd_lin.push(Instr { - var: b, - expr: Expr::Add { - accum: self.get_prim_accum(right), - addend: lin.cot, - }, - }); - } - Binop::Sub => { - let res = self.bwd_var(BwdTy::Known(DUAL)); - let a = self.bwd_var(BwdTy::Unit); - let b = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: a, - expr: Expr::Add { - accum: self.get_prim_accum(left), - addend: lin.cot, - }, - }); - self.block.bwd_lin.push(Instr { - var: b, - expr: Expr::Add { - accum: self.get_prim_accum(right), - addend: res, - }, - }); - self.block.bwd_lin.push(Instr { - var: res, - expr: Expr::Unary { - op: Unop::Neg, - arg: lin.cot, - }, - }); - } - Binop::Mul | Binop::Div => { - let res = self.bwd_var(BwdTy::Known(DUAL)); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { - accum: self.get_prim_accum(left), - addend: res, - }, - }); - self.block.bwd_lin.push(Instr { - var: res, - expr: Expr::Binary { - op, - left: lin.cot, - right: self.get_prim(right), - }, - }); - } - } - self.resolve(lin); - } - _ => { - let (a, b) = match op { - Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right), - Binop::Neq - | Binop::Lt - | Binop::Leq - | Binop::Eq - | Binop::Gt - | Binop::Geq - | Binop::Add - | Binop::Sub - | Binop::Mul - | Binop::Div => (self.get_prim(left), self.get_prim(right)), - }; - self.block.fwd.push(Instr { - var, - expr: Expr::Binary { - op, - left: a, - right: b, - }, - }); - self.keep(var); - } - } - self.prims[var.var()] = Some(Src::Original); - } - &Expr::Select { cond, then, els } => { - self.block.fwd.push(Instr { - var, - expr: Expr::Select { - cond, - then: self.get_re(then), - els: self.get_re(els), - }, - }); - self.keep(var); - let lin = self.accum(var, Scope::Original); - let acc_then = self.get_dual_accum(then).unwrap(); - let acc_els = self.get_dual_accum(els).unwrap(); - // TODO: this scope is wrong; it actually needs to match both `then` and `els` - let acc = self.bwd_var(BwdTy::Accum(Scope::Original, self.f.vars[then.var()])); - let unit = self.bwd_var(BwdTy::Unit); - self.block.bwd_lin.push(Instr { - var: unit, - expr: Expr::Add { - accum: acc, - addend: lin.cot, - }, - }); - self.block.bwd_lin.push(Instr { - var: acc, - expr: Expr::Select { - cond, - then: acc_then, - els: acc_els, - }, - }); - self.resolve(lin); - if let Ty::F64 = self.types[self.f.vars[var.var()].ty()] { - self.duals[var.var()] = Some((Src::Original, Src::Original)); - } - } - - Expr::Call { id, generics, args } => todo!(), - Expr::For { arg, body, ret } => { - let mut block = Block { - fwd: vec![], - inter_mem: vec![], - inter_tup: self.bwd_var(BwdTy::Unknown), - bwd_nonlin: vec![], - bwd_lin: vec![], - }; - swap(&mut self.block, &mut block); - self.block(body); // TODO - swap(&mut self.block, &mut block); - } - - &Expr::Read { var } => todo!(), - &Expr::Accum { shape } => todo!(), - - &Expr::Ask { var } => todo!(), - &Expr::Add { accum, addend } => todo!(), - - &Expr::Resolve { var } => todo!(), - } - } -} - -/// Return two functions that make up the transpose of `f`. -pub fn transpose(f: &Func) -> (Func, Func) { - let types: Vec<_> = f - .types - .iter() - .enumerate() - .map(|(i, ty)| match ty { - Ty::Unit => Ty::Unit, - Ty::Bool => Ty::Bool, - Ty::F64 => { - if !is_primitive(id::ty(i)) { - panic!() - } - Ty::F64 - } - &Ty::Fin { size } => Ty::Fin { size }, - &Ty::Generic { id } => Ty::Generic { id }, - &Ty::Scope { kind, id } => Ty::Scope { kind, id }, - &Ty::Ref { scope, inner } => { - if is_primitive(inner) { - panic!() - } - Ty::Ref { scope, inner } - } - &Ty::Array { index, elem } => { - if is_primitive(elem) { - panic!() - } - Ty::Array { index, elem } - } - Ty::Tuple { members } => { - if members.iter().any(|&t| is_primitive(t)) { - Ty::F64 - } else { - Ty::Tuple { - members: members.clone(), - } - } - } - }) - .collect(); - - let mut bwd_generics: Vec<_> = f - .generics - .iter() - .map(|&before| { - let mut after = before; - if before.contains(Constraint::Read) { - after.remove(Constraint::Read); - after.insert(Constraint::Accum); - } - if before.contains(Constraint::Accum) { - after.remove(Constraint::Accum); - after.insert(Constraint::Read); - } - after - }) - .collect(); - - let mut bwd_vars: Vec<_> = f.vars.iter().map(|&t| BwdTy::Known(t)).collect(); - let real_shape = id::var(bwd_vars.len()); - bwd_vars.push(BwdTy::Known(DUAL)); - let inter_tup = id::var(bwd_vars.len()); - bwd_vars.push(BwdTy::Unknown); - - let mut duals = vec![None; f.vars.len()].into_boxed_slice(); - let mut accums = vec![None; f.vars.len()].into_boxed_slice(); - - let mut inter_mem = vec![]; - let mut bwd_nonlin = vec![]; - - let mut bwd_params: Vec<_> = f - .params - .iter() - .map(|¶m| { - let g = id::generic(bwd_generics.len()); - bwd_generics.push(EnumSet::only(Constraint::Accum)); - let t = f.vars[param.var()]; - bwd_nonlin.push(Instr { - var: param, - expr: Expr::Member { - tuple: inter_tup, - member: id::member(inter_mem.len()), - }, - }); - inter_mem.push(param); - let acc = id::var(bwd_vars.len()); - bwd_vars.push(BwdTy::Accum(Scope::Generic(g), t)); - if let Ty::F64 = types[t.ty()] { - duals[param.var()] = Some((Src::Original, Src::Original)); - } - accums[param.var()] = Some(acc); - acc - }) - .collect(); - - let mut tp = Transpose { - f, - types, - fwd_vars: f.vars.to_vec(), - bwd_vars, - real_shape, - prims: vec![None; f.vars.len()].into(), - duals, - accums, - cotans: vec![None; f.vars.len()].into(), - block: Block { - fwd: vec![], - inter_tup, - inter_mem, - bwd_nonlin, - bwd_lin: vec![], - }, - }; - - let (t_intermediates, fwd_inter) = tp.block(&f.body); - let fwd_ret = tp.get_re(f.ret); - let bwd_acc = tp.get_dual_accum(f.ret).unwrap(); - let mut bwd_types = tp.types.clone(); - - let mut fwd_types = tp.types; - let t_bundle = id::ty(fwd_types.len()); - fwd_types.push(Ty::Tuple { - members: vec![f.vars[f.ret.var()], t_intermediates].into(), - }); - let mut fwd_vars = tp.fwd_vars; - let fwd_bundle = id::var(fwd_vars.len()); - fwd_vars.push(t_bundle); - let mut fwd_body = tp.block.fwd; - fwd_body.push(Instr { - var: fwd_bundle, - expr: Expr::Tuple { - members: vec![fwd_ret, fwd_inter].into(), - }, - }); - - let t_unit = id::ty(bwd_types.len()); - bwd_types.push(Ty::Unit); - let mut bwd_vars: Vec<_> = tp - .bwd_vars - .into_iter() - .enumerate() - .map(|(i, t)| match t { - BwdTy::Known(t) => t, - BwdTy::Unit => t_unit, - BwdTy::Accum(scope, inner) => { - let scope = id::ty(bwd_types.len()); - bwd_types.push(Ty::Scope { - kind: Constraint::Accum, - id: id::var(i), - }); - let t = id::ty(bwd_types.len()); - bwd_types.push(Ty::Ref { scope, inner }); - t - } - BwdTy::Unknown => panic!(), - }) - .collect(); - let bwd_cot = id::var(bwd_vars.len()); - bwd_vars.push(f.vars[f.ret.var()]); - let bwd_unit = id::var(bwd_vars.len()); - bwd_vars.push(t_unit); - bwd_params.push(bwd_cot); - bwd_params.push(tp.block.inter_tup); - let mut bwd_body = vec![Instr { - var: tp.real_shape, - expr: Expr::F64 { val: 0. }, - }]; - bwd_body.append(&mut tp.block.bwd_nonlin); - let mut bwd_linear = tp.block.bwd_lin; - bwd_linear.push(Instr { - var: bwd_unit, - expr: Expr::Add { - accum: bwd_acc, - addend: bwd_cot, - }, - }); - bwd_linear.reverse(); - bwd_body.append(&mut bwd_linear); - - ( - Func { - generics: f.generics.clone(), - types: fwd_types.into(), - vars: fwd_vars.into(), - params: f.params.clone(), - ret: fwd_bundle, - body: fwd_body.into(), - }, - Func { - generics: bwd_generics.into(), - types: bwd_types.into(), - vars: bwd_vars.into(), - params: bwd_params.into(), - ret: bwd_unit, - body: bwd_body.into(), - }, - ) -} diff --git a/crates/web/Cargo.toml b/crates/web/Cargo.toml index a4ee556..853cd68 100644 --- a/crates/web/Cargo.toml +++ b/crates/web/Cargo.toml @@ -15,7 +15,6 @@ js-sys = "0.3" rose = { path = "../core" } rose-autodiff = { path = "../autodiff" } rose-interp = { path = "../interp", features = ["serde"] } -rose-transpose = { path = "../transpose" } serde = { version = "1", features = ["derive"] } serde-wasm-bindgen = "0.4" wasm-bindgen = "=0.2.87" # Must be this version of wbg diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 3511125..ec26335 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -40,7 +40,6 @@ pub fn layouts() -> Result { ("Func", layout::()), ("Instr", layout::()), ("Ty", layout::()), - ("Val", layout::()), ]) } @@ -125,10 +124,6 @@ struct Pointee { structs: Box<[Option>]>, jvp: RefCell>>, - - fwd: RefCell>>, - - bwd: RefCell>>, } /// A node in a reference-counted acyclic digraph of functions. @@ -154,8 +149,6 @@ impl Func { }, structs: [].into(), jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), }), } } @@ -268,7 +261,6 @@ impl Func { inner, structs, jvp, - .. } = self.rc.as_ref(); let mut cache = jvp.borrow_mut(); if let Some(rc) = cache.as_ref().and_then(|weak| weak.upgrade()) { @@ -294,8 +286,6 @@ impl Func { }, structs: structs_jvp.into(), jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), }) } Inner::Opaque { .. } => todo!(), @@ -303,100 +293,6 @@ impl Func { *cache = Some(Rc::downgrade(&rc)); Self { rc } } - - fn transpose_pair(&self) -> (Self, Self) { - let Pointee { - inner, - structs, - fwd, - bwd, - .. - } = self.rc.as_ref(); - let mut cache_fwd = fwd.borrow_mut(); - let mut cache_bwd = bwd.borrow_mut(); - if let (Some(rc_fwd), Some(rc_bwd)) = ( - cache_fwd.as_ref().and_then(|weak| weak.upgrade()), - cache_bwd.as_ref().and_then(|weak| weak.upgrade()), - ) { - return (Self { rc: rc_fwd }, Self { rc: rc_bwd }); - } - let (rc_fwd, rc_bwd) = match inner { - Inner::Transparent { deps, def } => { - let (deps_fwd, deps_bwd): (Vec<_>, Vec<_>) = - deps.iter().map(|f| f.transpose_pair()).unzip(); - let (def_fwd, def_bwd) = rose_transpose::transpose(def); - let structs_fwd = def_fwd - .types - .iter() - .enumerate() - .map(|(i, ty)| match ty { - rose::Ty::F64 => None, - _ => structs.get(i).cloned().flatten(), - }) - .collect(); - let structs_bwd = def_bwd - .types - .iter() - .enumerate() - .map(|(i, ty)| match ty { - rose::Ty::F64 => None, - _ => structs.get(i).cloned().flatten(), - }) - .collect(); - ( - Rc::new(Pointee { - inner: Inner::Transparent { - deps: deps_fwd.into(), - def: def_fwd, - }, - structs: structs_fwd, - jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), - }), - Rc::new(Pointee { - inner: Inner::Transparent { - deps: deps_bwd.into(), - def: def_bwd, - }, - structs: structs_bwd, - jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), - }), - ) - } - Inner::Opaque { .. } => panic!(), - }; - *cache_fwd = Some(Rc::downgrade(&rc_fwd)); - *cache_bwd = Some(Rc::downgrade(&rc_bwd)); - (Self { rc: rc_fwd }, Self { rc: rc_bwd }) - } - - pub fn transpose(&self) -> Transpose { - let (fwd, bwd) = self.transpose_pair(); - Transpose { - fwd: Some(fwd), - bwd: Some(bwd), - } - } -} - -#[wasm_bindgen] -pub struct Transpose { - fwd: Option, - bwd: Option, -} - -#[wasm_bindgen] -impl Transpose { - pub fn fwd(&mut self) -> Option { - self.fwd.take() - } - - pub fn bwd(&mut self) -> Option { - self.bwd.take() - } } #[cfg(feature = "debug")] @@ -602,10 +498,6 @@ enum Ty { Fin { size: usize, }, - Scope { - kind: rose::Constraint, - id: id::Var, - }, Ref { scope: id::Ty, inner: id::Ty, @@ -618,7 +510,7 @@ enum Ty { /// A tuple type, with additional information about key names that makes it into a struct. Struct { /// String IDs for key names, in order; the actual strings are stored in JavaScript. - keys: Option>, + keys: Box<[usize]>, /// Member types of the underlying tuple. Must be the same length as `keys`. members: Box<[id::Ty]>, @@ -633,10 +525,9 @@ impl Ty { Ty::Bool => (rose::Ty::Bool, None), Ty::F64 => (rose::Ty::F64, None), Ty::Fin { size } => (rose::Ty::Fin { size }, None), - Ty::Scope { kind, id } => (rose::Ty::Scope { kind, id }, None), Ty::Ref { scope, inner } => (rose::Ty::Ref { scope, inner }, None), Ty::Array { index, elem } => (rose::Ty::Array { index, elem }, None), - Ty::Struct { keys, members } => (rose::Ty::Tuple { members }, keys), + Ty::Struct { keys, members } => (rose::Ty::Tuple { members }, Some(keys)), } } } @@ -726,8 +617,6 @@ impl FuncBuilder { }, structs: structs.into(), jvp: RefCell::new(None), - fwd: RefCell::new(None), - bwd: RefCell::new(None), }), } } @@ -831,10 +720,7 @@ impl FuncBuilder { /// `Err` if `t` is out of range or does not represent a struct type. pub fn keys(&self, t: usize) -> Result, JsError> { match self.ty(t)? { - Ty::Struct { - keys: Some(keys), - members: _, - } => Ok(keys.clone()), + Ty::Struct { keys, members: _ } => Ok(keys.clone()), _ => Err(JsError::new("type is not a struct")), } } @@ -912,13 +798,6 @@ impl FuncBuilder { ) } - #[wasm_bindgen(js_name = "tyAccum")] - pub fn ty_accum(&mut self, id: usize) -> usize { - let kind = rose::Constraint::Accum; - let id = id::var(id); - self.newtype(Ty::Scope { kind, id }, EnumSet::only(kind)) - } - /// Return the ID for the type of arrays with index type `index` and element type `elem`, /// /// Assumes `index` and `elem` are valid type IDs. @@ -946,7 +825,7 @@ impl FuncBuilder { pub fn ty_struct(&mut self, keys: &[usize], mems: &[usize]) -> usize { self.newtype( Ty::Struct { - keys: Some(keys.into()), + keys: keys.into(), members: mems.iter().map(|&t| id::ty(t)).collect(), }, EnumSet::only(rose::Constraint::Value), @@ -1131,7 +1010,10 @@ impl FuncBuilder { Ty::Struct { keys: structs[t] .as_ref() - .map(|ss| ss.iter().map(|&s| strings[s]).collect()), + .unwrap() + .iter() + .map(|&s| strings[s]) + .collect(), members: members .iter() .map(|x| types[x.ty()]) @@ -1557,17 +1439,4 @@ impl Block { }; self.instr(f, id::ty(t), expr) } - - pub fn accum(&mut self, f: &mut FuncBuilder, inner: usize, shape: usize) -> usize { - let inner = id::ty(inner); - let shape = id::var(shape); - let scope = id::ty(f.ty_accum(f.vars.len())); - let t = id::ty(f.newtype(Ty::Ref { scope, inner }, EnumSet::empty())); - self.instr(f, t, rose::Expr::Accum { shape }) - } - - pub fn resolve(&mut self, f: &mut FuncBuilder, t: usize, var: usize) -> usize { - let expr = rose::Expr::Resolve { var: id::var(var) }; - self.instr(f, id::ty(t), expr) - } } diff --git a/packages/core/src/impl.test.ts b/packages/core/src/impl.test.ts index b8686cf..0aa4727 100644 --- a/packages/core/src/impl.test.ts +++ b/packages/core/src/impl.test.ts @@ -39,7 +39,6 @@ test("core IR type layouts", () => { Func: { size: 44, align: 4 }, Instr: { size: 32, align: 8 }, Ty: { size: 12, align: 4 }, - Val: { size: 16, align: 8 }, }); }); diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 10bd300..2da93cf 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -595,51 +595,6 @@ export const jvp = ( return g; }; -export const vjp = ( - f: Fn & ((arg: A) => R), -): ((arg: A) => { ret: R; grad: (cot: R) => A }) => { - const g = jvp(f); - const tp = g[inner].transpose(); - const fwdFunc = tp.fwd()!; - const bwdFunc = tp.bwd()!; - const fwd: Fn = { [inner]: fwdFunc, [strings]: [...f[strings]] }; - const bwd: Fn = { [inner]: bwdFunc, [strings]: [...f[strings]] }; - funcs.register(fwd, fwdFunc); - funcs.register(bwd, bwdFunc); - return (arg: A) => { - const ctx = getCtx(); - const strs = intern(ctx, fwd[strings]); - const generics = new Uint32Array(); // TODO: support generics - const [tArg, tBundle] = ctx.func.ingest(fwd[inner], strs, generics); - const [tRet, tInter] = ctx.func.members(tBundle); - const argId = valId(ctx, tArg, arg); - const bundleId = ctx.block.call( - ctx.func, - fwd[inner], - generics, - tBundle, - new Uint32Array([argId]), - ); - const primalId = ctx.block.member(ctx.func, tRet, bundleId, 0); - const interId = ctx.block.member(ctx.func, tInter, bundleId, 1); - const grad = (cot: R) => { - if (getCtx() !== ctx) throw Error("VJP closure escaped its context"); - const cotId = valId(ctx, tRet, cot); - const accId = ctx.block.accum(ctx.func, tArg, argId); - const tScope = ctx.func.tyAccum(accId); - ctx.block.call( - ctx.func, - bwd[inner], - new Uint32Array([tScope]), // TODO: support generics - ctx.func.tyUnit(), - new Uint32Array([accId, cotId, interId]), - ); - return idVal(ctx, tArg, ctx.block.resolve(ctx.func, tArg, accId)) as A; - }; - return { ret: idVal(ctx, tRet, primalId) as R, grad }; - }; -}; - /** Return the variable ID for the abstract boolean `x`. */ const boolId = (ctx: Context, x: Bool): number => valId(ctx, ctx.func.tyBool(), x); diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 6e105ba..c99fc20 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -12,13 +12,11 @@ import { interp, jvp, mul, - or, select, sign, sqrt, sub, vec, - vjp, } from "./index.js"; describe("invalid", () => { @@ -347,56 +345,10 @@ describe("valid", () => { test("JVP with sharing in call graph", () => { let f = fn([Real], Real, (x) => x); - for (let i = 0; i < 20; ++i) { + for (let i = 0; i < 20; i++) { f = fn([Real], Real, (x) => add(f(x), f(x))); } const g = interp(jvp(f)); expect(g({ re: 2, du: 3 })).toEqual({ re: 2097152, du: 3145728 }); }); - - test("VJP", () => { - const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1])); - const g = fn([], Vec(3, Real), () => { - const { ret: x, grad } = vjp(f)([2, 3]); - const v = grad(1); - return [x, v[0], v[1]]; - }); - expect(interp(g)()).toEqual([6, 3, 2]); - }); - - test("VJP with struct and select", () => { - const Stuff = { a: Null, b: Bool, c: Real } as const; - const f = fn([Stuff], Real, ({ b, c }) => select(or(false, b), Real, c, 2)); - const g = fn([Bool, Real], { x: Real, stuff: Stuff }, (b, c) => { - const { ret: x, grad } = vjp(f)({ a: null, b, c }); - return { x, stuff: grad(3) }; - }); - const h = interp(g); - expect(h(true, 5)).toEqual({ x: 5, stuff: { a: null, b: true, c: 3 } }); - expect(h(false, 7)).toEqual({ x: 2, stuff: { a: null, b: false, c: 0 } }); - }); - - test("VJP with select on null", () => { - const f = fn([Null], Null, () => select(true, Null, null, null)); - const g = fn([], Null, () => vjp(f)(null).ret); - const h = interp(g); - expect(h()).toBe(null); - }); - - test("VJP with select on booleans", () => { - const f = fn([Bool], Bool, (p) => select(p, Bool, false, true)); - const g = fn([Bool], Bool, (p) => vjp(f)(p).ret); - const h = interp(g); - expect(h(true)).toBe(false); - expect(h(false)).toBe(true); - }); - - test("VJP with select on indices", () => { - const n = 2; - const f = fn([Bool], n, (p) => select(p, n, 0, 1)); - const g = fn([Bool], n, (p) => vjp(f)(p).ret); - const h = interp(g); - expect(h(true)).toBe(0); - expect(h(false)).toBe(1); - }); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index fad6ed2..602c51b 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -27,6 +27,5 @@ export { sqrt, sub, vec, - vjp, xor, } from "./impl.js";