Skip to content

Commit

Permalink
fix[next]: Improvements in DaCe backend (#1354)
Browse files Browse the repository at this point in the history
This PR contains some fixes and code refactoring in DaCe backend:
 * (refactoring) Use memlet API for full array subset
 * Fix for gpu execution: import cupy for sorting of field dimensions.
 * Fix for symbolic analysis of memlet volume: define symbols before visiting the closure domain in order to allow symbolic analysis of memlet volume
  • Loading branch information
edopao authored Oct 31, 2023
1 parent 0650d77 commit 3c463a6
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 42 deletions.
37 changes: 22 additions & 15 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims


try:
import cupy as cp
except ImportError:
cp = None


def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]:
sorted_dims = get_sorted_dims(domain.dims)
return [domain.ranges[dim_index] for dim_index, _ in sorted_dims]
Expand All @@ -49,8 +55,11 @@ def convert_arg(arg: Any):
sorted_dims = get_sorted_dims(arg.domain.dims)
ndim = len(sorted_dims)
dim_indices = [dim_index for dim_index, _ in sorted_dims]
assert isinstance(arg.ndarray, np.ndarray)
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
if isinstance(arg.ndarray, np.ndarray):
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
else:
assert cp is not None and isinstance(arg.ndarray, cp.ndarray)
return cp.moveaxis(arg.ndarray, range(ndim), dim_indices)
return arg


Expand Down Expand Up @@ -226,24 +235,22 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:


@program_executor
def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program,
*args,
**kwargs,
build_cache=_build_cache_cpu,
build_type=_build_type,
run_on_gpu=False,
)
def run_dace(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_on_gpu = any(not isinstance(arg.ndarray, np.ndarray) for arg in args if is_field(arg))
if run_on_gpu:
if cp is None:
raise RuntimeError(
f"Non-numpy field argument passed to program {program.id} but module cupy not installed"
)

if not all(isinstance(arg.ndarray, cp.ndarray) for arg in args if is_field(arg)):
raise RuntimeError("Execution on GPU requires all fields to be stored as cupy arrays")

@program_executor
def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
run_dace_iterator(
program,
*args,
**kwargs,
build_cache=_build_cache_gpu,
build_cache=_build_cache_gpu if run_on_gpu else _build_cache_cpu,
build_type=_build_type,
run_on_gpu=True,
run_on_gpu=run_on_gpu,
)
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def visit_StencilClosure(
# Update symbol table and get output domain of the closure
for name, type_ in self.storage_types.items():
if isinstance(type_, ts.ScalarType):
dtype = as_dace_type(type_)
closure_sdfg.add_symbol(name, dtype)
if name in input_names:
dtype = as_dace_type(type_)
closure_sdfg.add_symbol(name, dtype)
out_name = unique_var_name()
closure_sdfg.add_scalar(out_name, dtype, transient=True)
out_tasklet = closure_init_state.add_tasklet(
Expand All @@ -272,7 +272,7 @@ def visit_StencilClosure(
closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet)
program_arg_syms[name] = value
else:
program_arg_syms[name] = SymbolExpr(name, as_dace_type(type_))
program_arg_syms[name] = SymbolExpr(name, dtype)
closure_domain = self._visit_domain(node.domain, closure_ctx)

# Map SDFG tasklet arguments to parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
add_mapped_nested_sdfg,
as_dace_type,
connectivity_identifier,
create_memlet_at,
create_memlet_full,
filter_neighbor_tables,
flatten_list,
Expand Down Expand Up @@ -199,7 +200,6 @@ def builtin_neighbors(
result_access = state.add_access(result_name)

table_name = connectivity_identifier(offset_dim)
table_array = sdfg.arrays[table_name]

# generate unique map index name to avoid conflict with other maps inside same state
index_name = unique_name("__neigh_idx")
Expand All @@ -225,43 +225,40 @@ def builtin_neighbors(
state.add_access(table_name),
me,
shift_tasklet,
memlet=dace.Memlet(data=table_name, subset=",".join(f"0:{s}" for s in table_array.shape)),
memlet=create_memlet_full(table_name, sdfg.arrays[table_name]),
dst_conn="__table",
)
state.add_memlet_path(
iterator.indices[shifted_dim],
me,
shift_tasklet,
memlet=dace.Memlet(data=iterator.indices[shifted_dim].data, subset="0"),
memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0"),
dst_conn="__idx",
)
state.add_edge(
shift_tasklet,
"__result",
data_access_tasklet,
"__idx",
dace.Memlet(data=idx_name, subset="0"),
dace.Memlet.simple(idx_name, "0"),
)
# select full shape only in the neighbor-axis dimension
field_subset = [
field_subset = tuple(
f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}"
for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape)
]
)
state.add_memlet_path(
iterator.field,
me,
data_access_tasklet,
memlet=dace.Memlet(
data=iterator.field.data,
subset=",".join(field_subset),
),
memlet=create_memlet_at(iterator.field.data, field_subset),
dst_conn="__field",
)
state.add_memlet_path(
data_access_tasklet,
mx,
result_access,
memlet=dace.Memlet(data=result_name, subset=index_name),
memlet=dace.Memlet.simple(result_name, index_name),
src_conn="__result",
)

Expand Down Expand Up @@ -438,7 +435,7 @@ def visit_Lambda(
result_access,
None,
# in case of reduction lambda, the output edge from lambda tasklet performs write-conflict resolution
dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr),
dace.Memlet.simple(result_access.data, "0", wcr_str=context.reduce_wcr),
)
result = ValueExpr(value=result_access, dtype=expr.dtype)
else:
Expand Down Expand Up @@ -616,7 +613,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]:
deref_tasklet,
mx,
result_access,
memlet=dace.Memlet(data=result_name, subset=index_name),
memlet=dace.Memlet.simple(result_name, index_name),
src_conn="__result",
)

Expand Down Expand Up @@ -738,13 +735,13 @@ def _visit_reduce(self, node: itir.FunCall):
assert isinstance(op_name, itir.SymRef)
init = node.fun.args[1]

nreduce = self.context.body.arrays[neighbors_expr.value.data].shape[0]
reduce_array_desc = neighbors_expr.value.desc(self.context.body)

self.context.body.add_scalar(result_name, result_dtype, transient=True)
op_str = _MATH_BUILTINS_MAPPING[str(op_name)].format("__result", "__values[__idx]")
reduce_tasklet = self.context.state.add_tasklet(
"reduce",
code=f"__result = {init}\nfor __idx in range({nreduce}):\n __result = {op_str}",
code=f"__result = {init}\nfor __idx in range({reduce_array_desc.shape[0]}):\n __result = {op_str}",
inputs={"__values"},
outputs={"__result"},
)
Expand All @@ -753,14 +750,14 @@ def _visit_reduce(self, node: itir.FunCall):
None,
reduce_tasklet,
"__values",
dace.Memlet(data=neighbors_expr.value.data, subset=f"0:{nreduce}"),
create_memlet_full(neighbors_expr.value.data, reduce_array_desc),
)
self.context.state.add_edge(
reduce_tasklet,
"__result",
result_access,
None,
dace.Memlet(data=result_name, subset="0"),
dace.Memlet.simple(result_name, "0"),
)
else:
assert isinstance(node.fun, itir.FunCall)
Expand Down Expand Up @@ -973,7 +970,7 @@ def closure_to_tasklet_sdfg(
tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}")
access = state.add_access(name)
idx_accesses[dim] = access
state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0"))
state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0"))
for name, ty in inputs:
if isinstance(ty, ts.FieldType):
ndim = len(ty.dims)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,12 @@ def connectivity_identifier(name: str):


def create_memlet_full(source_identifier: str, source_array: dace.data.Array):
bounds = [(0, size) for size in source_array.shape]
subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds)
return dace.Memlet.simple(source_identifier, subset)
return dace.Memlet.from_array(source_identifier, source_array)


def create_memlet_at(source_identifier: str, index: tuple[str, ...]):
subset = ", ".join(index)
return dace.Memlet(data=source_identifier, subset=subset)
return dace.Memlet.simple(source_identifier, subset)


def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import gt4py.next as gtx
from gt4py.next.iterator import embedded
from gt4py.next.program_processors.runners import gtfn
from gt4py.next.program_processors.runners import dace_iterator, gtfn

from next_tests.integration_tests import cases
from next_tests.integration_tests.cases import cartesian_case # noqa: F401
Expand All @@ -26,7 +26,7 @@


@pytest.mark.requires_gpu
@pytest.mark.parametrize("fieldview_backend", [gtfn.run_gtfn_gpu])
@pytest.mark.parametrize("fieldview_backend", [dace_iterator.run_dace, gtfn.run_gtfn_gpu])
def test_copy(cartesian_case, fieldview_backend): # noqa: F811 # fixtures
import cupy as cp # TODO(ricoh): replace with storages solution when available

Expand Down

0 comments on commit 3c463a6

Please sign in to comment.