diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 5552f76848..597708ac65 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -325,7 +325,7 @@ class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ common.ConnectivityField[common.DimsT, common.DimT], NdArrayField[common.DimsT, core_defs.IntegralScalar], ): - _codomain: common.DimT + _codomain: common.DimT | fbuiltins.FieldOffset @functools.cached_property def _cache(self) -> dict: @@ -337,7 +337,7 @@ def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ig @property # type: ignore[override] # TODO(havogt): instead of inheriting from NdArrayField, steal implementation or common base - def codomain(self) -> common.DimT: + def codomain(self) -> common.DimT | fbuiltins.FieldOffset: return self._codomain @functools.cached_property @@ -346,7 +346,11 @@ def kind(self) -> common.ConnectivityKind: if self.domain.ndim > 1: kind |= common.ConnectivityKind.MODIFY_RANK kind |= common.ConnectivityKind.MODIFY_DIMS - if self.domain.dim_index(self.codomain) is None: + if isinstance(self.codomain, fbuiltins.FieldOffset): + codomain_new = self.codomain.source + else: + codomain_new = self.codomain + if self.domain.dim_index(codomain_new) is None: kind |= common.ConnectivityKind.MODIFY_DIMS return kind @@ -356,7 +360,7 @@ def from_array( # type: ignore[override] cls, data: npt.ArrayLike | core_defs.NDArrayObject, /, - codomain: common.DimT, + codomain: common.DimT | fbuiltins.FieldOffset, *, domain: common.DomainLike, dtype: Optional[core_defs.DTypeLike] = None, @@ -618,7 +622,7 @@ def _as_offset(offset_: fbuiltins.FieldOffset, field: common.Field) -> NdArrayCo ) -NdArrayField.register_builtin_func(experimental.as_offset, _as_offset) +NdArrayField.register_builtin_func(experimental.as_offset, _as_offset) # type: ignore[has-type] #type specified in experimental def _get_slices_from_domain_slice(