From 8e644585361aac30bf97f753e652117c0884bde5 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 12 Dec 2023 13:04:03 +0100 Subject: [PATCH 01/13] feat[next][dace]: Add support for if expressions with tuple argument (#1393) Some icon4py stencils require support for if expressions with tuple arguments. This PR adds support to the DaCe backend in the visitor of builtin_if function. Additionally, this PR contains one fix in the result of builtin_tuple_get, which should return a list. --- .../runners/dace_iterator/__init__.py | 1 - .../runners/dace_iterator/itir_to_tasklet.py | 42 ++++++++++++++----- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 34ba2d2d95..acfa06b456 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -247,7 +247,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) - sdfg: Optional[dace.SDFG] = None if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index f6f197859b..32b8cbf2b1 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -321,16 +321,36 @@ def builtin_can_deref( def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: - args = [arg for li in transformer.visit(node_args) for arg in li] - expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args + args = transformer.visit(node_args) + assert len(args) == 3 + if_node = args[0][0] if isinstance(args[0], list) else args[0] + + # the argument could be a list of elements on each branch representing the result of `make_tuple` + # however, the normal case is to find one value expression + assert len(args[1]) == len(args[2]) + if_expr_args = [ + (a[0] if isinstance(a, list) else a, b[0] if isinstance(b, list) else b) + for a, b in zip(args[1], args[2]) ] - expr = "({1} if {0} else {2})".format(*internals) - node_type = transformer.node_types[id(node)] - assert isinstance(node_type, itir_typing.Val) - type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(expr_args, expr, type_, "if") + + # in case of tuple arguments, generate one if-tasklet for each element of the output tuple + if_expr_values = [] + for a, b in if_expr_args: + assert a.dtype == b.dtype + expr_args = [ + (arg, f"{arg.value.data}_v") + for arg in (if_node, a, b) + if not isinstance(arg, SymbolExpr) + ] + internals = [ + arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" + for arg in (if_node, a, b) + ] + expr = "({1} if {0} else {2})".format(*internals) + if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if") + if_expr_values.append(if_expr[0]) + + return if_expr_values def builtin_list_get( @@ -356,7 +376,7 @@ def builtin_list_get( def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: - args = [transformer.visit(node_args[0])[0]] + args = transformer.visit(node_args[0]) internals = [f"{arg.value.data}_v" for arg in args] target_type = node_args[1] assert isinstance(target_type, itir.SymRef) @@ -380,7 +400,7 @@ def builtin_tuple_get( elements = transformer.visit(node_args[1]) index = node_args[0] if isinstance(index, itir.Literal): - return elements[int(index.value)] + return [elements[int(index.value)]] raise ValueError("Tuple can only be subscripted with compile-time constants") From a14ad09f6dd3043114238fc820d68621480cfc4e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 12 Dec 2023 13:24:51 +0100 Subject: [PATCH 02/13] feat[next]: Embedded field scan (#1365) Adds the scalar scan operator for embedded field view. --- .gitpod.yml | 2 +- src/gt4py/next/embedded/common.py | 17 ++ src/gt4py/next/embedded/context.py | 4 +- src/gt4py/next/embedded/nd_array_field.py | 8 +- src/gt4py/next/embedded/operators.py | 168 ++++++++++++++++++ src/gt4py/next/ffront/decorator.py | 95 ++++------ src/gt4py/next/field_utils.py | 22 +++ src/gt4py/next/iterator/embedded.py | 19 +- src/gt4py/next/utils.py | 22 ++- tests/next_tests/exclusion_matrices.py | 1 - tests/next_tests/integration_tests/cases.py | 6 +- .../ffront_tests/test_execution.py | 80 +++++++++ .../iterator_tests/test_column_stencil.py | 4 +- .../unit_tests/embedded_tests/test_common.py | 14 +- .../iterator_tests/test_embedded_internals.py | 8 +- 15 files changed, 372 insertions(+), 98 deletions(-) create mode 100644 src/gt4py/next/embedded/operators.py create mode 100644 src/gt4py/next/field_utils.py diff --git a/.gitpod.yml b/.gitpod.yml index 1d579d88eb..802d87796a 100644 --- a/.gitpod.yml +++ b/.gitpod.yml @@ -5,7 +5,7 @@ image: tasks: - name: Setup venv and dev tools init: | - ln -s /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode + ln -sfn /workspace/gt4py/.gitpod/.vscode /workspace/gt4py/.vscode python -m venv .venv source .venv/bin/activate pip install --upgrade pip setuptools wheel diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index d796189ab3..558730cb82 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -14,6 +14,10 @@ from __future__ import annotations +import functools +import itertools +import operator + from gt4py.eve.extended_typing import Any, Optional, Sequence, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions @@ -90,6 +94,19 @@ def _absolute_sub_domain( return common.Domain(*named_ranges) +def intersect_domains(*domains: common.Domain) -> common.Domain: + return functools.reduce( + operator.and_, + domains, + common.Domain(dims=tuple(), ranges=tuple()), + ) + + +def iterate_domain(domain: common.Domain): + for i in itertools.product(*[list(r) for r in domain.ranges]): + yield tuple(zip(domain.dims, i)) + + def _expand_ellipsis( indices: common.RelativeIndexSequence, target_size: int ) -> tuple[common.IntIndex | slice, ...]: diff --git a/src/gt4py/next/embedded/context.py b/src/gt4py/next/embedded/context.py index 5fbdbc6f25..93942a5959 100644 --- a/src/gt4py/next/embedded/context.py +++ b/src/gt4py/next/embedded/context.py @@ -24,7 +24,7 @@ #: Column range used in column mode (`column_axis != None`) in the current embedded iterator #: closure execution context. -closure_column_range: cvars.ContextVar[range] = cvars.ContextVar("column_range") +closure_column_range: cvars.ContextVar[common.NamedRange] = cvars.ContextVar("column_range") _undefined_offset_provider: common.OffsetProvider = {} @@ -37,7 +37,7 @@ @contextlib.contextmanager def new_context( *, - closure_column_range: range | eve.NothingType = eve.NOTHING, + closure_column_range: common.NamedRange | eve.NothingType = eve.NOTHING, offset_provider: common.OffsetProvider | eve.NothingType = eve.NOTHING, ): import gt4py.next.embedded.context as this_module diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ff6a2ceac7..6b69e8f8cc 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -16,7 +16,6 @@ import dataclasses import functools -import operator from collections.abc import Callable, Sequence from types import ModuleType from typing import ClassVar @@ -49,11 +48,10 @@ def _builtin_op(*fields: common.Field | core_defs.Scalar) -> NdArrayField: xp = first.__class__.array_ns op = getattr(xp, array_builtin_name) - domain_intersection = functools.reduce( - operator.and_, - [f.domain for f in fields if common.is_field(f)], - common.Domain(dims=tuple(), ranges=tuple()), + domain_intersection = embedded_common.intersect_domains( + *[f.domain for f in fields if common.is_field(f)] ) + transformed: list[core_defs.NDArrayObject | core_defs.Scalar] = [] for f in fields: if common.is_field(f): diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py new file mode 100644 index 0000000000..f50ace7687 --- /dev/null +++ b/src/gt4py/next/embedded/operators.py @@ -0,0 +1,168 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import dataclasses +from typing import Any, Callable, Generic, ParamSpec, Sequence, TypeVar + +from gt4py import eve +from gt4py._core import definitions as core_defs +from gt4py.next import common, constructors, utils +from gt4py.next.embedded import common as embedded_common, context as embedded_context + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +@dataclasses.dataclass(frozen=True) +class EmbeddedOperator(Generic[_R, _P]): + fun: Callable[_P, _R] + + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: + return self.fun(*args, **kwargs) + + +@dataclasses.dataclass(frozen=True) +class ScanOperator(EmbeddedOperator[_R, _P]): + forward: bool + init: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...] + axis: common.Dimension + + def __call__(self, *args: common.Field | core_defs.Scalar, **kwargs: common.Field | core_defs.Scalar) -> common.Field: # type: ignore[override] # we cannot properly type annotate relative to self.fun + scan_range = embedded_context.closure_column_range.get() + assert self.axis == scan_range[0] + scan_axis = scan_range[0] + domain_intersection = _intersect_scan_args(*args, *kwargs.values()) + non_scan_domain = common.Domain(*[nr for nr in domain_intersection if nr[0] != scan_axis]) + + out_domain = common.Domain( + *[scan_range if nr[0] == scan_axis else nr for nr in domain_intersection] + ) + if scan_axis not in out_domain.dims: + # even if the scan dimension is not in the input, we can scan over it + out_domain = common.Domain(*out_domain, (scan_range)) + + res = _construct_scan_array(out_domain)(self.init) + + def scan_loop(hpos): + acc = self.init + for k in scan_range[1] if self.forward else reversed(scan_range[1]): + pos = (*hpos, (scan_axis, k)) + new_args = [_tuple_at(pos, arg) for arg in args] + new_kwargs = {k: _tuple_at(pos, v) for k, v in kwargs.items()} + acc = self.fun(acc, *new_args, **new_kwargs) + _tuple_assign_value(pos, res, acc) + + if len(non_scan_domain) == 0: + # if we don't have any dimension orthogonal to scan_axis, we need to do one scan_loop + scan_loop(()) + else: + for hpos in embedded_common.iterate_domain(non_scan_domain): + scan_loop(hpos) + + return res + + +def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): + if "out" in kwargs: + # called from program or direct field_operator as program + offset_provider = kwargs.pop("offset_provider", None) + + new_context_kwargs = {} + if embedded_context.within_context(): + # called from program + assert offset_provider is None + else: + # field_operator as program + new_context_kwargs["offset_provider"] = offset_provider + + out = kwargs.pop("out") + domain = kwargs.pop("domain", None) + + flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,)) + assert all(f.domain == flattened_out[0].domain for f in flattened_out) + + out_domain = common.domain(domain) if domain is not None else flattened_out[0].domain + + new_context_kwargs["closure_column_range"] = _get_vertical_range(out_domain) + + with embedded_context.new_context(**new_context_kwargs) as ctx: + res = ctx.run(op, *args, **kwargs) + _tuple_assign_field( + out, + res, + domain=out_domain, + ) + else: + # called from other field_operator + return op(*args, **kwargs) + + +def _get_vertical_range(domain: common.Domain) -> common.NamedRange | eve.NothingType: + vertical_dim_filtered = [nr for nr in domain if nr[0].kind == common.DimensionKind.VERTICAL] + assert len(vertical_dim_filtered) <= 1 + return vertical_dim_filtered[0] if vertical_dim_filtered else eve.NOTHING + + +def _tuple_assign_field( + target: tuple[common.MutableField | tuple, ...] | common.MutableField, + source: tuple[common.Field | tuple, ...] | common.Field, + domain: common.Domain, +): + @utils.tree_map + def impl(target: common.MutableField, source: common.Field): + target[domain] = source[domain] + + impl(target, source) + + +def _intersect_scan_args( + *args: core_defs.Scalar | common.Field | tuple[core_defs.Scalar | common.Field | tuple, ...] +) -> common.Domain: + return embedded_common.intersect_domains( + *[arg.domain for arg in utils.flatten_nested_tuple(args) if common.is_field(arg)] + ) + + +def _construct_scan_array(domain: common.Domain): + @utils.tree_map + def impl(init: core_defs.Scalar) -> common.Field: + return constructors.empty(domain, dtype=type(init)) + + return impl + + +def _tuple_assign_value( + pos: Sequence[common.NamedIndex], + target: common.MutableField | tuple[common.MutableField | tuple, ...], + source: core_defs.Scalar | tuple[core_defs.Scalar | tuple, ...], +) -> None: + @utils.tree_map + def impl(target: common.MutableField, source: core_defs.Scalar): + target[pos] = source + + impl(target, source) + + +def _tuple_at( + pos: Sequence[common.NamedIndex], + field: common.Field | core_defs.Scalar | tuple[common.Field | core_defs.Scalar | tuple, ...], +) -> core_defs.Scalar | tuple[core_defs.ScalarT | tuple, ...]: + @utils.tree_map + def impl(field: common.Field | core_defs.Scalar) -> core_defs.Scalar: + res = field[pos] if common.is_field(field) else field + assert core_defs.is_scalar_type(res) + return res + + return impl(field) # type: ignore[return-value] diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index e06c651b13..8202cda6f5 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -32,8 +32,9 @@ from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, common, embedded as next_embedded +from gt4py.next import allocators as next_allocators, embedded as next_embedded from gt4py.next.common import Dimension, DimensionKind, GridType +from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( dialect_ast_enums, field_operator_ast as foast, @@ -550,6 +551,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): definition: Optional[types.FunctionType] = None backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND grid_type: Optional[GridType] = None + operator_attributes: Optional[dict[str, Any]] = None _program_cache: dict = dataclasses.field(default_factory=dict) @classmethod @@ -586,6 +588,7 @@ def from_function( definition=definition, backend=backend, grid_type=grid_type, + operator_attributes=operator_attributes, ) def __gt_type__(self) -> ts.CallableType: @@ -692,68 +695,38 @@ def __call__( *args, **kwargs, ) -> None: - # TODO(havogt): Don't select mode based on existence of kwargs, - # because now we cannot provide nice error messages. E.g. set context var - # if we are reaching this from a program call. - if "out" in kwargs: - out = kwargs.pop("out") + if not next_embedded.context.within_context() and self.backend is not None: + # non embedded execution offset_provider = kwargs.pop("offset_provider", None) - if self.backend is not None: - # "out" and "offset_provider" -> field_operator as program - # When backend is None, we are in embedded execution and for now - # we disable the program generation since it would involve generating - # Python source code from a PAST node. - args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) - # TODO(tehrengruber): check all offset providers are given - # deduce argument types - arg_types = [] - for arg in args: - arg_types.append(type_translation.from_value(arg)) - kwarg_types = {} - for name, arg in kwargs.items(): - kwarg_types[name] = type_translation.from_value(arg) - - return self.as_program(arg_types, kwarg_types)( - *args, out, offset_provider=offset_provider, **kwargs - ) - else: - # "out" -> field_operator called from program in embedded execution or - # field_operator called directly from Python in embedded execution - domain = kwargs.pop("domain", None) - if not next_embedded.context.within_context(): - # field_operator from Python in embedded execution - with next_embedded.context.new_context(offset_provider=offset_provider) as ctx: - res = ctx.run(self.definition, *args, **kwargs) - else: - # field_operator from program in embedded execution (offset_provicer is already set) - assert ( - offset_provider is None - or next_embedded.context.offset_provider.get() is offset_provider - ) - res = self.definition(*args, **kwargs) - _tuple_assign_field( - out, res, domain=None if domain is None else common.domain(domain) - ) - return + out = kwargs.pop("out") + args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) + # TODO(tehrengruber): check all offset providers are given + # deduce argument types + arg_types = [] + for arg in args: + arg_types.append(type_translation.from_value(arg)) + kwarg_types = {} + for name, arg in kwargs.items(): + kwarg_types[name] = type_translation.from_value(arg) + + return self.as_program(arg_types, kwarg_types)( + *args, out, offset_provider=offset_provider, **kwargs + ) else: - # field_operator called from other field_operator in embedded execution - assert self.backend is None - return self.definition(*args, **kwargs) - - -def _tuple_assign_field( - target: tuple[common.Field | tuple, ...] | common.Field, - source: tuple[common.Field | tuple, ...] | common.Field, - domain: Optional[common.Domain], -): - if isinstance(target, tuple): - if not isinstance(source, tuple): - raise RuntimeError(f"Cannot assign {source} to {target}.") - for t, s in zip(target, source): - _tuple_assign_field(t, s, domain) - else: - domain = domain or target.domain - target[domain] = source[domain] + if self.operator_attributes is not None and any( + has_scan_op_attribute := [ + attribute in self.operator_attributes + for attribute in ["init", "axis", "forward"] + ] + ): + assert all(has_scan_op_attribute) + forward = self.operator_attributes["forward"] + init = self.operator_attributes["init"] + axis = self.operator_attributes["axis"] + op = embedded_operators.ScanOperator(self.definition, forward, init, axis) + else: + op = embedded_operators.EmbeddedOperator(self.definition) + return embedded_operators.field_operator_call(op, args, kwargs) @typing.overload diff --git a/src/gt4py/next/field_utils.py b/src/gt4py/next/field_utils.py new file mode 100644 index 0000000000..14b7c3838c --- /dev/null +++ b/src/gt4py/next/field_utils.py @@ -0,0 +1,22 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +from gt4py.next import common, utils + + +@utils.tree_map +def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: + return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b02d6c8d72..b00e53bfd9 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -196,7 +196,7 @@ def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: #: Column range used in column mode (`column_axis != None`) in the current closure execution context. -column_range_cvar: cvars.ContextVar[range] = next_embedded.context.closure_column_range +column_range_cvar: cvars.ContextVar[common.NamedRange] = next_embedded.context.closure_column_range #: Offset provider dict in the current closure execution context. offset_provider_cvar: cvars.ContextVar[OffsetProvider] = next_embedded.context.offset_provider @@ -211,8 +211,8 @@ class Column(np.lib.mixins.NDArrayOperatorsMixin): def __init__(self, kstart: int, data: np.ndarray | Scalar) -> None: self.kstart = kstart assert isinstance(data, (np.ndarray, Scalar)) # type: ignore # mypy bug #11673 - column_range = column_range_cvar.get() - self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range), data) + column_range: common.NamedRange = column_range_cvar.get() + self.data = data if isinstance(data, np.ndarray) else np.full(len(column_range[1]), data) def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] @@ -746,7 +746,7 @@ def _make_tuple( except embedded_exceptions.IndexOutOfBounds: return _UNDEFINED else: - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] assert column_range is not None col: list[ @@ -823,7 +823,7 @@ def deref(self) -> Any: assert isinstance(k_pos, int) # the following range describes a range in the field # (negative values are relative to the origin, not relative to the size) - slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range)) + slice_column[self.column_axis] = range(k_pos, k_pos + len(column_range[1])) assert _is_concrete_position(shifted_pos) position = {**shifted_pos, **slice_column} @@ -864,7 +864,7 @@ def make_in_iterator( init = [None] * sparse_dimensions.count(sparse_dim) new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused if column_axis is not None: - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted assert column_range is not None new_pos[column_axis] = column_range.start @@ -1479,7 +1479,7 @@ def _column_dtype(elem: Any) -> np.dtype: @builtins.scan.register(EMBEDDED) def scan(scan_pass, is_forward: bool, init): def impl(*iters: ItIterator): - column_range = column_range_cvar.get() + column_range = column_range_cvar.get()[1] if column_range is None: raise RuntimeError("Column range is not defined, cannot scan.") @@ -1532,7 +1532,10 @@ def closure( column = ColumnDescriptor(column_axis.value, domain[column_axis.value]) del domain[column_axis.value] - column_range = column.col_range + column_range = ( + column_axis, + common.UnitRange(column.col_range.start, column.col_range.stop), + ) out = as_tuple_field(out) if is_tuple_of_field(out) else _wrap_field(out) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index baae8361c5..ec459906e0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -15,10 +15,6 @@ import functools from typing import Any, Callable, ClassVar, ParamSpec, TypeGuard, TypeVar, cast -import numpy as np - -from gt4py.next import common - class RecursionGuard: """ @@ -57,7 +53,6 @@ def __exit__(self, *exc): _T = TypeVar("_T") - _P = ParamSpec("_P") _R = TypeVar("_R") @@ -66,8 +61,17 @@ def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: return isinstance(v, tuple) and all(isinstance(e, t) for e in v) +# TODO(havogt): remove flatten duplications in the whole codebase +def flatten_nested_tuple(value: tuple[_T | tuple, ...]) -> tuple[_T, ...]: + if isinstance(value, tuple): + return sum((flatten_nested_tuple(v) for v in value), start=()) # type: ignore[arg-type] # cannot properly express nesting + else: + return (value,) + + def tree_map(fun: Callable[_P, _R]) -> Callable[..., _R | tuple[_R | tuple, ...]]: - """Apply `fun` to each entry of (possibly nested) tuples. + """ + Apply `fun` to each entry of (possibly nested) tuples. Examples: >>> tree_map(lambda x: x + 1)(((1, 2), 3)) @@ -88,9 +92,3 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: ) # mypy doesn't understand that `args` at this point is of type `_P.args` return impl - - -# TODO(havogt): consider moving to module like `field_utils` -@tree_map -def asnumpy(field: common.Field | np.ndarray) -> np.ndarray: - return field.asnumpy() if common.is_field(field) else field # type: ignore[return-value] # mypy doesn't understand the condition diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 84287e209f..3c42a180dd 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -130,7 +130,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), ] EMBEDDED_SKIP_LIST = [ - (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), ] diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 81f216397b..b1e26b40cb 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,7 +28,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping from gt4py.eve.extended_typing import Self -from gt4py.next import common, constructors, utils +from gt4py.next import common, constructors, field_utils from gt4py.next.ffront import decorator from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_specifications as ts, type_translation @@ -436,8 +436,8 @@ def verify( out_comp = out or inout assert out_comp is not None - out_comp_ndarray = utils.asnumpy(out_comp) - ref_ndarray = utils.asnumpy(ref) + out_comp_ndarray = field_utils.asnumpy(out_comp) + ref_ndarray = field_utils.asnumpy(ref) assert comparison(ref_ndarray, out_comp_ndarray), ( f"Verification failed:\n" f"\tcomparison={comparison.__name__}(ref, out)\n" diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 7f37b41383..51f853d41d 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -292,6 +292,7 @@ def testee_op( cases.verify(cartesian_case, testee_op, qc, tuple_scalar, out=qc, ref=expected) +@pytest.mark.uses_cartesian_shift @pytest.mark.uses_scan @pytest.mark.uses_index_fields def test_scalar_scan_vertical_offset(cartesian_case): # noqa: F811 # fixtures @@ -802,6 +803,85 @@ def simple_scan_operator(carry: float, a: tuple[float, float]) -> float: cases.verify(cartesian_case, simple_scan_operator, (inp1, inp2), out=out, ref=expected) +@pytest.mark.uses_scan +def test_scan_different_domain_in_tuple(cartesian_case): + init = 1.0 + i_size = cartesian_case.default_sizes[IDim] + k_size = cartesian_case.default_sizes[KDim] + + inp1_np = np.ones( + ( + i_size + 1, + k_size, + ) + ) # i_size bigger than in the other argument + inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) + inp1 = cartesian_case.as_field([IDim, KDim], inp1_np) + inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) + out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size))) + + def prev_levels_iterator(i): + return range(i + 1) + + expected = np.asarray( + [ + reduce( + lambda prev, k: prev + inp1_np[:-1, k] + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ] + ).transpose() + + @gtx.scan_operator(axis=KDim, forward=True, init=init) + def scan_op(carry: float, a: tuple[float, float]) -> float: + return carry + a[0] + a[1] + + @gtx.field_operator + def foo( + inp1: gtx.Field[[IDim, KDim], float], inp2: gtx.Field[[IDim, KDim], float] + ) -> gtx.Field[[IDim, KDim], float]: + return scan_op((inp1, inp2)) + + cases.verify(cartesian_case, foo, inp1, inp2, out=out, ref=expected) + + +@pytest.mark.uses_scan +def test_scan_tuple_field_scalar_mixed(cartesian_case): + init = 1.0 + i_size = cartesian_case.default_sizes[IDim] + k_size = cartesian_case.default_sizes[KDim] + + inp2_np = np.fromfunction(lambda i, k: k, shape=(i_size, k_size), dtype=float) + inp2 = cartesian_case.as_field([IDim, KDim], inp2_np) + out = cartesian_case.as_field([IDim, KDim], np.zeros((i_size, k_size))) + + def prev_levels_iterator(i): + return range(i + 1) + + expected = np.asarray( + [ + reduce( + lambda prev, k: prev + 1.0 + inp2_np[:, k], + prev_levels_iterator(k), + init, + ) + for k in range(k_size) + ] + ).transpose() + + @gtx.scan_operator(axis=KDim, forward=True, init=init) + def scan_op(carry: float, a: tuple[float, float]) -> float: + return carry + a[0] + a[1] + + @gtx.field_operator + def foo(inp1: float, inp2: gtx.Field[[IDim, KDim], float]) -> gtx.Field[[IDim, KDim], float]: + return scan_op((inp1, inp2)) + + cases.verify(cartesian_case, foo, 1.0, inp2, out=out, ref=expected) + + def test_docstring(cartesian_case): @gtx.field_operator def fieldop_with_docstring(a: cases.IField) -> cases.IField: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index fd571514ac..9ba8eef3a3 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -16,7 +16,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import utils +from gt4py.next import field_utils from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fendef, fundef, offset @@ -158,7 +158,7 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct k_size = 5 inp = inp_function(k_size) - ref = ref_function(utils.asnumpy(inp)) + ref = ref_function(field_utils.asnumpy(inp)) out = gtx.as_field([KDim], np.zeros((5,), dtype=np.int32)) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 640ed326bb..de511fdabb 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -19,7 +19,7 @@ from gt4py.next import common from gt4py.next.common import UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions -from gt4py.next.embedded.common import _slice_range, sub_domain +from gt4py.next.embedded.common import _slice_range, iterate_domain, sub_domain @pytest.mark.parametrize( @@ -135,3 +135,15 @@ def test_sub_domain(domain, index, expected): expected = common.domain(expected) result = sub_domain(domain, index) assert result == expected + + +def test_iterate_domain(): + domain = common.domain({I: 2, J: 3}) + ref = [] + for i in domain[I][1]: + for j in domain[J][1]: + ref.append(((I, i), (J, j))) + + testee = list(iterate_domain(domain)) + + assert testee == ref diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py index 3a35570ca2..9238cd4f7a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_internals.py @@ -19,13 +19,14 @@ import numpy as np import pytest +from gt4py.next import common from gt4py.next.iterator import embedded def _run_within_context( func: Callable[[], Any], *, - column_range: Optional[range] = None, + column_range: Optional[common.NamedRange] = None, offset_provider: Optional[embedded.OffsetProvider] = None, ) -> Any: def wrapped_func(): @@ -59,7 +60,10 @@ def test_func(data_a: int, data_b: int): # Setting an invalid column_range here shouldn't affect other contexts embedded.column_range_cvar.set(range(2, 999)) - _run_within_context(lambda: test_func(2, 3), column_range=range(0, 3)) + _run_within_context( + lambda: test_func(2, 3), + column_range=(common.Dimension("K", kind=common.DimensionKind.VERTICAL), range(0, 3)), + ) def test_column_ufunc_with_scalar(): From 3f595ffd6206b5bf3344b7288f98ac8e82adba52 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 12 Dec 2023 17:29:32 +0100 Subject: [PATCH 03/13] feat[next][dace]: Fix for broken DaCe test (#1396) Fix for broken DaCe test in baseline: - use `flatten_list` to get `ValueExpr` arguments to numeric builtin function Additionally, enable test for DaCe backend (left-over from PR #1393). --- .../runners/dace_iterator/itir_to_tasklet.py | 4 +--- .../feature_tests/iterator_tests/test_conditional.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 32b8cbf2b1..d10a14a1ee 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -1010,9 +1010,7 @@ def _visit_reduce(self, node: itir.FunCall): def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args: list[SymbolExpr | ValueExpr] = list( - itertools.chain(*[self.visit(arg) for arg in node.args]) - ) + args = flatten_list(self.visit(node.args)) expr_args = [ (arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr) ] diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 8536dbea90..db7776b2f4 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -31,7 +31,6 @@ def stencil_conditional(inp): return tuple_get(0, tmp) + tuple_get(1, tmp) -@pytest.mark.uses_tuple_returns def test_conditional_w_tuple(program_processor): program_processor, validate = program_processor From a5b2450e282add00fe90b8cf98cd68d96d42b1ea Mon Sep 17 00:00:00 2001 From: Rico Haeuselmann Date: Wed, 13 Dec 2023 11:41:16 +0100 Subject: [PATCH 04/13] style[next]: standardize error messages. (#1386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - add style guide to the coding guidelines - fix existing error messages in next - deal with ensuing qa errorrs / test updates - unshadow one test and fix the code it wasn't testing Co-authored-by: Rico Häuselmann Co-authored-by: Enrique González Paredes --- CODING_GUIDELINES.md | 38 +++++ src/gt4py/_core/definitions.py | 22 ++- src/gt4py/next/allocators.py | 8 +- src/gt4py/next/common.py | 91 ++++++------ src/gt4py/next/constructors.py | 15 +- src/gt4py/next/embedded/common.py | 6 +- src/gt4py/next/embedded/nd_array_field.py | 37 ++--- src/gt4py/next/errors/exceptions.py | 10 +- .../next/ffront/ast_passes/simple_assign.py | 2 +- .../ffront/ast_passes/single_static_assign.py | 2 +- src/gt4py/next/ffront/decorator.py | 29 ++-- src/gt4py/next/ffront/fbuiltins.py | 7 +- src/gt4py/next/ffront/foast_introspection.py | 2 +- .../foast_passes/closure_var_folding.py | 2 +- .../ffront/foast_passes/type_deduction.py | 133 +++++++++--------- src/gt4py/next/ffront/foast_pretty_printer.py | 2 +- src/gt4py/next/ffront/foast_to_itir.py | 10 +- src/gt4py/next/ffront/func_to_foast.py | 23 +-- src/gt4py/next/ffront/func_to_past.py | 4 +- .../next/ffront/past_passes/type_deduction.py | 42 +++--- src/gt4py/next/ffront/past_to_itir.py | 36 ++--- src/gt4py/next/ffront/source_utils.py | 8 +- src/gt4py/next/ffront/type_info.py | 2 +- src/gt4py/next/iterator/dispatcher.py | 2 +- src/gt4py/next/iterator/embedded.py | 26 ++-- src/gt4py/next/iterator/ir.py | 8 +- src/gt4py/next/iterator/ir_utils/ir_makers.py | 2 +- src/gt4py/next/iterator/runtime.py | 2 +- src/gt4py/next/iterator/tracing.py | 6 +- src/gt4py/next/iterator/transforms/cse.py | 4 +- .../next/iterator/transforms/pass_manager.py | 4 +- .../next/iterator/transforms/unroll_reduce.py | 6 +- src/gt4py/next/iterator/type_inference.py | 29 ++-- src/gt4py/next/otf/binding/nanobind.py | 2 +- .../compilation/build_systems/cmake_lists.py | 4 +- src/gt4py/next/otf/compilation/compiler.py | 2 +- src/gt4py/next/otf/stages.py | 2 +- src/gt4py/next/otf/workflow.py | 6 +- .../codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py | 6 +- .../codegens/gtfn/gtfn_module.py | 6 +- .../codegens/gtfn/itir_to_gtfn_ir.py | 20 +-- .../program_processors/processor_interface.py | 88 ++++++++---- .../runners/dace_iterator/__init__.py | 4 +- .../runners/dace_iterator/itir_to_tasklet.py | 4 +- .../runners/dace_iterator/utility.py | 2 +- .../next/program_processors/runners/gtfn.py | 6 +- src/gt4py/next/type_system/type_info.py | 48 ++++--- .../next/type_system/type_translation.py | 34 ++--- tests/next_tests/integration_tests/cases.py | 16 +-- .../ffront_tests/ffront_test_utils.py | 3 +- .../ffront_tests/test_arg_call_interface.py | 8 +- .../ffront_tests/test_execution.py | 8 +- .../test_math_builtin_execution.py | 2 +- .../ffront_tests/test_math_unary_builtins.py | 4 +- .../ffront_tests/test_program.py | 4 +- .../ffront_tests/test_scalar_if.py | 4 +- .../ffront_tests/test_type_deduction.py | 68 ++++----- .../iterator_tests/test_builtins.py | 2 +- tests/next_tests/unit_tests/conftest.py | 2 +- .../embedded_tests/test_nd_array_field.py | 2 +- .../ffront_tests/test_func_to_foast.py | 14 +- .../ffront_tests/test_func_to_past.py | 18 +-- .../ffront_tests/test_past_to_itir.py | 4 +- .../iterator_tests/test_runtime_domain.py | 2 +- .../test_processor_interface.py | 4 +- .../next_tests/unit_tests/test_allocators.py | 2 +- tests/next_tests/unit_tests/test_common.py | 2 +- .../unit_tests/test_constructors.py | 4 +- .../test_type_translation.py | 2 +- 69 files changed, 571 insertions(+), 458 deletions(-) diff --git a/CODING_GUIDELINES.md b/CODING_GUIDELINES.md index 957df0fb04..9376644064 100644 --- a/CODING_GUIDELINES.md +++ b/CODING_GUIDELINES.md @@ -51,6 +51,44 @@ We deviate from the [Google Python Style Guide][google-style-guide] only in the - Client code (like tests, doctests and examples) should use the above style for public FieldView API - Library code should always import the defining module and use qualified names. +### Error messages + +Error messages should be written as sentences, starting with a capital letter and ending with a period (avoid exclamation marks). Try to be informative without being verbose. Code objects such as 'ClassNames' and 'function_names' should be enclosed in single quotes, and so should string values used for message interpolation. + +Examples: + +```python +raise ValueError(f"Invalid argument 'dimension': should be of type 'Dimension', got '{dimension.type}'.") +``` + +Interpolated integer values do not need double quotes, if they are indicating an amount. Example: + +```python +raise ValueError(f"Invalid number of arguments: expected 3 arguments, got {len(args)}.") +``` + +The double quotes can also be dropped when presenting a sequence of values. In this case the message should be rephrased so the sequence is separated from the text by a colon ':'. + +```python +raise ValueError(f"unexpected keyword arguments: {', '.join(set(kwarg_names} - set(expected_kwarg_names)))}.") +``` + +The message should be kept to one sentence if reasonably possible. Ideally the sentence should be kept short and avoid unneccessary words. Examples: + +```python +# too many sentences +raise ValueError(f"Received an unexpeted number of arguments. Should receive 5 arguments, but got {len(args)}. Please provide the correct number of arguments.") +# better +raise ValueError(f"Wrong number of arguments: expected 5, got {len(args)}.") + +# less extreme +raise TypeError(f"Wrong argument type. Can only accept 'int's, got '{type(arg)}' instead.") +# but can still be improved +raise TypeError(f"Wrong argument type: 'int' expected, got '{type(arg)}'") +``` + +The terseness vs. helpfulness tradeoff should be more in favor of terseness for internal error messages and more in favor of helpfulness for `DSLError` and it's subclassses, where additional sentences are encouraged if they point out likely hidden sources of the problem or common fixes. + ### Docstrings We generate the API documentation automatically from the docstrings using [Sphinx][sphinx] and some extensions such as [Sphinx-autodoc][sphinx-autodoc] and [Sphinx-napoleon][sphinx-napoleon]. These follow the Google Python Style Guide docstring conventions to automatically format the generated documentation. A complete overview can be found here: [Example Google Style Python Docstrings](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html#example-google). diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 0e6301ae0f..091fa77e3f 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -73,17 +73,23 @@ BoolScalar: TypeAlias = Union[bool_, bool] BoolT = TypeVar("BoolT", bound=BoolScalar) -BOOL_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], BoolScalar.__args__) # type: ignore[attr-defined] +BOOL_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], BoolScalar.__args__ # type: ignore[attr-defined] +) IntScalar: TypeAlias = Union[int8, int16, int32, int64, int] IntT = TypeVar("IntT", bound=IntScalar) -INT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], IntScalar.__args__) # type: ignore[attr-defined] +INT_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], IntScalar.__args__ # type: ignore[attr-defined] +) UnsignedIntScalar: TypeAlias = Union[uint8, uint16, uint32, uint64] UnsignedIntT = TypeVar("UnsignedIntT", bound=UnsignedIntScalar) -UINT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], UnsignedIntScalar.__args__) # type: ignore[attr-defined] +UINT_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], UnsignedIntScalar.__args__ # type: ignore[attr-defined] +) IntegralScalar: TypeAlias = Union[IntScalar, UnsignedIntScalar] @@ -93,7 +99,9 @@ FloatingScalar: TypeAlias = Union[float32, float64, float] FloatingT = TypeVar("FloatingT", bound=FloatingScalar) -FLOAT_TYPES: Final[Tuple[type, ...]] = cast(Tuple[type, ...], FloatingScalar.__args__) # type: ignore[attr-defined] +FLOAT_TYPES: Final[Tuple[type, ...]] = cast( + Tuple[type, ...], FloatingScalar.__args__ # type: ignore[attr-defined] +) #: Type alias for all scalar types supported by GT4Py @@ -195,7 +203,7 @@ def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: if issubclass(sc_type, numbers.Complex): return DTypeKind.COMPLEX - raise TypeError("Unknown scalar type kind") + raise TypeError("Unknown scalar type kind.") @dataclasses.dataclass(frozen=True) @@ -491,10 +499,10 @@ def __rtruediv__(self, other: Any) -> NDArrayObject: def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool` ... - def __ne__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy want to return `bool` + def __ne__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool` ... def __gt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 58600d8cda..97e83276fe 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -142,7 +142,9 @@ def get_allocator( elif not strict or is_field_allocator(default): return default else: - raise TypeError(f"Object {obj} is neither a field allocator nor a field allocator factory") + raise TypeError( + f"Object '{obj}' is neither a field allocator nor a field allocator factory." + ) @dataclasses.dataclass(frozen=True) @@ -331,7 +333,7 @@ def allocate( """ if device is None and allocator is None: - raise ValueError("No 'device' or 'allocator' specified") + raise ValueError("No 'device' or 'allocator' specified.") actual_allocator = get_allocator(allocator) if actual_allocator is None: assert device is not None # for mypy @@ -339,7 +341,7 @@ def allocate( elif device is None: device = core_defs.Device(actual_allocator.__gt_device_type__, 0) elif device.device_type != actual_allocator.__gt_device_type__: - raise ValueError(f"Device {device} and allocator {actual_allocator} are incompatible") + raise ValueError(f"Device '{device}' and allocator '{actual_allocator}' are incompatible.") return actual_allocator.__gt_allocate__( domain=common.domain(domain), diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 7f1ad8c0bb..3e1fe52f31 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -125,7 +125,7 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re if isinstance(index, slice): start, stop, step = index.indices(len(self)) if step != 1: - raise ValueError("UnitRange: step required to be `1`.") + raise ValueError("'UnitRange': step required to be '1'.") new_start = self.start + (start or 0) new_stop = (self.start if stop > 0 else self.stop) + stop return UnitRange(new_start, new_stop) @@ -136,7 +136,7 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re if 0 <= index < len(self): return self.start + index else: - raise IndexError("UnitRange index out of range") + raise IndexError("'UnitRange' index out of range") def __and__(self, other: Set[int]) -> UnitRange: if isinstance(other, UnitRange): @@ -144,7 +144,9 @@ def __and__(self, other: Set[int]) -> UnitRange: stop = min(self.stop, other.stop) return UnitRange(start, stop) else: - raise NotImplementedError("Can only find the intersection between UnitRange instances.") + raise NotImplementedError( + "Can only find the intersection between 'UnitRange' instances." + ) def __le__(self, other: Set[int]): if isinstance(other, UnitRange): @@ -167,7 +169,7 @@ def __add__(self, other: int | Set[int]) -> UnitRange: ) ) else: - raise NotImplementedError("Can only compute union with int instances.") + raise NotImplementedError("Can only compute union with 'int' instances.") def __sub__(self, other: int | Set[int]) -> UnitRange: if isinstance(other, int): @@ -178,7 +180,7 @@ def __sub__(self, other: int | Set[int]) -> UnitRange: else: return self + (-other) else: - raise NotImplementedError("Can only compute substraction with int instances.") + raise NotImplementedError("Can only compute substraction with 'int' instances.") __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented @@ -199,7 +201,7 @@ def unit_range(r: RangeLike) -> UnitRange: return r if isinstance(r, range): if r.step != 1: - raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.") + raise ValueError(f"'UnitRange' requires step size 1, got '{r.step}'.") return UnitRange(r.start, r.stop) # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604) # once the related mypy bug (#16358) gets fixed @@ -211,7 +213,7 @@ def unit_range(r: RangeLike) -> UnitRange: return UnitRange(r[0], r[1]) if isinstance(r, core_defs.INTEGRAL_TYPES): return UnitRange(0, cast(core_defs.IntegralScalar, r)) - raise ValueError(f"`{r!r}` cannot be interpreted as `UnitRange`.") + raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") IntIndex: TypeAlias = int | core_defs.IntegralScalar @@ -296,20 +298,20 @@ def __init__( ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: - raise ValueError("Either both none of `dims` and `ranges` must be specified.") + raise ValueError("Either both none of 'dims' and 'ranges' must be specified.") if len(args) > 0: raise ValueError( - "No extra `args` allowed when constructing fomr `dims` and `ranges`." + "No extra 'args' allowed when constructing fomr 'dims' and 'ranges'." ) assert dims is not None and ranges is not None # for mypy if not all(isinstance(dim, Dimension) for dim in dims): raise ValueError( - f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." + f"'dims' argument needs to be a 'tuple[Dimension, ...]', got '{dims}'." ) if not all(isinstance(rng, UnitRange) for rng in ranges): raise ValueError( - f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." + f"'ranges' argument needs to be a 'tuple[UnitRange, ...]', got '{ranges}'." ) if len(dims) != len(ranges): raise ValueError( @@ -320,13 +322,15 @@ def __init__( object.__setattr__(self, "ranges", tuple(ranges)) else: if not all(is_named_range(arg) for arg in args): - raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") + raise ValueError( + f"Elements of 'Domain' need to be instances of 'NamedRange', got '{args}'." + ) dims, ranges = zip(*args) if args else ((), ()) object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) if len(set(self.dims)) != len(self.dims): - raise NotImplementedError(f"Domain dimensions must be unique, not {self.dims}.") + raise NotImplementedError(f"Domain dimensions must be unique, not '{self.dims}'.") def __len__(self) -> int: return len(self.ranges) @@ -365,7 +369,7 @@ def __getitem__( # noqa: F811 # redefine unused index_pos = self.dims.index(index) return self.dims[index_pos], self.ranges[index_pos] except ValueError: - raise KeyError(f"No Dimension of type {index} is present in the Domain.") + raise KeyError(f"No Dimension of type '{index}' is present in the Domain.") else: raise KeyError("Invalid index type, must be either int, slice, or Dimension.") @@ -415,10 +419,12 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: if isinstance(index, Dimension): dim_index = self.dim_index(index) if dim_index is None: - raise ValueError(f"Dimension {index} not found in Domain.") + raise ValueError(f"Dimension '{index}' not found in Domain.") index = dim_index if not (-len(self.dims) <= index < len(self.dims)): - raise IndexError(f"Index {index} out of bounds for Domain of length {len(self.dims)}.") + raise IndexError( + f"Index '{index}' out of bounds for Domain of length {len(self.dims)}." + ) if index < 0: index += len(self.dims) new_dims, new_ranges = zip(*named_ranges) if len(named_ranges) > 0 else ((), ()) @@ -462,13 +468,16 @@ def domain(domain_like: DomainLike) -> Domain: if all(isinstance(elem, core_defs.INTEGRAL_TYPES) for elem in domain_like.values()): return Domain( dims=tuple(domain_like.keys()), - ranges=tuple(UnitRange(0, s) for s in domain_like.values()), # type: ignore[arg-type] # type of `s` is checked in condition + ranges=tuple( + UnitRange(0, s) # type: ignore[arg-type] # type of `s` is checked in condition + for s in domain_like.values() + ), ) return Domain( dims=tuple(domain_like.keys()), ranges=tuple(unit_range(r) for r in domain_like.values()), ) - raise ValueError(f"`{domain_like}` is not `DomainLike`.") + raise ValueError(f"'{domain_like}' is not 'DomainLike'.") def _broadcast_ranges( @@ -670,7 +679,8 @@ class ConnectivityKind(enum.Flag): @extended_runtime_checkable -class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): # type: ignore[misc] # DimT should be covariant, but break in another place +# type: ignore[misc] # DimT should be covariant, but break in another place +class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod def codomain(self) -> DimT: @@ -690,61 +700,61 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa # Operators def __abs__(self) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __neg__(self) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __invert__(self) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __eq__(self, other: Any) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __ne__(self, other: Any) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("ConnectivityField does not support this operation") + raise TypeError("'ConnectivityField' does not support this operation.") def is_connectivity_field( @@ -845,7 +855,7 @@ def __gt_dims__(self) -> tuple[Dimension, ...]: @property def __gt_origin__(self) -> Never: - raise TypeError("CartesianConnectivity does not support this operation") + raise TypeError("'CartesianConnectivity' does not support this operation.") @property def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: @@ -877,7 +887,7 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa if not isinstance(image_range, UnitRange): if image_range[0] != self.codomain: raise ValueError( - f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." ) image_range = image_range[1] @@ -1017,3 +1027,4 @@ def register_builtin_func( @classmethod def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Callable[_P, _R]: return cls._builtin_func_map.get(func, NotImplemented) + return cls._builtin_func_map.get(func, NotImplemented) diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 63fde1cfde..9bb4cf17e5 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -254,12 +254,12 @@ def as_field( domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: raise ValueError( - f"Cannot construct `Field` from array of shape `{data.shape}` and domain `{domain}` " + f"Cannot construct 'Field' from array of shape '{data.shape}' and domain '{domain}'." ) if origin: domain_dims = set(domain) if unknown_dims := set(origin.keys()) - domain_dims: - raise ValueError(f"Origin keys {unknown_dims} not in domain {domain}") + raise ValueError(f"Origin keys {unknown_dims} not in domain {domain}.") else: origin = {} actual_domain = common.domain( @@ -277,7 +277,7 @@ def as_field( # already the correct layout and device. shape = storage_utils.asarray(data).shape if shape != actual_domain.shape: - raise ValueError(f"Cannot construct `Field` from array of shape `{shape}` ") + raise ValueError(f"Cannot construct 'Field' from array of shape '{shape}'.") if dtype is None: dtype = storage_utils.asarray(data).dtype dtype = core_defs.dtype(dtype) @@ -334,20 +334,20 @@ def as_connectivity( domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: raise ValueError( - f"Cannot construct `Field` from array of shape `{data.shape}` and domain `{domain}` " + f"Cannot construct 'Field' from array of shape '{data.shape}' and domain '{domain}'." ) actual_domain = common.domain([(d, (0, s)) for d, s in zip(domain, data.shape)]) else: actual_domain = common.domain(cast(common.DomainLike, domain)) if not isinstance(codomain, common.Dimension): - raise ValueError(f"Invalid codomain dimension `{codomain}`") + raise ValueError(f"Invalid codomain dimension '{codomain}'.") # TODO(egparedes): allow zero-copy construction (no reallocation) if buffer has # already the correct layout and device. shape = storage_utils.asarray(data).shape if shape != actual_domain.shape: - raise ValueError(f"Cannot construct `Field` from array of shape `{shape}` ") + raise ValueError(f"Cannot construct 'Field' from array of shape '{shape}'.") if dtype is None: dtype = storage_utils.asarray(data).dtype dtype = core_defs.dtype(dtype) @@ -356,7 +356,8 @@ def as_connectivity( if (allocator is None) and (device is None) and xtyping.supports_dlpack(data): device = core_defs.Device(*data.__dlpack_device__()) buffer = next_allocators.allocate(actual_domain, dtype, allocator=allocator, device=device) - buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] # TODO(havogt): consider addin MutableNDArrayObject + # TODO(havogt): consider addin MutableNDArrayObject + buffer.ndarray[...] = storage_utils.asarray(data) # type: ignore[index] connectivity_field = common.connectivity( buffer.ndarray, codomain=codomain, domain=actual_domain ) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 558730cb82..87e0800a10 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -32,7 +32,7 @@ def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Doma if common.is_relative_index_sequence(index_sequence): return _relative_sub_domain(domain, index_sequence) - raise IndexError(f"Unsupported index type: {index}") + raise IndexError(f"Unsupported index type: '{index}'.") def _relative_sub_domain( @@ -42,7 +42,9 @@ def _relative_sub_domain( expanded = _expand_ellipsis(index, len(domain)) if len(domain) < len(expanded): - raise IndexError(f"Trying to index a `Field` with {len(domain)} dimensions with {index}.") + raise IndexError( + f"Can not access dimension with index {index} of 'Field' with {len(domain)} dimensions." + ) expanded += (slice(None),) * (len(domain) - len(expanded)) for (dim, rng), idx in zip(domain, expanded, strict=True): if isinstance(idx, slice): diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 6b69e8f8cc..fbfe64ac42 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -174,7 +174,7 @@ def remap( dim = connectivity.codomain dim_idx = self.domain.dim_index(dim) if dim_idx is None: - raise ValueError(f"Incompatible index field, expected a field with dimension {dim}.") + raise ValueError(f"Incompatible index field, expected a field with dimension '{dim}'.") current_range: common.UnitRange = self.domain[dim_idx][1] new_ranges = connectivity.inverse_image(current_range) @@ -226,7 +226,7 @@ def __setitem__( if common.is_field(value): if not value.domain == target_domain: raise ValueError( - f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + f"Incompatible 'Domain' in assignment. Source domain = '{value.domain}', target domain = '{target_domain}'." ) value = value.ndarray @@ -268,28 +268,28 @@ def __setitem__( def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_and", "logical_and")(self, other) - raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__and__' not implemented for non-'bool' fields.") __rand__ = __and__ def __or__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_or", "logical_or")(self, other) - raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__or__' not implemented for non-'bool' fields.") __ror__ = __or__ def __xor__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("logical_xor", "logical_xor")(self, other) - raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__xor__' not implemented for non-'bool' fields.") __rxor__ = __xor__ def __invert__(self) -> NdArrayField: if self.dtype == core_defs.BoolDType(): return _make_builtin("invert", "invert")(self) - raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") + raise NotImplementedError("'__invert__' not implemented for non-'bool' fields.") def _slice( self, index: common.AnyIndexSpec @@ -322,7 +322,8 @@ def __gt_builtin_func__(cls, _: fbuiltins.BuiltInFunction) -> Never: # type: ig raise NotImplementedError() @property - def codomain(self) -> common.DimT: # type: ignore[override] # TODO(havogt): instead of inheriting from NdArrayField, steal implementation or common base + # type: ignore[override] # TODO(havogt): instead of inheriting from NdArrayField, steal implementation or common base + def codomain(self) -> common.DimT: return self._codomain @functools.cached_property @@ -378,7 +379,7 @@ def inverse_image( ): # TODO(havogt): cleanup duplication with CartesianConnectivity if image_range[0] != self.codomain: raise ValueError( - f"Dimension {image_range[0]} does not match the codomain dimension {self.codomain}" + f"Dimension '{image_range[0]}' does not match the codomain dimension '{self.codomain}'." ) image_range = image_range[1] @@ -423,7 +424,7 @@ def inverse_image( if non_contiguous_dims: raise ValueError( - f"Restriction generates non-contiguous dimensions {non_contiguous_dims}" + f"Restriction generates non-contiguous dimensions '{non_contiguous_dims}'." ) return new_dims @@ -446,8 +447,12 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ # -- Specialized implementations for builtin operations on array fields -- -NdArrayField.register_builtin_func(fbuiltins.abs, NdArrayField.__abs__) # type: ignore[attr-defined] -NdArrayField.register_builtin_func(fbuiltins.power, NdArrayField.__pow__) # type: ignore[attr-defined] +NdArrayField.register_builtin_func( + fbuiltins.abs, NdArrayField.__abs__ # type: ignore[attr-defined] +) +NdArrayField.register_builtin_func( + fbuiltins.power, NdArrayField.__pow__ # type: ignore[attr-defined] +) # TODO gamma for name in ( @@ -480,7 +485,7 @@ def _builtin_op( if not axis.kind == common.DimensionKind.LOCAL: raise ValueError("Can only reduce local dimensions.") if axis not in field.domain.dims: - raise ValueError(f"Field doesn't have dimension {axis}. Cannot reduce.") + raise ValueError(f"Field can not be reduced as it doesn't have dimension '{axis}'.") reduce_dim_index = field.domain.dims.index(axis) new_domain = common.Domain(*[nr for nr in field.domain if nr[0] != axis]) return field.__class__.from_array( @@ -547,7 +552,7 @@ def __setitem__( value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` - raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.") + raise NotImplementedError("'__setitem__' for JaxArrayField not yet implemented.") common.field.register(jnp.ndarray, JaxArrayField.from_array) @@ -572,7 +577,7 @@ def _builtins_broadcast( ) -> common.Field: # separated for typing reasons if common.is_field(field): return _broadcast(field, new_dimensions) - raise AssertionError("Scalar case not reachable from `fbuiltins.broadcast`.") + raise AssertionError("Scalar case not reachable from 'fbuiltins.broadcast'.") NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast) @@ -581,7 +586,7 @@ def _builtins_broadcast( def _astype(field: common.Field | core_defs.ScalarT | tuple, type_: type) -> NdArrayField: if isinstance(field, NdArrayField): return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain) - raise AssertionError("This is the NdArrayField implementation of `fbuiltins.astype`.") + raise AssertionError("This is the NdArrayField implementation of 'fbuiltins.astype'.") NdArrayField.register_builtin_func(fbuiltins.astype, _astype) @@ -643,4 +648,4 @@ def _compute_slice( elif common.is_int_index(rng): return rng - domain.ranges[pos].start else: - raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") + raise ValueError(f"Can only use integer or UnitRange ranges, provided type: '{type(rng)}'.") diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index e956858549..081453c023 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -61,7 +61,7 @@ class UnsupportedPythonFeatureError(DSLError): feature: str def __init__(self, location: Optional[SourceLocation], feature: str) -> None: - super().__init__(location, f"unsupported Python syntax: '{feature}'") + super().__init__(location, f"Unsupported Python syntax: '{feature}'.") self.feature = feature @@ -69,7 +69,7 @@ class UndefinedSymbolError(DSLError): sym_name: str def __init__(self, location: Optional[SourceLocation], name: str) -> None: - super().__init__(location, f"name '{name}' is not defined") + super().__init__(location, f"Name '{name}' is not defined.") self.sym_name = name @@ -77,7 +77,7 @@ class MissingAttributeError(DSLError): attr_name: str def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: - super().__init__(location, f"object does not have attribute '{attr_name}'") + super().__init__(location, f"Object does not have attribute '{attr_name}'.") self.attr_name = attr_name @@ -90,7 +90,7 @@ class MissingParameterAnnotationError(TypeError_): param_name: str def __init__(self, location: Optional[SourceLocation], param_name: str) -> None: - super().__init__(location, f"parameter '{param_name}' is missing type annotations") + super().__init__(location, f"Parameter '{param_name}' is missing type annotations.") self.param_name = param_name @@ -100,7 +100,7 @@ class InvalidParameterAnnotationError(TypeError_): def __init__(self, location: Optional[SourceLocation], param_name: str, type_: Any) -> None: super().__init__( - location, f"parameter '{param_name}' has invalid type annotation '{type_}'" + location, f"Parameter '{param_name}' has invalid type annotation '{type_}'." ) self.param_name = param_name self.annotated_type = type_ diff --git a/src/gt4py/next/ffront/ast_passes/simple_assign.py b/src/gt4py/next/ffront/ast_passes/simple_assign.py index e2e6439e37..8b079bb8c1 100644 --- a/src/gt4py/next/ffront/ast_passes/simple_assign.py +++ b/src/gt4py/next/ffront/ast_passes/simple_assign.py @@ -22,7 +22,7 @@ class NodeYielder(ast.NodeTransformer): def apply(cls, node: ast.AST) -> ast.AST: result = list(cls().visit(node)) if len(result) != 1: - raise ValueError("AST was split or lost during the pass. Use `.visit()` instead.") + raise ValueError("AST was split or lost during the pass, use '.visit()' instead.") return result[0] def visit(self, node: ast.AST) -> Iterator[ast.AST]: diff --git a/src/gt4py/next/ffront/ast_passes/single_static_assign.py b/src/gt4py/next/ffront/ast_passes/single_static_assign.py index 4181d7f449..ee1e29a8e8 100644 --- a/src/gt4py/next/ffront/ast_passes/single_static_assign.py +++ b/src/gt4py/next/ffront/ast_passes/single_static_assign.py @@ -65,7 +65,7 @@ class _AssignmentTracker: def define(self, name: str) -> None: if name in self.names(): - raise ValueError(f"Variable {name} is already defined.") + raise ValueError(f"Variable '{name}' is already defined.") # -1 signifies a self._counts[name] = -1 diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 8202cda6f5..4abd8f156a 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -88,7 +88,7 @@ def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any raise NotImplementedError( f"Using closure vars with same name but different value " f"across functions is not implemented yet. \n" - f"Collisions: {'`, `'.join(collisions)}" + f"Collisions: '{', '.join(collisions)}'." ) all_closure_vars = collections.ChainMap(all_closure_vars, all_child_closure_vars) @@ -125,7 +125,7 @@ def is_cartesian_offset(o: FieldOffset): if requested_grid_type == GridType.CARTESIAN and deduced_grid_type == GridType.UNSTRUCTURED: raise ValueError( - "grid_type == GridType.CARTESIAN was requested, but unstructured `FieldOffset` or local `Dimension` was found." + "'grid_type == GridType.CARTESIAN' was requested, but unstructured 'FieldOffset' or local 'Dimension' was found." ) return deduced_grid_type if requested_grid_type is None else requested_grid_type @@ -147,7 +147,7 @@ def _field_constituents_shape_and_dims( elif isinstance(arg_type, ts.ScalarType): yield (None, []) else: - raise ValueError("Expected `FieldType` or `TupleType` thereof.") + raise ValueError("Expected 'FieldType' or 'TupleType' thereof.") # TODO(tehrengruber): Decide if and how programs can call other programs. As a @@ -208,7 +208,7 @@ def __post_init__(self): ] if misnamed_functions: raise RuntimeError( - f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}" + f"The following symbols resolve to a function with a mismatching name: {','.join(misnamed_functions)}." ) undefined_symbols = [ @@ -218,7 +218,7 @@ def __post_init__(self): ] if undefined_symbols: raise RuntimeError( - f"The following closure variables are undefined: {', '.join(undefined_symbols)}" + f"The following closure variables are undefined: {', '.join(undefined_symbols)}." ) @functools.cached_property @@ -228,7 +228,7 @@ def __gt_allocator__( if self.backend: return self.backend.__gt_allocator__ else: - raise RuntimeError(f"Program {self} does not have a backend set.") + raise RuntimeError(f"Program '{self}' does not have a backend set.") def with_backend(self, backend: ppi.ProgramExecutor) -> Program: return dataclasses.replace(self, backend=backend) @@ -263,7 +263,7 @@ def with_bound_args(self, **kwargs) -> ProgramWithBoundArgs: """ for key in kwargs.keys(): if all(key != param.id for param in self.past_node.params): - raise TypeError(f"Keyword argument `{key}` is not a valid program parameter.") + raise TypeError(f"Keyword argument '{key}' is not a valid program parameter.") return ProgramWithBoundArgs( bound_args=kwargs, @@ -344,7 +344,7 @@ def _validate_args(self, *args, **kwargs) -> None: raise_exception=True, ) except ValueError as err: - raise TypeError(f"Invalid argument types in call to `{self.past_node.id}`!") from err + raise TypeError(f"Invalid argument types in call to '{self.past_node.id}'.") from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) @@ -397,9 +397,10 @@ def _column_axis(self): ] raise TypeError( - "Only `ScanOperator`s defined on the same axis " - + "can be used in a `Program`, but found:\n" + "Only 'ScanOperator's defined on the same axis " + + "can be used in a 'Program', found:\n" + "\n".join(scanops_per_axis_strs) + + "." ) return iter(scanops_per_axis.keys()).__next__() @@ -436,7 +437,7 @@ def _process_args(self, args: tuple, kwargs: dict): # a better error message. for name in self.bound_args.keys(): if name in kwargs: - raise ValueError(f"Parameter `{name}` already set as a bound argument.") + raise ValueError(f"Parameter '{name}' already set as a bound argument.") type_info.accepts_args( new_type, @@ -445,10 +446,10 @@ def _process_args(self, args: tuple, kwargs: dict): raise_exception=True, ) except ValueError as err: - bound_arg_names = ", ".join([f"`{bound_arg}`" for bound_arg in self.bound_args.keys()]) + bound_arg_names = ", ".join([f"'{bound_arg}'" for bound_arg in self.bound_args.keys()]) raise TypeError( - f"Invalid argument types in call to program `{self.past_node.id}` with " - f"bound arguments {bound_arg_names}!" + f"Invalid argument types in call to program '{self.past_node.id}' with " + f"bound arguments '{bound_arg_names}'." ) from err full_args = [*args] diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 8230e35a35..93f17b1eb8 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -139,13 +139,16 @@ def __call__(self, mask: MaskT, true_field: FieldT, false_field: FieldT) -> _R: if isinstance(true_field, tuple) or isinstance(false_field, tuple): if not (isinstance(true_field, tuple) and isinstance(false_field, tuple)): raise ValueError( - f"Either both or none can be tuple in {true_field=} and {false_field=}." # TODO(havogt) find a strategy to unify parsing and embedded error messages + # TODO(havogt) find a strategy to unify parsing and embedded error messages + f"Either both or none can be tuple in '{true_field=}' and '{false_field=}'." ) if len(true_field) != len(false_field): raise ValueError( "Tuple of different size not allowed." ) # TODO(havogt) find a strategy to unify parsing and embedded error messages - return tuple(where(mask, t, f) for t, f in zip(true_field, false_field)) # type: ignore[return-value] # `tuple` is not `_R` + return tuple( + where(mask, t, f) for t, f in zip(true_field, false_field) + ) # type: ignore[return-value] # `tuple` is not `_R` return super().__call__(mask, true_field, false_field) diff --git a/src/gt4py/next/ffront/foast_introspection.py b/src/gt4py/next/ffront/foast_introspection.py index 805df465b8..404b99d1a0 100644 --- a/src/gt4py/next/ffront/foast_introspection.py +++ b/src/gt4py/next/ffront/foast_introspection.py @@ -73,4 +73,4 @@ def deduce_stmt_return_kind(node: foast.Stmt) -> StmtReturnKind: elif isinstance(node, (foast.Assign, foast.TupleTargetAssign)): return StmtReturnKind.NO_RETURN else: - raise AssertionError(f"Statements of type `{type(node).__name__}` not understood.") + raise AssertionError(f"Statements of type '{type(node).__name__}' not understood.") diff --git a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py index 9afd22de2c..0561a80659 100644 --- a/src/gt4py/next/ffront/foast_passes/closure_var_folding.py +++ b/src/gt4py/next/ffront/foast_passes/closure_var_folding.py @@ -56,7 +56,7 @@ def visit_Attribute(self, node: foast.Attribute, **kwargs) -> foast.Constant: if hasattr(value.value, node.attr): return foast.Constant(value=getattr(value.value, node.attr), location=node.location) raise errors.MissingAttributeError(node.location, node.attr) - raise errors.DSLError(node.location, "attribute access only applicable to constants") + raise errors.DSLError(node.location, "Attribute access only applicable to constants.") def visit_FunctionDefinition( self, node: foast.FunctionDefinition, **kwargs diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 95c9128f87..639e5ff009 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -53,7 +53,7 @@ def with_altered_scalar_kind( elif isinstance(type_spec, ts.ScalarType): return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape) else: - raise ValueError(f"Expected field or scalar type, but got {type_spec}.") + raise ValueError(f"Expected field or scalar type, got '{type_spec}'.") def construct_tuple_type( @@ -113,7 +113,9 @@ def promote_to_mask_type( item in input_type.dims for item in mask_type.dims ): return_dtype = input_type.dtype if isinstance(input_type, ts.FieldType) else input_type - return type_info.promote(input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype)) # type: ignore + return type_info.promote( + input_type, ts.FieldType(dims=mask_type.dims, dtype=return_dtype) + ) # type: ignore else: return input_type @@ -148,7 +150,7 @@ def deduce_stmt_return_type( else: raise errors.DSLError( stmt.location, - f"If statement contains return statements with inconsistent types:" + "If statement contains return statements with inconsistent types:" f"{return_types[0]} != {return_types[1]}", ) return_type = return_types[0] or return_types[1] @@ -160,12 +162,12 @@ def deduce_stmt_return_type( elif isinstance(stmt, (foast.Assign, foast.TupleTargetAssign)): return_type = None else: - raise AssertionError(f"Nodes of type `{type(stmt).__name__}` not supported.") + raise AssertionError(f"Nodes of type '{type(stmt).__name__}' not supported.") if conditional_return_type and return_type and return_type != conditional_return_type: raise errors.DSLError( stmt.location, - f"If statement contains return statements with inconsistent types:" + "If statement contains return statements with inconsistent types:" f"{conditional_return_type} != {conditional_return_type}", ) @@ -179,7 +181,7 @@ def deduce_stmt_return_type( # If the node was constructed by the foast parsing we should never get here, but instead # we should have gotten an error there. raise AssertionError( - "Malformed block statement. Expected a return statement in this context, " + "Malformed block statement: expected a return statement in this context, " "but none was found. Please submit a bug report." ) @@ -195,7 +197,7 @@ def apply(cls, node: foast.LocatedNode) -> None: cls().visit(node, incomplete_nodes=incomplete_nodes) if incomplete_nodes: - raise AssertionError("FOAST expression is not fully typed.") + raise AssertionError("'FOAST' expression is not fully typed.") def visit_LocatedNode( self, node: foast.LocatedNode, *, incomplete_nodes: list[foast.LocatedNode] @@ -251,7 +253,7 @@ def visit_FunctionDefinition(self, node: foast.FunctionDefinition, **kwargs): if not isinstance(return_type, (ts.DataType, ts.DeferredType, ts.VoidType)): raise errors.DSLError( node.location, - f"Function must return `DataType`, `DeferredType`, or `VoidType`, got `{return_type}`.", + f"Function must return 'DataType', 'DeferredType', or 'VoidType', got '{return_type}'.", ) new_type = ts.FunctionType( pos_only_args=[], @@ -283,17 +285,17 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp if not isinstance(new_axis.type, ts.DimensionType): raise errors.DSLError( node.location, - f"Argument `axis` to scan operator `{node.id}` must be a dimension.", + f"Argument 'axis' to scan operator '{node.id}' must be a dimension.", ) if not new_axis.type.dim.kind == DimensionKind.VERTICAL: raise errors.DSLError( node.location, - f"Argument `axis` to scan operator `{node.id}` must be a vertical dimension.", + f"Argument 'axis' to scan operator '{node.id}' must be a vertical dimension.", ) new_forward = self.visit(node.forward, **kwargs) if not new_forward.type.kind == ts.ScalarKind.BOOL: raise errors.DSLError( - node.location, f"Argument `forward` to scan operator `{node.id}` must be a boolean." + node.location, f"Argument 'forward' to scan operator '{node.id}' must be a boolean." ) new_init = self.visit(node.init, **kwargs) if not all( @@ -302,8 +304,8 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp ): raise errors.DSLError( node.location, - f"Argument `init` to scan operator `{node.id}` must " - f"be an arithmetic type or a logical type or a composite of arithmetic and logical types.", + f"Argument 'init' to scan operator '{node.id}' must " + "be an arithmetic type or a logical type or a composite of arithmetic and logical types.", ) new_definition = self.visit(node.definition, **kwargs) new_def_type = new_definition.type @@ -311,15 +313,15 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp if new_init.type != new_def_type.returns: raise errors.DSLError( node.location, - f"Argument `init` to scan operator `{node.id}` must have same type as its return. " - f"Expected `{new_def_type.returns}`, but got `{new_init.type}`", + f"Argument 'init' to scan operator '{node.id}' must have same type as its return: " + f"expected '{new_def_type.returns}', got '{new_init.type}'.", ) elif new_init.type != carry_type: carry_arg_name = list(new_def_type.pos_or_kw_args.keys())[0] raise errors.DSLError( node.location, - f"Argument `init` to scan operator `{node.id}` must have same type as `{carry_arg_name}` argument. " - f"Expected `{carry_type}`, but got `{new_init.type}`", + f"Argument 'init' to scan operator '{node.id}' must have same type as '{carry_arg_name}' argument: " + f"expected '{carry_type}', got '{new_init.type}'.", ) new_type = ts_ffront.ScanOperatorType( @@ -339,7 +341,7 @@ def visit_ScanOperator(self, node: foast.ScanOperator, **kwargs) -> foast.ScanOp def visit_Name(self, node: foast.Name, **kwargs) -> foast.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise errors.DSLError(node.location, f"Undeclared symbol `{node.id}`.") + raise errors.DSLError(node.location, f"Undeclared symbol '{node.id}'.") symbol = symtable[node.id] return foast.Name(id=node.id, type=symbol.type, location=node.location) @@ -362,9 +364,9 @@ def visit_TupleTargetAssign( targets: TargetType = node.targets indices: list[tuple[int, int] | int] = compute_assign_indices(targets, num_elts) - if not any(isinstance(i, tuple) for i in indices) and len(indices) != num_elts: + if not any(isinstance(i, tuple) for i in indices) and len(targets) != num_elts: raise errors.DSLError( - node.location, f"Too many values to unpack (expected {len(indices)})." + node.location, f"Too many values to unpack (expected {len(targets)})." ) new_targets: TargetType = [] @@ -396,7 +398,7 @@ def visit_TupleTargetAssign( new_targets.append(new_target) else: raise errors.DSLError( - node.location, f"Assignment value must be of type tuple! Got: {values.type}" + node.location, f"Assignment value must be of type tuple, got '{values.type}'." ) return foast.TupleTargetAssign(targets=new_targets, value=values, location=node.location) @@ -416,15 +418,14 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: if not isinstance(new_node.condition.type, ts.ScalarType): raise errors.DSLError( node.location, - "Condition for `if` must be scalar. " - f"But got `{new_node.condition.type}` instead.", + "Condition for 'if' must be scalar, " f"got '{new_node.condition.type}' instead.", ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: raise errors.DSLError( node.location, - "Condition for `if` must be of boolean type. " - f"But got `{new_node.condition.type}` instead.", + "Condition for 'if' must be of boolean type, " + f"got '{new_node.condition.type}' instead.", ) for sym in node.annex.propagated_symbols.keys(): @@ -433,8 +434,8 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: ): raise errors.DSLError( node.location, - f"Inconsistent types between two branches for variable `{sym}`. " - f"Got types `{true_type}` and `{false_type}.", + f"Inconsistent types between two branches for variable '{sym}': " + f"got types '{true_type}' and '{false_type}.", ) # TODO: properly patch symtable (new node?) symtable[sym].type = new_node.annex.propagated_symbols[ @@ -455,8 +456,8 @@ def visit_Symbol( raise errors.DSLError( node.location, ( - "type inconsistency: expression was deduced to be " - f"of type {refine_type}, instead of the expected type {node.type}" + "Type inconsistency: expression was deduced to be " + f"of type '{refine_type}', instead of the expected type '{node.type}'." ), ) new_node: foast.Symbol = foast.Symbol( @@ -490,7 +491,7 @@ def visit_Subscript(self, node: foast.Subscript, **kwargs) -> foast.Subscript: new_type = new_value.type case _: raise errors.DSLError( - new_value.location, "Could not deduce type of subscript expression!" + new_value.location, "Could not deduce type of subscript expression." ) return foast.Subscript( @@ -531,13 +532,13 @@ def _deduce_ternaryexpr_type( if condition.type != ts.ScalarType(kind=ts.ScalarKind.BOOL): raise errors.DSLError( condition.location, - f"Condition is of type `{condition.type}` " f"but should be of type `bool`.", + f"Condition is of type '{condition.type}', should be of type 'bool'.", ) if true_expr.type != false_expr.type: raise errors.DSLError( node.location, - f"Left and right types are not the same: `{true_expr.type}` and `{false_expr.type}`", + f"Left and right types are not the same: '{true_expr.type}' and '{false_expr.type}'", ) return true_expr.type @@ -556,7 +557,7 @@ def _deduce_compare_type( for arg in (left, right): if not type_info.is_arithmetic(arg.type): raise errors.DSLError( - arg.location, f"Type {arg.type} can not be used in operator '{node.op}'!" + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) self._check_operand_dtypes_match(node, left=left, right=right) @@ -571,8 +572,8 @@ def _deduce_compare_type( except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote `{left.type}` and `{right.type}` to common type" - f" in call to `{node.op}`.", + f"Could not promote '{left.type}' and '{right.type}' to common type" + f" in call to '{node.op}'.", ) from ex def _deduce_binop_type( @@ -594,7 +595,7 @@ def _deduce_binop_type( for arg in (left, right): if not is_compatible(arg.type): raise errors.DSLError( - arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) left_type = cast(ts.FieldType | ts.ScalarType, left.type) @@ -608,7 +609,7 @@ def _deduce_binop_type( ): raise errors.DSLError( arg.location, - f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", ) try: @@ -616,8 +617,8 @@ def _deduce_binop_type( except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote `{left_type}` and `{right_type}` to common type" - f" in call to `{node.op}`.", + f"Could not promote '{left_type}' and '{right_type}' to common type" + f" in call to '{node.op}'.", ) from ex def _check_operand_dtypes_match( @@ -627,7 +628,7 @@ def _check_operand_dtypes_match( if not type_info.extract_dtype(left.type) == type_info.extract_dtype(right.type): raise errors.DSLError( node.location, - f"Incompatible datatypes in operator `{node.op}`: {left.type} and {right.type}!", + f"Incompatible datatypes in operator '{node.op}': '{left.type}' and '{right.type}'.", ) def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: @@ -644,7 +645,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> foast.UnaryOp: if not is_compatible(new_operand.type): raise errors.DSLError( node.location, - f"Incompatible type for unary operator `{node.op}`: `{new_operand.type}`!", + f"Incompatible type for unary operator '{node.op}': '{new_operand.type}'.", ) return foast.UnaryOp( op=node.op, operand=new_operand, location=node.location, type=new_operand.type @@ -674,13 +675,13 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: new_func, (foast.FunctionDefinition, foast.FieldOperator, foast.ScanOperator, foast.Name), ): - raise errors.DSLError(node.location, "Functions can only be called directly!") + raise errors.DSLError(node.location, "Functions can only be called directly.") elif isinstance(new_func.type, ts.FieldType): pass else: raise errors.DSLError( node.location, - f"Expression of type `{new_func.type}` is not callable, must be a `Function`, `FieldOperator`, `ScanOperator` or `Field`.", + f"Expression of type '{new_func.type}' is not callable, must be a 'Function', 'FieldOperator', 'ScanOperator' or 'Field'.", ) # ensure signature is valid @@ -693,7 +694,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: ) except ValueError as err: raise errors.DSLError( - node.location, f"Invalid argument types in call to `{new_func}`!" + node.location, f"Invalid argument types in call to '{new_func}'." ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) @@ -727,7 +728,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: func_name = cast(foast.Name, node.func).id # validate arguments - error_msg_preamble = f"Incompatible argument in call to `{func_name}`." + error_msg_preamble = f"Incompatible argument in call to '{func_name}'." error_msg_for_validator = { type_info.is_arithmetic: "an arithmetic", type_info.is_floating_point: "a floating point", @@ -741,13 +742,13 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: elif func_name in fbuiltins.BINARY_MATH_NUMBER_BUILTIN_NAMES: arg_validator = type_info.is_arithmetic else: - raise AssertionError(f"Unknown math builtin `{func_name}`.") + raise AssertionError(f"Unknown math builtin '{func_name}'.") error_msgs = [] for i, arg in enumerate(node.args): if not arg_validator(arg.type): error_msgs.append( - f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, but got `{arg.type}`." + f"Expected {i}-th argument to be {error_msg_for_validator[arg_validator]} type, got '{arg.type}'." ) if error_msgs: raise errors.DSLError( @@ -756,7 +757,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: ) if func_name == "power" and all(type_info.is_integral(arg.type) for arg in node.args): - print(f"Warning: return type of {func_name} might be inconsistent (not implemented).") + print(f"Warning: return type of '{func_name}' might be inconsistent (not implemented).") # deduce return type return_type: Optional[ts.FieldType | ts.ScalarType] = None @@ -777,7 +778,7 @@ def _visit_math_built_in(self, node: foast.Call, **kwargs) -> foast.Call: except ValueError as ex: raise errors.DSLError(node.location, error_msg_preamble) from ex else: - raise AssertionError(f"Unknown math builtin `{func_name}`.") + raise AssertionError(f"Unknown math builtin '{func_name}'.") return foast.Call( func=node.func, @@ -796,9 +797,9 @@ def _visit_reduction(self, node: foast.Call, **kwargs) -> foast.Call: field_dims_str = ", ".join(str(dim) for dim in field_type.dims) raise errors.DSLError( node.location, - f"Incompatible field argument in call to `{str(node.func)}`. " - f"Expected a field with dimension {reduction_dim}, but got " - f"{field_dims_str}.", + f"Incompatible field argument in call to '{str(node.func)}'. " + f"Expected a field with dimension '{reduction_dim}', got " + f"'{field_dims_str}'.", ) return_type = ts.FieldType( dims=[dim for dim in field_type.dims if dim != reduction_dim], @@ -834,7 +835,7 @@ def _visit_astype(self, node: foast.Call, **kwargs) -> foast.Call: ]: raise errors.DSLError( node.location, - f"Invalid call to `astype`. Second argument must be a scalar type, but got {new_type}.", + f"Invalid call to 'astype': second argument must be a scalar type, got '{new_type}'.", ) return_type = type_info.apply_to_primitive_constituents( @@ -860,16 +861,16 @@ def _visit_as_offset(self, node: foast.Call, **kwargs) -> foast.Call: if not type_info.is_integral(arg_1): raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`. " - f"Excepted integer for offset field dtype, but got {arg_1.dtype}" + f"Incompatible argument in call to '{str(node.func)}': " + f"expected integer for offset field dtype, got '{arg_1.dtype}'. " f"{node.location}", ) if arg_0.source not in arg_1.dims: raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`. " - f"{arg_0.source} not in list of offset field dimensions {arg_1.dims}. " + f"Incompatible argument in call to '{str(node.func)}': " + f"'{arg_0.source}' not in list of offset field dimensions '{arg_1.dims}'. " f"{node.location}", ) @@ -889,8 +890,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: if not type_info.is_logical(mask_type): raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`. Expected " - f"a field with dtype `bool`, but got `{mask_type}`.", + f"Incompatible argument in call to '{str(node.func)}': expected " + f"a field with dtype 'bool', got '{mask_type}'.", ) try: @@ -907,8 +908,8 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: ): raise errors.DSLError( node.location, - f"Return arguments need to be of same type in {str(node.func)}, but got: " - f"{node.args[1].type} and {node.args[2].type}", + f"Return arguments need to be of same type in '{str(node.func)}', got " + f"'{node.args[1].type}' and '{node.args[2].type}'.", ) else: true_branch_fieldtype = cast(ts.FieldType, true_branch_type) @@ -919,7 +920,7 @@ def _visit_where(self, node: foast.Call, **kwargs) -> foast.Call: except ValueError as ex: raise errors.DSLError( node.location, - f"Incompatible argument in call to `{str(node.func)}`.", + f"Incompatible argument in call to '{str(node.func)}'.", ) from ex return foast.Call( @@ -937,8 +938,8 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: if any([not (isinstance(elt.type, ts.DimensionType)) for elt in broadcast_dims_expr]): raise errors.DSLError( node.location, - f"Incompatible broadcast dimension type in {str(node.func)}. Expected " - f"all broadcast dimensions to be of type Dimension.", + f"Incompatible broadcast dimension type in '{str(node.func)}': expected " + f"all broadcast dimensions to be of type 'Dimension'.", ) broadcast_dims = [cast(ts.DimensionType, elt.type).dim for elt in broadcast_dims_expr] @@ -946,8 +947,8 @@ def _visit_broadcast(self, node: foast.Call, **kwargs) -> foast.Call: if not set((arg_dims := type_info.extract_dims(arg_type))).issubset(set(broadcast_dims)): raise errors.DSLError( node.location, - f"Incompatible broadcast dimensions in {str(node.func)}. Expected " - f"broadcast dimension is missing {set(arg_dims).difference(set(broadcast_dims))}", + f"Incompatible broadcast dimensions in '{str(node.func)}': expected " + f"broadcast dimension(s) '{set(arg_dims).difference(set(broadcast_dims))}' missing", ) return_type = ts.FieldType( diff --git a/src/gt4py/next/ffront/foast_pretty_printer.py b/src/gt4py/next/ffront/foast_pretty_printer.py index 3b81c85265..9275cdda95 100644 --- a/src/gt4py/next/ffront/foast_pretty_printer.py +++ b/src/gt4py/next/ffront/foast_pretty_printer.py @@ -110,7 +110,7 @@ def apply(cls, node: foast.LocatedNode, **kwargs) -> str: # type: ignore[overri node_type_name = type(node).__name__ if not hasattr(cls, node_type_name) and not hasattr(cls, f"visit_{node_type_name}"): raise NotImplementedError( - f"Pretty printer does not support nodes of type " f"`{node_type_name}`." + f"Pretty printer does not support nodes of type '{node_type_name}'." ) return cls().visit(node, **kwargs) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 3030c03fd1..c4d518d279 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -230,7 +230,7 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs) -> itir.Expr: dtype = type_info.extract_dtype(node.type) if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: - raise NotImplementedError(f"{node.op} is only supported on `bool`s.") + raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") return self._map("not_", node.operand) return self._map( @@ -313,7 +313,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> itir.Expr: return im.call(self.visit(node.func, **kwargs))(*lowered_args, *lowered_kwargs.values()) raise AssertionError( - f"Call to object of type {type(node.func.type).__name__} not understood." + f"Call to object of type '{type(node.func.type).__name__}' not understood." ) def _visit_astype(self, node: foast.Call, **kwargs) -> itir.FunCall: @@ -371,7 +371,9 @@ def _visit_type_constr(self, node: foast.Call, **kwargs) -> itir.Expr: im.literal(str(bool(source_type(node.args[0].value))), "bool") ) return im.promote_to_const_iterator(im.literal(str(node.args[0].value), node_kind)) - raise FieldOperatorLoweringError(f"Encountered a type cast, which is not supported: {node}") + raise FieldOperatorLoweringError( + f"Encountered a type cast, which is not supported: {node}." + ) def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; @@ -388,7 +390,7 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: elif isinstance(type_, ts.ScalarType): typename = type_.kind.name.lower() return im.promote_to_const_iterator(im.literal(str(val), typename)) - raise ValueError(f"Unsupported literal type {type_}.") + raise ValueError(f"Unsupported literal type '{type_}'.") def visit_Constant(self, node: foast.Constant, **kwargs) -> itir.Expr: return self._make_literal(node.value, node.type) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index c7c4c3a23f..0fd263308e 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -107,8 +107,9 @@ def _postprocess_dialect_ast( if annotated_return_type != foast_node.type.returns: # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented raise errors.DSLError( foast_node.location, - f"Annotated return type does not match deduced return type. Expected `{foast_node.type.returns}`" # type: ignore[union-attr] # revisit when `type_info.return_type` is implemented - f", but got `{annotated_return_type}`.", + "Annotated return type does not match deduced return type: expected " + f"'{foast_node.type.returns}'" # type: ignore[union-attr] # revisit when 'type_info.return_type' is implemented + f", got '{annotated_return_type}'.", ) return foast_node @@ -167,7 +168,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, **kwargs) -> foast.FunctionDe new_body = self._visit_stmts(node.body, self.get_location(node), **kwargs) if deduce_stmt_return_kind(new_body) == StmtReturnKind.NO_RETURN: - raise errors.DSLError(loc, "Function is expected to return a value.") + raise errors.DSLError(loc, "'Function' is expected to return a value.") return foast.FunctionDefinition( id=node.name, @@ -224,7 +225,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple ) if not isinstance(target, ast.Name): - raise errors.DSLError(self.get_location(node), "can only assign to names") + raise errors.DSLError(self.get_location(node), "Can only assign to names.") new_value = self.visit(node.value) constraint_type: Type[ts.DataType] = ts.DataType if isinstance(new_value, foast.TupleExpr): @@ -246,7 +247,7 @@ def visit_Assign(self, node: ast.Assign, **kwargs) -> foast.Assign | foast.Tuple def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs) -> foast.Assign: if not isinstance(node.target, ast.Name): - raise errors.DSLError(self.get_location(node), "can only assign to names") + raise errors.DSLError(self.get_location(node), "Can only assign to names.") if node.annotation is not None: assert isinstance( @@ -281,14 +282,14 @@ def _match_index(node: ast.expr) -> int: return -node.operand.value if isinstance(node.op, ast.UAdd): return node.operand.value - raise ValueError(f"Not an index: {node}") + raise ValueError(f"Not an index: '{node}'.") def visit_Subscript(self, node: ast.Subscript, **kwargs) -> foast.Subscript: try: index = self._match_index(node.slice) except ValueError: raise errors.DSLError( - self.get_location(node.slice), "expected an integral index" + self.get_location(node.slice), "eXpected an integral index." ) from None return foast.Subscript( @@ -310,7 +311,7 @@ def visit_Tuple(self, node: ast.Tuple, **kwargs) -> foast.TupleExpr: def visit_Return(self, node: ast.Return, **kwargs) -> foast.Return: loc = self.get_location(node) if not node.value: - raise errors.DSLError(loc, "must return a value, not None") + raise errors.DSLError(loc, "Must return a value, not None") return foast.Return(value=self.visit(node.value), location=loc) def visit_Expr(self, node: ast.Expr) -> foast.Expr: @@ -442,11 +443,11 @@ def _verify_builtin_type_constructor(self, node: ast.Call): if len(node.args) > 0 and not isinstance(node.args[0], ast.Constant): raise errors.DSLError( self.get_location(node), - f"{self._func_name(node)}() only takes literal arguments!", + f"'{self._func_name(node)}()' only takes literal arguments.", ) def _func_name(self, node: ast.Call) -> str: - return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. + return node.func.id # type: ignore[attr-defined] # We want this to fail if the attribute does not exist unexpectedly. def visit_Call(self, node: ast.Call, **kwargs) -> foast.Call: # TODO(tehrengruber): is this still needed or redundant with the checks in type deduction? @@ -468,7 +469,7 @@ def visit_Constant(self, node: ast.Constant, **kwargs) -> foast.Constant: type_ = type_translation.from_value(node.value) except ValueError: raise errors.DSLError( - loc, f"constants of type {type(node.value)} are not permitted" + loc, f"Constants of type {type(node.value)} are not permitted." ) from None return foast.Constant( diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index 7b04e90902..5b4dd934b9 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -129,7 +129,7 @@ def visit_Call(self, node: ast.Call) -> past.Call: new_func = self.visit(node.func) if not isinstance(new_func, past.Name): raise errors.DSLError( - loc, "functions must be referenced by their name in function calls" + loc, "Functions must be referenced by their name in function calls." ) return past.Call( @@ -166,7 +166,7 @@ def visit_UnaryOp(self, node: ast.UnaryOp) -> past.Constant: if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Constant): symbol_type = type_translation.from_value(node.operand.value) return past.Constant(value=-node.operand.value, type=symbol_type, location=loc) - raise errors.DSLError(loc, "unary operators are only applicable to literals") + raise errors.DSLError(loc, "Unary operators are only applicable to literals.") def visit_Constant(self, node: ast.Constant) -> past.Constant: symbol_type = type_translation.from_value(node.value) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index ed3bdae3ff..fc353d64e4 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -33,7 +33,7 @@ def _ensure_no_sliced_field(entry: past.Expr): For example, if argument is of type past.Subscript, this function will throw an error as both slicing and domain are being applied """ if not isinstance(entry, past.Name) and not isinstance(entry, past.TupleExpr): - raise ValueError("Either only domain or slicing allowed") + raise ValueError("Either only domain or slicing allowed.") elif isinstance(entry, past.TupleExpr): for param in entry.elts: _ensure_no_sliced_field(param) @@ -57,20 +57,18 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), ): raise ValueError( - f"Only calls `FieldOperator`s and `ScanOperator`s " - f"allowed in `Program`, but got `{new_func.type}`." + f"Only calls to 'FieldOperators' and 'ScanOperators' " + f"allowed in 'Program', got '{new_func.type}'." ) if "out" not in new_kwargs: - raise ValueError("Missing required keyword argument(s) `out`.") + raise ValueError("Missing required keyword argument 'out'.") if "domain" in new_kwargs: _ensure_no_sliced_field(new_kwargs["out"]) domain_kwarg = new_kwargs["domain"] if not isinstance(domain_kwarg, past.Dict): - raise ValueError( - f"Only Dictionaries allowed in domain, but got `{type(domain_kwarg)}`." - ) + raise ValueError(f"Only Dictionaries allowed in 'domain', got '{type(domain_kwarg)}'.") if len(domain_kwarg.values_) == 0 and len(domain_kwarg.keys_) == 0: raise ValueError("Empty domain not allowed.") @@ -78,18 +76,18 @@ def _validate_operator_call(new_func: past.Name, new_kwargs: dict): for dim in domain_kwarg.keys_: if not isinstance(dim.type, ts.DimensionType): raise ValueError( - f"Only Dimension allowed in domain dictionary keys, but got `{dim}` which is of type `{dim.type}`." + f"Only 'Dimension' allowed in domain dictionary keys, got '{dim}' which is of type '{dim.type}'." ) for domain_values in domain_kwarg.values_: if len(domain_values.elts) != 2: raise ValueError( - f"Only 2 values allowed in domain range, but got `{len(domain_values.elts)}`." + f"Only 2 values allowed in domain range, got {len(domain_values.elts)}." ) if not _is_integral_scalar(domain_values.elts[0]) or not _is_integral_scalar( domain_values.elts[1] ): raise ValueError( - f"Only integer values allowed in domain range, but got {domain_values.elts[0].type} and {domain_values.elts[1].type}." + f"Only integer values allowed in domain range, got '{domain_values.elts[0].type}' and '{domain_values.elts[1].type}'." ) @@ -149,7 +147,7 @@ def _deduce_binop_type( for arg in (left, right): if not isinstance(arg.type, ts.ScalarType) or not is_compatible(arg.type): raise errors.DSLError( - arg.location, f"Type {arg.type} can not be used in operator `{node.op}`!" + arg.location, f"Type '{arg.type}' can not be used in operator '{node.op}'." ) left_type = cast(ts.ScalarType, left.type) @@ -163,7 +161,7 @@ def _deduce_binop_type( ): raise errors.DSLError( arg.location, - f"Type {right_type} can not be used in operator `{node.op}`, it can only accept ints", + f"Type '{right_type}' can not be used in operator '{node.op}', it only accepts 'int'.", ) try: @@ -171,8 +169,8 @@ def _deduce_binop_type( except ValueError as ex: raise errors.DSLError( node.location, - f"Could not promote `{left_type}` and `{right_type}` to common type" - f" in call to `{node.op}`.", + f"Could not promote '{left_type}' and '{right_type}' to common type" + f" in call to '{node.op}'.", ) from ex def visit_BinOp(self, node: past.BinOp, **kwargs) -> past.BinOp: @@ -214,24 +212,24 @@ def visit_Call(self, node: past.Call, **kwargs): ) if operator_return_type != new_kwargs["out"].type: raise ValueError( - f"Expected keyword argument `out` to be of " - f"type {operator_return_type}, but got " - f"{new_kwargs['out'].type}." + "Expected keyword argument 'out' to be of " + f"type '{operator_return_type}', got " + f"'{new_kwargs['out'].type}'." ) elif new_func.id in ["minimum", "maximum"]: if new_args[0].type != new_args[1].type: raise ValueError( - f"First and second argument in {new_func.id} must be the same type." - f"Got `{new_args[0].type}` and `{new_args[1].type}`." + f"First and second argument in '{new_func.id}' must be of the same type." + f"Got '{new_args[0].type}' and '{new_args[1].type}'." ) return_type = new_args[0].type else: raise AssertionError( - "Only calls `FieldOperator`s, `ScanOperator`s or minimum and maximum builtins allowed" + "Only calls to 'FieldOperator', 'ScanOperator' or 'minimum' and 'maximum' builtins allowed." ) except ValueError as ex: - raise errors.DSLError(node.location, f"Invalid call to `{node.func.id}`.") from ex + raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.") from ex return past.Call( func=new_func, @@ -244,6 +242,6 @@ def visit_Call(self, node: past.Call, **kwargs): def visit_Name(self, node: past.Name, **kwargs) -> past.Name: symtable = kwargs["symtable"] if node.id not in symtable or symtable[node.id].type is None: - raise errors.DSLError(node.location, f"Undeclared or untyped symbol `{node.id}`.") + raise errors.DSLError(node.location, f"Undeclared or untyped symbol '{node.id}'.") return past.Name(id=node.id, type=symtable[node.id].type, location=node.location) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 2c5dfc6e2f..709912077b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -37,7 +37,7 @@ def _flatten_tuple_expr( for e in node.elts: result.extend(_flatten_tuple_expr(e)) return result - raise ValueError("Only `past.Name`, `past.Subscript` or `past.TupleExpr`s thereof are allowed.") + raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator): @@ -174,7 +174,7 @@ def _visit_slice_bound( else: lowered_bound = self.visit(slice_bound, **kwargs) else: - raise AssertionError("Expected `None` or `past.Constant`.") + raise AssertionError("Expected 'None' or 'past.Constant'.") return lowered_bound def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: @@ -189,8 +189,8 @@ def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: ) else: raise ValueError( - "Unexpected `out` argument. Must be a `past.Name`, `past.Subscript`" - " or a `past.TupleExpr` thereof." + "Unexpected 'out' argument. Must be a 'past.Name', 'past.Subscript'" + " or a 'past.TupleExpr' thereof." ) def _construct_itir_domain_arg( @@ -209,9 +209,9 @@ def _construct_itir_domain_arg( for out_field_type in out_field_types ): raise AssertionError( - f"Expected constituents of `{out_field.id}` argument to be" - f" fields defined on the same dimensions. This error should be " - f" caught in type deduction already." + f"Expected constituents of '{out_field.id}' argument to be" + " fields defined on the same dimensions. This error should be " + " caught in type deduction already." ) for dim_i, dim in enumerate(out_dims): @@ -232,7 +232,7 @@ def _construct_itir_domain_arg( ) if dim.kind == DimensionKind.LOCAL: - raise ValueError(f"Dimension {dim.value} must not be local.") + raise ValueError(f"Dimension '{dim.value}' must not be local.") domain_args.append( itir.FunCall( fun=itir.SymRef(id="named_range"), @@ -259,8 +259,8 @@ def _construct_itir_initialized_domain_arg( keys_dims_types = cast(ts.DimensionType, node_domain.keys_[dim_i].type).dim if keys_dims_types != dim: raise ValueError( - f"Dimensions in out field and field domain are not equivalent" - f"Expected {dim}, but got {keys_dims_types} " + "Dimensions in out field and field domain are not equivalent:" + f"expected '{dim}', got '{keys_dims_types}'." ) return [self.visit(bound) for bound in node_domain.values_[dim_i].elts] @@ -277,13 +277,13 @@ def _compute_field_slice(node: past.Subscript): out_field_slice_ = [node.slice_] else: raise AssertionError( - "Unexpected `out` argument. Must be tuple of slices or slice expression." + "Unexpected 'out' argument, must be tuple of slices or slice expression." ) node_dims_ls = cast(ts.FieldType, node.type).dims assert isinstance(node_dims_ls, list) if isinstance(node.type, ts.FieldType) and len(out_field_slice_) != len(node_dims_ls): raise ValueError( - f"Too many indices for field {out_field_name}: field is {len(node_dims_ls)}" + f"Too many indices for field '{out_field_name}': field is {len(node_dims_ls)}" f"-dimensional, but {len(out_field_slice_)} were indexed." ) return out_field_slice_ @@ -321,7 +321,11 @@ def _visit_stencil_call_out_arg( isinstance(field, past.Subscript) for field in flattened ), "Incompatible field in tuple: either all fields or no field must be sliced." assert all( - concepts.eq_nonlocated(first_field.slice_, field.slice_) for field in flattened # type: ignore[union-attr] # mypy cannot deduce type + concepts.eq_nonlocated( + first_field.slice_, + field.slice_, # type: ignore[union-attr] # mypy cannot deduce type + ) + for field in flattened ), "Incompatible field in tuple: all fields must be sliced in the same way." field_slice = self._compute_field_slice(first_field) first_field = first_field.value @@ -332,7 +336,7 @@ def _visit_stencil_call_out_arg( ) else: raise AssertionError( - "Unexpected `out` argument. Must be a `past.Subscript`, `past.Name` or `past.TupleExpr` node." + "Unexpected 'out' argument. Must be a 'past.Subscript', 'past.Name' or 'past.TupleExpr' node." ) def visit_Constant(self, node: past.Constant, **kwargs) -> itir.Literal: @@ -340,7 +344,7 @@ def visit_Constant(self, node: past.Constant, **kwargs) -> itir.Literal: match node.type.kind: case ts.ScalarKind.STRING: raise NotImplementedError( - f"Scalars of kind {node.type.kind} not supported currently." + f"Scalars of kind '{node.type.kind}' not supported currently." ) typename = node.type.kind.name.lower() return itir.Literal(value=str(node.value), type=typename) @@ -373,5 +377,5 @@ def visit_Call(self, node: past.Call, **kwargs) -> itir.FunCall: ) else: raise AssertionError( - "Only `minimum` and `maximum` builtins supported supported currently." + "Only 'minimum' and 'maximum' builtins supported supported currently." ) diff --git a/src/gt4py/next/ffront/source_utils.py b/src/gt4py/next/ffront/source_utils.py index 17b2050b1b..baf3037d5e 100644 --- a/src/gt4py/next/ffront/source_utils.py +++ b/src/gt4py/next/ffront/source_utils.py @@ -37,7 +37,7 @@ def make_source_definition_from_function(func: Callable) -> SourceDefinition: filename = str(pathlib.Path(inspect.getabsfile(func)).resolve()) if not filename: raise ValueError( - "Can not create field operator from a function that is not in a source file!" + "Can not create field operator from a function that is not in a source file." ) source_lines, line_offset = inspect.getsourcelines(func) source_code = textwrap.dedent(inspect.getsource(func)) @@ -47,7 +47,7 @@ def make_source_definition_from_function(func: Callable) -> SourceDefinition: return SourceDefinition(source_code, filename, line_offset - 1, column_offset) except OSError as err: - raise ValueError(f"Can not get source code of passed function ({func})") from err + raise ValueError(f"Can not get source code of passed function '{func}'.") from err def make_symbol_names_from_source(source: str, filename: str = MISSING_FILENAME) -> SymbolNames: @@ -55,13 +55,13 @@ def make_symbol_names_from_source(source: str, filename: str = MISSING_FILENAME) mod_st = symtable.symtable(source, filename, "exec") except SyntaxError as err: raise ValueError( - f"Unexpected error when parsing provided source code (\n{source}\n)" + f"Unexpected error when parsing provided source code: \n{source}\n" ) from err assert mod_st.get_type() == "module" if len(children := mod_st.get_children()) != 1: raise ValueError( - f"Sources with multiple function definitions are not yet supported (\n{source}\n)" + f"Sources with multiple function definitions are not yet supported: \n{source}\n" ) assert children[0].get_type() == "function" diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 7f56f5d92b..affae8fbca 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -51,7 +51,7 @@ def _as_field(arg_el: ts.TypeSpec, path: tuple[int, ...]) -> ts.TypeSpec: if type_info.extract_dtype(param_el) == type_info.extract_dtype(arg_el): return param_el else: - raise ValueError(f"{arg_el} is not compatible with {param_el}.") + raise ValueError(f"'{arg_el}' is not compatible with '{param_el}'.") return arg_el return type_info.apply_to_primitive_constituents(arg, _as_field, with_path_arg=True) diff --git a/src/gt4py/next/iterator/dispatcher.py b/src/gt4py/next/iterator/dispatcher.py index b2ca39df04..626c51ed1c 100644 --- a/src/gt4py/next/iterator/dispatcher.py +++ b/src/gt4py/next/iterator/dispatcher.py @@ -57,7 +57,7 @@ def register_key(self, key): def push_key(self, key): if key not in self._funs: - raise RuntimeError(f"Key {key} not registered") + raise RuntimeError(f"Key '{key}' not registered.") self.key_stack.append(key) def pop_key(self): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index b00e53bfd9..a4f32929db 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -238,7 +238,7 @@ def _validate_kstart(self, args): set(arg.kstart for arg in args if isinstance(arg, Column)) - {self.kstart} ): raise ValueError( - "Incompatible Column.kstart: it should be '{self.kstart}' but found other values: {wrong_kstarts}" + "Incompatible 'Column.kstart': it should be '{self.kstart}' but found other values: {wrong_kstarts}." ) def __array_ufunc__(self, ufunc, method, *inputs, **kwargs) -> Column: @@ -486,7 +486,7 @@ def promote_scalars(val: CompositeOfScalarOrField): return constant_field(val) else: raise ValueError( - f"Expected a `Field` or a number (`float`, `np.int64`, ...), but got {val_type}." + f"Expected a 'Field' or a number ('float', 'np.int64', ...), got '{val_type}'." ) @@ -566,7 +566,7 @@ def execute_shift( return new_pos - raise AssertionError("Unknown object in `offset_provider`") + raise AssertionError("Unknown object in 'offset_provider'.") def _is_list_of_complete_offsets( @@ -878,7 +878,7 @@ def make_in_iterator( return SparseListIterator(it, sparse_dimensions[0]) else: raise NotImplementedError( - f"More than one local dimension is currently not supported, got {sparse_dimensions}" + f"More than one local dimension is currently not supported, got {sparse_dimensions}." ) else: return it @@ -925,7 +925,7 @@ def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if common.is_mutable_field(self._ndarrayfield): self._ndarrayfield[self._translate_named_indices(named_indices)] = value else: - raise RuntimeError("Assigment into a non-mutable Field.") + raise RuntimeError("Assigment into a non-mutable Field is not allowed.") @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1023,7 +1023,7 @@ def np_as_located_field( def _maker(a) -> common.Field: if a.ndim != len(axes): - raise TypeError("ndarray.ndim incompatible with number of given dimensions") + raise TypeError("'ndarray.ndim' is incompatible with number of given dimensions.") ranges = [] for d, s in zip(axes, a.shape): offset = origin.get(d, 0) @@ -1071,7 +1071,7 @@ def dtype(self) -> core_defs.Int32DType: @property def ndarray(self) -> core_defs.NDArrayObject: - raise AttributeError("Cannot get `ndarray` of an infinite Field.") + raise AttributeError("Cannot get 'ndarray' of an infinite 'Field'.") def asnumpy(self) -> np.ndarray: raise NotImplementedError() @@ -1190,7 +1190,7 @@ def codomain(self) -> type[core_defs.ScalarT]: @property def ndarray(self) -> core_defs.NDArrayObject: - raise AttributeError("Cannot get `ndarray` of an infinite Field.") + raise AttributeError("Cannot get 'ndarray' of an infinite 'Field'.") def asnumpy(self) -> np.ndarray: raise NotImplementedError() @@ -1440,7 +1440,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: if isinstance(field, tuple): if len(field) != len(value): raise RuntimeError( - f"Tuple of incompatible size, expected tuple of len={len(field)}, got len={len(value)}" + f"Tuple of incompatible size, expected tuple of 'len={len(field)}', got 'len={len(value)}'." ) for f, v in zip(field, value): _tuple_assign(f, v, named_indices) @@ -1459,7 +1459,7 @@ def field_getitem(self, named_indices: NamedFieldIndices) -> Any: def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if not isinstance(value, tuple): - raise RuntimeError(f"Value needs to be tuple, got `{value}`.") + raise RuntimeError(f"Value needs to be tuple, got '{value}'.") _tuple_assign(self.data, value, named_indices) @@ -1503,13 +1503,13 @@ def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: if isinstance(domain, runtime.CartesianDomain): if any(isinstance(o, common.Connectivity) for o in offset_provider.values()): raise RuntimeError( - "Got a `CartesianDomain`, but found a `Connectivity` in `offset_provider`, expected `UnstructuredDomain`." + "Got a 'CartesianDomain', but found a 'Connectivity' in 'offset_provider', expected 'UnstructuredDomain'." ) def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any): if "offset_provider" not in kwargs: - raise RuntimeError("offset_provider not provided") + raise RuntimeError("'offset_provider' not provided.") offset_provider = kwargs["offset_provider"] @@ -1523,7 +1523,7 @@ def closure( _validate_domain(domain_, kwargs["offset_provider"]) domain: dict[Tag, range] = _dimension_to_tag(domain_) if not (common.is_field(out) or is_tuple_of_field(out)): - raise TypeError("Out needs to be a located field.") + raise TypeError("'Out' needs to be a located field.") column_range = None column: Optional[ColumnDescriptor] = None diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 535648cc47..e6ee20e227 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -49,13 +49,13 @@ class Sym(Node): # helper @datamodels.validator("kind") def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): if value and value not in ["Iterator", "Value"]: - raise ValueError(f"Invalid kind `{value}`, must be one of `Iterator`, `Value`.") + raise ValueError(f"Invalid kind '{value}', must be one of 'Iterator', 'Value'.") @datamodels.validator("dtype") def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): if value and value[0] not in TYPEBUILTINS: raise ValueError( - f"Invalid dtype `{value}`, must be one of `{'`, `'.join(TYPEBUILTINS)}`." + f"Invalid dtype '{value}', must be one of '{', '.join(TYPEBUILTINS)}'." ) @@ -71,7 +71,7 @@ class Literal(Expr): @datamodels.validator("type") def _type_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): if value not in TYPEBUILTINS: - raise ValueError(f"{value} is not a valid builtin type.") + raise ValueError(f"'{value}' is not a valid builtin type.") class NoneLiteral(Expr): @@ -115,7 +115,7 @@ class StencilClosure(Node): @datamodels.validator("output") def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): - raise ValueError("Only FunCall to `make_tuple` allowed.") + raise ValueError("Only FunCall to 'make_tuple' allowed.") UNARY_MATH_NUMBER_BUILTINS = {"abs"} diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index f7086ada0c..94a2646422 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -295,7 +295,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Literal(value='True', type='bool') """ if not isinstance(val, core_defs.Scalar): # type: ignore[arg-type] # mypy bug #11673 - raise ValueError(f"Value must be a scalar, but got {type(val).__name__}") + raise ValueError(f"Value must be a scalar, got '{type(val).__name__}'.") # At the time this has been written the iterator module has its own type system that is # uncoupled from the one used in the frontend. However since we decided to eventually replace diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index ffc00e474b..e12ae84dbc 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -96,7 +96,7 @@ def __call__(self, *args, backend: Optional[ProgramExecutor] = None, **kwargs): backend(self.itir(*args, **kwargs), *args, **kwargs) else: if fendef_embedded is None: - raise RuntimeError("Embedded execution is not registered") + raise RuntimeError("Embedded execution is not registered.") fendef_embedded(self.function, *args, **kwargs) def format_itir(self, *args, formatter: ProgramFormatter, **kwargs) -> str: diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index d1f6bba8d6..30fec1f9fd 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -164,7 +164,7 @@ def make_node(o): return NoneLiteral() if hasattr(o, "fun"): return SymRef(id=o.fun.__name__) - raise NotImplementedError(f"Cannot handle {o}") + raise NotImplementedError(f"Cannot handle '{o}'.") def trace_function_call(fun, *, args=None): @@ -269,7 +269,7 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: # the last parameter info might also be a keyword or variadic keyword argument, but # they are not supported. raise NotImplementedError( - "Only `POSITIONAL_OR_KEYWORD` or `VAR_POSITIONAL` parameters are supported." + "Only 'POSITIONAL_OR_KEYWORD' or 'VAR_POSITIONAL' parameters are supported." ) param_info = param_infos[-1] @@ -279,7 +279,7 @@ def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: param_name = param_info.name else: raise NotImplementedError( - "Only `POSITIONAL_OR_KEYWORD` or `VAR_POSITIONAL` parameters are supported." + "Only 'POSITIONAL_OR_KEYWORD' or 'VAR_POSITIONAL' parameters are supported." ) kind, dtype = None, None diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index cc70e11413..034a39d68f 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -123,7 +123,7 @@ def generic_visit(self, *args, **kwargs): depth = kwargs.pop("depth") return super().generic_visit(*args, depth=depth + 1, **kwargs) - def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here. + def visit(self, node: ir.Node, **kwargs) -> None: # type: ignore[override] # supertype accepts any node, but we want to be more specific here. if not isinstance(node, SymbolTableTrait) and not _is_collectable_expr(node): return super().visit(node, **kwargs) @@ -289,7 +289,7 @@ def extract_subexpression( # `_subexpr_2`: `x + y + (x + y)` raise NotImplementedError( "Results of the current implementation not meaningful for " - "`deepest_expr_first == True` and `once_only == True`." + "'deepest_expr_first == True' and 'once_only == True'." ) ignored_children = False diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index e2feb79c44..2e05391634 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -68,7 +68,7 @@ def _inline_into_scan(ir, *, max_iter=10): break ir = inlined else: - raise RuntimeError(f"Inlining into scan did not converge with {max_iter} iterations.") + raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") return ir @@ -117,7 +117,7 @@ def apply_common_transforms( break ir = inlined else: - raise RuntimeError("Inlining lift and lambdas did not converge.") + raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 60a5db7e96..861052bb25 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -81,7 +81,7 @@ def _get_connectivity( ) -> common.Connectivity: """Return single connectivity that is compatible with the arguments of the reduce.""" if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a `reduce` object, i.e. `reduce(...)(...)`.") + raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") connectivities: list[common.Connectivity] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): @@ -90,11 +90,11 @@ def _get_connectivity( connectivities.append(conn) if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of reduce.") + raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to reduce have incompatible partial shifts.") + raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") return connectivities[0] diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 14f3e95e10..2375118cd1 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -74,7 +74,7 @@ def from_elems(cls: typing.Type[T], *elems: Type) -> typing.Union[T, EmptyTuple] def __iter__(self) -> abc.Iterator[Type]: yield self.front if not isinstance(self.others, (Tuple, EmptyTuple)): - raise ValueError(f"Can not iterate over partially defined tuple {self}") + raise ValueError(f"Can not iterate over partially defined tuple '{self}'.") yield from self.others def __len__(self) -> int: @@ -286,7 +286,7 @@ def handle_constraint( if self.name != other.name: raise TypeError( - f"Can not satisfy constraint on primitive types: {self.name} ≡ {other.name}" + f"Can not satisfy constraint on primitive types: '{self.name}' ≡ '{other.name}'." ) return True @@ -300,7 +300,7 @@ def handle_constraint( self, other: Type, add_constraint: abc.Callable[[Type, Type], None] ) -> bool: if isinstance(other, UnionPrimitive): - raise AssertionError("`UnionPrimitive` may only appear on one side of a constraint.") + raise AssertionError("'UnionPrimitive' may only appear on one side of a constraint.") if not isinstance(other, Primitive): return False @@ -551,7 +551,8 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints): current_loc_out = current_loc_in for arg in shift_args: if not isinstance(arg, ir.OffsetLiteral): - continue # probably some dynamically computed offset, thus we assume it’s a number not an axis and just ignore it (see comment below) + # probably some dynamically computed offset, thus we assume it’s a number not an axis and just ignore it (see comment below) + continue offset = arg.value if isinstance(offset, int): continue # ignore ‘application’ of (partial) shifts @@ -639,7 +640,7 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: elif node.id in ir.GRAMMAR_BUILTINS: raise TypeError( f"Builtin '{node.id}' is only allowed as applied/called function by the type " - f"inference." + "inference." ) elif node.id in ir.TYPEBUILTINS: # TODO(tehrengruber): Implement propagating types of values referring to types, e.g. @@ -649,10 +650,10 @@ def visit_SymRef(self, node: ir.SymRef, *, symtable, **kwargs) -> Type: # `typing.Type`. raise NotImplementedError( f"Type builtin '{node.id}' is only supported as literal argument by the " - f"type inference." + "type inference." ) else: - raise NotImplementedError(f"Missing type definition for builtin '{node.id}'") + raise NotImplementedError(f"Missing type definition for builtin '{node.id}'.") elif node.id in symtable: sym_decl = symtable[node.id] assert isinstance(sym_decl, TYPED_IR_NODES) @@ -696,13 +697,13 @@ def _visit_make_tuple(self, node: ir.FunCall, **kwargs) -> Type: def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: # Calls to `tuple_get` are handled as being part of the grammar, not as function calls. if len(node.args) != 2: - raise TypeError("`tuple_get` requires exactly two arguments.") + raise TypeError("'tuple_get' requires exactly two arguments.") if ( not isinstance(node.args[0], ir.Literal) or node.args[0].type != ir.INTEGER_INDEX_BUILTIN ): raise TypeError( - f"The first argument to `tuple_get` must be a literal of type `{ir.INTEGER_INDEX_BUILTIN}`." + f"The first argument to 'tuple_get' must be a literal of type '{ir.INTEGER_INDEX_BUILTIN}'." ) self.visit(node.args[0], **kwargs) # visit index so that its type is collected idx = int(node.args[0].value) @@ -725,9 +726,9 @@ def _visit_tuple_get(self, node: ir.FunCall, **kwargs) -> Type: def _visit_neighbors(self, node: ir.FunCall, **kwargs) -> Type: if len(node.args) != 2: - raise TypeError("`neighbors` requires exactly two arguments.") + raise TypeError("'neighbors' requires exactly two arguments.") if not (isinstance(node.args[0], ir.OffsetLiteral) and isinstance(node.args[0].value, str)): - raise TypeError("The first argument to `neighbors` must be an `OffsetLiteral` tag.") + raise TypeError("The first argument to 'neighbors' must be an 'OffsetLiteral' tag.") # Visit arguments such that their type is also inferred self.visit(node.args, **kwargs) @@ -766,11 +767,11 @@ def _visit_neighbors(self, node: ir.FunCall, **kwargs) -> Type: def _visit_cast_(self, node: ir.FunCall, **kwargs) -> Type: if len(node.args) != 2: - raise TypeError("`cast_` requires exactly two arguments.") + raise TypeError("'cast_' requires exactly two arguments.") val_arg_type = self.visit(node.args[0], **kwargs) type_arg = node.args[1] if not isinstance(type_arg, ir.SymRef) or type_arg.id not in ir.TYPEBUILTINS: - raise TypeError("The second argument to `cast_` must be a type literal.") + raise TypeError("The second argument to 'cast_' must be a type literal.") size = TypeVar.fresh() @@ -964,7 +965,7 @@ def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None: and child_node.id in ir.GRAMMAR_BUILTINS | ir.TYPEBUILTINS ): raise AssertionError( - f"Expected a type to be inferred for node `{child_node}`, but none was found." + f"Expected a type to be inferred for node '{child_node}', but none was found." ) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 5d54512bd0..bfb3b0d474 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -206,7 +206,7 @@ def create_bindings( """ if program_source.language not in [languages.Cpp, languages.Cuda]: raise ValueError( - f"Can only create bindings for C++ program sources, received {program_source.language}." + f"Can only create bindings for C++ program sources, received '{program_source.language}'." ) wrapper_name = program_source.entry_point.name + "_wrapper" diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 5ea4ba0519..2c0511ebf4 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -101,7 +101,7 @@ def visit_FindDependency(self, dep: FindDependency): return f"find_package(GridTools REQUIRED PATHS {gridtools_cpp.get_cmake_dir()} NO_DEFAULT_PATH)" case _: - raise ValueError("Library {name} is not supported".format(name=dep.name)) + raise ValueError(f"Library '{dep.name}' is not supported") def visit_LinkDependency(self, dep: LinkDependency): # TODO(ricoh): do not add more libraries here @@ -115,7 +115,7 @@ def visit_LinkDependency(self, dep: LinkDependency): case "gridtools_gpu": lib_name = "GridTools::fn_gpu" case _: - raise ValueError("Library {name} is not supported".format(name=dep.name)) + raise ValueError(f"Library '{dep.name}' is not supported") cfg = "" if dep.name == "nanobind": diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index dacb444207..9fd20b16e2 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -80,7 +80,7 @@ def __call__( if not new_data or not is_compiled(new_data) or not module_exists(new_data, src_dir): raise CompilationError( - "On-the-fly compilation unsuccessful for {inp.source_module.entry_point.name}!" + f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) return getattr( diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 0370b5eeb3..a21bc83c0b 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -59,7 +59,7 @@ class ProgramSource(Generic[SrcL, SettingT]): def __post_init__(self): if not isinstance(self.language_settings, self.language.settings_class): raise TypeError( - f"Wrong language settings type for {self.language}, must be subclass of {self.language.settings_class}" + f"Wrong language settings type for '{self.language}', must be subclass of '{self.language.settings_class}'." ) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 6b6b91a310..ed8b768972 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -80,7 +80,7 @@ def replace(self, **kwargs: Any) -> Self: TypeError: If `self` is not a dataclass. """ if not dataclasses.is_dataclass(self): - raise TypeError(f"{self.__class__} is not a dataclass") + raise TypeError(f"'{self.__class__}' is not a dataclass.") assert not isinstance(self, type) return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`? @@ -242,7 +242,9 @@ class CachedStep( """ step: Workflow[StartT, EndT] - hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] + hash_function: Callable[[StartT], HashT] = dataclasses.field( + default=hash + ) # type: ignore[assignment] _cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index f412386bb3..74fbbfc93f 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -99,7 +99,7 @@ def _get_connectivity( ) -> common.Connectivity: """Return single connectivity that is compatible with the arguments of the reduce.""" if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a `reduce` object, i.e. `reduce(...)(...)`.") + raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") connectivities: list[common.Connectivity] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): @@ -108,11 +108,11 @@ def _get_connectivity( connectivities.append(conn) if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of reduce.") + raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to reduce have incompatible partial shifts.") + raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") return connectivities[0] diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 7bf310f4e1..4abdaa6eea 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -135,7 +135,7 @@ def _process_connectivity_args( if isinstance(connectivity, Connectivity): if connectivity.index_type not in [np.int32, np.int64]: raise ValueError( - "Neighbor table indices must be of type `np.int32` or `np.int64`." + "Neighbor table indices must be of type 'np.int32' or 'np.int64'." ) # parameter @@ -165,8 +165,8 @@ def _process_connectivity_args( pass else: raise AssertionError( - f"Expected offset provider `{name}` to be a `Connectivity` or `Dimension`, " - f"but got {type(connectivity).__name__}." + f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " + f"got '{type(connectivity).__name__}'." ) return parameters, arg_exprs diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index f78a052679..842080f8ae 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -59,7 +59,7 @@ def pytype_to_cpptype(t: str): "axis_literal": None, # TODO: domain? }[t] except KeyError: - raise TypeError(f"Unsupported type '{t}'") from None + raise TypeError(f"Unsupported type '{t}'.") from None _vertical_dimension = "gtfn::unstructured::dim::vertical" @@ -83,7 +83,7 @@ def _get_gridtype(closures: list[itir.StencilClosure]) -> common.GridType: grid_types = {_extract_grid_type(d) for d in domains} if len(grid_types) != 1: raise ValueError( - f"Found StencilClosures with more than one GridType: {grid_types}. This is currently not supported." + f"Found 'StencilClosures' with more than one 'GridType': '{grid_types}'. This is currently not supported." ) return grid_types.pop() @@ -109,7 +109,7 @@ def _collect_dimensions_from_domain( offset_definitions[dim_name] = TagDefinition(name=Sym(id=dim_name)) elif domain.fun == itir.SymRef(id="unstructured_domain"): if len(domain.args) > 2: - raise ValueError("unstructured_domain must not have more than 2 arguments.") + raise ValueError("Unstructured_domain must not have more than 2 arguments.") if len(domain.args) > 0: horizontal_range = domain.args[0] assert isinstance(horizontal_range, itir.FunCall) @@ -126,7 +126,7 @@ def _collect_dimensions_from_domain( ) else: raise AssertionError( - "Expected either a call to `cartesian_domain` or to `unstructured_domain`." + "Expected either a call to 'cartesian_domain' or to 'unstructured_domain'." ) return offset_definitions @@ -181,7 +181,7 @@ def _collect_offset_definitions( ) else: raise AssertionError( - "Elements of offset provider need to be either `Dimension` or `Connectivity`." + "Elements of offset provider need to be either 'Dimension' or 'Connectivity'." ) return offset_definitions @@ -233,7 +233,7 @@ def apply( fencil_definition = node else: raise TypeError( - f"Expected a `FencilDefinition` or `FencilWithTemporaries`, but got `{type(node).__name__}`." + f"Expected a 'FencilDefinition' or 'FencilWithTemporaries', got '{type(node).__name__}'." ) grid_type = _get_gridtype(fencil_definition.closures) @@ -303,7 +303,7 @@ def _make_domain(self, node: itir.FunCall): isinstance(named_range, itir.FunCall) and named_range.fun == itir.SymRef(id="named_range") ): - raise ValueError("Arguments to `domain` need to be calls to `named_range`.") + raise ValueError("Arguments to 'domain' need to be calls to 'named_range'.") tags.append(self.visit(named_range.args[0])) sizes.append( BinaryExpr( @@ -410,9 +410,9 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs: Any) -> Node: # special handling of applied builtins is handled in `_visit_` return getattr(self, visit_method)(node, **kwargs) elif node.fun.id == "shift": - raise ValueError("unapplied shift call not supported: {node}") + raise ValueError("Unapplied shift call not supported: '{node}'.") elif node.fun.id == "scan": - raise ValueError("scans are only supported at the top level of a stencil closure") + raise ValueError("Scans are only supported at the top level of a stencil closure.") if isinstance(node.fun, itir.FunCall): if node.fun.fun == itir.SymRef(id="shift"): assert len(node.args) == 1 @@ -440,7 +440,7 @@ def _visit_output_argument(self, node: itir.Expr): return self.visit(node) elif isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="make_tuple"): return SidComposite(values=[self._visit_output_argument(v) for v in node.args]) - raise ValueError("Expected `SymRef` or `make_tuple` in output argument.") + raise ValueError("Expected 'SymRef' or 'make_tuple' in output argument.") @staticmethod def _bool_from_literal(node: itir.Node): diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index d9f8b36301..95d3d2ca35 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -56,6 +56,62 @@ def kind(self) -> type[ProgramFormatter]: return ProgramFormatter +def _make_arg_filter( + accept_args: None | int | Literal["all"] = "all", +) -> Callable[[tuple[Any, ...]], tuple[Any, ...]]: + match accept_args: + case None: + + def arg_filter(args: tuple[Any, ...]) -> tuple[Any, ...]: + return () + + case "all": + + def arg_filter(args: tuple[Any, ...]) -> tuple[Any, ...]: + return args + + case int(): + if accept_args < 0: + raise ValueError( + f"Number of accepted arguments cannot be a negative number, got {accept_args}." + ) + + def arg_filter(args: tuple[Any, ...]) -> tuple[Any, ...]: + return args[:accept_args] + + case _: + raise ValueError(f"Invalid 'accept_args' value: {accept_args}.") + return arg_filter + + +def _make_kwarg_filter( + accept_kwargs: None | Sequence[str] | Literal["all"] = "all", +) -> Callable[[dict[str, Any]], dict[str, Any]]: + match accept_kwargs: + case None: + + def kwarg_filter(kwargs: dict[str, Any]) -> dict[str, Any]: + return {} + + case "all": + + def kwarg_filter(kwargs: dict[str, Any]) -> dict[str, Any]: + return kwargs + + case Sequence(): + if not all(isinstance(a, str) for a in accept_kwargs): + raise ValueError( + f"Provided invalid list of keyword argument names: '{accept_kwargs}'." + ) + + def kwarg_filter(kwargs: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in kwargs.items() if key in accept_kwargs} + + case _: + raise ValueError(f"Invalid 'accept_kwargs' value: {accept_kwargs}") + return kwarg_filter + + def make_program_processor( func: ProgramProcessorCallable[OutputT], kind: type[ProcessorKindT], @@ -80,33 +136,9 @@ def make_program_processor( Raises: ValueError: If the value of `accept_args` or `accept_kwargs` is invalid. """ - args_filter: Callable[[Sequence], Sequence] - if accept_args is None: - args_filter = lambda args: () # noqa: E731 # use def instead of named lambdas - elif accept_args == "all": - args_filter = lambda args: args # noqa: E731 - elif isinstance(accept_args, int): - if accept_args < 0: - raise ValueError( - f"Number of accepted arguments cannot be a negative number ({accept_args})" - ) - args_filter = lambda args: args[:accept_args] # type: ignore[misc] # noqa: E731 - else: - raise ValueError(f"Invalid ({accept_args}) accept_args value") - - filtered_kwargs: Callable[[dict[str, Any]], dict[str, Any]] - if accept_kwargs is None: - filtered_kwargs = lambda kwargs: {} # noqa: E731 # use def instead of named lambdas - elif accept_kwargs == "all": # don't swap with 'isinstance(..., Sequence)' - filtered_kwargs = lambda kwargs: kwargs # noqa: E731 - elif isinstance(accept_kwargs, Sequence): - if not all(isinstance(a, str) for a in accept_kwargs): - raise ValueError(f"Provided invalid list of keyword argument names ({accept_args})") - filtered_kwargs = lambda kwargs: { # noqa: E731 - key: value for key, value in kwargs.items() if key in accept_kwargs # type: ignore[operator] # key in accept_kwargs - } - else: - raise ValueError(f"Invalid ({accept_kwargs}) 'accept_kwargs' value") + args_filter = _make_arg_filter(accept_args) + + filtered_kwargs = _make_kwarg_filter(accept_kwargs) @functools.wraps(func) def _wrapper(program: itir.FencilDefinition, *args, **kwargs) -> OutputT: @@ -195,7 +227,7 @@ def ensure_processor_kind( obj: ProgramProcessor[OutputT, ProcessorKindT], kind: type[ProcessorKindT] ) -> None: if not is_processor_kind(obj, kind): - raise TypeError(f"{obj} is not a {kind.__name__}!") + raise TypeError(f"'{obj}' is not a '{kind.__name__}'.") class ProgramBackend( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index acfa06b456..65f9d9d71a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -148,7 +148,7 @@ def get_stride_args( stride, remainder = divmod(stride_size, value.itemsize) if remainder != 0: raise ValueError( - f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)" + f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)." ) stride_args[str(sym)] = stride @@ -334,7 +334,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: else: def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: - raise RuntimeError("Missing `cupy` dependency for GPU execution.") + raise RuntimeError("Missing 'cupy' dependency for GPU execution.") run_dace_gpu = otf_exec.OTFBackend( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d10a14a1ee..d08476847f 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -401,7 +401,7 @@ def builtin_tuple_get( index = node_args[0] if isinstance(index, itir.Literal): return [elements[int(index.value)]] - raise ValueError("Tuple can only be subscripted with compile-time constants") + raise ValueError("Tuple can only be subscripted with compile-time constants.") _GENERAL_BUILTIN_MAPPING: dict[ @@ -640,7 +640,7 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: elif builtin_name in _GENERAL_BUILTIN_MAPPING: return self._visit_general_builtin(node) else: - raise NotImplementedError(f"{builtin_name} not implemented") + raise NotImplementedError(f"'{builtin_name}' not implemented.") return self._visit_call(node) def _visit_call(self, node: itir.FunCall): diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index cb14b89e8a..55717326a3 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -32,7 +32,7 @@ def as_dace_type(type_: ts.ScalarType): return dace.float32 elif type_.kind == ts.ScalarKind.FLOAT64: return dace.float64 - raise ValueError(f"scalar type {type_} not supported") + raise ValueError(f"Scalar type '{type_}' not supported.") def filter_neighbor_tables(offset_provider: dict[str, Any]): diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 5d4b450d39..baa45ddc0e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -83,7 +83,7 @@ def extract_connectivity_args( if isinstance(conn, common.Connectivity): if not isinstance(conn, common.NeighborTable): raise NotImplementedError( - "Only `NeighborTable` connectivities implemented at this point." + "Only 'NeighborTable' connectivities implemented at this point." ) # copying to device here is a fallback for easy testing and might be removed later conn_arg = _ensure_is_on_device(conn.table, device) @@ -92,8 +92,8 @@ def extract_connectivity_args( pass else: raise AssertionError( - f"Expected offset provider `{name}` to be a `Connectivity` or `Dimension`, " - f"but got {type(conn).__name__}." + f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " + f"but got '{type(conn).__name__}'." ) return args diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 564df7fd1a..20fa8bd791 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -75,13 +75,15 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: match symbol_type: case ts.DeferredType(constraint): if constraint is None: - raise ValueError(f"No type information available for {symbol_type}!") + raise ValueError(f"No type information available for '{symbol_type}'.") elif isinstance(constraint, tuple): - raise ValueError(f"Not sufficient type information available for {symbol_type}!") + raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") return constraint case ts.TypeSpec() as concrete_type: return concrete_type.__class__ - raise ValueError(f"Invalid type for TypeInfo: requires {ts.TypeSpec}, got {type(symbol_type)}!") + raise ValueError( + f"Invalid type for TypeInfo: requires '{ts.TypeSpec}', got '{type(symbol_type)}'." + ) def primitive_constituents( @@ -163,7 +165,7 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: return dtype case ts.ScalarType() as dtype: return dtype - raise ValueError(f"Can not unambiguosly extract data type from {symbol_type}!") + raise ValueError(f"Can not unambiguosly extract data type from '{symbol_type}'.") def is_floating_point(symbol_type: ts.TypeSpec) -> bool: @@ -320,7 +322,7 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[common.Dimension]: return [] case ts.FieldType(dims): return dims - raise ValueError(f"Can not extract dimensions from {symbol_type}!") + raise ValueError(f"Can not extract dimensions from '{symbol_type}'.") def is_local_field(type_: ts.FieldType) -> bool: @@ -435,7 +437,7 @@ def promote(*types: ts.FieldType | ts.ScalarType) -> ts.FieldType | ts.ScalarTyp dtype = cast(ts.ScalarType, promote(*(extract_dtype(type_) for type_ in types))) return ts.FieldType(dims=dims, dtype=dtype) - raise TypeError("Expected a FieldType or ScalarType.") + raise TypeError("Expected a 'FieldType' or 'ScalarType'.") @functools.singledispatch @@ -446,7 +448,7 @@ def return_type( with_kwargs: dict[str, ts.TypeSpec], ): raise NotImplementedError( - f"Return type deduction of type " f"{type(callable_type).__name__} not implemented." + f"Return type deduction of type " f"'{type(callable_type).__name__}' not implemented." ) @@ -473,7 +475,7 @@ def return_type_field( raise ValueError("Could not deduce return type of invalid remap operation.") from ex if not isinstance(with_args[0], ts.OffsetType): - raise ValueError(f"First argument must be of type {ts.OffsetType}, got {with_args[0]}.") + raise ValueError(f"First argument must be of type '{ts.OffsetType}', got '{with_args[0]}'.") source_dim = with_args[0].source target_dims = with_args[0].target @@ -500,7 +502,7 @@ def canonicalize_arguments( ignore_errors=False, use_signature_ordering=False, ) -> tuple[list, dict]: - raise NotImplementedError(f"Not implemented for type {type(func_type).__name__}.") + raise NotImplementedError(f"Not implemented for type '{type(func_type).__name__}'.") @canonicalize_arguments.register @@ -526,7 +528,7 @@ def canonicalize_function_arguments( cargs[args_idx] = ckwargs.pop(name) elif not ignore_errors: raise AssertionError( - f"Error canonicalizing function arguments. Got multiple values for argument `{name}`." + f"Error canonicalizing function arguments. Got multiple values for argument '{name}'." ) a, b = set(func_type.kw_only_args.keys()), set(ckwargs.keys()) @@ -534,7 +536,7 @@ def canonicalize_function_arguments( if invalid_kw_args and (not ignore_errors or use_signature_ordering): # this error can not be ignored as otherwise the invariant that no arguments are dropped # is invalidated. - raise AssertionError(f"Invalid keyword arguments {', '.join(invalid_kw_args)}.") + raise AssertionError(f"Invalid keyword arguments '{', '.join(invalid_kw_args)}'.") if use_signature_ordering: ckwargs = {k: ckwargs[k] for k in func_type.kw_only_args.keys() if k in ckwargs} @@ -566,7 +568,7 @@ def structural_function_signature_incompatibilities( if args_idx < len(args): # remove the argument here such that later errors stay comprehensible kwargs.pop(name) - yield f"Got multiple values for argument `{name}`." + yield f"Got multiple values for argument '{name}'." num_pos_params = len(func_type.pos_only_args) + len(func_type.pos_or_kw_args) num_pos_args = len(args) - args.count(UNDEFINED_ARG) @@ -582,17 +584,17 @@ def structural_function_signature_incompatibilities( range(len(func_type.pos_only_args), num_pos_params), func_type.pos_or_kw_args.keys() ): if args[i] is UNDEFINED_ARG: - missing_positional_args.append(f"`{arg_type}`") + missing_positional_args.append(f"'{arg_type}'") if missing_positional_args: yield f"Missing {len(missing_positional_args)} required positional argument{'s' if len(missing_positional_args) != 1 else ''}: {', '.join(missing_positional_args)}" # check for missing or extra keyword arguments kw_a_m_b = set(func_type.kw_only_args.keys()) - set(kwargs.keys()) if len(kw_a_m_b) > 0: - yield f"Missing required keyword argument{'s' if len(kw_a_m_b) != 1 else ''} `{'`, `'.join(kw_a_m_b)}`." + yield f"Missing required keyword argument{'s' if len(kw_a_m_b) != 1 else ''} '{', '.join(kw_a_m_b)}'." kw_b_m_a = set(kwargs.keys()) - set(func_type.kw_only_args.keys()) if len(kw_b_m_a) > 0: - yield f"Got unexpected keyword argument{'s' if len(kw_b_m_a) != 1 else ''} `{'`, `'.join(kw_b_m_a)}`." + yield f"Got unexpected keyword argument{'s' if len(kw_b_m_a) != 1 else ''} '{', '.join(kw_b_m_a)}'." @functools.singledispatch @@ -604,7 +606,7 @@ def function_signature_incompatibilities( Note that all types must be concrete/complete. """ - raise NotImplementedError(f"Not implemented for type {type(func_type).__name__}.") + raise NotImplementedError(f"Not implemented for type '{type(func_type).__name__}'.") @function_signature_incompatibilities.register @@ -639,14 +641,14 @@ def function_signature_incompatibilities_func( # noqa: C901 if i < len(func_type.pos_only_args): arg_repr = f"{_number_to_ordinal_number(i+1)} argument" else: - arg_repr = f"argument `{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}`" - yield f"Expected {arg_repr} to be of type `{a_arg}`, but got `{b_arg}`." + arg_repr = f"argument '{list(func_type.pos_or_kw_args.keys())[i - len(func_type.pos_only_args)]}'" + yield f"Expected {arg_repr} to be of type '{a_arg}', got '{b_arg}'." for kwarg in set(func_type.kw_only_args.keys()) & set(kwargs.keys()): if (a_kwarg := func_type.kw_only_args[kwarg]) != ( b_kwarg := kwargs[kwarg] ) and not is_concretizable(a_kwarg, to_type=b_kwarg): - yield f"Expected keyword argument `{kwarg}` to be of type `{func_type.kw_only_args[kwarg]}`, but got `{kwargs[kwarg]}`." + yield f"Expected keyword argument '{kwarg}' to be of type '{func_type.kw_only_args[kwarg]}', got '{kwargs[kwarg]}'." @function_signature_incompatibilities.register @@ -660,11 +662,11 @@ def function_signature_incompatibilities_field( return if not isinstance(args[0], ts.OffsetType): - yield f"Expected first argument to be of type {ts.OffsetType}, but got {args[0]}." + yield f"Expected first argument to be of type '{ts.OffsetType}', got '{args[0]}'." return if kwargs: - yield f"Got unexpected keyword argument(s) `{'`, `'.join(kwargs.keys())}`." + yield f"Got unexpected keyword argument(s) '{', '.join(kwargs.keys())}'." return source_dim = args[0].source @@ -705,7 +707,7 @@ def accepts_args( """ if not isinstance(callable_type, ts.CallableType): if raise_exception: - raise ValueError(f"Expected a callable type, but got `{callable_type}`.") + raise ValueError(f"Expected a callable type, got '{callable_type}'.") return False errors = function_signature_incompatibilities(callable_type, with_args, with_kwargs) @@ -713,7 +715,7 @@ def accepts_args( error_list = list(errors) if len(error_list) > 0: raise ValueError( - f"Invalid call to function of type `{callable_type}`:\n" + f"Invalid call to function of type '{callable_type}':\n" + ("\n".join([f" - {error}" for error in error_list])) ) return True diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 007a83844c..88a8347fe4 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -37,7 +37,7 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: try: dt = np.dtype(dtype) except TypeError as err: - raise ValueError(f"Invalid scalar type definition ({dtype})") from err + raise ValueError(f"Invalid scalar type definition ('{dtype}').") from err if dt.shape == () and dt.fields is None: match dt: @@ -54,9 +54,9 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: case np.str_: return ts.ScalarKind.STRING case _: - raise ValueError(f"Impossible to map '{dtype}' value to a ScalarKind") + raise ValueError(f"Impossible to map '{dtype}' value to a 'ScalarKind'.") else: - raise ValueError(f"Non-trivial dtypes like '{dtype}' are not yet supported") + raise ValueError(f"Non-trivial dtypes like '{dtype}' are not yet supported.") def from_type_hint( @@ -76,7 +76,7 @@ def from_type_hint( type_hint = xtyping.eval_forward_ref(type_hint, globalns=globalns, localns=localns) except Exception as error: raise ValueError( - f"Type annotation ({type_hint}) has undefined forward references!" + f"Type annotation '{type_hint}' has undefined forward references." ) from error # Annotated @@ -98,50 +98,50 @@ def from_type_hint( case builtins.tuple: if not args: - raise ValueError(f"Tuple annotation ({type_hint}) requires at least one argument!") + raise ValueError(f"Tuple annotation '{type_hint}' requires at least one argument.") if Ellipsis in args: - raise ValueError(f"Unbound tuples ({type_hint}) are not allowed!") + raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") return ts.TupleType(types=[recursive_make_symbol(arg) for arg in args]) case common.Field: if (n_args := len(args)) != 2: - raise ValueError(f"Field type requires two arguments, got {n_args}! ({type_hint})") + raise ValueError(f"Field type requires two arguments, got {n_args}: '{type_hint}'.") dims: Union[Ellipsis, list[common.Dimension]] = [] dim_arg, dtype_arg = args if isinstance(dim_arg, list): for d in dim_arg: if not isinstance(d, common.Dimension): - raise ValueError(f"Invalid field dimension definition '{d}'") + raise ValueError(f"Invalid field dimension definition '{d}'.") dims.append(d) elif dim_arg is Ellipsis: dims = dim_arg else: - raise ValueError(f"Invalid field dimensions '{dim_arg}'") + raise ValueError(f"Invalid field dimensions '{dim_arg}'.") try: dtype = recursive_make_symbol(dtype_arg) except ValueError as error: raise ValueError( - f"Field dtype argument must be a scalar type (got '{dtype_arg}')!" + f"Field dtype argument must be a scalar type (got '{dtype_arg}')." ) from error if not isinstance(dtype, ts.ScalarType) or dtype.kind == ts.ScalarKind.STRING: - raise ValueError("Field dtype argument must be a scalar type (got '{dtype}')!") + raise ValueError("Field dtype argument must be a scalar type (got '{dtype}').") return ts.FieldType(dims=dims, dtype=dtype) case collections.abc.Callable: if not args: - raise ValueError("Not annotated functions are not supported!") + raise ValueError("Unannotated functions are not supported.") try: arg_types, return_type = args args = [recursive_make_symbol(arg) for arg in arg_types] except Exception as error: - raise ValueError(f"Invalid callable annotations in {type_hint}") from error + raise ValueError(f"Invalid callable annotations in '{type_hint}'.") from error kwargs_info = [arg for arg in extra_args if isinstance(arg, xtyping.CallableKwargsInfo)] if len(kwargs_info) != 1: - raise ValueError(f"Invalid callable annotations in {type_hint}") + raise ValueError(f"Invalid callable annotations in '{type_hint}'.") kwargs = { arg: recursive_make_symbol(arg_type) for arg, arg_type in kwargs_info[0].data.items() @@ -155,7 +155,7 @@ def from_type_hint( returns=recursive_make_symbol(return_type), ) - raise ValueError(f"'{type_hint}' type is not supported") + raise ValueError(f"'{type_hint}' type is not supported.") def from_value(value: Any) -> ts.TypeSpec: @@ -178,7 +178,7 @@ def from_value(value: Any) -> ts.TypeSpec: break if not symbol_type: raise ValueError( - f"Value `{value}` is out of range to be representable as `INT32` or `INT64`." + f"Value '{value}' is out of range to be representable as 'INT32' or 'INT64'." ) return candidate_type elif isinstance(value, common.Dimension): @@ -200,4 +200,4 @@ def from_value(value: Any) -> ts.TypeSpec: if isinstance(symbol_type, (ts.DataType, ts.CallableType, ts.OffsetType, ts.DimensionType)): return symbol_type else: - raise ValueError(f"Impossible to map '{value}' value to a Symbol") + raise ValueError(f"Impossible to map '{value}' value to a 'Symbol'.") diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index b1e26b40cb..6217d3c782 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -127,7 +127,7 @@ class ConstInitializer(DataInitializer): def __init__(self, value: ScalarValue): if not core_defs.is_scalar_type(value): raise ValueError( - "`ConstInitializer` can not be used with non-scalars. Use `Case.as_field` instead." + "'ConstInitializer' can not be used with non-scalars. Use 'Case.as_field' instead." ) self.value = value @@ -162,7 +162,7 @@ class IndexInitializer(DataInitializer): @property def scalar_value(self) -> ScalarValue: - raise AttributeError("`scalar_value` not supported in `IndexInitializer`.") + raise AttributeError("'scalar_value' not supported in 'IndexInitializer'.") def field( self, @@ -172,7 +172,7 @@ def field( ) -> FieldValue: if len(sizes) > 1: raise ValueError( - f"`IndexInitializer` only supports fields with a single `Dimension`, got {sizes}." + f"'IndexInitializer' only supports fields with a single 'Dimension', got {sizes}." ) n_data = list(sizes.values())[0] return constructors.as_field( @@ -244,7 +244,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.partial(*args, **kwargs) def __getattr__(self, name: str) -> Any: - raise AttributeError(f"No setter for argument {name}.") + raise AttributeError(f"No setter for argument '{name}'.") @typing.overload @@ -323,7 +323,7 @@ class NewBuilder(Builder): if 0 < len(args) <= 1 and args[0] is not None: return make_builder_inner(args[0]) if len(args) > 1: - raise ValueError(f"make_builder takes only one positional argument, {len(args)} received!") + raise ValueError(f"make_builder takes only one positional argument, {len(args)} received.") return make_builder_inner @@ -533,7 +533,7 @@ def _allocate_from_type( ) case _: raise TypeError( - f"Can not allocate for type {arg_type} with initializer {strategy or 'default'}" + f"Can not allocate for type '{arg_type}' with initializer '{strategy or 'default'}'." ) @@ -542,7 +542,7 @@ def get_param_types( ) -> dict[str, ts.TypeSpec]: if fieldview_prog.definition is None: raise ValueError( - f"test cases do not support {type(fieldview_prog)} with empty .definition attribute (as you would get from .as_program())!" + f"test cases do not support '{type(fieldview_prog)}' with empty .definition attribute (as you would get from .as_program())." ) annotations = xtyping.get_type_hints(fieldview_prog.definition) return { @@ -559,7 +559,7 @@ def get_param_size(param_type: ts.TypeSpec, sizes: dict[gtx.Dimension, int]) -> case ts.TupleType(types): return sum([get_param_size(t, sizes=sizes) for t in types]) case _: - raise TypeError(f"Can not get size for parameter of type {param_type}") + raise TypeError(f"Can not get size for parameter of type '{param_type}'.") def extend_sizes( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index f8a3f6a975..e25576ebde 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -22,7 +22,6 @@ import gt4py.next as gtx from gt4py.next.ffront import decorator from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.runners import gtfn, roundtrip try: @@ -39,7 +38,7 @@ def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: """Temporary default backend to not accidentally test the wrong backend.""" - raise ValueError("No backend selected! Backend selection is mandatory in tests.") + raise ValueError("No backend selected. Backend selection is mandatory in tests.") OPTIONAL_PROCESSORS = [] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index 6293ff76bd..b41696a36b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -226,7 +226,7 @@ def testee( def test_scan_wrong_return_type(cartesian_case): with pytest.raises( errors.DSLError, - match=(r"Argument `init` to scan operator `testee_scan` must have same type as its return"), + match=(r"Argument 'init' to scan operator 'testee_scan' must have same type as its return"), ): @scan_operator(axis=KDim, forward=True, init=0) @@ -245,7 +245,7 @@ def test_scan_wrong_state_type(cartesian_case): with pytest.raises( errors.DSLError, match=( - r"Argument `init` to scan operator `testee_scan` must have same type as `state` argument" + r"Argument 'init' to scan operator 'testee_scan' must have same type as 'state' argument" ), ): @@ -276,7 +276,7 @@ def program_bound_args(arg1: bool, arg2: bool, out: cases.IField): def test_bind_invalid_arg(cartesian_case, bound_args_testee): with pytest.raises( - TypeError, match="Keyword argument `inexistent_arg` is not a valid program parameter." + TypeError, match="Keyword argument 'inexistent_arg' is not a valid program parameter." ): bound_args_testee.with_bound_args(inexistent_arg=1) @@ -306,7 +306,7 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te assert ( re.search( - "Parameter `arg2` already set as a bound argument.", exc_info.value.__cause__.args[0] + "Parameter 'arg2' already set as a bound argument.", exc_info.value.__cause__.args[0] ) is not None ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 51f853d41d..a08931628b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -1188,7 +1188,7 @@ def unpack( def test_tuple_unpacking_too_many_values(cartesian_case): with pytest.raises( errors.DSLError, - match=(r"Could not deduce type: Too many values to unpack \(expected 3\)"), + match=(r"Too many values to unpack \(expected 3\)."), ): @gtx.field_operator(backend=cartesian_case.backend) @@ -1197,8 +1197,10 @@ def _star_unpack() -> tuple[int32, float64, int32]: return a, b, c -def test_tuple_unpacking_too_many_values(cartesian_case): - with pytest.raises(errors.DSLError, match=(r"Assignment value must be of type tuple!")): +def test_tuple_unpacking_too_few_values(cartesian_case): + with pytest.raises( + errors.DSLError, match=(r"Assignment value must be of type tuple, got 'int32'.") + ): @gtx.field_operator(backend=cartesian_case.backend) def _invalid_unpack() -> tuple[int32, float64, int32]: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 8cfcff160c..167ccbb0a5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -57,7 +57,7 @@ def make_builtin_field_operator(builtin_name: str): "return": cases.IFloatField, } else: - raise AssertionError(f"Unknown builtin `{builtin_name}`") + raise AssertionError(f"Unknown builtin '{builtin_name}'.") closure_vars = {"IDim": IDim, builtin_name: getattr(fbuiltins, builtin_name)} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index c2ab43773f..f5bf453a09 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -147,9 +147,9 @@ def tilde_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: def test_unary_not(cartesian_case): pytest.xfail( - "We accidentally supported `not` on fields. This is wrong, we should raise an error." + "We accidentally supported 'not' on fields. This is wrong, we should raise an error." ) - with pytest.raises: # TODO `not` on a field should be illegal + with pytest.raises: # TODO 'not' on a field should be illegal @gtx.field_operator def not_fieldop(inp1: cases.IBoolField) -> cases.IBoolField: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index 4c0613a33c..c86881ab7c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -228,8 +228,8 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): copy_program(inp, out, offset_provider={}) msgs = [ - r"- Expected argument `in_field` to be of type `Field\[\[IDim], float64\]`," - r" but got `Field\[\[JDim\], float64\]`.", + r"- Expected argument 'in_field' to be of type 'Field\[\[IDim], float64\]'," + r" got 'Field\[\[JDim\], float64\]'.", ] for msg in msgs: assert re.search(msg, exc_info.value.__cause__.args[0]) is not None diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 84b480a23d..af06da3e29 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -334,7 +334,7 @@ def if_without_else( def test_if_non_scalar_condition(): - with pytest.raises(errors.DSLError, match="Condition for `if` must be scalar."): + with pytest.raises(errors.DSLError, match="Condition for 'if' must be scalar"): @field_operator def if_non_scalar_condition( @@ -347,7 +347,7 @@ def if_non_scalar_condition( def test_if_non_boolean_condition(): - with pytest.raises(errors.DSLError, match="Condition for `if` must be of boolean type."): + with pytest.raises(errors.DSLError, match="Condition for 'if' must be of boolean type"): @field_operator def if_non_boolean_condition( diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py index d1a5f24f79..2174871f89 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py @@ -88,7 +88,7 @@ def type_info_cases() -> list[tuple[Optional[ts.TypeSpec], dict]]: def callable_type_info_cases(): # reuse all the other test cases not_callable = [ - (symbol_type, [], {}, [r"Expected a callable type, but got "], None) + (symbol_type, [], {}, [r"Expected a callable type, got "], None) for symbol_type, attributes in type_info_cases() if not isinstance(symbol_type, ts.CallableType) ] @@ -165,7 +165,7 @@ def callable_type_info_cases(): nullary_func_type, [], {"foo": bool_type}, - [r"Got unexpected keyword argument `foo`."], + [r"Got unexpected keyword argument 'foo'."], None, ), ( @@ -180,7 +180,7 @@ def callable_type_info_cases(): unary_func_type, [float_type], {}, - [r"Expected 1st argument to be of type `bool`, but got `float64`."], + [r"Expected 1st argument to be of type 'bool', got 'float64'."], None, ), ( @@ -188,7 +188,7 @@ def callable_type_info_cases(): [], {}, [ - r"Missing 1 required positional argument: `foo`", + r"Missing 1 required positional argument: 'foo'", r"Function takes 1 positional argument, but 0 were given.", ], None, @@ -199,31 +199,31 @@ def callable_type_info_cases(): kw_or_pos_arg_func_type, [], {"foo": float_type}, - [r"Expected argument `foo` to be of type `bool`, but got `float64`."], + [r"Expected argument 'foo' to be of type 'bool', got 'float64'."], None, ), ( kw_or_pos_arg_func_type, [], {"bar": bool_type}, - [r"Got unexpected keyword argument `bar`."], + [r"Got unexpected keyword argument 'bar'."], None, ), # function with keyword-only argument - (kw_only_arg_func_type, [], {}, [r"Missing required keyword argument `foo`."], None), + (kw_only_arg_func_type, [], {}, [r"Missing required keyword argument 'foo'."], None), (kw_only_arg_func_type, [], {"foo": bool_type}, [], ts.VoidType()), ( kw_only_arg_func_type, [], {"foo": float_type}, - [r"Expected keyword argument `foo` to be of type `bool`, but got `float64`."], + [r"Expected keyword argument 'foo' to be of type 'bool', got 'float64'."], None, ), ( kw_only_arg_func_type, [], {"bar": bool_type}, - [r"Got unexpected keyword argument `bar`."], + [r"Got unexpected keyword argument 'bar'."], None, ), # function with positional, keyword-or-positional, and keyword-only argument @@ -232,9 +232,9 @@ def callable_type_info_cases(): [], {}, [ - r"Missing 1 required positional argument: `foo`", + r"Missing 1 required positional argument: 'foo'", r"Function takes 2 positional arguments, but 0 were given.", - r"Missing required keyword argument `bar`", + r"Missing required keyword argument 'bar'", ], None, ), @@ -244,7 +244,7 @@ def callable_type_info_cases(): {}, [ r"Function takes 2 positional arguments, but 1 were given.", - r"Missing required keyword argument `bar`", + r"Missing required keyword argument 'bar'", ], None, ), @@ -252,14 +252,14 @@ def callable_type_info_cases(): pos_arg_and_kw_or_pos_arg_and_kw_only_arg_func_type, [bool_type], {"foo": int_type}, - [r"Missing required keyword argument `bar`"], + [r"Missing required keyword argument 'bar'"], None, ), ( pos_arg_and_kw_or_pos_arg_and_kw_only_arg_func_type, [bool_type], {"foo": int_type}, - [r"Missing required keyword argument `bar`"], + [r"Missing required keyword argument 'bar'"], None, ), ( @@ -274,9 +274,9 @@ def callable_type_info_cases(): [int_type], {"bar": bool_type, "foo": bool_type}, [ - r"Expected 1st argument to be of type `bool`, but got `int64`", - r"Expected argument `foo` to be of type `int64`, but got `bool`", - r"Expected keyword argument `bar` to be of type `float64`, but got `bool`", + r"Expected 1st argument to be of type 'bool', got 'int64'", + r"Expected argument 'foo' to be of type 'int64', got 'bool'", + r"Expected keyword argument 'bar' to be of type 'float64', got 'bool'", ], None, ), @@ -299,7 +299,7 @@ def callable_type_info_cases(): [ts.TupleType(types=[float_type, field_type])], {}, [ - r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `tuple\[float64, Field\[\[I\], float64\]\]`" + r"Expected 1st argument to be of type 'tuple\[bool, Field\[\[I\], float64\]\]', got 'tuple\[float64, Field\[\[I\], float64\]\]'" ], ts.VoidType(), ), @@ -308,7 +308,7 @@ def callable_type_info_cases(): [int_type], {}, [ - r"Expected 1st argument to be of type `tuple\[bool, Field\[\[I\], float64\]\]`, but got `int64`" + r"Expected 1st argument to be of type 'tuple\[bool, Field\[\[I\], float64\]\]', got 'int64'" ], ts.VoidType(), ), @@ -330,8 +330,8 @@ def callable_type_info_cases(): ], {}, [ - r"Expected argument `a` to be of type `Field\[\[K\], int64\]`, but got `Field\[\[K\], float64\]`", - r"Expected argument `b` to be of type `Field\[\[K\], int64\]`, but got `Field\[\[K\], float64\]`", + r"Expected argument 'a' to be of type 'Field\[\[K\], int64\]', got 'Field\[\[K\], float64\]'", + r"Expected argument 'b' to be of type 'Field\[\[K\], int64\]', got 'Field\[\[K\], float64\]'", ], ts.FieldType(dims=[KDim], dtype=float_type), ), @@ -393,8 +393,8 @@ def callable_type_info_cases(): ], {}, [ - r"Expected argument `a` to be of type `tuple\[Field\[\[I, J, K\], int64\], " - r"Field\[\[\.\.\.\], int64\]\]`, but got `tuple\[Field\[\[I, J, K\], int64\]\]`." + r"Expected argument 'a' to be of type 'tuple\[Field\[\[I, J, K\], int64\], " + r"Field\[\[\.\.\.\], int64\]\]', got 'tuple\[Field\[\[I, J, K\], int64\]\]'." ], ts.FieldType(dims=[IDim, JDim, KDim], dtype=float_type), ), @@ -491,7 +491,7 @@ def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): with pytest.raises( errors.DSLError, - match=(r"Type Field\[\[TDim\], bool\] can not be used in operator `\+`!"), + match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'."), ): _ = FieldOperatorParser.apply_to_function(add_bools) @@ -507,7 +507,7 @@ def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): with pytest.raises( errors.DSLError, match=( - r"Could not promote `Field\[\[X], float64\]` and `Field\[\[Y\], float64\]` to common type in call to +." + r"Could not promote 'Field\[\[X], float64\]' and 'Field\[\[Y\], float64\]' to common type in call to +." ), ): _ = FieldOperatorParser.apply_to_function(nonmatching) @@ -519,7 +519,7 @@ def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=(r"Type Field\[\[TDim\], float64\] can not be used in operator `\&`!"), + match=(r"Type 'Field\[\[TDim\], float64\]' can not be used in operator '\&'."), ): _ = FieldOperatorParser.apply_to_function(float_bitop) @@ -530,7 +530,7 @@ def sign_bool(a: Field[[TDim], bool]): with pytest.raises( errors.DSLError, - match=r"Incompatible type for unary operator `\-`: `Field\[\[TDim\], bool\]`!", + match=r"Incompatible type for unary operator '\-': 'Field\[\[TDim\], bool\]'.", ): _ = FieldOperatorParser.apply_to_function(sign_bool) @@ -541,7 +541,7 @@ def not_int(a: Field[[TDim], int64]): with pytest.raises( errors.DSLError, - match=r"Incompatible type for unary operator `not`: `Field\[\[TDim\], int64\]`!", + match=r"Incompatible type for unary operator 'not': 'Field\[\[TDim\], int64\]'.", ): _ = FieldOperatorParser.apply_to_function(not_int) @@ -613,7 +613,7 @@ def mismatched_lit() -> Field[[TDim], "float32"]: with pytest.raises( errors.DSLError, - match=(r"Could not promote `float32` and `float64` to common type in call to +."), + match=(r"Could not promote 'float32' and 'float64' to common type in call to +."), ): _ = FieldOperatorParser.apply_to_function(mismatched_lit) @@ -643,7 +643,7 @@ def disjoint_broadcast(a: Field[[ADim], float64]): with pytest.raises( errors.DSLError, - match=r"Expected broadcast dimension is missing", + match=r"expected broadcast dimension\(s\) \'.*\' missing", ): _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) @@ -658,7 +658,7 @@ def badtype_broadcast(a: Field[[ADim], float64]): with pytest.raises( errors.DSLError, - match=r"Expected all broadcast dimensions to be of type Dimension.", + match=r"expected all broadcast dimensions to be of type 'Dimension'.", ): _ = FieldOperatorParser.apply_to_function(badtype_broadcast) @@ -778,7 +778,7 @@ def simple_astype(a: Field[[TDim], float64]): with pytest.raises( errors.DSLError, - match=r"Invalid call to `astype`. Second argument must be a scalar type, but got.", + match=r"Invalid call to 'astype': second argument must be a scalar type, got.", ): _ = FieldOperatorParser.apply_to_function(simple_astype) @@ -806,7 +806,7 @@ def modulo_floats(inp: Field[[TDim], float]): with pytest.raises( errors.DSLError, - match=r"Type float64 can not be used in operator `%`", + match=r"Type 'float64' can not be used in operator '%'", ): _ = FieldOperatorParser.apply_to_function(modulo_floats) @@ -844,6 +844,6 @@ def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): with pytest.raises( errors.DSLError, - match=f"Excepted integer for offset field dtype", + match=f"expected integer for offset field dtype", ): _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index c0d565bbf4..d3f3f35699 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -126,7 +126,7 @@ def fenimpl(size, arg0, arg1, arg2, out): closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1, arg2]) else: - raise AssertionError("Add overload") + raise AssertionError("Add overload.") return run_processor(fenimpl, processor, out.shape[0], *inps, out) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 6f91557e46..4177a5aeee 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -109,7 +109,7 @@ def run_processor( elif ppi.is_processor_kind(processor, ppi.ProgramFormatter): print(program.format_itir(*args, formatter=processor, **kwargs)) else: - raise TypeError(f"program processor kind not recognized: {processor}!") + raise TypeError(f"program processor kind not recognized: '{processor}'.") @dataclasses.dataclass diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 2b78eb9114..1a38e5245e 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -631,5 +631,5 @@ def test_setitem_wrong_domain(): np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) ) - with pytest.raises(ValueError, match=r"Incompatible `Domain`.*"): + with pytest.raises(ValueError, match=r"Incompatible 'Domain'.*"): field[(1, slice(None))] = value_incompatible diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py index e5bbed19fd..96ecc19c0b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_foast.py @@ -88,7 +88,7 @@ def mistyped(inp: gtx.Field): with pytest.raises( ValueError, - match="Field type requires two arguments, got 0!", + match="Field type requires two arguments, got 0.", ): _ = FieldOperatorParser.apply_to_function(mistyped) @@ -245,7 +245,7 @@ def conditional_wrong_mask_type( ) -> gtx.Field[[TDim], float64]: return where(a, a, a) - msg = r"Expected a field with dtype `bool`." + msg = r"expected a field with dtype 'bool'" with pytest.raises(errors.DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(conditional_wrong_mask_type) @@ -269,7 +269,7 @@ def test_ternary_with_field_condition(): def ternary_with_field_condition(cond: gtx.Field[[], bool]): return 1 if cond else 2 - with pytest.raises(errors.DSLError, match=r"should be .* `bool`"): + with pytest.raises(errors.DSLError, match=r"should be .* 'bool'"): _ = FieldOperatorParser.apply_to_function(ternary_with_field_condition) @@ -288,7 +288,7 @@ def test_adr13_wrong_return_type_annotation(): def wrong_return_type_annotation() -> gtx.Field[[], float]: return 1.0 - with pytest.raises(errors.DSLError, match=r"Expected `float.*`"): + with pytest.raises(errors.DSLError, match=r"expected 'float.*'"): _ = FieldOperatorParser.apply_to_function(wrong_return_type_annotation) @@ -395,8 +395,6 @@ def zero_dims_ternary( ): return a if cond == 1 else b - msg = r"Incompatible datatypes in operator `==`" - with pytest.raises(errors.DSLError) as exc_info: + msg = r"Incompatible datatypes in operator '=='" + with pytest.raises(errors.DSLError, match=msg): _ = FieldOperatorParser.apply_to_function(zero_dims_ternary) - - assert re.search(msg, exc_info.value.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py index 1d1a1efad4..cca05f9917 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_func_to_past.py @@ -113,7 +113,7 @@ def undefined_field_program(in_field: gtx.Field[[IDim], "float64"]): with pytest.raises( errors.DSLError, - match=(r"Undeclared or untyped symbol `out_field`."), + match=(r"Undeclared or untyped symbol 'out_field'."), ): ProgramParser.apply_to_function(undefined_field_program) @@ -165,10 +165,10 @@ def domain_format_1_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_1_program) - assert exc_info.match("Invalid call to `domain_format_1`") + assert exc_info.match("Invalid call to 'domain_format_1'") assert ( - re.search("Only Dictionaries allowed in domain", exc_info.value.__cause__.args[0]) + re.search("Only Dictionaries allowed in 'domain'", exc_info.value.__cause__.args[0]) is not None ) @@ -184,7 +184,7 @@ def domain_format_2_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_2_program) - assert exc_info.match("Invalid call to `domain_format_2`") + assert exc_info.match("Invalid call to 'domain_format_2'") assert ( re.search("Only 2 values allowed in domain range", exc_info.value.__cause__.args[0]) @@ -203,10 +203,10 @@ def domain_format_3_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_3_program) - assert exc_info.match("Invalid call to `domain_format_3`") + assert exc_info.match("Invalid call to 'domain_format_3'") assert ( - re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument\ 'out'", exc_info.value.__cause__.args[0]) is not None ) @@ -224,7 +224,7 @@ def domain_format_4_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_4_program) - assert exc_info.match("Invalid call to `domain_format_4`") + assert exc_info.match("Invalid call to 'domain_format_4'") assert ( re.search("Either only domain or slicing allowed", exc_info.value.__cause__.args[0]) @@ -243,7 +243,7 @@ def domain_format_5_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_5_program) - assert exc_info.match("Invalid call to `domain_format_5`") + assert exc_info.match("Invalid call to 'domain_format_5'") assert ( re.search("Only integer values allowed in domain range", exc_info.value.__cause__.args[0]) @@ -262,6 +262,6 @@ def domain_format_6_program(in_field: gtx.Field[[IDim], float64]): ) as exc_info: ProgramParser.apply_to_function(domain_format_6_program) - assert exc_info.match("Invalid call to `domain_format_6`") + assert exc_info.match("Invalid call to 'domain_format_6'") assert re.search("Empty domain not allowed.", exc_info.value.__cause__.args[0]) is not None diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py index c4fe30c596..a1a7b79cec 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py @@ -177,7 +177,7 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): grid_type=gtx.GridType.CARTESIAN, ) - assert exc_info.match("Invalid call to `identity`") + assert exc_info.match("Invalid call to 'identity'") # TODO(tehrengruber): re-enable again when call signature check doesn't return # immediately after missing `out` argument # assert ( @@ -187,6 +187,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): # is not None # ) assert ( - re.search(r"Missing required keyword argument\(s\) `out`", exc_info.value.__cause__.args[0]) + re.search(r"Missing required keyword argument 'out'", exc_info.value.__cause__.args[0]) is not None ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 232995be58..73ad24f42b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -56,7 +56,7 @@ def test_embedded_error_on_wrong_domain(): 1, ), ) - with pytest.raises(RuntimeError, match="expected `UnstructuredDomain`"): + with pytest.raises(RuntimeError, match="expected 'UnstructuredDomain'"): foo[dom]( gtx.as_field([I], np.zeros((1,))), out=out, diff --git a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py index 05e982cf0c..1ba35da7c6 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py +++ b/tests/next_tests/unit_tests/program_processor_tests/test_processor_interface.py @@ -74,12 +74,12 @@ def test_undecorated_formatter_function_is_not_recognized(): def undecorated_formatter(fencil: itir.FencilDefinition, *args, **kwargs) -> str: return "" - with pytest.raises(TypeError, match="is not a ProgramFormatter"): + with pytest.raises(TypeError, match="is not a 'ProgramFormatter'"): ensure_processor_kind(undecorated_formatter, ProgramFormatter) def test_wrong_processor_type_is_caught_at_runtime(dummy_formatter): - with pytest.raises(TypeError, match="is not a ProgramExecutor"): + with pytest.raises(TypeError, match="is not a 'ProgramExecutor'"): ensure_processor_kind(dummy_formatter, ProgramExecutor) diff --git a/tests/next_tests/unit_tests/test_allocators.py b/tests/next_tests/unit_tests/test_allocators.py index 456654c1d0..599bea75e7 100644 --- a/tests/next_tests/unit_tests/test_allocators.py +++ b/tests/next_tests/unit_tests/test_allocators.py @@ -108,7 +108,7 @@ def test_get_allocator(): with pytest.raises( TypeError, - match=f"Object {invalid_obj} is neither a field allocator nor a field allocator factory", + match=f"Object '{invalid_obj}' is neither a field allocator nor a field allocator factory", ): next_allocators.get_allocator(invalid_obj, strict=True) diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index da63536953..bafabfb56e 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -96,7 +96,7 @@ def test_unit_range_slice_error(rng): def test_unit_range_set_intersection(rng): with pytest.raises( - NotImplementedError, match="Can only find the intersection between UnitRange instances." + NotImplementedError, match="Can only find the intersection between 'UnitRange' instances." ): rng & {1, 5} diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index e8b070f0c0..8d95c9951f 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -139,7 +139,7 @@ def test_as_field_origin(): def test_field_wrong_dims(): with pytest.raises( ValueError, - match=(r"Cannot construct `Field` from array of shape"), + match=(r"Cannot construct 'Field' from array of shape"), ): gtx.as_field([I, J], np.random.rand(sizes[I]).astype(gtx.float32)) @@ -147,7 +147,7 @@ def test_field_wrong_dims(): def test_field_wrong_domain(): with pytest.raises( ValueError, - match=(r"Cannot construct `Field` from array of shape"), + match=(r"Cannot construct 'Field' from array of shape"), ): domain = common.Domain( dims=(I, J), diff --git a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py index d281f5cd90..0a0b747a28 100644 --- a/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py +++ b/tests/next_tests/unit_tests/type_system_tests/test_type_translation.py @@ -158,7 +158,7 @@ def test_invalid_symbol_types(): type_translation.from_type_hint(common.Field[[IDim], None]) # Functions - with pytest.raises(ValueError, match="Not annotated functions are not supported"): + with pytest.raises(ValueError, match="Unannotated functions are not supported"): type_translation.from_type_hint(typing.Callable) with pytest.raises(ValueError, match="Invalid callable annotations"): From 0d66829d8c68b89a620c87fa3fbc8f5b64287d27 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 14 Dec 2023 11:45:16 +0100 Subject: [PATCH 05/13] docs[next]: Partially fix Quickstart Guide (#1390) Changes to the quickstart guide to use `field.asnumpy()` (introduced in #1366) instead of `np.asarray(field)`. The quickstart guide is still broken though since the embedded backend (used by default) does not support skip neighbors connectivities. --- docs/user/next/QuickstartGuide.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index 1ae1db4d92..dc70f804fd 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -102,7 +102,7 @@ You can call field operators from [programs](#Programs), other field operators, result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) add(a, b, out=result, offset_provider={}) -print("{} + {} = {} ± {}".format(a_value, b_value, np.average(np.asarray(result)), np.std(np.asarray(result)))) +print("{} + {} = {} ± {}".format(a_value, b_value, np.average(result.asnumpy()), np.std(result.asnumpy()))) ``` #### Programs @@ -128,7 +128,7 @@ You can execute the program by simply calling it: result = gtx.as_field([CellDim, KDim], np.zeros(shape=grid_shape)) run_add(a, b, result, offset_provider={}) -print("{} + {} = {} ± {}".format(b_value, (a_value + b_value), np.average(np.asarray(result)), np.std(np.asarray(result)))) +print("{} + {} = {} ± {}".format(b_value, (a_value + b_value), np.average(result.asnumpy()), np.std(result.asnumpy()))) ``` #### Composing field operators and programs @@ -256,7 +256,7 @@ def run_nearest_cell_to_edge(cell_values: gtx.Field[[CellDim], float64], out : g run_nearest_cell_to_edge(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) -print("0th adjacent cell's value: {}".format(np.asarray(edge_values))) +print("0th adjacent cell's value: {}".format(edge_values.asnumpy())) ``` Running the above snippet results in the following edge field: @@ -283,7 +283,7 @@ def run_sum_adjacent_cells(cells : gtx.Field[[CellDim], float64], out : gtx.Fiel run_sum_adjacent_cells(cell_values, edge_values, offset_provider={"E2C": E2C_offset_provider}) -print("sum of adjacent cells: {}".format(np.asarray(edge_values))) +print("sum of adjacent cells: {}".format(edge_values.asnumpy())) ``` For the border edges, the results are unchanged compared to the previous example, but the inner edges now contain the sum of the two adjacent cells: @@ -317,7 +317,7 @@ def conditional(mask: gtx.Field[[CellDim, KDim], bool], a: gtx.Field[[CellDim, K return where(mask, a, b) conditional(mask, a, b, out=result_where, offset_provider={}) -print("where return: {}".format(np.asarray(result_where))) +print("where return: {}".format(result_where.asnumpy())) ``` **Tuple implementation:** @@ -340,7 +340,7 @@ result_1: gtx.Field[[CellDim, KDim], float64], result_2: gtx.Field[[CellDim, KDi _conditional_tuple(mask, a, b, out=(result_1, result_2)) conditional_tuple(mask, a, b, result_1, result_2, offset_provider={}) -print("where tuple return: {}".format((np.asarray(result_1), np.asarray(result_2)))) +print("where tuple return: {}".format((result_1.asnumpy(), result_2.asnumpy()))) ``` The `where` builtin also allows for nesting of tuples. In this scenario, it will first perform an unrolling: @@ -375,7 +375,7 @@ def conditional_tuple_nested( _conditional_tuple_nested(mask, a, b, c, d, out=((result_1, result_2), (result_2, result_1))) conditional_tuple_nested(mask, a, b, c, d, result_1, result_2, offset_provider={}) -print("where nested tuple return: {}".format(((np.asarray(result_1), np.asarray(result_2)), (np.asarray(result_2), np.asarray(result_1))))) +print("where nested tuple return: {}".format(((result_1.asnumpy(), result_2.asnumpy()), (result_2.asnumpy(), result_1.asnumpy())))) ``` #### Implementing the pseudo-laplacian @@ -447,7 +447,7 @@ run_pseudo_laplacian(cell_values, result_pseudo_lap, offset_provider={"E2C": E2C_offset_provider, "C2E": C2E_offset_provider}) -print("pseudo-laplacian: {}".format(np.asarray(result_pseudo_lap))) +print("pseudo-laplacian: {}".format(result_pseudo_lap.asnumpy())) ``` As a closure, here is an example of chaining field operators, which is very simple to do when working with fields. The field operator below executes the pseudo-laplacian, and then calls the pseudo-laplacian on the result of the first, in effect, calculating the laplacian of a laplacian. From cdcd6537bbc05b050a25ae6abea5b69490ed87db Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Mon, 18 Dec 2023 08:52:29 +0100 Subject: [PATCH 06/13] feat[next]: Add missing UnitRange comparison functions (#1363) - Introduce a better Infinity - Make UnitRange Generic to express finite, infinite, left-finite, right-finite properly. - Remove `Set` from UnitRange --- src/gt4py/next/common.py | 228 ++++++++++++------ src/gt4py/next/embedded/common.py | 1 + src/gt4py/next/embedded/nd_array_field.py | 25 +- src/gt4py/next/ffront/fbuiltins.py | 2 +- src/gt4py/next/iterator/embedded.py | 2 +- .../runners/dace_iterator/__init__.py | 44 ++-- .../embedded_tests/test_nd_array_field.py | 9 +- tests/next_tests/unit_tests/test_common.py | 138 ++++++++--- 8 files changed, 305 insertions(+), 144 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 3e1fe52f31..29d606ccc0 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -20,9 +20,8 @@ import enum import functools import numbers -import sys import types -from collections.abc import Mapping, Sequence, Set +from collections.abc import Mapping, Sequence import numpy as np import numpy.typing as npt @@ -33,10 +32,12 @@ Any, Callable, ClassVar, + Generic, Never, Optional, ParamSpec, Protocol, + Self, TypeAlias, TypeGuard, TypeVar, @@ -52,16 +53,6 @@ DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) -class Infinity(int): - @classmethod - def positive(cls) -> Infinity: - return cls(sys.maxsize) - - @classmethod - def negative(cls) -> Infinity: - return cls(-sys.maxsize) - - Tag: TypeAlias = str @@ -84,31 +75,86 @@ def __str__(self): return f"{self.value}[{self.kind}]" +class Infinity(enum.Enum): + """Describes an unbounded `UnitRange`.""" + + NEGATIVE = enum.auto() + POSITIVE = enum.auto() + + def __add__(self, _: int) -> Self: + return self + + __radd__ = __add__ + + def __sub__(self, _: int) -> Self: + return self + + __rsub__ = __sub__ + + def __le__(self, other: int | Infinity) -> bool: + return self is self.NEGATIVE or other is self.POSITIVE + + def __lt__(self, other: int | Infinity) -> bool: + return self is self.NEGATIVE and other is not self + + def __ge__(self, other: int | Infinity) -> bool: + return self is self.POSITIVE or other is self.NEGATIVE + + def __gt__(self, other: int | Infinity) -> bool: + return self is self.POSITIVE and other is not self + + +def _as_int(v: core_defs.IntegralScalar | Infinity) -> int | Infinity: + return v if isinstance(v, Infinity) else int(v) + + +_Left = TypeVar("_Left", int, Infinity) +_Right = TypeVar("_Right", int, Infinity) + + @dataclasses.dataclass(frozen=True, init=False) -class UnitRange(Sequence[int], Set[int]): +class UnitRange(Sequence[int], Generic[_Left, _Right]): """Range from `start` to `stop` with step size one.""" - start: int - stop: int + start: _Left + stop: _Right - def __init__(self, start: core_defs.IntegralScalar, stop: core_defs.IntegralScalar) -> None: + def __init__( + self, start: core_defs.IntegralScalar | Infinity, stop: core_defs.IntegralScalar | Infinity + ) -> None: if start < stop: - object.__setattr__(self, "start", int(start)) - object.__setattr__(self, "stop", int(stop)) + object.__setattr__(self, "start", _as_int(start)) + object.__setattr__(self, "stop", _as_int(stop)) else: # make UnitRange(0,0) the single empty UnitRange object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) - # TODO: the whole infinity idea and implementation is broken and should be replaced @classmethod - def infinity(cls) -> UnitRange: - return cls(Infinity.negative(), Infinity.positive()) + def infinite( + cls, + ) -> UnitRange: + return cls(Infinity.NEGATIVE, Infinity.POSITIVE) def __len__(self) -> int: - if Infinity.positive() in (abs(self.start), abs(self.stop)): - return Infinity.positive() - return max(0, self.stop - self.start) + if UnitRange.is_finite(self): + return max(0, self.stop - self.start) + raise ValueError("Cannot compute length of open 'UnitRange'.") + + @classmethod + def is_finite(cls, obj: UnitRange) -> TypeGuard[FiniteUnitRange]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.start is not Infinity.NEGATIVE and obj.stop is not Infinity.POSITIVE + + @classmethod + def is_right_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[_Left, int]]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.stop is not Infinity.POSITIVE + + @classmethod + def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.start is not Infinity.NEGATIVE def __repr__(self) -> str: return f"UnitRange({self.start}, {self.stop})" @@ -122,6 +168,7 @@ def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unuse ... def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # redefine unused + assert UnitRange.is_finite(self) if isinstance(index, slice): start, stop, step = index.indices(len(self)) if step != 1: @@ -138,61 +185,60 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re else: raise IndexError("'UnitRange' index out of range") - def __and__(self, other: Set[int]) -> UnitRange: - if isinstance(other, UnitRange): - start = max(self.start, other.start) - stop = min(self.stop, other.stop) - return UnitRange(start, stop) - else: - raise NotImplementedError( - "Can only find the intersection between 'UnitRange' instances." - ) + def __and__(self, other: UnitRange) -> UnitRange: + return UnitRange(max(self.start, other.start), min(self.stop, other.stop)) + + def __contains__(self, value: Any) -> bool: + return ( + isinstance(value, core_defs.INTEGRAL_TYPES) + and value >= self.start + and value < self.stop + ) + + def __le__(self, other: UnitRange) -> bool: + return self.start >= other.start and self.stop <= other.stop + + def __lt__(self, other: UnitRange) -> bool: + return (self.start > other.start and self.stop <= other.stop) or ( + self.start >= other.start and self.stop < other.stop + ) + + def __ge__(self, other: UnitRange) -> bool: + return self.start <= other.start and self.stop >= other.stop - def __le__(self, other: Set[int]): + def __gt__(self, other: UnitRange) -> bool: + return (self.start < other.start and self.stop >= other.stop) or ( + self.start <= other.start and self.stop > other.stop + ) + + def __eq__(self, other: Any) -> bool: if isinstance(other, UnitRange): - return self.start >= other.start and self.stop <= other.stop - elif len(self) == Infinity.positive(): - return False - else: - return Set.__le__(self, other) - - def __add__(self, other: int | Set[int]) -> UnitRange: - if isinstance(other, int): - if other == Infinity.positive(): - return UnitRange.infinity() - elif other == Infinity.negative(): - return UnitRange(0, 0) - return UnitRange( - *( - s if s in [Infinity.negative(), Infinity.positive()] else s + other - for s in (self.start, self.stop) - ) - ) - else: - raise NotImplementedError("Can only compute union with 'int' instances.") - - def __sub__(self, other: int | Set[int]) -> UnitRange: - if isinstance(other, int): - if other == Infinity.negative(): - return self + Infinity.positive() - elif other == Infinity.positive(): - return self + Infinity.negative() - else: - return self + (-other) + return self.start == other.start and self.stop == other.stop else: - raise NotImplementedError("Can only compute substraction with 'int' instances.") + return False + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) - __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented + def __add__(self, other: int) -> UnitRange: + return UnitRange(self.start + other, self.stop + other) + + def __sub__(self, other: int) -> UnitRange: + return UnitRange(self.start - other, self.stop - other) def __str__(self) -> str: return f"({self.start}:{self.stop})" +FiniteUnitRange: TypeAlias = UnitRange[int, int] + + RangeLike: TypeAlias = ( UnitRange | range | tuple[core_defs.IntegralScalar, core_defs.IntegralScalar] | core_defs.IntegralScalar + | None ) @@ -207,18 +253,23 @@ def unit_range(r: RangeLike) -> UnitRange: # once the related mypy bug (#16358) gets fixed if ( isinstance(r, tuple) - and isinstance(r[0], core_defs.INTEGRAL_TYPES) - and isinstance(r[1], core_defs.INTEGRAL_TYPES) + and (isinstance(r[0], core_defs.INTEGRAL_TYPES) or r[0] in (None, Infinity.NEGATIVE)) + and (isinstance(r[1], core_defs.INTEGRAL_TYPES) or r[1] in (None, Infinity.POSITIVE)) ): - return UnitRange(r[0], r[1]) + start = r[0] if r[0] is not None else Infinity.NEGATIVE + stop = r[1] if r[1] is not None else Infinity.POSITIVE + return UnitRange(start, stop) if isinstance(r, core_defs.INTEGRAL_TYPES): return UnitRange(0, cast(core_defs.IntegralScalar, r)) + if r is None: + return UnitRange.infinite() raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple +FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement @@ -245,6 +296,10 @@ def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: ) +def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: + return UnitRange.is_finite(v[1]) + + def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) @@ -283,18 +338,27 @@ def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: return (v[0], unit_range(v[1])) +_Rng = TypeVar( + "_Rng", + UnitRange[int, int], + UnitRange[Infinity, int], + UnitRange[int, Infinity], + UnitRange[Infinity, Infinity], +) + + @dataclasses.dataclass(frozen=True, init=False) -class Domain(Sequence[NamedRange]): +class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]): """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" dims: tuple[Dimension, ...] - ranges: tuple[UnitRange, ...] + ranges: tuple[_Rng, ...] def __init__( self, - *args: NamedRange, + *args: tuple[Dimension, _Rng], dims: Optional[Sequence[Dimension]] = None, - ranges: Optional[Sequence[UnitRange]] = None, + ranges: Optional[Sequence[_Rng]] = None, ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: @@ -343,16 +407,23 @@ def ndim(self) -> int: def shape(self) -> tuple[int, ...]: return tuple(len(r) for r in self.ranges) + @classmethod + def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return all(UnitRange.is_finite(rng) for rng in obj.ranges) + @overload - def __getitem__(self, index: int) -> NamedRange: + def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... @overload - def __getitem__(self, index: slice) -> Domain: # noqa: F811 # redefine unused + def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused ... @overload - def __getitem__(self, index: Dimension) -> NamedRange: # noqa: F811 # redefine unused + def __getitem__( # noqa: F811 # redefine unused + self, index: Dimension + ) -> tuple[Dimension, _Rng]: ... def __getitem__( # noqa: F811 # redefine unused @@ -434,6 +505,9 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return Domain(dims=dims, ranges=ranges) +FiniteDomain: TypeAlias = Domain[FiniteUnitRange] + + DomainLike: TypeAlias = ( Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] ) # `Domain` is `Sequence[NamedRange]` and therefore a subset @@ -484,7 +558,7 @@ def _broadcast_ranges( broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange] ) -> tuple[UnitRange, ...]: return tuple( - ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims + ranges[dims.index(d)] if d in dims else UnitRange.infinite() for d in broadcast_dims ) @@ -847,7 +921,7 @@ def asnumpy(self) -> Never: @functools.cached_property def domain(self) -> Domain: - return Domain(dims=(self.dimension,), ranges=(UnitRange.infinity(),)) + return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) @property def __gt_dims__(self) -> tuple[Dimension, ...]: diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 87e0800a10..94efe4d61d 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -58,6 +58,7 @@ def _relative_sub_domain( else: # not in new domain assert common.is_int_index(idx) + assert common.UnitRange.is_finite(rng) new_index = (rng.start if idx >= 0 else rng.stop) + idx if new_index < rng.start or new_index >= rng.stop: raise embedded_exceptions.IndexOutOfBounds( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index fbfe64ac42..8bd2673db9 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -113,6 +113,7 @@ def __gt_dims__(self) -> tuple[common.Dimension, ...]: @property def __gt_origin__(self) -> tuple[int, ...]: + assert common.Domain.is_finite(self._domain) return tuple(-r.start for _, r in self._domain) @property @@ -386,6 +387,7 @@ def inverse_image( assert isinstance(image_range, common.UnitRange) + assert common.UnitRange.is_finite(image_range) restricted_mask = (self._ndarray >= image_range.start) & ( self._ndarray < image_range.stop ) @@ -566,9 +568,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] named_ranges.append((dim, field.domain[pos][1])) else: domain_slice.append(np.newaxis) - named_ranges.append( - (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) - ) + named_ranges.append((dim, common.UnitRange.infinite())) return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) @@ -638,14 +638,19 @@ def _compute_slice( ValueError: If `new_rng` is not an integer or a UnitRange. """ if isinstance(rng, common.UnitRange): - if domain.ranges[pos] == common.UnitRange.infinity(): - return slice(None) - else: - return slice( - rng.start - domain.ranges[pos].start, - rng.stop - domain.ranges[pos].start, - ) + start = ( + rng.start - domain.ranges[pos].start + if common.UnitRange.is_left_finite(domain.ranges[pos]) + else None + ) + stop = ( + rng.stop - domain.ranges[pos].start + if common.UnitRange.is_right_finite(domain.ranges[pos]) + else None + ) + return slice(start, stop) elif common.is_int_index(rng): + assert common.Domain.is_finite(domain) return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: '{type(rng)}'.") diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 93f17b1eb8..278dde9180 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -192,7 +192,7 @@ def broadcast( np.asarray(field)[ tuple([np.newaxis] * len(dims)) ], # TODO(havogt) use FunctionField once available - domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinity()] * len(dims))), + domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinite()] * len(dims))), ) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index a4f32929db..ef70a2e645 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1059,7 +1059,7 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override @property def domain(self) -> common.Domain: - return common.Domain((self._dimension, common.UnitRange.infinity())) + return common.Domain((self._dimension, common.UnitRange.infinite())) @property def codomain(self) -> type[core_defs.int32]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 65f9d9d71a..037c4f3e4d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -24,10 +24,9 @@ import gt4py.next.iterator.ir as itir import gt4py.next.program_processors.otf_compile_executor as otf_exec import gt4py.next.program_processors.processor_interface as ppi -from gt4py.next.common import Dimension, Domain, UnitRange, is_field -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider -from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms -from gt4py.next.otf.compilation import cache +from gt4py.next import common +from gt4py.next.iterator import embedded as itir_embedded, transforms as itir_transforms +from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG @@ -40,7 +39,8 @@ cp = None -def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: +def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.FiniteUnitRange]: + assert common.Domain.is_finite(domain) sorted_dims = get_sorted_dims(domain.dims) return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] @@ -54,7 +54,7 @@ def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: def convert_arg(arg: Any): - if is_field(arg): + if common.is_field(arg): sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) dim_indices = [dim_index for dim_index, _ in sorted_dims] @@ -67,9 +67,11 @@ def convert_arg(arg: Any): def preprocess_program( - program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode + program: itir.FencilDefinition, + offset_provider: Mapping[str, Any], + lift_mode: itir_transforms.LiftMode, ): - node = apply_common_transforms( + node = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, lift_mode=lift_mode, @@ -81,7 +83,7 @@ def preprocess_program( if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): fencil_definition = node else: - fencil_definition = apply_common_transforms( + fencil_definition = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, force_inline_lambda_args=True, @@ -109,7 +111,7 @@ def _ensure_is_on_device( def get_connectivity_args( - neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]], + neighbor_tables: Sequence[tuple[str, itir_embedded.NeighborTableOffsetProvider]], device: dace.dtypes.DeviceType, ) -> dict[str, Any]: return { @@ -134,7 +136,7 @@ def get_offset_args( return { str(sym): -drange.start for param, arg in zip(params, args) - if is_field(arg) + if common.is_field(arg) for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) } @@ -162,13 +164,19 @@ def get_stride_args( def get_cache_id( program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - column_axis: Optional[Dimension], + column_axis: Optional[common.Dimension], offset_provider: Mapping[str, Any], ) -> str: max_neighbors = [ (k, v.max_neighbors) for k, v in offset_provider.items() - if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider)) + if isinstance( + v, + ( + itir_embedded.NeighborTableOffsetProvider, + itir_embedded.StridedNeighborOffsetProvider, + ), + ) ] cache_id_args = [ str(arg) @@ -191,8 +199,8 @@ def build_sdfg_from_itir( offset_provider: dict[str, Any], auto_optimize: bool = False, on_gpu: bool = False, - column_axis: Optional[Dimension] = None, - lift_mode: LiftMode = LiftMode.FORCE_INLINE, + column_axis: Optional[common.Dimension] = None, + lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, ) -> dace.SDFG: """Translate a Fencil into an SDFG. @@ -210,7 +218,7 @@ def build_sdfg_from_itir( """ # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force # `lift_more` to `FORCE_INLINE` mode. - lift_mode = LiftMode.FORCE_INLINE + lift_mode = itir_transforms.LiftMode.FORCE_INLINE arg_types = [type_translation.from_value(arg) for arg in args] device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU @@ -237,7 +245,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) auto_optimize = kwargs.get("auto_optimize", False) - lift_mode = kwargs.get("lift_mode", LiftMode.FORCE_INLINE) + lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] @@ -263,7 +271,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): lift_mode=lift_mode, ) - sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): dace.config.Config.set("compiler", "build_type", value=build_type) dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 1a38e5245e..6863b09c12 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -11,7 +11,6 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import dataclasses import itertools import math import operator @@ -20,7 +19,7 @@ import numpy as np import pytest -from gt4py.next import common, embedded +from gt4py.next import common from gt4py.next.common import Dimension, Domain, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice @@ -353,7 +352,7 @@ def test_cartesian_remap_implementation(): common.field( np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinity())), + Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinite())), ) ), ( @@ -362,7 +361,7 @@ def test_cartesian_remap_implementation(): common.field( np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange.infinity(), UnitRange(0, 10))), + Domain(dims=(IDim, JDim), ranges=(UnitRange.infinite(), UnitRange(0, 10))), ) ), ( @@ -373,7 +372,7 @@ def test_cartesian_remap_implementation(): ), Domain( dims=(IDim, JDim, KDim), - ranges=(UnitRange.infinity(), UnitRange(0, 10), UnitRange.infinity()), + ranges=(UnitRange.infinite(), UnitRange(0, 10), UnitRange.infinite()), ), ) ), diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index bafabfb56e..7650e90c3c 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -14,6 +14,7 @@ import operator from typing import Optional, Pattern +import numpy as np import pytest from gt4py.next.common import ( @@ -41,6 +42,56 @@ def a_domain(): return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) +@pytest.fixture(params=[Infinity.POSITIVE, Infinity.NEGATIVE]) +def unbounded(request): + yield request.param + + +def test_unbounded_add_sub(unbounded): + assert unbounded + 1 == unbounded + assert unbounded - 1 == unbounded + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +@pytest.mark.parametrize("op", [operator.le, operator.lt]) +def test_unbounded_comparison_less(value, op): + assert not op(Infinity.POSITIVE, value) + assert op(value, Infinity.POSITIVE) + + assert op(Infinity.NEGATIVE, value) + assert not op(value, Infinity.NEGATIVE) + + assert op(Infinity.NEGATIVE, Infinity.POSITIVE) + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +@pytest.mark.parametrize("op", [operator.ge, operator.gt]) +def test_unbounded_comparison_greater(value, op): + assert op(Infinity.POSITIVE, value) + assert not op(value, Infinity.POSITIVE) + + assert not op(Infinity.NEGATIVE, value) + assert op(value, Infinity.NEGATIVE) + + assert not op(Infinity.NEGATIVE, Infinity.POSITIVE) + + +def test_unbounded_eq(unbounded): + assert unbounded == unbounded + assert unbounded <= unbounded + assert unbounded >= unbounded + assert not unbounded < unbounded + assert not unbounded > unbounded + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +def test_unbounded_max_min(value): + assert max(Infinity.POSITIVE, value) == Infinity.POSITIVE + assert min(Infinity.POSITIVE, value) == value + assert max(Infinity.NEGATIVE, value) == value + assert min(Infinity.NEGATIVE, value) == Infinity.NEGATIVE + + def test_empty_range(): expected = UnitRange(0, 0) assert UnitRange(1, 1) == expected @@ -58,9 +109,20 @@ def test_unit_range_length(rng): assert len(rng) == 10 -@pytest.mark.parametrize("rng_like", [(2, 4), range(2, 4), UnitRange(2, 4)]) -def test_unit_range_like(rng_like): - assert unit_range(rng_like) == UnitRange(2, 4) +@pytest.mark.parametrize( + "rng_like, expected", + [ + ((2, 4), UnitRange(2, 4)), + (range(2, 4), UnitRange(2, 4)), + (UnitRange(2, 4), UnitRange(2, 4)), + ((None, None), UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE)), + ((2, None), UnitRange(2, Infinity.POSITIVE)), + ((None, 4), UnitRange(Infinity.NEGATIVE, 4)), + (None, UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE)), + ], +) +def test_unit_range_like(rng_like, expected): + assert unit_range(rng_like) == expected def test_unit_range_repr(rng): @@ -94,13 +156,6 @@ def test_unit_range_slice_error(rng): rng[1:2:5] -def test_unit_range_set_intersection(rng): - with pytest.raises( - NotImplementedError, match="Can only find the intersection between 'UnitRange' instances." - ): - rng & {1, 5} - - @pytest.mark.parametrize( "rng1, rng2, expected", [ @@ -121,46 +176,65 @@ def test_unit_range_intersection(rng1, rng2, expected): @pytest.mark.parametrize( "rng1, rng2, expected", [ - (UnitRange(20, Infinity.positive()), UnitRange(10, 15), UnitRange(0, 0)), - (UnitRange(Infinity.negative(), 0), UnitRange(5, 10), UnitRange(0, 0)), - (UnitRange(Infinity.negative(), 0), UnitRange(-10, 0), UnitRange(-10, 0)), - (UnitRange(0, Infinity.positive()), UnitRange(Infinity.negative(), 5), UnitRange(0, 5)), + (UnitRange(20, Infinity.POSITIVE), UnitRange(10, 15), UnitRange(0, 0)), + (UnitRange(Infinity.NEGATIVE, 0), UnitRange(5, 10), UnitRange(0, 0)), + (UnitRange(Infinity.NEGATIVE, 0), UnitRange(-10, 0), UnitRange(-10, 0)), + (UnitRange(0, Infinity.POSITIVE), UnitRange(Infinity.NEGATIVE, 5), UnitRange(0, 5)), ( - UnitRange(Infinity.negative(), 0), - UnitRange(Infinity.negative(), 5), - UnitRange(Infinity.negative(), 0), + UnitRange(Infinity.NEGATIVE, 0), + UnitRange(Infinity.NEGATIVE, 5), + UnitRange(Infinity.NEGATIVE, 0), ), ( - UnitRange(Infinity.negative(), Infinity.positive()), - UnitRange(Infinity.negative(), Infinity.positive()), - UnitRange(Infinity.negative(), Infinity.positive()), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), ), ], ) -def test_unit_range_infinite_intersection(rng1, rng2, expected): +def test_unit_range_unbounded_intersection(rng1, rng2, expected): result = rng1 & rng2 assert result == expected -def test_positive_infinity_range(): - pos_inf_range = UnitRange(Infinity.positive(), Infinity.positive()) - assert len(pos_inf_range) == 0 +@pytest.mark.parametrize( + "rng", + [ + UnitRange(Infinity.NEGATIVE, 0), + UnitRange(0, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + ], +) +def test_positive_infinite_range_len(rng): + with pytest.raises(ValueError, match=r".*open.*"): + len(rng) -def test_mixed_infinity_range(): - mixed_inf_range = UnitRange(Infinity.negative(), Infinity.positive()) - assert len(mixed_inf_range) == Infinity.positive() +def test_range_contains(): + assert 1 in UnitRange(0, 2) + assert 1 not in UnitRange(0, 1) + assert 1 in UnitRange(0, Infinity.POSITIVE) + assert 1 in UnitRange(Infinity.NEGATIVE, 2) + assert 1 in UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE) + assert "s" not in UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE) @pytest.mark.parametrize( "op, rng1, rng2, expected", [ (operator.le, UnitRange(-1, 2), UnitRange(-2, 3), True), - (operator.le, UnitRange(-1, 2), {-1, 0, 1}, True), - (operator.le, UnitRange(-1, 2), {-1, 0}, False), - (operator.le, UnitRange(-1, 2), {-2, -1, 0, 1, 2}, True), - (operator.le, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 3), True), - (operator.le, UnitRange(Infinity.negative(), 2), {1, 2, 3}, False), + (operator.le, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.ge, UnitRange(-2, 3), UnitRange(-1, 2), True), + (operator.ge, UnitRange(Infinity.NEGATIVE, 3), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.lt, UnitRange(-1, 2), UnitRange(-2, 2), True), + (operator.lt, UnitRange(-2, 1), UnitRange(-2, 2), True), + (operator.lt, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.gt, UnitRange(-2, 2), UnitRange(-1, 2), True), + (operator.gt, UnitRange(-2, 2), UnitRange(-2, 1), True), + (operator.gt, UnitRange(Infinity.NEGATIVE, 3), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.eq, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.ne, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.ne, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 2), False), ], ) def test_range_comparison(op, rng1, rng2, expected): From 6c7c5d51b440c40175a25fb75fcbde7c919afbd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Mon, 18 Dec 2023 10:58:51 +0100 Subject: [PATCH 07/13] feat[dace]: Buildflags to the `ITIR -> SDFG` translation (#1389) Made it possible to also pass build flags to the `ITIR -> SDFG` translator. --- .../runners/dace_iterator/__init__.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 037c4f3e4d..59569de30b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -47,10 +47,6 @@ def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.FiniteUnitRa """ Default build configuration in DaCe backend """ _build_type = "Release" -# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins -_cpu_args = ( - "-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label" -) def convert_arg(arg: Any): @@ -242,6 +238,7 @@ def build_sdfg_from_itir( def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): # build parameters build_cache = kwargs.get("build_cache", None) + compiler_args = kwargs.get("compiler_args", None) # `None` will take default. build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) auto_optimize = kwargs.get("auto_optimize", False) @@ -274,7 +271,10 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): dace.config.Config.set("compiler", "build_type", value=build_type) - dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) + if compiler_args is not None: + dace.config.Config.set( + "compiler", "cuda" if on_gpu else "cpu", "args", value=compiler_args + ) sdfg_program = sdfg.compile(validate=False) # store SDFG program in build cache @@ -312,12 +312,21 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + compiler_args = dace.config.Config.get("compiler", "cpu", "args") + + # disable finite-math-only in order to support isfinite/isinf/isnan builtins + if "-ffast-math" in compiler_args: + compiler_args += " -fno-finite-math-only" + if "-ffinite-math-only" in compiler_args: + compiler_args.replace("-ffinite-math-only", "") + run_dace_iterator( program, *args, **kwargs, build_cache=_build_cache_cpu, build_type=_build_type, + compiler_args=compiler_args, on_gpu=False, ) From 315d9203bb667baa3daaea4b797a0846a2b70887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 19 Dec 2023 07:35:51 +0100 Subject: [PATCH 08/13] feat[dace]: Computing SDFG call arguments (#1398) Added a function to get the arguments to call an SDFG. This commit adds a function that allows to generate the arguments needed to call an SDFG, before this was part of `run_dace_iterator()`. This made it very complex to run an SDFG outside this function. One should consider this as an amend to [PR #1379](https://github.com/GridTools/gt4py/pull/1379). --- .../runners/dace_iterator/__init__.py | 79 ++++++++++++------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 59569de30b..97dd90eb54 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -90,8 +90,9 @@ def preprocess_program( return fencil_definition -def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: - return {name.id: convert_arg(arg) for name, arg in zip(params, args)} +def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: + sdfg_params: Sequence[str] = sdfg.arg_names + return {sdfg_param: convert_arg(arg) for sdfg_param, arg in zip(sdfg_params, args)} def _ensure_is_on_device( @@ -127,13 +128,16 @@ def get_shape_args( def get_offset_args( - arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] + sdfg: dace.SDFG, + args: Sequence[Any], ) -> Mapping[str, int]: + sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays + sdfg_params: Sequence[str] = sdfg.arg_names return { str(sym): -drange.start - for param, arg in zip(params, args) + for sdfg_param, arg in zip(sdfg_params, args) if common.is_field(arg) - for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) + for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain)) } @@ -189,6 +193,45 @@ def get_cache_id( return m.hexdigest() +def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: + """Extracts the arguments needed to call the SDFG. + + This function can handle the same arguments that are passed to `run_dace_iterator()`. + + Args: + sdfg: The SDFG for which we want to get the arguments. + """ # noqa: D401 + offset_provider = kwargs["offset_provider"] + on_gpu = kwargs.get("on_gpu", False) + + neighbor_tables = filter_neighbor_tables(offset_provider) + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + + dace_args = get_args(sdfg, args) + dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} + dace_conn_args = get_connectivity_args(neighbor_tables, device) + dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) + dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) + dace_strides = get_stride_args(sdfg.arrays, dace_field_args) + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg, args) + all_args = { + **dace_args, + **dace_conn_args, + **dace_shapes, + **dace_conn_shapes, + **dace_strides, + **dace_conn_strides, + **dace_offsets, + } + expected_args = { + key: value + for key, value in all_args.items() + if key in sdfg.signature_arglist(with_types=False) + } + return expected_args + + def build_sdfg_from_itir( program: itir.FencilDefinition, *args, @@ -248,8 +291,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: @@ -281,29 +322,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): if build_cache is not None: build_cache[cache_id] = sdfg_program - dace_args = get_args(program.params, args) - dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} - dace_conn_args = get_connectivity_args(neighbor_tables, device) - dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) - dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) - dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) - dace_offsets = get_offset_args(sdfg.arrays, program.params, args) - - all_args = { - **dace_args, - **dace_conn_args, - **dace_shapes, - **dace_conn_shapes, - **dace_strides, - **dace_conn_strides, - **dace_offsets, - } - expected_args = { - key: value - for key, value in all_args.items() - if key in sdfg.signature_arglist(with_types=False) - } + expected_args = get_sdfg_args(sdfg, *args, **kwargs) with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) From 15a7bd627d9fc818befd5f6ff6e795868563ff37 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 19 Dec 2023 08:43:41 +0100 Subject: [PATCH 09/13] fix[next][dace]: Fix memlet for array slicing (#1399) Implementation of array slicing in DaCe backend changed to a mapped tasklet. Tested on GPU. CUDA code generation did not support the previous implementation, based on memlet in nested-SDFG. --- .../runners/dace_iterator/itir_to_tasklet.py | 66 ++++++------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index d08476847f..4c202b1fe8 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -18,7 +18,6 @@ import dace import numpy as np -from dace import subsets from dace.transformation.dataflow import MapFusion import gt4py.eve.codegen @@ -754,52 +753,29 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: dace.Memlet.simple(node.data, "0") for node in deref_nodes[1:] ] - # we create a nested sdfg in order to access the index scalar values as symbols in a memlet subset - deref_sdfg = dace.SDFG("deref") - deref_sdfg.add_array( - "_inp", field_array.shape, iterator.dtype, strides=field_array.strides - ) - for connector in deref_connectors[1:]: - deref_sdfg.add_scalar(connector, _INDEX_DTYPE) - deref_sdfg.add_array("_out", result_shape, iterator.dtype) - deref_init_state = deref_sdfg.add_state("init", True) - deref_access_state = deref_sdfg.add_state("access") - deref_sdfg.add_edge( - deref_init_state, - deref_access_state, - dace.InterstateEdge( - assignments={f"_sym{inp}": inp for inp in deref_connectors[1:]} - ), - ) - # we access the size in source field shape as symbols set on the nested sdfg - source_subset = tuple( - f"_sym_i_{dim}" if dim in iterator.indices else f"0:{size}" + # we create a mapped tasklet for array slicing + map_ranges = { + f"_i_{dim}": f"0:{size}" for dim, size in zip(sorted_dims, field_array.shape) + if dim not in iterator.indices + } + src_subset = ",".join([f"_i_{dim}" for dim in sorted_dims]) + dst_subset = ",".join( + [f"_i_{dim}" for dim in sorted_dims if dim not in iterator.indices] ) - deref_access_state.add_nedge( - deref_access_state.add_access("_inp"), - deref_access_state.add_access("_out"), - dace.Memlet( - data="_out", - subset=subsets.Range.from_array(result_array), - other_subset=",".join(source_subset), - ), - ) - - deref_node = self.context.state.add_nested_sdfg( - deref_sdfg, - self.context.body, - inputs=set(deref_connectors), - outputs={"_out"}, - ) - for connector, node, memlet in zip(deref_connectors, deref_nodes, deref_memlets): - self.context.state.add_edge(node, None, deref_node, connector, memlet) - self.context.state.add_edge( - deref_node, - "_out", - result_node, - None, - dace.Memlet.from_array(result_name, result_array), + self.context.state.add_mapped_tasklet( + "deref", + map_ranges, + inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, + outputs={ + "_out": dace.Memlet.from_array(result_name, result_array), + }, + code=f"_out[{dst_subset}] = _inp[{src_subset}]", + external_edges=True, + input_nodes={node.data: node for node in deref_nodes}, + output_nodes={ + result_name: result_node, + }, ) return [ValueExpr(result_node, iterator.dtype)] From af33e21fab16fb3de13ec5721b050dada63e220c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Tue, 19 Dec 2023 11:31:23 +0100 Subject: [PATCH 10/13] fix[dace]: Fixed SDFG args (#1400) Modified how the SDFG arguments are computed. It was noticed that some transformations, especially the `SDFG.apply_gpu_transformation()`, to the SDFG, added new arguments to the SDFG. But, since a lot of functions build on the `SDFG.arg_names` member and this member was populated before the transformation, an error occurred. Thus it was changed such that `SDFG.arg_names` was only populated with the arguments also known to the Fencil. --- .../runners/dace_iterator/__init__.py | 17 ++++++++--------- .../runners/dace_iterator/itir_to_sdfg.py | 11 +++-------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 97dd90eb54..7fd4794e57 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -207,6 +207,7 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: neighbor_tables = filter_neighbor_tables(offset_provider) device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + sdfg_sig = sdfg.signature_arglist(with_types=False) dace_args = get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_connectivity_args(neighbor_tables, device) @@ -224,11 +225,8 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: **dace_conn_strides, **dace_offsets, } - expected_args = { - key: value - for key, value in all_args.items() - if key in sdfg.signature_arglist(with_types=False) - } + expected_args = {key: all_args[key] for key in sdfg_sig} + return expected_args @@ -258,21 +256,22 @@ def build_sdfg_from_itir( # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force # `lift_more` to `FORCE_INLINE` mode. lift_mode = itir_transforms.LiftMode.FORCE_INLINE - arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) + # TODO: According to Lex one should build the SDFG first in a general mannor. + # Generalisation to a particular device should happen only at the end. sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) sdfg = sdfg_genenerator.visit(program) sdfg.simplify() # run DaCe auto-optimization heuristics if auto_optimize: - # TODO Investigate how symbol definitions improve autoopt transformations, - # in which case the cache table should take the symbols map into account. + # TODO: Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. symbols: dict[str, int] = {} + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index b3e6662623..e3b5ddf2ac 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -209,14 +209,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) # Create the call signature for the SDFG. - # All arguments required by the SDFG, regardless if explicit and implicit, are added - # as positional arguments. In the front are all arguments to the Fencil, in that - # order, they are followed by the arguments created by the translation process, - arg_list = [str(a) for a in node.params] - sig_list = program_sdfg.signature_arglist(with_types=False) - implicit_args = set(sig_list) - set(arg_list) - call_params = arg_list + [ia for ia in sig_list if ia in implicit_args] - program_sdfg.arg_names = call_params + # Only the arguments requiered by the Fencil, i.e. `node.params` are added as poitional arguments. + # The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. + program_sdfg.arg_names = [str(a) for a in node.params] program_sdfg.validate() return program_sdfg From b21dd566bcbd805279d94f36a20c5ea34a300d97 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 19 Dec 2023 12:06:59 +0100 Subject: [PATCH 11/13] feat[next]: Test for local dimension in output (#1392) Currently only supported in field view embedded. --- pyproject.toml | 1 + tests/next_tests/exclusion_matrices.py | 3 +++ .../ffront_tests/test_external_local_field.py | 19 +++++++++++++++++++ 3 files changed, 23 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 2cf4fb12e2..5d7a2f2cb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -342,6 +342,7 @@ markers = [ 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', 'uses_sparse_fields: tests that require backend support for sparse fields', + 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', 'uses_tuple_args: tests that require backend support for tuple arguments', 'uses_tuple_returns: tests that require backend support for tuple results', diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 3c42a180dd..f6d2b10a14 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -95,6 +95,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SPARSE_FIELDS = "uses_sparse_fields" +USES_SPARSE_FIELDS_AS_OUTPUT = "uses_sparse_fields_as_output" USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" USES_TUPLE_ARGS = "uses_tuple_args" @@ -119,6 +120,7 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), ] DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ (USES_CONSTANT_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), @@ -159,4 +161,5 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ProgramFormatterId.GTFN_CPP_FORMATTER: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], + ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], } diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 42938e2f4b..698dce2b5c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -82,3 +82,22 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 out=cases.allocate(unstructured_case, testee, cases.RETURN)(), ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1), ) + + +@pytest.mark.uses_sparse_fields_as_output +def test_write_local_field(unstructured_case): + @gtx.field_operator + def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: + return inp(V2E) + + out = unstructured_case.as_field( + [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + cases.verify( + unstructured_case, + testee, + inp, + out=out, + ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table], + ) From 100bc7fee17e9235da070e1bbf0fedd615de541f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 3 Jan 2024 12:09:23 +0100 Subject: [PATCH 12/13] Add missing grid_type argument to scan operator decorator (#1404) --- src/gt4py/next/ffront/decorator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 4abd8f156a..53159008f0 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -775,6 +775,7 @@ def scan_operator( forward: bool, init: core_defs.Scalar, backend: Optional[str], + grid_type: GridType, ) -> FieldOperator[foast.ScanOperator]: ... @@ -786,6 +787,7 @@ def scan_operator( forward: bool, init: core_defs.Scalar, backend: Optional[str], + grid_type: GridType, ) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... @@ -797,6 +799,7 @@ def scan_operator( forward: bool = True, init: core_defs.Scalar = 0.0, backend=None, + grid_type: GridType = None, ) -> ( FieldOperator[foast.ScanOperator] | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] @@ -834,6 +837,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, backend, + grid_type, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, ) From 7a9489f73ddddd6aff219fc3890bed23e791a9a8 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Thu, 4 Jan 2024 00:47:33 +0100 Subject: [PATCH 13/13] Fix size check in CollapseTuple pass (#1405) --- src/gt4py/next/iterator/transforms/collapse_tuple.py | 3 +++ src/gt4py/next/iterator/type_inference.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 7d710fc919..30457f2246 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -41,6 +41,9 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t ): return UnknownLength + if not type_.dtype.has_known_length: + return UnknownLength + return len(type_.dtype) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 2375118cd1..68627cfd89 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -77,6 +77,12 @@ def __iter__(self) -> abc.Iterator[Type]: raise ValueError(f"Can not iterate over partially defined tuple '{self}'.") yield from self.others + @property + def has_known_length(self): + return isinstance(self.others, EmptyTuple) or ( + isinstance(self.others, Tuple) and self.others.has_known_length + ) + def __len__(self) -> int: return sum(1 for _ in self)