diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 35c2062..185c629 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -160,7 +160,7 @@ impl Opaque for Infallible { /// basically, the `'a` lifetime is for the graph of functions, and the `'b` lifetime is just for /// this particular instance of interpretation -struct Interpreter<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> { +struct Interpreter<'a, 'b, O, T: Refs<'a, Opaque = O>> { typemap: &'b mut IndexSet, refs: T, def: &'a Func, diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs index f1e0455..790aa6a 100644 --- a/crates/wasm/src/lib.rs +++ b/crates/wasm/src/lib.rs @@ -52,7 +52,7 @@ type Imports = IndexMap<(O, Box<[id::Ty]>), (Box<[id::Ty]>, id::Ty)>; type Funcs<'a, T> = IndexMap<(ByAddress<&'a Func>, Box<[id::Ty]>), (T, Box<[id::Ty]>)>; /// Computes a topological sort of a call graph via depth-first search. -struct Topsort<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>> { +struct Topsort<'a, O, T> { /// All types seen so far. types: IndexSet, @@ -63,7 +63,7 @@ struct Topsort<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>> { funcs: Funcs<'a, T>, } -impl<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>> Topsort<'a, O, T> { +impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Topsort<'a, O, T> { /// Search in the given `block` of `f`, using `refs` to resolve immediate function calls. /// /// The `types` argument is the resolved type ID for each of `f.types` in `self.types`. @@ -86,9 +86,7 @@ impl<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>> Topsort<'a, O, T> { | Expr::Accum { .. } | Expr::Add { .. } | Expr::Resolve { .. } => {} - Expr::For { body, .. } => { - self.block(refs, f, types, body); - } + Expr::For { body, .. } => self.block(refs, f, types, body), Expr::Call { id, generics, args } => { let gens = generics.iter().map(|t| types[t.ty()]).collect(); match refs.get(*id).unwrap() { @@ -321,7 +319,7 @@ struct Meta { } /// Generates WebAssembly code for a function. -struct Codegen<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> { +struct Codegen<'a, 'b, O, T> { /// Metadata about all the types in the global type index. metas: &'b [Meta], @@ -359,7 +357,7 @@ struct Codegen<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> { wasm: Function, } -impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { +impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { /// Return metadata for the type ID `t` in the current function. /// /// Do not use this if your type ID is already resolved to refer to the global type index. @@ -758,7 +756,7 @@ pub struct Wasm { } /// Compile `f` and all its direct and indirect callees to a WebAssembly module. -pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> Wasm { +pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> Wasm { let mut topsort = Topsort { types: IndexSet::new(), imports: IndexMap::new(), diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 748989a..4f09527 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "debug")] +mod pprint; + use by_address::ByAddress; use enumset::EnumSet; use indexmap::{IndexMap, IndexSet}; @@ -143,6 +146,13 @@ pub struct Func { rc: Rc, } +#[cfg(feature = "debug")] +impl std::fmt::Display for Func { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + pprint::write_graph(f, self.node()) + } +} + #[wasm_bindgen] impl Func { /// Return an opaque function taking `params` `F64` parameters and returning `F64`. @@ -189,6 +199,12 @@ impl Func { } } + #[cfg(feature = "debug")] + #[wasm_bindgen] + pub fn pprint(&self) -> String { + format!("{self}") + } + /// Return the IDs of this function's parameter types. #[wasm_bindgen(js_name = "paramTypes")] pub fn param_types(&self) -> Box<[usize]> { @@ -507,196 +523,6 @@ impl Transpose { } } -#[cfg(feature = "debug")] -#[wasm_bindgen] -pub fn pprint(f: &Func) -> Result { - use std::fmt::Write as _; // see https://doc.rust-lang.org/std/macro.write.html - - fn print_instr( - mut s: &mut String, - def: &rose::Func, - spaces: usize, - instr: &rose::Instr, - ) -> Result<(), JsError> { - for _ in 0..spaces { - write!(&mut s, " ")?; - } - let x = instr.var.var(); - write!(&mut s, "x{}: T{} = ", x, def.vars[x].ty())?; - match &instr.expr { - rose::Expr::Unit => writeln!(&mut s, "unit")?, - rose::Expr::Bool { val } => writeln!(&mut s, "{val}")?, - rose::Expr::F64 { val } => writeln!(&mut s, "{val}")?, - rose::Expr::Fin { val } => writeln!(&mut s, "{val}")?, - rose::Expr::Array { elems } => { - write!(&mut s, "[")?; - print_elems(s, 'x', elems.iter().map(|elem| elem.var()))?; - writeln!(&mut s, "]")?; - } - rose::Expr::Tuple { members } => { - write!(&mut s, "(")?; - print_elems(s, 'x', members.iter().map(|member| member.var()))?; - writeln!(&mut s, ")")?; - } - rose::Expr::Index { array, index } => { - writeln!(&mut s, "x{}[x{}]", array.var(), index.var())? - } - rose::Expr::Member { tuple, member } => { - writeln!(&mut s, "x{}[{}]", tuple.var(), member.member())? - } - rose::Expr::Slice { array, index } => { - writeln!(&mut s, "&x{}[x{}]", array.var(), index.var())? - } - rose::Expr::Field { tuple, member } => { - writeln!(&mut s, "&x{}[{}]", tuple.var(), member.member())? - } - rose::Expr::Unary { op, arg } => match op { - rose::Unop::Not => writeln!(&mut s, "not x{}", arg.var())?, - rose::Unop::Neg => writeln!(&mut s, "-x{}", arg.var())?, - rose::Unop::Abs => writeln!(&mut s, "|x{}|", arg.var())?, - rose::Unop::Sign => writeln!(&mut s, "sign(x{})", arg.var())?, - rose::Unop::Ceil => writeln!(&mut s, "ceil(x{})", arg.var())?, - rose::Unop::Floor => writeln!(&mut s, "floor(x{})", arg.var())?, - rose::Unop::Trunc => writeln!(&mut s, "trunc(x{})", arg.var())?, - rose::Unop::Sqrt => writeln!(&mut s, "sqrt(x{})", arg.var())?, - }, - rose::Expr::Binary { op, left, right } => match op { - rose::Binop::And => writeln!(&mut s, "x{} and x{}", left.var(), right.var())?, - rose::Binop::Or => writeln!(&mut s, "x{} or x{}", left.var(), right.var())?, - rose::Binop::Iff => writeln!(&mut s, "x{} iff x{}", left.var(), right.var())?, - rose::Binop::Xor => writeln!(&mut s, "x{} xor x{}", left.var(), right.var())?, - rose::Binop::Neq => writeln!(&mut s, "x{} != x{}", left.var(), right.var())?, - rose::Binop::Lt => writeln!(&mut s, "x{} < x{}", left.var(), right.var())?, - rose::Binop::Leq => writeln!(&mut s, "x{} <= x{}", left.var(), right.var())?, - rose::Binop::Eq => writeln!(&mut s, "x{} == x{}", left.var(), right.var())?, - rose::Binop::Gt => writeln!(&mut s, "x{} > x{}", left.var(), right.var())?, - rose::Binop::Geq => writeln!(&mut s, "x{} >= x{}", left.var(), right.var())?, - rose::Binop::Add => writeln!(&mut s, "x{} + x{}", left.var(), right.var())?, - rose::Binop::Sub => writeln!(&mut s, "x{} - x{}", left.var(), right.var())?, - rose::Binop::Mul => writeln!(&mut s, "x{} * x{}", left.var(), right.var())?, - rose::Binop::Div => writeln!(&mut s, "x{} / x{}", left.var(), right.var())?, - }, - rose::Expr::Select { cond, then, els } => { - writeln!(&mut s, "x{} ? x{} : x{}", cond.var(), then.var(), els.var())? - } - rose::Expr::Call { id, generics, args } => { - write!(&mut s, "f{}<", id.func())?; - print_elems(s, 'T', generics.iter().map(|generic| generic.ty()))?; - write!(&mut s, ">(")?; - print_elems(s, 'x', args.iter().map(|arg| arg.var()))?; - writeln!(&mut s, ")")?; - } - rose::Expr::For { arg, body, ret } => { - writeln!( - &mut s, - "for x{}: T{} {{", - arg.var(), - def.vars[arg.var()].ty() - )?; - print_block(s, def, spaces + 2, body, *ret)?; - for _ in 0..spaces { - write!(&mut s, " ")?; - } - writeln!(&mut s, "}}")? - } - rose::Expr::Accum { shape } => writeln!(&mut s, "accum x{}", shape.var())?, - rose::Expr::Add { accum, addend } => { - writeln!(&mut s, "x{} += x{}", accum.var(), addend.var())? - } - rose::Expr::Resolve { var } => writeln!(&mut s, "resolve x{}", var.var())?, - } - Ok(()) - } - - fn print_block( - mut s: &mut String, - def: &rose::Func, - spaces: usize, - body: &[rose::Instr], - ret: id::Var, - ) -> Result<(), JsError> { - for instr in body.iter() { - print_instr(s, def, spaces, instr)?; - } - for _ in 0..spaces { - write!(&mut s, " ")?; - } - writeln!(&mut s, "x{}", ret.var())?; - Ok(()) - } - - fn print_elems( - s: &mut String, - prefix: char, - items: impl Iterator, - ) -> std::fmt::Result { - let mut first = true; - for item in items { - if first { - first = false; - } else { - write!(s, ", ")?; - } - write!(s, "{}{}", prefix, item)?; - } - Ok(()) - } - - let mut s = String::new(); - let Pointee { inner, .. } = f.rc.as_ref(); - let def = match inner { - Inner::Transparent { def, .. } => def, - Inner::Opaque { .. } => return Err(JsError::new("opaque function")), - }; - - for (i, constraints) in def.generics.iter().enumerate() { - write!(&mut s, "G{i} = ")?; - let mut first = true; - for constraint in constraints.iter() { - if first { - first = false; - } else { - write!(&mut s, " + ")?; - } - write!(&mut s, "{constraint:?}")?; - } - writeln!(&mut s)?; - } - for (i, ty) in def.types.iter().enumerate() { - write!(&mut s, "T{i} = ")?; - match ty { - rose::Ty::Unit | rose::Ty::Bool | rose::Ty::F64 => writeln!(&mut s, "{ty:?}")?, - rose::Ty::Fin { size } => writeln!(&mut s, "{size}")?, - rose::Ty::Generic { id } => writeln!(&mut s, "G{}", id.generic())?, - rose::Ty::Ref { inner } => writeln!(&mut s, "&T{}", inner.ty())?, - rose::Ty::Array { index, elem } => writeln!(&mut s, "[T{}]T{}", index.ty(), elem.ty())?, - rose::Ty::Tuple { members } => { - write!(&mut s, "(")?; - print_elems(&mut s, 'T', members.iter().map(|member| member.ty()))?; - writeln!(&mut s, ")")?; - } - } - } - write!(&mut s, "(")?; - let mut first = true; - for param in def.params.iter() { - if first { - first = false; - } else { - write!(&mut s, ", ")?; - } - write!(&mut s, "x{}: T{}", param.var(), def.vars[param.var()].ty())?; - } - writeln!(&mut s, ") -> T{} {{", def.vars[def.ret.var()].ty())?; - for instr in def.body.iter() { - print_instr(&mut s, def, 2, instr)?; - } - writeln!(&mut s, " x{}", def.ret.var())?; - writeln!(&mut s, "}}")?; - - Ok(s) -} - /// A type, with key name information in the case of tuples (which thus become structs). #[derive(Clone, Debug, Eq, Hash, PartialEq)] enum Ty { diff --git a/crates/web/src/pprint.rs b/crates/web/src/pprint.rs new file mode 100644 index 0000000..df38690 --- /dev/null +++ b/crates/web/src/pprint.rs @@ -0,0 +1,299 @@ +use by_address::ByAddress; +use enumset::EnumSet; +use indexmap::{IndexMap, IndexSet}; +use rose::{id, Binop, Constraint, Expr, Func, Instr, Node, Refs, Ty, Unop}; +use std::{fmt, hash::Hash}; + +fn write_constraints(f: &mut fmt::Formatter<'_>, constraints: EnumSet) -> fmt::Result { + let mut first = true; + for constraint in constraints.iter() { + if first { + first = false; + } else { + write!(f, " + ")?; + } + write!(f, "{constraint:?}")?; + } + Ok(()) +} + +fn write_generics(f: &mut fmt::Formatter<'_>, generics: &[EnumSet]) -> fmt::Result { + write!(f, "<")?; + let mut first = true; + for (i, &constraints) in generics.iter().enumerate() { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "G{i}: ")?; + write_constraints(f, constraints)?; + } + write!(f, ">") +} + +fn write_types(f: &mut fmt::Formatter<'_>, types: &[Ty]) -> fmt::Result { + for (i, ty) in types.iter().enumerate() { + write!(f, " type T{i} = ")?; + match ty { + Ty::Unit | Ty::Bool | Ty::F64 => writeln!(f, "{ty:?}")?, + Ty::Fin { size } => writeln!(f, "{size}")?, + Ty::Generic { id } => writeln!(f, "G{}", id.generic())?, + Ty::Ref { inner } => writeln!(f, "&T{}", inner.ty())?, + Ty::Array { index, elem } => writeln!(f, "[T{}]T{}", index.ty(), elem.ty())?, + Ty::Tuple { members } => { + write!(f, "(")?; + write_elems(f, 'T', members.iter().map(|member| member.ty()))?; + writeln!(f, ")")?; + } + } + } + Ok(()) +} + +fn write_opaque( + f: &mut fmt::Formatter<'_>, + generics: &[EnumSet], + types: &[Ty], + params: &[id::Ty], + ret: id::Ty, +) -> fmt::Result { + write_generics(f, generics)?; + writeln!(f, "{{")?; + write_types(f, types)?; + write!(f, " opaque: (")?; + write_elems(f, 'T', params.iter().map(|t| t.ty()))?; + writeln!(f, ") -> T{}", ret.ty())?; + writeln!(f, "}}") +} + +fn search<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>( + f: &mut fmt::Formatter<'_>, + imports: &mut IndexSet, + funcs: &mut IndexMap, T>, + refs: &T, + block: &[Instr], +) -> fmt::Result { + for instr in block.iter() { + match &instr.expr { + &Expr::Call { id, .. } => match refs.get(id).unwrap() { + Node::Transparent { refs, def } => { + let key = ByAddress(def); + if !funcs.contains_key(&key) { + search(f, imports, funcs, &refs, &def.body)?; + funcs.insert(key, refs); + } + } + Node::Opaque { + generics, + types, + params, + ret, + def, + } => { + let (i, new) = imports.insert_full(def); + if new { + write!(f, "fn f{i} = ")?; + write_opaque(f, generics, types, params, ret)?; + writeln!(f)?; + } + } + }, + Expr::For { body, .. } => search(f, imports, funcs, refs, body)?, + _ => {} + } + } + Ok(()) +} + +fn write_elems( + f: &mut fmt::Formatter<'_>, + prefix: char, + items: impl Iterator, +) -> std::fmt::Result { + let mut first = true; + for item in items { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "{}{}", prefix, item)?; + } + Ok(()) +} + +struct Function<'a, 'b, O, T> { + imports: &'b IndexSet, + funcs: &'b IndexMap, T>, + refs: &'b T, + def: &'a Func, +} + +impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> { + fn write_instr(&self, f: &mut fmt::Formatter<'_>, spaces: usize, instr: &Instr) -> fmt::Result { + for _ in 0..spaces { + write!(f, " ")?; + } + let x = instr.var.var(); + write!(f, "let x{}: T{} = ", x, self.def.vars[x].ty())?; + match &instr.expr { + Expr::Unit => writeln!(f, "unit")?, + Expr::Bool { val } => writeln!(f, "{val}")?, + Expr::F64 { val } => writeln!(f, "{val}")?, + Expr::Fin { val } => writeln!(f, "{val}")?, + Expr::Array { elems } => { + write!(f, "[")?; + write_elems(f, 'x', elems.iter().map(|elem| elem.var()))?; + writeln!(f, "]")?; + } + Expr::Tuple { members } => { + write!(f, "(")?; + write_elems(f, 'x', members.iter().map(|member| member.var()))?; + writeln!(f, ")")?; + } + Expr::Index { array, index } => writeln!(f, "x{}[x{}]", array.var(), index.var())?, + Expr::Member { tuple, member } => writeln!(f, "x{}[{}]", tuple.var(), member.member())?, + Expr::Slice { array, index } => writeln!(f, "&x{}[x{}]", array.var(), index.var())?, + Expr::Field { tuple, member } => writeln!(f, "&x{}[{}]", tuple.var(), member.member())?, + Expr::Unary { op, arg } => match op { + Unop::Not => writeln!(f, "not x{}", arg.var())?, + Unop::Neg => writeln!(f, "-x{}", arg.var())?, + Unop::Abs => writeln!(f, "|x{}|", arg.var())?, + Unop::Sign => writeln!(f, "sign(x{})", arg.var())?, + Unop::Ceil => writeln!(f, "ceil(x{})", arg.var())?, + Unop::Floor => writeln!(f, "floor(x{})", arg.var())?, + Unop::Trunc => writeln!(f, "trunc(x{})", arg.var())?, + Unop::Sqrt => writeln!(f, "sqrt(x{})", arg.var())?, + }, + Expr::Binary { op, left, right } => match op { + Binop::And => writeln!(f, "x{} and x{}", left.var(), right.var())?, + Binop::Or => writeln!(f, "x{} or x{}", left.var(), right.var())?, + Binop::Iff => writeln!(f, "x{} iff x{}", left.var(), right.var())?, + Binop::Xor => writeln!(f, "x{} xor x{}", left.var(), right.var())?, + Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?, + Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?, + Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?, + Binop::Eq => writeln!(f, "x{} == x{}", left.var(), right.var())?, + Binop::Gt => writeln!(f, "x{} > x{}", left.var(), right.var())?, + Binop::Geq => writeln!(f, "x{} >= x{}", left.var(), right.var())?, + Binop::Add => writeln!(f, "x{} + x{}", left.var(), right.var())?, + Binop::Sub => writeln!(f, "x{} - x{}", left.var(), right.var())?, + Binop::Mul => writeln!(f, "x{} * x{}", left.var(), right.var())?, + Binop::Div => writeln!(f, "x{} / x{}", left.var(), right.var())?, + }, + Expr::Select { cond, then, els } => { + writeln!(f, "x{} ? x{} : x{}", cond.var(), then.var(), els.var())? + } + Expr::Call { id, generics, args } => { + let i = match self.refs.get(*id).unwrap() { + Node::Transparent { def, .. } => { + self.imports.len() + self.funcs.get_index_of(&ByAddress(def)).unwrap() + } + Node::Opaque { def, .. } => self.imports.get_index_of(&def).unwrap(), + }; + write!(f, "f{i}<")?; + write_elems(f, 'T', generics.iter().map(|generic| generic.ty()))?; + write!(f, ">(")?; + write_elems(f, 'x', args.iter().map(|arg| arg.var()))?; + writeln!(f, ")")?; + } + Expr::For { arg, body, ret } => { + writeln!( + f, + "for x{}: T{} {{", + arg.var(), + self.def.vars[arg.var()].ty() + )?; + self.write_block(f, spaces + 2, body, *ret)?; + for _ in 0..spaces { + write!(f, " ")?; + } + writeln!(f, "}}")? + } + Expr::Accum { shape } => writeln!(f, "accum x{}", shape.var())?, + Expr::Add { accum, addend } => writeln!(f, "x{} += x{}", accum.var(), addend.var())?, + Expr::Resolve { var } => writeln!(f, "resolve x{}", var.var())?, + } + Ok(()) + } + + fn write_block( + &self, + f: &mut fmt::Formatter<'_>, + spaces: usize, + body: &[Instr], + ret: id::Var, + ) -> fmt::Result { + for instr in body.iter() { + self.write_instr(f, spaces, instr)?; + } + for _ in 0..spaces { + write!(f, " ")?; + } + writeln!(f, "x{}", ret.var())?; + Ok(()) + } + + fn write_func(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write_generics(f, &self.def.generics)?; + writeln!(f, "{{")?; + write_types(f, &self.def.types)?; + write!(f, " (")?; + let mut first = true; + for param in self.def.params.iter() { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "x{}: T{}", param.var(), self.def.vars[param.var()].ty())?; + } + writeln!(f, ") -> T{} {{", self.def.vars[self.def.ret.var()].ty())?; + self.write_block(f, 4, &self.def.body, self.def.ret)?; + writeln!(f, " }}")?; + writeln!(f, "}}") + } +} + +pub fn write_graph<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>( + f: &mut fmt::Formatter<'_>, + root: Node<'a, O, T>, +) -> fmt::Result { + match root { + Node::Opaque { + generics, + types, + params, + ret, + def: _, + } => { + write!(f, "fn f0 = ")?; + write_opaque(f, generics, types, params, ret) + } + Node::Transparent { refs, def } => { + let mut imports = IndexSet::new(); + let mut funcs = IndexMap::new(); + search(f, &mut imports, &mut funcs, &refs, &def.body)?; + for (i, (def, refs)) in funcs.iter().enumerate() { + write!(f, "fn f{} = ", imports.len() + i)?; + Function { + imports: &imports, + funcs: &funcs, + refs, + def, + } + .write_func(f)?; + writeln!(f)?; + } + write!(f, "fn f{} = ", imports.len() + funcs.len())?; + Function { + imports: &imports, + funcs: &funcs, + refs: &refs, + def, + } + .write_func(f) + } + } +} diff --git a/packages/core/src/impl.test.ts b/packages/core/src/impl.test.ts index cc93cdd..c8192eb 100644 --- a/packages/core/src/impl.test.ts +++ b/packages/core/src/impl.test.ts @@ -1,36 +1,20 @@ import * as wasm from "@rose-lang/wasm"; import { describe, expect, test } from "vitest"; import { - Bool, + Dual, Fn, Real, Vec, - abs, - add, - and, - div, - eq, fn, - geq, - gt, - iff, inner, - leq, - lt, mul, + mulLin, neg, - neq, - not, - or, - select, - sign, - sqrt, - sub, - vec, - xor, + opaque, + vjp, } from "./impl.js"; -const pprint = (f: Fn): string => wasm.pprint(f[inner]); +const pprint = (f: Fn): string => f[inner].pprint(); test("core IR type layouts", () => { // these don't matter too much, but it's good to notice if sizes increase @@ -44,152 +28,360 @@ test("core IR type layouts", () => { }); describe("pprint", () => { - test("if", () => { - const f = fn([Real, Real], Real, (x, y) => { - const p = lt(x, y); - const a = mul(x, y); - const b = sub(y, x); - const z = select(p, Real, add(a, x), mul(b, y)); - const w = add(z, x); - return add(y, w); - }); + test("opaque", () => { + const f = opaque([Real], Real, (x) => x); const s = pprint(f); expect(s).toBe( ` -T0 = F64 -T1 = F64 -T2 = Bool -(x0: T0, x1: T0) -> T0 { - x2: T2 = x0 < x1 - x3: T0 = x0 * x1 - x4: T0 = x1 - x0 - x5: T0 = x3 + x0 - x6: T0 = x4 * x1 - x7: T0 = x2 ? x5 : x6 - x8: T0 = x7 + x0 - x9: T0 = x1 + x8 - x9 +fn f0 = <>{ + type T0 = F64 + opaque: (T0) -> T0 } `.trimStart(), ); }); - test("call funcs", () => { - const g = fn([Real], Real, (y) => add(2, y)); - const h = fn([Real], Real, (z) => mul(2, z)); - const f = fn([Real], Real, (x) => { - const a = g(x); - const b = h(x); - return add(a, b); + test("graph", () => { + const exp = opaque([Real], Real, Math.exp); + const sin = opaque([Real], Real, Math.sin); + const cos = opaque([Real], Real, Math.cos); + + exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = exp(x); + return { re: y, du: mulLin(dx, y) }; + }); + sin.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + return { re: sin(x), du: mulLin(dx, cos(x)) }; + }); + cos.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + return { re: cos(x), du: mulLin(dx, neg(sin(x))) }; }); - const s = pprint(f); - expect(s).toBe( - ` -T0 = F64 -T1 = F64 -(x0: T0) -> T0 { - x1: T0 = f0<>(x0) - x2: T0 = f1<>(x0) - x3: T0 = x1 + x2 - x3 -} -`.trimStart(), - ); - }); - test("unary operations", () => { - const f = fn([Real], Real, (x) => { - const a = not(true); - const b = neg(x); - const c = abs(b); - const d = sign(x); - const e = sqrt(x); - return e; + const Complex = { re: Real, im: Real } as const; + + const complexp = fn([Complex], Complex, (z) => { + const c = exp(z.re); + return { re: mul(c, cos(z.im)), im: mul(c, sin(z.im)) }; }); - const s = pprint(f); - expect(s).toBe( - ` -T0 = F64 -T1 = F64 -T2 = Bool -(x0: T0) -> T0 { - x1: T2 = true - x2: T2 = not x1 - x3: T0 = -x0 - x4: T0 = |x3| - x5: T0 = sign(x0) - x6: T0 = sqrt(x0) - x6 -} -`.trimStart(), - ); - }); - test("binary operations", () => { - const f = fn([Real, Real], Bool, (x, y) => { - const a = add(x, y); - const b = sub(x, y); - const c = mul(x, y); - const d = div(x, y); - const e = and(true, false); - const f = or(true, false); - const g = iff(true, false); - const h = xor(true, false); - const i = neq(x, y); - const j = lt(x, y); - const k = leq(x, y); - const l = eq(x, y); - const m = gt(x, y); - return geq(c, d); + const f = fn([Complex, Complex], Vec(2, Complex), (z, w) => { + const { ret, grad } = vjp(complexp)(z); + return [ret, grad(w)]; }); + const s = pprint(f); expect(s).toBe( ` -T0 = F64 -T1 = F64 -T2 = Bool -(x0: T0, x1: T0) -> T2 { - x6: T2 = true - x7: T2 = false - x2: T0 = x0 + x1 - x3: T0 = x0 - x1 - x4: T0 = x0 * x1 - x5: T0 = x0 / x1 - x8: T2 = x6 and x7 - x9: T2 = x6 or x7 - x10: T2 = x6 iff x7 - x11: T2 = x6 xor x7 - x12: T2 = x0 != x1 - x13: T2 = x0 < x1 - x14: T2 = x0 <= x1 - x15: T2 = x0 == x1 - x16: T2 = x0 > x1 - x17: T2 = x4 >= x5 - x17 +fn f0 = <>{ + type T0 = F64 + opaque: (T0) -> T0 } -`.trimStart(), - ); - }); - test("for", () => { - const n = 3; - const Rn = Vec(n, Real); - const f = fn([Rn, Rn], Rn, (a, b) => vec(n, Real, (i) => add(a[i], b[i]))); - const s = pprint(f); - expect(s).toBe( - ` -T0 = F64 -T1 = F64 -T2 = 3 -T3 = [T2]T0 -(x0: T3, x1: T3) -> T3 { - x6: T3 = for x2: T2 { - x3: T0 = x0[x2] - x4: T0 = x1[x2] - x5: T0 = x3 + x4 - x5 +fn f1 = <>{ + type T0 = F64 + opaque: (T0) -> T0 +} + +fn f2 = <>{ + type T0 = F64 + opaque: (T0) -> T0 +} + +fn f3 = <>{ + type T0 = F64 + type T1 = F64 + type T2 = F64 + type T3 = Unit + type T4 = &T2 + type T5 = &T1 + type T6 = (T2, T0) + type T7 = (T2, T6) + (x0: T2) -> T7 { + let x3: T0 = f0<>(x0) + let x6: T6 = (x0, x3) + let x7: T7 = (x3, x6) + x7 + } +} + +fn f4 = <>{ + type T0 = F64 + type T1 = F64 + type T2 = F64 + type T3 = Unit + type T4 = &T2 + type T5 = &T0 + type T6 = &T1 + type T7 = (T2, T0, T0, T0) + type T8 = (T2, T7) + (x0: T2) -> T8 { + let x3: T0 = f1<>(x0) + let x4: T0 = f2<>(x0) + let x5: T0 = -x4 + let x8: T7 = (x0, x3, x4, x5) + let x9: T8 = (x3, x8) + x9 + } +} + +fn f5 = <>{ + type T0 = F64 + type T1 = F64 + type T2 = F64 + type T3 = Unit + type T4 = &T2 + type T5 = &T1 + type T6 = (T2, T0, T0) + type T7 = (T2, T6) + (x0: T2) -> T7 { + let x3: T0 = f2<>(x0) + let x4: T0 = f1<>(x0) + let x7: T6 = (x0, x3, x4) + let x8: T7 = (x3, x7) + x8 + } +} + +fn f6 = <>{ + type T0 = F64 + type T1 = F64 + type T2 = F64 + type T3 = F64 + type T4 = (T2, T2) + type T5 = Unit + type T6 = &T4 + type T7 = &T1 + type T8 = &T2 + type T9 = F64 + type T10 = &T9 + type T11 = (T9, T9) + type T12 = (T9, T11) + type T13 = (T2, T12) + type T14 = (T9, T9, T9, T9) + type T15 = (T9, T14) + type T16 = (T2, T15) + type T17 = &T0 + type T18 = (T9, T9, T9) + type T19 = (T9, T18) + type T20 = (T2, T19) + type T21 = (T4, T2, T12, T2, T15, T0, T2, T19, T0, T4) + type T22 = (T4, T21) + (x0: T4) -> T22 { + let x1: T2 = x0[1] + let x31: T13 = f3<>(x1) + let x2: T2 = x31[0] + let x32: T12 = x31[1] + let x3: T2 = x0[0] + let x33: T16 = f4<>(x3) + let x4: T2 = x33[0] + let x34: T15 = x33[1] + let x19: T0 = x2 * x4 + let x6: T2 = x0[0] + let x35: T20 = f5<>(x6) + let x7: T2 = x35[0] + let x36: T19 = x35[1] + let x27: T0 = x2 * x7 + let x9: T4 = (x27, x19) + let x37: T21 = (x0, x2, x32, x4, x34, x19, x7, x36, x27, x9) + let x38: T22 = (x9, x37) + x38 + } +} + +fn f7 = <>{ + type T0 = F64 + type T1 = F64 + type T2 = F64 + type T3 = Unit + type T4 = &T2 + type T5 = &T1 + type T6 = (T2, T0, T0) + (x9: T4, x14: T2, x8: T6) -> T3 { + let x7: T1 = 0 + let x0: T2 = x8[0] + let x3: T0 = x8[1] + let x4: T0 = x8[2] + let x10: T5 = accum x7 + let x15: T3 = x10 += x14 + let x11: T1 = resolve x10 + let x12: T1 = x11 * x4 + let x13: T3 = x9 += x12 + let x16: T3 = unit + x16 + } +} + +fn f8 = <>{ + type T0 = F64 + type T1 = F64 + type T2 = F64 + type T3 = Unit + type T4 = &T2 + type T5 = &T0 + type T6 = &T1 + type T7 = (T2, T0, T0, T0) + (x10: T4, x17: T2, x9: T7) -> T3 { + let x8: T1 = 0 + let x0: T2 = x9[0] + let x3: T0 = x9[1] + let x4: T0 = x9[2] + let x5: T0 = x9[3] + let x11: T5 = accum x5 + let x13: T6 = accum x8 + let x18: T3 = x13 += x17 + let x14: T1 = resolve x13 + let x15: T1 = x14 * x5 + let x16: T3 = x10 += x15 + let x12: T0 = resolve x11 + let x19: T3 = unit + x19 + } +} + +fn f9 = <>{ + type T0 = F64 + type T1 = F64 + type T2 = F64 + type T3 = Unit + type T4 = &T2 + type T5 = &T1 + type T6 = (T2, T0) + (x8: T4, x13: T2, x7: T6) -> T3 { + let x6: T1 = 0 + let x0: T2 = x7[0] + let x3: T0 = x7[1] + let x9: T5 = accum x6 + let x14: T3 = x9 += x13 + let x10: T1 = resolve x9 + let x11: T1 = x10 * x3 + let x12: T3 = x8 += x11 + let x15: T3 = unit + x15 + } +} + +fn f10 = <>{ + type T0 = F64 + type T1 = F64 + type T2 = F64 + type T3 = F64 + type T4 = (T2, T2) + type T5 = Unit + type T6 = &T4 + type T7 = &T1 + type T8 = &T2 + type T9 = F64 + type T10 = &T9 + type T11 = (T9, T9) + type T12 = (T9, T11) + type T13 = (T2, T12) + type T14 = (T9, T9, T9, T9) + type T15 = (T9, T14) + type T16 = (T2, T15) + type T17 = &T0 + type T18 = (T9, T9, T9) + type T19 = (T9, T18) + type T20 = (T2, T19) + type T21 = (T4, T2, T12, T2, T15, T0, T2, T19, T0, T4) + (x33: T6, x85: T4, x32: T21) -> T5 { + let x31: T1 = 0 + let x0: T4 = x32[0] + let x34: T7 = accum x31 + let x1: T2 = x0[1] + let x36: T8 = &x33[1] + let x2: T2 = x32[1] + let x37: T12 = x32[2] + let x38: T8 = accum x2 + let x3: T2 = x0[0] + let x41: T8 = &x33[0] + let x4: T2 = x32[3] + let x42: T15 = x32[4] + let x43: T8 = accum x4 + let x19: T0 = x32[5] + let x46: T17 = accum x19 + let x48: T7 = accum x31 + let x52: T7 = accum x31 + let x56: T7 = accum x31 + let x6: T2 = x0[0] + let x60: T8 = &x33[0] + let x7: T2 = x32[6] + let x61: T19 = x32[7] + let x62: T8 = accum x7 + let x27: T0 = x32[8] + let x65: T17 = accum x27 + let x67: T7 = accum x31 + let x71: T7 = accum x31 + let x75: T7 = accum x31 + let x9: T4 = x32[9] + let x79: T6 = accum x9 + let x86: T5 = x79 += x85 + let x80: T4 = resolve x79 + let x83: T2 = x80[1] + let x84: T5 = x56 += x83 + let x81: T2 = x80[0] + let x82: T5 = x75 += x81 + let x76: T1 = resolve x75 + let x78: T5 = x71 += x76 + let x77: T5 = x67 += x76 + let x72: T1 = resolve x71 + let x73: T1 = x72 * x2 + let x74: T5 = x62 += x73 + let x68: T1 = resolve x67 + let x69: T1 = x68 * x7 + let x70: T5 = x38 += x69 + let x66: T0 = resolve x65 + let x63: T2 = resolve x62 + let x64: T5 = f7<>(x60, x63, x61) + let x57: T1 = resolve x56 + let x59: T5 = x52 += x57 + let x58: T5 = x48 += x57 + let x53: T1 = resolve x52 + let x54: T1 = x53 * x2 + let x55: T5 = x43 += x54 + let x49: T1 = resolve x48 + let x50: T1 = x49 * x4 + let x51: T5 = x38 += x50 + let x47: T0 = resolve x46 + let x44: T2 = resolve x43 + let x45: T5 = f8<>(x41, x44, x42) + let x39: T2 = resolve x38 + let x40: T5 = f9<>(x36, x39, x37) + let x35: T1 = resolve x34 + let x87: T5 = unit + x87 + } +} + +fn f11 = <>{ + type T0 = F64 + type T1 = F64 + type T2 = (T0, T0) + type T3 = 2 + type T4 = [T3]T2 + type T5 = Unit + type T6 = &T2 + type T7 = &T0 + type T8 = (T0, T0) + type T9 = (T0, T8) + type T10 = (T0, T9) + type T11 = (T0, T0, T0, T0) + type T12 = (T0, T11) + type T13 = (T0, T12) + type T14 = (T0, T0, T0) + type T15 = (T0, T14) + type T16 = (T0, T15) + type T17 = (T2, T0, T9, T0, T12, T0, T0, T15, T0, T2) + type T18 = (T2, T17) + (x0: T2, x1: T2) -> T4 { + let x2: T18 = f6<>(x0) + let x3: T2 = x2[0] + let x4: T17 = x2[1] + let x5: T6 = accum x0 + let x6: T5 = f10<>(x5, x1, x4) + let x7: T2 = resolve x5 + let x8: T4 = [x3, x7] + x8 } - x6 } `.trimStart(), );