Skip to content

Commit

Permalink
chore(frontends): add tests to increase coverage of fhe modules
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 committed Apr 16, 2024
1 parent a3762ab commit 34de883
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,15 @@ class ModuleDebugArtifacts:

def __init__(
self,
function_names: List[str],
function_names: Optional[List[str]] = None,
output_directory: Union[str, Path] = DEFAULT_OUTPUT_DIRECTORY,
):
self.output_directory = Path(output_directory)
self.mlir_to_compile = None
self.client_parameters = None
self.functions = {name: FunctionDebugArtifacts() for name in function_names}
self.functions = (
{name: FunctionDebugArtifacts() for name in function_names} if function_names else {}
)

def add_mlir_to_compile(self, mlir: str):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def decoration(class_):
if not functions:
error = "Tried to define an @fhe.module without any @fhe.function"
raise RuntimeError(error)
return functools.wraps(class_)(ModuleCompiler([f for (_, f) in functions]))
return ModuleCompiler([f for (_, f) in functions])

return decoration

Expand Down
7 changes: 5 additions & 2 deletions frontends/concrete-python/concrete/fhe/compilation/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,9 +674,12 @@ def functions(self) -> Dict[str, FheFunction]:
"""
Return a dictionnary containing all the functions of the module.
"""
return {name: getattr(self, name) for name in self.graphs.keys()}
return {
name: FheFunction(name, self.runtime, self.graphs[name]) for name in self.graphs.keys()
}

def __getattr__(self, item):
if item not in list(self.graphs.keys()):
self.__getattribute__(item)
error = f"No attribute {item}"
raise AttributeError(error)
return FheFunction(item, self.runtime, self.graphs[item])
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class FunctionDef:
parameter_encryption_statuses: Dict[str, EncryptionStatus]
inputset: List[Any]
graph: Optional[Graph]
artifacts: Optional[FunctionDebugArtifacts]
_is_direct: bool
_parameter_values: Dict[str, ValueDescription]

def __init__(
Expand Down Expand Up @@ -95,26 +93,30 @@ def __init__(
param: EncryptionStatus(status.lower())
for param, status in parameter_encryption_statuses.items()
}
self.artifacts = None
self.inputset = []
self.graph = None
self.name = function.__name__
self._is_direct = False
self._parameter_values = {}

def trace(self, sample: Union[Any, Tuple[Any, ...]]):
def trace(
self,
sample: Union[Any, Tuple[Any, ...]],
artifacts: Optional[FunctionDebugArtifacts] = None,
):
"""
Trace the function and fuse the resulting graph with a sample input.
Args:
sample (Union[Any, Tuple[Any, ...]]):
sample to use for tracing
artifacts: Optiona[FunctionDebugArtifacts]:
the object to store artifacts in
"""

if self.artifacts is not None:
self.artifacts.add_source_code(self.function)
if artifacts is not None:
artifacts.add_source_code(self.function)
for param, encryption_status in self.parameter_encryption_statuses.items():
self.artifacts.add_parameter_encryption_status(param, encryption_status)
artifacts.add_parameter_encryption_status(param, encryption_status)

parameters = {
param: ValueDescription.of(arg, is_encrypted=(status == EncryptionStatus.ENCRYPTED))
Expand All @@ -129,10 +131,10 @@ def trace(self, sample: Union[Any, Tuple[Any, ...]]):
}

self.graph = Tracer.trace(self.function, parameters, name=self.name)
if self.artifacts is not None:
self.artifacts.add_graph("initial", self.graph)
if artifacts is not None:
artifacts.add_graph("initial", self.graph)

fuse(self.graph, self.artifacts)
fuse(self.graph, artifacts)

def evaluate(
self,
Expand All @@ -158,15 +160,6 @@ def evaluate(
artifact object to store informations in
"""

if self._is_direct:
self.graph = Tracer.trace(
self.function, self._parameter_values, is_direct=True, name=self.name
)
artifacts.add_graph("initial", self.graph) # pragma: no cover
fuse(self.graph, artifacts)
artifacts.add_graph("final", self.graph) # pragma: no cover
return

if inputset is not None:
previous_inputset_length = len(self.inputset)
for index, sample in enumerate(iter(inputset)):
Expand Down Expand Up @@ -209,7 +202,7 @@ def evaluate(
)
raise RuntimeError(message) from error

self.trace(first_sample)
self.trace(first_sample, artifacts)
assert self.graph is not None

bounds = self.graph.measure_bounds(self.inputset)
Expand Down Expand Up @@ -504,10 +497,12 @@ def compile(
raise RuntimeError(error)

module_artifacts = (
module_artifacts
if module_artifacts is not None
else ModuleDebugArtifacts(list(self.functions.keys()))
module_artifacts if module_artifacts is not None else ModuleDebugArtifacts()
)
if not module_artifacts.functions:
module_artifacts.functions = {
f: FunctionDebugArtifacts() for f in self.functions.keys()
}

dbg = DebugManager(configuration)

Expand All @@ -524,9 +519,7 @@ def compile(
mlir_context = self.compilation_context.mlir_context()
graphs = {}
for name, function in self.functions.items():
if function.graph is None:
error = "Expected graph to be set."
raise RuntimeError(error)
assert function.graph is not None
graphs[name] = function.graph
mlir_module = GraphConverter(configuration).convert_many(graphs, mlir_context)
mlir_str = str(mlir_module).strip()
Expand Down Expand Up @@ -558,7 +551,7 @@ def compile(
if configuration.dump_artifacts_on_unexpected_failures:
module_artifacts.export()

traceback_path = self.artifacts.output_directory.joinpath("traceback.txt")
traceback_path = module_artifacts.output_directory.joinpath("traceback.txt")
with open(traceback_path, "w", encoding="utf-8") as f:
f.write(traceback.format_exc())

Expand Down
Loading

0 comments on commit 34de883

Please sign in to comment.