Skip to content

Commit

Permalink
converto scalar to array on nsdfg output
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Dec 10, 2024
1 parent 9bdc75b commit 746f9d8
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,10 @@ def apply(
else:
assert isinstance(inner_desc, dace.data.Scalar)
assert len(new_strides) == 0
# we convert the scalar data to array to avoid a gpu codegen error
nsdfg_node.sdfg.arrays[inner_data] = dace.data.Array(
inner_desc.dtype, (1,), inner_desc.transient
)
for stride in new_strides:
for sym in stride.free_symbols:
nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype)
Expand Down

0 comments on commit 746f9d8

Please sign in to comment.