diff --git a/src/gt4py/next/embedded/operators.py b/src/gt4py/next/embedded/operators.py index f50ace7687..0992401ebb 100644 --- a/src/gt4py/next/embedded/operators.py +++ b/src/gt4py/next/embedded/operators.py @@ -17,7 +17,7 @@ from gt4py import eve from gt4py._core import definitions as core_defs -from gt4py.next import common, constructors, utils +from gt4py.next import common, constructors, errors, utils from gt4py.next.embedded import common as embedded_common, context as embedded_context @@ -77,17 +77,20 @@ def scan_loop(hpos): def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): if "out" in kwargs: # called from program or direct field_operator as program - offset_provider = kwargs.pop("offset_provider", None) - new_context_kwargs = {} if embedded_context.within_context(): # called from program - assert offset_provider is None + assert "offset_provider" not in kwargs else: # field_operator as program + if "offset_provider" not in kwargs: + raise errors.MissingArgumentError(None, "offset_provider", True) + offset_provider = kwargs.pop("offset_provider", None) + new_context_kwargs["offset_provider"] = offset_provider out = kwargs.pop("out") + domain = kwargs.pop("domain", None) flattened_out: tuple[common.Field, ...] = utils.flatten_nested_tuple((out,)) @@ -105,7 +108,10 @@ def field_operator_call(op: EmbeddedOperator, args: Any, kwargs: Any): domain=out_domain, ) else: - # called from other field_operator + # called from other field_operator or missing `out` argument + if "offset_provider" in kwargs: + # assuming we wanted to call the field_operator as program, otherwise `offset_provider` would not be there + raise errors.MissingArgumentError(None, "out", True) return op(*args, **kwargs) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 61441e83b9..dd48d6f0f9 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -21,6 +21,7 @@ from .exceptions import ( DSLError, InvalidParameterAnnotationError, + MissingArgumentError, MissingAttributeError, MissingParameterAnnotationError, UndefinedSymbolError, @@ -33,6 +34,7 @@ "InvalidParameterAnnotationError", "MissingAttributeError", "MissingParameterAnnotationError", + "MissingArgumentError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", "set_verbose_exceptions", diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 081453c023..858f969447 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -81,6 +81,18 @@ def __init__(self, location: Optional[SourceLocation], attr_name: str) -> None: self.attr_name = attr_name +class MissingArgumentError(DSLError): + arg_name: str + is_kwarg: bool + + def __init__(self, location: Optional[SourceLocation], arg_name: str, is_kwarg: bool) -> None: + super().__init__( + location, f"Expected {'keyword-' if is_kwarg else ''}argument '{arg_name}'." + ) + self.attr_name = arg_name + self.is_kwarg = is_kwarg + + class TypeError_(DSLError): def __init__(self, location: Optional[SourceLocation], message: str) -> None: super().__init__(location, message) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 53159008f0..76a0ddcde0 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -29,10 +29,11 @@ from devtools import debug +from gt4py import eve from gt4py._core import definitions as core_defs from gt4py.eve import utils as eve_utils from gt4py.eve.extended_typing import Any, Optional -from gt4py.next import allocators as next_allocators, embedded as next_embedded +from gt4py.next import allocators as next_allocators, embedded as next_embedded, errors from gt4py.next.common import Dimension, DimensionKind, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( @@ -61,11 +62,10 @@ sym, ) from gt4py.next.program_processors import processor_interface as ppi -from gt4py.next.program_processors.runners import roundtrip from gt4py.next.type_system import type_info, type_specifications as ts, type_translation -DEFAULT_BACKEND: Callable = roundtrip.executor +DEFAULT_BACKEND: Callable = None def _get_closure_vars_recursively(closure_vars: dict[str, Any]) -> dict[str, Any]: @@ -176,15 +176,15 @@ class Program: past_node: past.Program closure_vars: dict[str, Any] - definition: Optional[types.FunctionType] = None - backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND - grid_type: Optional[GridType] = None + definition: Optional[types.FunctionType] + backend: Optional[ppi.ProgramExecutor] + grid_type: Optional[GridType] @classmethod def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND, + backend: Optional[ppi.ProgramExecutor], grid_type: Optional[GridType] = None, ) -> Program: source_def = SourceDefinition.from_function(definition) @@ -495,7 +495,7 @@ def program(*, backend: Optional[ppi.ProgramExecutor]) -> Callable[[types.Functi def program( definition=None, *, - backend=None, + backend=eve.NOTHING, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) grid_type=None, ) -> Program | Callable[[types.FunctionType], Program]: """ @@ -517,7 +517,9 @@ def program( """ def program_inner(definition: types.FunctionType) -> Program: - return Program.from_function(definition, backend, grid_type) + return Program.from_function( + definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type + ) return program_inner if definition is None else program_inner(definition) @@ -549,9 +551,9 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): foast_node: OperatorNodeT closure_vars: dict[str, Any] - definition: Optional[types.FunctionType] = None - backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND - grid_type: Optional[GridType] = None + definition: Optional[types.FunctionType] + backend: Optional[ppi.ProgramExecutor] + grid_type: Optional[GridType] operator_attributes: Optional[dict[str, Any]] = None _program_cache: dict = dataclasses.field(default_factory=dict) @@ -559,7 +561,7 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]): def from_function( cls, definition: types.FunctionType, - backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND, + backend: Optional[ppi.ProgramExecutor], grid_type: Optional[GridType] = None, *, operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, @@ -686,6 +688,7 @@ def as_program( self._program_cache[hash_] = Program( past_node=past_node, closure_vars=closure_vars, + definition=None, backend=self.backend, grid_type=self.grid_type, ) @@ -698,7 +701,12 @@ def __call__( ) -> None: if not next_embedded.context.within_context() and self.backend is not None: # non embedded execution - offset_provider = kwargs.pop("offset_provider", None) + if "offset_provider" not in kwargs: + raise errors.MissingArgumentError(None, "offset_provider", True) + offset_provider = kwargs.pop("offset_provider") + + if "out" not in kwargs: + raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs) # TODO(tehrengruber): check all offset providers are given @@ -744,7 +752,7 @@ def field_operator( ... -def field_operator(definition=None, *, backend=None, grid_type=None): +def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None): """ Generate an implementation of the field operator from a Python function object. @@ -762,7 +770,9 @@ def field_operator(definition=None, *, backend=None, grid_type=None): """ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast.FieldOperator]: - return FieldOperator.from_function(definition, backend, grid_type) + return FieldOperator.from_function( + definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, grid_type + ) return field_operator_inner if definition is None else field_operator_inner(definition) @@ -798,7 +808,7 @@ def scan_operator( axis: Dimension, forward: bool = True, init: core_defs.Scalar = 0.0, - backend=None, + backend=eve.NOTHING, grid_type: GridType = None, ) -> ( FieldOperator[foast.ScanOperator] @@ -836,8 +846,7 @@ def scan_operator( def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, - backend, - grid_type, + DEFAULT_BACKEND if backend is eve.NOTHING else backend, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index e25576ebde..1f5a1f0c48 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -22,6 +22,8 @@ import gt4py.next as gtx from gt4py.next.ffront import decorator from gt4py.next.iterator import ir as itir +from gt4py.next.program_processors import processor_interface as ppi +from gt4py.next.program_processors.runners import gtfn, roundtrip try: @@ -36,9 +38,10 @@ import next_tests.exclusion_matrices as definitions +@ppi.program_executor def no_backend(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> None: """Temporary default backend to not accidentally test the wrong backend.""" - raise ValueError("No backend selected. Backend selection is mandatory in tests.") + raise ValueError("No backend selected! Backend selection is mandatory in tests.") OPTIONAL_PROCESSORS = [] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py index 167ccbb0a5..4444742c66 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_builtin_execution.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import math -from typing import Callable +from typing import Callable, Optional import numpy as np import pytest @@ -22,6 +22,7 @@ from gt4py.next.ffront import dialect_ast_enums, fbuiltins, field_operator_ast as foast from gt4py.next.ffront.decorator import FieldOperator from gt4py.next.ffront.foast_passes.type_deduction import FieldOperatorTypeDeduction +from gt4py.next.program_processors import processor_interface as ppi from gt4py.next.type_system import type_translation from next_tests.integration_tests import cases @@ -39,7 +40,7 @@ # becomes easier. -def make_builtin_field_operator(builtin_name: str): +def make_builtin_field_operator(builtin_name: str, backend: Optional[ppi.ProgramExecutor]): # TODO(tehrengruber): creating a field operator programmatically should be # easier than what we need to do here. # construct annotation dictionary containing the input argument and return @@ -109,8 +110,9 @@ def make_builtin_field_operator(builtin_name: str): return FieldOperator( foast_node=typed_foast_node, closure_vars=closure_vars, - backend=None, definition=None, + backend=backend, + grid_type=None, ) @@ -129,9 +131,7 @@ def test_math_function_builtins_execution(cartesian_case, builtin_name: str, inp expected = ref_impl(*inputs) out = cartesian_case.as_field([IDim], np.zeros_like(expected)) - builtin_field_op = make_builtin_field_operator(builtin_name).with_backend( - cartesian_case.backend - ) + builtin_field_op = make_builtin_field_operator(builtin_name, cartesian_case.backend) builtin_field_op(*inps, out=out, offset_provider={}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py new file mode 100644 index 0000000000..ba4b1b0cdb --- /dev/null +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_embedded_regression.py @@ -0,0 +1,137 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +import pytest + +from gt4py import next as gtx +from gt4py.next import errors + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import IField, cartesian_case # noqa: F401 # fixtures +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( # noqa: F401 # fixtures + KDim, + fieldview_backend, +) + + +def test_default_backend_is_respected_field_operator(cartesian_case): # noqa: F811 # fixtures + """Test that manually calling the field operator without setting the backend raises an error.""" + + # Important not to set the backend here! + @gtx.field_operator + def copy(a: IField) -> IField: + return a + + a = cases.allocate(cartesian_case, copy, "a")() + + with pytest.raises(ValueError, match="No backend selected!"): + # Calling this should fail if the default backend is respected + # due to `fieldview_backend` fixture (dependency of `cartesian_case`) + # setting the default backend to something invalid. + _ = copy(a, out=a, offset_provider={}) + + +def test_default_backend_is_respected_scan_operator(cartesian_case): # noqa: F811 # fixtures + """Test that manually calling the scan operator without setting the backend raises an error.""" + + # Important not to set the backend here! + @gtx.scan_operator(axis=KDim, init=0.0, forward=True) + def sum(state: float, a: float) -> float: + return state + a + + a = gtx.ones({KDim: 10}, allocator=cartesian_case.backend) + + with pytest.raises(ValueError, match="No backend selected!"): + # see comment in field_operator test + _ = sum(a, out=a, offset_provider={}) + + +def test_default_backend_is_respected_program(cartesian_case): # noqa: F811 # fixtures + """Test that manually calling the program without setting the backend raises an error.""" + + @gtx.field_operator + def copy(a: IField) -> IField: + return a + + # Important not to set the backend here! + @gtx.program + def copy_program(a: IField, b: IField) -> IField: + copy(a, out=b) + + a = cases.allocate(cartesian_case, copy_program, "a")() + b = cases.allocate(cartesian_case, copy_program, "b")() + + with pytest.raises(ValueError, match="No backend selected!"): + # see comment in field_operator test + _ = copy_program(a, b, offset_provider={}) + + +def test_missing_arg_field_operator(cartesian_case): # noqa: F811 # fixtures + """Test that calling a field_operator without required args raises an error.""" + + @gtx.field_operator(backend=cartesian_case.backend) + def copy(a: IField) -> IField: + return a + + a = cases.allocate(cartesian_case, copy, "a")() + + with pytest.raises(errors.MissingArgumentError, match="'out'"): + _ = copy(a, offset_provider={}) + + with pytest.raises(errors.MissingArgumentError, match="'offset_provider'"): + _ = copy(a, out=a) + + +def test_missing_arg_scan_operator(cartesian_case): # noqa: F811 # fixtures + """Test that calling a scan_operator without required args raises an error.""" + + @gtx.scan_operator(backend=cartesian_case.backend, axis=KDim, init=0.0, forward=True) + def sum(state: float, a: float) -> float: + return state + a + + a = cases.allocate(cartesian_case, sum, "a")() + + with pytest.raises(errors.MissingArgumentError, match="'out'"): + _ = sum(a, offset_provider={}) + + with pytest.raises(errors.MissingArgumentError, match="'offset_provider'"): + _ = sum(a, out=a) + + +def test_missing_arg_program(cartesian_case): # noqa: F811 # fixtures + """Test that calling a program without required args raises an error.""" + + @gtx.field_operator + def copy(a: IField) -> IField: + return a + + a = cases.allocate(cartesian_case, copy, "a")() + b = cases.allocate(cartesian_case, copy, cases.RETURN)() + + with pytest.raises(errors.DSLError, match="Invalid call"): + + @gtx.program(backend=cartesian_case.backend) + def copy_program(a: IField, b: IField) -> IField: + copy(a) + + _ = copy_program(a, offset_provider={}) + + with pytest.raises(TypeError, match="'offset_provider'"): + + @gtx.program(backend=cartesian_case.backend) + def copy_program(a: IField, b: IField) -> IField: + copy(a, out=b) + + _ = copy_program(a)