Skip to content

Commit

Permalink
feat[next]: Remove dace_iterator backend and pass_manager_legacy (#1753)
Browse files Browse the repository at this point in the history
The dace orchestration tests are temporarily skipped until #1742 is
merged.
The dace backend with SDFG optimization is temporarily disabled in unit
tests until #1639 is merged.
A second PR will reorganize the files in dace backend module.
  • Loading branch information
edopao authored Dec 3, 2024
1 parent e5abcd2 commit a2551ac
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 3,341 deletions.
181 changes: 0 additions & 181 deletions src/gt4py/next/iterator/transforms/pass_manager_legacy.py

This file was deleted.

62 changes: 28 additions & 34 deletions src/gt4py/next/program_processors/runners/dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,34 @@

import factory

import gt4py._core.definitions as core_defs
import gt4py.next.allocators as next_allocators
from gt4py.next import backend
from gt4py.next.otf import workflow
from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow
from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow
from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory


class DaCeIteratorBackendFactory(GTFNBackendFactory):
class DaCeFieldviewBackendFactory(GTFNBackendFactory):
class Meta:
model = backend.Backend

class Params:
otf_workflow = factory.SubFactory(
dace_iterator_workflow.DaCeWorkflowFactory,
device_type=factory.SelfAttribute("..device_type"),
use_field_canonical_representation=factory.SelfAttribute(
"..use_field_canonical_representation"
),
name_device = "cpu"
name_cached = ""
name_postfix = ""
gpu = factory.Trait(
allocator=next_allocators.StandardGPUFieldBufferAllocator(),
device_type=next_allocators.CUPY_DEVICE or core_defs.DeviceType.CUDA,
name_device="gpu",
)
auto_optimize = factory.Trait(
otf_workflow__translation__auto_optimize=True, name_postfix="_opt"
cached = factory.Trait(
executor=factory.LazyAttribute(
lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function)
),
name_cached="_cached",
)
use_field_canonical_representation: bool = False

name = factory.LazyAttribute(
lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.itir"
)

transforms = backend.LEGACY_TRANSFORMS


run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True)
run_dace_cpu_noopt = DaCeIteratorBackendFactory(cached=True, auto_optimize=False)

run_dace_gpu = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=True)
run_dace_gpu_noopt = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=False)

itir_cpu = run_dace_cpu
itir_gpu = run_dace_gpu


class DaCeFieldviewBackendFactory(GTFNBackendFactory):
class Params:
device_type = core_defs.DeviceType.CPU
otf_workflow = factory.SubFactory(
dace_fieldview_workflow.DaCeWorkflowFactory,
device_type=factory.SelfAttribute("..device_type"),
Expand All @@ -55,11 +44,16 @@ class Params:
auto_optimize = factory.Trait(name_postfix="_opt")

name = factory.LazyAttribute(
lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.gtir"
lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}"
)

executor = factory.LazyAttribute(lambda o: o.otf_workflow)
allocator = next_allocators.StandardCPUFieldBufferAllocator()
transforms = backend.DEFAULT_TRANSFORMS


gtir_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False)
gtir_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False)
run_dace_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=True)
run_dace_cpu_noopt = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False)

run_dace_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=True)
run_dace_gpu_noopt = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False)
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
cp = None


def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any:
def _convert_arg(arg: Any, sdfg_param: str) -> Any:
if not isinstance(arg, gtx_common.Field):
return arg
if len(arg.domain.dims) == 0:
Expand All @@ -41,26 +41,14 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation:
raise RuntimeError(
f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}."
)
if not use_field_canonical_representation:
return arg.ndarray
# the canonical representation requires alphabetical ordering of the dimensions in field domain definition
sorted_dims = dace_utils.get_sorted_dims(arg.domain.dims)
ndim = len(sorted_dims)
dim_indices = [dim_index for dim_index, _ in sorted_dims]
if isinstance(arg.ndarray, np.ndarray):
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
else:
assert cp is not None and isinstance(arg.ndarray, cp.ndarray)
return cp.moveaxis(arg.ndarray, range(ndim), dim_indices)


def _get_args(
sdfg: dace.SDFG, args: Sequence[Any], use_field_canonical_representation: bool
) -> dict[str, Any]:
return arg.ndarray


def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
sdfg_params: Sequence[str] = sdfg.arg_names
flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args))
return {
sdfg_param: _convert_arg(arg, sdfg_param, use_field_canonical_representation)
sdfg_param: _convert_arg(arg, sdfg_param)
for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True)
}

Expand Down Expand Up @@ -154,10 +142,10 @@ def get_sdfg_conn_args(

def get_sdfg_args(
sdfg: dace.SDFG,
offset_provider: gtx_common.OffsetProvider,
*args: Any,
check_args: bool = False,
on_gpu: bool = False,
use_field_canonical_representation: bool = True,
**kwargs: Any,
) -> dict[str, Any]:
"""Extracts the arguments needed to call the SDFG.
Expand All @@ -166,10 +154,10 @@ def get_sdfg_args(
Args:
sdfg: The SDFG for which we want to get the arguments.
offset_provider: Offset provider.
"""
offset_provider = kwargs["offset_provider"]

dace_args = _get_args(sdfg, args, use_field_canonical_representation)
dace_args = _get_args(sdfg, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu)
dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

import re
from typing import Final, Literal, Optional, Sequence
from typing import Final, Literal, Optional

import dace

Expand Down Expand Up @@ -96,10 +96,3 @@ def filter_connectivity_types(
for offset, conn in offset_provider_type.items()
if isinstance(conn, gtx_common.NeighborConnectivityType)
}


def get_sorted_dims(
dims: Sequence[gtx_common.Dimension],
) -> Sequence[tuple[int, gtx_common.Dimension]]:
"""Sort list of dimensions in alphabetical order."""
return sorted(enumerate(dims), key=lambda v: v[1].value)
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def decorated_program(

sdfg_args = dace_backend.get_sdfg_args(
sdfg,
offset_provider,
*args,
check_args=False,
offset_provider=offset_provider,
on_gpu=on_gpu,
use_field_canonical_representation=use_field_canonical_representation,
)
Expand Down
Loading

0 comments on commit a2551ac

Please sign in to comment.