Skip to content

Commit

Permalink
fix(frontends): forbid clear nodes in composition rules
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 committed Sep 20, 2024
1 parent 9ae9a53 commit 4d2eb73
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ class NotComposable:
Composition policy that does not allow the forwarding of any output to any input.
"""

def get_rules_iter(self, _funcs: List[FunctionDef]) -> Iterable[CompositionRule]:
def get_rules_iter(self, _) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
Expand All @@ -341,18 +341,17 @@ def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
outputs = chain(
*[
map(CompositionClause.create, zip(repeat(f.name), range(len(f.output_nodes))))
for f in funcs
]
)
inputs = chain(
*[
map(CompositionClause.create, zip(repeat(f.name), range(len(f.input_nodes))))
for f in funcs
]
)
outputs = []
for f in funcs:
for pos, node in f.output_nodes.items():
if node.output.is_encrypted:
outputs.append(CompositionClause.create((f.name, pos)))
inputs = []
for f in funcs:
for pos, node in f.input_nodes.items():
if node.output.is_encrypted:
inputs.append(CompositionClause.create((f.name, pos)))

return map(CompositionRule.create, product(outputs, inputs))


Expand Down Expand Up @@ -397,7 +396,7 @@ def get_outputs_iter(self) -> Iterable[CompositionClause]:

class AllOutputs(NamedTuple):
"""
All the outputs of a given function of a module.
All the encrypted outputs of a given function of a module.
"""

func: FunctionDef
Expand All @@ -407,6 +406,7 @@ def get_outputs_iter(self) -> Iterable[CompositionClause]:
Return an iterator over the possible outputs of the wire output.
"""
assert self.func.graph # pragma: no cover
# No need to filter since only encrypted outputs are valid.
return map( # pragma: no cover
CompositionClause.create,
zip(repeat(self.func.name), range(self.func.graph.outputs_count)),
Expand All @@ -430,7 +430,7 @@ def get_inputs_iter(self) -> Iterable[CompositionClause]:

class AllInputs(NamedTuple):
"""
All the inputs of a given function of a module.
All the encrypted inputs of a given function of a module.
"""

func: FunctionDef
Expand All @@ -440,10 +440,11 @@ def get_inputs_iter(self) -> Iterable[CompositionClause]:
Return an iterator over the possible inputs of the wire input.
"""
assert self.func.graph # pragma: no cover
return map( # pragma: no cover
CompositionClause.create,
zip(repeat(self.func.name), range(self.func.graph.inputs_count)),
)
output = []
for i in range(self.func.graph.inputs_count):
if self.func.graph.input_nodes[i].output.is_encrypted:
output.append(CompositionClause.create((self.func.name, i)))
return output


class Wire(NamedTuple):
Expand Down Expand Up @@ -474,11 +475,27 @@ class Wired:
def __init__(self, wires: Optional[Set[Wire]] = None):
self.wires = wires if wires else set()

def get_rules_iter(self, _) -> Iterable[CompositionRule]:
def get_rules_iter(self, funcs: List[Graph]) -> Iterable[CompositionRule]:
"""
Return an iterator over composition rules.
"""
return chain(*[w.get_rules_iter(_) for w in self.wires])
funcsd = {f.name: f for f in funcs}
rules = list(chain(*[w.get_rules_iter(funcs) for w in self.wires]))

# We check that the given rules are legit (they concern only encrypted values)
for rule in rules:
if (
not funcsd[rule.from_.func].output_nodes[rule.from_.pos].output.is_encrypted
): # pragma: no cover
message = f"Invalid composition rule encountered: \
Output {rule.from_.pos} of {rule.from_.func} is not encrypted"
raise RuntimeError(message)
if not funcsd[rule.to.func].input_nodes[rule.to.pos].output.is_encrypted:
message = f"Invalid composition rule encountered: \
Input {rule.from_.pos} of {rule.from_.func} is not encrypted"
raise RuntimeError(message)

return rules


class DebugManager:
Expand Down
73 changes: 73 additions & 0 deletions frontends/concrete-python/tests/compilation/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,3 +879,76 @@ def inc(x):

assert module.execution_runtime.initialized
assert module.simulation_runtime.initialized


def test_all_composable_with_clears(helpers):

@fhe.module()
class Module:
@fhe.function({"x": "encrypted", "y": "clear"})
def inc(x, y):
return (x + y + 1) % 20

module = Module.compile(
{
"inc": [
(np.random.randint(1, 20, size=()), np.random.randint(1, 20, size=()))
for _ in range(100)
]
},
helpers.configuration().fork(),
)


def test_wired_with_invalid_wire(helpers):

@fhe.module()
class Module:
@fhe.function({"x": "encrypted", "y": "clear"})
def inc(x, y):
return (x + y + 1) % 20

composition = fhe.Wired(
{
fhe.Wire(fhe.Output(inc, 0), fhe.Input(inc, 0)),
fhe.Wire(fhe.Output(inc, 0), fhe.Input(inc, 1)), # Faulty one
}
)

with pytest.raises(
Exception, match="Invalid composition rule encountered: Input 0 of inc is not encrypted"
):
module = Module.compile(
{
"inc": [
(np.random.randint(1, 20, size=()), np.random.randint(1, 20, size=()))
for _ in range(100)
]
},
helpers.configuration().fork(),
)


def test_wired_with_all_encrypted_inputs(helpers):

@fhe.module()
class Module:
@fhe.function({"x": "encrypted", "y": "clear"})
def inc(x, y):
return (x + y + 1) % 20

composition = fhe.Wired(
{
fhe.Wire(fhe.Output(inc, 0), fhe.AllInputs(inc)),
}
)

module = Module.compile(
{
"inc": [
(np.random.randint(1, 20, size=()), np.random.randint(1, 20, size=()))
for _ in range(100)
]
},
helpers.configuration().fork(),
)

0 comments on commit 4d2eb73

Please sign in to comment.