Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(optimizer): fix performance regression #1164

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ impl VariancedDag {
nb_constraints: out_shape.flat_size(),
safe_variance_bound: max_variance,
noise_expression: variance.clone(),
noise_evaluator: None,
location: op.location.clone(),
};
self.external_variance_constraints.push(constraint);
Expand Down Expand Up @@ -273,6 +274,7 @@ impl VariancedDag {
nb_constraints: out_shape.flat_size(),
safe_variance_bound: max_variance,
noise_expression: variance.clone(),
noise_evaluator: None,
location: dag_op.location.clone(),
};
self.external_variance_constraints.push(constraint);
Expand Down Expand Up @@ -646,6 +648,7 @@ fn variance_constraint(
safe_variance_bound,
nb_partitions,
noise_expression: noise,
noise_evaluator: None,
location,
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{fmt, ops::Add};

use super::{
partitions::PartitionIndex,
symbolic::{fast_keyswitch, keyswitch, Symbol, SymbolMap},
symbolic::{fast_keyswitch, keyswitch, Symbol, SymbolArray, SymbolMap, SymbolScheme},
};

/// A structure storing the number of times an fhe operation gets executed in a circuit.
Expand All @@ -29,36 +29,44 @@ impl fmt::Display for OperationsCount {

/// An ensemble of costs associated with fhe operation symbols.
#[derive(Clone, Debug)]
pub struct ComplexityValues(SymbolMap<f64>);
pub struct ComplexityValues(SymbolArray<f64>);

impl ComplexityValues {
/// Returns an empty set of cost values.
pub fn new() -> Self {
ComplexityValues(SymbolMap::new())
pub fn from_scheme(scheme: &SymbolScheme) -> ComplexityValues {
ComplexityValues(SymbolArray::from_scheme(scheme))
}

/// Sets the cost associated with an fhe operation symbol.
pub fn set_cost(&mut self, source: Symbol, value: f64) {
self.0.set(source, value);
self.0.set(&source, value);
}
}

/// A complexity expression is a sum of complexity terms associating operation
/// symbols with the number of time they gets executed in the circuit.
#[derive(Clone, Debug)]
pub struct ComplexityExpression(SymbolMap<usize>);
pub struct ComplexityEvaluator(SymbolArray<usize>);

impl ComplexityExpression {
impl ComplexityEvaluator {
/// Creates a complexity expression from a set of operation counts.
pub fn from(counts: &OperationsCount) -> Self {
Self(counts.0.clone())
pub fn from_scheme_and_counts(
scheme: &SymbolScheme,
counts: &OperationsCount,
) -> ComplexityEvaluator {
Self(SymbolArray::from_scheme_and_map(scheme, &counts.0))
}

pub fn scheme(&self) -> &SymbolScheme {
self.0.scheme()
}

/// Evaluates the total cost expression on a set of cost values.
pub fn evaluate_total_cost(&self, costs: &ComplexityValues) -> f64 {
self.0.iter().fold(0.0, |acc, (symbol, n_ops)| {
acc + (n_ops as f64) * costs.0.get(symbol)
})
self.0
.iter()
.zip(costs.0.iter())
.fold(0.0, |acc, (n_ops, cost)| acc + (*n_ops as f64) * *cost)
}

/// Evaluates the max ks cost expression on a set of cost values.
Expand All @@ -69,11 +77,11 @@ impl ComplexityExpression {
src_partition: PartitionIndex,
dst_partition: PartitionIndex,
) -> f64 {
let actual_ks_cost = costs.0.get(keyswitch(src_partition, dst_partition));
let ks_coeff = self.0.get(keyswitch(src_partition, dst_partition));
let actual_ks_cost = costs.0.get(&keyswitch(src_partition, dst_partition));
let ks_coeff = self.0.get(&keyswitch(src_partition, dst_partition));
let actual_complexity =
self.evaluate_total_cost(costs) - (ks_coeff as f64) * actual_ks_cost;
(complexity_cut - actual_complexity) / (ks_coeff as f64)
self.evaluate_total_cost(costs) - (*ks_coeff as f64) * actual_ks_cost;
(complexity_cut - actual_complexity) / (*ks_coeff as f64)
}

/// Evaluates the max fks cost expression on a set of cost values.
Expand All @@ -84,10 +92,10 @@ impl ComplexityExpression {
src_partition: PartitionIndex,
dst_partition: PartitionIndex,
) -> f64 {
let actual_fks_cost = costs.0.get(fast_keyswitch(src_partition, dst_partition));
let fks_coeff = self.0.get(fast_keyswitch(src_partition, dst_partition));
let actual_fks_cost = costs.0.get(&fast_keyswitch(src_partition, dst_partition));
let fks_coeff = self.0.get(&fast_keyswitch(src_partition, dst_partition));
let actual_complexity =
self.evaluate_total_cost(costs) - (fks_coeff as f64) * actual_fks_cost;
(complexity_cut - actual_complexity) / (fks_coeff as f64)
self.evaluate_total_cost(costs) - (*fks_coeff as f64) * actual_fks_cost;
(complexity_cut - actual_complexity) / (*fks_coeff as f64)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,18 @@ impl Feasible {

for constraint in &self.undominated_constraints {
let pbs_coeff = constraint
.noise_expression
.noise_evaluator
.as_ref()
.unwrap()
.coeff(bootstrap_noise(partition));
if pbs_coeff == 0.0 {
continue;
}
let actual_variance = constraint.noise_expression.evaluate(operations_variance)
let actual_variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance)
- pbs_coeff * actual_pbs_variance;
let pbs_max_variance = (constraint.safe_variance_bound - actual_variance) / pbs_coeff;
smallest_pbs_max_variance = smallest_pbs_max_variance.min(pbs_max_variance);
Expand All @@ -75,12 +81,18 @@ impl Feasible {

for constraint in &self.undominated_constraints {
let ks_coeff = constraint
.noise_expression
.noise_evaluator
.as_ref()
.unwrap()
.coeff(keyswitch_noise(src_partition, dst_partition));
if ks_coeff == 0.0 {
continue;
}
let actual_variance = constraint.noise_expression.evaluate(operations_variance)
let actual_variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance)
- ks_coeff * actual_ks_variance;
let ks_max_variance = (constraint.safe_variance_bound - actual_variance) / ks_coeff;
smallest_ks_max_variance = smallest_ks_max_variance.min(ks_max_variance);
Expand All @@ -102,12 +114,18 @@ impl Feasible {

for constraint in &self.undominated_constraints {
let fks_coeff = constraint
.noise_expression
.noise_evaluator
.as_ref()
.unwrap()
.coeff(fast_keyswitch_noise(src_partition, dst_partition));
if fks_coeff == 0.0 {
continue;
}
let actual_variance = constraint.noise_expression.evaluate(operations_variance)
let actual_variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance)
- fks_coeff * actual_fks_variance;
let fks_max_variance = (constraint.safe_variance_bound - actual_variance) / fks_coeff;
smallest_fks_max_variance = smallest_fks_max_variance.min(fks_max_variance);
Expand All @@ -126,7 +144,11 @@ impl Feasible {

fn local_feasible(&self, operations_variance: &NoiseValues) -> bool {
for constraint in &self.undominated_constraints {
if constraint.noise_expression.evaluate(operations_variance)
if constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance)
> constraint.safe_variance_bound
{
return false;
Expand All @@ -148,7 +170,11 @@ impl Feasible {
let mut worst_relative_variance = 0.0;
let mut worst_variance = 0.0;
for constraint in &self.undominated_constraints {
let variance = constraint.noise_expression.evaluate(operations_variance);
let variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance);
let relative_variance = variance / constraint.safe_variance_bound;
if relative_variance > worst_relative_variance {
worst_relative_variance = relative_variance;
Expand All @@ -167,7 +193,11 @@ impl Feasible {
fn global_p_error_with_cut(&self, operations_variance: &NoiseValues, cut: f64) -> Option<f64> {
let mut global_p_error = 0.0;
for constraint in &self.constraints {
let variance = constraint.noise_expression.evaluate(operations_variance);
let variance = constraint
.noise_evaluator
.as_ref()
.unwrap()
.evaluate(operations_variance);
let relative_variance = variance / constraint.safe_variance_bound;
let p_error = p_error_from_relative_variance(relative_variance, self.kappa);
global_p_error = combine_errors(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,27 @@ use std::{

use super::{
partitions::PartitionIndex,
symbolic::{Symbol, SymbolMap},
symbolic::{Symbol, SymbolArray, SymbolMap, SymbolScheme},
};

/// An ensemble of noise values for fhe operations.
#[derive(Debug, Clone, PartialEq)]
pub struct NoiseValues(SymbolMap<f64>);
pub struct NoiseValues(SymbolArray<f64>);

impl NoiseValues {
/// Returns an empty set of noise values.
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
NoiseValues(SymbolMap::new())
pub fn from_scheme(scheme: &SymbolScheme) -> NoiseValues {
NoiseValues(SymbolArray::from_scheme(scheme))
}

/// Sets the noise variance associated with a noise source.
pub fn set_variance(&mut self, source: NoiseSource, value: f64) {
self.0.set(source.0, value);
self.0.set(&source.0, value);
}

/// Returns the variance associated with a noise source
pub fn variance(&self, source: NoiseSource) -> f64 {
self.0.get(source.0)
*self.0.get(&source.0)
}
}

Expand All @@ -36,10 +35,36 @@ impl Display for NoiseValues {
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct NoiseEvaluator(SymbolArray<f64>);

impl NoiseEvaluator {
/// Returns a zero noise expression
pub fn from_scheme_and_expression(
scheme: &SymbolScheme,
expr: &NoiseExpression,
) -> NoiseEvaluator {
NoiseEvaluator(SymbolArray::from_scheme_and_map(scheme, &expr.0))
}

/// Returns the coefficient associated with a noise source.
pub fn coeff(&self, source: NoiseSource) -> f64 {
*self.0.get(&source.0)
}

/// Evaluate the noise expression on a set of noise values.
pub fn evaluate(&self, values: &NoiseValues) -> f64 {
self.0
.iter()
.zip(values.0.iter())
.fold(0.0, |acc, (coef, var)| acc + coef * var)
}
}

/// A noise expression, i.e. a sum of noise terms associating a noise source,
/// with a multiplicative coefficient.
#[derive(Debug, Clone, PartialEq)]
pub struct NoiseExpression(SymbolMap<f64>);
pub struct NoiseExpression(pub SymbolMap<f64>);

impl NoiseExpression {
/// Returns a zero noise expression
Expand Down Expand Up @@ -70,12 +95,12 @@ impl NoiseExpression {
lhs
}

/// Evaluate the noise expression on a set of noise values.
pub fn evaluate(&self, values: &NoiseValues) -> f64 {
self.terms_iter().fold(0.0, |acc, term| {
acc + term.coefficient * values.variance(term.source)
})
}
// /// Evaluate the noise expression on a set of noise values.
// pub fn evaluate(&self, values: &NoiseValues) -> f64 {
// self.terms_iter().fold(0.0, |acc, term| {
// acc + term.coefficient * values.variance(term.source)
// })
// }
}

impl Display for NoiseExpression {
Expand Down Expand Up @@ -196,7 +221,7 @@ impl Mul<NoiseSource> for f64 {

/// A symbolic source of noise, or a noise source variable.
#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub struct NoiseSource(Symbol);
pub struct NoiseSource(pub Symbol);

/// Returns an input noise source symbol.
pub fn input_noise(partition: PartitionIndex) -> NoiseSource {
Expand Down
Loading
Loading