-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor[next]: cleaner toolchain (#1537)
## Changed: Toolchain does not own input arguments to DSL programs anymore, instead the input datastructure owns them and the toolchain can dispatch them to a subset of the steps. ## Toolchain migration: Old: ```python FieldopTransformWorkflow().replace(foast_inject_args=FopArgsInjector(*args, **kwargs))(fieldoperator_definition) ProgramTransformWorkflow().replace(program_inject_args=ProgArgsInjector(*args, **kwargs))(program_definition) ``` New: ```python FieldopTransformWorkflow()(InputWithArgs(fieldoperator_definition, args, kwargs)) ProgramTransformWorkflow()(InputWithArgs(program_definition, args, kwargs)) ``` ## Added: - `otf.workflow`: - new workflow type: `NamedStepSequenceWithArgs` takes an `InputWithArgs` and dispatches `.args` and `.kwargs` to steps that set `take_args = True` in the field metadata - new data type `InputWithArgs` wraps a workflow stage and call args - `backend`: Replace `*ArgsInjector` using the new `NamedStepSequenceWithArgs` infrastructure
- Loading branch information
Showing
7 changed files
with
967 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
```python | ||
import dataclasses | ||
import typing | ||
|
||
from gt4py import next as gtx | ||
from gt4py.next.otf import workflow | ||
from gt4py import eve | ||
``` | ||
|
||
<link href="https://fonts.googleapis.com/icon?family=Material+Icons" rel="stylesheet"><script src="https://spcl.github.io/dace/webclient2/dist/sdfv.js"></script> | ||
<link href="https://spcl.github.io/dace/webclient2/sdfv.css" rel="stylesheet"> | ||
|
||
## Replace Steps | ||
|
||
```python | ||
cached_lowering_toolchain = gtx.backend.DEFAULT_PROG_TRANSFORMS.replace( | ||
past_to_itir=workflow.CachedStep( | ||
step=gtx.ffront.past_to_itir.PastToItirFactory(), | ||
hash_function=eve.utils.content_hash | ||
) | ||
) | ||
``` | ||
|
||
## Skip Steps / Change Order | ||
|
||
```python | ||
gtx.backend.DEFAULT_PROG_TRANSFORMS.step_order | ||
``` | ||
|
||
['func_to_past', | ||
'past_lint', | ||
'past_inject_args', | ||
'past_transform_args', | ||
'past_to_itir'] | ||
|
||
```python | ||
@dataclasses.dataclass(frozen=True) | ||
class SkipLinting(gtx.backend.ProgramTransformWorkflow): | ||
@property | ||
def step_order(self): | ||
return [ | ||
"func_to_past", | ||
# not running "past_lint" | ||
"past_inject_args", | ||
"past_transform_args", | ||
"past_to_itir", | ||
] | ||
|
||
same_steps = dataclasses.asdict(gtx.backend.DEFAULT_PROG_TRANSFORMS) | ||
skip_linting_transforms = SkipLinting( | ||
**same_steps | ||
) | ||
``` | ||
|
||
## Alternative Factory | ||
|
||
```python | ||
class MyCodeGen: | ||
... | ||
|
||
class Cpp2BindingsGen: | ||
... | ||
|
||
class PureCpp2WorkflowFactory(gtx.program_processors.runners.gtfn.GTFNCompileWorkflowFactory): | ||
translation: workflow.Workflow[ | ||
gtx.otf.stages.ProgramCall, gtx.otf.stages.ProgramSource] = MyCodeGen() | ||
bindings: workflow.Workflow[ | ||
gtx.otf.stages.ProgramSource, gtx.otf.stages.CompilableSource] = Cpp2BindingsGen() | ||
|
||
PureCpp2WorkflowFactory(cmake_build_type=gtx.config.CMAKE_BUILD_TYPE.DEBUG) | ||
``` | ||
|
||
## Invent new Workflow Types | ||
|
||
````mermaid | ||
graph LR | ||
IN_T --> i{{split}} --> A_T --> a{{track_a}} --> B_T --> o{{combine}} --> OUT_T | ||
i --> X_T --> x{{track_x}} --> Y_T --> o | ||
```python | ||
IN_T = typing.TypeVar("IN_T") | ||
A_T = typing.TypeVar("A_T") | ||
B_T = typing.TypeVar("B_T") | ||
X_T = typing.TypeVar("X_T") | ||
Y_T = typing.TypeVar("Y_T") | ||
OUT_T = typing.TypeVar("OUT_T") | ||
@dataclasses.dataclass(frozen=True) | ||
class FullyModularDiamond( | ||
workflow.ChainableWorkflowMixin[IN_T, OUT_T], | ||
workflow.ReplaceEnabledWorkflowMixin[IN_T, OUT_T], | ||
typing.Protocol[IN_T, OUT_T, A_T, B_T, X_T, Y_T] | ||
): | ||
split: workflow.Workflow[IN_T, tuple[A_T, X_T]] | ||
track_a: workflow.Workflow[A_T, B_T] | ||
track_x: workflow.Workflow[X_T, Y_T] | ||
combine: workflow.Workflow[tuple[B_T, Y_T], OUT_T] | ||
def __call__(self, inp: IN_T) -> OUT_T: | ||
a, x = self.split(inp) | ||
b = self.track_a(a) | ||
y = self.track_x(x) | ||
return self.combine((b, y)) | ||
@dataclasses.dataclass(frozen=True) | ||
class PartiallyModularDiamond( | ||
workflow.ChainableWorkflowMixin[IN_T, OUT_T], | ||
workflow.ReplaceEnabledWorkflowMixin[IN_T, OUT_T], | ||
typing.Protocol[IN_T, OUT_T, A_T, B_T, X_T, Y_T] | ||
): | ||
track_a: workflow.Workflow[A_T, B_T] | ||
track_x: workflow.Workflow[X_T, Y_T] | ||
def split(inp: IN_T) -> tuple[A_T, X_T]: | ||
... | ||
def combine(b: B_T, y: Y_T) -> OUT_T: | ||
... | ||
def __call__(inp: IN_T) -> OUT_T: | ||
a, x = self.split(inp) | ||
return self.combine( | ||
b=self.track_a(a), | ||
y=self.track_x(x) | ||
) | ||
```` |
Oops, something went wrong.