Skip to content

Commit

Permalink
feat(ci): update ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
bcm-at-zama committed Sep 5, 2024
1 parent 2c2f149 commit 8477f95
Show file tree
Hide file tree
Showing 20 changed files with 233 additions and 165 deletions.
6 changes: 4 additions & 2 deletions frontends/concrete-python/.ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ ignore = [
"concrete/fhe/mlir/processors/all.py" = ["F401"]
"concrete/fhe/mlir/processors/assign_bit_widths.py" = ["ARG002"]
"concrete/fhe/mlir/converter.py" = ["ARG002", "B011", "F403", "F405"]
"examples/**" = ["PLR2004"]
"tests/**" = ["PLR2004", "PLW0603", "SIM300", "S311"]
"concrete/**" = ["RUF010"]
"examples/**" = ["PLR2004", "RUF010"]
"tests/**" = ["PLR2004", "PLW0603", "SIM300", "S311", "RUF010"]
"tests/execution/test_tfhers.py" = ["S605"]
"benchmarks/**" = ["S311", "B023"]
"scripts/**" = ["DTZ005"]
7 changes: 6 additions & 1 deletion frontends/concrete-python/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,12 @@ pylint:

ruff:
eval $(shell make silent_cp_activate)
ruff concrete/ examples/ scripts/ tests/ benchmarks/

ruff check concrete/
ruff check examples/
ruff check scripts/
ruff check tests/
ruff check benchmarks/

check-links:
@# Check that no links target the main branch, some internal repositories (Concrete ML or Concrete) or our internal GitBook
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ def export(self):
# wrapt 1.12.1
# zipp 3.5.0

pip_process = subprocess.run(
["pip", "--disable-pip-version-check", "list"],
# S603 `subprocess` call: check for execution of untrusted input
# S607 Starting a process with a partial executable path
pip_process = subprocess.run( # noqa: S603
["pip", "--disable-pip-version-check", "list"], # noqa: S607
stdout=subprocess.PIPE,
check=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,12 +464,15 @@ def get_rules_iter(self, _) -> Iterable[CompositionRule]:
)


class Wired(NamedTuple):
class Wired:
"""
Composition policy which allows the forwarding of certain outputs to certain inputs.
"""

wires: Set[Wire] = set()
wires: Set[Wire]

def __init__(self, wires: Optional[Set[Wire]] = None):
self.wires = wires if wires else set()

def get_rules_iter(self, _) -> Iterable[CompositionRule]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
KeyType,
OptimizerMultiParameterStrategy,
OptimizerStrategy,
Encoding,
PrimitiveOperation,
)
from mlir.ir import Module as MlirModule
Expand Down
5 changes: 1 addition & 4 deletions frontends/concrete-python/concrete/fhe/compilation/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ def __init__(self, client_parameters: ClientParameters):
self.client_parameters = client_parameters

def __eq__(self, other: Any): # pragma: no cover
if self.client_parameters.serialize() != other.client_parameters.serialize():
return False

return True
return self.client_parameters.serialize() == other.client_parameters.serialize()

def serialize(self) -> bytes:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def _trace_conv(
"conv2d": _evaluate_conv2d,
"conv3d": _evaluate_conv3d,
}
eval_func = conv_eval_funcs.get(conv_func, None)
eval_func = conv_eval_funcs.get(conv_func)
assert_that(
eval_func is not None,
f"expected conv_func to be one of {list(conv_eval_funcs.keys())}, but got {conv_func}",
Expand Down Expand Up @@ -658,7 +658,7 @@ def _evaluate_conv(
"conv3d": torch.conv3d,
}

torch_conv_func = conv_funcs.get(conv_func, None)
torch_conv_func = conv_funcs.get(conv_func)
assert_that(
torch_conv_func is not None,
f"expected conv_func to be one of {list(conv_funcs.keys())}, but got {conv_func}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,65 +396,65 @@ def min_max(self, node: Node, preds: List[Node]):
# Operations
# ==========

add = {
add = { # noqa: RUF012
inputs_and_output_share_precision,
}

array = {
array = { # noqa: RUF012
inputs_and_output_share_precision,
}

assign_dynamic = {
assign_dynamic = { # noqa: RUF012
inputs_and_output_share_precision,
}

assign_static = {
assign_static = { # noqa: RUF012
inputs_and_output_share_precision,
}

bitwise_and = {
bitwise_and = { # noqa: RUF012
all_inputs_are_encrypted: {
bitwise,
},
}

bitwise_or = {
bitwise_or = { # noqa: RUF012
all_inputs_are_encrypted: {
bitwise,
},
}

bitwise_xor = {
bitwise_xor = { # noqa: RUF012
all_inputs_are_encrypted: {
bitwise,
},
}

broadcast_to = {
broadcast_to = { # noqa: RUF012
inputs_and_output_share_precision,
}

concatenate = {
concatenate = { # noqa: RUF012
inputs_and_output_share_precision,
}

conv1d = {
conv1d = { # noqa: RUF012
inputs_and_output_share_precision,
}

conv2d = {
conv2d = { # noqa: RUF012
inputs_and_output_share_precision,
}

conv3d = {
conv3d = { # noqa: RUF012
inputs_and_output_share_precision,
}

copy = {
copy = { # noqa: RUF012
inputs_and_output_share_precision,
}

dot = {
dot = { # noqa: RUF012
all_inputs_are_encrypted: {
inputs_share_precision,
inputs_require_one_more_bit,
Expand All @@ -464,51 +464,51 @@ def min_max(self, node: Node, preds: List[Node]):
},
}

equal = {
equal = { # noqa: RUF012
all_inputs_are_encrypted: {
comparison,
},
}

expand_dims = {
expand_dims = { # noqa: RUF012
inputs_and_output_share_precision,
}

greater = {
greater = { # noqa: RUF012
all_inputs_are_encrypted: {
comparison,
},
}

greater_equal = {
greater_equal = { # noqa: RUF012
all_inputs_are_encrypted: {
comparison,
},
}

index_static = {
index_static = { # noqa: RUF012
inputs_and_output_share_precision,
}

left_shift = {
left_shift = { # noqa: RUF012
all_inputs_are_encrypted: {
bitwise,
},
}

less = {
less = { # noqa: RUF012
all_inputs_are_encrypted: {
comparison,
},
}

less_equal = {
less_equal = { # noqa: RUF012
all_inputs_are_encrypted: {
comparison,
},
}

matmul = {
matmul = { # noqa: RUF012
all_inputs_are_encrypted: {
inputs_share_precision,
inputs_require_one_more_bit,
Expand All @@ -518,34 +518,34 @@ def min_max(self, node: Node, preds: List[Node]):
},
}

maximum = {
maximum = { # noqa: RUF012
all_inputs_are_encrypted: {
min_max,
},
}

maxpool1d = {
maxpool1d = { # noqa: RUF012
inputs_and_output_share_precision,
inputs_require_one_more_bit,
}

maxpool2d = {
maxpool2d = { # noqa: RUF012
inputs_and_output_share_precision,
inputs_require_one_more_bit,
}

maxpool3d = {
maxpool3d = { # noqa: RUF012
inputs_and_output_share_precision,
inputs_require_one_more_bit,
}

minimum = {
minimum = { # noqa: RUF012
all_inputs_are_encrypted: {
min_max,
},
}

multiply = {
multiply = { # noqa: RUF012
all_inputs_are_encrypted: {
inputs_share_precision,
inputs_require_one_more_bit,
Expand All @@ -555,48 +555,48 @@ def min_max(self, node: Node, preds: List[Node]):
},
}

negative = {
negative = { # noqa: RUF012
inputs_and_output_share_precision,
}

not_equal = {
not_equal = { # noqa: RUF012
all_inputs_are_encrypted: {
comparison,
},
}

reshape = {
reshape = { # noqa: RUF012
inputs_and_output_share_precision,
}

right_shift = {
right_shift = { # noqa: RUF012
all_inputs_are_encrypted: {
bitwise,
},
}

round_bit_pattern = {
round_bit_pattern = { # noqa: RUF012
has_overflow_protection: {
inputs_and_output_share_precision,
},
}

subtract = {
subtract = { # noqa: RUF012
inputs_and_output_share_precision,
}

sum = {
sum = { # noqa: RUF012
inputs_and_output_share_precision,
}

squeeze = {
squeeze = { # noqa: RUF012
inputs_and_output_share_precision,
}

transpose = {
transpose = { # noqa: RUF012
inputs_and_output_share_precision,
}

truncate_bit_pattern = {
truncate_bit_pattern = { # noqa: RUF012
inputs_and_output_share_precision,
}
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,10 @@ def format(
bounds += "∈ ["

lower, upper = node.bounds
assert type(lower) == type(upper) # pylint: disable=unidiomatic-typecheck

# pylint: disable=unidiomatic-typecheck
assert type(lower) == type(upper) # noqa: E721
# pylint: enable=unidiomatic-typecheck

if isinstance(lower, (float, np.float32, np.float64)):
bounds += f"{round(lower, 6)}, {round(upper, 6)}"
Expand Down
6 changes: 3 additions & 3 deletions frontends/concrete-python/concrete/fhe/tracing/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, cast
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Type, Union, cast

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -216,7 +216,7 @@ def sanitize(value: Any) -> Any:
computation = Node.constant(value)
return Tracer(computation, [])

SUPPORTED_NUMPY_OPERATORS: Set[Any] = {
SUPPORTED_NUMPY_OPERATORS: ClassVar[Set[Any]] = {
np.abs,
np.absolute,
np.add,
Expand Down Expand Up @@ -318,7 +318,7 @@ def sanitize(value: Any) -> Any:
np.zeros_like,
}

SUPPORTED_KWARGS: Dict[Any, Set[str]] = {
SUPPORTED_KWARGS: ClassVar[Dict[Any, Set[str]]] = {
np.around: {
"decimals",
},
Expand Down
Loading

0 comments on commit 8477f95

Please sign in to comment.