Skip to content

Commit

Permalink
feat: add multi-output support
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Nov 24, 2023
1 parent 111c7e3 commit a089021
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 25 deletions.
20 changes: 10 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ OPEN_PR="true"
# Force the installation of a Concrete Python version, which is very useful with nightly versions
# /!\ WARNING /!\: This version should NEVER be a wildcard as it might create some
# issues when trying to run it in the future.
CONCRETE_PYTHON_VERSION="concrete-python==2.5.0rc1"
CONCRETE_PYTHON_VERSION="concrete-python==2023.11.23"

# Force the installation of Concrete Python's latest version, release-candidates included
# CONCRETE_PYTHON_VERSION="$$(poetry run python \
Expand All @@ -32,30 +32,30 @@ setup_env:
@# The keyring install is to allow pip to fetch credentials for our internal repo if needed
PIP_INDEX_URL=https://pypi.org/simple \
PIP_EXTRA_INDEX_URL=https://pypi.org/simple \
poetry run python --version
poetry run python -m pip install keyring
poetry run python -m pip install -U pip wheel
python -m poetry run python --version
python -m poetry run python -m pip install keyring
python -m poetry run python -m pip install -U pip wheel

@# Only for linux and docker, reinstall setuptools. On macOS, it creates warnings, see 169
if [[ $$(uname) != "Darwin" ]]; then \
poetry run python -m pip install -U --force-reinstall setuptools; \
python -m poetry run python -m pip install -U --force-reinstall setuptools; \
fi
if [[ $$(uname) != "Linux" ]] && [[ $$(uname) != "Darwin" ]]; then \
poetry install --only dev; \
python -m poetry install --only dev; \
else \
poetry install; \
python -m poetry install; \
fi

echo "Installing $(CONCRETE_PYTHON_VERSION)" && \
poetry run python -m pip install -U --pre "$(CONCRETE_PYTHON_VERSION)" --no-deps
python -m poetry run python -m pip install -U --pre "$(CONCRETE_PYTHON_VERSION)" --no-deps
"$(MAKE)" fix_omp_issues_for_intel_mac

.PHONY: sync_env # Synchronise the environment
sync_env: check_poetry_version
if [[ $$(uname) != "Linux" ]] && [[ $$(uname) != "Darwin" ]]; then \
poetry install --remove-untracked --only dev; \
python -m poetry install --remove-untracked --only dev; \
else \
poetry install --remove-untracked; \
python -m poetry install --remove-untracked; \
fi
"$(MAKE)" setup_env

Expand Down
2 changes: 1 addition & 1 deletion deps_licenses/licenses_mac_silicon_user.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ certifi, 2023.7.22, Mozilla Public License 2.0 (MPL 2.0)
charset-normalizer, 3.2.0, MIT License
click, 8.1.7, BSD License
coloredlogs, 15.0.1, MIT License
concrete-python, 2.5.0rc1, BSD-3-Clause
concrete-python, 2023.11.23, BSD-3-Clause
dependencies, 2.0.1, BSD License
dill, 0.3.7, BSD License
exceptiongroup, 1.1.3, MIT License
Expand Down
2 changes: 1 addition & 1 deletion deps_licenses/licenses_mac_silicon_user.txt.md5
Original file line number Diff line number Diff line change
@@ -1 +1 @@
6e23de913fba55c72e9420fdef5e78de
a87fa07ad174cbccb364c183680c2cd9
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ readme = "README.md"
# Investigate if it is better to fix specific versions or use lower and upper bounds
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2665
python = ">=3.8.1,<3.11"
concrete-python = "2.5.0-rc1 "
#concrete-python = "2023.11.23"
setuptools = "65.6.3"
skops = {version = "0.5.0"}
xgboost = "1.6.2"
Expand Down
22 changes: 22 additions & 0 deletions src/concrete/ml/pytest/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,28 @@
# pylint: disable=too-many-lines


class MultiOutputModel(nn.Module):
"""Multi-output model"""

def __init__(
self,
) -> None:
"""Torch Model."""
super().__init__()

def forward(self, x, y):
"""Forward pass.
Args:
x (torch.Tensor): The input of the model.
y (torch.Tensor): The input of the model.
Returns:
Tuple[torch.Tensor. torch.Tensor]: Output of the network.
"""
return x + y, (x - y) ** 2


class SimpleNet(torch.nn.Module):
"""Fake torch model used to generate some onnx."""

Expand Down
14 changes: 2 additions & 12 deletions src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,6 @@ def __init__(
self.ordered_module_input_names = tuple(ordered_module_input_names)
self.ordered_module_output_names = tuple(ordered_module_output_names)

num_outputs = len(self.ordered_module_output_names)
assert_true(
(num_outputs) == 1,
f"{QuantizedModule.__class__.__name__} only supports a single output for now, "
f"got {num_outputs}",
)

assert quant_layers_dict is not None
self.quant_layers_dict = copy.deepcopy(quant_layers_dict)
self.output_quantizers = self._set_output_quantizers()
Expand Down Expand Up @@ -445,12 +438,10 @@ def _clear_forward(self, *q_x: numpy.ndarray) -> numpy.ndarray:
layer_results[output_name] for output_name in self.ordered_module_output_names
)

assert_true(len(output_quantized_arrays) == 1)

# The output of a graph must be a QuantizedArray
assert isinstance(output_quantized_arrays[0], QuantizedArray)
assert all(isinstance(elt, QuantizedArray) for elt in output_quantized_arrays)

return output_quantized_arrays[0].qvalues
return tuple(elt.qvalues for elt in output_quantized_arrays)

def _fhe_forward(self, *q_x: numpy.ndarray, simulate: bool = True) -> numpy.ndarray:
"""Forward function executed in FHE or with simulation.
Expand Down Expand Up @@ -688,7 +679,6 @@ def compile(
single_precision=False,
fhe_simulation=False,
fhe_execution=True,
jit=False,
)

self._is_compiled = True
Expand Down
15 changes: 15 additions & 0 deletions tests/torch/test_compile_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MultiInputNN,
MultiInputNNConfigurable,
MultiInputNNDifferentSize,
MultiOutputModel,
NetWithLoops,
PaddingNet,
ShapeOperationsNet,
Expand Down Expand Up @@ -1256,6 +1257,20 @@ def test_fancy_indexing_torch(model_object, default_configuration):
compile_brevitas_qat_model(model, x, n_bits=4, configuration=default_configuration)


@pytest.mark.parametrize(
"model_object",
[
pytest.param(MultiOutputModel),
],
)
def test_multi_output(model_object, default_configuration):
"""Test fancy indexing torch."""
model = model_object()
x = numpy.random.randint(0, 2, size=(100, 3, 10)).astype(numpy.float64)
y = numpy.random.randint(0, 2, size=(100, 3, 10)).astype(numpy.float64)
compile_torch_model(model, (x, y), n_bits=4, configuration=default_configuration)


@pytest.mark.parametrize(
"model, input_output_feature",
[
Expand Down

0 comments on commit a089021

Please sign in to comment.