Skip to content

Commit

Permalink
chore: fix pcc
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Apr 17, 2024
1 parent 33e1cc8 commit 0dee236
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 15 deletions.
10 changes: 2 additions & 8 deletions benchmarks/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,14 @@ def get_preprocessor() -> ColumnTransformer:

def get_train_test_data(data: pandas.DataFrame) -> Tuple[pandas.DataFrame, pandas.DataFrame]:
"""Split the data into a train and test set."""
(
train_data,
test_data,
) = train_test_split(
(train_data, test_data,) = train_test_split(
data,
test_size=0.2,
random_state=0,
)

# The test set is reduced for faster FHE runs.
(
_,
test_data,
) = train_test_split(
(_, test_data,) = train_test_split(
test_data,
test_size=500,
random_state=0,
Expand Down
4 changes: 3 additions & 1 deletion src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,9 @@ def __init__(self, input_output, activation_function, n_bits=2, disable_bit_chec
n_bits_weights = n_bits

# Generate the pattern 0, 1, ..., 2^N-1, 0, 1, .. 2^N-1, 0, 1..
all_weights = numpy.mod(numpy.arange(numpy.prod(self.fc1.weight.shape)), 2**n_bits_weights)
all_weights = numpy.mod(
numpy.arange(numpy.prod(self.fc1.weight.shape)), 2**n_bits_weights
)

# Shuffle the pattern and reshape to weight shape
numpy.random.shuffle(all_weights)
Expand Down
12 changes: 6 additions & 6 deletions src/concrete/ml/quantization/base_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,14 @@ def dump_dict(self) -> Dict:
metadata["_input_idx_to_params_name"] = self._input_idx_to_params_name
metadata["_params_that_are_onnx_inputs"] = self._params_that_are_onnx_inputs
metadata["_params_that_are_onnx_var_inputs"] = self._params_that_are_onnx_var_inputs
metadata["_params_that_are_required_onnx_inputs"] = (
self._params_that_are_required_onnx_inputs
)
metadata[
"_params_that_are_required_onnx_inputs"
] = self._params_that_are_required_onnx_inputs
metadata["_has_attr"] = self._has_attr
metadata["_inputs_not_quantized"] = self._inputs_not_quantized
metadata["quantize_inputs_with_model_outputs_precision"] = (
self.quantize_inputs_with_model_outputs_precision
)
metadata[
"quantize_inputs_with_model_outputs_precision"
] = self.quantize_inputs_with_model_outputs_precision
metadata["produces_graph_output"] = self.produces_graph_output
metadata["produces_raw_output"] = self.produces_raw_output
metadata["error_tracker"] = self.error_tracker
Expand Down

0 comments on commit 0dee236

Please sign in to comment.