Skip to content

Commit

Permalink
chore(frontend-python): update black dev dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
aquint-zama committed May 6, 2024
1 parent b030148 commit 8f26a73
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 82 deletions.
14 changes: 8 additions & 6 deletions frontends/concrete-python/concrete/fhe/compilation/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,14 @@ def simulate(self, *args: Any) -> Any:

exporter = SimulatedValueExporter.new(self.simulator.client_specs.client_parameters)
exported = [
None
if arg is None
else Value(
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
if isinstance(arg, np.ndarray) and arg.shape != ()
else exporter.export_scalar(position, int(arg))
(
None
if arg is None
else Value(
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
if isinstance(arg, np.ndarray) and arg.shape != ()
else exporter.export_scalar(position, int(arg))
)
)
for position, arg in enumerate(ordered_validated_args)
]
Expand Down
14 changes: 8 additions & 6 deletions frontends/concrete-python/concrete/fhe/compilation/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,14 @@ def encrypt(

exporter = ValueExporter.new(keyset, self.specs.client_parameters, function_name)
exported = [
None
if arg is None
else Value(
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
if isinstance(arg, np.ndarray) and arg.shape != ()
else exporter.export_scalar(position, int(arg))
(
None
if arg is None
else Value(
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
if isinstance(arg, np.ndarray) and arg.shape != ()
else exporter.export_scalar(position, int(arg))
)
)
for position, arg in enumerate(ordered_sanitized_args)
]
Expand Down
16 changes: 10 additions & 6 deletions frontends/concrete-python/concrete/fhe/compilation/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,11 @@ def trace(
self.artifacts = (
artifacts
if artifacts is not None
else DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
else (
DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
)
)

try:
Expand Down Expand Up @@ -447,9 +449,11 @@ def compile(
self.artifacts = (
artifacts
if artifacts is not None
else DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
else (
DebugArtifacts()
if self.configuration.dump_artifacts_on_unexpected_failures
else None
)
)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1233,9 +1233,11 @@ def fork(
args = locals()
return Configuration(
**{
name: getattr(self, name)
if isinstance(args[name], Configuration.Keep)
else args[name]
name: (
getattr(self, name)
if isinstance(args[name], Configuration.Keep)
else args[name]
)
for name in get_type_hints(Configuration.__init__)
}
)
Expand Down
14 changes: 8 additions & 6 deletions frontends/concrete-python/concrete/fhe/compilation/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,14 @@ def simulate(self, *args: Any) -> Any:
self.runtime.server.client_specs.client_parameters, self.name
)
exported = [
None
if arg is None
else Value(
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
if isinstance(arg, np.ndarray) and arg.shape != ()
else exporter.export_scalar(position, int(arg))
(
None
if arg is None
else Value(
exporter.export_tensor(position, arg.flatten().tolist(), list(arg.shape))
if isinstance(arg, np.ndarray) and arg.shape != ()
else exporter.export_scalar(position, int(arg))
)
)
for position, arg in enumerate(ordered_validated_args)
]
Expand Down
8 changes: 5 additions & 3 deletions frontends/concrete-python/concrete/fhe/mlir/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,9 +2368,11 @@ def index_static(
size = 1
stride = 1
offset = int(
indexing_element
if indexing_element >= 0
else indexing_element + dimension_size,
(
indexing_element
if indexing_element >= 0
else indexing_element + dimension_size
),
)

offsets.append(offset)
Expand Down
4 changes: 1 addition & 3 deletions frontends/concrete-python/concrete/fhe/mlir/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,9 +749,7 @@ def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:

if variable_input.bit_width > original_bit_width:
bit_width_difference = variable_input.bit_width - original_bit_width
shifter = ctx.constant(
ctx.i(variable_input.bit_width + 1), 2**bit_width_difference
)
shifter = ctx.constant(ctx.i(variable_input.bit_width + 1), 2**bit_width_difference)
variable_input = ctx.mul(variable_input.type, variable_input, shifter)

variable_input = ctx.reinterpret(variable_input, bit_width=truncated_bit_width)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def main():
global_p_error=None, # 2**log2_global_p_error,
p_error=2**log2_p_error,
bitwise_strategy_preference=fhe.BitwiseStrategy.ONE_TLU_PROMOTED,
verbose=args.verbose_compilation
verbose=args.verbose_compilation,
# parameter_selection_strategy=fhe.ParameterSelectionStrategy.MULTI,
# single_precision=False,
)
Expand Down
2 changes: 1 addition & 1 deletion frontends/concrete-python/requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pytest-randomly==3.15.0
pytest-xdist==3.2.1
pytest==7.2.2

black==23.1.0
black==24.4.0
isort==5.12.0

mypy==1.1.1
Expand Down
4 changes: 1 addition & 3 deletions frontends/concrete-python/tests/compilation/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,7 @@ def test_circuit_run_with_unused_arg(helpers):
def f(x, y): # pylint: disable=unused-argument
return x + 10

inputset = [
(np.random.randint(2**3, 2**4), np.random.randint(2**4, 2**5)) for _ in range(100)
]
inputset = [(np.random.randint(2**3, 2**4), np.random.randint(2**4, 2**5)) for _ in range(100)]
circuit = f.compile(inputset, configuration)

with pytest.raises(ValueError, match="Expected 2 inputs but got 1"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def test_configuration_bad_fork(kwargs, expected_error, expected_message):
main.%1 == main.%4
""",
"""
(
"""
main.%0 = 3
main.%1 = 8
Expand All @@ -240,8 +241,8 @@ def test_configuration_bad_fork(kwargs, expected_error, expected_message):
main.max = 8
"""
if USE_MULTI_PRECISION
else """
if USE_MULTI_PRECISION
else """
main.%0 = 8
main.%1 = 8
Expand All @@ -250,7 +251,8 @@ def test_configuration_bad_fork(kwargs, expected_error, expected_message):
main.%4 = 8
main.max = 8
""",
"""
),
),
],
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests of everything related to multi-circuit.
"""

import tempfile

import numpy as np
Expand Down
Loading

0 comments on commit 8f26a73

Please sign in to comment.