From c3640953fb40389306578c6f73d641857e17edae Mon Sep 17 00:00:00 2001 From: feltroid Prime <96737978+feltroidprime@users.noreply.github.com> Date: Fri, 18 Oct 2024 11:44:17 +0700 Subject: [PATCH] Rework high level compilation logic. (#232) --- hydra/garaga/modulo_circuit_structs.py | 4 +- .../precompiled_circuits/all_circuits.py | 168 ++++++++++++++---- .../compilable_circuits/base.py | 110 ++++++------ .../cairo1_mpcheck_circuits.py | 7 +- .../common_cairo_fustat_circuits.py | 21 --- .../compilable_circuits/isogeny.py | 1 - 6 files changed, 197 insertions(+), 114 deletions(-) diff --git a/hydra/garaga/modulo_circuit_structs.py b/hydra/garaga/modulo_circuit_structs.py index 5131e9e4..6c9097c2 100644 --- a/hydra/garaga/modulo_circuit_structs.py +++ b/hydra/garaga/modulo_circuit_structs.py @@ -453,8 +453,7 @@ def struct_name(self) -> str: def extract_from_circuit_output( self, offset_to_reference_map: dict[int, str] ) -> str: - assert len(self.elmts) == 1 - return f"let {self.name}:{self.struct_name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}];" + return f"let {self.name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}];" def dump_to_circuit_input(self) -> str: bits = self.bits @@ -570,7 +569,6 @@ def struct_name(self) -> str: def extract_from_circuit_output( self, offset_to_reference_map: dict[int, str] ) -> str: - assert len(self.elmts) == 1 return f"let {self.name}:{self.struct_name} = array![{','.join([f'outputs.get_output({offset_to_reference_map[elmt.offset]})' for elmt in self.elmts])}].span();" def dump_to_circuit_input(self) -> str: diff --git a/hydra/garaga/precompiled_circuits/all_circuits.py b/hydra/garaga/precompiled_circuits/all_circuits.py index 65e6236a..27832e64 100644 --- a/hydra/garaga/precompiled_circuits/all_circuits.py +++ b/hydra/garaga/precompiled_circuits/all_circuits.py @@ -5,6 +5,7 @@ cairo1_tests_header, compilation_mode_to_file_header, compile_circuit, + create_cairo1_test, format_cairo_files_in_parallel, ) from garaga.precompiled_circuits.compilable_circuits.cairo1_mpcheck_circuits import ( @@ -228,72 +229,173 @@ class CircuitID(Enum): } -def main( - PRECOMPILED_CIRCUITS_DIR: str, - CIRCUITS_TO_COMPILE: dict[CircuitID, dict], - compilation_mode: int = 1, -): - """Compiles and writes all circuits to .cairo files""" +def initialize_compilation( + PRECOMPILED_CIRCUITS_DIR: str, CIRCUITS_TO_COMPILE: dict +) -> tuple[ + set[str], + dict[str, set[str]], + dict[str, set[str]], + dict[str, set[str]], + dict[str, open], +]: + """ + Initialize the compilation process by creating the necessary directories and files. + Returns : + - filenames_used: set of all filenames that will be used + - codes: dict of sets of strings, where each set contains the compiled circuits for a given filename + - cairo1_tests_functions: dict of sets of strings, where each set contains the cairo1 tests for a given filename + - cairo1_full_function_names: dict of sets of strings, where each set contains the full function names for a given filename + - files: dict of open files, where each file is for a given filename + """ create_directory(PRECOMPILED_CIRCUITS_DIR) - # Ensure the 'codes' dict keys match the filenames used for file creation. - # Using sets to remove potential duplicates filenames_used = set([v["filename"] for v in CIRCUITS_TO_COMPILE.values()]) codes = {filename: set() for filename in filenames_used} cairo1_tests_functions = {filename: set() for filename in filenames_used} cairo1_full_function_names = {filename: set() for filename in filenames_used} - files = { f: open(f"{PRECOMPILED_CIRCUITS_DIR}{f}.cairo", "w") for f in filenames_used } + return ( + filenames_used, + codes, + cairo1_tests_functions, + cairo1_full_function_names, + files, + ) - # Write the header to each file - HEADER = compilation_mode_to_file_header(compilation_mode) +def write_headers(files: dict[str, open], compilation_mode: int) -> None: + """ + Write the header to the files. + """ + HEADER = compilation_mode_to_file_header(compilation_mode) for file in files.values(): file.write(HEADER) - # Instantiate and compile circuits for each curve +def compile_circuits( + CIRCUITS_TO_COMPILE: dict, + compilation_mode: int, + codes: dict[str, set[str]], + cairo1_full_function_names: dict[str, set[str]], + cairo1_tests_functions: dict[str, set[str]], +) -> None: + """ + Compile the circuits and write them to the files. + """ for circuit_id, circuit_info in CIRCUITS_TO_COMPILE.items(): for curve_id in circuit_info.get( "curve_ids", [CurveID.BN254, CurveID.BLS12_381] ): filename_key = circuit_info["filename"] - compiled_circuits, full_function_names = compile_circuit( + compiled_circuits, full_function_names, circuit_instances = compile_circuit( curve_id, circuit_info["class"], circuit_info["params"], compilation_mode, - cairo1_tests_functions, filename_key, ) codes[filename_key].update(compiled_circuits) if compilation_mode == 1: - cairo1_full_function_names[filename_key].update(full_function_names) + generate_cairo1_tests( + circuit_instances, + full_function_names, + curve_id, + cairo1_tests_functions, + filename_key, + ) + + +def generate_cairo1_tests( + circuit_instances, + full_function_names, + curve_id, + cairo1_tests_functions, + filename_key, +): + for circuit_instance, full_function_name in zip( + circuit_instances, full_function_names + ): + circuit_input = circuit_instance.full_input_cairo1 + circuit_output = ( + circuit_instance.circuit.output_structs + if sum([len(x.elmts) for x in circuit_instance.circuit.output_structs]) + == len(circuit_instance.circuit.output) + else circuit_instance.circuit.output + ) + cairo1_tests_functions[filename_key].add( + create_cairo1_test( + full_function_name, + circuit_input, + circuit_output, + curve_id.value, + ) + ) - # Write selector functions and compiled circuit codes to their respective files + +def write_compiled_circuits( + files: dict[str, open], + codes: dict[str, set[str]], + cairo1_full_function_names: dict[str, set[str]], + cairo1_tests_functions: dict[str, set[str]], + compilation_mode: int, +) -> None: + """ + Write the compiled circuits and the cairo1 tests to the files. + """ print("Writing circuits and selectors to .cairo files...") - for filename in filenames_used: - if filename in files: - # Write the compiled circuit codes - for compiled_circuit in sorted(codes[filename]): - files[filename].write(compiled_circuit + "\n") + for filename, file in files.items(): + for compiled_circuit in sorted(codes[filename]): + file.write(compiled_circuit + "\n") + + if compilation_mode == 1: + write_cairo1_tests( + file, filename, cairo1_full_function_names, cairo1_tests_functions + ) - if compilation_mode == 1: - files[filename].write(cairo1_tests_header() + "\n") - fns_to_import = sorted(cairo1_full_function_names[filename]) - if "" in fns_to_import: - fns_to_import.remove("") - files[filename].write(f"use super::{{{','.join(fns_to_import)}}};\n") - for cairo1_test in sorted(cairo1_tests_functions[filename]): - files[filename].write(cairo1_test + "\n") - files[filename].write("}\n") - else: - print(f"Warning: No file associated with filename '{filename}'") +def write_cairo1_tests( + file: open, + filename: str, + cairo1_full_function_names: dict[str, set[str]], + cairo1_tests_functions: dict[str, set[str]], +) -> None: + file.write(cairo1_tests_header() + "\n") + fns_to_import = sorted(cairo1_full_function_names[filename]) + if "" in fns_to_import: + fns_to_import.remove("") + file.write(f"use super::{{{','.join(fns_to_import)}}};\n") + for cairo1_test in sorted(cairo1_tests_functions[filename]): + file.write(cairo1_test + "\n") + file.write("}\n") + + +def main( + PRECOMPILED_CIRCUITS_DIR: str, + CIRCUITS_TO_COMPILE: dict[CircuitID, dict], + compilation_mode: int = 1, +): + """Compiles and writes all circuits to .cairo files""" + filenames_used, codes, cairo1_tests_functions, cairo1_full_function_names, files = ( + initialize_compilation(PRECOMPILED_CIRCUITS_DIR, CIRCUITS_TO_COMPILE) + ) + write_headers(files, compilation_mode) + compile_circuits( + CIRCUITS_TO_COMPILE, + compilation_mode, + codes, + cairo1_full_function_names, + cairo1_tests_functions, + ) + write_compiled_circuits( + files, + codes, + cairo1_full_function_names, + cairo1_tests_functions, + compilation_mode, + ) - # Close all files for file in files.values(): file.close() diff --git a/hydra/garaga/precompiled_circuits/compilable_circuits/base.py b/hydra/garaga/precompiled_circuits/compilable_circuits/base.py index 611087a2..95c2d2e3 100644 --- a/hydra/garaga/precompiled_circuits/compilable_circuits/base.py +++ b/hydra/garaga/precompiled_circuits/compilable_circuits/base.py @@ -2,6 +2,7 @@ import subprocess from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor +from typing import Type from garaga.definitions import CurveID, get_base_field from garaga.hints.io import int_array_to_u384_array @@ -14,9 +15,6 @@ class BaseModuloCircuit(ABC): Base class for all modulo circuits that will be compiled to Cairo code. Parameters: - name: str, the name of the circuit - - input_len: int, the number of input elements. - The actual number set here is not always used, - - curve_id: int, the id of the curve - auto_run: bool, whether to run the circuit automatically at initialization. When compiling, this flag is set to true so the ModuloCircuit class inside the @@ -31,7 +29,6 @@ class BaseModuloCircuit(ABC): def __init__( self, name: str, - input_len: int, curve_id: int, auto_run: bool = True, compilation_mode: int = 0, @@ -39,7 +36,6 @@ def __init__( self.name = name self.curve_id = curve_id self.field = get_base_field(curve_id) - self.input_len = input_len self.init_hash = None self.generic_over_curve = False self.compilation_mode = compilation_mode @@ -100,13 +96,12 @@ class BaseEXTFCircuit(BaseModuloCircuit): def __init__( self, name: str, - input_len: int, curve_id: int, auto_run: bool = True, init_hash: int = None, compilation_mode: int = 0, ): - super().__init__(name, input_len, curve_id, auto_run, compilation_mode) + super().__init__(name, curve_id, auto_run, compilation_mode) self.init_hash = init_hash @@ -154,68 +149,81 @@ def to_snake_case(s: str) -> str: return re.sub(r"(?<=[a-z])(?=[A-Z])|[^a-zA-Z0-9]", "_", s).lower() -def compile_circuit( +def create_circuit_instances( + circuit_class: Type[BaseModuloCircuit], curve_id: CurveID, - circuit_class: BaseModuloCircuit, params: list[dict], compilation_mode: int, - cairo1_tests_functions: dict[str, set[str]], - filename_key: str, -) -> tuple[list[str], str]: - # print( - # f"Compiling {curve_id.name}:{circuit_class.__name__} {f'with params {params}' if params else ''}..." - # ) - - circuits: list[BaseModuloCircuit] = [] - compiled_circuits: list[str] = [] - full_function_names: list[str] = [] - +) -> list[BaseModuloCircuit]: + """ + Create a list of circuit instances from a given circuit class, curve id, params and compilation mode. + """ + circuits = [] if params is None: circuit_instance = circuit_class( - curve_id=curve_id.value, compilation_mode=compilation_mode + curve_id=curve_id.value, compilation_mode=compilation_mode, auto_run=True ) circuits.append(circuit_instance) else: for param in params: circuit_instance = circuit_class( - curve_id=curve_id.value, compilation_mode=compilation_mode, **param + curve_id=curve_id.value, + compilation_mode=compilation_mode, + auto_run=True, + **param, ) circuits.append(circuit_instance) + return circuits - for i, circuit_instance in enumerate(circuits): - function_name = ( - f"{circuit_instance.name.upper()}" - if circuit_instance.circuit.generic_circuit - else f"{curve_id.name}_{circuit_instance.name.upper()}" - ) - compiled_circuit, full_function_name = circuit_instance.circuit.compile_circuit( - function_name=function_name - ) - compiled_circuits.append(compiled_circuit) +def compile_single_circuit( + circuit_instance: BaseModuloCircuit, +) -> tuple[ModuloCircuit, str]: + """ + Compile a single circuit instance to Cairo code. + Returns the compiled circuit and the full function name. + """ + curve_id = CurveID(circuit_instance.curve_id) + function_name = ( + f"{circuit_instance.name.upper()}" + if circuit_instance.circuit.generic_circuit + else f"{curve_id.name}_{circuit_instance.name.upper()}" + ) + compiled_circuit, full_function_name = circuit_instance.circuit.compile_circuit( + function_name=function_name + ) + return compiled_circuit, full_function_name - if compilation_mode == 1: - circuit_input = circuit_instance.full_input_cairo1 - if sum( - [len(x.elmts) for x in circuit_instance.circuit.output_structs] - ) == len(circuit_instance.circuit.output): - circuit_output = circuit_instance.circuit.output_structs - else: - circuit_output = circuit_instance.circuit.output - full_function_names.append(full_function_name) - cairo1_tests_functions[filename_key].add( - write_cairo1_test( - full_function_name, - circuit_input, - circuit_output, - curve_id.value, - ) - ) - return compiled_circuits, full_function_names +def compile_circuit( + curve_id: CurveID, + circuit_class: BaseModuloCircuit, + params: list[dict], + compilation_mode: int, + filename_key: str, +) -> tuple[list[str], list[str], list[BaseModuloCircuit]]: + """ + Compile a list of circuit instances to Cairo code. + Returns : + - compiled_circuits: list of compiled circuits as strings + - full_function_names: list of full function names as strings + - circuit_instances: list of circuit instances that have been compiled + """ + circuits = create_circuit_instances( + circuit_class, curve_id, params, compilation_mode + ) + compiled_circuits = [] + full_function_names = [] + + for circuit_instance in circuits: + compiled_circuit, full_function_name = compile_single_circuit(circuit_instance) + compiled_circuits.append(compiled_circuit) + full_function_names.append(full_function_name) + + return compiled_circuits, full_function_names, circuits -def write_cairo1_test(function_name: str, input: list, output: list, curve_id: int): +def create_cairo1_test(function_name: str, input: list, output: list, curve_id: int): return "" if function_name == "": # print(f"passing test") diff --git a/hydra/garaga/precompiled_circuits/compilable_circuits/cairo1_mpcheck_circuits.py b/hydra/garaga/precompiled_circuits/compilable_circuits/cairo1_mpcheck_circuits.py index 69c0ab27..51fc33c0 100644 --- a/hydra/garaga/precompiled_circuits/compilable_circuits/cairo1_mpcheck_circuits.py +++ b/hydra/garaga/precompiled_circuits/compilable_circuits/cairo1_mpcheck_circuits.py @@ -82,7 +82,6 @@ def __init__( self.n_fixed_g2 = n_fixed_g2 super().__init__( name=name, - input_len=None, curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -988,7 +987,7 @@ def __init__( compilation_mode: int = 0, ): super().__init__( - "fp12_mul_assert_one", None, curve_id, auto_run, init_hash, compilation_mode + "fp12_mul_assert_one", curve_id, auto_run, init_hash, compilation_mode ) def build_input(self) -> list[PyFelt]: @@ -1048,9 +1047,7 @@ def __init__( init_hash: int = None, compilation_mode: int = 0, ): - super().__init__( - "eval_e12d", None, curve_id, auto_run, init_hash, compilation_mode - ) + super().__init__("eval_e12d", curve_id, auto_run, init_hash, compilation_mode) def build_input(self) -> list[PyFelt]: input = [] diff --git a/hydra/garaga/precompiled_circuits/compilable_circuits/common_cairo_fustat_circuits.py b/hydra/garaga/precompiled_circuits/compilable_circuits/common_cairo_fustat_circuits.py index 8f5e4182..deb8b737 100644 --- a/hydra/garaga/precompiled_circuits/compilable_circuits/common_cairo_fustat_circuits.py +++ b/hydra/garaga/precompiled_circuits/compilable_circuits/common_cairo_fustat_circuits.py @@ -20,7 +20,6 @@ def __init__( ) -> None: super().__init__( name="dummy", - input_len=2, curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -52,7 +51,6 @@ class IsOnCurveG1G2Circuit(BaseModuloCircuit): def __init__(self, curve_id: int, auto_run: bool = True, compilation_mode: int = 0): super().__init__( name="is_on_curve_g1_g2", - input_len=(2 + 4), curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -101,7 +99,6 @@ class IsOnCurveG1Circuit(BaseModuloCircuit): def __init__(self, curve_id: int, auto_run: bool = True, compilation_mode: int = 0): super().__init__( name="is_on_curve_g1", - input_len=(2 + 1), curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -135,7 +132,6 @@ class IsOnCurveG2Circuit(BaseModuloCircuit): def __init__(self, curve_id: int, auto_run: bool = True, compilation_mode: int = 0): super().__init__( name="is_on_curve_g2", - input_len=(2 + 1), curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -178,7 +174,6 @@ def __init__( ) -> None: super().__init__( name="slope_intercept_same_point", - input_len=3, # P(Px, Py), A in y^2 = x^3 + Ax + B curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -221,7 +216,6 @@ def __init__( ) -> None: super().__init__( name="acc_eval_point_challenge_signed", - input_len=8, # Eval_Accumulator + (m,b) + xA + (Px, Py) + ep + en curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -279,7 +273,6 @@ def __init__( ) -> None: super().__init__( name="rhs_finalize_acc", - input_len=6, # Eval_Accumulator + (m,b) + xA + (Qx, Qy) curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -329,15 +322,6 @@ def __init__( self.n_points = n_points super().__init__( name=f"eval_fn_challenge_dupl_{n_points}P", - input_len=( - (2 + 2 + 2) # 2 EC challenge points (x,y) + 2 coefficients - + ( # F=a(x) + y b(x). - (1 + n_points) # Number of coefficients in a's numerator - + (1 + n_points + 1) # Number of coefficients in a's denominator - + (1 + n_points + 1) # Number of coefficients in b's numerator - + (1 + n_points + 4) # Number of coefficients in b's denominator - ) - ), curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -435,7 +419,6 @@ def __init__( self.n_points = n_points super().__init__( name=f"init_fn_challenge_dupl_{n_points}P", - input_len=None, curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -519,7 +502,6 @@ def __init__( ): super().__init__( name="acc_function_challenge_dupl", - input_len=None, curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -600,7 +582,6 @@ def __init__( ): super().__init__( name="finalize_fn_challenge_dupl", - input_len=None, curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -645,7 +626,6 @@ def __init__( ): super().__init__( name="add_ec_point", - input_len=4, # xP, yP, xQ, yQ curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, @@ -682,7 +662,6 @@ def __init__( ): super().__init__( name="double_ec_point", - input_len=3, # xP, yP, A curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode, diff --git a/hydra/garaga/precompiled_circuits/compilable_circuits/isogeny.py b/hydra/garaga/precompiled_circuits/compilable_circuits/isogeny.py index 41409a57..5912f87b 100644 --- a/hydra/garaga/precompiled_circuits/compilable_circuits/isogeny.py +++ b/hydra/garaga/precompiled_circuits/compilable_circuits/isogeny.py @@ -14,7 +14,6 @@ def __init__( ) -> None: super().__init__( name=f"apply_isogeny_{CurveID(curve_id).name.lower()}", - input_len=2, curve_id=curve_id, auto_run=auto_run, compilation_mode=compilation_mode,