Skip to content

Commit

Permalink
Fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Dec 11, 2024
1 parent 9eb29be commit 8a0793b
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import diskcache
import factory
import filelock

import gt4py._core.definitions as core_defs
import gt4py.next.allocators as next_allocators
from gt4py.eve import utils
from gt4py.eve.utils import content_hash
from gt4py.next import NeighborTableOffsetProvider, backend, common, config
from gt4py.next import backend, common, config
from gt4py.next.embedded import nd_array_field
from gt4py.next.iterator import ir as itir
from gt4py.next.otf import arguments
from gt4py.next.otf import recipes, stages, workflow
from gt4py.next.iterator import embedded, ir as itir
from gt4py.next.otf import arguments, recipes, stages, workflow
from gt4py.next.otf.binding import nanobind
from gt4py.next.otf.compilation import compiler
from gt4py.next.otf.compilation.build_systems import compiledb
Expand All @@ -42,15 +42,13 @@ def handle_field(arg: Any) -> tuple:

type_handlers_convert_args = {
tuple: handle_tuple,
NumPyArrayField: handle_field,
nd_array_field.NumPyArrayField: handle_field,
}

try:
import cupy as cp

from gt4py.next.embedded.nd_array_field import CuPyArrayField

type_handlers_convert_args[CuPyArrayField] = handle_field
type_handlers_convert_args[nd_array_field.CuPyArrayField] = handle_field
except ImportError:
cp = None

Expand All @@ -60,12 +58,12 @@ def handle_default(arg: Any) -> Any:


def convert_args(
inp: stages.ExtendedCompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU
inp: stages.ExtendedCompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU
) -> stages.CompiledProgram:
def decorated_program(
*args: Any,
offset_provider: dict[str, common.Connectivity | common.Dimension],
out: Any = None,
*args: Any,
offset_provider: dict[str, common.Connectivity | common.Dimension],
out: Any = None,
) -> None:
if out is not None:
args = (*args, out)
Expand Down Expand Up @@ -93,36 +91,37 @@ def convert_arg(arg: Any) -> Any:


def handle_connectivity(
conn: NeighborTableOffsetProvider, zero_tuple: tuple[int, ...], device: core_defs.DeviceType, copy: bool
conn: embedded.NeighborTableOffsetProvider, # type: ignore
zero_tuple: tuple[int, ...],
device: core_defs.DeviceType,
copy: bool,
) -> ConnectivityArg:
if not copy:
return (conn.table, zero_tuple)
return (_ensure_is_on_device(conn.table, device), zero_tuple)
return (conn.table, zero_tuple) # type: ignore
return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore


def handle_dimension(*args: Any, **kwargs: Any) -> None:
return None


def handle_invalid_type(
conn: Any, *args: Any, **kwargs: Any
) -> None:
def handle_invalid_type(conn: Any, *args: Any, **kwargs: Any) -> None:
raise AssertionError(
f"Unsupported offset provider type '{type(conn).__name__}'. "
"Expected 'Connectivity' or 'Dimension'."
)


type_handlers_connectivity_args = {
NeighborTableOffsetProvider: handle_connectivity,
embedded.NeighborTableOffsetProvider: handle_connectivity,
nd_array_field.NumPyArrayConnectivityField: handle_connectivity,
nd_array_field.CuPyArrayConnectivityField: handle_connectivity,
nd_array_field.CuPyArrayConnectivityField: handle_connectivity,
common.Dimension: handle_dimension,
}


def _ensure_is_on_device(
connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType
connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType
) -> core_defs.NDArrayObject:
if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]:
import cupy as cp
Expand All @@ -137,12 +136,13 @@ def _ensure_is_on_device(


def extract_connectivity_args(
offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType,
copy: bool = False
offset_provider: dict[str, common.Connectivity | common.Dimension],
device: core_defs.DeviceType,
copy: bool = False,
) -> list[ConnectivityArg]:
zero_tuple = (0, 0)
args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = []
for name, conn in offset_provider.items():
for conn in offset_provider.values():
handler = type_handlers_connectivity_args.get(type(conn), handle_invalid_type)
result = handler(conn, zero_tuple, device, copy) # type: ignore
if result:
Expand Down

0 comments on commit 8a0793b

Please sign in to comment.