Skip to content

Commit

Permalink
removed [0][0] indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Jan 5, 2024
1 parent 9f2bcc6 commit 90f6796
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,9 @@ def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdA
def _as_offset(offset_: fbuiltins.FieldOffset, field: common.Field) -> NdArrayConnectivityField:
if isinstance(field, NdArrayField):
# change field.ndarray from relative to absolute
offset_dim = np.where(list(map(lambda x: x == offset_.source, field.domain.dims)))[0][0]
offset_dim = np.squeeze(
np.where(list(map(lambda x: x == offset_.source, field.domain.dims)))
).item()
new_connectivity = np.indices(field.ndarray.shape)[offset_dim] + field.ndarray
return NumPyArrayConnectivityField.from_array(
new_connectivity, codomain=offset_.source, domain=field.domain
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def as_offset(
field: common.Field,
/,
) -> common.ConnectivityField:
offset_dim = np.where(list(map(lambda x: x == offset_.source, field.domain.dims)))[0][0]
offset_dim = np.squeeze(
np.where(list(map(lambda x: x == offset_.source, field.domain.dims)))
).item()
new_connectivity = np.indices(field.ndarray.shape)[offset_dim] + field.ndarray
return common.connectivity(new_connectivity, codomain=offset_.source, domain=field.domain)

Expand Down

0 comments on commit 90f6796

Please sign in to comment.