From bdcfd38e4d8c9fb9dcd4298ac263f681e470a6ce Mon Sep 17 00:00:00 2001 From: Maxim Vezenov Date: Wed, 29 Jan 2025 11:32:32 -0500 Subject: [PATCH] fix(brillig): Globals entry point reachability analysis (#7188) Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com> --- .../src/brillig/brillig_gen.rs | 3 +- .../brillig/brillig_gen/brillig_globals.rs | 430 +++++++++++++++++- .../noirc_evaluator/src/brillig/brillig_ir.rs | 9 +- .../src/brillig/brillig_ir/artifact.rs | 16 +- .../src/brillig/brillig_ir/entry_point.rs | 2 +- .../src/brillig/brillig_ir/instructions.rs | 4 +- .../src/brillig/brillig_ir/registers.rs | 4 +- compiler/noirc_evaluator/src/brillig/mod.rs | 42 +- compiler/noirc_evaluator/src/ssa.rs | 5 +- compiler/noirc_evaluator/src/ssa/opt/die.rs | 32 +- .../noirc_evaluator/src/ssa/opt/inlining.rs | 2 +- compiler/noirc_evaluator/src/ssa/opt/mod.rs | 2 +- .../noirc_evaluator/src/ssa/opt/unrolling.rs | 5 +- .../src/ssa/parser/into_ssa.rs | 1 + .../src/ssa/ssa_gen/program.rs | 6 +- cspell.json | 1 + 16 files changed, 509 insertions(+), 55 deletions(-) diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen.rs index a6117a8f2da..f23e64aec52 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen.rs @@ -58,12 +58,13 @@ pub(crate) fn gen_brillig_for( brillig: &Brillig, ) -> Result, InternalError> { // Create the entry point artifact + let globals_memory_size = brillig.globals_memory_size.get(&func.id()).copied().unwrap_or(0); let mut entry_point = BrilligContext::new_entry_point_artifact( arguments, FunctionContext::return_values(func), func.id(), true, - brillig.globals_memory_size, + globals_memory_size, ); entry_point.name = func.name().to_string(); diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs index 9f9d271283d..6f5645485a2 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_globals.rs @@ -1,24 +1,227 @@ +use std::collections::BTreeMap; + use acvm::FieldElement; use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; -use super::{BrilligArtifact, BrilligBlock, BrilligVariable, FunctionContext, Label, ValueId}; -use crate::{ - brillig::{brillig_ir::BrilligContext, DataFlowGraph}, - ssa::ir::dfg::GlobalsGraph, +use super::{ + BrilligArtifact, BrilligBlock, BrilligVariable, Function, FunctionContext, Label, ValueId, +}; +use crate::brillig::{ + brillig_ir::BrilligContext, called_functions_vec, Brillig, DataFlowGraph, FunctionId, + Instruction, Value, }; +/// Context structure for generating Brillig globals +/// it stores globals related data required for code generation of regular Brillig functions. +#[derive(Default)] +pub(crate) struct BrilligGlobals { + /// Both `used_globals` and `brillig_entry_points` need to be built + /// from a function call graph. + /// + /// Maps a Brillig function to the globals used in that function. + /// This includes all globals used in functions called internally. + used_globals: HashMap>, + /// Maps a Brillig entry point to all functions called in that entry point. + /// This includes any nested calls as well, as we want to be able to associate + /// any Brillig function with the appropriate global allocations. + brillig_entry_points: HashMap>, + + /// Maps an inner call to its Brillig entry point + /// This is simply used to simplify fetching global allocations when compiling + /// individual Brillig functions. + inner_call_to_entry_point: HashMap>, + /// Final map that associated an entry point with its Brillig global allocations + entry_point_globals_map: HashMap, +} + +/// Mapping of SSA value ids to their Brillig allocations +pub(crate) type SsaToBrilligGlobals = HashMap; + +impl BrilligGlobals { + pub(crate) fn new( + functions: &BTreeMap, + mut used_globals: HashMap>, + main_id: FunctionId, + ) -> Self { + let mut brillig_entry_points = HashMap::default(); + let acir_functions = functions.iter().filter(|(_, func)| func.runtime().is_acir()); + for (_, function) in acir_functions { + for block_id in function.reachable_blocks() { + for instruction_id in function.dfg[block_id].instructions() { + let instruction = &function.dfg[*instruction_id]; + let Instruction::Call { func: func_id, arguments: _ } = instruction else { + continue; + }; + + let func_value = &function.dfg[*func_id]; + let Value::Function(func_id) = func_value else { continue }; + + let called_function = &functions[func_id]; + if called_function.runtime().is_acir() { + continue; + } + + // We have now found a Brillig entry point. + // Let's recursively build a call graph to determine any functions + // whose parent is this entry point and any globals used in those internal calls. + brillig_entry_points.insert(*func_id, HashSet::default()); + Self::mark_entry_points_calls_recursive( + functions, + *func_id, + called_function, + &mut used_globals, + &mut brillig_entry_points, + im::HashSet::new(), + ); + } + } + } + + // If main has been marked as Brillig, it is itself an entry point. + // Run the same analysis from above on main. + let main_func = &functions[&main_id]; + if main_func.runtime().is_brillig() { + brillig_entry_points.insert(main_id, HashSet::default()); + Self::mark_entry_points_calls_recursive( + functions, + main_id, + main_func, + &mut used_globals, + &mut brillig_entry_points, + im::HashSet::new(), + ); + } + + Self { used_globals, brillig_entry_points, ..Default::default() } + } + + /// Recursively mark any functions called in an entry point as well as + /// any globals used in those functions. + /// Using the information collected we can determine which globals + /// an entry point must initialize. + fn mark_entry_points_calls_recursive( + functions: &BTreeMap, + entry_point: FunctionId, + called_function: &Function, + used_globals: &mut HashMap>, + brillig_entry_points: &mut HashMap>, + mut explored_functions: im::HashSet, + ) { + if explored_functions.insert(called_function.id()).is_some() { + return; + } + + let inner_calls = called_functions_vec(called_function).into_iter().collect::>(); + + for inner_call in inner_calls { + let inner_globals = used_globals + .get(&inner_call) + .expect("Should have a slot for each function") + .clone(); + used_globals + .get_mut(&entry_point) + .expect("ICE: should have func") + .extend(inner_globals); + + if let Some(inner_calls) = brillig_entry_points.get_mut(&entry_point) { + inner_calls.insert(inner_call); + } + + Self::mark_entry_points_calls_recursive( + functions, + entry_point, + &functions[&inner_call], + used_globals, + brillig_entry_points, + explored_functions.clone(), + ); + } + } + + pub(crate) fn declare_globals( + &mut self, + globals_dfg: &DataFlowGraph, + brillig: &mut Brillig, + enable_debug_trace: bool, + ) { + // Map for fetching the correct entry point globals when compiling any function + let mut inner_call_to_entry_point: HashMap> = + HashMap::default(); + let mut entry_point_globals_map = HashMap::default(); + // We only need to generate globals for entry points + for (entry_point, entry_point_inner_calls) in self.brillig_entry_points.iter() { + let entry_point = *entry_point; + + for inner_call in entry_point_inner_calls { + inner_call_to_entry_point.entry(*inner_call).or_default().push(entry_point); + } + + let used_globals = self.used_globals.remove(&entry_point).unwrap_or_default(); + let (artifact, brillig_globals, globals_size) = + convert_ssa_globals(enable_debug_trace, globals_dfg, &used_globals, entry_point); + + entry_point_globals_map.insert(entry_point, brillig_globals); + + brillig.globals.insert(entry_point, artifact); + brillig.globals_memory_size.insert(entry_point, globals_size); + } + + self.inner_call_to_entry_point = inner_call_to_entry_point; + self.entry_point_globals_map = entry_point_globals_map; + } + + /// Fetch the global allocations that can possibly be accessed + /// by any given Brillig function (non-entry point or entry point). + /// The allocations available to a function are determined by its entry point. + /// For a given function id input, this function will search for that function's + /// entry point (or multiple entry points) and fetch the global allocations + /// associated with those entry points. + /// These allocations can then be used when compiling the Brillig function + /// and resolving global variables. + pub(crate) fn get_brillig_globals( + &self, + brillig_function_id: FunctionId, + ) -> SsaToBrilligGlobals { + let entry_points = self.inner_call_to_entry_point.get(&brillig_function_id); + + let mut globals_allocations = HashMap::default(); + if let Some(entry_points) = entry_points { + // A Brillig function is used by multiple entry points. Fetch both globals allocations + // in case one is used by the internal call. + let entry_point_allocations = entry_points + .iter() + .flat_map(|entry_point| self.entry_point_globals_map.get(entry_point)) + .collect::>(); + for map in entry_point_allocations { + globals_allocations.extend(map); + } + } else if let Some(globals) = self.entry_point_globals_map.get(&brillig_function_id) { + // If there is no mapping from an inner call to an entry point, that means `brillig_function_id` + // is itself an entry point and we can fetch the global allocations directly from `self.entry_point_globals_map`. + // vec![globals] + globals_allocations.extend(globals); + } else { + unreachable!( + "ICE: Expected global allocation to be set for function {brillig_function_id}" + ); + } + globals_allocations + } +} + pub(crate) fn convert_ssa_globals( enable_debug_trace: bool, - globals: GlobalsGraph, + globals_dfg: &DataFlowGraph, used_globals: &HashSet, + entry_point: FunctionId, ) -> (BrilligArtifact, HashMap, usize) { - let mut brillig_context = BrilligContext::new_for_global_init(enable_debug_trace); + let mut brillig_context = BrilligContext::new_for_global_init(enable_debug_trace, entry_point); // The global space does not have globals itself let empty_globals = HashMap::default(); // We can use any ID here as this context is only going to be used for globals which does not differentiate // by functions and blocks. The only Label that should be used in the globals context is `Label::globals_init()` let mut function_context = FunctionContext::default(); - brillig_context.enter_context(Label::globals_init()); + brillig_context.enter_context(Label::globals_init(entry_point)); let block_id = DataFlowGraph::default().make_block(); let mut brillig_block = BrilligBlock { @@ -31,13 +234,220 @@ pub(crate) fn convert_ssa_globals( building_globals: true, }; - let globals_dfg = DataFlowGraph::from(globals); - brillig_block.compile_globals(&globals_dfg, used_globals); + brillig_block.compile_globals(globals_dfg, used_globals); - let globals_size = brillig_block.brillig_context.global_space_size(); + let globals_size = brillig_context.global_space_size(); brillig_context.return_instruction(); let artifact = brillig_context.artifact(); (artifact, function_context.ssa_value_allocations, globals_size) } + +#[cfg(test)] +mod tests { + use acvm::{ + acir::brillig::{BitSize, Opcode}, + FieldElement, + }; + + use crate::brillig::{brillig_ir::registers::RegisterAllocator, GlobalSpace, LabelType, Ssa}; + + #[test] + fn entry_points_different_globals() { + let src = " + g0 = Field 2 + + acir(inline) fn main f0 { + b0(v1: Field, v2: Field): + v4 = call f1(v1) -> Field + constrain v4 == Field 2 + v6 = call f2(v1) -> Field + constrain v6 == Field 2 + return + } + brillig(inline) fn entry_point_no_globals f1 { + b0(v1: Field): + v3 = add v1, Field 1 + v4 = add v3, Field 1 + return v4 + } + brillig(inline) fn entry_point_globals f2 { + b0(v1: Field): + v2 = add v1, Field 2 + return v2 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + // Need to run DIE to generate the used globals map, which is necessary for Brillig globals generation. + let mut ssa = ssa.dead_instruction_elimination(); + + let used_globals_map = std::mem::take(&mut ssa.used_globals); + let brillig = ssa.to_brillig_with_globals(false, used_globals_map); + + assert_eq!( + brillig.globals.len(), + 2, + "Should have a globals artifact associated with each entry point" + ); + for (func_id, mut artifact) in brillig.globals { + let labels = artifact.take_labels(); + // When entering a context two labels are created. + // One is a context label and another is a section label. + assert_eq!(labels.len(), 2); + for (label, position) in labels { + assert_eq!(label.label_type, LabelType::GlobalInit(func_id)); + assert_eq!(position, 0); + } + if func_id.to_u32() == 1 { + assert_eq!( + artifact.byte_code.len(), + 1, + "Expected just a `Return`, but got more than a single opcode" + ); + assert!(matches!(&artifact.byte_code[0], Opcode::Return)); + } else if func_id.to_u32() == 2 { + assert_eq!( + artifact.byte_code.len(), + 2, + "Expected enough opcodes to initialize the globals" + ); + let Opcode::Const { destination, bit_size, value } = &artifact.byte_code[0] else { + panic!("First opcode is expected to be `Const`"); + }; + assert_eq!(destination.unwrap_direct(), GlobalSpace::start()); + assert!(matches!(bit_size, BitSize::Field)); + assert_eq!(*value, FieldElement::from(2u128)); + + assert!(matches!(&artifact.byte_code[1], Opcode::Return)); + } else { + panic!("Unexpected function id: {func_id}"); + } + } + } + + #[test] + fn entry_point_nested_globals() { + let src = " + g0 = Field 1 + g1 = make_array [Field 1, Field 1] : [Field; 2] + g2 = Field 0 + g3 = make_array [Field 0, Field 0] : [Field; 2] + g4 = make_array [g1, g3] : [[Field; 2]; 2] + + acir(inline) fn main f0 { + b0(v5: Field, v6: Field): + v8 = call f1(v5) -> Field + constrain v8 == Field 2 + call f2(v5, v6) + v12 = call f1(v5) -> Field + constrain v12 == Field 2 + call f3(v5, v6) + v15 = call f1(v5) -> Field + constrain v15 == Field 2 + return + } + brillig(inline) fn entry_point_no_globals f1 { + b0(v5: Field): + v6 = add v5, Field 1 + v7 = add v6, Field 1 + return v7 + } + brillig(inline) fn check_acc_entry_point f2 { + b0(v5: Field, v6: Field): + v8 = allocate -> &mut Field + store Field 0 at v8 + jmp b1(u32 0) + b1(v7: u32): + v11 = lt v7, u32 2 + jmpif v11 then: b3, else: b2 + b2(): + v12 = load v8 -> Field + v13 = eq v12, Field 0 + constrain v13 == u1 0 + v15 = eq v5, v6 + constrain v15 == u1 0 + v16 = add v5, Field 1 + v17 = add v16, Field 1 + constrain v17 == Field 2 + return + b3(): + v19 = array_get g4, index v7 -> [Field; 2] + v20 = load v8 -> Field + v21 = array_get v19, index u32 0 -> Field + v22 = add v20, v21 + v24 = array_get v19, index u32 1 -> Field + v25 = add v22, v24 + store v25 at v8 + v26 = unchecked_add v7, u32 1 + jmp b1(v26) + } + brillig(inline) fn entry_point_inner_func_globals f3 { + b0(v5: Field, v6: Field): + call f4(v5, v6) + return + } + brillig(inline) fn non_entry_point_wrapper f4 { + b0(v5: Field, v6: Field): + call f2(v5, v6) + call f2(v5, v6) + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + // Need to run DIE to generate the used globals map, which is necessary for Brillig globals generation. + let mut ssa = ssa.dead_instruction_elimination(); + + let used_globals_map = std::mem::take(&mut ssa.used_globals); + let brillig = ssa.to_brillig_with_globals(false, used_globals_map); + + assert_eq!( + brillig.globals.len(), + 3, + "Should have a globals artifact associated with each entry point" + ); + for (func_id, mut artifact) in brillig.globals { + let labels = artifact.take_labels(); + // When entering a context two labels are created. + // One is a context label and another is a section label. + assert_eq!(labels.len(), 2); + for (label, position) in labels { + assert_eq!(label.label_type, LabelType::GlobalInit(func_id)); + assert_eq!(position, 0); + } + if func_id.to_u32() == 1 { + assert_eq!( + artifact.byte_code.len(), + 2, + "Expected enough opcodes to initialize the globals" + ); + let Opcode::Const { destination, bit_size, value } = &artifact.byte_code[0] else { + panic!("First opcode is expected to be `Const`"); + }; + assert_eq!(destination.unwrap_direct(), GlobalSpace::start()); + assert!(matches!(bit_size, BitSize::Field)); + assert_eq!(*value, FieldElement::from(1u128)); + + assert!(matches!(&artifact.byte_code[1], Opcode::Return)); + } else if func_id.to_u32() == 2 || func_id.to_u32() == 3 { + // We want the entry point which uses globals (f2) and the entry point which calls f2 function internally (f3 through f4) + // to have the same globals initialized. + assert_eq!( + artifact.byte_code.len(), + 30, + "Expected enough opcodes to initialize the globals" + ); + let globals_max_memory = brillig + .globals_memory_size + .get(&func_id) + .copied() + .expect("Should have globals memory size"); + assert_eq!(globals_max_memory, 7); + } else { + panic!("Unexpected function id: {func_id}"); + } + } + } +} diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index ad09f73e90f..520fd5aad96 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -37,7 +37,7 @@ use acvm::{ }; use debug_show::DebugShow; -use super::{GlobalSpace, ProcedureId}; +use super::{FunctionId, GlobalSpace, ProcedureId}; /// The Brillig VM does not apply a limit to the memory address space, /// As a convention, we take use 32 bits. This means that we assume that @@ -221,11 +221,14 @@ impl BrilligContext { /// Special brillig context to codegen global values initialization impl BrilligContext { - pub(crate) fn new_for_global_init(enable_debug_trace: bool) -> BrilligContext { + pub(crate) fn new_for_global_init( + enable_debug_trace: bool, + entry_point: FunctionId, + ) -> BrilligContext { BrilligContext { obj: BrilligArtifact::default(), registers: GlobalSpace::new(), - context_label: Label::globals_init(), + context_label: Label::globals_init(entry_point), current_section: 0, next_section: 1, debug_show: DebugShow::new(enable_debug_trace), diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs index 4c48675d1e7..c9223715042 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs @@ -76,7 +76,8 @@ pub(crate) enum LabelType { /// Labels for intrinsic procedures Procedure(ProcedureId), /// Label for initialization of globals - GlobalInit, + /// Stores a function ID referencing the entry point + GlobalInit(FunctionId), } impl std::fmt::Display for LabelType { @@ -91,7 +92,9 @@ impl std::fmt::Display for LabelType { } LabelType::Entrypoint => write!(f, "Entrypoint"), LabelType::Procedure(procedure_id) => write!(f, "Procedure({:?})", procedure_id), - LabelType::GlobalInit => write!(f, "Globals Initialization"), + LabelType::GlobalInit(function_id) => { + write!(f, "Globals Initialization({function_id:?})") + } } } } @@ -127,8 +130,8 @@ impl Label { Label { label_type: LabelType::Procedure(procedure_id), section: None } } - pub(crate) fn globals_init() -> Self { - Label { label_type: LabelType::GlobalInit, section: None } + pub(crate) fn globals_init(function_id: FunctionId) -> Self { + Label { label_type: LabelType::GlobalInit(function_id), section: None } } } @@ -334,4 +337,9 @@ impl BrilligArtifact { pub(crate) fn set_call_stack(&mut self, call_stack: CallStack) { self.call_stack = call_stack; } + + #[cfg(test)] + pub(crate) fn take_labels(&mut self) -> HashMap { + std::mem::take(&mut self.labels) + } } diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs index 030ed7133e8..6d4cc814d3e 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs @@ -32,7 +32,7 @@ impl BrilligContext { context.codegen_entry_point(&arguments, &return_parameters); if globals_init { - context.add_globals_init_instruction(); + context.add_globals_init_instruction(target_function); } context.add_external_call_instruction(target_function); diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs index d67da423d44..9dd541c7180 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs @@ -200,8 +200,8 @@ impl BrilligContext< self.obj.add_unresolved_external_call(BrilligOpcode::Call { location: 0 }, proc_label); } - pub(super) fn add_globals_init_instruction(&mut self) { - let globals_init_label = Label::globals_init(); + pub(super) fn add_globals_init_instruction(&mut self, func_id: FunctionId) { + let globals_init_label = Label::globals_init(func_id); self.debug_show.add_external_call_instruction(globals_init_label.to_string()); self.obj .add_unresolved_external_call(BrilligOpcode::Call { location: 0 }, globals_init_label); diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/registers.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/registers.rs index 88b8a598b10..093c99dec3b 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/registers.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/registers.rs @@ -150,13 +150,14 @@ impl RegisterAllocator for ScratchSpace { /// Globals have a separate memory space /// This memory space is initialized once at the beginning of a program /// and is read-only. +#[derive(Default)] pub(crate) struct GlobalSpace { storage: DeallocationListAllocator, max_memory_address: usize, } impl GlobalSpace { - pub(super) fn new() -> Self { + pub(crate) fn new() -> Self { Self { storage: DeallocationListAllocator::new(Self::start()), max_memory_address: Self::start(), @@ -224,6 +225,7 @@ impl RegisterAllocator for GlobalSpace { } } +#[derive(Default)] struct DeallocationListAllocator { /// A free-list of registers that have been deallocated and can be used again. deallocated_registers: BTreeSet, diff --git a/compiler/noirc_evaluator/src/brillig/mod.rs b/compiler/noirc_evaluator/src/brillig/mod.rs index b74c519f61a..791f6a466cf 100644 --- a/compiler/noirc_evaluator/src/brillig/mod.rs +++ b/compiler/noirc_evaluator/src/brillig/mod.rs @@ -2,7 +2,7 @@ pub(crate) mod brillig_gen; pub(crate) mod brillig_ir; use acvm::FieldElement; -use brillig_gen::brillig_globals::convert_ssa_globals; +use brillig_gen::brillig_globals::BrilligGlobals; use brillig_ir::{artifact::LabelType, brillig_variable::BrilligVariable, registers::GlobalSpace}; use self::{ @@ -12,15 +12,18 @@ use self::{ procedures::compile_procedure, }, }; + use crate::ssa::{ ir::{ dfg::DataFlowGraph, function::{Function, FunctionId}, - value::ValueId, + instruction::Instruction, + value::{Value, ValueId}, }, + opt::inlining::called_functions_vec, ssa_gen::Ssa, }; -use fxhash::FxHashMap as HashMap; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use std::{borrow::Cow, collections::BTreeSet}; pub use self::brillig_ir::procedures::ProcedureId; @@ -31,8 +34,8 @@ pub use self::brillig_ir::procedures::ProcedureId; pub struct Brillig { /// Maps SSA function labels to their brillig artifact ssa_function_to_brillig: HashMap>, - globals: BrilligArtifact, - globals_memory_size: usize, + globals: HashMap>, + globals_memory_size: HashMap, } impl Brillig { @@ -58,7 +61,7 @@ impl Brillig { } // Procedures are compiled as needed LabelType::Procedure(procedure_id) => Some(Cow::Owned(compile_procedure(procedure_id))), - LabelType::GlobalInit => Some(Cow::Borrowed(&self.globals)), + LabelType::GlobalInit(function_id) => self.globals.get(&function_id).map(Cow::Borrowed), _ => unreachable!("ICE: Expected a function or procedure label"), } } @@ -72,9 +75,18 @@ impl std::ops::Index for Brillig { } impl Ssa { - /// Compile Brillig functions and ACIR functions reachable from them #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn to_brillig(&self, enable_debug_trace: bool) -> Brillig { + self.to_brillig_with_globals(enable_debug_trace, HashMap::default()) + } + + /// Compile Brillig functions and ACIR functions reachable from them + #[tracing::instrument(level = "trace", skip_all)] + pub(crate) fn to_brillig_with_globals( + &self, + enable_debug_trace: bool, + used_globals_map: HashMap>, + ) -> Brillig { // Collect all the function ids that are reachable from brillig // That means all the functions marked as brillig and ACIR functions called by them let brillig_reachable_function_ids = self @@ -89,17 +101,21 @@ impl Ssa { return brillig; } - // Globals are computed once at compile time and shared across all functions, + let mut brillig_globals = + BrilligGlobals::new(&self.functions, used_globals_map, self.main_id); + + // SSA Globals are computed once at compile time and shared across all functions, // thus we can just fetch globals from the main function. + // This same globals graph will then be used to declare Brillig globals for the respective entry points. let globals = (*self.functions[&self.main_id].dfg.globals).clone(); - let (artifact, brillig_globals, globals_size) = - convert_ssa_globals(enable_debug_trace, globals, &self.used_global_values); - brillig.globals = artifact; - brillig.globals_memory_size = globals_size; + let globals_dfg = DataFlowGraph::from(globals); + brillig_globals.declare_globals(&globals_dfg, &mut brillig, enable_debug_trace); for brillig_function_id in brillig_reachable_function_ids { + let globals_allocations = brillig_globals.get_brillig_globals(brillig_function_id); + let func = &self.functions[&brillig_function_id]; - brillig.compile(func, enable_debug_trace, &brillig_globals); + brillig.compile(func, enable_debug_trace, &globals_allocations); } brillig diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 49d033bacc2..0d4435f9214 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -38,7 +38,7 @@ use crate::acir::{Artifacts, GeneratedAcir}; mod checks; pub(super) mod function_builder; pub mod ir; -mod opt; +pub(crate) mod opt; #[cfg(test)] pub(crate) mod parser; pub mod ssa_gen; @@ -122,8 +122,9 @@ pub(crate) fn optimize_into_acir( drop(ssa_gen_span_guard); + let used_globals_map = std::mem::take(&mut ssa.used_globals); let brillig = time("SSA to Brillig", options.print_codegen_timings, || { - ssa.to_brillig(options.enable_brillig_logging) + ssa.to_brillig_with_globals(options.enable_brillig_logging, used_globals_map) }); let ssa_gen_span = span!(Level::TRACE, "ssa_generation"); diff --git a/compiler/noirc_evaluator/src/ssa/opt/die.rs b/compiler/noirc_evaluator/src/ssa/opt/die.rs index a8f0659f8db..cb90ac6e492 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -30,26 +30,36 @@ impl Ssa { } fn dead_instruction_elimination_inner(mut self, flattened: bool) -> Ssa { - let mut used_global_values: HashSet<_> = self + let mut used_globals_map: HashMap<_, _> = self .functions .par_iter_mut() - .flat_map(|(_, func)| func.dead_instruction_elimination(true, flattened)) + .filter_map(|(id, func)| { + let set = func.dead_instruction_elimination(true, flattened); + if func.runtime().is_brillig() { + Some((*id, set)) + } else { + None + } + }) .collect(); let globals = &self.functions[&self.main_id].dfg.globals; - // Check which globals are used across all functions - for (id, value) in globals.values_iter().rev() { - if used_global_values.contains(&id) { - if let Value::Instruction { instruction, .. } = &value { - let instruction = &globals[*instruction]; - instruction.for_each_value(|value_id| { - used_global_values.insert(value_id); - }); + for used_global_values in used_globals_map.values_mut() { + // DIE only tracks used instruction results, however, globals include constants. + // Back track globals for internal values which may be in use. + for (id, value) in globals.values_iter().rev() { + if used_global_values.contains(&id) { + if let Value::Instruction { instruction, .. } = &value { + let instruction = &globals[*instruction]; + instruction.for_each_value(|value_id| { + used_global_values.insert(value_id); + }); + } } } } - self.used_global_values = used_global_values; + self.used_globals = used_globals_map; self } diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 7f96df1384b..a2da9bf6f3d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -184,7 +184,7 @@ struct PerFunctionContext<'function> { /// Utility function to find out the direct calls of a function. /// /// Returns the function IDs from all `Call` instructions without deduplication. -fn called_functions_vec(func: &Function) -> Vec { +pub(crate) fn called_functions_vec(func: &Function) -> Vec { let mut called_function_ids = Vec::new(); for block_id in func.reachable_blocks() { for instruction_id in func.dfg[block_id].instructions() { diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index 44796e2531e..2dd73ff7211 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -12,7 +12,7 @@ mod defunctionalize; mod die; pub(crate) mod flatten_cfg; mod hint; -mod inlining; +pub(crate) mod inlining; mod loop_invariant; mod make_constrain_not_equal; mod mem2reg; diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index eb0bbd8c532..f6dda107d9c 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -80,10 +80,11 @@ impl Ssa { if let Some(max_incr_pct) = max_bytecode_increase_percent { if global_cache.is_none() { let globals = (*function.dfg.globals).clone(); - // DIE is run at the end of our SSA optimizations, so we mark all globals as in use here. let used_globals = &globals.values_iter().map(|(id, _)| id).collect(); + let globals_dfg = DataFlowGraph::from(globals); + // DIE is run at the end of our SSA optimizations, so we mark all globals as in use here. let (_, brillig_globals, _) = - convert_ssa_globals(false, globals, used_globals); + convert_ssa_globals(false, &globals_dfg, used_globals, function.id()); global_cache = Some(brillig_globals); } let brillig_globals = global_cache.as_ref().unwrap(); diff --git a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index 37d2cd720f9..9fb6f43535c 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -122,6 +122,7 @@ impl Translator { } RuntimeType::Brillig(inline_type) => { self.builder.new_brillig_function(external_name, function_id, inline_type); + self.builder.set_globals(self.globals_graph.clone()); } } diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs index 04986bd8db1..ad52473620d 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/program.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use acvm::acir::circuit::ErrorSelector; -use fxhash::FxHashSet as HashSet; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use iter_extended::btree_map; use serde::{Deserialize, Serialize}; use serde_with::serde_as; @@ -20,7 +20,7 @@ use super::ValueId; pub(crate) struct Ssa { #[serde_as(as = "Vec<(_, _)>")] pub(crate) functions: BTreeMap, - pub(crate) used_global_values: HashSet, + pub(crate) used_globals: HashMap>, pub(crate) main_id: FunctionId, #[serde(skip)] pub(crate) next_id: AtomicCounter, @@ -59,7 +59,7 @@ impl Ssa { error_selector_to_type: error_types, // This field is set only after running DIE and is utilized // for optimizing implementation of globals post-SSA. - used_global_values: HashSet::default(), + used_globals: HashMap::default(), } } diff --git a/cspell.json b/cspell.json index 1174a56dd33..c877c2f0529 100644 --- a/cspell.json +++ b/cspell.json @@ -91,6 +91,7 @@ "Elligator", "endianness", "envrc", + "EXPONENTIATE", "Flamegraph", "flate", "fmtstr",