Skip to content

Commit

Permalink
Swizzle struct member
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed Sep 12, 2024
1 parent 51e34a1 commit 9406b1f
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 20 deletions.
2 changes: 1 addition & 1 deletion include/CodeGen/Swizzle.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace shaderpulse {
namespace codegen {

extern std::unordered_map<char, int> swizzleMap;
mlir::Value swizzle(mlir::OpBuilder &builder, mlir::Value composite, ast::MemberAccessExpression* memberAccess);
mlir::Value swizzle(mlir::OpBuilder &builder, mlir::Value composite, ast::MemberAccessExpression* memberAccess, int startIndex = 0);

};

Expand Down
14 changes: 13 additions & 1 deletion lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,18 @@ void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) {
std::vector<mlir::Value> memberIndicesAcc;

if (currentBaseComposite) {
for (auto &member : memberAccess->getMembers()) {
std::pair<int, VariableDeclaration*> prevMemberIndexPair;
for (int i = 0; i < memberAccess->getMembers().size(); i++) {
auto &member = memberAccess->getMembers()[i];
if (auto var = dynamic_cast<VariableExpression*>(member.get())) {
// Swizzle detected
if (prevMemberIndexPair.second && prevMemberIndexPair.second->getType()->getKind() == shaderpulse::TypeKind::Vector) {
mlir::Value accessChain = builder.create<spirv::AccessChainOp>(builder.getUnknownLoc(), baseCompositeValue, memberIndicesAcc);
mlir::Value swizzled = swizzle(builder, load(accessChain), memberAccess, i);
expressionStack.push_back(swizzled);
return;
}

auto memberIndexPair = currentBaseComposite->getMemberWithIndex(var->getName());
memberIndicesAcc.push_back(builder.create<spirv::ConstantOp>(builder.getUnknownLoc(), mlir::IntegerType::get(&context, 32, mlir::IntegerType::Signless), builder.getI32IntegerAttr(memberIndexPair.first)));

Expand All @@ -832,6 +842,8 @@ void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) {
currentBaseComposite = structDeclarations[structName];
}
}

prevMemberIndexPair = memberIndexPair;
// This is a duplicate of ArrayAccessExpression, idially we want to reuse that.
} else if (auto arrayAccess = dynamic_cast<ArrayAccessExpression*>(member.get())) {
auto varName = dynamic_cast<VariableExpression*>(arrayAccess->getArray())->getName();
Expand Down
37 changes: 19 additions & 18 deletions lib/CodeGen/Swizzle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,29 @@ std::unordered_map<char, int> swizzleMap = {
{'a', 3}
};

mlir::Value swizzle(mlir::OpBuilder &builder, mlir::Value composite, ast::MemberAccessExpression* memberAccess) {
mlir::Value swizzle(mlir::OpBuilder &builder, mlir::Value composite, ast::MemberAccessExpression* memberAccess, int startIndex) {
mlir::Value currentComposite = composite;

for (auto &member : memberAccess->getMembers()) {
if (auto var = dynamic_cast<ast::VariableExpression*>(member.get())) {
std::vector<int> indices;
auto swizzle = var->getName();

if (swizzle.length() == 1) {
indices.push_back(swizzleMap.find(swizzle[0])->second);
return builder.create<mlir::spirv::CompositeExtractOp>(builder.getUnknownLoc(), currentComposite, indices);
} else {
for (auto c : swizzle) {
indices.push_back(swizzleMap.find(c)->second);
for (int i = startIndex; i < memberAccess->getMembers().size(); i++) {
auto &member = memberAccess->getMembers()[i];
if (auto var = dynamic_cast<ast::VariableExpression*>(member.get())) {
std::vector<int> indices;
auto swizzle = var->getName();

if (swizzle.length() == 1) {
indices.push_back(swizzleMap.find(swizzle[0])->second);
return builder.create<mlir::spirv::CompositeExtractOp>(builder.getUnknownLoc(), currentComposite, indices);
} else {
for (auto c : swizzle) {
indices.push_back(swizzleMap.find(c)->second);
}

llvm::ArrayRef<int64_t> shape(static_cast<int64_t>(swizzle.length()));
mlir::Type elementType = currentComposite.getType().dyn_cast<mlir::VectorType>().getElementType();
mlir::VectorType shuffleType = mlir::VectorType::get(shape, elementType);
currentComposite = builder.create<mlir::spirv::VectorShuffleOp>(builder.getUnknownLoc(), shuffleType, currentComposite, currentComposite, builder.getI32ArrayAttr(indices));
}

llvm::ArrayRef<int64_t> shape(static_cast<int64_t>(swizzle.length()));
mlir::Type elementType = currentComposite.getType().dyn_cast<mlir::VectorType>().getElementType();
mlir::VectorType shuffleType = mlir::VectorType::get(shape, elementType);
currentComposite = builder.create<mlir::spirv::VectorShuffleOp>(builder.getUnknownLoc(), shuffleType, currentComposite, currentComposite, builder.getI32ArrayAttr(indices));
}
}
}

return currentComposite;
Expand Down
13 changes: 13 additions & 0 deletions test/CodeGen/swizzle.glsl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
struct StructSwizzle {
vec3 color;
vec3 pos;
};

void main() {
vec2 _vec2 = vec2(0.1, 0.2);

Expand Down Expand Up @@ -33,4 +38,12 @@ void main() {

// CHECK: %30 = spirv.VectorShuffle [2 : i32, 1 : i32, 0 : i32] %29 : vector<4xf32>, %29 : vector<4xf32> -> vector<3xf32>
_swizz_vec3 = _vec4.bgr;

StructSwizzle structSwizz = StructSwizzle(vec3(1.0, 0.0, 0.0), vec3(0.5, 0.0, 1.0));

// CHECK: %cst0_i32 = spirv.Constant 0 : i32
// CHECK-NEXT: %35 = spirv.AccessChain %34[%cst0_i32] : !spirv.ptr<!spirv.struct<(vector<3xf32>, vector<3xf32>)>, Function>, i32
// CHECK-NEXT: %36 = spirv.Load "Function" %35 : vector<3xf32>
// CHECK-NEXT: %37 = spirv.VectorShuffle [1 : i32, 0 : i32] %36 : vector<3xf32>, %36 : vector<3xf32> -> vector<2xf32>
_swizz_chain = structSwizz.color.yx;
}

0 comments on commit 9406b1f

Please sign in to comment.