diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index e489f130db..48c666a363 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -399,6 +399,14 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG: sdfg = dace.SDFG(node.id) sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) + + # DaCe requires C-compatible strings for the names of data containers, + # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name + # separator in the SSA pass, which generates invalid symbols for DaCe. + # Here we find new names for invalid symbols present in the IR. + node = dace_gtir_utils.replace_invalid_symbols(sdfg, node) + + # start block of the stateful graph entry_state = sdfg.add_state("program_entry", is_start_block=True) # declarations of temporaries result in transient array definitions in the SDFG diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index 355eaac903..baae8a6ccd 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -9,9 +9,13 @@ from __future__ import annotations import itertools -from typing import Any +from typing import Any, Dict, TypeVar +import dace + +from gt4py import eve from gt4py.next import common as gtx_common +from gt4py.next.iterator import ir as gtir from gt4py.next.type_system import type_specifications as ts @@ -66,3 +70,40 @@ def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: return ts.TupleType( types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_dtype for d in data] ) + + +def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program: + """ + Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). + + If any invalid symbol present, this funtion returns a copy of the input IR where + the invalid symbols have been replaced with new names. If all symbols are valid, + the input IR is returned without copying it. + """ + + class ReplaceSymbols(eve.PreserveLocationVisitor, eve.NodeTranslator): + T = TypeVar("T", gtir.Sym, gtir.SymRef) + + def _replace_sym(self, node: T, symtable: Dict[str, str]) -> T: + sym = str(node.id) + return type(node)(id=symtable.get(sym, sym), type=node.type) + + def visit_Sym(self, node: gtir.Sym, *, symtable: Dict[str, str]) -> gtir.Sym: + return self._replace_sym(node, symtable) + + def visit_SymRef(self, node: gtir.SymRef, *, symtable: Dict[str, str]) -> gtir.SymRef: + return self._replace_sym(node, symtable) + + # program arguments are checked separetely, because they cannot be replaced + if not all(dace.dtypes.validate_name(str(sym.id)) for sym in ir.params): + raise ValueError("Invalid symbol in program parameters.") + + invalid_symbols_mapping = { + sym_id: sdfg.temp_data_name() + for sym in eve.walk_values(ir).if_isinstance(gtir.Sym).to_set() + if not dace.dtypes.validate_name(sym_id := str(sym.id)) + } + if len(invalid_symbols_mapping) != 0: + return ReplaceSymbols().visit(ir, symtable=invalid_symbols_mapping) + else: + return ir diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index 41f540d3cf..9f5498b4a7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -1629,20 +1629,20 @@ def test_gtir_let_lambda(): declarations=[], body=[ gtir.SetAt( - # `x1` is a let-lambda expression representing `x * 3` - # `x2` is a let-lambda expression representing `x * 4` - # - note that the let-symbol `x2` is used twice, in a nested let-expression, to test aliasing of the symbol - # `x3` is a let-lambda expression simply accessing `x` field symref - expr=im.let("x1", im.op_as_fieldop("multiplies", subdomain)(3.0, "x"))( + # `xᐞ1` is a let-lambda expression representing `x * 3` + # `xᐞ2` is a let-lambda expression representing `x * 4` + # - note that the let-symbol `xᐞ2` is used twice, in a nested let-expression, to test aliasing of the symbol + # `xᐞ3` is a let-lambda expression simply accessing `x` field symref + expr=im.let("xᐞ1", im.op_as_fieldop("multiplies", subdomain)(3.0, "x"))( im.let( - "x2", - im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( - im.op_as_fieldop("plus", subdomain)("x2", "x2") + "xᐞ2", + im.let("xᐞ2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( + im.op_as_fieldop("plus", subdomain)("xᐞ2", "xᐞ2") ), )( - im.let("x3", "x")( + im.let("xᐞ3", "x")( im.op_as_fieldop("plus", subdomain)( - "x1", im.op_as_fieldop("plus", subdomain)("x2", "x3") + "xᐞ1", im.op_as_fieldop("plus", subdomain)("xᐞ2", "xᐞ3") ) ) )