Skip to content

Commit

Permalink
Merge branch 'main' into index_builtin
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N authored Nov 4, 2024
2 parents f558e35 + 44d6224 commit 0c50850
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,14 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG:

sdfg = dace.SDFG(node.id)
sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo)

# DaCe requires C-compatible strings for the names of data containers,
# such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name
# separator in the SSA pass, which generates invalid symbols for DaCe.
# Here we find new names for invalid symbols present in the IR.
node = dace_gtir_utils.replace_invalid_symbols(sdfg, node)

# start block of the stateful graph
entry_state = sdfg.add_state("program_entry", is_start_block=True)

# declarations of temporaries result in transient array definitions in the SDFG
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
from __future__ import annotations

import itertools
from typing import Any
from typing import Any, Dict, TypeVar

import dace

from gt4py import eve
from gt4py.next import common as gtx_common
from gt4py.next.iterator import ir as gtir
from gt4py.next.type_system import type_specifications as ts


Expand Down Expand Up @@ -66,3 +70,40 @@ def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType:
return ts.TupleType(
types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_dtype for d in data]
)


def replace_invalid_symbols(sdfg: dace.SDFG, ir: gtir.Program) -> gtir.Program:
"""
Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings).
If any invalid symbol present, this funtion returns a copy of the input IR where
the invalid symbols have been replaced with new names. If all symbols are valid,
the input IR is returned without copying it.
"""

class ReplaceSymbols(eve.PreserveLocationVisitor, eve.NodeTranslator):
T = TypeVar("T", gtir.Sym, gtir.SymRef)

def _replace_sym(self, node: T, symtable: Dict[str, str]) -> T:
sym = str(node.id)
return type(node)(id=symtable.get(sym, sym), type=node.type)

def visit_Sym(self, node: gtir.Sym, *, symtable: Dict[str, str]) -> gtir.Sym:
return self._replace_sym(node, symtable)

def visit_SymRef(self, node: gtir.SymRef, *, symtable: Dict[str, str]) -> gtir.SymRef:
return self._replace_sym(node, symtable)

# program arguments are checked separetely, because they cannot be replaced
if not all(dace.dtypes.validate_name(str(sym.id)) for sym in ir.params):
raise ValueError("Invalid symbol in program parameters.")

invalid_symbols_mapping = {
sym_id: sdfg.temp_data_name()
for sym in eve.walk_values(ir).if_isinstance(gtir.Sym).to_set()
if not dace.dtypes.validate_name(sym_id := str(sym.id))
}
if len(invalid_symbols_mapping) != 0:
return ReplaceSymbols().visit(ir, symtable=invalid_symbols_mapping)
else:
return ir
Original file line number Diff line number Diff line change
Expand Up @@ -1629,20 +1629,20 @@ def test_gtir_let_lambda():
declarations=[],
body=[
gtir.SetAt(
# `x1` is a let-lambda expression representing `x * 3`
# `x2` is a let-lambda expression representing `x * 4`
# - note that the let-symbol `x2` is used twice, in a nested let-expression, to test aliasing of the symbol
# `x3` is a let-lambda expression simply accessing `x` field symref
expr=im.let("x1", im.op_as_fieldop("multiplies", subdomain)(3.0, "x"))(
# `xᐞ1` is a let-lambda expression representing `x * 3`
# `xᐞ2` is a let-lambda expression representing `x * 4`
# - note that the let-symbol `xᐞ2` is used twice, in a nested let-expression, to test aliasing of the symbol
# `xᐞ3` is a let-lambda expression simply accessing `x` field symref
expr=im.let("xᐞ1", im.op_as_fieldop("multiplies", subdomain)(3.0, "x"))(
im.let(
"x2",
im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))(
im.op_as_fieldop("plus", subdomain)("x2", "x2")
"xᐞ2",
im.let("xᐞ2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))(
im.op_as_fieldop("plus", subdomain)("xᐞ2", "xᐞ2")
),
)(
im.let("x3", "x")(
im.let("xᐞ3", "x")(
im.op_as_fieldop("plus", subdomain)(
"x1", im.op_as_fieldop("plus", subdomain)("x2", "x3")
"xᐞ1", im.op_as_fieldop("plus", subdomain)("xᐞ2", "xᐞ3")
)
)
)
Expand Down

0 comments on commit 0c50850

Please sign in to comment.