Skip to content

Commit

Permalink
feat[cartesian]: K offset write (#1452)
Browse files Browse the repository at this point in the history
## Description

Looking at enabling offset write in K to help with physics.
Write is allowed on `FORWARD` and `BACKWARD`, disallowed for `PARALLEL`.

TODO:
- [x] Make tests for `conditional`
- [ ] Explore auto `extend` calculation
- [x] Fix `dace:X` backends

Link to #131 
Discussion happened on [GridTools concept
](GridTools/concepts#34)

## Requirements

- [x] All fixes and/or new features come with corresponding tests.
- [x] Important design decisions have been documented in the approriate
ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md)
folder.

---------

Co-authored-by: Hannes Vogt <[email protected]>
Co-authored-by: Florian Deconinck <[email protected]>
  • Loading branch information
3 people authored Sep 23, 2024
1 parent b6d4427 commit e29016d
Show file tree
Hide file tree
Showing 11 changed files with 354 additions and 73 deletions.
1 change: 1 addition & 0 deletions src/gt4py/cartesian/frontend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def generate(
externals: Dict[str, Any],
dtypes: Dict[Type, Type],
options: BuildOptions,
backend_name: str,
) -> gtir.Stencil:
"""
Generate a StencilDefinition from a stencil Python function.
Expand Down
28 changes: 23 additions & 5 deletions src/gt4py/cartesian/frontend/gtscript_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,7 @@ def __init__(
fields: dict,
parameters: dict,
local_symbols: dict,
backend_name: str,
*,
domain: nodes.Domain,
temp_decls: Optional[Dict[str, nodes.FieldDecl]] = None,
Expand All @@ -721,6 +722,7 @@ def __init__(
isinstance(value, (type, np.dtype)) for value in local_symbols.values()
)

self.backend_name = backend_name
self.fields = fields
self.parameters = parameters
self.local_symbols = local_symbols
Expand Down Expand Up @@ -1432,11 +1434,26 @@ def visit_Assign(self, node: ast.Assign) -> list:
for t in node.targets[0].elts if isinstance(node.targets[0], ast.Tuple) else node.targets:
name, spatial_offset, data_index = self._parse_assign_target(t)
if spatial_offset:
if any(offset != 0 for offset in spatial_offset):
if spatial_offset[0] != 0 or spatial_offset[1] != 0:
raise GTScriptSyntaxError(
message="Assignment to non-zero offsets is not supported.",
message="Assignment to non-zero offsets is not supported in IJ.",
loc=nodes.Location.from_ast_node(t),
)
# Case of K-offset
if len(spatial_offset) == 3 and spatial_offset[2] != 0:
if self.iteration_order == nodes.IterationOrder.PARALLEL:
raise GTScriptSyntaxError(
message="Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.",
loc=nodes.Location.from_ast_node(t),
)
if self.backend_name in ["gt:gpu", "dace:gpu"]:
import cupy as cp

if cp.cuda.runtime.runtimeGetVersion() < 12000:
raise GTScriptSyntaxError(
message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} for CUDA<12. Please update CUDA.",
loc=nodes.Location.from_ast_node(t),
)

if not self._is_known(name):
if name in self.temp_decls:
Expand Down Expand Up @@ -1997,7 +2014,7 @@ def extract_arg_descriptors(self):

return api_signature, fields_decls, parameter_decls

def run(self):
def run(self, backend_name: str):
assert (
isinstance(self.ast_root, ast.Module)
and "body" in self.ast_root._fields
Expand Down Expand Up @@ -2055,6 +2072,7 @@ def run(self):
fields=fields_decls,
parameters=parameter_decls,
local_symbols={}, # Not used
backend_name=backend_name,
domain=domain,
temp_decls=temp_decls,
dtypes=self.dtypes,
Expand Down Expand Up @@ -2110,14 +2128,14 @@ def prepare_stencil_definition(cls, definition, externals):
return GTScriptParser.annotate_definition(definition, externals)

@classmethod
def generate(cls, definition, externals, dtypes, options):
def generate(cls, definition, externals, dtypes, options, backend_name):
if options.build_info is not None:
start_time = time.perf_counter()

if not hasattr(definition, "_gtscript_"):
cls.prepare_stencil_definition(definition, externals)
translator = GTScriptParser(definition, externals=externals, dtypes=dtypes, options=options)
definition_ir = translator.run()
definition_ir = translator.run(backend_name)

# GTIR only supports LatLonGrids
if definition_ir.domain != nodes.Domain.LatLonGrid():
Expand Down
31 changes: 28 additions & 3 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,29 @@ def visit_FieldAccess(
is_target: bool,
targets: Set[eve.SymbolRef],
var_offset_fields: Set[eve.SymbolRef],
K_write_with_offset: Set[eve.SymbolRef],
**kwargs: Any,
) -> Union[dcir.IndexAccess, dcir.ScalarAccess]:
"""Generate the relevant accessor to match the memlet that was previously setup.
When a Field is written in K, we force the usage of the OUT memlet throughout the stencil
to make sure all side effects are being properly resolved. Frontend checks ensure that no
parallel code issues sips here.
"""

res: Union[dcir.IndexAccess, dcir.ScalarAccess]
if node.name in var_offset_fields:
if node.name in var_offset_fields.union(K_write_with_offset):
# If write in K, we consider the variable to always be a target
is_target = is_target or node.name in targets or node.name in K_write_with_offset
name = get_tasklet_symbol(node.name, node.offset, is_target=is_target)
res = dcir.IndexAccess(
name=node.name + "__",
name=name,
offset=self.visit(
node.offset,
is_target=False,
is_target=is_target,
targets=targets,
var_offset_fields=var_offset_fields,
K_write_with_offset=K_write_with_offset,
**kwargs,
),
data_index=node.data_index,
Expand Down Expand Up @@ -799,11 +811,23 @@ def visit_VerticalLoop(
)
)

# Variable offsets
var_offset_fields = {
acc.name
for acc in node.walk_values().if_isinstance(oir.FieldAccess)
if isinstance(acc.offset, oir.VariableKOffset)
}

# We book keep - all write offset to K
K_write_with_offset = set()
for assign_node in node.walk_values().if_isinstance(oir.AssignStmt):
if isinstance(assign_node.left, oir.FieldAccess):
if (
isinstance(assign_node.left.offset, common.CartesianOffset)
and assign_node.left.offset.k != 0
):
K_write_with_offset.add(assign_node.left.name)

sections_idx = next(
idx
for idx, item in enumerate(global_ctx.library_node.expansion_specification)
Expand All @@ -821,6 +845,7 @@ def visit_VerticalLoop(
iteration_ctx=iteration_ctx,
symbol_collector=symbol_collector,
var_offset_fields=var_offset_fields,
K_write_with_offset=K_write_with_offset,
**kwargs,
)
)
Expand Down
7 changes: 6 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def visit_IndexAccess(
# if this node is not a target, it will still use the symbol of the write memlet if the
# field was previously written in the same memlet.
memlets = kwargs["read_memlets"] + kwargs["write_memlets"]
memlet = next(mem for mem in memlets if mem.connector == node.name)
try:
memlet = next(mem for mem in memlets if mem.connector == node.name)
except StopIteration:
raise ValueError(
"Memlet connector and tasklet variable mismatch, DaCe IR error."
) from None

index_strs = []
if node.offset is not None:
Expand Down
10 changes: 7 additions & 3 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def get_tasklet_symbol(
name: eve.SymbolRef, offset: Union[CartesianOffset, VariableKOffset], is_target: bool
):
if is_target:
return f"__{name}"
return f"gtOUT__{name}"

acc_name = name + "__"
acc_name = f"gtIN__{name}"
if offset is not None:
offset_strs = []
for axis in dcir.Axis.dims_3d():
Expand Down Expand Up @@ -230,9 +230,12 @@ def _make_access_info(
region,
he_grid,
grid_subset,
is_write,
) -> dcir.FieldAccessInfo:
# Check we have expression offsets in K
# OR write offsets in K
offset = [offset_node.to_dict()[k] for k in "ijk"]
if isinstance(offset_node, oir.VariableKOffset):
if isinstance(offset_node, oir.VariableKOffset) or (offset[2] != 0 and is_write):
variable_offset_axes = [dcir.Axis.K]
else:
variable_offset_axes = []
Expand Down Expand Up @@ -291,6 +294,7 @@ def visit_FieldAccess(
region=region,
he_grid=he_grid,
grid_subset=grid_subset,
is_write=is_write,
)
ctx.access_infos[node.name] = access_info.union(
ctx.access_infos.get(node.name, access_info)
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/cartesian/stencil_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def gtir_pipeline(self) -> GtirPipeline:
return self._build_data.get("gtir_pipeline") or self._build_data.setdefault(
"gtir_pipeline",
GtirPipeline(
self.frontend.generate(self.definition, self.externals, self.dtypes, self.options),
self.frontend.generate(
self.definition, self.externals, self.dtypes, self.options, self.backend.name
),
self.stencil_id,
),
)
Expand Down
12 changes: 12 additions & 0 deletions tests/cartesian_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import datetime

import numpy as np
import pytest

from gt4py import cartesian as gt4pyc
Expand Down Expand Up @@ -54,3 +55,14 @@ def _get_backends_with_storage_info(storage_info_kind: str):
@pytest.fixture()
def id_version():
return gt_utils.shashed_id(str(datetime.datetime.now()))


def get_array_library(backend: str):
"""Return device ready array maker library"""
backend_cls = gt4pyc.backend.from_name(backend)
assert backend_cls is not None
if backend_cls.storage_info["device"] == "gpu":
assert cp is not None
return cp
else:
return np
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import pytest

from gt4py import cartesian as gt4pyc, storage as gt_storage
from gt4py.cartesian import gtscript

from cartesian_tests.definitions import ALL_BACKENDS, PERFORMANCE_BACKENDS
from cartesian_tests.definitions import ALL_BACKENDS, PERFORMANCE_BACKENDS, get_array_library
from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import copy_stencil


Expand All @@ -22,20 +21,10 @@
cp = None


def _get_array_library(backend: str):
backend_cls = gt4pyc.backend.from_name(backend)
assert backend_cls is not None
if backend_cls.storage_info["device"] == "gpu":
assert cp is not None
return cp
else:
return np


@pytest.mark.parametrize("backend", ALL_BACKENDS)
@pytest.mark.parametrize("order", ["C", "F"])
def test_numpy_allocators(backend, order):
xp = _get_array_library(backend)
xp = get_array_library(backend)
shape = (20, 10, 5)
inp = xp.array(xp.random.randn(*shape), order=order, dtype=xp.float_)
outp = xp.zeros(shape=shape, order=order, dtype=xp.float_)
Expand All @@ -48,7 +37,7 @@ def test_numpy_allocators(backend, order):

@pytest.mark.parametrize("backend", PERFORMANCE_BACKENDS)
def test_bad_layout_warns(backend):
xp = _get_array_library(backend)
xp = get_array_library(backend)
backend_cls = gt4pyc.backend.from_name(backend)
assert backend_cls is not None

Expand Down
Loading

0 comments on commit e29016d

Please sign in to comment.