Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL][NFCI] Finalize switch to SPV_KHR_cooperative_matrix #16045

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 1 addition & 76 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,34 +350,6 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) {
return ResultType;
}

template <bool NeedTypeInterpret = false>
llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs,
const unsigned Val = 0) {
// TODO: we should actually have exactly 5 template parameters: 1 for
// type and 4 for type parameters. But in previous version of the SPIR-V
// spec we have Layout matrix type parameter, that was later removed.
// Once we update to the newest version of the spec - this should be updated.
assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) &&
"Wrong JointMatrixINTEL template parameters number");
// This is required to represent optional 'Component Type Interpretation'
// parameter
std::vector<unsigned> Params;
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
"Wrong JointMatrixINTEL template parameter");
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
}
// Don't add type interpretation for legacy matrices.
// Legacy matrices has 5 template parameters, while new representation
// has 6.
if (NeedTypeInterpret && TemplateArgs.size() != 5)
Params.push_back(Val);

return llvm::TargetExtType::get(CompTy->getContext(),
"spirv.JointMatrixINTEL", {CompTy}, Params);
}

llvm::Type *
getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs) {
Expand All @@ -394,49 +366,6 @@ getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
CompTy->getContext(), "spirv.CooperativeMatrixKHR", {CompTy}, Params);
}

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
"1st JointMatrixINTEL template parameter must be type");
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());

// Per JointMatrixINTEL spec the type can have an optional
// 'Component Type Interpretation' parameter. We should emit it in case
// if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as
// matrix's components. Yet 'bfloat16' should be represented as 'int16' and
// 'tf32' as 'float' types.
if (CompTy->isStructTy()) {
StringRef LlvmTyName = CompTy->getStructName();
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.starts_with("class.sycl::") ||
LlvmTyName.starts_with("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName == "half") {
CompTy = llvm::Type::getHalfTy(getLLVMContext());
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
} else if (LlvmTyName == "tf32") {
CompTy = llvm::Type::getFloatTy(getLLVMContext());
// 'tf32' interpretation is mapped to '0'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0);
} else if (LlvmTyName == "bfloat16") {
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
// 'bfloat16' interpretation is mapped to '1'
return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1);
} else {
llvm_unreachable("Wrong matrix base type!");
}
}
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
}

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
Expand Down Expand Up @@ -733,11 +662,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
if (ClangETy && ClangETy->isStructureOrClassType()) {
RecordDecl *RD = ClangETy->getAsCXXRecordDecl();
if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_JointMatrixINTEL") {
ResultType = ConvertSYCLJointMatrixINTELType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_CooperativeMatrixKHR") {
"__spv::__spirv_CooperativeMatrixKHR") {
ResultType = ConvertSPVCooperativeMatrixType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
Expand Down
8 changes: 0 additions & 8 deletions clang/lib/CodeGen/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,6 @@ class CodeGenTypes {
/// load/store type are the same.
llvm::Type *convertTypeForLoadStore(QualType T, llvm::Type *LLVMTy = nullptr);

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
/// The expected representation is:
/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%,
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
Expand Down
41 changes: 0 additions & 41 deletions clang/test/CodeGenSYCL/joint_matrix.cpp

This file was deleted.

150 changes: 0 additions & 150 deletions sycl/include/sycl/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,155 +27,6 @@

extern __DPCPP_SYCL_EXTERNAL float __spirv_RoundFToTF32INTEL(float a);

#ifndef __SPIRV_USE_COOPERATIVE_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_JointMatrixLoadINTEL(T *Ptr, std::size_t Stride,
__spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
std::size_t Stride, __spv::MatrixLayout Layout = L,
__spv::Scope::Flag Sc = S, int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CooperativeMatrixConstructCheckedINTEL(int32_t CoordX,
int32_t CoordY,
uint32_t Height,
uint32_t Width,
const T Value);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CooperativeMatrixLoadCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY, __spv::MatrixLayout Layout = L,
uint32_t Height = 0, uint32_t Width = 0, std::size_t Stride = 0,
int MemOperand = 0);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
T *Ptr, int32_t CoordX, int32_t CoordY,
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
std::size_t Stride = 0, int MemOperand = 0);

template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<TC, M, N, LC, S, UC> *
__spirv_JointMatrixMadINTEL(
__spv::__spirv_JointMatrixINTEL<TA, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<TB, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<TC, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
__spirv_JointMatrixUUMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
__spirv_JointMatrixUSMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,
__spv::MatrixUse UC,
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *
__spirv_JointMatrixSUMadINTEL(
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S, UA> *A,
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S, UB> *B,
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S, UC> *C,
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);

template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *);

template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL Ts __spirv_VectorExtractDynamic(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);

template <typename Ts, typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_VectorInsertDynamic(__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *,
Ts val, size_t i);
#else // __SPIRV_USE_COOPERATIVE_MATRIX
template <typename T, typename Tp, std::size_t R, std::size_t C,
__spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down Expand Up @@ -304,7 +155,6 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
__spv::__spirv_CooperativeMatrixKHR<Tp, S, R, C, U> *Object,
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
std::size_t Stride = 0, int MemOperand = 0);
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

template <typename T>
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixPrefetchINTEL(
Expand Down
10 changes: 0 additions & 10 deletions sycl/include/sycl/__spirv/spirv_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ enum class MatrixLayout : uint32_t {

enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };

#ifdef __SPIRV_USE_COOPERATIVE_MATRIX
enum class MatrixOperands : uint32_t {
// SPV_KHR_cooperative_matrix operands
NoneKHR = 0,
Expand All @@ -133,19 +132,10 @@ enum class MatrixOperands : uint32_t {
MatrixCBFloat16ComponentsINTEL = 0x80,
MatrixResultBFloat16ComponentsINTEL = 0x100
};
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

#ifndef __SPIRV_USE_COOPERATIVE_MATRIX

template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
Scope::Flag S = Scope::Flag::Subgroup,
MatrixUse U = MatrixUse::MatrixA>
struct __spirv_JointMatrixINTEL;
#else
template <typename T, Scope::Flag S = Scope::Flag::Subgroup, std::size_t R = 1,
std::size_t C = 1, MatrixUse U = MatrixUse::MatrixA>
struct __spirv_CooperativeMatrixKHR;
#endif // __SPIRV_USE_COOPERATIVE_MATRIX

struct __spirv_TaskSequenceINTEL;

Expand Down
Loading
Loading