Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into remove_array2
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Nov 20, 2023
2 parents 1e7b0ed + 42912cc commit 8234894
Show file tree
Hide file tree
Showing 28 changed files with 410 additions and 191 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ markers = [
'requires_dace: tests that require `dace` package',
'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)',
'uses_applied_shifts: tests that require backend support for applied-shifts',
'uses_can_deref: tests that require backend support for can_deref',
'uses_constant_fields: tests that require backend support for constant fields',
'uses_dynamic_offsets: tests that require backend support for dynamic offsets',
'uses_if_stmts: tests that require backend support for if-statements',
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ def shape(self) -> tuple[int, ...]:
def dtype(self) -> Any:
...

def astype(self, dtype: npt.DTypeLike) -> NDArrayObject:
...

def __getitem__(self, item: Any) -> NDArrayObject:
...

Expand Down
9 changes: 6 additions & 3 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,24 @@ def __getitem__(self, index: int | slice) -> int | UnitRange:
else:
raise IndexError("UnitRange index out of range")

def __and__(self, other: Set[Any]) -> UnitRange:
def __and__(self, other: Set[int]) -> UnitRange:
if isinstance(other, UnitRange):
start = max(self.start, other.start)
stop = min(self.stop, other.stop)
return UnitRange(start, stop)
else:
raise NotImplementedError("Can only find the intersection between UnitRange instances.")

def __le__(self, other: Set[Any]):
def __le__(self, other: Set[int]):
if isinstance(other, UnitRange):
# required for infinity comparison
return self.start >= other.start and self.stop <= other.stop
elif len(self) == Infinity.positive():
return False
else:
return Set.__le__(self, other)

__ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented

def __str__(self) -> str:
return f"({self.start}:{self.stop})"

Expand Down
13 changes: 8 additions & 5 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def ndarray(self) -> core_defs.NDArrayObject:
return self._ndarray

def asnumpy(self) -> np.ndarray:
return np.asarray(self._ndarray)
if self.array_ns == cp:
return cp.asnumpy(self._ndarray)
else:
return np.asarray(self._ndarray)

@property
def dtype(self) -> core_defs.DType[core_defs.ScalarT]:
Expand Down Expand Up @@ -286,7 +289,7 @@ def _np_cp_setitem(
_nd_array_implementations = [np]


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class NumPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = np

Expand All @@ -299,7 +302,7 @@ class NumPyArrayField(NdArrayField):
if cp:
_nd_array_implementations.append(cp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class CuPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = cp

Expand All @@ -311,7 +314,7 @@ class CuPyArrayField(NdArrayField):
if jnp:
_nd_array_implementations.append(jnp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class JaxArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = jnp

Expand Down Expand Up @@ -353,7 +356,7 @@ def _builtins_broadcast(


def _astype(field: NdArrayField, type_: type) -> NdArrayField:
return field.__class__.from_array(field.ndarray, domain=field.domain, dtype=type_)
return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain)


NdArrayField.register_builtin_func(fbuiltins.astype, _astype) # type: ignore[arg-type] # TODO(havogt) the registry should not be for any Field
Expand Down
46 changes: 19 additions & 27 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,14 @@ class Program:
past_node: past.Program
closure_vars: dict[str, Any]
definition: Optional[types.FunctionType] = None
backend: None | eve_utils.NOTHING | ppi.ProgramExecutor = (
eve_utils.NOTHING
) # TODO(havogt): temporary change, remove once `None` is default backend
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None

@classmethod
def from_function(
cls,
definition: types.FunctionType,
backend: Optional[ppi.ProgramExecutor] = None,
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
grid_type: Optional[GridType] = None,
) -> Program:
source_def = SourceDefinition.from_function(definition)
Expand Down Expand Up @@ -287,23 +285,16 @@ def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> No
rewritten_args, size_args, kwargs = self._process_args(args, kwargs)

if self.backend is None:
self.definition(*rewritten_args, **kwargs)
return

backend = self.backend
if self.backend is eve_utils.NOTHING:
warnings.warn(
UserWarning(
f"Field View Program '{self.itir.id}': Using default ({DEFAULT_BACKEND}) backend."
f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend."
)
)
backend = DEFAULT_BACKEND

ppi.ensure_processor_kind(backend, ppi.ProgramExecutor)
if "debug" in kwargs:
debug(self.itir)
self.definition(*rewritten_args, **kwargs)
return

backend(
self.backend(
self.itir,
*rewritten_args,
*size_args,
Expand Down Expand Up @@ -548,14 +539,14 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
foast_node: OperatorNodeT
closure_vars: dict[str, Any]
definition: Optional[types.FunctionType] = None
backend: Optional[ppi.ProgramExecutor] = None
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None

@classmethod
def from_function(
cls,
definition: types.FunctionType,
backend: Optional[ppi.ProgramExecutor] = None,
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
grid_type: Optional[GridType] = None,
*,
operator_node_cls: type[OperatorNodeT] = foast.FieldOperator,
Expand Down Expand Up @@ -706,7 +697,7 @@ def __call__(
)
else:
# "out" -> field_operator called from program in embedded execution
# TODO put offset_provider in ctxt var
# TODO(egparedes): put offset_provider in ctxt var here when implementing remap
domain = kwargs.pop("domain", None)
res = self.definition(*args, **kwargs)
_tuple_assign_field(
Expand All @@ -715,22 +706,23 @@ def __call__(
return
else:
# field_operator called from other field_operator in embedded execution
assert self.backend is None
return self.definition(*args, **kwargs)


def _tuple_assign_field(
tgt: tuple[common.Field | tuple, ...] | common.Field,
src: tuple[common.Field | tuple, ...] | common.Field,
domain: common.Domain,
target: tuple[common.Field | tuple, ...] | common.Field,
source: tuple[common.Field | tuple, ...] | common.Field,
domain: Optional[common.Domain],
):
if isinstance(tgt, tuple):
if not isinstance(src, tuple):
raise RuntimeError(f"Cannot assign {src} to {tgt}.")
for t, s in zip(tgt, src):
if isinstance(target, tuple):
if not isinstance(source, tuple):
raise RuntimeError(f"Cannot assign {source} to {target}.")
for t, s in zip(target, source):
_tuple_assign_field(t, s, domain)
else:
domain = domain or tgt.domain
tgt[domain] = src[domain]
domain = domain or target.domain
target[domain] = source[domain]


@typing.overload
Expand Down
24 changes: 17 additions & 7 deletions src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def broadcast(
dims: tuple[common.Dimension, ...],
/,
) -> common.Field:
assert core_defs.is_scalar_type(field)
assert core_defs.is_scalar_type(
field
) # default implementation for scalars, Fields are handled via dispatch
return common.field(
np.asarray(field)[
tuple([np.newaxis] * len(dims))
Expand All @@ -212,9 +214,16 @@ def where(


@BuiltInFunction
def astype(field: common.Field | core_defs.ScalarT, type_: type, /) -> common.Field:
assert core_defs.is_scalar_type(field)
return type_(field)
def astype(
value: Field | core_defs.ScalarT | Tuple,
type_: type,
/,
) -> Field | core_defs.ScalarT | Tuple:
if isinstance(value, tuple):
return tuple(astype(v, type_) for v in value)
# default implementation for scalars, Fields are handled via dispatch
assert core_defs.is_scalar_type(value)
return core_defs.dtype(type_).scalar_type(value)


UNARY_MATH_NUMBER_BUILTIN_NAMES = ["abs"]
Expand Down Expand Up @@ -248,7 +257,7 @@ def astype(field: common.Field | core_defs.ScalarT, type_: type, /) -> common.Fi
def _make_unary_math_builtin(name):
def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT:
# TODO(havogt): enable once we have a failing test (see `test_math_builtin_execution.py`)
# assert core_defs.is_scalar_type(value) # noqa: E800 # commented code
# assert core_defs.is_scalar_type(value) # default implementation for scalars, Fields are handled via dispatch # noqa: E800 # commented code
# return getattr(math, name)(value)# noqa: E800 # commented code
raise NotImplementedError()

Expand Down Expand Up @@ -278,9 +287,10 @@ def impl(
rhs: common.Field | core_defs.ScalarT,
/,
) -> common.Field | core_defs.ScalarT:
# default implementation for scalars, Fields are handled via dispatch
assert core_defs.is_scalar_type(lhs)
assert core_defs.is_scalar_type(rhs)
return BINARY_MATH_NUMBER_BUILTIN_TO_PYTHON_SCALAR_FUNCTION[name](lhs, rhs) # type: ignore[operator] # Cannot call function of unknown type
return getattr(np, name)(lhs, rhs)

impl.__name__ = name
globals()[name] = BuiltInFunction(impl)
Expand Down Expand Up @@ -325,7 +335,7 @@ class FieldOffset(runtime.Offset):

def __post_init__(self):
if len(self.target) == 2 and self.target[1].kind != common.DimensionKind.LOCAL:
raise ValueError("Second Dimension in offset must be a local Dimension.")
raise ValueError("Second dimension in offset must be a local dimension.")

def __gt_type__(self):
return ts.OffsetType(source=self.source, target=self.target)
11 changes: 8 additions & 3 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,10 +823,12 @@ def _visit_min_over(self, node: foast.Call, **kwargs) -> foast.Call:
return self._visit_reduction(node, **kwargs)

def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
return_type: ts.TupleType | ts.ScalarType | ts.FieldType
value, new_type = node.args
assert isinstance(
value.type, (ts.FieldType, ts.ScalarType)
value.type, (ts.FieldType, ts.ScalarType, ts.TupleType)
) # already checked using generic mechanism

if not isinstance(new_type, foast.Name) or new_type.id.upper() not in [
kind.name for kind in ts.ScalarKind
]:
Expand All @@ -835,8 +837,11 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call:
f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.",
)

return_type = with_altered_scalar_kind(
value.type, getattr(ts.ScalarKind, new_type.id.upper())
return_type = type_info.apply_to_primitive_constituents(
value.type,
lambda primitive_type: with_altered_scalar_kind(
primitive_type, getattr(ts.ScalarKind, new_type.id.upper())
),
)

return foast.Call(
Expand Down
35 changes: 29 additions & 6 deletions src/gt4py/next/ffront/foast_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,9 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr:

def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall:
assert len(node.args) == 2 and isinstance(node.args[1], foast.Name)
obj, dtype = node.args[0], node.args[1].id

# TODO check that we test astype that results in a itir.map_ operation
return self._map(
im.lambda_("it")(im.call("cast_")("it", str(dtype))),
obj,
obj, new_type = node.args[0], node.args[1].id
return self._process_elements(
lambda x: im.call("cast_")(x, str(new_type)), obj, obj.type, **kwargs
)

def _visit_where(self, node: foast.Call, **kwargs) -> itir.FunCall:
Expand Down Expand Up @@ -403,6 +400,32 @@ def _map(self, op, *args, **kwargs):

return im.promote_to_lifted_stencil(im.call(op))(*lowered_args)

def _process_elements(
self,
process_func: Callable[[itir.Expr], itir.Expr],
obj: foast.Expr,
current_el_type: ts.TypeSpec,
current_el_expr: itir.Expr = im.ref("expr"),
):
"""Recursively applies a processing function to all primitive constituents of a tuple."""
if isinstance(current_el_type, ts.TupleType):
# TODO(ninaburg): Refactor to avoid duplicating lowered obj expression for each tuple element.
return im.promote_to_lifted_stencil(lambda *elts: im.make_tuple(*elts))(
*[
self._process_elements(
process_func,
obj,
current_el_type.types[i],
im.tuple_get(i, current_el_expr),
)
for i in range(len(current_el_type.types))
]
)
elif type_info.contains_local_field(current_el_type):
raise NotImplementedError("Processing fields with local dimension is not implemented.")
else:
return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj)


class FieldOperatorLoweringError(Exception):
...
Loading

0 comments on commit 8234894

Please sign in to comment.