diff --git a/include/CodeGen/MLIRCodeGen.h b/include/CodeGen/MLIRCodeGen.h index b02e980..503a34c 100644 --- a/include/CodeGen/MLIRCodeGen.h +++ b/include/CodeGen/MLIRCodeGen.h @@ -97,6 +97,7 @@ class MLIRCodeGen : public ASTVisitor { void declare(StringRef name, SymbolTableEntry entry); void createVariable(shaderpulse::Type *, VariableDeclaration *); void insertEntryPoint(); + mlir::Value load(mlir::Value); std::pair popExpressionStack(); mlir::Value currentBasePointer; diff --git a/lib/CodeGen/MLIRCodeGen.cpp b/lib/CodeGen/MLIRCodeGen.cpp index 43641d5..adf7429 100644 --- a/lib/CodeGen/MLIRCodeGen.cpp +++ b/lib/CodeGen/MLIRCodeGen.cpp @@ -54,7 +54,7 @@ void MLIRCodeGen::visit(TranslationUnit *unit) { insertEntryPoint(); } -std::pair MLIRCodeGen::popExpressionStack() { +std::pair MLIRCodeGen::popExpressionStack() { auto val = expressionStack.back(); expressionStack.pop_back(); return val; @@ -64,16 +64,16 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { binExp->getLhs()->accept(this); binExp->getRhs()->accept(this); - std::pair rhsPair = popExpressionStack(); - std::pair lhsPair = popExpressionStack(); + std::pair rhsPair = popExpressionStack(); + std::pair lhsPair = popExpressionStack(); - Value rhs = rhsPair.second; - Value lhs = lhsPair.second; + mlir::Value rhs = load(rhsPair.second); + mlir::Value lhs = load(lhsPair.second); Type* typeContext = lhsPair.first; // TODO: implement source location auto loc = builder.getUnknownLoc(); - Value val; + mlir::Value val; switch (binExp->getOp()) { case BinaryOperator::Add: @@ -217,11 +217,11 @@ void MLIRCodeGen::visit(ConditionalExpression *condExp) { condExp->getTruePart()->accept(this); condExp->getCondition()->accept(this); - std::pair condition = popExpressionStack(); - std::pair truePart = popExpressionStack(); - std::pair falsePart = popExpressionStack(); + std::pair condition = popExpressionStack(); + std::pair truePart = popExpressionStack(); + std::pair falsePart = popExpressionStack(); - Value res = builder.create( + mlir::Value res = builder.create( builder.getUnknownLoc(), convertShaderPulseType(&context, truePart.first, structDeclarations), condition.second, @@ -241,7 +241,7 @@ void MLIRCodeGen::visit(InitializerExpression *initExp) { * is implemented. */ auto type = builder.getIntegerType(32, true); - Value val = builder.create( + mlir::Value val = builder.create( builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(32, 0, true))); @@ -250,10 +250,12 @@ void MLIRCodeGen::visit(InitializerExpression *initExp) { void MLIRCodeGen::visit(UnaryExpression *unExp) { unExp->getExpression()->accept(this); - std::pair rhs = popExpressionStack(); + std::pair rhsPair = popExpressionStack(); auto loc = builder.getUnknownLoc(); - Value val; + mlir::Value rhs = load(rhsPair.second); + mlir::Value result; + Type* rhsType = rhsPair.first; switch (unExp->getOp()) { case UnaryOperator::Inc: @@ -263,23 +265,24 @@ void MLIRCodeGen::visit(UnaryExpression *unExp) { // TODO: implement post-, pre-fix decrement break; case UnaryOperator::Plus: - expressionStack.push_back(rhs); + expressionStack.push_back(std::make_pair(rhsPair.first, rhs)); break; case UnaryOperator::Dash: - if (rhs.first->isFloatLike()) { - val = builder.create(loc, rhs.second); + if (rhsType->isFloatLike()) { + result = builder.create(loc, rhs); } else { - val = builder.create(loc, rhs.second); + result = builder.create(loc, rhs); } - expressionStack.push_back(std::make_pair(rhs.first, val)); + + expressionStack.push_back(std::make_pair(rhsType, result)); break; case UnaryOperator::Bang: - val = builder.create(loc, rhs.second); - expressionStack.push_back(std::make_pair(rhs.first, val)); + result = builder.create(loc, rhs); + expressionStack.push_back(std::make_pair(rhsType, result)); break; case UnaryOperator::Tilde: - val = builder.create(loc, rhs.second); - expressionStack.push_back(std::make_pair(rhs.first, val)); + result = builder.create(loc, rhs); + expressionStack.push_back(std::make_pair(rhsType, result)); break; } } @@ -320,7 +323,7 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type, Operation *initializerOp = nullptr; if (expressionStack.size() > 0) { - Value val = popExpressionStack().second; + mlir::Value val = popExpressionStack().second; initializerOp = val.getDefiningOp(); } @@ -346,10 +349,10 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type, varDecl->getInitialzerExpression()->accept(this); } - Value val; + mlir::Value val; if (expressionStack.size() > 0) { - val = popExpressionStack().second; + val = load(popExpressionStack().second); } spirv::PointerType ptrType = spirv::PointerType::get( @@ -404,7 +407,7 @@ void MLIRCodeGen::visit(WhileStatement *whileStmt) { Block *merge = loopOp.getMergeBlock(); builder.create( - loc, conditionOp, body, ArrayRef(), merge, ArrayRef()); + loc, conditionOp, body, ArrayRef(), merge, ArrayRef()); // Emit the continue/latch block. Block *continueBlock = loopOp.getContinueBlock(); @@ -484,37 +487,24 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { void MLIRCodeGen::visit(ArrayAccessExpression *arrayAccess) { auto array = arrayAccess->getArray(); array->accept(this); - std::pair mlirArray = popExpressionStack(); + std::pair mlirArray = popExpressionStack(); Type* elementType = dynamic_cast(mlirArray.first)->getElementType(); std::vector indices; for (auto &access : arrayAccess->getAccessChain()) { access->accept(this); auto val = popExpressionStack().second; - - // If it's a variable index, load it first - if (val.getType().isa()) { - auto loadedIdx = builder.create(builder.getUnknownLoc(), val); - indices.push_back(loadedIdx); - } else { - indices.push_back(val); - } + indices.push_back(load(val)); } - Value accessChain = builder.create(builder.getUnknownLoc(), mlirArray.second, indices); - - if (arrayAccess->isLhs()) { - expressionStack.push_back(std::make_pair(elementType, accessChain)); - } else { - expressionStack.push_back(std::make_pair(elementType, builder.create(builder.getUnknownLoc(), accessChain)->getResult(0))); - } + mlir::Value accessChain = builder.create(builder.getUnknownLoc(), mlirArray.second, indices); + expressionStack.push_back(std::make_pair(elementType, accessChain)); } void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) { auto baseComposite = memberAccess->getBaseComposite(); baseComposite->accept(this); - Value baseCompositeValue = popExpressionStack().second; - std::vector memberIndices; + mlir::Value baseCompositeValue = popExpressionStack().second; std::vector memberIndicesAcc; Type* memberType; @@ -522,7 +512,6 @@ void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) { for (auto &member : memberAccess->getMembers()) { if (auto var = dynamic_cast(member.get())) { auto memberIndexPair = currentBaseComposite->getMemberWithIndex(var->getName()); - memberIndices.push_back(memberIndexPair.first); memberIndicesAcc.push_back(builder.create(builder.getUnknownLoc(), mlir::IntegerType::get(&context, 32, mlir::IntegerType::Signless), builder.getI32IntegerAttr(memberIndexPair.first))); if (memberIndexPair.second->getType()->getKind() == TypeKind::Struct) { @@ -534,16 +523,23 @@ void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) { } memberType = memberIndexPair.second->getType(); + // This is a duplicate of ArrayAccessExpression, idially we want to reuse that. + } else if (auto arrayAccess = dynamic_cast(member.get())) { + auto varName = dynamic_cast(arrayAccess->getArray())->getName(); + auto memberIndexPair = currentBaseComposite->getMemberWithIndex(varName); + memberIndicesAcc.push_back(builder.create(builder.getUnknownLoc(), mlir::IntegerType::get(&context, 32, mlir::IntegerType::Signless), builder.getI32IntegerAttr(memberIndexPair.first))); + memberType = dynamic_cast(memberIndexPair.second->getType())->getElementType(); + + for (auto &access : arrayAccess->getAccessChain()) { + access->accept(this); + auto val = popExpressionStack().second; + memberIndicesAcc.push_back(load(val)); + } } } - if (memberAccess->isLhs()) { - Value accessChain = builder.create(builder.getUnknownLoc(), baseCompositeValue, memberIndicesAcc); - expressionStack.push_back(std::make_pair(memberType, accessChain)); - } else { - Value compositeElement = builder.create(builder.getUnknownLoc(), baseCompositeValue, memberIndices); - expressionStack.push_back(std::make_pair(memberType, compositeElement)); - } + mlir::Value accessChain = builder.create(builder.getUnknownLoc(), baseCompositeValue, memberIndicesAcc); + expressionStack.push_back(std::make_pair(memberType, accessChain)); } } @@ -563,7 +559,7 @@ void MLIRCodeGen::visit(IfStatement *ifStmt) { auto loc = builder.getUnknownLoc(); ifStmt->getCondition()->accept(this); - Value condition = popExpressionStack().second; + mlir::Value condition = popExpressionStack().second; auto selectionOp = builder.create(loc, spirv::SelectionControl::None); selectionOp.addMergeBlock(); @@ -599,7 +595,7 @@ void MLIRCodeGen::visit(IfStatement *ifStmt) { builder.setInsertionPointToEnd(selectionHeaderBlock); builder.create( - loc, condition, thenBlock, ArrayRef(), elseBlock, ArrayRef()); + loc, condition, thenBlock, ArrayRef(), elseBlock, ArrayRef()); builder.setInsertionPointToEnd(restoreInsertionBlock); } @@ -608,8 +604,8 @@ void MLIRCodeGen::visit(AssignmentExpression *assignmentExp) { assignmentExp->getUnaryExpression()->accept(this); assignmentExp->getExpression()->accept(this); - Value val = popExpressionStack().second; - Value ptr = popExpressionStack().second; + mlir::Value val = load(popExpressionStack().second); + mlir::Value ptr = popExpressionStack().second; builder.create(builder.getUnknownLoc(), ptr, val); } @@ -629,7 +625,7 @@ void MLIRCodeGen::visit(CallExpression *callExp) { if (callExp->getArguments().size() > 0) { for (auto &arg : callExp->getArguments()) { arg->accept(this); - operands.push_back(popExpressionStack().second); + operands.push_back(load(popExpressionStack().second)); } } @@ -653,18 +649,18 @@ void MLIRCodeGen::visit(VariableExpression *varExp) { if (entry.isFunctionParam) { expressionStack.push_back(std::make_pair(entry.type, entry.value)); } else if (entry.variable) { - Value val; + mlir::Value val; if (entry.isGlobal) { auto addressOfGlobal = builder.create(builder.getUnknownLoc(), entry.ptrType, varExp->getName()); - val = varExp->isLhs() ? addressOfGlobal->getResult(0) : builder.create(builder.getUnknownLoc(), addressOfGlobal)->getResult(0); + val = addressOfGlobal->getResult(0); // If we're inside the entry point function, collect the used global variables if (insideEntryPoint) { interface.push_back(SymbolRefAttr::get(&context, varExp->getName())); } } else { - val = (varExp->isLhs() || (entry.variable->getType()->getKind() == TypeKind::Array)) ? entry.value : builder.create(builder.getUnknownLoc(), entry.value); + val = entry.value; } if (entry.variable->getType()->getKind() == TypeKind::Struct) { @@ -683,7 +679,7 @@ void MLIRCodeGen::visit(VariableExpression *varExp) { void MLIRCodeGen::visit(IntegerConstantExpression *intConstExp) { auto type = builder.getIntegerType(32, true); - Value val = builder.create( + mlir::Value val = builder.create( builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(32, intConstExp->getVal(), true))); @@ -692,7 +688,7 @@ void MLIRCodeGen::visit(IntegerConstantExpression *intConstExp) { void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) { auto type = builder.getIntegerType(32, false); - Value val = builder.create( + mlir::Value val = builder.create( builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(32, uintConstExp->getVal(), false))); @@ -701,7 +697,7 @@ void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) { void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) { auto type = builder.getF32Type(); - Value val = builder.create( + mlir::Value val = builder.create( builder.getUnknownLoc(), type, FloatAttr::get(type, APFloat(floatConstExp->getVal()))); @@ -710,7 +706,7 @@ void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) { void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) { auto type = builder.getF64Type(); - Value val = builder.create( + mlir::Value val = builder.create( builder.getUnknownLoc(), type, FloatAttr::get(type, APFloat(doubleConstExp->getVal()))); @@ -719,7 +715,7 @@ void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) { void MLIRCodeGen::visit(BoolConstantExpression *boolConstExp) { auto type = builder.getIntegerType(1); - Value val = builder.create( + mlir::Value val = builder.create( builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, boolConstExp->getVal()))); @@ -733,7 +729,7 @@ void MLIRCodeGen::visit(ReturnStatement *returnStmt) { if (expressionStack.empty()) { builder.create(builder.getUnknownLoc()); } else { - Value val = popExpressionStack().second; + mlir::Value val = popExpressionStack().second; builder.create(builder.getUnknownLoc(), val); } } @@ -806,6 +802,14 @@ void MLIRCodeGen::visit(DefaultLabel *defaultLabel) {} void MLIRCodeGen::visit(CaseLabel *defaultLabel) {} +mlir::Value MLIRCodeGen::load(mlir::Value val) { + if (val.getType().isa()) { + return builder.create(builder.getUnknownLoc(), val); + } + + return val; +} + }; // namespace codegen }; // namespace shaderpulse diff --git a/lib/CodeGen/TypeConversion.cpp b/lib/CodeGen/TypeConversion.cpp index 1eb78f6..a1a2389 100644 --- a/lib/CodeGen/TypeConversion.cpp +++ b/lib/CodeGen/TypeConversion.cpp @@ -45,7 +45,6 @@ mlir::Type convertShaderPulseType(mlir::MLIRContext *ctx, Type *shaderPulseType, for (auto &member : structDecl->getMembers()) { auto varMember = dynamic_cast(member.get()); - std::cout << "Converting member type: " << varMember->getIdentifierName() << std::endl; memberTypes.push_back(convertShaderPulseType(ctx, varMember->getType(), structDeclarations)); } diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp index 854e4cb..d67aecd 100644 --- a/lib/Parser/Parser.cpp +++ b/lib/Parser/Parser.cpp @@ -1209,7 +1209,7 @@ std::optional>> Parser::parseArrayAccess do { advanceToken(); - if (auto access = parsePostfixExpression(/*parsingSubExpression*/ true)) { + if (auto access = parsePostfixExpression()) { accessChain.push_back(std::move(access)); advanceToken(); @@ -1297,12 +1297,12 @@ std::unique_ptr Parser::parseUnaryExpression() { std::unique_ptr Parser::parsePostfixExpression(bool parsingSubExpression) { if (auto primary = parsePrimaryExpression()) { - if (parsingSubExpression) { + if (auto access = parseArrayAccess()) { + return std::make_unique(std::move(primary), std::move(*access), parsingLhsExpression); + } else if (parsingSubExpression) { return primary; } else if (auto members = parseMemberAccessChain()) { return std::make_unique(std::move(primary), std::move(*members), parsingLhsExpression); - } else if (auto access = parseArrayAccess()) { - return std::make_unique(std::move(primary), std::move(*access), parsingLhsExpression); } else { return primary; } diff --git a/test/CodeGen/structs.glsl b/test/CodeGen/structs.glsl index ce7c9eb..996e86c 100644 --- a/test/CodeGen/structs.glsl +++ b/test/CodeGen/structs.glsl @@ -14,28 +14,32 @@ struct StructWithArr { int[4] a; } +struct Indices { + int idx1; + int idx2; +} + void main() { // CHECK: %0 = spirv.CompositeConstruct %cst_f32, %cst2_si32, %cst3_ui32, %true : (f32, si32, ui32, i1) -> !spirv.struct<(f32, si32, ui32, i1)> + // CHECK-NEXT: %1 = spirv.Variable : !spirv.ptr, Function> MyStruct myStruct = MyStruct(0.1, 2, 3u, true); - // CHECK: %3 = spirv.CompositeExtract %2[0 : i32] : !spirv.struct<(f32, si32, ui32, i1)> - // CHECK-NEXT: %4 = spirv.Variable : !spirv.ptr - // CHECK-NEXT: spirv.Store "Function" %4, %3 : f32 + // Basic member access + + // CHECK: %cst0_i32 = spirv.Constant 0 : i32 + // CHECK-NEXT: %2 = spirv.AccessChain %1[%cst0_i32] : !spirv.ptr, Function>, i32 float a = myStruct.a; - // CHECK: %6 = spirv.CompositeExtract %5[1 : i32] : !spirv.struct<(f32, si32, ui32, i1)> - // CHECK-NEXT: %7 = spirv.Variable : !spirv.ptr - // CHECK-NEXT: spirv.Store "Function" %7, %6 : si32 + // CHECK: %cst1_i32 = spirv.Constant 1 : i32 + // CHECK-NEXT: %5 = spirv.AccessChain %1[%cst1_i32] : !spirv.ptr, Function>, i32 int b = myStruct.b; - // CHECK: %9 = spirv.CompositeExtract %8[2 : i32] : !spirv.struct<(f32, si32, ui32, i1)> - // CHECK-NEXT: %10 = spirv.Variable : !spirv.ptr - // CHECK-NEXT: spirv.Store "Function" %10, %9 : ui32 + // CHECK: %cst2_i32 = spirv.Constant 2 : i32 + // CHECK-NEXT: %8 = spirv.AccessChain %1[%cst2_i32] : !spirv.ptr, Function>, i32 uint c = myStruct.c; - // CHECK: %12 = spirv.CompositeExtract %11[3 : i32] : !spirv.struct<(f32, si32, ui32, i1)> - // CHECK-NEXT: %13 = spirv.Variable : !spirv.ptr - // CHECK-NEXT: spirv.Store "Function" %13, %12 : i1 + // CHECK: %cst3_i32 = spirv.Constant 3 : i32 + // CHECK-NEXT: %11 = spirv.AccessChain %1[%cst3_i32] : !spirv.ptr, Function>, i32 bool d = myStruct.d; // Struct in struct @@ -45,18 +49,34 @@ void main() { // CHECK-NEXT: %15 = spirv.CompositeConstruct %14, %cst1_si32 : (!spirv.struct<(f32, si32, ui32, i1)>, si32) -> !spirv.struct<(!spirv.struct<(f32, si32, ui32, i1)>, si32)> MyStruct2 myStruct2 = MyStruct2(MyStruct(0.1, 2, 3u, true), 1); - // CHECK: %17 = spirv.Load "Function" %16 : !spirv.struct<(!spirv.struct<(f32, si32, ui32, i1)>, si32)> - // CHECK-NEXT: %cst0_i32_4 = spirv.Constant 0 : i32 - // CHECK-NEXT: %cst1_i32_5 = spirv.Constant 1 : i32 - // CHECK-NEXT: %18 = spirv.CompositeExtract %17[0 : i32, 1 : i32] : !spirv.struct<(!spirv.struct<(f32, si32, ui32, i1)>, si32)> - b = myStruct2.structMember.b; - + // CHECK: %cst0_i32_4 = spirv.Constant 0 : i32 + // CHECK-NEXT: %cst3_i32_5 = spirv.Constant 3 : i32 + // CHECK-NEXT: %17 = spirv.AccessChain %16[%cst0_i32_4, %cst3_i32_5] : !spirv.ptr, si32)>, Function>, i32, i32 + d = myStruct2.structMember.d; // Struct with array + // CHECK: %19 = spirv.CompositeConstruct %cst1_si32_6, %cst2_si32_7, %cst3_si32, %cst4_si32 : (si32, si32, si32, si32) -> !spirv.array<4 x si32> // CHECK-NEXT: %20 = spirv.CompositeConstruct %19 : (!spirv.array<4 x si32>) -> !spirv.struct<(!spirv.array<4 x si32>)> StructWithArr structWithArr = StructWithArr(int[4](1, 2, 3, 4)); - // TODO: This currently fails at the Parser level. Implement member parsing for arrays. - // int arrElemFromStruct = structWithArr.a[0]; + // CHECK: %cst0_i32_8 = spirv.Constant 0 : i32 + // CHECK-NEXT: %cst2_si32_9 = spirv.Constant 2 : si32 + // CHECK-NEXT: %22 = spirv.AccessChain %21[%cst0_i32_8, %cst2_si32_9] : !spirv.ptr)>, Function>, i32, si32 + int arrElemFromStruct = structWithArr.a[2]; + + // Member access as array index + + // CHECK: %25 = spirv.CompositeConstruct %cst1_si32_10, %cst2_si32_11 : (si32, si32) -> !spirv.array<2 x si32> + // CHECK-NEXT: %26 = spirv.Variable : !spirv.ptr, Function> + int[2] arr = int[2](1, 2); + + // CHECK: %28 = spirv.Variable : !spirv.ptr, Function> + Indices indices = Indices(0, 1); + + // CHECK: %cst1_i32_13 = spirv.Constant 1 : i32 + // CHECK-NEXT: %29 = spirv.AccessChain %28[%cst1_i32_13] : !spirv.ptr, Function>, i32 + // CHECK-NEXT:%30 = spirv.Load "Function" %29 : si32 + // CHECK-NEXT:%31 = spirv.AccessChain %26[%30] : !spirv.ptr, Function>, si32 + arr[indices.idx2] = 24; }