Skip to content

Commit

Permalink
OpenBLAS's transpose needs float and double pointers instead of std::…
Browse files Browse the repository at this point in the history
…complex.
  • Loading branch information
alexnick83 committed Nov 13, 2023
1 parent ab39d5c commit 8eeb622
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions dace/libraries/standard/nodes/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class ExpandTransposeOpenBLAS(ExpandTransformation):
def expansion(node, state, sdfg):
node.validate(sdfg, state)
dtype = node.dtype
cast = ""
if dtype == dace.float32:
func = "somatcopy"
alpha = "1.0f"
Expand All @@ -150,19 +151,21 @@ def expansion(node, state, sdfg):
alpha = "1.0"
elif dtype == dace.complex64:
func = "comatcopy"
alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
cast = "(float*)"
alpha = f"{cast}dace::blas::BlasConstants::Get().Complex64Pone()"
elif dtype == dace.complex128:
func = "zomatcopy"
alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
cast = "(double*)"
alpha = f"{cast}dace::blas::BlasConstants::Get().Complex128Pone()"
else:
raise ValueError("Unsupported type for OpenBLAS omatcopy extension: " + str(dtype))
# TODO: Add stride support
_, _, (m, n), _ = _get_transpose_input(node, state, sdfg)
# Adaptations for BLAS API
order = 'CblasRowMajor'
trans = 'CblasTrans'
code = ("cblas_{f}({o}, {t}, {m}, {n}, {a}, _inp, "
"{n}, _out, {m});").format(f=func, o=order, t=trans, m=m, n=n, a=alpha)
code = ("cblas_{f}({o}, {t}, {m}, {n}, {a}, {c}_inp, "
"{n}, {c}_out, {m});").format(f=func, o=order, t=trans, m=m, n=n, a=alpha, c=cast)
tasklet = dace.sdfg.nodes.Tasklet(node.name,
node.in_connectors,
node.out_connectors,
Expand Down

0 comments on commit 8eeb622

Please sign in to comment.