Skip to content

Commit

Permalink
Merge origin/main
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Dec 10, 2024
2 parents fc46edf + ae62965 commit 4e12195
Show file tree
Hide file tree
Showing 135 changed files with 7,347 additions and 7,751 deletions.
4 changes: 2 additions & 2 deletions .github/pull_request_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Delete this comment and add a proper description of the changes contained in thi
- test: Adding missing tests or correcting existing tests
<scope>: cartesian | eve | next | storage
# ONLY if changes are limited to a specific subsytem
# ONLY if changes are limited to a specific subsystem
- PR Description:
Expand All @@ -27,7 +27,7 @@ Delete this comment and add a proper description of the changes contained in thi
## Requirements

- [ ] All fixes and/or new features come with corresponding tests.
- [ ] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder.
- [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder.

If this PR contains code authored by new contributors please make sure:

Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ repos:
- devtools==0.12.2
- diskcache==5.6.3
- factory-boy==3.3.1
- filelock==3.16.1
- frozendict==2.4.6
- gridtools-cpp==2.3.8
- importlib-resources==6.4.5
Expand Down
8 changes: 4 additions & 4 deletions constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ executing==2.1.0 # via devtools, stack-data
factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy
faker==33.0.0 # via factory-boy
fastjsonschema==2.20.0 # via nbformat
filelock==3.16.1 # via tox, virtualenv
filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv
fonttools==4.55.0 # via matplotlib
fparser==0.1.4 # via dace
frozendict==2.4.6 # via gt4py (pyproject.toml)
Expand Down Expand Up @@ -113,8 +113,8 @@ psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist
ptyprocess==0.7.0 # via pexpect
pure-eval==0.2.3 # via stack-data
pybind11==2.13.6 # via gt4py (pyproject.toml)
pydantic==2.9.2 # via bump-my-version, pydantic-settings
pydantic-core==2.23.4 # via pydantic
pydantic==2.10.0 # via bump-my-version, pydantic-settings
pydantic-core==2.27.0 # via pydantic
pydantic-settings==2.6.1 # via bump-my-version
pydot==3.0.2 # via tach
pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx
Expand Down Expand Up @@ -159,7 +159,7 @@ stack-data==0.6.3 # via ipython
stdlib-list==0.10.0 # via tach
sympy==1.13.3 # via dace
tabulate==0.9.0 # via gt4py (pyproject.toml)
tach==0.14.3 # via -r requirements-dev.in
tach==0.14.4 # via -r requirements-dev.in
tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox
tomli-w==1.0.0 # via tach
tomlkit==0.13.2 # via bump-my-version
Expand Down
2 changes: 1 addition & 1 deletion docs/user/next/advanced/HackTheToolchain.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from gt4py import eve

```python
cached_lowering_toolchain = gtx.backend.DEFAULT_TRANSFORMS.replace(
past_to_itir=gtx.ffront.past_to_itir.past_to_itir_factory(cached=False)
past_to_itir=gtx.ffront.past_to_itir.past_to_gtir_factory(cached=False)
)
```

Expand Down
1 change: 1 addition & 0 deletions min-extra-requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ deepdiff==5.6.0
devtools==0.6
diskcache==5.6.3
factory-boy==3.3.0
filelock==3.0.0
frozendict==2.3
gridtools-cpp==2.3.8
hypothesis==6.0.0
Expand Down
1 change: 1 addition & 0 deletions min-requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ deepdiff==5.6.0
devtools==0.6
diskcache==5.6.3
factory-boy==3.3.0
filelock==3.0.0
frozendict==2.3
gridtools-cpp==2.3.8
hypothesis==6.0.0
Expand Down
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
'devtools>=0.6',
'diskcache>=5.6.3',
'factory-boy>=3.3.0',
'filelock>=3.0.0',
'frozendict>=2.3',
'gridtools-cpp>=2.3.8,==2.*',
"importlib-resources>=5.0;python_version<'3.9'",
Expand Down Expand Up @@ -239,17 +240,14 @@ markers = [
'requires_atlas: tests that require `atlas4py` bindings package',
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'starts_from_gtir_program: tests that require backend to start lowering from GTIR program',
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_constant_fields: tests that require backend support for constant fields',
'uses_dynamic_offsets: tests that require backend support for dynamic offsets',
'uses_floordiv: tests that require backend support for floor division',
'uses_if_stmts: tests that require backend support for if-statements',
'uses_index_fields: tests that require backend support for index fields',
'uses_lift_expressions: tests that require backend support for lift expressions',
'uses_negative_modulo: tests that require backend support for modulo on negative numbers',
'uses_origin: tests that require backend support for domain origin',
'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions',
'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields',
'uses_scan: tests that uses scan',
'uses_scan_in_field_operator: tests that require backend support for scan in field operator',
Expand Down
8 changes: 4 additions & 4 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ executing==2.1.0 # via -c constraints.txt, devtools, stack-data
factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy
faker==33.0.0 # via -c constraints.txt, factory-boy
fastjsonschema==2.20.0 # via -c constraints.txt, nbformat
filelock==3.16.1 # via -c constraints.txt, tox, virtualenv
filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv
fonttools==4.55.0 # via -c constraints.txt, matplotlib
fparser==0.1.4 # via -c constraints.txt, dace
frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml)
Expand Down Expand Up @@ -113,8 +113,8 @@ psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipyk
ptyprocess==0.7.0 # via -c constraints.txt, pexpect
pure-eval==0.2.3 # via -c constraints.txt, stack-data
pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml)
pydantic==2.9.2 # via -c constraints.txt, bump-my-version, pydantic-settings
pydantic-core==2.23.4 # via -c constraints.txt, pydantic
pydantic==2.10.0 # via -c constraints.txt, bump-my-version, pydantic-settings
pydantic-core==2.27.0 # via -c constraints.txt, pydantic
pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version
pydot==3.0.2 # via -c constraints.txt, tach
pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx
Expand Down Expand Up @@ -158,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython
stdlib-list==0.10.0 # via -c constraints.txt, tach
sympy==1.13.3 # via -c constraints.txt, dace
tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml)
tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in
tach==0.14.4 # via -c constraints.txt, -r requirements-dev.in
tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox
tomli-w==1.0.0 # via -c constraints.txt, tach
tomlkit==0.13.2 # via -c constraints.txt, bump-my-version
Expand Down
16 changes: 8 additions & 8 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,17 @@


def _specialize_transient_strides(sdfg: dace.SDFG, layout_map):
repldict = replace_strides(
replacement_dictionary = replace_strides(
[array for array in sdfg.arrays.values() if array.transient], layout_map
)
sdfg.replace_dict(repldict)
sdfg.replace_dict(replacement_dictionary)
for state in sdfg.nodes():
for node in state.nodes():
if isinstance(node, dace.nodes.NestedSDFG):
for k, v in repldict.items():
for k, v in replacement_dictionary.items():
if k in node.symbol_mapping:
node.symbol_mapping[k] = v
for k in repldict.keys():
for k in replacement_dictionary.keys():
if k in sdfg.symbols:
sdfg.remove_symbol(k)

Expand Down Expand Up @@ -143,7 +143,7 @@ def _to_device(sdfg: dace.SDFG, device: str) -> None:
node.device = dace.DeviceType.GPU


def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map):
def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map):
args_data = make_args_data_from_gtir(gtir_pipeline)

# stencils without effect
Expand All @@ -164,7 +164,7 @@ def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map)
return sdfg


def _post_expand_trafos(sdfg: dace.SDFG):
def _post_expand_transformations(sdfg: dace.SDFG):
# DaCe "standard" clean-up transformations
sdfg.simplify(validate=False)

Expand Down Expand Up @@ -355,7 +355,7 @@ def _unexpanded_sdfg(self):
sdfg = OirSDFGBuilder().visit(oir_node)

_to_device(sdfg, self.builder.backend.storage_info["device"])
_pre_expand_trafos(
_pre_expand_transformations(
self.builder.gtir_pipeline,
sdfg,
self.builder.backend.storage_info["layout_map"],
Expand All @@ -371,7 +371,7 @@ def unexpanded_sdfg(self):
def _expanded_sdfg(self):
sdfg = self._unexpanded_sdfg()
sdfg.expand_library_nodes()
_post_expand_trafos(sdfg)
_post_expand_transformations(sdfg)
return sdfg

def expanded_sdfg(self):
Expand Down
7 changes: 7 additions & 0 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,13 @@ def visit_Assign(self, node: ast.Assign) -> list:
loc=nodes.Location.from_ast_node(t),
)

if self.backend_name in ["gt:gpu"]:
raise GTScriptSyntaxError(
message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} as an unsolved bug remains."
"Please refer to https://github.com/GridTools/gt4py/issues/1754.",
loc=nodes.Location.from_ast_node(t),
)

if not self._is_known(name):
if name in self.temp_decls:
field_decl = self.temp_decls[name]
Expand Down
50 changes: 25 additions & 25 deletions src/gt4py/cartesian/gtc/dace/expansion_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def get_expansion_order_index(expansion_order, axis):
for idx, item in enumerate(expansion_order):
if isinstance(item, Iteration) and item.axis == axis:
return idx
elif isinstance(item, Map):

if isinstance(item, Map):
for it in item.iterations:
if it.kind == "contiguous" and it.axis == axis:
return idx
Expand Down Expand Up @@ -136,7 +137,9 @@ def _choose_loop_or_map(node, eo):
return eo


def _order_as_spec(computation_node, expansion_order):
def _order_as_spec(
computation_node: StencilComputation, expansion_order: Union[List[str], List[ExpansionItem]]
) -> List[ExpansionItem]:
expansion_order = list(_choose_loop_or_map(computation_node, eo) for eo in expansion_order)
expansion_specification = []
for item in expansion_order:
Expand Down Expand Up @@ -170,7 +173,7 @@ def _order_as_spec(computation_node, expansion_order):
return expansion_specification


def _populate_strides(node, expansion_specification):
def _populate_strides(node: StencilComputation, expansion_specification: List[ExpansionItem]):
"""Fill in `stride` attribute of `Iteration` and `Loop` dataclasses.
For loops, stride is set to either -1 or 1, based on iteration order.
Expand All @@ -185,10 +188,7 @@ def _populate_strides(node, expansion_specification):
for it in iterations:
if isinstance(it, Loop):
if it.stride is None:
if node.oir_node.loop_order == common.LoopOrder.BACKWARD:
it.stride = -1
else:
it.stride = 1
it.stride = -1 if node.oir_node.loop_order == common.LoopOrder.BACKWARD else 1
else:
if it.stride is None:
if it.kind == "tiling":
Expand All @@ -204,7 +204,7 @@ def _populate_strides(node, expansion_specification):
it.stride = 1


def _populate_storages(self, expansion_specification):
def _populate_storages(expansion_specification: List[ExpansionItem]):
assert all(isinstance(es, ExpansionItem) for es in expansion_specification)
innermost_axes = set(dcir.Axis.dims_3d())
tiled_axes = set()
Expand All @@ -222,7 +222,7 @@ def _populate_storages(self, expansion_specification):
tiled_axes.remove(it.axis)


def _populate_cpu_schedules(self, expansion_specification):
def _populate_cpu_schedules(expansion_specification: List[ExpansionItem]):
is_outermost = True
for es in expansion_specification:
if isinstance(es, Map):
Expand All @@ -234,7 +234,7 @@ def _populate_cpu_schedules(self, expansion_specification):
es.schedule = dace.ScheduleType.Default


def _populate_gpu_schedules(self, expansion_specification):
def _populate_gpu_schedules(expansion_specification: List[ExpansionItem]):
# On GPU if any dimension is tiled and has a contiguous map in the same axis further in
# pick those two maps as Device/ThreadBlock maps. If not, Make just device map with
# default blocksizes
Expand Down Expand Up @@ -267,16 +267,16 @@ def _populate_gpu_schedules(self, expansion_specification):
es.schedule = dace.ScheduleType.Default


def _populate_schedules(self, expansion_specification):
def _populate_schedules(node: StencilComputation, expansion_specification: List[ExpansionItem]):
assert all(isinstance(es, ExpansionItem) for es in expansion_specification)
assert hasattr(self, "_device")
if self.device == dace.DeviceType.GPU:
_populate_gpu_schedules(self, expansion_specification)
assert hasattr(node, "_device")
if node.device == dace.DeviceType.GPU:
_populate_gpu_schedules(expansion_specification)
else:
_populate_cpu_schedules(self, expansion_specification)
_populate_cpu_schedules(expansion_specification)


def _collapse_maps_gpu(self, expansion_specification):
def _collapse_maps_gpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]:
def _union_map_items(last_item, next_item):
if last_item.schedule == next_item.schedule:
return (
Expand Down Expand Up @@ -307,7 +307,7 @@ def _union_map_items(last_item, next_item):
),
)

res_items = []
res_items: List[ExpansionItem] = []
for item in expansion_specification:
if isinstance(item, Map):
if not res_items or not isinstance(res_items[-1], Map):
Expand All @@ -324,8 +324,8 @@ def _union_map_items(last_item, next_item):
return res_items


def _collapse_maps_cpu(self, expansion_specification):
res_items = []
def _collapse_maps_cpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]:
res_items: List[ExpansionItem] = []
for item in expansion_specification:
if isinstance(item, Map):
if (
Expand Down Expand Up @@ -360,12 +360,12 @@ def _collapse_maps_cpu(self, expansion_specification):
return res_items


def _collapse_maps(self, expansion_specification):
assert hasattr(self, "_device")
if self.device == dace.DeviceType.GPU:
res_items = _collapse_maps_gpu(self, expansion_specification)
def _collapse_maps(node: StencilComputation, expansion_specification: List[ExpansionItem]):
assert hasattr(node, "_device")
if node.device == dace.DeviceType.GPU:
res_items = _collapse_maps_gpu(expansion_specification)
else:
res_items = _collapse_maps_cpu(self, expansion_specification)
res_items = _collapse_maps_cpu(expansion_specification)
expansion_specification.clear()
expansion_specification.extend(res_items)

Expand All @@ -387,7 +387,7 @@ def make_expansion_order(
_populate_strides(node, expansion_specification)
_populate_schedules(node, expansion_specification)
_collapse_maps(node, expansion_specification)
_populate_storages(node, expansion_specification)
_populate_storages(expansion_specification)
return expansion_specification


Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def visit_VerticalLoop(
state.add_edge(
access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset)
)

for field in access_collection.write_fields():
access_node = state.add_access(field, debuginfo=dace.DebugInfo(0))
library_node.add_out_connector("__out_" + field)
Expand All @@ -131,8 +132,6 @@ def visit_VerticalLoop(
library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset)
)

return

def visit_Stencil(self, node: oir.Stencil, **kwargs):
ctx = OirSDFGBuilder.SDFGContext(stencil=node)
for param in node.params:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def array_dimensions(array: dace.data.Array):
return dims


def replace_strides(arrays, get_layout_map):
def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, str]:
symbol_mapping = {}
for array in arrays:
dims = array_dimensions(array)
Expand Down
1 change: 0 additions & 1 deletion src/gt4py/eve/.gitignore

This file was deleted.

Loading

0 comments on commit 4e12195

Please sign in to comment.