diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index f415c95b63..09f53be600 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -64,7 +64,7 @@ def func_to_past(inp: DSL_PRG) -> PRG: ) -def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]: +def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PRG]: """ Wrap `func_to_past` in a chainable and optionally cached workflow step. diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index bf3bee4b56..fed02da305 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -100,6 +100,7 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No @add_content_to_fingerprint.register(FieldOperatorDefinition) @add_content_to_fingerprint.register(FoastOperatorDefinition) +@add_content_to_fingerprint.register(ProgramDefinition) @add_content_to_fingerprint.register(PastProgramDefinition) @add_content_to_fingerprint.register(toolchain.CompilableProgram) @add_content_to_fingerprint.register(arguments.CompileTimeArgs) @@ -121,6 +122,10 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo for item in sourcedef: add_content_to_fingerprint(item, hasher) + closure_vars = source_utils.get_closure_vars_from_function(obj) + for item in sorted(closure_vars.items(), key=lambda x: x[0]): + add_content_to_fingerprint(item, hasher) + @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 36d6debf9d..b91527b492 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -825,10 +825,16 @@ def simple_scan_operator( ) -> tuple[int32, tuple[int32, int32]]: return (carry[0] + 1, (carry[1][0] + 1, carry[1][1] + 1)) - @gtx.program + # @gtx.program def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): simple_scan_operator(out=out) + print(type(testee)) + dsl_definition = gtx.ffront.stages.ProgramDefinition(definition=testee) + print(type(dsl_definition)) + print(f"{dsl_definition=}") + past_definition = gtx.ffront.func_to_past.func_to_past(dsl_definition) + cases.verify_with_default_data( cartesian_case, testee,