Skip to content

Commit

Permalink
Handle OpVectorShuffle with differing vector sizes (intel#2391)
Browse files Browse the repository at this point in the history
The SPIR-V to LLVM conversion would bail out when encountering an
`OpVectorShuffle` whose vector operands differ in size.  SPIR-V
allows differing vector sizes, but LLVM's `shufflevector` does not.

Remove the assert and insert an additional `shufflevector` to align
the vector operands when needed.

Original commit:
KhronosGroup/SPIRV-LLVM-Translator@3df5e38250a6d7c
  • Loading branch information
svenvh authored and sys-ce-bb committed Mar 7, 2024
1 parent 4a58a77 commit 2483d62
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 9 deletions.
5 changes: 5 additions & 0 deletions llvm-spirv/lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ using namespace llvm;

namespace llvm {
class IntrinsicInst;
class IRBuilderBase;
}

namespace SPIRV {
Expand Down Expand Up @@ -551,6 +552,10 @@ std::string mapLLVMTypeToOCLType(const Type *Ty, bool Signed,
Type *PointerElementType = nullptr);
SPIRVDecorate *mapPostfixToDecorate(StringRef Postfix, SPIRVEntry *Target);

/// Return vector V extended with poison elements to match the number of
/// components of NewType.
Value *extendVector(Value *V, FixedVectorType *NewType, IRBuilderBase &Builder);

/// Add decorations to a SPIR-V entry.
/// \param Decs Each string is a postfix without _ at the beginning.
SPIRVValue *addDecorations(SPIRVValue *Target,
Expand Down
35 changes: 31 additions & 4 deletions llvm-spirv/lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2309,10 +2309,37 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
if (BB) {
Builder.SetInsertPoint(BB);
}
return mapValue(BV, Builder.CreateShuffleVector(
transValue(VS->getVector1(), F, BB),
transValue(VS->getVector2(), F, BB),
ConstantVector::get(Components), BV->getName()));
Value *Vec1 = transValue(VS->getVector1(), F, BB);
Value *Vec2 = transValue(VS->getVector2(), F, BB);
auto *Vec1Ty = cast<FixedVectorType>(Vec1->getType());
auto *Vec2Ty = cast<FixedVectorType>(Vec2->getType());
if (Vec1Ty->getNumElements() != Vec2Ty->getNumElements()) {
// LLVM's shufflevector requires that the two vector operands have the
// same type; SPIR-V's OpVectorShuffle allows the vector operands to
// differ in the number of components. Adjust for that by extending
// the smaller vector.
if (Vec1Ty->getNumElements() < Vec2Ty->getNumElements()) {
Vec1 = extendVector(Vec1, Vec2Ty, Builder);
// Extending Vec1 requires offsetting any Vec2 indices in Components by
// the number of new elements.
unsigned Offset = Vec2Ty->getNumElements() - Vec1Ty->getNumElements();
unsigned Vec2Start = Vec1Ty->getNumElements();
for (auto &C : Components) {
if (auto *CI = dyn_cast<ConstantInt>(C)) {
uint64_t V = CI->getZExtValue();
if (V >= Vec2Start) {
// This is a Vec2 index; add the offset to it.
C = ConstantInt::get(Int32Ty, V + Offset);
}
}
}
} else {
Vec2 = extendVector(Vec2, Vec1Ty, Builder);
}
}
return mapValue(
BV, Builder.CreateShuffleVector(
Vec1, Vec2, ConstantVector::get(Components), BV->getName()));
}

case OpBitReverse: {
Expand Down
18 changes: 18 additions & 0 deletions llvm-spirv/lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,24 @@ void removeFnAttr(CallInst *Call, Attribute::AttrKind Attr) {
Call->removeFnAttr(Attr);
}

Value *extendVector(Value *V, FixedVectorType *NewType,
IRBuilderBase &Builder) {
unsigned OldSize = cast<FixedVectorType>(V->getType())->getNumElements();
unsigned NewSize = NewType->getNumElements();
assert(OldSize < NewSize);
std::vector<Constant *> Components;
IntegerType *Int32Ty = Builder.getInt32Ty();
for (unsigned I = 0; I < NewSize; I++) {
if (I < OldSize)
Components.push_back(ConstantInt::get(Int32Ty, I));
else
Components.push_back(PoisonValue::get(Int32Ty));
}

return Builder.CreateShuffleVector(V, PoisonValue::get(V->getType()),
ConstantVector::get(Components), "vecext");
}

void saveLLVMModule(Module *M, const std::string &OutputFile) {
std::error_code EC;
ToolOutputFile Out(OutputFile.c_str(), EC, sys::fs::OF_None);
Expand Down
6 changes: 1 addition & 5 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2213,15 +2213,11 @@ class SPIRVVectorShuffleBase : public SPIRVInstTemplateBase {
protected:
void validate() const override {
SPIRVInstruction::validate();
SPIRVId Vector1 = Ops[0];
SPIRVId Vector2 = Ops[1];
[[maybe_unused]] SPIRVId Vector1 = Ops[0];
assert(OpCode == OpVectorShuffle);
assert(Type->isTypeVector());
assert(Type->getVectorComponentType() ==
getValueType(Vector1)->getVectorComponentType());
if (getValue(Vector1)->isForward() || getValue(Vector2)->isForward())
return;
assert(getValueType(Vector1) == getValueType(Vector2));
assert(Ops.size() - 2 == Type->getVectorComponentCount());
}
};
Expand Down
36 changes: 36 additions & 0 deletions llvm-spirv/test/OpVectorShuffle.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
; REQUIRES: spirv-as
; RUN: spirv-as --target-env spv1.0 -o %t.spv %s
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r -o - %t.spv | llvm-dis | FileCheck %s
OpCapability Addresses
OpCapability Kernel
OpMemoryModel Physical32 OpenCL
OpEntryPoint Kernel %1 "testVecShuffle"
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%uintv2 = OpTypeVector %uint 2
%uintv3 = OpTypeVector %uint 3
%uintv4 = OpTypeVector %uint 4
%func = OpTypeFunction %void %uintv2 %uintv3

%1 = OpFunction %void None %func
%pv2 = OpFunctionParameter %uintv2
%pv3 = OpFunctionParameter %uintv3
%entry = OpLabel

; Same vector lengths
%vs1 = OpVectorShuffle %uintv4 %pv3 %pv3 0 1 3 5
; CHECK: shufflevector <3 x i32> %[[#]], <3 x i32> %[[#]], <4 x i32> <i32 0, i32 1, i32 3, i32 5>

; vec1 smaller than vec2
%vs2 = OpVectorShuffle %uintv4 %pv2 %pv3 0 1 3 4
; CHECK: %[[VS2EXT:[0-9a-z]+]] = shufflevector <2 x i32> %0, <2 x i32> poison, <3 x i32> <i32 0, i32 1, i32 poison>
; CHECK: shufflevector <3 x i32> %[[VS2EXT]], <3 x i32> %[[#]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>

; vec1 larger than vec2
%vs3 = OpVectorShuffle %uintv4 %pv3 %pv2 0 1 3 4
; CHECK: %[[VS3EXT:[0-9a-z]+]] = shufflevector <2 x i32> %0, <2 x i32> poison, <3 x i32> <i32 0, i32 1, i32 poison>
; CHECK: shufflevector <3 x i32> %[[#]], <3 x i32> %[[VS3EXT]], <4 x i32> <i32 0, i32 1, i32 3, i32 4>

OpReturn
OpFunctionEnd

0 comments on commit 2483d62

Please sign in to comment.