diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 4abd8f156a..53159008f0 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -775,6 +775,7 @@ def scan_operator( forward: bool, init: core_defs.Scalar, backend: Optional[str], + grid_type: GridType, ) -> FieldOperator[foast.ScanOperator]: ... @@ -786,6 +787,7 @@ def scan_operator( forward: bool, init: core_defs.Scalar, backend: Optional[str], + grid_type: GridType, ) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... @@ -797,6 +799,7 @@ def scan_operator( forward: bool = True, init: core_defs.Scalar = 0.0, backend=None, + grid_type: GridType = None, ) -> ( FieldOperator[foast.ScanOperator] | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] @@ -834,6 +837,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, backend, + grid_type, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, )