Skip to content

Commit

Permalink
use SPIR-V headers for SPV_INTEL_bfloat16_conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Jan 25, 2025
1 parent cec12d6 commit eff2aef
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 20 deletions.
4 changes: 2 additions & 2 deletions lib/SPIRV/OCLToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1893,7 +1893,7 @@ void OCLToSPIRVBase::visitCallConvertBFloat16AsUshort(CallInst *CI,
}
}

mutateCallInst(CI, internal::OpConvertFToBF16INTEL);
mutateCallInst(CI, OpConvertFToBF16INTEL);
}

void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
Expand Down Expand Up @@ -1936,7 +1936,7 @@ void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
}
}

mutateCallInst(CI, internal::OpConvertBF16ToFINTEL);
mutateCallInst(CI, OpConvertBF16ToFINTEL);
}
} // namespace SPIRV

Expand Down
8 changes: 4 additions & 4 deletions lib/SPIRV/SPIRVToOCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ void SPIRVToOCLBase::visitCallInst(CallInst &CI) {
visitCallSPIRVReadClockKHR(&CI);
return;
}
if (OC == internal::OpConvertFToBF16INTEL ||
OC == internal::OpConvertBF16ToFINTEL) {
if (OC == OpConvertFToBF16INTEL ||
OC == OpConvertBF16ToFINTEL) {
visitCallSPIRVBFloat16Conversions(&CI, OC);
return;
}
Expand Down Expand Up @@ -928,10 +928,10 @@ void SPIRVToOCLBase::visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC) {
: "";
std::string Name;
switch (static_cast<uint32_t>(OC)) {
case internal::OpConvertFToBF16INTEL:
case OpConvertFToBF16INTEL:
Name = "intel_convert_bfloat16" + N + "_as_ushort" + N;
break;
case internal::OpConvertBF16ToFINTEL:
case OpConvertBF16ToFINTEL:
Name = "intel_convert_as_bfloat16" + N + "_float" + N;
break;
default:
Expand Down
8 changes: 4 additions & 4 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3558,9 +3558,9 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
SPIRVCapVec getRequiredCapability() const override {
SPIRVType *ResCompTy = this->getType();
if (ResCompTy->isTypeCooperativeMatrixKHR())
return getVec(internal::CapabilityBfloat16ConversionINTEL,
return getVec(CapabilityBFloat16ConversionINTEL,
internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
return getVec(internal::CapabilityBfloat16ConversionINTEL);
return getVec(CapabilityBFloat16ConversionINTEL);
}

std::optional<ExtensionID> getRequiredExtension() const override {
Expand Down Expand Up @@ -3614,7 +3614,7 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
InCompTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(InCompTy)->getCompType();
}
if (OC == internal::OpConvertFToBF16INTEL) {
if (OC == OpConvertFToBF16INTEL) {
SPVErrLog.checkError(
ResCompTy->isTypeInt(16), SPIRVEC_InvalidInstruction,
InstName + "\nResult value must be a scalar or vector of integer "
Expand Down Expand Up @@ -3642,7 +3642,7 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
};

#define _SPIRV_OP(x) \
typedef SPIRVBfloat16ConversionINTELInstBase<internal::Op##x> SPIRV##x;
typedef SPIRVBfloat16ConversionINTELInstBase<Op##x> SPIRV##x;
_SPIRV_OP(ConvertFToBF16INTEL)
_SPIRV_OP(ConvertBF16ToFINTEL)
#undef _SPIRV_OP
Expand Down
2 changes: 1 addition & 1 deletion lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(CapabilityOptNoneEXT, "OptNoneEXT");
add(CapabilityAtomicFloat16AddEXT, "AtomicFloat16AddEXT");
add(CapabilityDebugInfoModuleINTEL, "DebugInfoModuleINTEL");
add(CapabilityBFloat16ConversionINTEL, "Bfloat16ConversionINTEL");
add(CapabilitySplitBarrierINTEL, "SplitBarrierINTEL");
add(CapabilityGlobalVariableFPGADecorationsINTEL,
"GlobalVariableFPGADecorationsINTEL");
Expand All @@ -644,7 +645,6 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
// From spirv_internal.hpp
add(internal::CapabilityFastCompositeINTEL, "FastCompositeINTEL");
add(internal::CapabilityTokenTypeINTEL, "TokenTypeINTEL");
add(internal::CapabilityBfloat16ConversionINTEL, "Bfloat16ConversionINTEL");
add(internal::CapabilityJointMatrixINTEL, "JointMatrixINTEL");
add(internal::CapabilityHWThreadQueryINTEL, "HWThreadQueryINTEL");
add(internal::CapabilityGlobalVariableDecorationsINTEL,
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,8 @@ _SPIRV_OP(TypeBufferSurfaceINTEL, 6086)
_SPIRV_OP(TypeStructContinuedINTEL, 6090)
_SPIRV_OP(ConstantCompositeContinuedINTEL, 6091)
_SPIRV_OP(SpecConstantCompositeContinuedINTEL, 6092)
_SPIRV_OP(ConvertFToBF16INTEL, 6116)
_SPIRV_OP(ConvertBF16ToFINTEL, 6117)
_SPIRV_OP(ControlBarrierArriveINTEL, 6142)
_SPIRV_OP(ControlBarrierWaitINTEL, 6143)
_SPIRV_OP(ArithmeticFenceEXT, 6145)
Expand Down
2 changes: 0 additions & 2 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

_SPIRV_OP_INTERNAL(Forward, internal::OpForward)
_SPIRV_OP_INTERNAL(TypeTokenINTEL, internal::OpTypeTokenINTEL)
_SPIRV_OP_INTERNAL(ConvertFToBF16INTEL, internal::OpConvertFToBF16INTEL)
_SPIRV_OP_INTERNAL(ConvertBF16ToFINTEL, internal::OpConvertBF16ToFINTEL)
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
Expand Down
7 changes: 0 additions & 7 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ enum InternalLinkageType {

enum InternalOp {
IOpTypeTokenINTEL = 6113,
IOpConvertFToBF16INTEL = 6116,
IOpConvertBF16ToFINTEL = 6117,
IOpTypeJointMatrixINTEL = 6119,
IOpJointMatrixLoadINTEL = 6120,
IOpJointMatrixStoreINTEL = 6121,
Expand Down Expand Up @@ -107,7 +105,6 @@ enum InternalDecoration {
enum InternalCapability {
ICapFastCompositeINTEL = 6093,
ICapTokenTypeINTEL = 6112,
ICapBfloat16ConversionINTEL = 6115,
ICapabilityJointMatrixINTEL = 6118,
ICapabilityHWThreadQueryINTEL = 6134,
ICapGlobalVariableDecorationsINTEL = 6146,
Expand Down Expand Up @@ -267,8 +264,6 @@ constexpr SourceLanguage SourceLanguageCPP20 =

constexpr Op OpForward = static_cast<Op>(IOpForward);
constexpr Op OpTypeTokenINTEL = static_cast<Op>(IOpTypeTokenINTEL);
constexpr Op OpConvertFToBF16INTEL = static_cast<Op>(IOpConvertFToBF16INTEL);
constexpr Op OpConvertBF16ToFINTEL = static_cast<Op>(IOpConvertBF16ToFINTEL);

constexpr Decoration DecorationCallableFunctionINTEL =
static_cast<Decoration>(IDecCallableFunctionINTEL);
Expand All @@ -287,8 +282,6 @@ constexpr Capability CapabilityFastCompositeINTEL =
static_cast<Capability>(ICapFastCompositeINTEL);
constexpr Capability CapabilityTokenTypeINTEL =
static_cast<Capability>(ICapTokenTypeINTEL);
constexpr Capability CapabilityBfloat16ConversionINTEL =
static_cast<Capability>(ICapBfloat16ConversionINTEL);
constexpr Capability CapabilityGlobalVariableDecorationsINTEL =
static_cast<Capability>(ICapGlobalVariableDecorationsINTEL);

Expand Down

0 comments on commit eff2aef

Please sign in to comment.