Skip to content

Commit

Permalink
feat[next][dace]: Lowering to SDFG of index builtin (#1751)
Browse files Browse the repository at this point in the history
Implements the lowering to SDFG of the GTIR index builtin.
  • Loading branch information
edopao authored Nov 27, 2024
1 parent f6c219b commit f6c0498
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 16 deletions.
14 changes: 14 additions & 0 deletions src/gt4py/next/iterator/ir_utils/ir_makers.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,20 @@ def _impl(it: itir.Expr) -> itir.FunCall:
return _impl


def index(dim: common.Dimension) -> itir.FunCall:
"""
Create a call to the `index` builtin, shorthand for `call("index")(axis)`,
after converting the given dimension to `itir.AxisLiteral`.
Args:
dim: the dimension corresponding to the index axis.
Returns:
A function that constructs a Field of indices in the given dimension.
"""
return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind))


def map_(op):
"""Create a `map_` call."""
return call(call("map_")(op))
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from gt4py.next import common as gtx_common, utils as gtx_utils
from gt4py.next.ffront import fbuiltins as gtx_fbuiltins
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils
from gt4py.next.iterator.type_system import type_specifications as itir_ts
from gt4py.next.program_processors.runners.dace_common import utility as dace_utils
from gt4py.next.program_processors.runners.dace_fieldview import (
Expand Down Expand Up @@ -277,20 +277,31 @@ def extract_domain(node: gtir.Node) -> FieldopDomain:
the corresponding lower and upper bounds. The returned lower bound is inclusive,
the upper bound is exclusive: [lower_bound, upper_bound[
"""
assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain"))

domain = []
for named_range in node.args:
assert cpm.is_call_to(named_range, "named_range")
assert len(named_range.args) == 3
axis = named_range.args[0]
assert isinstance(axis, gtir.AxisLiteral)
lower_bound, upper_bound = (
dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg))
for arg in named_range.args[1:3]
)
dim = gtx_common.Dimension(axis.value, axis.kind)
domain.append((dim, lower_bound, upper_bound))

def parse_range_boundary(expr: gtir.Expr) -> str:
return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr))

if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")):
for named_range in node.args:
assert cpm.is_call_to(named_range, "named_range")
assert len(named_range.args) == 3
axis = named_range.args[0]
assert isinstance(axis, gtir.AxisLiteral)
lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3])
dim = gtx_common.Dimension(axis.value, axis.kind)
domain.append((dim, lower_bound, upper_bound))

elif isinstance(node, domain_utils.SymbolicDomain):
assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"}
for dim, drange in node.ranges.items():
domain.append(
(dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop))
)

else:
raise ValueError(f"Invalid domain {node}.")

return domain

Expand Down Expand Up @@ -545,6 +556,51 @@ def construct_output(inner_data: FieldopData) -> FieldopData:
return result_temps


def translate_index(
node: gtir.Node,
sdfg: dace.SDFG,
state: dace.SDFGState,
sdfg_builder: gtir_sdfg.SDFGBuilder,
) -> FieldopResult:
"""
Lowers the `index` builtin function to a mapped tasklet that writes the dimension
index values to a transient array. The extent of the index range is taken from
the domain information that should be present in the node annex.
"""
assert "domain" in node.annex
domain = extract_domain(node.annex.domain)
assert len(domain) == 1
dim, lower_bound, upper_bound = domain[0]
dim_index = dace_gtir_utils.get_map_variable(dim)

field_dims, field_offset, field_shape = _get_field_layout(domain)
field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE))

output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE)
output_node = state.add_access(output)

sdfg_builder.add_mapped_tasklet(
"index",
state,
map_ranges={
dim_index: f"{lower_bound}:{upper_bound}",
},
inputs={},
code=f"__val = {dim_index}",
outputs={
"__val": dace.Memlet(
data=output_node.data,
subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)),
)
},
input_nodes={},
output_nodes={output_node.data: output_node},
external_edges=True,
)

return FieldopData(output_node, field_type, field_offset)


def _get_data_nodes(
sdfg: dace.SDFG,
state: dace.SDFGState,
Expand Down Expand Up @@ -777,6 +833,7 @@ def translate_symbol_ref(
translate_as_fieldop,
translate_broadcast_scalar,
translate_if,
translate_index,
translate_literal,
translate_make_tuple,
translate_tuple_get,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,8 @@ def visit_FunCall(
# use specialized dataflow builder classes for each builtin function
if cpm.is_call_to(node, "if_"):
return gtir_builtin_translators.translate_if(node, sdfg, head_state, self)
elif cpm.is_call_to(node, "index"):
return gtir_builtin_translators.translate_index(node, sdfg, head_state, self)
elif cpm.is_call_to(node, "make_tuple"):
return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self)
elif cpm.is_call_to(node, "tuple_get"):
Expand Down
1 change: 0 additions & 1 deletion tests/next_tests/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
(STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE),
]
GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [
(USES_INDEX_BUILTIN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE),
(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
Note: this test module covers the fieldview flavour of ITIR.
"""

import copy
import functools

import numpy as np
import pytest

from gt4py.next import common as gtx_common, constructors
from gt4py.next import common as gtx_common
from gt4py.next.iterator import ir as gtir
from gt4py.next.iterator.ir_utils import ir_makers as im
from gt4py.next.iterator.transforms import infer_domain
from gt4py.next.type_system import type_specifications as ts

from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import (
Expand Down Expand Up @@ -1973,3 +1973,49 @@ def test_gtir_if_values():

sdfg(a, b, c, **FSYMBOLS)
assert np.allclose(c, np.where(a < b, a, b))


def test_gtir_index():
MARGIN = 2
assert MARGIN < N
domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")})
subdomain = im.domain(
gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))}
)

testee = gtir.Program(
id="gtir_cast",
function_definitions=[],
params=[
gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)),
gtir.Sym(id="size", type=SIZE_TYPE),
],
declarations=[],
body=[
gtir.SetAt(
expr=im.let("i", im.index(IDim))(
im.op_as_fieldop("plus", domain)(
"i",
im.as_fieldop(
im.lambda_("a")(im.deref(im.shift(IDim.value, 1)("a"))), subdomain
)("i"),
)
),
domain=subdomain,
target=gtir.SymRef(id="x"),
)
],
)

v = np.empty(N, dtype=np.int32)

# we need to run domain inference in order to add the domain annex information to the index node.
testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS)
sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS)

ref = np.concatenate(
(v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :])
)

sdfg(v, **FSYMBOLS)
np.allclose(v, ref)

0 comments on commit f6c0498

Please sign in to comment.