diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 8c48183024..622747b52b 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -161,13 +161,15 @@ def from_array( return cls(domain, array) def remap( - self: NdArrayField, connectivity: common.ConnectivityField + self: NdArrayField, connectivity: common.ConnectivityField | fbuiltins.FieldOffset ) -> NdArrayField: # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField if not common.is_connectivity_field(connectivity): + assert isinstance(connectivity, fbuiltins.FieldOffset) connectivity = connectivity.as_connectivity_field() assert common.is_connectivity_field(connectivity) + # Compute the new domain dim = connectivity.codomain dim_idx = self.domain.dim_index(dim)