diff --git a/frontends/concrete-python/examples/tfhers-ml/README.md b/frontends/concrete-python/examples/tfhers-ml/README.md index 3d2859ead..f263131a4 100644 --- a/frontends/concrete-python/examples/tfhers-ml/README.md +++ b/frontends/concrete-python/examples/tfhers-ml/README.md @@ -90,7 +90,7 @@ python -c "print(','.join(map(lambda x: str(x << 10), [$(cat $TDIR/result_plaint We need to dequantize integer outputs using a pre-built quantizer for our ML model ```sh -../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --config ./output_quantizer.json +../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --shape=5,3 --config ./output_quantizer.json ``` ## Compute error diff --git a/frontends/concrete-python/tests/tfhers-utils/src/main.rs b/frontends/concrete-python/tests/tfhers-utils/src/main.rs index a936df11f..1ea174d6e 100644 --- a/frontends/concrete-python/tests/tfhers-utils/src/main.rs +++ b/frontends/concrete-python/tests/tfhers-utils/src/main.rs @@ -444,6 +444,16 @@ fn main() { .value_delimiter(',') .num_args(1..), ) + .arg( + Arg::new("shape") + .short('s') + .long("shape") + .help("shape of values") + .action(ArgAction::Set) + .required(false) + .value_delimiter(',') + .num_args(0..), + ) .arg( Arg::new("output") .long("output") @@ -476,6 +486,16 @@ fn main() { .value_delimiter(',') .num_args(1..), ) + .arg( + Arg::new("shape") + .short('s') + .long("shape") + .help("shape of values") + .action(ArgAction::Set) + .required(false) + .value_delimiter(',') + .num_args(0..), + ) .arg( Arg::new("output") .long("output") @@ -574,13 +594,17 @@ fn main() { .get_many::("value") .unwrap() .collect(); + let shapes: Vec = match quantize_matches.get_many::("shape") { + Some(shapes) => shapes.into_iter().map(|s| s.parse().unwrap()).collect(), + None => vec![value_str.len()], + }; let config_path = quantize_matches.get_one::("config").unwrap(); let output_path = quantize_matches.get_one::("output"); let quantizer = Quantizer::from_json_file(config_path).unwrap(); let value: Vec = value_str.iter().map(|v| v.parse().unwrap()).collect(); let quantized_array = quantizer.quantize( - &ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[value.len()]), value).unwrap(), + &ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shapes), value).unwrap(), ); let quantized_values: Vec<&i64> = quantized_array.iter().collect(); let results_str: Vec = quantized_values.iter().map(|v| v.to_string()).collect(); @@ -597,13 +621,17 @@ fn main() { .get_many::("value") .unwrap() .collect(); + let shapes: Vec = match dequantize_matches.get_many::("shape") { + Some(shapes) => shapes.into_iter().map(|s| s.parse().unwrap()).collect(), + None => vec![value_str.len()], + }; let config_path = dequantize_matches.get_one::("config").unwrap(); let output_path = dequantize_matches.get_one::("output"); let quantizer = Quantizer::from_json_file(config_path).unwrap(); let value: Vec = value_str.iter().map(|v| v.parse().unwrap()).collect(); let dequantized_array = quantizer.dequantize( - &ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[value.len()]), value).unwrap(), + &ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shapes), value).unwrap(), ); let dequantized_values: Vec<&f64> = dequantized_array.iter().collect(); let results_str: Vec =