diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index aa83aed163..524091cbe3 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import TYPE_CHECKING, Final, Iterable, Optional, Protocol, Sequence, TypeAlias +from typing import TYPE_CHECKING, Any, Final, Iterable, Optional, Protocol, Sequence, TypeAlias import dace from dace import subsets as sbs @@ -219,17 +219,26 @@ def _parse_fieldop_arg( ) -> ( gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr - | tuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr, ...] + | gtir_dataflow.ValueExpr + | tuple[ + gtir_dataflow.IteratorExpr + | gtir_dataflow.MemletExpr + | gtir_dataflow.ValueExpr + | tuple[Any, ...], + ..., + ] ): """Helper method to visit an expression passed as argument to a field operator.""" arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) - def get_arg_value(arg: FieldopData) -> gtir_dataflow.MemletExpr | gtir_dataflow.IteratorExpr: - # In case of scan field operator, the arguments to the vertical stencil are passed by value. + def get_arg_value( + arg: FieldopData, + ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: arg_expr = arg.get_local_view(domain) if not by_value or isinstance(arg_expr, gtir_dataflow.MemletExpr): return arg_expr + # In case of scan field operator, the arguments to the vertical stencil are passed by value. return gtir_dataflow.MemletExpr( arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg) ) @@ -277,7 +286,8 @@ def _create_field_operator( node_type: ts.FieldType | ts.TupleType, sdfg_builder: gtir_sdfg.SDFGBuilder, input_edges: Iterable[gtir_dataflow.DataflowInputEdge], - output_edges: gtir_dataflow.DataflowOutputEdge | tuple[gtir_dataflow.DataflowOutputEdge, ...], + output_edges: gtir_dataflow.DataflowOutputEdge + | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], scan_dim: Optional[gtx_common.Dimension] = None, ) -> FieldopResult: """ @@ -446,16 +456,18 @@ def translate_as_fieldop( if cpm.is_call_to(fieldop_expr, "scan"): return translate_scan(node, sdfg, state, sdfg_builder) - elif isinstance(fieldop_expr, gtir.Lambda): - # Default case, handled below: the argument expression is a lambda function - # representing the stencil operation to be computed over the field domain. - stencil_expr = fieldop_expr - elif cpm.is_ref_to(fieldop_expr, "deref"): + + assert isinstance(node.type, ts.FieldType) + if cpm.is_ref_to(fieldop_expr, "deref"): # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. stencil_expr = im.lambda_("a")(im.deref("a")) stencil_expr.expr.type = node.type.dtype + elif isinstance(fieldop_expr, gtir.Lambda): + # Default case, handled below: the argument expression is a lambda function + # representing the stencil operation to be computed over the field domain. + stencil_expr = fieldop_expr else: raise NotImplementedError( f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node." @@ -835,6 +847,7 @@ def translate_scan( ) -> FieldopResult: assert isinstance(node, gtir.FunCall) assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, (ts.FieldType, ts.TupleType)) fun_node = node.fun assert len(fun_node.args) == 2 @@ -886,7 +899,9 @@ def scan_output_name(input_name: str) -> str: # create list of params to the lambda function with associated node type lambda_symbols = {scan_state: scan_state_type} | { - str(p.id): arg.type for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) + str(p.id): arg.type + for p, arg in zip(stencil_expr.params[1:], node.args, strict=True) + if isinstance(arg.type, ts.DataType) } # visit the arguments to be passed to the lambda expression @@ -900,8 +915,8 @@ def scan_output_name(input_name: str) -> str: } # parse the dataflow input and output symbols - lambda_flat_args = {} - lambda_field_offsets = {} + lambda_flat_args: dict[str, FieldopData] = {} + lambda_field_offsets: dict[str, Optional[list[dace.symbolic.SymExpr]]] = {} for param, arg in lambda_args_mapping.items(): tuple_fields = flatten_tuples(param, arg) lambda_field_offsets |= {tsym: tfield.offset for tsym, tfield in tuple_fields} @@ -986,9 +1001,10 @@ def init_scan_state(sym: gtir.Sym) -> None: nsdfg.make_array_memlet(input_state), ) - init_scan_state(scan_state_input) if isinstance( - scan_state_input, FieldopData - ) else gtx_utils.tree_map(init_scan_state)(scan_state_input) + if isinstance(scan_state_input, tuple): + gtx_utils.tree_map(init_scan_state)(scan_state_input) + else: + init_scan_state(scan_state_input) # connect the dataflow input directly to the source data nodes, without passing through a map node; # the reason is that the map for horizontal domain is outside the scan loop region @@ -1000,8 +1016,8 @@ def connect_scan_output( scan_output_edge: gtir_dataflow.DataflowOutputEdge, sym: gtir.Sym ) -> FieldopData: scan_result = scan_output_edge.result - assert isinstance(scan_result, gtir_dataflow.ValueExpr) - assert isinstance(sym.type, ts.ScalarType) and scan_result.gt_dtype == sym.type + assert isinstance(scan_result.gt_dtype, ts.ScalarType) + assert scan_result.gt_dtype == sym.type scan_result_data = scan_result.dc_node.data scan_result_desc = scan_result.dc_node.desc(nsdfg) @@ -1023,12 +1039,17 @@ def connect_scan_output( output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype) return FieldopData(output_node, output_type, scan_output_offset) - if isinstance(scan_state_input, gtir.Sym): - assert isinstance(result, gtir_dataflow.DataflowOutputEdge) - lambda_output = connect_scan_output(result, scan_state_input) - else: - assert isinstance(result, tuple) - lambda_output = gtx_utils.tree_map(connect_scan_output)(result, scan_state_input) + lambda_output = ( + gtx_utils.tree_map(connect_scan_output)(result, scan_state_input) + if (isinstance(result, tuple) and isinstance(scan_state_input, tuple)) + else connect_scan_output(result, scan_state_input) + if ( + isinstance(result, gtir_dataflow.DataflowOutputEdge) + and isinstance(scan_state_input, gtir.Sym) + ) + else None + ) + assert lambda_output # in case tuples are passed as argument, isolated non-transient nodes might be left in the state, # because not all tuple fields are necessarily accessed in the lambda scope @@ -1095,15 +1116,14 @@ def construct_output_edge(scan_data: FieldopData) -> gtir_dataflow.DataflowOutpu None, dace.Memlet.from_array(output_data, output_desc), ) - output_expr = gtir_dataflow.MemletExpr( - output_node, scan_data.gt_type.dtype, sbs.Range.from_array(output_desc) - ) + output_expr = gtir_dataflow.ValueExpr(output_node, scan_data.gt_type.dtype) return gtir_dataflow.DataflowOutputEdge(state, output_expr) - if isinstance(lambda_output, FieldopData): - output_edges = construct_output_edge(lambda_output) - else: - output_edges = gtx_utils.tree_map(construct_output_edge)(lambda_output) + output_edges = ( + construct_output_edge(lambda_output) + if isinstance(lambda_output, FieldopData) + else gtx_utils.tree_map(construct_output_edge)(lambda_output) + ) return _create_field_operator( sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges, scan_dim diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 5767a86c42..fa54942049 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -10,7 +10,7 @@ import abc import dataclasses -from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, TypeVar, Union +from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union import dace from dace import subsets as sbs @@ -290,7 +290,10 @@ class LambdaToDataflow(eve.NodeVisitor): state: dace.SDFGState subgraph_builder: gtir_sdfg.DataflowBuilder input_edges: list[DataflowInputEdge] - symbol_map: dict[str, tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]] + symbol_map: dict[ + str, + IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...], + ] def __init__( self, @@ -556,7 +559,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") - def _visit_if(self, node: gtir.FunCall) -> ValueExpr | list[ValueExpr]: + def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: assert len(node.args) == 3 # TODO(edopao): enable once DaCe supports it in next release @@ -564,9 +567,12 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | list[ValueExpr]: condition_value = self.visit(node.args[0]) assert ( - isinstance(condition_value, DataExpr) - and isinstance(condition_value.gt_dtype, ts.ScalarType) - and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ( + isinstance(condition_value.gt_dtype, ts.ScalarType) + and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ) + if isinstance(condition_value, (MemletExpr, ValueExpr)) + else (condition_value.dc_dtype == dace.dtypes.bool) ) nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) @@ -612,19 +618,20 @@ def _visit_if(self, node: gtir.FunCall) -> ValueExpr | list[ValueExpr]: def visit_branch( state: dace.SDFGState, expr: gtir.Expr - ) -> tuple[list[DataflowInputEdge], tuple[DataflowOutputEdge]]: + ) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], + ]: assert state in nsdfg.states() - T = TypeVar("T", IteratorExpr, MemletExpr, ValueExpr) - - def visit_arg(arg: T) -> T: + def visit_arg(arg: IteratorExpr | DataExpr) -> IteratorExpr | ValueExpr: if isinstance(arg, IteratorExpr): arg_node = arg.field arg_desc = arg_node.desc(self.sdfg) arg_subset = sbs.Range.from_array(arg_desc) else: - assert isinstance(arg, (MemletExpr | ValueExpr)) + assert isinstance(arg, (MemletExpr, ValueExpr)) arg_node = arg.dc_node if isinstance(arg, MemletExpr): assert set(arg.subset.size()) == {1} @@ -659,10 +666,16 @@ def visit_arg(arg: T) -> T: if isinstance(arg, IteratorExpr): return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) else: + assert isinstance(inner_desc, dace.data.Scalar) return ValueExpr(inner_node, arg.gt_dtype) lambda_params = [] - lambda_args = [] + lambda_args: list[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ] = [] for p in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): arg = self.symbol_map[p] if isinstance(arg, tuple): @@ -705,9 +718,9 @@ def construct_output( if isinstance(out_edge, tuple): assert isinstance(node.type, ts.TupleType) out_symbol = dace_gtir_utils.make_symbol_tuple("__output", node.type) - outer_value = gtx_utils.tree_map(lambda x, y: construct_output(state, x, y))( - out_edge, out_symbol - ) + outer_value = gtx_utils.tree_map( + lambda x, y, output_state=state: construct_output(output_state, x, y) + )(out_edge, out_symbol) else: assert isinstance(node.type, ts.FieldType | ts.ScalarType) outer_value = construct_output(state, out_edge, im.sym("__output", node.type)) @@ -725,7 +738,7 @@ def construct_output( else: result = outer_value - outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple(result)} + outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} nsdfg_node = self.state.add_nested_sdfg( nsdfg, @@ -758,10 +771,11 @@ def connect_output(inner_value: ValueExpr) -> ValueExpr: ) return ValueExpr(output_node, inner_value.gt_dtype) - if isinstance(result, tuple): - return gtx_utils.tree_map(connect_output)(result) - else: - return connect_output(result) + return ( + gtx_utils.tree_map(connect_output)(result) + if isinstance(result, tuple) + else connect_output(result) + ) def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, itir_ts.ListType) @@ -1489,7 +1503,7 @@ def _visit_tuple_get( def visit_FunCall( self, node: gtir.FunCall - ) -> IteratorExpr | DataExpr | tuple[DataflowOutputEdge, ...]: + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: if cpm.is_call_to(node, "deref"): return self._visit_deref(node) @@ -1515,8 +1529,7 @@ def visit_FunCall( return self._visit_shift(node) elif isinstance(node.fun, gtir.Lambda): - lambda_args = [self.visit(arg) for arg in node.args] - return self.visit_Lambda(node.fun, args=lambda_args) + raise AssertionError("Lambda node should be visited with 'apply()' method.") elif isinstance(node.fun, gtir.SymRef): return self._visit_generic_builtin(node) @@ -1525,24 +1538,10 @@ def visit_FunCall( raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") def visit_Lambda( - self, node: gtir.Lambda, args: list[tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]] - ) -> DataflowOutputEdge | tuple[DataflowOutputEdge, ...]: - # lambda arguments are mapped to symbols defined in lambda scope - prev_symbols: dict[str, Optional[tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]]] = {} - for p, arg in zip(node.params, args, strict=True): - symbol_name = str(p.id) - prev_symbols[symbol_name] = self.symbol_map.get(symbol_name, None) - self.symbol_map[symbol_name] = arg - + self, node: gtir.Lambda + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: result = self.visit(node.expr) - # remove locally defined lambda symbols and restore previous symbols - for symbol_name, arg in prev_symbols.items(): - if arg is None: - self.symbol_map.pop(symbol_name) - else: - self.symbol_map[symbol_name] = arg - def make_output_edge( output_expr: ValueExpr | MemletExpr | SymbolExpr, ) -> DataflowOutputEdge: @@ -1576,16 +1575,19 @@ def parse_result( return r return make_output_edge(r) - if isinstance(result, tuple): - return gtx_utils.tree_map(parse_result)(result) - else: - return parse_result(result) + return ( + gtx_utils.tree_map(parse_result)(result) + if isinstance(result, tuple) + else parse_result(result) + ) def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: dc_dtype = dace_utils.as_dace_type(node.type) return SymbolExpr(node.value, dc_dtype) - def visit_SymRef(self, node: gtir.SymRef) -> tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]: + def visit_SymRef( + self, node: gtir.SymRef + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: param = str(node.id) if param in self.symbol_map: return self.symbol_map[param] @@ -1594,7 +1596,45 @@ def visit_SymRef(self, node: gtir.SymRef) -> tuple[IteratorExpr | MemletExpr | S return SymbolExpr(param, dace.string) def apply( - self, node: gtir.Lambda, args: list[tuple[IteratorExpr | MemletExpr | SymbolExpr, ...]] - ) -> tuple[list[DataflowInputEdge], tuple[DataflowOutputEdge]]: - output_edges = self.visit_Lambda(node, args=args) - return self.input_edges, output_edges + self, + node: gtir.Lambda, + args: list[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], + ) -> tuple[ + list[DataflowInputEdge], + DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...], + ]: + # lambda arguments are mapped to symbols defined in lambda scope + prev_symbols: dict[ + str, + Optional[ + IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...] + ], + ] = {} + for p, arg in zip(node.params, args, strict=True): + symbol_name = str(p.id) + prev_symbols[symbol_name] = self.symbol_map.get(symbol_name, None) + self.symbol_map[symbol_name] = arg + + if cpm.is_let(node.expr): + let_node = node.expr + let_args = [self.visit(arg) for arg in let_node.args] + assert isinstance(let_node.fun, gtir.Lambda) + input_edges, output_edges = self.apply(let_node.fun, args=let_args) + + else: + output_edges = self.visit_Lambda(node) + input_edges = self.input_edges + + # remove locally defined lambda symbols and restore previous symbols + for symbol_name, prev_value in prev_symbols.items(): + if prev_value is None: + self.symbol_map.pop(symbol_name) + else: + self.symbol_map[symbol_name] = prev_value + + return input_edges, output_edges diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 640fde6236..66058f5711 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -169,9 +169,13 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): EmbeddedIds.NUMPY_EXECUTION: EMBEDDED_SKIP_LIST, EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST - + [(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE)], # TODO(edopao): result validation fails with dace optimization + + [ + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): result validation fails with dace optimization OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST - + [(USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE)], # TODO(edopao): result validation fails with dace optimization + + [ + (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE) + ], # TODO(edopao): result validation fails with dace optimization OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST