From 2d2c73b462abdf1c1bcbe35cac6f48bd685dbbce Mon Sep 17 00:00:00 2001 From: jfecher Date: Wed, 29 Jan 2025 11:00:45 -0600 Subject: [PATCH] chore: Rework defunctionalize pass to not rely on DFG bugs (#7222) --- compiler/noirc_evaluator/src/ssa/ir/dfg.rs | 8 +- .../src/ssa/opt/defunctionalize.rs | 157 +++++++++++++----- .../src/ssa/opt/loop_invariant.rs | 8 +- 3 files changed, 121 insertions(+), 52 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index 8e87db15caf..c6b8c2324d9 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -424,7 +424,9 @@ impl DataFlowGraph { if let Some(existing) = self.functions.get(&function) { return *existing; } - self.values.insert(Value::Function(function)) + let result = self.values.insert(Value::Function(function)); + self.functions.insert(function, result); + result } /// Gets or creates a ValueId for the given FunctionId. @@ -432,7 +434,9 @@ impl DataFlowGraph { if let Some(existing) = self.foreign_functions.get(function) { return *existing; } - self.values.insert(Value::ForeignFunction(function.to_owned())) + let result = self.values.insert(Value::ForeignFunction(function.to_owned())); + self.foreign_functions.insert(function.to_owned(), result); + result } /// Gets or creates a ValueId for the given Intrinsic. diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 4afddbef41a..fc83b8f2c1a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -4,7 +4,7 @@ //! with a non-literal target can be replaced with a call to an apply function. //! The apply function is a dispatch function that takes the function id as a parameter //! and dispatches to the correct target. -use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::collections::{BTreeMap, BTreeSet}; use acvm::FieldElement; use iter_extended::vecmap; @@ -80,21 +80,54 @@ impl DefunctionalizationContext { /// Defunctionalize a single function fn defunctionalize(&mut self, func: &mut Function) { - let mut call_target_values = HashSet::new(); - for block_id in func.reachable_blocks() { - let block = &func.dfg[block_id]; - let instructions = block.instructions().to_vec(); + let block = &mut func.dfg[block_id]; + + // Temporarily take the parameters here just to avoid cloning them + let parameters = block.take_parameters(); + for parameter in ¶meters { + if func.dfg.type_of_value(*parameter) == Type::Function { + func.dfg.set_type_of_value(*parameter, Type::field()); + } + } + + let block = &mut func.dfg[block_id]; + block.set_parameters(parameters); + + // Do the same for the terminator + let mut terminator = block.take_terminator(); + terminator.map_values_mut(|value| map_function_to_field(func, value).unwrap_or(value)); + + let block = &mut func.dfg[block_id]; + block.set_terminator(terminator); - for instruction_id in instructions { - let instruction = func.dfg[instruction_id].clone(); + // Now we can finally change each instruction, replacing + // each first class function with a field value and replacing calls + // to a first class function to a call to the relevant `apply` function. + #[allow(clippy::unnecessary_to_owned)] // clippy is wrong here + for instruction_id in block.instructions().to_vec() { + let mut instruction = func.dfg[instruction_id].clone(); let mut replacement_instruction = None; + + if remove_first_class_functions_in_instruction(func, &mut instruction) { + func.dfg[instruction_id] = instruction.clone(); + } + + #[allow(clippy::unnecessary_to_owned)] // clippy is wrong here + for result in func.dfg.instruction_results(instruction_id).to_vec() { + if func.dfg.type_of_value(result) == Type::Function { + func.dfg.set_type_of_value(result, Type::field()); + } + } + // Operate on call instructions let (target_func_id, arguments) = match &instruction { Instruction::Call { func: target_func_id, arguments } => { (*target_func_id, arguments) } - _ => continue, + _ => { + continue; + } }; match func.dfg[target_func_id] { @@ -116,13 +149,8 @@ impl DefunctionalizationContext { arguments.insert(0, target_func_id); } let func = apply_function_value_id; - call_target_values.insert(func); - replacement_instruction = Some(Instruction::Call { func, arguments }); } - Value::Function(..) => { - call_target_values.insert(target_func_id); - } _ => {} } if let Some(new_instruction) = replacement_instruction { @@ -130,29 +158,6 @@ impl DefunctionalizationContext { } } } - - // Change the type of all the values that are not call targets to NativeField - let value_ids = vecmap(func.dfg.values_iter(), |(id, _)| id); - for value_id in value_ids { - if let Type::Function = func.dfg[value_id].get_type().as_ref() { - match &func.dfg[value_id] { - // If the value is a static function, transform it to the function id - Value::Function(id) => { - if !call_target_values.contains(&value_id) { - let field = NumericType::NativeField; - let new_value = - func.dfg.make_constant(function_id_to_field(*id), field); - func.dfg.set_value_from_id(value_id, new_value); - } - } - // If the value is a function used as value, just change the type of it - Value::Instruction { .. } | Value::Param { .. } => { - func.dfg.set_type_of_value(value_id, Type::field()); - } - _ => {} - } - } - } } /// Returns the apply function for the given signature @@ -161,6 +166,54 @@ impl DefunctionalizationContext { } } +/// Replace any first class functions used in an instruction with a field value. +/// This applies to any function used anywhere else other than the function position +/// of a call instruction. Returns true if the instruction was modified +fn remove_first_class_functions_in_instruction( + func: &mut Function, + instruction: &mut Instruction, +) -> bool { + let mut modified = false; + let mut map_value = |value: ValueId| { + if let Some(new_value) = map_function_to_field(func, value) { + modified = true; + new_value + } else { + value + } + }; + + if let Instruction::Call { func: _, arguments } = instruction { + for arg in arguments { + *arg = map_value(*arg); + } + } else { + instruction.map_values_mut(map_value); + } + + modified +} + +/// Try to map the given function literal to a field, returning Some(field) on success. +/// Returns none if the given value was not a function or doesn't need to be mapped. +fn map_function_to_field(func: &mut Function, value: ValueId) -> Option { + if let Type::Function = func.dfg[value].get_type().as_ref() { + match &func.dfg[value] { + // If the value is a static function, transform it to the function id + Value::Function(id) => { + let new_value = function_id_to_field(*id); + return Some(func.dfg.make_constant(new_value, NumericType::NativeField)); + } + // If the value is a function used as value, just change the type of it + Value::Instruction { .. } | Value::Param { .. } => { + func.dfg.set_type_of_value(value, Type::field()); + } + _ => (), + } + } + None +} + /// Collects all functions used as values that can be called by their signatures fn find_variants(ssa: &Ssa) -> Variants { let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new(); @@ -252,13 +305,25 @@ fn create_apply_functions( variants_map: BTreeMap<(Signature, RuntimeType), Vec>, ) -> ApplyFunctions { let mut apply_functions = HashMap::default(); - for ((signature, runtime), variants) in variants_map.into_iter() { + for ((mut signature, runtime), variants) in variants_map.into_iter() { assert!( !variants.is_empty(), "ICE: at least one variant should exist for a dynamic call {signature:?}" ); let dispatches_to_multiple_functions = variants.len() > 1; + for param in &mut signature.params { + if *param == Type::Function { + *param = Type::field(); + } + } + + for ret in &mut signature.returns { + if *ret == Type::Function { + *ret = Type::field(); + } + } + let id = if dispatches_to_multiple_functions { create_apply_function(ssa, signature.clone(), runtime, variants) } else { @@ -282,7 +347,7 @@ fn create_apply_function( function_ids: Vec, ) -> FunctionId { assert!(!function_ids.is_empty()); - let globals = ssa.functions[&function_ids[0]].dfg.globals.clone(); + let globals = ssa.main().dfg.globals.clone(); ssa.add_fn(|id| { let mut function_builder = FunctionBuilder::new("apply".to_string(), id); function_builder.set_globals(globals); @@ -386,10 +451,10 @@ mod tests { v5 = add v0, u32 1 v6 = eq v3, v5 constrain v3 == v5 - v9 = call f1(f3, v0) -> u32 - v10 = add v0, u32 1 - v11 = eq v9, v10 - constrain v9 == v10 + v8 = call f1(f3, v0) -> u32 + v9 = add v0, u32 1 + v10 = eq v8, v9 + constrain v8 == v9 return } brillig(inline) fn wrapper f1 { @@ -419,10 +484,10 @@ mod tests { v5 = add v0, u32 1 v6 = eq v3, v5 constrain v3 == v5 - v9 = call f1(Field 3, v0) -> u32 - v10 = add v0, u32 1 - v11 = eq v9, v10 - constrain v9 == v10 + v8 = call f1(Field 3, v0) -> u32 + v9 = add v0, u32 1 + v10 = eq v8, v9 + constrain v8 == v9 return } brillig(inline) fn wrapper f1 { diff --git a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs index 1e2e783d516..6efed689f51 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs @@ -649,8 +649,8 @@ mod test { v21 = add v1, v2 v23 = array_set v19, index v21, value Field 128 call f1(v23) - v25 = add v2, u32 1 - jmp b1(v25) + v24 = add v2, u32 1 + jmp b1(v24) } brillig(inline) fn foo f1 { b0(v0: [Field; 5]): @@ -685,8 +685,8 @@ mod test { v21 = add v1, v2 v23 = array_set v14, index v21, value Field 128 call f1(v23) - v25 = add v2, u32 1 - jmp b1(v25) + v24 = add v2, u32 1 + jmp b1(v24) } brillig(inline) fn foo f1 { b0(v0: [Field; 5]):