From 5a2a9c4c5cd20bb754566bf215b9c2cf379949c1 Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Wed, 27 Sep 2023 19:40:02 -0400 Subject: [PATCH] Use local instead of memory for `F64` accumulator (#118) --- crates/wasm/src/lib.rs | 313 +++++++++++++++++++++++++++-------------- 1 file changed, 205 insertions(+), 108 deletions(-) diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs index 790aa6a..0dfcd92 100644 --- a/crates/wasm/src/lib.rs +++ b/crates/wasm/src/lib.rs @@ -1,7 +1,10 @@ use by_address::ByAddress; use indexmap::{map::Entry, IndexMap, IndexSet}; use rose::{id, Binop, Expr, Func, Instr, Node, Refs, Ty, Unop}; -use std::{hash::Hash, mem::take}; +use std::{ + hash::Hash, + mem::{replace, take}, +}; use wasm_encoder::{ BlockType, CodeSection, EntityType, ExportSection, Function, FunctionSection, ImportSection, Instruction, MemArg, MemorySection, MemoryType, Module, TypeSection, ValType, @@ -136,20 +139,6 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Topsort<'a, O, T> { } } -/// Return the WebAssembly value type used to represent a local of type `ty`. -fn val_type(ty: &Ty) -> ValType { - match ty { - Ty::Unit - | Ty::Bool - | Ty::Fin { .. } - | Ty::Ref { .. } - | Ty::Array { .. } - | Ty::Tuple { .. } => ValType::I32, - Ty::F64 => ValType::F64, - Ty::Generic { .. } => unreachable!(), - } -} - /// A WebAssembly memory offset or size. type Size = u32; @@ -318,6 +307,25 @@ struct Meta { members: Option>, } +/// Return the WebAssembly value type used to represent a local for the type with global ID `t`. +/// +/// The second component of the pair is `true` iff a parameter with this result should be switched +/// to a result instead; the only case where this happens is if the type is a `Ref` containing an +/// `F64`. +fn val_type(metas: &[Meta], t: id::Ty) -> (ValType, bool) { + match metas[t.ty()].ty { + Ty::Unit | Ty::Bool | Ty::Fin { .. } | Ty::Array { .. } | Ty::Tuple { .. } => { + (ValType::I32, false) + } + Ty::F64 => (ValType::F64, false), + Ty::Generic { .. } => unreachable!(), + Ty::Ref { inner } => { + let (vt, _) = val_type(metas, inner); + (vt, vt == ValType::F64) + } + } +} + /// Generates WebAssembly code for a function. struct Codegen<'a, 'b, O, T> { /// Metadata about all the types in the global type index. @@ -339,7 +347,7 @@ struct Codegen<'a, 'b, O, T> { refs: &'b T, /// The definition of the particular function we're generating code for. - def: &'b Func, + def: &'a Func, /// Mapping from this function's type indices to type indices in the global type index. types: &'b [id::Ty], @@ -353,6 +361,15 @@ struct Codegen<'a, 'b, O, T> { /// allocation cost depends both on its block's allocation cost and on the number of iterations. offset: Size, + /// Stack of pending accumulator instructions to process at the end of each scope. + /// + /// The bottom element is always an empty vector, because after we process the entire function, + /// we call `resolve` again, which always pops this stack. + stack: Vec>, + + /// Pending accumulator instructions to process at the end of this scope, in reverse order. + unresolved: Vec<&'a Instr>, + /// The WebAssembly function under construction. wasm: Function, } @@ -415,8 +432,73 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { layout.store(&mut self.wasm, offset) } + /// Pop the top of the scope stack and process all pending accumulator instructions. + fn resolve(&mut self) { + let unresolved = replace(&mut self.unresolved, self.stack.pop().unwrap()); + for instr in unresolved.into_iter().rev() { + match instr.expr { + Expr::Slice { array, index } => { + let &Meta { layout, .. } = + self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { + Ty::Ref { inner } => inner, + _ => unreachable!(), + }); + self.get(array); + self.get(index); + self.u32_const(layout.size()); + self.wasm.instruction(&Instruction::I32Mul); + self.wasm.instruction(&Instruction::I32Add); + self.get(array); + self.get(index); + self.u32_const(layout.size()); + // TODO: avoid recalculating the offset + self.wasm.instruction(&Instruction::I32Mul); + self.wasm.instruction(&Instruction::I32Add); + self.load(layout, 0); + self.get(instr.var); + self.wasm.instruction(&Instruction::F64Add); + self.store(layout, 0); + } + Expr::Field { tuple, member } => { + let &Meta { layout, .. } = + self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { + Ty::Ref { inner } => inner, + _ => unreachable!(), + }); + let Meta { members, .. } = + self.meta(match self.def.types[self.def.vars[tuple.var()].ty()] { + Ty::Ref { inner } => inner, + _ => unreachable!(), + }); + let offset = members.as_ref().unwrap()[member.member()]; + self.get(tuple); + self.get(tuple); + self.load(layout, offset); + self.get(instr.var); + self.wasm.instruction(&Instruction::F64Add); + self.store(layout, offset); + } + Expr::Select { cond, then, els } => { + self.get(cond); + self.wasm.instruction(&Instruction::If(BlockType::Empty)); + self.get(then); + self.get(instr.var); + self.wasm.instruction(&Instruction::F64Add); + self.set(then); + self.wasm.instruction(&Instruction::Else); + self.get(els); + self.get(instr.var); + self.wasm.instruction(&Instruction::F64Add); + self.set(els); + self.wasm.instruction(&Instruction::End); + } + _ => unreachable!(), + } + } + } + /// Generate code for the given `block`. - fn block(&mut self, block: &[Instr]) { + fn block(&mut self, block: &'a [Instr]) { for instr in block.iter() { match &instr.expr { Expr::Unit => { @@ -483,45 +565,41 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { Ty::Ref { inner } => inner, _ => unreachable!(), }); - let size = meta.layout.size(); - self.get(array); - self.get(index); - self.u32_const(size); - self.wasm.instruction(&Instruction::I32Mul); - self.wasm.instruction(&Instruction::I32Add); - if let Ty::Array { .. } | Ty::Tuple { .. } = &meta.ty { - // if this array holds primitives then we just want a pointer to the - // element, but if it's actually another composite value then it's already a - // pointer, so we need to do a load because otherwise we'd have a pointer to - // a pointer instead of just one direct pointer - self.load(meta.layout, 0); + match meta.ty { + Ty::F64 => { + self.wasm.instruction(&Instruction::F64Const(0.)); + self.unresolved.push(instr); + } + _ => { + let size = meta.layout.size(); + self.get(array); + self.get(index); + self.u32_const(size); + self.wasm.instruction(&Instruction::I32Mul); + self.wasm.instruction(&Instruction::I32Add); + self.load(meta.layout, 0); + } } } &Expr::Field { tuple, member } => { - let Meta { members, .. } = - self.meta(match self.def.types[self.def.vars[tuple.var()].ty()] { - Ty::Ref { inner } => inner, - _ => unreachable!(), - }); - let offset = members.as_ref().unwrap()[member.member()]; let meta = self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { Ty::Ref { inner } => inner, _ => unreachable!(), }); - self.get(tuple); - match &meta.ty { - Ty::Unit | Ty::Bool | Ty::F64 | Ty::Fin { .. } => { - // if this array holds primitives then we just want a pointer to the - // element - self.u32_const(offset); - self.wasm.instruction(&Instruction::I32Add); + match meta.ty { + Ty::F64 => { + self.wasm.instruction(&Instruction::F64Const(0.)); + self.unresolved.push(instr); } - Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), - Ty::Array { .. } | Ty::Tuple { .. } => { - // if this array holds other composite values then each element is - // already a pointer, so we need to do a load because otherwise we'd - // have a pointer to a pointer instead of just one direct pointer + _ => { + let Meta { members, .. } = + self.meta(match self.def.types[self.def.vars[tuple.var()].ty()] { + Ty::Ref { inner } => inner, + _ => unreachable!(), + }); + let offset = members.as_ref().unwrap()[member.member()]; + self.get(tuple); self.load(meta.layout, offset); } } @@ -584,10 +662,18 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { }; } &Expr::Select { cond, then, els } => { - self.get(then); - self.get(els); - self.get(cond); - self.wasm.instruction(&Instruction::Select); + match self.def.types[self.def.vars[instr.var.var()].ty()] { + Ty::Ref { inner } if self.meta(inner).ty == Ty::F64 => { + self.wasm.instruction(&Instruction::F64Const(0.)); + self.unresolved.push(instr); + } + _ => { + self.get(then); + self.get(els); + self.get(cond); + self.wasm.instruction(&Instruction::Select); + } + } } Expr::Call { id, generics, args } => { let gens = generics @@ -595,7 +681,11 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { .map(|t| self.types[self.def.vars[t.ty()].ty()]) .collect(); for &arg in args.iter() { - self.get(arg); + match self.def.types[self.def.vars[arg.var()].ty()] { + // `F64` accumulators become results, not params + Ty::Ref { inner } if self.meta(inner).ty == Ty::F64 => {} + _ => self.get(arg), + } } let i = match self.refs.get(*id).unwrap() { Node::Transparent { def, .. } => { @@ -610,6 +700,17 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { }; self.wasm .instruction(&Instruction::Call(i.try_into().unwrap())); + for &arg in args.iter().rev() { + match &self.def.types[self.def.vars[arg.var()].ty()] { + &Ty::Ref { inner } if self.meta(inner).ty == Ty::F64 => { + // `F64` accumulators became results + self.get(arg); + self.wasm.instruction(&Instruction::F64Add); + self.set(arg); + } + _ => {} + } + } } Expr::For { arg, body, ret } => { let n = u_size(match self.meta(self.def.vars[arg.var()]).ty { @@ -636,7 +737,9 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { self.set(*arg); self.wasm.instruction(&Instruction::Loop(BlockType::Empty)); + self.stack.push(take(&mut self.unresolved)); self.block(body); + self.resolve(); self.get(instr.var); self.get(*arg); @@ -663,26 +766,11 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { &Expr::Accum { shape } => { let meta = self.meta(self.def.vars[shape.var()]); match &meta.ty { - // this is a bit subtle: usually a `Ref` variable is a pointer, and that is - // also true for `Ref` variables to values of these three discrete - // continuous types if they come from an `Expr::Slice` or `Expr::Field`; - // but, if we're directly starting an accumulator for a discrete primitive - // value, then its value can't be modified, so we can just store it directly - // instead of allocating extra memory; this works because the WebAssembly - // value types for all these discrete primitive types are the same as for - // pointers, and it's OK to have the representation be different depending - // on whether it's directly introduced by `Expr::Accum` or not, because - // those are the only ones on which we can use `Expr::Resolve`, and `Ref`s - // cannot be directly read before they're resolved anyway Ty::Unit | Ty::Bool | Ty::Fin { .. } => self.get(shape), Ty::F64 => { - self.pointer(); - self.pointer(); // TODO: `f64.const` instructions are always 8 bytes, much larger than // most instructions; maybe we should just keep this constant in a local self.wasm.instruction(&Instruction::F64Const(0.)); - self.store(Layout::F64, 0); - self.bump(8); } Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), Ty::Array { .. } | Ty::Tuple { .. } => { @@ -694,6 +782,7 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { self.bump(cost); } } + self.stack.push(take(&mut self.unresolved)); } &Expr::Add { accum, addend } => { let meta = self.meta(self.def.vars[addend.var()]); @@ -701,11 +790,9 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { Ty::Unit | Ty::Bool | Ty::Fin { .. } => {} Ty::F64 => { self.get(accum); - self.get(accum); - self.load(Layout::F64, 0); self.get(addend); self.wasm.instruction(&Instruction::F64Add); - self.store(Layout::F64, 0); + self.set(accum); } Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), Ty::Array { .. } | Ty::Tuple { .. } => { @@ -718,15 +805,8 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { self.wasm.instruction(&Instruction::I32Const(0)); } &Expr::Resolve { var } => { + self.resolve(); self.get(var); - if let Ty::F64 = &self.meta(self.def.vars[instr.var.var()]).ty { - // as explained above, if the `inner` value is a discrete primitive type - // then we cheated and stored it directly in the local so there's nothing to - // do, and if it's a pointer then we still just need a pointer so there's - // also nothing to do; but if it's a continuous primitive type then we - // introduced an extra layer of indirection so we need to loads - self.load(Layout::F64, 0); - } } } self.set(instr.var); @@ -796,21 +876,27 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> // we add to this lazily as we generate our imports, functions, and code, after which we'll // generate the actual function types section right near the end; it doesn't matter as long as // the order we actually add the sections to the module itself is correct - let mut func_types: IndexSet<(Box<[ValType]>, ValType)> = IndexSet::new(); + type Signature = (Box<[ValType]>, Box<[ValType]>); + let mut func_types: IndexSet = IndexSet::new(); + let (accum_type_index, _) = + func_types.insert_full(([ValType::I32, ValType::I32].into(), [].into())); let mut import_section = ImportSection::new(); for (i, (params, ret)) in imports.values().enumerate() { - let (type_index, _) = func_types.insert_full(( - params.iter().map(|t| val_type(&types[t.ty()])).collect(), - val_type(&types[ret.ty()]), - )); + // short for `ValType` + let vt = |t: id::Ty| match types[t.ty()] { + Ty::F64 => ValType::F64, + _ => unreachable!(), + }; + let (type_index, _) = + func_types.insert_full((params.iter().map(|&t| vt(t)).collect(), [vt(*ret)].into())); // we reserve type index zero for the type with two `i32` params and no results, which we // use for accumulation zero and add functions; we don't include that in the `func_types` // index itself, because that index only holds function types with exactly one result import_section.import( "", &i.to_string(), - EntityType::Function((1 + type_index).try_into().unwrap()), + EntityType::Function(type_index.try_into().unwrap()), ); } @@ -1052,9 +1138,9 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> layout, accum: cost.map(|cost| { let zero = extras.try_into().unwrap(); - function_section.function(0); + function_section.function(accum_type_index.try_into().unwrap()); let add = (extras + 1).try_into().unwrap(); - function_section.function(0); + function_section.function(accum_type_index.try_into().unwrap()); extras += 2; Accum { zero, cost, add } }), @@ -1064,29 +1150,31 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> let mut costs = vec![]; // allocation cost of each function, in bytes for ((def, _), (refs, def_types)) in funcs.iter() { - let vt = |t: id::Ty| val_type(&metas[def_types[t.ty()].ty()].ty); // short for `ValType` - let params: Local = (def.params.len() + 1).try_into().unwrap(); // extra pointer parameter + let vt = |t: id::Ty| val_type(&metas, def_types[t.ty()]); // short for `ValType` let mut locals = vec![None; def.vars.len()]; - let (type_index, _) = func_types.insert_full(( - def.params - .iter() - .enumerate() - .map(|(i, param)| { - locals[param.var()] = Some(i.try_into().unwrap()); - vt(def.vars[param.var()]) - }) - .chain([ValType::I32]) // extra pointer parameter - .collect(), - vt(def.vars[def.ret.var()]), - )); - function_section.function((1 + type_index).try_into().unwrap()); + let (ret_ty, _) = vt(def.vars[def.ret.var()]); + let mut params = vec![]; + let mut results = vec![ret_ty]; + for param in def.params.iter() { + let (val_ty, result) = vt(def.vars[param.var()]); + if result { + results.push(val_ty); + } else { + locals[param.var()] = Some(params.len().try_into().unwrap()); + params.push(val_ty); + } + } + params.push(ValType::I32); // extra pointer parameter + let num_params: u32 = params.len().try_into().unwrap(); + let (type_index, _) = func_types.insert_full((params.into(), results.into())); + function_section.function(type_index.try_into().unwrap()); let mut i32s = 0; for (i, &t) in def.vars.iter().enumerate() { if locals[i].is_none() { - if let ValType::I32 = vt(t) { - locals[i] = Some(params + i32s); + if let (ValType::I32, _) = vt(t) { + locals[i] = Some(num_params + i32s); i32s += 1; } } @@ -1094,8 +1182,8 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> let mut f64s = 0; for (i, &t) in def.vars.iter().enumerate() { if locals[i].is_none() { - assert_eq!(vt(t), ValType::F64); - locals[i] = Some(params + i32s + f64s); + assert!(matches!(vt(t), (ValType::F64, _))); + locals[i] = Some(num_params + i32s + f64s); f64s += 1; } } @@ -1112,19 +1200,28 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> types: def_types, locals: &locals, offset: 0, + stack: vec![vec![]], + unresolved: vec![], wasm: Function::new([(i32s, ValType::I32), (f64s, ValType::F64)]), }; + // accumulator result variables are automatically zero: https://stackoverflow.com/a/77170544 codegen.block(&def.body); + codegen.resolve(); codegen.get(def.ret); + for ¶m in def.params.iter() { + if let (_, true) = vt(def.vars[param.var()]) { + // return the accumulator variables we moved from params to results + codegen.get(param); + } + } codegen.wasm.instruction(&Instruction::End); code_section.function(&codegen.wasm); costs.push(codegen.offset); } let mut type_section = TypeSection::new(); - type_section.function([ValType::I32, ValType::I32], []); // for accumulation functions - for (params, ret) in func_types { - type_section.function(params.into_vec(), [ret]); + for (params, results) in func_types { + type_section.function(params.into_vec(), results.into_vec()); } let mut memory_section = MemorySection::new();