Skip to content

Commit

Permalink
resolved some pre-commit errors
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Dec 13, 2023
1 parent b94e81c commit f815712
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def field(
def connectivity(
definition: Any,
/,
codomain: Dimension,
codomain: Dimension | fbuiltins.FieldOffset,
*,
domain: Optional[DomainLike] = None,
dtype: Optional[core_defs.DType] = None,
Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,9 @@ def remap(
connectivity.codomain, fbuiltins.FieldOffset
):
new_buffer = (
np.diagonal(new_buffer).T
xp.transpose(xp.diagonal(new_buffer))
if dim.kind == "horizontal"
else np.diagonal(new_buffer.T)
else xp.diagonal(xp.transpose(new_buffer))
)

return self.__class__.from_array(new_buffer, domain=new_domain, dtype=self.dtype)
Expand Down Expand Up @@ -356,7 +356,7 @@ def from_array( # type: ignore[override]
cls,
data: npt.ArrayLike | core_defs.NDArrayObject,
/,
codomain: common.DimT | fbuiltins.FieldOffset,
codomain: common.DimT,
*,
domain: common.DomainLike,
dtype: Optional[core_defs.DTypeLike] = None,
Expand Down Expand Up @@ -603,6 +603,9 @@ def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdA
raise AssertionError("This is the NdArrayField implementation of 'fbuiltins.astype'.")


NdArrayField.register_builtin_func(fbuiltins.astype, _astype)


def _as_offset(offset_: fbuiltins.FieldOffset, field: common.Field) -> NdArrayConnectivityField:
if isinstance(field, NdArrayField):
# change field.ndarray from local to global
Expand All @@ -615,7 +618,6 @@ def _as_offset(offset_: fbuiltins.FieldOffset, field: common.Field) -> NdArrayCo
)


NdArrayField.register_builtin_func(fbuiltins.astype, _astype)
NdArrayField.register_builtin_func(experimental.as_offset, _as_offset)


Expand Down
5 changes: 4 additions & 1 deletion src/gt4py/next/ffront/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np

from gt4py.next import common
from gt4py.next.ffront.fbuiltins import BuiltInFunction, FieldOffset

Expand All @@ -22,4 +24,5 @@ def as_offset(
field: common.Field,
/,
) -> common.ConnectivityField:
raise NotImplementedError()
global_index_arr = np.arange(field.ndarray.shape[0]) + field.ndarray
return common.connectivity(global_index_arr, codomain=offset_, domain=field.domain)

0 comments on commit f815712

Please sign in to comment.