Skip to content

Commit

Permalink
use shapes in tfhers-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
youben11 committed Nov 28, 2024
1 parent 3b8bc13 commit 03f4748
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
2 changes: 1 addition & 1 deletion frontends/concrete-python/examples/tfhers-ml/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ python -c "print(','.join(map(lambda x: str(x << 10), [$(cat $TDIR/result_plaint
## Dequantize values

```sh
../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --config ./output_quantizer.json
../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --shape=5,3 --config ./output_quantizer.json
```

## Clean tmpdir
Expand Down
32 changes: 30 additions & 2 deletions frontends/concrete-python/tests/tfhers-utils/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,16 @@ fn main() {
.value_delimiter(',')
.num_args(1..),
)
.arg(
Arg::new("shape")
.short('s')
.long("shape")
.help("shape of values")
.action(ArgAction::Set)
.required(false)
.value_delimiter(',')
.num_args(0..),
)
.arg(
Arg::new("output")
.long("output")
Expand Down Expand Up @@ -476,6 +486,16 @@ fn main() {
.value_delimiter(',')
.num_args(1..),
)
.arg(
Arg::new("shape")
.short('s')
.long("shape")
.help("shape of values")
.action(ArgAction::Set)
.required(false)
.value_delimiter(',')
.num_args(0..),
)
.arg(
Arg::new("output")
.long("output")
Expand Down Expand Up @@ -574,13 +594,17 @@ fn main() {
.get_many::<String>("value")
.unwrap()
.collect();
let shapes: Vec<usize> = match quantize_matches.get_many::<String>("shape") {
Some(shapes) => shapes.into_iter().map(|s| s.parse().unwrap()).collect(),
None => vec![value_str.len()],
};
let config_path = quantize_matches.get_one::<String>("config").unwrap();
let output_path = quantize_matches.get_one::<String>("output");

let quantizer = Quantizer::from_json_file(config_path).unwrap();
let value: Vec<f64> = value_str.iter().map(|v| v.parse().unwrap()).collect();
let quantized_array = quantizer.quantize(
&ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[value.len()]), value).unwrap(),
&ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shapes), value).unwrap(),
);
let quantized_values: Vec<&i64> = quantized_array.iter().collect();
let results_str: Vec<String> = quantized_values.iter().map(|v| v.to_string()).collect();
Expand All @@ -597,13 +621,17 @@ fn main() {
.get_many::<String>("value")
.unwrap()
.collect();
let shapes: Vec<usize> = match dequantize_matches.get_many::<String>("shape") {
Some(shapes) => shapes.into_iter().map(|s| s.parse().unwrap()).collect(),
None => vec![value_str.len()],
};
let config_path = dequantize_matches.get_one::<String>("config").unwrap();
let output_path = dequantize_matches.get_one::<String>("output");

let quantizer = Quantizer::from_json_file(config_path).unwrap();
let value: Vec<i64> = value_str.iter().map(|v| v.parse().unwrap()).collect();
let dequantized_array = quantizer.dequantize(
&ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[value.len()]), value).unwrap(),
&ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&shapes), value).unwrap(),
);
let dequantized_values: Vec<&f64> = dequantized_array.iter().collect();
let results_str: Vec<String> =
Expand Down

0 comments on commit 03f4748

Please sign in to comment.