Skip to content

Commit

Permalink
Handle connecitivity fields
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Dec 11, 2024
1 parent a49e341 commit 0b2ca6c
Showing 1 changed file with 46 additions and 33 deletions.
79 changes: 46 additions & 33 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def handle_tuple(arg: Any, convert_arg: Callable) -> Any:
return tuple(convert_arg(a) for a in arg)


def handle_field(arg: Any) -> tuple:
def handle_field(arg: nd_array_field.NdArrayField) -> tuple:
arr = arg.ndarray
origin = getattr(arg, "__gt_origin__", tuple([0] * len(arg.domain)))
return arr, origin
Expand All @@ -45,10 +45,55 @@ def handle_field(arg: Any) -> tuple:
nd_array_field.NumPyArrayField: handle_field,
}

ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]]


def handle_neighbortable(
conn: embedded.NeighborTableOffsetProvider, # type: ignore
zero_tuple: tuple[int, ...],
device: core_defs.DeviceType,
copy: bool,
) -> ConnectivityArg:
if not copy:
return (conn.table, zero_tuple) # type: ignore
return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore


def handle_connectivity_field(
conn: nd_array_field.NdArrayField,
zero_tuple: tuple[int, ...],
device: core_defs.DeviceType,
copy: bool,
) -> ConnectivityArg:
if not copy:
return (conn.ndarray, zero_tuple)
return (_ensure_is_on_device(conn.ndarray, device), zero_tuple)


def handle_dimension(*args: Any, **kwargs: Any) -> None:
return None


def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None:
raise AssertionError(
f"Unsupported offset provider type '{type(conn).__name__}'. "
"Expected 'Connectivity' or 'Dimension'."
)


type_handlers_connectivity_args = {
embedded.NeighborTableOffsetProvider: handle_neighbortable,
nd_array_field.NumPyArrayConnectivityField: handle_connectivity_field,
common.Dimension: handle_dimension,
}

try:
import cupy as cp

type_handlers_convert_args[nd_array_field.CuPyArrayField] = handle_field
type_handlers_connectivity_args[nd_array_field.CuPyArrayConnectivityField] = (
handle_connectivity_field
)
except ImportError:
cp = None

Expand Down Expand Up @@ -87,38 +132,6 @@ def convert_arg(arg: Any) -> Any:
return handler(arg)


ConnectivityArg = tuple[core_defs.NDArrayObject, tuple[int, ...]]


def handle_connectivity(
conn: embedded.NeighborTableOffsetProvider, # type: ignore
zero_tuple: tuple[int, ...],
device: core_defs.DeviceType,
copy: bool,
) -> ConnectivityArg:
if not copy:
return (conn.table, zero_tuple) # type: ignore
return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore


def handle_dimension(*args: Any, **kwargs: Any) -> None:
return None


def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None:
raise AssertionError(
f"Unsupported offset provider type '{type(conn).__name__}'. "
"Expected 'Connectivity' or 'Dimension'."
)


type_handlers_connectivity_args = {
embedded.NeighborTableOffsetProvider: handle_connectivity,
nd_array_field.NumPyArrayConnectivityField: handle_connectivity,
common.Dimension: handle_dimension,
}


def _ensure_is_on_device(
connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType
) -> core_defs.NDArrayObject:
Expand Down

0 comments on commit 0b2ca6c

Please sign in to comment.