Skip to content

Commit

Permalink
Rework high level compilation logic. (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
feltroidprime authored Oct 18, 2024
1 parent 584cad6 commit c364095
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 114 deletions.
4 changes: 1 addition & 3 deletions hydra/garaga/modulo_circuit_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,7 @@ def struct_name(self) -> str:
def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 1
return f"let {self.name}:{self.struct_name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}];"
return f"let {self.name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}];"

def dump_to_circuit_input(self) -> str:
bits = self.bits
Expand Down Expand Up @@ -570,7 +569,6 @@ def struct_name(self) -> str:
def extract_from_circuit_output(
self, offset_to_reference_map: dict[int, str]
) -> str:
assert len(self.elmts) == 1
return f"let {self.name}:{self.struct_name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}].span();"

def dump_to_circuit_input(self) -> str:
Expand Down
168 changes: 135 additions & 33 deletions hydra/garaga/precompiled_circuits/all_circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
cairo1_tests_header,
compilation_mode_to_file_header,
compile_circuit,
create_cairo1_test,
format_cairo_files_in_parallel,
)
from garaga.precompiled_circuits.compilable_circuits.cairo1_mpcheck_circuits import (
Expand Down Expand Up @@ -228,72 +229,173 @@ class CircuitID(Enum):
}


def main(
PRECOMPILED_CIRCUITS_DIR: str,
CIRCUITS_TO_COMPILE: dict[CircuitID, dict],
compilation_mode: int = 1,
):
"""Compiles and writes all circuits to .cairo files"""
def initialize_compilation(
PRECOMPILED_CIRCUITS_DIR: str, CIRCUITS_TO_COMPILE: dict
) -> tuple[
set[str],
dict[str, set[str]],
dict[str, set[str]],
dict[str, set[str]],
dict[str, open],
]:
"""
Initialize the compilation process by creating the necessary directories and files.
Returns :
- filenames_used: set of all filenames that will be used
- codes: dict of sets of strings, where each set contains the compiled circuits for a given filename
- cairo1_tests_functions: dict of sets of strings, where each set contains the cairo1 tests for a given filename
- cairo1_full_function_names: dict of sets of strings, where each set contains the full function names for a given filename
- files: dict of open files, where each file is for a given filename
"""
create_directory(PRECOMPILED_CIRCUITS_DIR)
# Ensure the 'codes' dict keys match the filenames used for file creation.
# Using sets to remove potential duplicates
filenames_used = set([v["filename"] for v in CIRCUITS_TO_COMPILE.values()])
codes = {filename: set() for filename in filenames_used}
cairo1_tests_functions = {filename: set() for filename in filenames_used}
cairo1_full_function_names = {filename: set() for filename in filenames_used}

files = {
f: open(f"{PRECOMPILED_CIRCUITS_DIR}{f}.cairo", "w") for f in filenames_used
}
return (
filenames_used,
codes,
cairo1_tests_functions,
cairo1_full_function_names,
files,
)

# Write the header to each file
HEADER = compilation_mode_to_file_header(compilation_mode)

def write_headers(files: dict[str, open], compilation_mode: int) -> None:
"""
Write the header to the files.
"""
HEADER = compilation_mode_to_file_header(compilation_mode)
for file in files.values():
file.write(HEADER)

# Instantiate and compile circuits for each curve

def compile_circuits(
CIRCUITS_TO_COMPILE: dict,
compilation_mode: int,
codes: dict[str, set[str]],
cairo1_full_function_names: dict[str, set[str]],
cairo1_tests_functions: dict[str, set[str]],
) -> None:
"""
Compile the circuits and write them to the files.
"""
for circuit_id, circuit_info in CIRCUITS_TO_COMPILE.items():
for curve_id in circuit_info.get(
"curve_ids", [CurveID.BN254, CurveID.BLS12_381]
):
filename_key = circuit_info["filename"]
compiled_circuits, full_function_names = compile_circuit(
compiled_circuits, full_function_names, circuit_instances = compile_circuit(
curve_id,
circuit_info["class"],
circuit_info["params"],
compilation_mode,
cairo1_tests_functions,
filename_key,
)
codes[filename_key].update(compiled_circuits)
if compilation_mode == 1:

cairo1_full_function_names[filename_key].update(full_function_names)
generate_cairo1_tests(
circuit_instances,
full_function_names,
curve_id,
cairo1_tests_functions,
filename_key,
)


def generate_cairo1_tests(
circuit_instances,
full_function_names,
curve_id,
cairo1_tests_functions,
filename_key,
):
for circuit_instance, full_function_name in zip(
circuit_instances, full_function_names
):
circuit_input = circuit_instance.full_input_cairo1
circuit_output = (
circuit_instance.circuit.output_structs
if sum([len(x.elmts) for x in circuit_instance.circuit.output_structs])
== len(circuit_instance.circuit.output)
else circuit_instance.circuit.output
)
cairo1_tests_functions[filename_key].add(
create_cairo1_test(
full_function_name,
circuit_input,
circuit_output,
curve_id.value,
)
)

# Write selector functions and compiled circuit codes to their respective files

def write_compiled_circuits(
files: dict[str, open],
codes: dict[str, set[str]],
cairo1_full_function_names: dict[str, set[str]],
cairo1_tests_functions: dict[str, set[str]],
compilation_mode: int,
) -> None:
"""
Write the compiled circuits and the cairo1 tests to the files.
"""
print("Writing circuits and selectors to .cairo files...")
for filename in filenames_used:
if filename in files:
# Write the compiled circuit codes
for compiled_circuit in sorted(codes[filename]):
files[filename].write(compiled_circuit + "\n")
for filename, file in files.items():
for compiled_circuit in sorted(codes[filename]):
file.write(compiled_circuit + "\n")

if compilation_mode == 1:
write_cairo1_tests(
file, filename, cairo1_full_function_names, cairo1_tests_functions
)

if compilation_mode == 1:
files[filename].write(cairo1_tests_header() + "\n")
fns_to_import = sorted(cairo1_full_function_names[filename])
if "" in fns_to_import:
fns_to_import.remove("")
files[filename].write(f"use super::{{{','.join(fns_to_import)}}};\n")
for cairo1_test in sorted(cairo1_tests_functions[filename]):
files[filename].write(cairo1_test + "\n")
files[filename].write("}\n")

else:
print(f"Warning: No file associated with filename '{filename}'")
def write_cairo1_tests(
file: open,
filename: str,
cairo1_full_function_names: dict[str, set[str]],
cairo1_tests_functions: dict[str, set[str]],
) -> None:
file.write(cairo1_tests_header() + "\n")
fns_to_import = sorted(cairo1_full_function_names[filename])
if "" in fns_to_import:
fns_to_import.remove("")
file.write(f"use super::{{{','.join(fns_to_import)}}};\n")
for cairo1_test in sorted(cairo1_tests_functions[filename]):
file.write(cairo1_test + "\n")
file.write("}\n")


def main(
PRECOMPILED_CIRCUITS_DIR: str,
CIRCUITS_TO_COMPILE: dict[CircuitID, dict],
compilation_mode: int = 1,
):
"""Compiles and writes all circuits to .cairo files"""
filenames_used, codes, cairo1_tests_functions, cairo1_full_function_names, files = (
initialize_compilation(PRECOMPILED_CIRCUITS_DIR, CIRCUITS_TO_COMPILE)
)
write_headers(files, compilation_mode)
compile_circuits(
CIRCUITS_TO_COMPILE,
compilation_mode,
codes,
cairo1_full_function_names,
cairo1_tests_functions,
)
write_compiled_circuits(
files,
codes,
cairo1_full_function_names,
cairo1_tests_functions,
compilation_mode,
)

# Close all files
for file in files.values():
file.close()

Expand Down
Loading

0 comments on commit c364095

Please sign in to comment.