Skip to content

Commit

Permalink
WIP remote debugging session
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Oct 16, 2024
1 parent f72c84f commit 736a924
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/func_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def func_to_past(inp: DSL_PRG) -> PRG:
)


def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]:
def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PRG]:
"""
Wrap `func_to_past` in a chainable and optionally cached workflow step.
Expand Down
5 changes: 5 additions & 0 deletions src/gt4py/next/ffront/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No

@add_content_to_fingerprint.register(FieldOperatorDefinition)
@add_content_to_fingerprint.register(FoastOperatorDefinition)
@add_content_to_fingerprint.register(ProgramDefinition)
@add_content_to_fingerprint.register(PastProgramDefinition)
@add_content_to_fingerprint.register(toolchain.CompilableProgram)
@add_content_to_fingerprint.register(arguments.CompileTimeArgs)
Expand All @@ -121,6 +122,10 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo
for item in sourcedef:
add_content_to_fingerprint(item, hasher)

closure_vars = source_utils.get_closure_vars_from_function(obj)
for item in sorted(closure_vars.items(), key=lambda x: x[0]):
add_content_to_fingerprint(item, hasher)


@add_content_to_fingerprint.register
def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,10 +825,16 @@ def simple_scan_operator(
) -> tuple[int32, tuple[int32, int32]]:
return (carry[0] + 1, (carry[1][0] + 1, carry[1][1] + 1))

@gtx.program
# @gtx.program
def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]):
simple_scan_operator(out=out)

print(type(testee))
dsl_definition = gtx.ffront.stages.ProgramDefinition(definition=testee)
print(type(dsl_definition))
print(f"{dsl_definition=}")
past_definition = gtx.ffront.func_to_past.func_to_past(dsl_definition)

cases.verify_with_default_data(
cartesian_case,
testee,
Expand Down

0 comments on commit 736a924

Please sign in to comment.