diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index 90dbe4dc38..028d24dbba 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -17,8 +17,9 @@ def _is_complex(dtype): if hasattr(dtype, "is_complex") and callable(dtype.is_complex): return dtype.is_complex() - else: - return dtype in [dtypes.complex64, dtypes.complex128] + if not isinstance(dtype, dtypes.typeclass): + dtype = dace.dtype_to_typeclass(dtype) + return dtype in [dtypes.complex64, dtypes.complex128] def _cast_to_dtype_str(value, dtype: dace.dtypes.typeclass) -> str: