Skip to content

Commit

Permalink
Always check if it is on device
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Dec 11, 2024
1 parent 0b2ca6c commit 2607a81
Showing 1 changed file with 1 addition and 7 deletions.
8 changes: 1 addition & 7 deletions src/gt4py/next/program_processors/runners/gtfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ def handle_neighbortable(
conn: embedded.NeighborTableOffsetProvider, # type: ignore
zero_tuple: tuple[int, ...],
device: core_defs.DeviceType,
copy: bool,
) -> ConnectivityArg:
if not copy:
return (conn.table, zero_tuple) # type: ignore
return (_ensure_is_on_device(conn.table, device), zero_tuple) # type: ignore


Expand All @@ -65,8 +62,6 @@ def handle_connectivity_field(
device: core_defs.DeviceType,
copy: bool,
) -> ConnectivityArg:
if not copy:
return (conn.ndarray, zero_tuple)
return (_ensure_is_on_device(conn.ndarray, device), zero_tuple)


Expand Down Expand Up @@ -150,13 +145,12 @@ def _ensure_is_on_device(
def extract_connectivity_args(
offset_provider: dict[str, common.Connectivity | common.Dimension],
device: core_defs.DeviceType,
copy: bool = False,
) -> list[ConnectivityArg]:
zero_tuple = (0, 0)
args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = []
for conn in offset_provider.values():
handler = type_handlers_connectivity_args.get(type(conn), handle_invalid_type)
result = handler(conn, zero_tuple, device, copy) # type: ignore
result = handler(conn, zero_tuple, device) # type: ignore
if result:
args.append(result)
return args
Expand Down

0 comments on commit 2607a81

Please sign in to comment.