From 26066a9c9abe1a694b59fd24551dc44240a46e8d Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 14 Nov 2024 14:10:40 +0100 Subject: [PATCH] docs(frontend): TFHE-rs interop example using ML model Co-authored-by: jfrery --- .../examples/tfhers-ml/README.md | 108 +++++++++++ .../examples/tfhers-ml/compute_error.py | 63 ++++++ .../examples/tfhers-ml/example.py | 179 ++++++++++++++++++ .../examples/tfhers-ml/input_quantizer.json | 1 + .../examples/tfhers-ml/output_quantizer.json | 1 + .../examples/tfhers-ml/test.sh | 8 + .../examples/tfhers-ml/test_values.json | 120 ++++++++++++ .../examples/tfhers-ml/tfhers_params.json | 1 + .../tests/tfhers-utils/Cargo.lock | 68 +++++++ .../tests/tfhers-utils/Cargo.toml | 3 + .../tests/tfhers-utils/src/main.rs | 113 ++++++++++- 11 files changed, 664 insertions(+), 1 deletion(-) create mode 100644 frontends/concrete-python/examples/tfhers-ml/README.md create mode 100644 frontends/concrete-python/examples/tfhers-ml/compute_error.py create mode 100644 frontends/concrete-python/examples/tfhers-ml/example.py create mode 100644 frontends/concrete-python/examples/tfhers-ml/input_quantizer.json create mode 100644 frontends/concrete-python/examples/tfhers-ml/output_quantizer.json create mode 100755 frontends/concrete-python/examples/tfhers-ml/test.sh create mode 100644 frontends/concrete-python/examples/tfhers-ml/test_values.json create mode 100644 frontends/concrete-python/examples/tfhers-ml/tfhers_params.json diff --git a/frontends/concrete-python/examples/tfhers-ml/README.md b/frontends/concrete-python/examples/tfhers-ml/README.md new file mode 100644 index 0000000000..3d2859ead0 --- /dev/null +++ b/frontends/concrete-python/examples/tfhers-ml/README.md @@ -0,0 +1,108 @@ +# TFHE-rs Interoperability Example + +This is a similar example to the [first TFHE-rs example](../tfhers/), except that it uses tensors and run a linear ML model. It also uses quantization. + +## Make tmpdir + +We want to setup a temporary working directory first + +```sh +export TDIR=`mktemp -d` +``` + +## KeyGen + +First we need to build the TFHE-rs utility in [this directory](../../tests/tfhers-utils/) by running + +```sh +cd ../../tests/tfhers-utils/ +make build +cd - +``` + +Then we can generate keys in two different ways. You only need to run one of the following methods + +#### Generate the Secret Key in Concrete + +We start by doing keygen in Concrete + +```sh +python example.py keygen -o $TDIR/concrete_sk -k $TDIR/concrete_keyset +``` + +Then we do a partial keygen in TFHE-rs + +```sh +../../tests/tfhers-utils/target/release/tfhers_utils keygen --lwe-sk $TDIR/concrete_sk --output-lwe-sk $TDIR/tfhers_sk -c $TDIR/tfhers_client_key -s $TDIR/tfhers_server_key +``` + +#### Generate the Secret Key in TFHE-rs + +We start by doing keygen in TFHE-rs + +```sh +../../tests/tfhers-utils/target/release/tfhers_utils keygen --output-lwe-sk $TDIR/tfhers_sk -c $TDIR/tfhers_client_key -s $TDIR/tfhers_server_key +``` + +Then we do a partial keygen in Concrete. + +```sh +python example.py keygen -s $TDIR/tfhers_sk -o $TDIR/concrete_sk -k $TDIR/concrete_keyset +``` + +## Quantize values + +We need to quantize floating point inputs using a pre-built quantizer for our ML model + +```sh +../../tests/tfhers-utils/target/release/tfhers_utils quantize --value=5.1,3.5,1.4,0.2,4.9,3,1.4,0.2,4.7,3.2,1.3,0.2,4.6,3.1,1.5,0.2,5,3.6,1.4,0.2 --config ./input_quantizer.json -o $TDIR/quantized_values +``` + +## Encrypt in TFHE-rs + +```sh +../../tests/tfhers-utils/target/release/tfhers_utils encrypt-with-key --signed --value=$(cat $TDIR/quantized_values) --ciphertext $TDIR/tfhers_ct --client-key $TDIR/tfhers_client_key +``` + +## Run in Concrete + +```sh +python example.py run -k $TDIR/concrete_keyset -c $TDIR/tfhers_ct -o $TDIR/tfhers_ct_out +``` + +## Decrypt in TFHE-rs + +```sh +../../tests/tfhers-utils/target/release/tfhers_utils decrypt-with-key --tensor --signed --ciphertext $TDIR/tfhers_ct_out --client-key $TDIR/tfhers_client_key --plaintext $TDIR/result_plaintext +``` + +## Rescale Output + +At the end of the circuit, we are rounding the result to 8 bits, discarding the remaining LSB bits. As we have `lsbs_to_remove=10` we are re-introducing the 10 bits of LSB. + +```sh +python -c "print(','.join(map(lambda x: str(x << 10), [$(cat $TDIR/result_plaintext)])))" > $TDIR/rescaled_plaintext +``` + + +## Dequantize values + +We need to dequantize integer outputs using a pre-built quantizer for our ML model + +```sh +../../tests/tfhers-utils/target/release/tfhers_utils dequantize --value=$(cat $TDIR/rescaled_plaintext) --config ./output_quantizer.json +``` + +## Compute error + +We compare the output to the expected result + +```sh +python compute_error.py --plaintext-file "$TDIR/rescaled_plaintext" --quantized-predictions-file "test_values.json" +``` + +## Clean tmpdir + +```sh +rm -rf $TDIR +``` diff --git a/frontends/concrete-python/examples/tfhers-ml/compute_error.py b/frontends/concrete-python/examples/tfhers-ml/compute_error.py new file mode 100644 index 0000000000..99c9e407db --- /dev/null +++ b/frontends/concrete-python/examples/tfhers-ml/compute_error.py @@ -0,0 +1,63 @@ +import json + +import click + + +@click.command() +@click.option( + "--plaintext-file", "-p", required=True, help="Path to the rescaled plaintext values file." +) +@click.option( + "--quantized-predictions-file", + "-q", + required=True, + help="Path to the test_values.json file containing quantized predictions.", +) +def compute_error(plaintext_file, quantized_predictions_file): + """Compute the error between decrypted rescaled values and quantized predictions.""" + # Read rescaled plaintext values from plaintext_file + with open(plaintext_file) as f: + rescaled_plaintext_values = [int(x) for x in f.read().strip().split(",")] + + # Read quantized_predictions from quantized_predictions_file + with open(quantized_predictions_file) as f: + data = json.load(f) + quantized_predictions = data["quantized_predictions"] + + # Flatten quantized_predictions + quantized_predictions_flat = [int(x) for sublist in quantized_predictions for x in sublist] + + # Round down 10 bits using (x // (1 << 10)) * (1 << 10) + rounded_quantized_predictions = [ + (x // (1 << 10)) * (1 << 10) for x in quantized_predictions_flat + ] + + # Compare rescaled_plaintext_values with rounded_quantized_predictions + num_differences = 0 + total_values = len(rescaled_plaintext_values) + errors = [] + for i in range(total_values): + a = rescaled_plaintext_values[i] + b = rounded_quantized_predictions[i] + print(f"output: {a}, expected: {b}") + if a != b: + num_differences += 1 + error_in_units = round((a - b) / (1 << 10)) + errors.append((i, error_in_units)) + + print("Number of differing values: {}".format(num_differences)) + print("Total values compared: {}".format(total_values)) + if num_differences > 0: + print("Differences (index, error in units of 2^10):") + for idx, error_in_units in errors: + print("Index {}: error = {}".format(idx, error_in_units)) + + # success is when we don't offset by more than 1 + for error in errors: + if error[1] > 1: + return 1 + return 0 + + +if __name__ == "__main__": + compute_error() diff --git a/frontends/concrete-python/examples/tfhers-ml/example.py b/frontends/concrete-python/examples/tfhers-ml/example.py new file mode 100644 index 0000000000..bd43d1ce1b --- /dev/null +++ b/frontends/concrete-python/examples/tfhers-ml/example.py @@ -0,0 +1,179 @@ +import os +import typing +from functools import partial + +import click +import numpy as np + +import concrete.fhe as fhe +from concrete.fhe import tfhers + +### Options ########################### +# These parameters were saved by running the tfhers_utils utility: +# tfhers_utils save-params tfhers_params.json +TFHERS_PARAMS_FILE = "tfhers_params.json" +FHEUINT_PRECISION = 8 +IS_SIGNED = True +####################################### + +tfhers_type = tfhers.get_type_from_params( + TFHERS_PARAMS_FILE, + is_signed=IS_SIGNED, + precision=FHEUINT_PRECISION, +) +tfhers_int = partial(tfhers.TFHERSInteger, tfhers_type) + +#### Model Parameters ################## +q_weights = np.array([[-25, 21, -10], [42, -20, -37], [-128, -15, 127], [-58, -51, 94]]) +q_bias = np.array([[35167, 9417, -44584]]) +weight_quantizer_zero_point = -5 +######################################## + +rounder = fhe.AutoRounder(target_msbs=8) # We want to keep 8 MSBs + + +@typing.no_type_check +def ml_inference(q_X: np.ndarray) -> np.ndarray: + y_pred = q_X @ q_weights - weight_quantizer_zero_point * np.sum(q_X, axis=1, keepdims=True) + y_pred += q_bias + y_pred = fhe.round_bit_pattern(y_pred, rounder) + y_pred = y_pred >> rounder.lsbs_to_remove + return y_pred + + +def compute(tfhers_x): + ####### TFHE-rs to Concrete ######### + + # x and y are supposed to be TFHE-rs values. + # to_native will use type information from x and y to do + # a correct conversion from TFHE-rs to Concrete + concrete_x = tfhers.to_native(tfhers_x) + ####### TFHE-rs to Concrete ######### + + ####### Concrete Computation ######## + concrete_res = ml_inference(concrete_x) + ####### Concrete Computation ######## + + ####### Concrete to TFHE-rs ######### + tfhers_res = tfhers.from_native( + concrete_res, tfhers_type + ) # we have to specify the type we want to convert to + ####### Concrete to TFHE-rs ######### + return tfhers_res + + +def ccompilee(): + compiler = fhe.Compiler( + compute, + { + "tfhers_x": "encrypted", + }, + ) + + inputset = [ + ( + tfhers_int( + np.array( + [ + [36, -17, -85, -124], + [29, -33, -85, -124], + [23, -26, -88, -124], + [19, -30, -82, -124], + [32, -13, -85, -124], + ] + ) + ), + ) + ] + + # Add the auto-adjustment before compilation + fhe.AutoRounder.adjust(compute, inputset) + + # Print the number of bits rounded + print(f"lsbs_to_remove: {rounder.lsbs_to_remove}") + + circuit = compiler.compile(inputset) + + tfhers_bridge = tfhers.new_bridge(circuit=circuit) + return circuit, tfhers_bridge + + +@click.group() +def cli(): + pass + + +@cli.command() +@click.option("-s", "--secret-key", type=str, required=False) +@click.option("-o", "--output-secret-key", type=str, required=True) +@click.option("-k", "--concrete-keyset-path", type=str, required=True) +def keygen(output_secret_key: str, secret_key: str, concrete_keyset_path: str): + """Concrete Key Generation""" + + circuit, tfhers_bridge = ccompilee() + + if os.path.exists(concrete_keyset_path): + print(f"removing old keyset at '{concrete_keyset_path}'") + os.remove(concrete_keyset_path) + + if secret_key: + print(f"partial keygen from sk at '{secret_key}'") + # load the initial secret key to use for keygen + with open( + secret_key, + "rb", + ) as f: + buff = f.read() + input_idx_to_key = {0: buff} + tfhers_bridge.keygen_with_initial_keys(input_idx_to_key_buffer=input_idx_to_key) + else: + print("full keygen") + circuit.keygen() + + print("saving Concrete keyset") + circuit.client.keys.save(concrete_keyset_path) + print(f"saved Concrete keyset to '{concrete_keyset_path}'") + + sk: bytes = tfhers_bridge.serialize_input_secret_key(input_idx=0) + print(f"writing secret key of size {len(sk)} to '{output_secret_key}'") + with open(output_secret_key, "wb") as f: + f.write(sk) + + +@cli.command() +@click.option("-c", "--rust-ct", type=str, required=True) +@click.option("-o", "--output-rust-ct", type=str, required=False) +@click.option("-k", "--concrete-keyset-path", type=str, required=True) +def run(rust_ct: str, output_rust_ct: str, concrete_keyset_path: str): + """Run circuit""" + circuit, tfhers_bridge = ccompilee() + + if not os.path.exists(concrete_keyset_path): + raise RuntimeError("cannot find keys, you should run keygen before") + print(f"loading keys from '{concrete_keyset_path}'") + circuit.client.keys.load(concrete_keyset_path) + + # read tfhers int from file + with open(rust_ct, "rb") as f: + buff = f.read() + # import fheuint8 and get its description + tfhers_uint8_x = tfhers_bridge.import_value(buff, input_idx=0) + + print("Homomorphic evaluation...") + encrypted_result = circuit.run(tfhers_uint8_x) + + if output_rust_ct: + print("exporting Rust ciphertexts") + # export fheuint8 + buff = tfhers_bridge.export_value(encrypted_result, output_idx=0) + # write it to file + with open(output_rust_ct, "wb") as f: + f.write(buff) + else: + result = circuit.decrypt(encrypted_result) + decoded = tfhers_type.decode(result) + print(f"Concrete decryption result: raw({result}), decoded({decoded})") + + +if __name__ == "__main__": + cli() diff --git a/frontends/concrete-python/examples/tfhers-ml/input_quantizer.json b/frontends/concrete-python/examples/tfhers-ml/input_quantizer.json new file mode 100644 index 0000000000..0daa127a7a --- /dev/null +++ b/frontends/concrete-python/examples/tfhers-ml/input_quantizer.json @@ -0,0 +1 @@ +{"type_name": "UniformQuantizer", "serialized_value": {"n_bits": 8, "is_signed": true, "is_symmetric": false, "is_qat": false, "is_narrow": false, "is_precomputed_qat": false, "rmax": {"type_name": "numpy_float", "serialized_value": 7.9, "dtype": "float64"}, "rmin": {"type_name": "numpy_float", "serialized_value": 0.1, "dtype": "float64"}, "scale": {"type_name": "numpy_float", "serialized_value": 0.03058823529411765, "dtype": "float64"}, "zero_point": {"type_name": "numpy_integer", "serialized_value": -131, "dtype": "int64"}, "offset": 128, "no_clipping": false}} diff --git a/frontends/concrete-python/examples/tfhers-ml/output_quantizer.json b/frontends/concrete-python/examples/tfhers-ml/output_quantizer.json new file mode 100644 index 0000000000..dc97e78618 --- /dev/null +++ b/frontends/concrete-python/examples/tfhers-ml/output_quantizer.json @@ -0,0 +1 @@ +{"type_name": "UniformQuantizer", "serialized_value": {"is_signed": false, "is_symmetric": false, "is_qat": false, "is_narrow": false, "is_precomputed_qat": false, "rmax": null, "rmin": null, "scale": {"type_name": "numpy_float", "serialized_value": 0.0006288117860507253, "dtype": "float64"}, "zero_point": {"type_name": "numpy_array", "serialized_value": [[39038, 11790, -50828]], "dtype": "int64"}, "offset": 0, "no_clipping": true}} diff --git a/frontends/concrete-python/examples/tfhers-ml/test.sh b/frontends/concrete-python/examples/tfhers-ml/test.sh new file mode 100755 index 0000000000..dd0a9eebab --- /dev/null +++ b/frontends/concrete-python/examples/tfhers-ml/test.sh @@ -0,0 +1,8 @@ +#!/bin/sh + +# This file tests that the example is working + +shell_blocks=$(sed -n '/^```sh/,/^```/ p' < README.md | sed '/^```sh/d' | sed '/^```/d') + +set -e +output=$(eval "$shell_blocks" 2>&1) || echo "$output" diff --git a/frontends/concrete-python/examples/tfhers-ml/test_values.json b/frontends/concrete-python/examples/tfhers-ml/test_values.json new file mode 100644 index 0000000000..eab37cbbe6 --- /dev/null +++ b/frontends/concrete-python/examples/tfhers-ml/test_values.json @@ -0,0 +1,120 @@ +{ + "input_fp32_values": [ + [ + 5.1, + 3.5, + 1.4, + 0.2 + ], + [ + 4.9, + 3.0, + 1.4, + 0.2 + ], + [ + 4.7, + 3.2, + 1.3, + 0.2 + ], + [ + 4.6, + 3.1, + 1.5, + 0.2 + ], + [ + 5.0, + 3.6, + 1.4, + 0.2 + ] + ], + "input_quantized_values": [ + [ + 36, + -17, + -85, + -124 + ], + [ + 29, + -33, + -85, + -124 + ], + [ + 23, + -26, + -88, + -124 + ], + [ + 19, + -30, + -82, + -124 + ], + [ + 32, + -13, + -85, + -124 + ] + ], + "quantized_predictions": [ + [ + 50675, + 17162, + -67716 + ], + [ + 50063, + 17220, + -67169 + ], + [ + 50881, + 16989, + -67759 + ], + [ + 50035, + 16885, + -66819 + ], + [ + 50943, + 16998, + -67824 + ] + ], + "dequantized_predictions": [ + [ + 7.317482754272291, + 3.3779769146644965, + -10.61937344282465 + ], + [ + 6.932649941209247, + 3.4144479982554388, + -10.275413395854903 + ], + [ + 7.44701798219874, + 3.269192475677721, + -10.64641234962483 + ], + [ + 6.915043211199826, + 3.2037960499284455, + -10.055329270737149 + ], + [ + 7.486004312933885, + 3.2748517817521776, + -10.687285115718128 + ] + ] +} diff --git a/frontends/concrete-python/examples/tfhers-ml/tfhers_params.json b/frontends/concrete-python/examples/tfhers-ml/tfhers_params.json new file mode 100644 index 0000000000..bf62c959f2 --- /dev/null +++ b/frontends/concrete-python/examples/tfhers-ml/tfhers_params.json @@ -0,0 +1 @@ +{"lwe_dimension":902,"glwe_dimension":1,"polynomial_size":4096,"lwe_noise_distribution":{"Gaussian":{"std":1.0994794733558207e-6,"mean":0.0}},"glwe_noise_distribution":{"Gaussian":{"std":2.168404344971009e-19,"mean":0.0}},"pbs_base_log":15,"pbs_level":2,"ks_base_log":3,"ks_level":6,"message_modulus":4,"carry_modulus":8,"max_noise_level":10,"log2_p_fail":-64.084,"ciphertext_modulus":{"modulus":0,"scalar_bits":64},"encryption_key_choice":"Big"} diff --git a/frontends/concrete-python/tests/tfhers-utils/Cargo.lock b/frontends/concrete-python/tests/tfhers-utils/Cargo.lock index e07be46402..a90f9c843b 100644 --- a/frontends/concrete-python/tests/tfhers-utils/Cargo.lock +++ b/frontends/concrete-python/tests/tfhers-utils/Cargo.lock @@ -205,6 +205,16 @@ dependencies = [ "pulp", ] +[[package]] +name = "concrete_quantizer" +version = "0.1.0" +source = "git+https://github.com/zama-ai/concrete-ml-processing-rs#0e0ba9a036cc24fa38dad199230f74be7fdcfd3f" +dependencies = [ + "ndarray", + "serde", + "serde_json", +] + [[package]] name = "cpufeatures" version = "0.2.13" @@ -368,12 +378,38 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", + "serde", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -385,6 +421,15 @@ dependencies = [ "serde", ] +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -406,6 +451,21 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "portable-atomic" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" + +[[package]] +name = "portable-atomic-util" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90a7d5beecc52a491b54d6dd05c7a45ba1801666a5baad9fdbfc6fef8d2d206c" +dependencies = [ + "portable-atomic", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -445,6 +505,12 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -588,6 +654,8 @@ version = "0.1.0" dependencies = [ "bincode", "clap", + "concrete_quantizer", + "ndarray", "serde", "serde_json", "tfhe", diff --git a/frontends/concrete-python/tests/tfhers-utils/Cargo.toml b/frontends/concrete-python/tests/tfhers-utils/Cargo.toml index af0a7f67a8..6b13a49d88 100644 --- a/frontends/concrete-python/tests/tfhers-utils/Cargo.toml +++ b/frontends/concrete-python/tests/tfhers-utils/Cargo.toml @@ -11,6 +11,9 @@ serde_json = "1.0.128" clap = { version = "4.5.16", features = ["derive"] } +concrete_quantizer = { git = "https://github.com/zama-ai/concrete-ml-processing-rs" } +ndarray = "0.16.1" + tfhe = { version = "0.8.6", features = ["integer"] } [target.x86_64-unknown-linux-gnu.dependencies] diff --git a/frontends/concrete-python/tests/tfhers-utils/src/main.rs b/frontends/concrete-python/tests/tfhers-utils/src/main.rs index 546a1ffae0..a936df11ff 100644 --- a/frontends/concrete-python/tests/tfhers-utils/src/main.rs +++ b/frontends/concrete-python/tests/tfhers-utils/src/main.rs @@ -5,6 +5,7 @@ use std::path::Path; use clap::{Arg, ArgAction, Command}; +use concrete_quantizer::Quantizer; use tfhe::core_crypto::prelude::LweSecretKey; use tfhe::named::Named; use tfhe::prelude::*; @@ -17,7 +18,6 @@ use tfhe::{ use serde::de::DeserializeOwned; use serde::Serialize; -use serde_json; const BLOCK_PARAMS: ClassicPBSParameters = tfhe::shortint::prelude::PARAM_MESSAGE_2_CARRY_3_KS_PBS; const SERIALIZE_SIZE_LIMIT: u64 = 1_000_000_000; @@ -421,6 +421,70 @@ fn main() { .required(true), ), ) + .subcommand( + Command::new("quantize") + .long_flag("quantize") + .about("Quantize float values into integers.") + .arg( + Arg::new("config") + .short('c') + .long("config") + .help("quantizer configuration") + .required(true) + .action(ArgAction::Set) + .num_args(1), + ) + .arg( + Arg::new("value") + .short('v') + .long("value") + .help("value(s) to quantize") + .action(ArgAction::Set) + .required(true) + .value_delimiter(',') + .num_args(1..), + ) + .arg( + Arg::new("output") + .long("output") + .short('o') + .help("where to write quantized output") + .action(ArgAction::Set) + .num_args(1), + ), + ) + .subcommand( + Command::new("dequantize") + .long_flag("dequantize") + .about("Dequantize integers values into floats.") + .arg( + Arg::new("config") + .short('c') + .long("config") + .help("quantizer configuration") + .required(true) + .action(ArgAction::Set) + .num_args(1), + ) + .arg( + Arg::new("value") + .short('v') + .long("value") + .help("value(s) to quantize") + .action(ArgAction::Set) + .required(true) + .value_delimiter(',') + .num_args(1..), + ) + .arg( + Arg::new("output") + .long("output") + .short('o') + .help("where to write dequantized output") + .action(ArgAction::Set) + .num_args(1), + ), + ) .get_matches(); match matches.subcommand() { @@ -505,6 +569,53 @@ fn main() { let mut file = fs::File::create(path).unwrap(); file.write_all(json_string.as_bytes()).unwrap(); } + Some(("quantize", quantize_matches)) => { + let value_str: Vec<&String> = quantize_matches + .get_many::("value") + .unwrap() + .collect(); + let config_path = quantize_matches.get_one::("config").unwrap(); + let output_path = quantize_matches.get_one::("output"); + + let quantizer = Quantizer::from_json_file(config_path).unwrap(); + let value: Vec = value_str.iter().map(|v| v.parse().unwrap()).collect(); + let quantized_array = quantizer.quantize( + &ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[value.len()]), value).unwrap(), + ); + let quantized_values: Vec<&i64> = quantized_array.iter().collect(); + let results_str: Vec = quantized_values.iter().map(|v| v.to_string()).collect(); + let string_result = results_str.join(","); + if let Some(path) = output_path { + let pt_path: &Path = Path::new(path); + fs::write(pt_path, string_result).unwrap(); + } else { + println!("quantized: {}", string_result); + } + } + Some(("dequantize", dequantize_matches)) => { + let value_str: Vec<&String> = dequantize_matches + .get_many::("value") + .unwrap() + .collect(); + let config_path = dequantize_matches.get_one::("config").unwrap(); + let output_path = dequantize_matches.get_one::("output"); + + let quantizer = Quantizer::from_json_file(config_path).unwrap(); + let value: Vec = value_str.iter().map(|v| v.parse().unwrap()).collect(); + let dequantized_array = quantizer.dequantize( + &ndarray::ArrayD::from_shape_vec(ndarray::IxDyn(&[value.len()]), value).unwrap(), + ); + let dequantized_values: Vec<&f64> = dequantized_array.iter().collect(); + let results_str: Vec = + dequantized_values.iter().map(|v| v.to_string()).collect(); + let string_result = results_str.join(","); + if let Some(path) = output_path { + let pt_path: &Path = Path::new(path); + fs::write(pt_path, string_result).unwrap(); + } else { + println!("dequantized: {}", string_result); + } + } _ => unreachable!(), // If all subcommands are defined above, anything else is unreachable } }