diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index 40e80b9487..e4d6115897 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -235,6 +235,7 @@ impl VariancedDag { nb_constraints: out_shape.flat_size(), safe_variance_bound: max_variance, noise_expression: variance.clone(), + noise_evaluator: None, location: op.location.clone(), }; self.external_variance_constraints.push(constraint); @@ -273,6 +274,7 @@ impl VariancedDag { nb_constraints: out_shape.flat_size(), safe_variance_bound: max_variance, noise_expression: variance.clone(), + noise_evaluator: None, location: dag_op.location.clone(), }; self.external_variance_constraints.push(constraint); @@ -646,6 +648,7 @@ fn variance_constraint( safe_variance_bound, nb_partitions, noise_expression: noise, + noise_evaluator: None, location, } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs index a050e4817a..c75f4d5baf 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs @@ -2,7 +2,7 @@ use std::{fmt, ops::Add}; use super::{ partitions::PartitionIndex, - symbolic::{fast_keyswitch, keyswitch, Symbol, SymbolMap}, + symbolic::{fast_keyswitch, keyswitch, Symbol, SymbolArray, SymbolMap, SymbolScheme}, }; /// A structure storing the number of times an fhe operation gets executed in a circuit. @@ -29,36 +29,44 @@ impl fmt::Display for OperationsCount { /// An ensemble of costs associated with fhe operation symbols. #[derive(Clone, Debug)] -pub struct ComplexityValues(SymbolMap); +pub struct ComplexityValues(SymbolArray); impl ComplexityValues { /// Returns an empty set of cost values. - pub fn new() -> Self { - ComplexityValues(SymbolMap::new()) + pub fn from_scheme(scheme: &SymbolScheme) -> ComplexityValues { + ComplexityValues(SymbolArray::from_scheme(scheme)) } /// Sets the cost associated with an fhe operation symbol. pub fn set_cost(&mut self, source: Symbol, value: f64) { - self.0.set(source, value); + self.0.set(&source, value); } } /// A complexity expression is a sum of complexity terms associating operation /// symbols with the number of time they gets executed in the circuit. #[derive(Clone, Debug)] -pub struct ComplexityExpression(SymbolMap); +pub struct ComplexityEvaluator(SymbolArray); -impl ComplexityExpression { +impl ComplexityEvaluator { /// Creates a complexity expression from a set of operation counts. - pub fn from(counts: &OperationsCount) -> Self { - Self(counts.0.clone()) + pub fn from_scheme_and_counts( + scheme: &SymbolScheme, + counts: &OperationsCount, + ) -> ComplexityEvaluator { + Self(SymbolArray::from_scheme_and_map(scheme, &counts.0)) + } + + pub fn scheme(&self) -> &SymbolScheme { + self.0.scheme() } /// Evaluates the total cost expression on a set of cost values. pub fn evaluate_total_cost(&self, costs: &ComplexityValues) -> f64 { - self.0.iter().fold(0.0, |acc, (symbol, n_ops)| { - acc + (n_ops as f64) * costs.0.get(symbol) - }) + self.0 + .iter() + .zip(costs.0.iter()) + .fold(0.0, |acc, (n_ops, cost)| acc + (*n_ops as f64) * *cost) } /// Evaluates the max ks cost expression on a set of cost values. @@ -69,11 +77,11 @@ impl ComplexityExpression { src_partition: PartitionIndex, dst_partition: PartitionIndex, ) -> f64 { - let actual_ks_cost = costs.0.get(keyswitch(src_partition, dst_partition)); - let ks_coeff = self.0.get(keyswitch(src_partition, dst_partition)); + let actual_ks_cost = costs.0.get(&keyswitch(src_partition, dst_partition)); + let ks_coeff = self.0.get(&keyswitch(src_partition, dst_partition)); let actual_complexity = - self.evaluate_total_cost(costs) - (ks_coeff as f64) * actual_ks_cost; - (complexity_cut - actual_complexity) / (ks_coeff as f64) + self.evaluate_total_cost(costs) - (*ks_coeff as f64) * actual_ks_cost; + (complexity_cut - actual_complexity) / (*ks_coeff as f64) } /// Evaluates the max fks cost expression on a set of cost values. @@ -84,10 +92,10 @@ impl ComplexityExpression { src_partition: PartitionIndex, dst_partition: PartitionIndex, ) -> f64 { - let actual_fks_cost = costs.0.get(fast_keyswitch(src_partition, dst_partition)); - let fks_coeff = self.0.get(fast_keyswitch(src_partition, dst_partition)); + let actual_fks_cost = costs.0.get(&fast_keyswitch(src_partition, dst_partition)); + let fks_coeff = self.0.get(&fast_keyswitch(src_partition, dst_partition)); let actual_complexity = - self.evaluate_total_cost(costs) - (fks_coeff as f64) * actual_fks_cost; - (complexity_cut - actual_complexity) / (fks_coeff as f64) + self.evaluate_total_cost(costs) - (*fks_coeff as f64) * actual_fks_cost; + (complexity_cut - actual_complexity) / (*fks_coeff as f64) } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs index 7f93c883c3..af98536291 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs @@ -49,13 +49,20 @@ impl Feasible { for constraint in &self.undominated_constraints { let pbs_coeff = constraint - .noise_expression + .noise_evaluator + .as_ref() + .unwrap() .coeff(bootstrap_noise(partition)); if pbs_coeff == 0.0 { continue; } - let actual_variance = constraint.noise_expression.evaluate(operations_variance) - - pbs_coeff * actual_pbs_variance; + let actual_variance = unsafe { + constraint + .noise_evaluator + .as_ref() + .unwrap_unchecked() + .evaluate(operations_variance) + } - pbs_coeff * actual_pbs_variance; let pbs_max_variance = (constraint.safe_variance_bound - actual_variance) / pbs_coeff; smallest_pbs_max_variance = smallest_pbs_max_variance.min(pbs_max_variance); } @@ -75,13 +82,20 @@ impl Feasible { for constraint in &self.undominated_constraints { let ks_coeff = constraint - .noise_expression + .noise_evaluator + .as_ref() + .unwrap() .coeff(keyswitch_noise(src_partition, dst_partition)); if ks_coeff == 0.0 { continue; } - let actual_variance = constraint.noise_expression.evaluate(operations_variance) - - ks_coeff * actual_ks_variance; + let actual_variance = unsafe { + constraint + .noise_evaluator + .as_ref() + .unwrap_unchecked() + .evaluate(operations_variance) + } - ks_coeff * actual_ks_variance; let ks_max_variance = (constraint.safe_variance_bound - actual_variance) / ks_coeff; smallest_ks_max_variance = smallest_ks_max_variance.min(ks_max_variance); } @@ -102,13 +116,20 @@ impl Feasible { for constraint in &self.undominated_constraints { let fks_coeff = constraint - .noise_expression + .noise_evaluator + .as_ref() + .unwrap() .coeff(fast_keyswitch_noise(src_partition, dst_partition)); if fks_coeff == 0.0 { continue; } - let actual_variance = constraint.noise_expression.evaluate(operations_variance) - - fks_coeff * actual_fks_variance; + let actual_variance = unsafe { + constraint + .noise_evaluator + .as_ref() + .unwrap_unchecked() + .evaluate(operations_variance) + } - fks_coeff * actual_fks_variance; let fks_max_variance = (constraint.safe_variance_bound - actual_variance) / fks_coeff; smallest_fks_max_variance = smallest_fks_max_variance.min(fks_max_variance); } @@ -126,8 +147,13 @@ impl Feasible { fn local_feasible(&self, operations_variance: &NoiseValues) -> bool { for constraint in &self.undominated_constraints { - if constraint.noise_expression.evaluate(operations_variance) - > constraint.safe_variance_bound + if unsafe { + constraint + .noise_evaluator + .as_ref() + .unwrap_unchecked() + .evaluate(operations_variance) + } > constraint.safe_variance_bound { return false; }; @@ -148,7 +174,13 @@ impl Feasible { let mut worst_relative_variance = 0.0; let mut worst_variance = 0.0; for constraint in &self.undominated_constraints { - let variance = constraint.noise_expression.evaluate(operations_variance); + let variance = unsafe { + constraint + .noise_evaluator + .as_ref() + .unwrap_unchecked() + .evaluate(operations_variance) + }; let relative_variance = variance / constraint.safe_variance_bound; if relative_variance > worst_relative_variance { worst_relative_variance = relative_variance; @@ -167,7 +199,13 @@ impl Feasible { fn global_p_error_with_cut(&self, operations_variance: &NoiseValues, cut: f64) -> Option { let mut global_p_error = 0.0; for constraint in &self.constraints { - let variance = constraint.noise_expression.evaluate(operations_variance); + let variance = unsafe { + constraint + .noise_evaluator + .as_ref() + .unwrap_unchecked() + .evaluate(operations_variance) + }; let relative_variance = variance / constraint.safe_variance_bound; let p_error = p_error_from_relative_variance(relative_variance, self.kappa); global_p_error = combine_errors( diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/noise_expression.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/noise_expression.rs index ef0c2e40a8..bdb120c45c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/noise_expression.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/noise_expression.rs @@ -5,28 +5,27 @@ use std::{ use super::{ partitions::PartitionIndex, - symbolic::{Symbol, SymbolMap}, + symbolic::{Symbol, SymbolArray, SymbolMap, SymbolScheme}, }; /// An ensemble of noise values for fhe operations. #[derive(Debug, Clone, PartialEq)] -pub struct NoiseValues(SymbolMap); +pub struct NoiseValues(SymbolArray); impl NoiseValues { /// Returns an empty set of noise values. - #[allow(clippy::new_without_default)] - pub fn new() -> Self { - NoiseValues(SymbolMap::new()) + pub fn from_scheme(scheme: &SymbolScheme) -> NoiseValues { + NoiseValues(SymbolArray::from_scheme(scheme)) } /// Sets the noise variance associated with a noise source. pub fn set_variance(&mut self, source: NoiseSource, value: f64) { - self.0.set(source.0, value); + self.0.set(&source.0, value); } /// Returns the variance associated with a noise source pub fn variance(&self, source: NoiseSource) -> f64 { - self.0.get(source.0) + *self.0.get(&source.0) } } @@ -36,10 +35,36 @@ impl Display for NoiseValues { } } +#[derive(Debug, Clone, PartialEq)] +pub struct NoiseEvaluator(SymbolArray); + +impl NoiseEvaluator { + /// Returns a zero noise expression + pub fn from_scheme_and_expression( + scheme: &SymbolScheme, + expr: &NoiseExpression, + ) -> NoiseEvaluator { + NoiseEvaluator(SymbolArray::from_scheme_and_map(scheme, &expr.0)) + } + + /// Returns the coefficient associated with a noise source. + pub fn coeff(&self, source: NoiseSource) -> f64 { + *self.0.get(&source.0) + } + + /// Evaluate the noise expression on a set of noise values. + pub fn evaluate(&self, values: &NoiseValues) -> f64 { + self.0 + .iter() + .zip(values.0.iter()) + .fold(0.0, |acc, (coef, var)| acc + coef * var) + } +} + /// A noise expression, i.e. a sum of noise terms associating a noise source, /// with a multiplicative coefficient. #[derive(Debug, Clone, PartialEq)] -pub struct NoiseExpression(SymbolMap); +pub struct NoiseExpression(pub SymbolMap); impl NoiseExpression { /// Returns a zero noise expression @@ -70,12 +95,12 @@ impl NoiseExpression { lhs } - /// Evaluate the noise expression on a set of noise values. - pub fn evaluate(&self, values: &NoiseValues) -> f64 { - self.terms_iter().fold(0.0, |acc, term| { - acc + term.coefficient * values.variance(term.source) - }) - } + // /// Evaluate the noise expression on a set of noise values. + // pub fn evaluate(&self, values: &NoiseValues) -> f64 { + // self.terms_iter().fold(0.0, |acc, term| { + // acc + term.coefficient * values.variance(term.source) + // }) + // } } impl Display for NoiseExpression { @@ -196,7 +221,7 @@ impl Mul for f64 { /// A symbolic source of noise, or a noise source variable. #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] -pub struct NoiseSource(Symbol); +pub struct NoiseSource(pub Symbol); /// Returns an input noise source symbol. pub fn input_noise(partition: PartitionIndex) -> NoiseSource { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs index 352f97d047..0b8cdf9b10 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/mod.rs @@ -15,7 +15,7 @@ use crate::optimization::decomposition::keyswitch::KsComplexityNoise; use crate::optimization::decomposition::{cmux, keyswitch, DecompCaches, PersistDecompCaches}; use crate::parameters::GlweParameters; -use crate::optimization::dag::multi_parameters::complexity::ComplexityExpression; +use crate::optimization::dag::multi_parameters::complexity::ComplexityEvaluator; use crate::optimization::dag::multi_parameters::feasible::Feasible; use crate::optimization::dag::multi_parameters::partition_cut::PartitionCut; use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; @@ -28,7 +28,7 @@ use super::noise_expression::{ bootstrap_noise, fast_keyswitch_noise, input_noise, keyswitch_noise, modulus_switching_noise, NoiseValues, }; -use super::symbolic::{bootstrap, fast_keyswitch, keyswitch}; +use super::symbolic::{bootstrap, fast_keyswitch, keyswitch, SymbolScheme}; const DEBUG: bool = false; @@ -90,7 +90,7 @@ fn optimize_1_ks( ks_pareto: &[KsComplexityNoise], operations: &mut OperationsCV, feasible: &Feasible, - complexity: &ComplexityExpression, + complexity: &ComplexityEvaluator, cut_complexity: f64, ) -> Option { // find the first feasible (and less complex) @@ -132,7 +132,7 @@ fn optimize_many_independant_ks( ks_used: &[Vec], operations: &OperationsCV, feasible: &Feasible, - complexity: &ComplexityExpression, + complexity: &ComplexityEvaluator, caches: &mut keyswitch::Cache, cut_complexity: f64, ) -> Option<(Vec<(KsDst, KsComplexityNoise)>, OperationsCV)> { @@ -184,7 +184,7 @@ fn optimize_1_fks_and_all_compatible_ks( fks_dst: PartitionIndex, operations: &OperationsCV, feasible: &Feasible, - complexity: &ComplexityExpression, + complexity: &ComplexityEvaluator, caches: &mut keyswitch::Cache, cut_complexity: f64, ciphertext_modulus_log: u32, @@ -320,7 +320,7 @@ fn optimize_dst_exclusive_fks_subset_and_all_ks( ks_used: &[Vec], operations: &OperationsCV, feasible: &Feasible, - complexity: &ComplexityExpression, + complexity: &ComplexityEvaluator, caches: &mut keyswitch::Cache, cut_complexity: f64, ciphertext_modulus_log: u32, @@ -383,7 +383,7 @@ fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( ks_used: &[Vec], operations: &OperationsCV, feasible: &Feasible, - complexity: &ComplexityExpression, + complexity: &ComplexityEvaluator, caches: &mut keyswitch::Cache, cut_complexity: f64, best_p_error: f64, @@ -711,7 +711,7 @@ fn optimize_macro( used_tlu_keyswitch: &[Vec], used_conversion_keyswitch: &[Vec], feasible: &Feasible, - complexity: &ComplexityExpression, + complexity: &ComplexityEvaluator, caches: &mut DecompCaches, init_parameters: &Parameters, best_complexity: f64, @@ -735,8 +735,8 @@ fn optimize_macro( let fks_to_optimize = fks_to_optimize(nb_partitions, used_conversion_keyswitch, partition); let operations = OperationsCV { - variance: NoiseValues::new(), - cost: ComplexityValues::new(), + variance: NoiseValues::from_scheme(complexity.scheme()), + cost: ComplexityValues::from_scheme(complexity.scheme()), }; let partition_feasible = feasible.filter_constraints(partition); @@ -1024,14 +1024,24 @@ pub fn optimize( ciphertext_modulus_log, }; - let dag = analyze(dag, &noise_config, p_cut, default_partition)?; + let dag_p_cut = p_cut + .clone() + .or(Some(PartitionCut::for_each_precision(dag))); + + let mut dag = analyze(dag, &noise_config, &dag_p_cut, default_partition)?; let kappa = error::sigma_scale_of_error_probability(config.maximum_acceptable_error_probability); let mut caches = persistent_caches.caches(); + let scheme = SymbolScheme::new(dag.nb_partitions); + + dag.variance_constraints + .iter_mut() + .for_each(|c| c.init_evaluator(&scheme)); let feasible = Feasible::of(&dag.variance_constraints, kappa, None); - let complexity = ComplexityExpression::from(&dag.operations_count); + + let complexity = ComplexityEvaluator::from_scheme_and_counts(&scheme, &dag.operations_count); let used_tlu_keyswitch = used_tlu_keyswitch(&dag); let used_conversion_keyswitch = used_conversion_keyswitch(&dag); @@ -1223,7 +1233,7 @@ fn sanity_check( ciphertext_modulus_log: u32, security_level: u64, feasible: &Feasible, - complexity: &ComplexityExpression, + complexity: &ComplexityEvaluator, ) { assert!(params.is_feasible.is_feasible()); assert!( @@ -1232,8 +1242,8 @@ fn sanity_check( ); let nb_partitions = params.macro_params.len(); let mut operations = OperationsCV { - variance: NoiseValues::new(), - cost: ComplexityValues::new(), + variance: NoiseValues::from_scheme(complexity.scheme()), + cost: ComplexityValues::from_scheme(complexity.scheme()), }; let micro_params = ¶ms.micro_params; for partition in PartitionIndex::range(0, nb_partitions) { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic.rs index c5ae790571..7482f59efc 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic.rs @@ -1,11 +1,16 @@ use super::partitions::PartitionIndex; use std::{collections::HashMap, fmt::Display}; -/// A map associating symbols with values. +/// A flexible and slow map associating values with symbols. /// /// By default all symbols are assumed to be associated with the default value /// of the type T. In practice, only associations with non-default values are /// stored in the map. +/// +/// Note: +/// ----- +/// This map is flexible but slow to lookup. Hence it is mostly suited to the +/// analysis part of the optimizer. #[derive(Clone, Debug, PartialEq)] pub struct SymbolMap(HashMap); @@ -60,6 +65,12 @@ impl SymbolMap { } } +impl Default for SymbolMap { + fn default() -> Self { + Self::new() + } +} + impl SymbolMap { /// Formats the symbol map with a given separator and symbol prefix. pub fn fmt_with( @@ -82,7 +93,141 @@ impl SymbolMap { } } -/// A symbol related to an fhe operation. +/// An indexing scheme for symbol arrays. +/// +/// Returns the linear index in a symbol array. +#[derive(Clone, Debug, PartialEq)] +pub struct SymbolScheme(usize); + +impl SymbolScheme { + /// Creates a new symbol scheme for a given number of partitions. + pub fn new(n_partitions: usize) -> Self { + SymbolScheme(n_partitions) + } + + /// Checks if a symbol is valid. + fn has_symbol(&self, sym: &Symbol) -> bool { + match sym { + Symbol::Input(i) => i.0 < self.0, + Symbol::Bootstrap(i) => i.0 < self.0, + Symbol::ModulusSwitch(i) => i.0 < self.0, + Symbol::Keyswitch(i, j) => i.0 < self.0 && j.0 < self.0, + Symbol::FastKeyswitch(i, j) => i.0 < self.0 && j.0 < self.0, + } + } + + /// Returns the linear index for a given symbol + pub fn get_symbol_index(&self, sym: &Symbol) -> usize { + debug_assert!(self.has_symbol(sym)); + match sym { + Symbol::Input(i) => i.0, + Symbol::Bootstrap(i) => self.0 + i.0, + Symbol::ModulusSwitch(i) => self.0 * 2 + i.0, + Symbol::Keyswitch(i, j) => self.0 * 3 + i.0 * self.0 + j.0, + Symbol::FastKeyswitch(i, j) => self.0 * (3 + self.0) + i.0 * self.0 + j.0, + } + } + + /// Returns the number of symbols in the scheme. + pub fn len(&self) -> usize { + self.0 * (3 + 2 * self.0) + } + + /// Returns an iterator over valid symbols. + pub fn iter(&self) -> impl Iterator + '_ { + (0..self.len()).map(|i| { + if i < self.len() { + Symbol::Input(PartitionIndex(i)) + } else if i < 2 * self.0 { + Symbol::Bootstrap(PartitionIndex(i - self.len())) + } else if i < 3 * self.0 { + Symbol::ModulusSwitch(PartitionIndex(i - 2 * self.len())) + } else if i < self.0 * (3 + self.0) { + let a = i - 3 * self.0; + Symbol::Keyswitch(PartitionIndex(a / self.0), PartitionIndex(a % self.0)) + } else { + let a = i - (3 + self.0) * self.0; + Symbol::FastKeyswitch(PartitionIndex(a / self.0), PartitionIndex(a % self.0)) + } + }) + } +} + +/// A rigid and fast map associating values with symbols. +/// +/// Stores all the possible values for a circuit with a given number of partitions. +/// +/// Note: +/// ----- +/// This map is rigid but allows fast lookup and iteration. Hence it is mostly suited to +/// the optimization part of the optimizer. +#[derive(Clone, Debug, PartialEq)] +pub struct SymbolArray { + pub(super) scheme: SymbolScheme, + pub(super) values: Vec, +} + +impl SymbolArray { + /// Creates a new Symbol array from a scheme. + pub fn from_scheme(scheme: &SymbolScheme) -> SymbolArray { + SymbolArray { + scheme: scheme.to_owned(), + values: vec![T::default(); scheme.len()], + } + } + + pub fn from_scheme_and_map(scheme: &SymbolScheme, map: &SymbolMap) -> SymbolArray { + let mut output = Self::from_scheme(scheme); + map.iter().for_each(|(sym, v)| output.set(&sym, v)); + output + } + + /// Sets the value associated with a given symbol. + pub fn set(&mut self, sym: &Symbol, val: T) { + self.values[self.scheme.get_symbol_index(sym)] = val; + } + + /// Returns the value associated with a given symbol. + pub fn get<'a>(&'a self, sym: &Symbol) -> &'a T { + &self.values[self.scheme.get_symbol_index(sym)] + } + + /// Returns the scheme used for this array. + pub fn scheme(&self) -> &SymbolScheme { + &self.scheme + } + + /// Returns an iterator over value refs. + pub fn iter(&self) -> impl Iterator { + self.values.iter() + } + + /// Returns an iterator over values and associated symbols. + pub fn iter_with_sym(&self) -> impl Iterator { + self.scheme.iter().zip(self.values.iter()) + } +} + +impl SymbolArray { + /// Formats the symbol array with a given separator and symbol prefix. + pub fn fmt_with( + &self, + f: &mut std::fmt::Formatter<'_>, + separator: &str, + sym_prefix: &str, + ) -> std::fmt::Result { + let mut terms = self.iter_with_sym(); + match terms.next() { + Some((sym, val)) => write!(f, "{val}{sym_prefix}{sym}")?, + None => return write!(f, "∅"), + } + for (sym, val) in terms { + write!(f, " {separator} {val}{sym_prefix}{sym}")?; + } + Ok(()) + } +} + #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub enum Symbol { Input(PartitionIndex), @@ -132,5 +277,5 @@ pub fn bootstrap(partition: PartitionIndex) -> Symbol { /// Returns a modulus switch symbol. #[allow(unused)] pub fn modulus_switching(partition: PartitionIndex) -> Symbol { - Symbol::Bootstrap(partition) + Symbol::ModulusSwitch(partition) } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs index 67e9c40bc3..1acf7a7abb 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs @@ -4,8 +4,9 @@ use std::fmt; use super::noise_expression::{ bootstrap_noise, fast_keyswitch_noise, input_noise, keyswitch_noise, modulus_switching_noise, - NoiseExpression, + NoiseEvaluator, NoiseExpression, }; +use super::symbolic::SymbolScheme; #[derive(Clone, Debug, PartialEq)] pub struct VarianceConstraint { @@ -15,6 +16,7 @@ pub struct VarianceConstraint { pub nb_constraints: u64, pub safe_variance_bound: f64, pub noise_expression: NoiseExpression, + pub noise_evaluator: Option, pub location: Location, } @@ -35,6 +37,13 @@ impl fmt::Display for VarianceConstraint { } impl VarianceConstraint { + pub fn init_evaluator(&mut self, scheme: &SymbolScheme) { + self.noise_evaluator = Some(NoiseEvaluator::from_scheme_and_expression( + scheme, + &self.noise_expression, + )); + } + #[allow(clippy::cast_sign_loss)] fn dominance_index(&self) -> u64 { let max_coeff = self