From 32dde792bde505807a5729261e4f1d12a1451bdb Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Tue, 21 May 2024 10:07:27 +0200 Subject: [PATCH] refactor[next]: cleaner toolchain (#1537) ## Changed: Toolchain does not own input arguments to DSL programs anymore, instead the input datastructure owns them and the toolchain can dispatch them to a subset of the steps. ## Toolchain migration: Old: ```python FieldopTransformWorkflow().replace(foast_inject_args=FopArgsInjector(*args, **kwargs))(fieldoperator_definition) ProgramTransformWorkflow().replace(program_inject_args=ProgArgsInjector(*args, **kwargs))(program_definition) ``` New: ```python FieldopTransformWorkflow()(InputWithArgs(fieldoperator_definition, args, kwargs)) ProgramTransformWorkflow()(InputWithArgs(program_definition, args, kwargs)) ``` ## Added: - `otf.workflow`: - new workflow type: `NamedStepSequenceWithArgs` takes an `InputWithArgs` and dispatches `.args` and `.kwargs` to steps that set `take_args = True` in the field metadata - new data type `InputWithArgs` wraps a workflow stage and call args - `backend`: Replace `*ArgsInjector` using the new `NamedStepSequenceWithArgs` infrastructure --- docs/user/next/advanced/HackTheToolchain.md | 129 +++++ .../ToolchainWalkthrough.md} | 286 ++++++++-- docs/user/next/advanced/WorkflowPatterns.md | 492 ++++++++++++++++++ src/gt4py/next/backend.py | 84 ++- src/gt4py/next/otf/workflow.py | 23 + .../unit_tests/ffront_tests/test_stages.py | 76 ++- tox.ini | 4 +- 7 files changed, 967 insertions(+), 127 deletions(-) create mode 100644 docs/user/next/advanced/HackTheToolchain.md rename docs/user/next/{Advanced_ToolchainWalkthrough.md => advanced/ToolchainWalkthrough.md} (64%) create mode 100644 docs/user/next/advanced/WorkflowPatterns.md diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md new file mode 100644 index 0000000000..70681796ee --- /dev/null +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -0,0 +1,129 @@ +```python +import dataclasses +import typing + +from gt4py import next as gtx +from gt4py.next.otf import workflow +from gt4py import eve +``` + + + + +## Replace Steps + +```python +cached_lowering_toolchain = gtx.backend.DEFAULT_PROG_TRANSFORMS.replace( + past_to_itir=workflow.CachedStep( + step=gtx.ffront.past_to_itir.PastToItirFactory(), + hash_function=eve.utils.content_hash + ) +) +``` + +## Skip Steps / Change Order + +```python +gtx.backend.DEFAULT_PROG_TRANSFORMS.step_order +``` + + ['func_to_past', + 'past_lint', + 'past_inject_args', + 'past_transform_args', + 'past_to_itir'] + +```python +@dataclasses.dataclass(frozen=True) +class SkipLinting(gtx.backend.ProgramTransformWorkflow): + @property + def step_order(self): + return [ + "func_to_past", + # not running "past_lint" + "past_inject_args", + "past_transform_args", + "past_to_itir", + ] + +same_steps = dataclasses.asdict(gtx.backend.DEFAULT_PROG_TRANSFORMS) +skip_linting_transforms = SkipLinting( + **same_steps +) +``` + +## Alternative Factory + +```python +class MyCodeGen: + ... + +class Cpp2BindingsGen: + ... + +class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWorkflowFactory): + translation: workflow.Workflow[ + gtx.otf.stages.ProgramCall, gtx.otf.stages.ProgramSource] = MyCodeGen() + bindings: workflow.Workflow[ + gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableSource] = Cpp2BindingsGen() + +PureCpp2WorkflowFactory(cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG) +``` + +## Invent new Workflow Types + +````mermaid +graph LR + +IN_T --> i{{split}} --> A_T --> a{{track_a}} --> B_T --> o{{combine}} --> OUT_T +i --> X_T --> x{{track_x}} --> Y_T --> o + + +```python +IN_T = typing.TypeVar("IN_T") +A_T = typing.TypeVar("A_T") +B_T = typing.TypeVar("B_T") +X_T = typing.TypeVar("X_T") +Y_T = typing.TypeVar("Y_T") +OUT_T = typing.TypeVar("OUT_T") + +@dataclasses.dataclass(frozen=True) +class FullyModularDiamond( + workflow.ChainableWorkflowMixin[IN_T, OUT_T], + workflow.ReplaceEnabledWorkflowMixin[IN_T, OUT_T], + typing.Protocol[IN_T, OUT_T, A_T, B_T, X_T, Y_T] +): + split: workflow.Workflow[IN_T, tuple[A_T, X_T]] + track_a: workflow.Workflow[A_T, B_T] + track_x: workflow.Workflow[X_T, Y_T] + combine: workflow.Workflow[tuple[B_T, Y_T], OUT_T] + + def __call__(self, inp: IN_T) -> OUT_T: + a, x = self.split(inp) + b = self.track_a(a) + y = self.track_x(x) + return self.combine((b, y)) + + +@dataclasses.dataclass(frozen=True) +class PartiallyModularDiamond( + workflow.ChainableWorkflowMixin[IN_T, OUT_T], + workflow.ReplaceEnabledWorkflowMixin[IN_T, OUT_T], + typing.Protocol[IN_T, OUT_T, A_T, B_T, X_T, Y_T] +): + track_a: workflow.Workflow[A_T, B_T] + track_x: workflow.Workflow[X_T, Y_T] + + def split(inp: IN_T) -> tuple[A_T, X_T]: + ... + + def combine(b: B_T, y: Y_T) -> OUT_T: + ... + + def __call__(inp: IN_T) -> OUT_T: + a, x = self.split(inp) + return self.combine( + b=self.track_a(a), + y=self.track_x(x) + ) +```` diff --git a/docs/user/next/Advanced_ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md similarity index 64% rename from docs/user/next/Advanced_ToolchainWalkthrough.md rename to docs/user/next/advanced/ToolchainWalkthrough.md index 94a7bfa7e2..d44663a72c 100644 --- a/docs/user/next/Advanced_ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -24,14 +24,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta ``` # Walkthrough from Field Operator @@ -71,14 +83,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style fdef fill:red style foast fill:red @@ -114,14 +138,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style foast fill:red style itir_expr fill:red @@ -147,34 +183,53 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) -style foast fill:red +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta + +style foasta fill:red style fclos fill:red linkStyle 2 stroke:red,stroke-width:4px,color:pink ``` -Here we have to dynamically generate a workflow step, because the arguments were not known before. +Here we have to manually combine the previous result with the call arguments. When we call the toolchain as a whole later we will only have to do this once at the beginning. + +```python +fclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_foast_closure( + gtx.otf.workflow.InputWithArgs( + data=foast, + args=(gtx.ones(domain={I: 10}, dtype=gtx.float64),), + kwargs={ + "out": gtx.zeros(domain={I: 10}, dtype=gtx.float64), + "from_fieldop": example_fo + }, + ) +) +``` ```python -fclos = backend.DEFAULT_FIELDOP_TRANSFORMS.foast_inject_args.__class__( - args=(gtx.ones(domain={I: 10}, dtype=gtx.float64),), - kwargs={ - "out": gtx.zeros(domain={I: 10}, dtype=gtx.float64) - }, - from_fieldop=example_fo -)(foast) +fclos.closure_vars["example_fo"].backend ``` ```python -gtx.ffront.stages.FoastClosure? +gtx.ffront.stages.FoastClosure?? ``` Init signature: @@ -185,6 +240,13 @@ gtx.ffront.stages.FoastClosure?  closure_vars: 'dict[str, Any]', ) -> None Docstring: FoastClosure(foast_op_def: 'FoastOperatorDefinition[OperatorNodeT]', args: 'tuple[Any, ...]', kwargs: 'dict[str, Any]', closure_vars: 'dict[str, Any]') + Source: + @dataclasses.dataclass(frozen=True) + class FoastClosure(Generic[OperatorNodeT]): +  foast_op_def: FoastOperatorDefinition[OperatorNodeT] +  args: tuple[Any, ...] +  kwargs: dict[str, Any] +  closure_vars: dict[str, Any] File: ~/Code/gt4py/src/gt4py/next/ffront/stages.py Type: type Subclasses: @@ -198,14 +260,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style fclos fill:red style pclos fill:red @@ -242,14 +316,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style pclos fill:red %%style pclos fill:red @@ -260,6 +346,12 @@ linkStyle 4 stroke:red,stroke-width:4px,color:pink pclost = backend.DEFAULT_PROG_TRANSFORMS.past_transform_args(pclos) ``` +```python +pclost.kwargs +``` + + {} + ## Lower PAST -> ITIR still forwarding the call arguments @@ -269,14 +361,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style pclos fill:red style pcall fill:red @@ -326,30 +430,46 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta + +style fdefa fill:red +style fuwr fill:red style fdef fill:red +style fargs fill:red style foast fill:red +style fiwr fill:red +style foasta fill:red style fclos fill:red style pclos fill:red style pcall fill:red -linkStyle 0,2,3,4,5 stroke:red,stroke-width:4px,color:pink +linkStyle 0,2,3,4,5,9,10,11,12,13,14 stroke:red,stroke-width:4px,color:pink ``` ### Starting from DSL ```python -foast_toolchain = backend.DEFAULT_FIELDOP_TRANSFORMS.replace( - foast_inject_args=backend.FopArgsInjector(args=fclos.args, kwargs=fclos.kwargs, from_fieldop=example_fo) +pitir2 = backend.DEFAULT_FIELDOP_TRANSFORMS( + gtx.otf.workflow.InputWithArgs(data=start, args=fclos.args, kwargs=fclos.kwargs | {"from_fieldop": example_fo}) ) -pitir2 = foast_toolchain(start) assert pitir2 == pitir ``` @@ -365,22 +485,39 @@ example_compiled = gtx.program_processors.runners.roundtrip.executor.otf_workflo example_compiled(*pitir2.args, offset_provider=OFFSET_PROVIDER) ``` +We can re-run with the output from the previous run as in- and output. + ```python example_compiled(pitir2.args[1], *pitir2.args[1:], offset_provider=OFFSET_PROVIDER) ``` ```python -pitir2.args[1].asnumpy() +pitir2.args[2] ``` - array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.]) + 10 + +```python +pitir.args +``` + + (NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), + NumPyArrayField(_domain=Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(0, 10),)), _ndarray=array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3.])), + 10, + 10) ### Starting from FOAST Note that it is the exact same call but with a different input stage ```python -pitir3 = foast_toolchain(foast) +pitir3 = backend.DEFAULT_FIELDOP_TRANSFORMS( + gtx.otf.workflow.InputWithArgs( + data=foast, + args=fclos.args, + kwargs=fclos.kwargs | {"from_fieldop": example_fo} + ) +) assert pitir3 == pitir ``` @@ -419,14 +556,26 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta style pdef fill:red style past fill:red @@ -444,27 +593,40 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) -style past fill:red +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta + +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta + +style pasta fill:red style pclos fill:red -linkStyle 7 stroke:red,stroke-width:4px,color:pink +linkStyle 8 stroke:red,stroke-width:4px,color:pink ``` ```python -pclos = backend.DEFAULT_PROG_TRANSFORMS.replace( - past_inject_args=backend.ProgArgsInjector( +pclos = backend.DEFAULT_PROG_TRANSFORMS( + gtx.otf.workflow.InputWithArgs( + data=p_past, args=fclos.args, kwargs=fclos.kwargs ) -)(p_past) +) ``` ## Full Program Toolchain @@ -474,27 +636,45 @@ graph LR fdef(FieldOperatorDefinition) -->|func_to_foast| foast(FoastOperatorDefinition) foast -->|foast_to_itir| itir_expr(itir.Expr) -foast -->|foast_inject_args| fclos(FoastClosure) +foasta -->|foast_to_foast_closure| fclos(FoastClosure) fclos -->|foast_to_past_closure| pclos(PastClosure) pclos -->|past_process_args| pclos pclos -->|past_to_itir| pcall(ProgramCall) pdef(ProgramDefinition) -->|func_to_past| past(PastProgramDefinition) past -->|past_lint| past -past -->|past_inject_args| pclos(ProgramClosure) +pasta -->|past_to_past_closure| pclos(ProgramClosure) + +fdefa(InputWithArgs) --> fuwr{{"internal unwrapping"}} --> fdef +fuwr --> fargs(args, kwargs) + +foast --> fiwr{{"internal wrapping"}} --> foasta(InputWithArgs) +fargs --> foasta +pdefa(InputWithArgs) --> puwr{{"internal unwrapping"}} --> pdef +puwr --> pargs(args, kwargs) + +past --> piwr{{"internal wrapping"}} --> pasta(InputWithArgs) +pargs --> pasta + +style pdefa fill:red +style puwr fill:red style pdef fill:red +style pargs fill:red style past fill:red +style piwr fill:red +style pasta fill:red style pclos fill:red style pcall fill:red -linkStyle 4,5,6,7 stroke:red,stroke-width:4px,color:pink +linkStyle 4,5,6,7,8,15,16,17,18,19,20 stroke:red,stroke-width:4px,color:pink ``` ### Starting from DSL ```python -toolchain = backend.DEFAULT_PROG_TRANSFORMS.replace( - past_inject_args=backend.ProgArgsInjector( +p_itir1 = backend.DEFAULT_PROG_TRANSFORMS( + gtx.otf.workflow.InputWithArgs( + data=p_start, args=fclos.args, kwargs=fclos.kwargs ) @@ -502,11 +682,13 @@ toolchain = backend.DEFAULT_PROG_TRANSFORMS.replace( ``` ```python -p_itir1 = toolchain(p_start) -``` - -```python -p_itir2 = toolchain(p_past) +p_itir2 = backend.DEFAULT_PROG_TRANSFORMS( + gtx.otf.workflow.InputWithArgs( + data=p_past, + args=fclos.args, + kwargs=fclos.kwargs + ) +) ``` ```python diff --git a/docs/user/next/advanced/WorkflowPatterns.md b/docs/user/next/advanced/WorkflowPatterns.md new file mode 100644 index 0000000000..76880d86f0 --- /dev/null +++ b/docs/user/next/advanced/WorkflowPatterns.md @@ -0,0 +1,492 @@ +--- +jupyter: + jupytext: + formats: ipynb,md + text_representation: + extension: .md + format_name: markdown + format_version: "1.3" + jupytext_version: 1.16.1 + kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +```python editable=true slideshow={"slide_type": ""} +import dataclasses +import re + +import factory + +import gt4py.next as gtx + +import devtools +``` + + + +# How to read (toolchain) workflows + + + + + +## Basic workflow (single step) + +```mermaid +graph LR + +StageA -->|basic workflow| StageB +``` + +Where "Stage" describes any data structure, and where `StageA` contains all the input data and `StageB` contains all the output data. + + + + + +### Simplest possible + + + +```python editable=true slideshow={"slide_type": ""} +def simple_add_1(inp: int) -> int: + return inp + 1 + +simple_add_1(1) +``` + + + +This is already a (single step) workflow. We can build a more complex one by chaining it multiple times. + +```mermaid +graph LR + +inp(A: int) -->|simple_add_1| b(A + 1) -->|simple_add_1| c(A + 2) -->|simple_add_1| out(A + 3) +``` + + + +```python editable=true slideshow={"slide_type": ""} +manual_add_3 = gtx.otf.workflow.StepSequence.start( + simple_add_1 +).chain(simple_add_1).chain(simple_add_1) + +manual_add_3(1) +``` + + + +### Simplest Composable Step + +All we have to do for chaining to work out of the box is add the `make_step` decorator! + + + +```python editable=true slideshow={"slide_type": ""} +@gtx.otf.workflow.make_step +def chainable_add_1(inp: int) -> int: + return inp + 1 +``` + +```python editable=true slideshow={"slide_type": ""} +add_3 = chainable_add_1.chain(chainable_add_1).chain(chainable_add_1) +add_3(1) +``` + +### Example in the Wild + +```python jupyter={"outputs_hidden": true} +gtx.ffront.func_to_past.func_to_past.steps.inner[0]?? +``` + + + +### Step with Parameters + +Sometimes we want to allow for different configurations of a step. + + + +```python editable=true slideshow={"slide_type": ""} +@dataclasses.dataclass(frozen=True) +class MathOp(gtx.otf.workflow.ChainableWorkflowMixin[int, int]): + op: str + rhs: int = 0 + + def __call__(self, inp: int) -> int: + return getattr(self, self.op)(inp, self.rhs) + + def add(self, lhs: int, rhs: int) -> int: + return lhs + rhs + + def mul(self, lhs: int, rhs: int) -> int: + return lhs * rhs + +add_3_times_2 = ( + MathOp("add", 3) + .chain(MathOp("mul", 2)) +) +add_3_times_2(1) +``` + +### Example in the Wild + +```python jupyter={"outputs_hidden": true} +gtx.program_processors.runners.roundtrip.Roundtrip?? +``` + + + +### Wrapper Steps + +Sometimes we want to make a step behave slightly differently without modifying the step itself. In this case we can wrap it into a wrapper step. These behave a little bit like (limited) decorators. +Below we will go through the existing wrapper steps, which you might encounter. + +#### Caching / memoizing + +For example we might want to cach the output (memoize) for which we need to add a way of hashing the input: + +```mermaid +graph LR + + +inp --> calc +inp(A: int) --> ha{{"hash_function(A)"}} --> h("hash(A)") --> ck{{"check cache"}} -->|miss| miss("not in cache") --> calc{{add_3_times_2}} --> out(result) +ck -->|hit| hit("in cache") --> out +``` + +For this we can use the `CachedStep`, you will see something like below + + + +```python editable=true slideshow={"slide_type": ""} +@gtx.otf.workflow.make_step +def debug_print(inp: int) -> int: + print("cache miss!") + return inp + +cached_calc = gtx.otf.workflow.CachedStep( + step=debug_print.chain(add_3_times_2), + hash_function=lambda i: str(i) # using ints as their own hash +) + +cached_calc(1) +cached_calc(1) +cached_calc(1) +``` + +### Example in the Wild + +```python jupyter={"outputs_hidden": true} +gtx.backend.DEFAULT_PROG_TRANSFORMS.past_lint?? +``` + + + +Though we execute the workflow three times we only get the debug print once, it worked! Btw, hashing is rarely that easy in the wild... + +#### Conditionally skipping steps + +The `SkippableStep` pattern can be used to skip a step under a given condition. A main use case is when you might want to run a workflow either from the start or from further along (with the same interface). + +Let's say we want to make our calculation workflow compatible with string input. We can add a conversion step (which only works with strings). + + + +```python editable=true slideshow={"slide_type": ""} +@gtx.otf.workflow.make_step +def to_int(inp: str) -> int: + assert isinstance(inp, str), "Can not work with 'int'!" # yes, this is horribly contrived + return int(inp) + +str_calc = to_int.chain(add_3_times_2) + +str_calc("1") +``` + + + +Now we can start from a string that contains an int. But if we already have an int, it will fail. + + + +```python editable=true slideshow={"slide_type": ""} +try: + str_calc(1) +except AssertionError as err: + print(err) +``` + + + +What to do? What we want is a to conditionally skip the first step, so we replace it with a `SkippableStep`: + +```python +class OptionalStrToInt(SkippableStep[str | int, int]): + step: Workflow[str, int] + + def skip_condition(self, inp: str | int) -> bool: + ... # return True to skip (if we get an int) or False to run the conversion (str case) + +``` + +```mermaid +graph LR + +int(A: int = 1) --> calc{{"add_3_times_2(1)"}} --> result(8) +int --> ski{{"skip_condition(1)"}} -->|True| calc +str("B: str = '1'") --> sks{{"skip_condition('1')"}} -->|False| conv{{to_int}} --> b2("int(B) = 1") --> calc +``` + + + +```python editable=true slideshow={"slide_type": ""} +@dataclasses.dataclass(frozen=True) +class OptionalStrToInt(gtx.otf.workflow.SkippableStep[str | int, int]): + step: gtx.otf.workflow.Workflow[str, int] = to_int + + def skip_condition(self, inp: str | int) -> bool: + match inp: + case int(): + return True + case str(): + return False + case _: + # optionally raise an error with good advice + return False + +strint_calc = OptionalStrToInt().chain(add_3_times_2) +strint_calc(1) == strint_calc("1") +``` + + + +### Example in the Wild + + + +```python jupyter={"outputs_hidden": true} editable=true slideshow={"slide_type": ""} +gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past?? +``` + + + +### Step with factory (builder) + +If a step can be useful with different combinations of parameters and wrappers, it should have a factory. In this case we will add a neutral wrapper around it, so we can put any combination of wrappers into that: + + + +```python editable=true slideshow={"slide_type": ""} +@dataclasses.dataclass(frozen=True) +class AnyStrToInt(gtx.otf.workflow.ChainableWorkflowMixin[str | int, int]): + inner_step: gtx.otf.workflow.Workflow[str, int] = to_int + + def __call__(self, inp: str | int) -> int: + return self.inner_step(inp) + + +class StrToIntFactory(factory.Factory): + class Meta: + model = AnyStrToInt + + class Params: + default_step = to_int + optional: bool = False + optional_or_not = factory.LazyAttribute(lambda o: OptionalStrToInt(step=o.default_step) if o.optional else o.default_step) + cached = factory.Trait( + inner_step = factory.LazyAttribute( + lambda o: gtx.otf.workflow.CachedStep(step=(o.optional_or_not), hash_function=str) + ) + ) + inner_step = factory.LazyAttribute(lambda o: o.optional_or_not) + +cached = StrToIntFactory(cached=True) +optional = StrToIntFactory(optional=True) +both = StrToIntFactory(cached=True, optional=True) +neither = StrToIntFactory() +neither.inner_step +``` + +### Example in the Wild + +```python +gtx.ffront.past_passes.linters.LinterFactory?? +``` + + + +## Composition 1: Chaining + +So far we have only seen compsition of workflows by chaining. Any sequence of steps can be represented as a chain. Chains can be built of smaller chains, so a Workflow could be composed and then reused in a bigger workflow. + +However, chains are of limited use in the real world, because it's a pain to access a specific step. This we might want to do in order to: + +- run that step in isolation for debugging or other purposes +- build a new chain with a step swapped out (workflows are immutable). + +Imagine swapping out `sub_third` in `complicated_workflow` below (without copy pasting code): + +```python +complicated_workflow = ( + start_step + .chain(first_sub_first.chain(first_sub_second).chain(first_sub_third)) + .chain(second_sub_first.chain(second_sub_second)) + .chain(last) +) +``` + +```mermaid +graph TD +c{{complicated_workflow}} --> 0 --> s{{start_step}} +c --> 1 -->|0| a1{{first_sub_first}} +1 -->|1| a2{{first_sub_second}} +1 -->|2| a3{{first_sub_third}} +c --> 2 -->|0| b1{{second_sub_first}} +2 -->|1| b2{{second_sub_second}} +c --> 3 -->|0| l{{last}} +``` + + + + + +## Composition 2: Sequence of Named Steps + +Let's say we want a string processing workflow where the intermediate stages are also of value on their own. We would want to access individual steps, specifically each step as it was configured for this workflow (with parameters, caching, etc identical). + +For this we can use `NamedStepSequence`, giving each step a name, by which we can access it later. For this we have to create a dataclass and derive from `NamedStepSequence`. Each step is then a field of the dataclass, type hinted as a `Workflow`. The resulting workflow will run the steps in order of their apperance in the class body. + +To use the same "complicated workflow" example from above: + +```python +@dataclasses.dataclass(frozen=True) +class FirstSub(gtx.otf.workflow.NamedStepSequence[B, E]): + first: Workflow[B, C] + second: Workflow[C, D] + third: Workflow[D, E] + + +@dataclasses.dataclass(frozen=True) +class SecondSub(gtx.otf.workflow.NamedStepSequence[E, G]): + first: Workflow[E, F] + second: Workflow[F, G] + + +@dataclasses.dataclass(frozen=True) +class ComplicatedWorkflow(gtx.otf.workflow.NamedStepSequence[A, F]): + start_step: Workflow[A, B] + first_sub: Workflow[B, E] + second_sub: Workflow[E, G] + last: Workflow[G, F] + +complicated_workflow = ComplicatedWorkflow( + start_step=start_step, + first_sub=FirstSub( + first=first_sub_first, + second=first_sub_second, + third=first_sub_third + ), + second_sub=SecondSub( + first=second_sub_first, + second=second_sub_second + ), + last=last +) + +``` + +```mermaid +graph TD + +w{{complicated_workflow: ComplicatedWorkflow}} -->|".start_step"| a{{start_step}} +w -->|".first_sub.first"| b{{first_sub_first}} +w -->|".first_sub.second"| c{{first_sub_second}} +w -->|".first_sub.third"| d{{first_sub_third}} +w -->|".second_sub.first"| e{{second_sub_first}} +w -->|".second_sub_second"| f{{second_sub_second}} +w -->|".last"| g{{last}} +``` + + + +```python editable=true slideshow={"slide_type": ""} +## Here we define how the steps are composed +@dataclasses.dataclass(frozen=True) +class StrProcess(gtx.otf.workflow.NamedStepSequence): + hexify_colors: gtx.otf.workflow.Workflow[str, str] + replace_tabs: gtx.otf.workflow.Workflow[str, str] + + +## Here we define the steps themselves +@dataclasses.dataclass(frozen=True) +class HexifyColors(gtx.otf.workflow.ChainableWorkflowMixin): + color_scheme: dict[str, str] = dataclasses.field( + default_factory=lambda: {"blue": "#0000ff", "green": "#00ff00", "red": "#ff0000"} + ) + + def __call__(self, inp: str) -> str: + result = inp + for color, hexcode in self.color_scheme.items(): + result = result.replace(color, hexcode) + return result + + +def spaces_to_tabs(inp: str) -> str: + return re.sub(r" ", r"\t", inp) +``` + + + +Note that with all this there comes an extra feature: We can easily create variants with different steps, without having to change the code that will use the composed workflow. Even if the calling code calls steps in isolation! + + + +```python editable=true slideshow={"slide_type": ""} +CUSTOM_COLORS = {"blue": "#55aaff", "green": "#00ff00", "red": "#ff0000"} + +proc = StrProcess( + hexify_colors=HexifyColors( + color_scheme=CUSTOM_COLORS + ), + replace_tabs=spaces_to_tabs +) + +proc(""" +p { + background-color: blue; + color: red; +} +""") +``` + +```python editable=true slideshow={"slide_type": ""} +proc.hexify_colors("blue") +``` + + + +`NamedStepSequence`s still work with wrapper steps, parameters and chaining. They can also be nested. So for a complex workflow there would be innumerous possible variants. Therefore expect to often see them paired with factories. + + + +### Example in the Wild + +```python editable=true slideshow={"slide_type": ""} +gtx.backend.DEFAULT_PROG_TRANSFORMS?? +``` + +```python +gtx.program_processors.runners.gtfn.run_gtfn_gpu.executor.otf_workflow?? +``` + +```python +gtx.program_processors.runners.gtfn.GTFNBackendFactory?? +``` + +```python + +``` diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 3d3c7a27e1..3c0d19853e 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -34,23 +34,21 @@ from gt4py.next.program_processors import processor_interface as ppi -@dataclasses.dataclass(frozen=True) -class FopArgsInjector(workflow.Workflow): - args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) - kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) - from_fieldop: Any = None - - def __call__(self, inp: ffront_stages.FoastOperatorDefinition) -> ffront_stages.FoastClosure: - return ffront_stages.FoastClosure( - foast_op_def=inp, - args=self.args, - kwargs=self.kwargs, - closure_vars={inp.foast_node.id: self.from_fieldop}, - ) +@workflow.make_step +def foast_to_foast_closure( + inp: workflow.InputWithArgs[ffront_stages.FoastOperatorDefinition], +) -> ffront_stages.FoastClosure: + from_fieldop = inp.kwargs.pop("from_fieldop") + return ffront_stages.FoastClosure( + foast_op_def=inp.data, + args=inp.args, + kwargs=inp.kwargs, + closure_vars={inp.data.foast_node.id: from_fieldop}, + ) @dataclasses.dataclass(frozen=True) -class FieldopTransformWorkflow(workflow.NamedStepSequence): +class FieldopTransformWorkflow(workflow.NamedStepSequenceWithArgs): """Modular workflow for transformations with access to intermediates.""" func_to_foast: workflow.SkippableStep[ @@ -59,9 +57,9 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): ] = dataclasses.field( default_factory=lambda: func_to_foast.OptionalFuncToFoastFactory(cached=True) ) - foast_inject_args: workflow.Workflow[ - ffront_stages.FoastOperatorDefinition, ffront_stages.FoastClosure - ] = dataclasses.field(default_factory=FopArgsInjector) + foast_to_foast_closure: workflow.Workflow[ + workflow.InputWithArgs[ffront_stages.FoastOperatorDefinition], ffront_stages.FoastClosure + ] = dataclasses.field(default=foast_to_foast_closure, metadata={"takes_args": True}) foast_to_past_closure: workflow.Workflow[ ffront_stages.FoastClosure, ffront_stages.PastClosure ] = dataclasses.field( @@ -90,7 +88,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequence): def step_order(self) -> list[str]: return [ "func_to_foast", - "foast_inject_args", + "foast_to_foast_closure", "foast_to_past_closure", "past_transform_args", "past_to_itir", @@ -101,22 +99,7 @@ def step_order(self) -> list[str]: @dataclasses.dataclass(frozen=True) -class ProgArgsInjector(workflow.Workflow): - args: tuple[Any, ...] = dataclasses.field(default_factory=tuple) - kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) - - def __call__(self, inp: ffront_stages.PastProgramDefinition) -> ffront_stages.PastClosure: - return ffront_stages.PastClosure( - past_node=inp.past_node, - closure_vars=inp.closure_vars, - grid_type=inp.grid_type, - args=self.args, - kwargs=self.kwargs, - ) - - -@dataclasses.dataclass(frozen=True) -class ProgramTransformWorkflow(workflow.NamedStepSequence): +class ProgramTransformWorkflow(workflow.NamedStepSequenceWithArgs): """Modular workflow for transformations with access to intermediates.""" func_to_past: workflow.SkippableStep[ @@ -128,11 +111,22 @@ class ProgramTransformWorkflow(workflow.NamedStepSequence): past_lint: workflow.Workflow[ ffront_stages.PastProgramDefinition, ffront_stages.PastProgramDefinition ] = dataclasses.field(default_factory=past_linters.LinterFactory) - past_inject_args: workflow.Workflow[ + past_to_past_closure: workflow.Workflow[ ffront_stages.PastProgramDefinition, ffront_stages.PastClosure - ] = dataclasses.field(default_factory=ProgArgsInjector) + ] = dataclasses.field( + default=lambda inp: ffront_stages.PastClosure( + past_node=inp.data.past_node, + closure_vars=inp.data.closure_vars, + grid_type=inp.data.grid_type, + args=inp.args, + kwargs=inp.kwargs, + ), + metadata={"takes_args": True}, + ) past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = ( - dataclasses.field(default=past_process_args.past_process_args) + dataclasses.field( + default=past_process_args.past_process_args, metadata={"takes_args": False} + ) ) past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( dataclasses.field(default_factory=past_to_itir.PastToItirFactory) @@ -152,28 +146,22 @@ class Backend(Generic[core_defs.DeviceTypeT]): def __call__( self, program: ffront_stages.ProgramDefinition | ffront_stages.FieldOperatorDefinition, - *args: tuple[Any], - **kwargs: dict[str, Any], + *args: Any, + **kwargs: Any, ) -> None: if isinstance( program, (ffront_stages.FieldOperatorDefinition, ffront_stages.FoastOperatorDefinition) ): offset_provider = kwargs.pop("offset_provider") from_fieldop = kwargs.pop("from_fieldop") - transforms_fop = self.transforms_fop.replace( - foast_inject_args=FopArgsInjector( - args=args, kwargs=kwargs, from_fieldop=from_fieldop - ) + program_call = self.transforms_fop( + workflow.InputWithArgs(program, args, kwargs | {"from_fieldop": from_fieldop}) ) - program_call = transforms_fop(program) program_call = dataclasses.replace( program_call, kwargs=program_call.kwargs | {"offset_provider": offset_provider} ) else: - transforms_prog = self.transforms_prog.replace( - past_inject_args=ProgArgsInjector(args=args, kwargs=kwargs) - ) - program_call = transforms_prog(program) + program_call = self.transforms_prog(workflow.InputWithArgs(program, args, kwargs)) self.executor(program_call.program, *program_call.args, **program_call.kwargs) @property diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index c83748dece..2ab46e4cf9 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -265,3 +265,26 @@ def __call__(self, inp: StartT) -> EndT: def skip_condition(self, inp: StartT) -> bool: raise NotImplementedError() + + +@dataclasses.dataclass +class InputWithArgs(Generic[StartT]): + data: StartT + args: tuple[Any] + kwargs: dict[str, Any] + + +@dataclasses.dataclass(frozen=True) +class NamedStepSequenceWithArgs(NamedStepSequence[InputWithArgs[StartT], EndT]): + def __call__(self, inp: InputWithArgs[StartT]) -> EndT: + args = inp.args + kwargs = inp.kwargs + step_result: Any = inp.data + fields = {f.name: f for f in dataclasses.fields(self)} + for step_name in self.step_order: + step = getattr(self, step_name) + if fields[step_name].metadata.get("takes_args", False): + step_result = step(InputWithArgs(step_result, args, kwargs)) + else: + step_result = step(step_result) + return step_result diff --git a/tests/next_tests/unit_tests/ffront_tests/test_stages.py b/tests/next_tests/unit_tests/ffront_tests/test_stages.py index 67ac96d653..29dcda9e1d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_stages.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_stages.py @@ -12,9 +12,13 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses + import pytest + from gt4py import next as gtx from gt4py.next.ffront import stages +from gt4py.next.otf import workflow @pytest.fixture @@ -87,7 +91,7 @@ def copy_program(a: gtx.Field[[jdim], gtx.int32], out: gtx.Field[[jdim], gtx.int yield copy_program -def test_cache_key_field_op_def(fieldop, samecode_fieldop, different_fieldop): +def test_fingerprint_stage_field_op_def(fieldop, samecode_fieldop, different_fieldop): assert stages.fingerprint_stage(samecode_fieldop.definition_stage) != stages.fingerprint_stage( fieldop.definition_stage ) @@ -96,7 +100,7 @@ def test_cache_key_field_op_def(fieldop, samecode_fieldop, different_fieldop): ) -def test_cache_key_foast_op_def(fieldop, samecode_fieldop, different_fieldop): +def test_fingerprint_stage_foast_op_def(fieldop, samecode_fieldop, different_fieldop): foast = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast(fieldop.definition_stage) samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast( samecode_fieldop.definition_stage @@ -109,42 +113,64 @@ def test_cache_key_foast_op_def(fieldop, samecode_fieldop, different_fieldop): assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast) -def test_cache_key_foast_closure(fieldop, samecode_fieldop, different_fieldop, idim, jdim): - foast_closure = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( - gtx.backend.FopArgsInjector( +@dataclasses.dataclass(frozen=True) +class ToFoastClosure(workflow.NamedStepSequenceWithArgs): + func_to_foast: workflow.Workflow = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast + foast_to_closure: workflow.Workflow = dataclasses.field( + default=gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.foast_to_foast_closure, + metadata={"takes_args": True}, + ) + + +def test_fingerprint_stage_foast_closure(fieldop, samecode_fieldop, different_fieldop, idim, jdim): + toolchain = ToFoastClosure() + foast_closure = toolchain( + workflow.InputWithArgs( + data=fieldop.definition_stage, args=(gtx.zeros({idim: 10}, gtx.int32),), - kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)}, - from_fieldop=fieldop, + kwargs={ + "out": gtx.zeros({idim: 10}, gtx.int32), + "from_fieldop": fieldop, + }, ), - )(fieldop.definition_stage) - samecode = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( - gtx.backend.FopArgsInjector( + ) + samecode = toolchain( + workflow.InputWithArgs( + data=samecode_fieldop.definition_stage, args=(gtx.zeros({idim: 10}, gtx.int32),), - kwargs={"out": gtx.zeros({idim: 10}, gtx.int32)}, - from_fieldop=samecode_fieldop, + kwargs={ + "out": gtx.zeros({idim: 10}, gtx.int32), + "from_fieldop": samecode_fieldop, + }, ) - )(samecode_fieldop.definition_stage) - different = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( - gtx.backend.FopArgsInjector( + ) + different = toolchain( + workflow.InputWithArgs( + data=different_fieldop.definition_stage, args=(gtx.zeros({jdim: 10}, gtx.int32),), - kwargs={"out": gtx.zeros({jdim: 10}, gtx.int32)}, - from_fieldop=different_fieldop, + kwargs={ + "out": gtx.zeros({jdim: 10}, gtx.int32), + "from_fieldop": different_fieldop, + }, ) - )(different_fieldop.definition_stage) - different_args = gtx.backend.DEFAULT_FIELDOP_TRANSFORMS.func_to_foast.chain( - gtx.backend.FopArgsInjector( + ) + different_args = toolchain( + workflow.InputWithArgs( + data=fieldop.definition_stage, args=(gtx.zeros({idim: 11}, gtx.int32),), - kwargs={"out": gtx.zeros({idim: 11}, gtx.int32)}, - from_fieldop=fieldop, + kwargs={ + "out": gtx.zeros({idim: 11}, gtx.int32), + "from_fieldop": fieldop, + }, ) - )(fieldop.definition_stage) + ) assert stages.fingerprint_stage(samecode) != stages.fingerprint_stage(foast_closure) assert stages.fingerprint_stage(different) != stages.fingerprint_stage(foast_closure) assert stages.fingerprint_stage(different_args) != stages.fingerprint_stage(foast_closure) -def test_cache_key_program_def(program, samecode_program, different_program): +def test_fingerprint_stage_program_def(program, samecode_program, different_program): assert stages.fingerprint_stage(samecode_program.definition_stage) != stages.fingerprint_stage( program.definition_stage ) @@ -153,7 +179,7 @@ def test_cache_key_program_def(program, samecode_program, different_program): ) -def test_cache_key_past_def(program, samecode_program, different_program): +def test_fingerprint_stage_past_def(program, samecode_program, different_program): past = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(program.definition_stage) samecode = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(samecode_program.definition_stage) different = gtx.backend.DEFAULT_PROG_TRANSFORMS.func_to_past(different_program.definition_stage) diff --git a/tox.ini b/tox.ini index 6a7623c704..b78a432c6d 100644 --- a/tox.ini +++ b/tox.ini @@ -110,12 +110,12 @@ commands = description = Run notebooks commands_pre = jupytext docs/user/next/QuickstartGuide.md --to .ipynb - jupytext docs/user/next/Advanced_ToolchainWalkthrough.md --to .ipynb + jupytext docs/user/next/advanced/*.md --to .ipynb commands = python -m pytest --nbmake docs/user/next/workshop/slides -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake docs/user/next/workshop/exercises -k 'solutions' -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake docs/user/next/QuickstartGuide.ipynb -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake docs/user/next/Advanced_ToolchainWalkthrough.ipynb -v -n {env:NUM_PROCESSES:1} + python -m pytest --nbmake docs/user/next/advanced -v -n {env:NUM_PROCESSES:1} python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} # -- Other artefacts --