Skip to content

Commit

Permalink
add auto rounding to force 8 bits output in ml_inference
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery authored and youben11 committed Nov 18, 2024
1 parent ddf7827 commit b38d49c
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions frontends/concrete-python/examples/tfhers-ml/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
tfhers_int = partial(tfhers.TFHERSInteger, tfhers_type)

rounder = fhe.AutoRounder(target_msbs=8) # We want to keep 8 MSBs

q_weights = np.array(
[
Expand Down Expand Up @@ -66,6 +67,8 @@ def ml_inference(q_X: np.ndarray) -> np.ndarray:
# Quantizing weights and inputs makes an additional term appear in the inference function
y_pred = q_X @ q_weights - weight_quantizer_zero_point * np.sum(q_X, axis=1, keepdims=True)
y_pred += q_bias
y_pred = fhe.round_bit_pattern(y_pred, rounder)
y_pred = (y_pred >> rounder.lsbs_to_remove)
return y_pred


Expand Down Expand Up @@ -110,6 +113,10 @@ def ccompilee():
),
)
]

# Add the auto-adjustment before compilation
fhe.AutoRounder.adjust(compute, inputset)

circuit = compiler.compile(inputset, show_graph=True, show_mlir=True)

tfhers_bridge = tfhers.new_bridge(circuit=circuit)
Expand Down

0 comments on commit b38d49c

Please sign in to comment.