From 15b22918dfb4599d9c51e72ec2cf9f3cff47107c Mon Sep 17 00:00:00 2001 From: rudy Date: Wed, 3 Apr 2024 12:04:32 +0200 Subject: [PATCH] fix(optimizer): bad variance on zero noise input on levelled op --- .../dag/multi_parameters/optimize/tests.rs | 14 ++++++++++++++ .../dag/multi_parameters/symbolic_variance.rs | 3 +++ 2 files changed, 17 insertions(+) diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs index d15dc1dff3..835916e48d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize/tests.rs @@ -845,3 +845,17 @@ fn test_maximal_multi() { // note: we have a 5% relative margin since dag complexity is slightly better than v0 assert!(sol.complexity < 1.05 * (sol_ref.complexity / expected_speedup)); } + +#[test] +fn test_bug_with_zero_noise() { + let complexity = LevelledComplexity::ZERO; + let out_shape = Shape::number(); + let mut dag = unparametrized::OperationDag::new(); + let v0 = dag.add_input(2, &out_shape); + let v1 = dag.add_levelled_op([v0], complexity, 0.0, &out_shape, "comment"); + let v2 = dag.add_levelled_op([v1], complexity, 1.0, &out_shape, "comment"); + let v3 = dag.add_unsafe_cast(v2, 1); + let _ = dag.add_lut(v3, FunctionTable { values: vec![] }, 1); + let sol = optimize(&dag, &None, 0); + assert!(sol.is_some()); +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs index 892b2c9805..bba2cc357a 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs @@ -157,6 +157,9 @@ impl SymbolicVariance { // replace all current_max by new_coeff // multiply everything else by new_coeff / current_max let mut new = self.clone(); + if current_max == 0.0 { + return new; + } for cell in &mut new.coeffs.values { if *cell == current_max { *cell = new_coeff;