diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index a2e1b62569..c82af50a66 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -15,15 +15,15 @@ import diskcache import factory import filelock + import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py.eve import utils from gt4py.eve.utils import content_hash -from gt4py.next import NeighborTableOffsetProvider, backend, common, config +from gt4py.next import backend, common, config from gt4py.next.embedded import nd_array_field -from gt4py.next.iterator import ir as itir -from gt4py.next.otf import arguments -from gt4py.next.otf import recipes, stages, workflow +from gt4py.next.iterator import embedded, ir as itir +from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler from gt4py.next.otf.compilation.build_systems import compiledb @@ -42,15 +42,13 @@ def handle_field(arg: Any) -> tuple: type_handlers_convert_args = { tuple: handle_tuple, - NumPyArrayField: handle_field, + nd_array_field.NumPyArrayField: handle_field, } try: import cupy as cp - from gt4py.next.embedded.nd_array_field import CuPyArrayField - - type_handlers_convert_args[CuPyArrayField] = handle_field + type_handlers_convert_args[nd_array_field.CuPyArrayField] = handle_field except ImportError: cp = None @@ -60,12 +58,12 @@ def handle_default(arg: Any) -> Any: def convert_args( - inp: stages.ExtendedCompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU + inp: stages.ExtendedCompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU ) -> stages.CompiledProgram: def decorated_program( - *args: Any, - offset_provider: dict[str, common.Connectivity | common.Dimension], - out: Any = None, + *args: Any, + offset_provider: dict[str, common.Connectivity | common.Dimension], + out: Any = None, ) -> None: if out is not None: args = (*args, out) @@ -93,20 +91,21 @@ def convert_arg(arg: Any) -> Any: def handle_connectivity( - conn: NeighborTableOffsetProvider, zero_tuple: tuple[int, ...], device: core_defs.DeviceType, copy: bool + conn: embedded.NeighborTableOffsetProvider, # type: ignore + zero_tuple: tuple[int, ...], + device: core_defs.DeviceType, + copy: bool, ) -> ConnectivityArg: if not copy: - return (conn.table, zero_tuple) - return (_ensure_is_on_device(conn.table, device), zero_tuple) + return (conn.table, zero_tuple) # type: ignore + return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore def handle_dimension(*args: Any, **kwargs: Any) -> None: return None -def handle_invalid_type( - conn: Any, *args: Any, **kwargs: Any -) -> None: +def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None: raise AssertionError( f"Unsupported offset provider type '{type(conn).__name__}'. " "Expected 'Connectivity' or 'Dimension'." @@ -114,15 +113,15 @@ def handle_invalid_type( type_handlers_connectivity_args = { - NeighborTableOffsetProvider: handle_connectivity, + embedded.NeighborTableOffsetProvider: handle_connectivity, nd_array_field.NumPyArrayConnectivityField: handle_connectivity, - nd_array_field.CuPyArrayConnectivityField: handle_connectivity, + nd_array_field.CuPyArrayConnectivityField: handle_connectivity, common.Dimension: handle_dimension, } def _ensure_is_on_device( - connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType + connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType ) -> core_defs.NDArrayObject: if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]: import cupy as cp @@ -137,12 +136,13 @@ def _ensure_is_on_device( def extract_connectivity_args( - offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType, - copy: bool = False + offset_provider: dict[str, common.Connectivity | common.Dimension], + device: core_defs.DeviceType, + copy: bool = False, ) -> list[ConnectivityArg]: zero_tuple = (0, 0) args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [] - for name, conn in offset_provider.items(): + for conn in offset_provider.values(): handler = type_handlers_connectivity_args.get(type(conn), handle_invalid_type) result = handler(conn, zero_tuple, device, copy) # type: ignore if result: