Skip to content

Commit

Permalink
feat(frontend): provide an API to reset the compiler state
Browse files Browse the repository at this point in the history
  • Loading branch information
umut-sahin authored and BourgerieQuentin committed Sep 13, 2024
1 parent 69d5a35 commit a8f435f
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,10 @@ def pretty(d, indent=0): # pragma: no cover
return circuit

# pylint: enable=too-many-branches,too-many-statements

def reset(self):
"""
Reset the compiler so that another compilation with another inputset can be performed.
"""
fresh_compiler = Compiler(self.function, self.parameter_encryption_statuses)
self.__dict__.update(fresh_compiler.__dict__)
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def compile(

return self.compiler.compile(inputset, configuration, artifacts, **kwargs)

def reset(self):
"""
Reset the compilable so that another compilation with another inputset can be performed.
"""

self.compiler.reset()


def compiler(parameters: Mapping[str, Union[str, EncryptionStatus]]):
"""
Expand Down
62 changes: 62 additions & 0 deletions frontends/concrete-python/tests/compilation/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,3 +434,65 @@ def f(x):
helpers.configuration().fork(enable_tlu_fusing=False),
)
assert circuit4.programmable_bootstrap_count == 6


def test_compiler_reset(helpers):
def f(x, y):
return x + y

configuration = helpers.configuration()
compiler = fhe.Compiler(f, {"x": "encrypted", "y": "encrypted"})

inputset1 = fhe.inputset(fhe.uint3, fhe.uint3)
circuit1 = compiler.compile(inputset1, configuration)

helpers.check_str(
"""
module {
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
return %0 : !FHE.eint<4>
}
}
""".strip(),
circuit1.mlir.strip(),
)
compiler.reset()

inputset2 = fhe.inputset(fhe.uint10, fhe.uint10)
circuit2 = compiler.compile(inputset2, configuration)

helpers.check_str(
"""
module {
func.func @main(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<11>, !FHE.eint<11>) -> !FHE.eint<11>
return %0 : !FHE.eint<11>
}
}
""".strip(),
circuit2.mlir.strip(),
)
compiler.reset()

inputset3 = fhe.inputset(fhe.tensor[fhe.uint2, 3, 2], fhe.tensor[fhe.uint2, 2]) # type: ignore
circuit3 = compiler.compile(inputset3, configuration)

helpers.check_str(
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
%0 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<3>>, tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>>
return %0 : tensor<3x2x!FHE.eint<3>>
}
}
""".strip(),
circuit3.mlir.strip(),
)
compiler.reset()
62 changes: 62 additions & 0 deletions frontends/concrete-python/tests/compilation/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,65 @@ def circuit4(x: fhe.uint3):
assert str(excinfo.value) == (
"'round(x)' cannot be used in direct definition (you may use np.around instead)"
)


def test_compiler_reset(helpers):
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def compiler(x, y):
return x + y

configuration = helpers.configuration()

inputset1 = fhe.inputset(fhe.uint3, fhe.uint3)
circuit1 = compiler.compile(inputset1, configuration)

helpers.check_str(
"""
module {
func.func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<4>, !FHE.eint<4>) -> !FHE.eint<4>
return %0 : !FHE.eint<4>
}
}
""".strip(),
circuit1.mlir.strip(),
)
compiler.reset()

inputset2 = fhe.inputset(fhe.uint10, fhe.uint10)
circuit2 = compiler.compile(inputset2, configuration)

helpers.check_str(
"""
module {
func.func @main(%arg0: !FHE.eint<11>, %arg1: !FHE.eint<11>) -> !FHE.eint<11> {
%0 = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<11>, !FHE.eint<11>) -> !FHE.eint<11>
return %0 : !FHE.eint<11>
}
}
""".strip(),
circuit2.mlir.strip(),
)
compiler.reset()

inputset3 = fhe.inputset(fhe.tensor[fhe.uint2, 3, 2], fhe.tensor[fhe.uint2, 2]) # type: ignore
circuit3 = compiler.compile(inputset3, configuration)

helpers.check_str(
"""
module {
func.func @main(%arg0: tensor<3x2x!FHE.eint<3>>, %arg1: tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>> {
%0 = "FHELinalg.add_eint"(%arg0, %arg1) : (tensor<3x2x!FHE.eint<3>>, tensor<2x!FHE.eint<3>>) -> tensor<3x2x!FHE.eint<3>>
return %0 : tensor<3x2x!FHE.eint<3>>
}
}
""".strip(),
circuit3.mlir.strip(),
)
compiler.reset()

0 comments on commit a8f435f

Please sign in to comment.