From b8b2a5a4b7f0a93ccff36f03a2dad44ec0e859c6 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Tue, 10 Sep 2024 23:25:16 +0200 Subject: [PATCH] Refactor type handling in CodeGen (#29) * Refactor codegen types * Namespace cleanup, add function test case that was previously not possible --- include/CodeGen/MLIRCodeGen.h | 8 +- include/CodeGen/TypeConversion.h | 14 +- lib/CodeGen/MLIRCodeGen.cpp | 329 ++++++++++++++----------------- lib/CodeGen/TypeConversion.cpp | 43 +++- test/CodeGen/functions.glsl | 6 +- 5 files changed, 210 insertions(+), 190 deletions(-) diff --git a/include/CodeGen/MLIRCodeGen.h b/include/CodeGen/MLIRCodeGen.h index 47fc417..71d9e19 100644 --- a/include/CodeGen/MLIRCodeGen.h +++ b/include/CodeGen/MLIRCodeGen.h @@ -14,7 +14,6 @@ #include "llvm/ADT/ScopedHashTable.h" #include #include -#include #include #include @@ -85,7 +84,7 @@ class MLIRCodeGen : public ASTVisitor { bool inGlobalScope = true; llvm::StringMap functionMap; llvm::StringMap structDeclarations; - std::vector> expressionStack; + std::vector expressionStack; StructDeclaration* currentBaseComposite = nullptr; llvm::ScopedHashTable @@ -104,10 +103,9 @@ class MLIRCodeGen : public ASTVisitor { void initBuiltinFuncMap(); bool callBuiltIn(CallExpression* exp); mlir::Value load(mlir::Value); - - std::pair popExpressionStack(); + mlir::Value popExpressionStack(); mlir::Value currentBasePointer; - mlir::Value convertOp(ConstructorExpression* constructorExp, std::pair operand); + mlir::Value convertOp(ConstructorExpression* constructorExp, mlir::Value val); }; }; // namespace codegen diff --git a/include/CodeGen/TypeConversion.h b/include/CodeGen/TypeConversion.h index 7086c6c..f99f930 100644 --- a/include/CodeGen/TypeConversion.h +++ b/include/CodeGen/TypeConversion.h @@ -8,9 +8,17 @@ namespace shaderpulse { namespace codegen { -mlir::Type convertShaderPulseType(mlir::MLIRContext *, Type *, llvm::StringMap &); -std::optional getSpirvStorageClass(TypeQualifier *); -std::optional getLocationFromTypeQualifier(mlir::MLIRContext *ctx, TypeQualifier *); +mlir::Type convertShaderPulseType(mlir::MLIRContext *, shaderpulse::Type *, llvm::StringMap &); +std::optional getSpirvStorageClass(shaderpulse::TypeQualifier *); +std::optional getLocationFromTypeQualifier(mlir::MLIRContext *ctx, shaderpulse::TypeQualifier *); +mlir::Type getElementType(mlir::Type type); +bool isBoolLike(mlir::Type type); +bool isIntLike(mlir::Type type); +bool isSIntLike(mlir::Type type); +bool isUIntLike(mlir::Type type); +bool isFloatLike(mlir::Type type); +bool isF32Like(mlir::Type type); +bool isF64Like(mlir::Type type); }; // namespace codegen }; // namespace shaderpulse diff --git a/lib/CodeGen/MLIRCodeGen.cpp b/lib/CodeGen/MLIRCodeGen.cpp index 35b2324..fab48e6 100644 --- a/lib/CodeGen/MLIRCodeGen.cpp +++ b/lib/CodeGen/MLIRCodeGen.cpp @@ -177,7 +177,7 @@ void MLIRCodeGen::visit(TranslationUnit *unit) { insertEntryPoint(); } -std::pair MLIRCodeGen::popExpressionStack() { +mlir::Value MLIRCodeGen::popExpressionStack() { assert(expressionStack.size() > 0 && "Expression stack is empty"); auto val = expressionStack.back(); expressionStack.pop_back(); @@ -188,12 +188,9 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { binExp->getLhs()->accept(this); binExp->getRhs()->accept(this); - std::pair rhsPair = popExpressionStack(); - std::pair lhsPair = popExpressionStack(); - - mlir::Value rhs = load(rhsPair.second); - mlir::Value lhs = load(lhsPair.second); - shaderpulse::Type* typeContext = lhsPair.first; + mlir::Value rhs = load(popExpressionStack()); + mlir::Value lhs = load(popExpressionStack()); + mlir::Type typeContext = rhs.getType(); // TODO: implement source location auto loc = builder.getUnknownLoc(); @@ -201,137 +198,137 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { switch (binExp->getOp()) { case BinaryOperator::Add: - if (typeContext->isIntLike()) { + if (isIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::Sub: - if (typeContext->isIntLike()) { + if (isIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::Mul: - if (typeContext->isIntLike()) { + if (isIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::Div: - if (typeContext->isUIntLike()) { + if (isUIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); - } else if (typeContext->isIntLike()) { + } else if (isIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::Mod: - if (typeContext->isIntLike()) { + if (isIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::ShiftLeft: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::ShiftRight: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::Lt: - if (typeContext->isFloatLike()) { + if (isFloatLike(typeContext)) { val = builder.create(loc, lhs, rhs); - } else if (typeContext->isUIntLike()) { + } else if (isUIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::Gt: - if (typeContext->isFloatLike()) { + if (isFloatLike(typeContext)) { val = builder.create(loc, lhs, rhs); - } else if (typeContext->isUIntLike()) { + } else if (isUIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::LtEq: - if (typeContext->isFloatLike()) { + if (isFloatLike(typeContext)) { val = builder.create(loc, lhs, rhs); - } else if (typeContext->isUIntLike()) { + } else if (isUIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::GtEq: - if (typeContext->isFloatLike()) { + if (isFloatLike(typeContext)) { val = builder.create(loc, lhs, rhs); - } else if (typeContext->isUIntLike()) { + } else if (isUIntLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::Eq: - if (typeContext->isFloatLike()) { + if (isFloatLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::Neq: - if (typeContext->isFloatLike()) { + if (isFloatLike(typeContext)) { val = builder.create(loc, lhs, rhs); } else { val = builder.create(loc, lhs, rhs); } - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::BitAnd: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::BitXor: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::BitIor: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::LogAnd: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; case BinaryOperator::LogXor: // TODO: not implemented in current spirv dialect break; case BinaryOperator::LogOr: val = builder.create(loc, lhs, rhs); - expressionStack.push_back(std::make_pair(typeContext, val)); + expressionStack.push_back(val); break; } } @@ -341,18 +338,18 @@ void MLIRCodeGen::visit(ConditionalExpression *condExp) { condExp->getTruePart()->accept(this); condExp->getCondition()->accept(this); - std::pair condition = popExpressionStack(); - std::pair truePart = popExpressionStack(); - std::pair falsePart = popExpressionStack(); + mlir::Value condition = load(popExpressionStack()); + mlir::Value truePart = load(popExpressionStack()); + mlir::Value falsePart = load(popExpressionStack()); mlir::Value res = builder.create( builder.getUnknownLoc(), - convertShaderPulseType(&context, truePart.first, structDeclarations), - condition.second, - truePart.second, - falsePart.second); + truePart.getType(), + condition, + truePart, + falsePart); - expressionStack.push_back(std::make_pair(truePart.first, res)); + expressionStack.push_back(res); } void MLIRCodeGen::visit(ForStatement *forStmt) { @@ -369,30 +366,27 @@ void MLIRCodeGen::visit(InitializerExpression *initExp) { builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(32, 0, true))); - expressionStack.push_back(std::make_pair(nullptr, val)); + expressionStack.push_back(val); } void MLIRCodeGen::visit(UnaryExpression *unExp) { unExp->getExpression()->accept(this); - std::pair rhsPair = popExpressionStack(); + mlir::Value ptrRhs = popExpressionStack(); + mlir::Value rhs = load(ptrRhs); + mlir::Value result {}; + mlir::Type rhsType = rhs.getType(); auto loc = builder.getUnknownLoc(); - mlir::Value rhs = load(rhsPair.second); - mlir::Value result; - shaderpulse::Type* rhsType = rhsPair.first; - auto op = unExp->getOp(); switch (op) { case UnaryOperator::Inc: case UnaryOperator::Dec: { - mlir::Value ptrRhs = rhsPair.second; - - if (rhsType->isIntLike()) { + if (isIntLike(rhsType)) { auto one = builder.create( builder.getUnknownLoc(), - mlir::IntegerType::get(&context, 32, rhsType->isUIntLike() ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed), - rhsType->isUIntLike() ? builder.getUI32IntegerAttr(1) :builder.getSI32IntegerAttr(1) + mlir::IntegerType::get(&context, 32, isUIntLike(rhsType) ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed), + isUIntLike(rhsType) ? builder.getUI32IntegerAttr(1) :builder.getSI32IntegerAttr(1) ); if (op == UnaryOperator::Inc) { @@ -411,28 +405,28 @@ void MLIRCodeGen::visit(UnaryExpression *unExp) { } builder.create(builder.getUnknownLoc(), ptrRhs, result); - expressionStack.push_back(std::make_pair(rhsType, result)); + expressionStack.push_back(result); break; } case UnaryOperator::Plus: - expressionStack.push_back(std::make_pair(rhsPair.first, rhs)); + expressionStack.push_back(rhs); break; case UnaryOperator::Dash: - if (rhsType->isFloatLike()) { + if (isFloatLike(rhsType)) { result = builder.create(loc, rhs); } else { result = builder.create(loc, rhs); } - expressionStack.push_back(std::make_pair(rhsType, result)); + expressionStack.push_back(result); break; case UnaryOperator::Bang: result = builder.create(loc, rhs); - expressionStack.push_back(std::make_pair(rhsType, result)); + expressionStack.push_back(result); break; case UnaryOperator::Tilde: result = builder.create(loc, rhs); - expressionStack.push_back(std::make_pair(rhsType, result)); + expressionStack.push_back(result); break; } } @@ -495,7 +489,7 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type, if (varDecl->getInitialzerExpression()) { varDecl->getInitialzerExpression()->accept(this); - val = load(popExpressionStack().second); + val = load(popExpressionStack()); } spirv::PointerType ptrType = spirv::PointerType::get( @@ -549,7 +543,7 @@ void MLIRCodeGen::visit(WhileStatement *whileStmt) { Block *merge = loopOp.getMergeBlock(); whileStmt->getCondition()->accept(this); - auto conditionOp = load(popExpressionStack().second); + auto conditionOp = load(popExpressionStack()); builder.create( loc, conditionOp, body, ArrayRef(), merge, ArrayRef()); @@ -568,14 +562,11 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { auto constructorType = constructorExp->getType(); std::vector operands; - std::vector operandTypes; if (constructorExp->getArguments().size() > 0) { for (auto &arg : constructorExp->getArguments()) { arg->accept(this); - auto typeValPair = popExpressionStack(); - operands.push_back(load(typeValPair.second)); - operandTypes.push_back(typeValPair.first); + operands.push_back(load(popExpressionStack())); } } @@ -588,7 +579,7 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { if (structDeclarations.find(structName) != structDeclarations.end()) { mlir::Value val = builder.create( builder.getUnknownLoc(), convertShaderPulseType(&context, constructorType, structDeclarations), operands); - expressionStack.push_back(std::make_pair(constructorType, val)); + expressionStack.push_back(val); } break; @@ -598,21 +589,20 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { case shaderpulse::TypeKind::Array: { // If the vector constructor has a single argument, and it is the same length as the current vector, // but the element type is different, than it is a type conversion and not a composite construction. - if (constructorTypeKind == shaderpulse::TypeKind::Vector && (operands.size() == 1) && (operandTypes[0]->getKind() == shaderpulse::TypeKind::Vector)) { - auto argVecType = dynamic_cast(operandTypes[0]); - auto constrVecType = dynamic_cast(constructorType); + if (constructorTypeKind == shaderpulse::TypeKind::Vector && (operands.size() == 1) && (operands[0].getType().isa())) { + auto argVecType = operands[0].getType().dyn_cast(); + auto constrVecType = convertShaderPulseType(&context, constructorType, structDeclarations).dyn_cast(); - if ((argVecType->getLength() == constrVecType->getLength()) && !argVecType->getElementType()->isEqual(*constrVecType->getElementType())) { - convertOp(constructorExp, std::make_pair(operandTypes[0], operands[0])); + if ((argVecType.getShape()[0] == constrVecType.getShape()[0]) && (argVecType.getElementType() != constrVecType.getElementType())) { + convertOp(constructorExp, operands[0]); } else { - mlir::Value val = builder.create( - builder.getUnknownLoc(), convertShaderPulseType(&context, constructorType, structDeclarations), operands); - expressionStack.push_back(std::make_pair(constructorType, val)); + mlir::Value val = builder.create(builder.getUnknownLoc(), constrVecType, operands); + expressionStack.push_back(val); } } else { mlir::Value val = builder.create( builder.getUnknownLoc(), convertShaderPulseType(&context, constructorType, structDeclarations), operands); - expressionStack.push_back(std::make_pair(constructorType, val)); + expressionStack.push_back(val); } break; } @@ -638,60 +628,56 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) { mlir::Value val = builder.create( builder.getUnknownLoc(), convertShaderPulseType(&context, constructorType, structDeclarations), columnVectors); - expressionStack.push_back(std::make_pair(constructorType, val)); + expressionStack.push_back(val); break; } // Scalar type conversions default: - convertOp(constructorExp, std::make_pair(operandTypes[0], operands[0])); + convertOp(constructorExp, operands[0]); break; } } -mlir::Value MLIRCodeGen::convertOp(ConstructorExpression* constructorExp, std::pair operand) { - shaderpulse::Type* toType = constructorExp->getType(); - shaderpulse::Type* fromType = operand.first; - mlir::Value val = operand.second; - mlir::Type resultType = convertShaderPulseType(&context, toType, structDeclarations); - - if (fromType->isUIntLike() && toType->isFloatLike()) { - expressionStack.push_back(std::make_pair(toType, builder.create(builder.getUnknownLoc(), resultType, val))); - } else if (fromType->isIntLike() && toType->isFloatLike()) { - expressionStack.push_back(std::make_pair(toType, builder.create(builder.getUnknownLoc(), resultType, val))); - } else if (fromType->isFloatLike() && toType->isUIntLike()) { - expressionStack.push_back(std::make_pair(toType, builder.create(builder.getUnknownLoc(), resultType, val))); - } else if (fromType->isFloatLike() && toType->isIntLike()) { - expressionStack.push_back(std::make_pair(toType, builder.create(builder.getUnknownLoc(), resultType, val))); - } else if ((fromType->isSIntLike() && toType->isUIntLike()) || (fromType->isUIntLike() && toType->isSIntLike())) { - expressionStack.push_back(std::make_pair(toType, builder.create(builder.getUnknownLoc(), resultType, val))); - } else if (fromType->isBoolLike() && toType->isIntLike()) { +mlir::Value MLIRCodeGen::convertOp(ConstructorExpression* constructorExp, mlir::Value val) { + mlir::Type toType = convertShaderPulseType(&context, constructorExp->getType(), structDeclarations); + mlir::Type fromType = val.getType(); + + if (isUIntLike(fromType) && isFloatLike(toType)) { + expressionStack.push_back(builder.create(builder.getUnknownLoc(), toType, val)); + } else if (isIntLike(fromType) && isFloatLike(toType)) { + expressionStack.push_back(builder.create(builder.getUnknownLoc(), toType, val)); + } else if (isFloatLike(fromType) && isUIntLike(toType)) { + expressionStack.push_back(builder.create(builder.getUnknownLoc(), toType, val)); + } else if (isFloatLike(fromType) && isIntLike(toType)) { + expressionStack.push_back(builder.create(builder.getUnknownLoc(), toType, val)); + } else if ((isSIntLike(fromType) && isUIntLike(toType)) || (isUIntLike(fromType) && isSIntLike(toType))) { + expressionStack.push_back(builder.create(builder.getUnknownLoc(), toType, val)); + } else if (isBoolLike(fromType) && isIntLike(toType)) { mlir::Value one; auto constOne = builder.create( builder.getUnknownLoc(), - mlir::IntegerType::get(&context, 32, toType->isUIntLike() ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed), - toType->isUIntLike() ? builder.getUI32IntegerAttr(1) : builder.getSI32IntegerAttr(1) + mlir::IntegerType::get(&context, 32, isUIntLike(toType) ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed), + isUIntLike(toType) ? builder.getUI32IntegerAttr(1) : builder.getSI32IntegerAttr(1) ); mlir::Value zero; auto constZero = builder.create( builder.getUnknownLoc(), - mlir::IntegerType::get(&context, 32, toType->isUIntLike() ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed), - toType->isUIntLike() ? builder.getUI32IntegerAttr(0) : builder.getSI32IntegerAttr(0) + mlir::IntegerType::get(&context, 32, isUIntLike(toType) ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed), + isUIntLike(toType) ? builder.getUI32IntegerAttr(0) : builder.getSI32IntegerAttr(0) ); - if (fromType->getKind() == shaderpulse::TypeKind::Vector) { + if (fromType.isa()) { std::vector operandsZero; std::vector operandsOne; - for (int i = 0; i < dynamic_cast(fromType)->getLength(); i++) { + for (int i = 0; i < fromType.dyn_cast().getShape()[0]; i++) { operandsZero.push_back(constZero); operandsOne.push_back(constOne); } - zero = builder.create( - builder.getUnknownLoc(), convertShaderPulseType(&context, toType, structDeclarations), operandsZero); - one = builder.create( - builder.getUnknownLoc(), convertShaderPulseType(&context, toType, structDeclarations), operandsOne); + zero = builder.create(builder.getUnknownLoc(), toType, operandsZero); + one = builder.create(builder.getUnknownLoc(), toType, operandsOne); } else { one = constOne; zero = constZero; @@ -699,39 +685,37 @@ mlir::Value MLIRCodeGen::convertOp(ConstructorExpression* constructorExp, std::p mlir::Value res = builder.create( builder.getUnknownLoc(), - resultType, + toType, val, one, zero ); - expressionStack.push_back(std::make_pair(toType, res)); - } else if (fromType->isBoolLike() && toType->isFloatLike()) { + expressionStack.push_back(res); + } else if (isBoolLike(fromType) && isFloatLike(toType)) { mlir::Value one; auto constOne = builder.create( builder.getUnknownLoc(), - toType->isF32Like() ? mlir::FloatType::getF32(&context) : mlir::FloatType::getF64(&context), - toType->isF32Like() ? builder.getF32FloatAttr(1.0f) : builder.getF64FloatAttr(1.0) + isF32Like(toType) ? mlir::FloatType::getF32(&context) : mlir::FloatType::getF64(&context), + isF32Like(toType) ? builder.getF32FloatAttr(1.0f) : builder.getF64FloatAttr(1.0) ); mlir::Value zero; auto constZero = builder.create( builder.getUnknownLoc(), - toType->isF32Like() ? mlir::FloatType::getF32(&context) : mlir::FloatType::getF64(&context), - toType->isF32Like() ? builder.getF32FloatAttr(0.0f) : builder.getF64FloatAttr(0.0) + isF32Like(toType) ? mlir::FloatType::getF32(&context) : mlir::FloatType::getF64(&context), + isF32Like(toType) ? builder.getF32FloatAttr(0.0f) : builder.getF64FloatAttr(0.0) ); - if (fromType->getKind() == shaderpulse::TypeKind::Vector) { + if (fromType.isa()) { std::vector operandsZero; std::vector operandsOne; - for (int i = 0; i < dynamic_cast(fromType)->getLength(); i++) { + for (int i = 0; i < fromType.dyn_cast().getShape()[0]; i++) { operandsZero.push_back(constZero); operandsOne.push_back(constOne); } - zero = builder.create( - builder.getUnknownLoc(), convertShaderPulseType(&context, toType, structDeclarations), operandsZero); - one = builder.create( - builder.getUnknownLoc(), convertShaderPulseType(&context, toType, structDeclarations), operandsOne); + zero = builder.create(builder.getUnknownLoc(), toType, operandsZero); + one = builder.create(builder.getUnknownLoc(), toType, operandsOne); } else { one = constOne; zero = constZero; @@ -739,83 +723,78 @@ mlir::Value MLIRCodeGen::convertOp(ConstructorExpression* constructorExp, std::p mlir::Value res = builder.create( builder.getUnknownLoc(), - resultType, + toType, val, one, zero ); - expressionStack.push_back(std::make_pair(toType, res)); - } else if ((fromType->isF32Like() && toType->isF64Like()) || (fromType->isF64Like() && toType->isF32Like())) { - expressionStack.push_back(std::make_pair(toType, builder.create(builder.getUnknownLoc(), resultType, val))); - } else if (toType->isBoolLike()) { + expressionStack.push_back(res); + } else if ((isF32Like(fromType) && isF64Like(toType)) || (isF64Like(fromType) && isF32Like(toType))) { + expressionStack.push_back(builder.create(builder.getUnknownLoc(), toType, val)); + } else if (isBoolLike(toType)) { mlir::Value zero; - if (fromType->isIntLike()) { + if (isIntLike(fromType)) { zero = builder.create( builder.getUnknownLoc(), - mlir::IntegerType::get(&context, 32, fromType->isUIntLike() ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed), - fromType->isUIntLike() ? builder.getUI32IntegerAttr(0) : builder.getSI32IntegerAttr(0) + mlir::IntegerType::get(&context, 32, isUIntLike(fromType) ? mlir::IntegerType::Unsigned : mlir::IntegerType::Signed), + isUIntLike(fromType) ? builder.getUI32IntegerAttr(0) : builder.getSI32IntegerAttr(0) ); - if (fromType->getKind() == shaderpulse::TypeKind::Vector) { + if (fromType.isa()) { std::vector operandsZero; - for (int i = 0; i < dynamic_cast(fromType)->getLength(); i++) { + for (int i = 0; i < fromType.dyn_cast().getShape()[0]; i++) { operandsZero.push_back(zero); } - zero = builder.create( - builder.getUnknownLoc(), convertShaderPulseType(&context, fromType, structDeclarations), operandsZero); + zero = builder.create(builder.getUnknownLoc(), fromType, operandsZero); } - expressionStack.push_back(std::make_pair(toType, builder.create(builder.getUnknownLoc(), val, zero))); - } else if (fromType->isFloatLike()) { + expressionStack.push_back(builder.create(builder.getUnknownLoc(), val, zero)); + } else if (isFloatLike(fromType)) { zero = builder.create( builder.getUnknownLoc(), - fromType->isF32Like() ? mlir::FloatType::getF32(&context) : mlir::FloatType::getF64(&context), - fromType->isF32Like() ? builder.getF32FloatAttr(0.0f) : builder.getF64FloatAttr(0.0) + isF32Like(fromType) ? mlir::FloatType::getF32(&context) : mlir::FloatType::getF64(&context), + isF32Like(fromType) ? builder.getF32FloatAttr(0.0f) : builder.getF64FloatAttr(0.0) ); - if (fromType->getKind() == shaderpulse::TypeKind::Vector) { + if (fromType.isa()) { std::vector operandsZero; - for (int i = 0; i < dynamic_cast(fromType)->getLength(); i++) { + for (int i = 0; i < fromType.dyn_cast().getShape()[0]; i++) { operandsZero.push_back(zero); } - zero = builder.create( - builder.getUnknownLoc(), convertShaderPulseType(&context, fromType, structDeclarations), operandsZero); + zero = builder.create(builder.getUnknownLoc(), fromType, operandsZero); } - expressionStack.push_back(std::make_pair(toType, builder.create(builder.getUnknownLoc(), val, zero))); + expressionStack.push_back(builder.create(builder.getUnknownLoc(), val, zero)); } } else { - expressionStack.push_back(operand); + expressionStack.push_back(val); } } void MLIRCodeGen::visit(ArrayAccessExpression *arrayAccess) { auto array = arrayAccess->getArray(); array->accept(this); - std::pair mlirArray = popExpressionStack(); - shaderpulse::Type* elementType = dynamic_cast(mlirArray.first)->getElementType(); + mlir::Value mlirArray = popExpressionStack(); std::vector indices; for (auto &access : arrayAccess->getAccessChain()) { access->accept(this); - auto val = popExpressionStack().second; - indices.push_back(load(val)); + indices.push_back(load(popExpressionStack())); } - mlir::Value accessChain = builder.create(builder.getUnknownLoc(), mlirArray.second, indices); - expressionStack.push_back(std::make_pair(elementType, accessChain)); + mlir::Value accessChain = builder.create(builder.getUnknownLoc(), mlirArray, indices); + expressionStack.push_back(accessChain); } void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) { auto baseComposite = memberAccess->getBaseComposite(); baseComposite->accept(this); - mlir::Value baseCompositeValue = popExpressionStack().second; + mlir::Value baseCompositeValue = popExpressionStack(); std::vector memberIndicesAcc; - shaderpulse::Type* memberType; if (currentBaseComposite) { for (auto &member : memberAccess->getMembers()) { @@ -830,25 +809,21 @@ void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) { currentBaseComposite = structDeclarations[structName]; } } - - memberType = memberIndexPair.second->getType(); // This is a duplicate of ArrayAccessExpression, idially we want to reuse that. } else if (auto arrayAccess = dynamic_cast(member.get())) { auto varName = dynamic_cast(arrayAccess->getArray())->getName(); auto memberIndexPair = currentBaseComposite->getMemberWithIndex(varName); memberIndicesAcc.push_back(builder.create(builder.getUnknownLoc(), mlir::IntegerType::get(&context, 32, mlir::IntegerType::Signless), builder.getI32IntegerAttr(memberIndexPair.first))); - memberType = dynamic_cast(memberIndexPair.second->getType())->getElementType(); for (auto &access : arrayAccess->getAccessChain()) { access->accept(this); - auto val = popExpressionStack().second; - memberIndicesAcc.push_back(load(val)); + memberIndicesAcc.push_back(load(popExpressionStack())); } } } mlir::Value accessChain = builder.create(builder.getUnknownLoc(), baseCompositeValue, memberIndicesAcc); - expressionStack.push_back(std::make_pair(memberType, accessChain)); + expressionStack.push_back(accessChain); } } @@ -875,7 +850,7 @@ void MLIRCodeGen::visit(IfStatement *ifStmt) { { SymbolTableScopeT varScope(symbolTable); ifStmt->getCondition()->accept(this); - condition = load(popExpressionStack().second); + condition = load(popExpressionStack()); selectionOp = builder.create(loc, spirv::SelectionControl::None); selectionOp.addMergeBlock(); @@ -927,8 +902,8 @@ void MLIRCodeGen::visit(AssignmentExpression *assignmentExp) { assignmentExp->getUnaryExpression()->accept(this); assignmentExp->getExpression()->accept(this); - mlir::Value val = load(popExpressionStack().second); - mlir::Value ptr = popExpressionStack().second; + mlir::Value val = load(popExpressionStack()); + mlir::Value ptr = popExpressionStack(); builder.create(builder.getUnknownLoc(), ptr, val); } @@ -948,7 +923,7 @@ void MLIRCodeGen::visit(CallExpression *callExp) { if (callExp->getArguments().size() > 0) { for (auto &arg : callExp->getArguments()) { arg->accept(this); - operands.push_back(load(popExpressionStack().second)); + operands.push_back(load(popExpressionStack())); } } @@ -958,8 +933,7 @@ void MLIRCodeGen::visit(CallExpression *callExp) { builder.getUnknownLoc(), calledFunc.getFunctionType().getResults(), SymbolRefAttr::get(&context, calledFunc.getSymName()), operands); - // TODO: get return type of callee - expressionStack.push_back(std::make_pair(nullptr, funcCall.getResult(0))); + expressionStack.push_back(funcCall.getResult(0)); } else { assert(callBuiltIn(callExp) && "Function not found"); } @@ -969,7 +943,7 @@ void MLIRCodeGen::visit(VariableExpression *varExp) { auto entry = symbolTable.lookup(varExp->getName()); if (entry.isFunctionParam) { - expressionStack.push_back(std::make_pair(entry.type, entry.value)); + expressionStack.push_back(entry.value); } else if (entry.variable) { mlir::Value val; @@ -993,7 +967,7 @@ void MLIRCodeGen::visit(VariableExpression *varExp) { } } - expressionStack.push_back(std::make_pair(entry.variable->getType(), val)); + expressionStack.push_back(val); } else { std::cout << "Unable to find variable: " << varExp->getName() << std::endl; } @@ -1005,7 +979,7 @@ void MLIRCodeGen::visit(IntegerConstantExpression *intConstExp) { builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(32, intConstExp->getVal(), true))); - expressionStack.push_back(std::make_pair(intConstExp->getType(), val)); + expressionStack.push_back(val); } void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) { @@ -1014,7 +988,7 @@ void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) { builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(32, uintConstExp->getVal(), false))); - expressionStack.push_back(std::make_pair(uintConstExp->getType(), val)); + expressionStack.push_back(val); } void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) { @@ -1023,7 +997,7 @@ void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) { builder.getUnknownLoc(), type, FloatAttr::get(type, APFloat(floatConstExp->getVal()))); - expressionStack.push_back(std::make_pair(floatConstExp->getType(), val)); + expressionStack.push_back(val); } void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) { @@ -1032,7 +1006,7 @@ void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) { builder.getUnknownLoc(), type, FloatAttr::get(type, APFloat(doubleConstExp->getVal()))); - expressionStack.push_back(std::make_pair(doubleConstExp->getType(), val)); + expressionStack.push_back(val); } void MLIRCodeGen::visit(BoolConstantExpression *boolConstExp) { @@ -1041,7 +1015,7 @@ void MLIRCodeGen::visit(BoolConstantExpression *boolConstExp) { builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, boolConstExp->getVal()))); - expressionStack.push_back(std::make_pair(boolConstExp->getType(), val)); + expressionStack.push_back(val); } void MLIRCodeGen::visit(ReturnStatement *returnStmt) { @@ -1051,7 +1025,7 @@ void MLIRCodeGen::visit(ReturnStatement *returnStmt) { if (expressionStack.empty()) { builder.create(builder.getUnknownLoc()); } else { - mlir::Value val = popExpressionStack().second; + mlir::Value val = popExpressionStack(); builder.create(builder.getUnknownLoc(), val); } } @@ -1137,15 +1111,12 @@ bool MLIRCodeGen::callBuiltIn(CallExpression* exp) { if (builtinFuncIt != builtInFuncMap.end()) { std::vector operands; - std::vector types; for (auto &arg : exp->getArguments()) { arg->accept(this); - auto typeValPair = popExpressionStack(); - types.push_back(typeValPair.first); - operands.push_back(load(typeValPair.second)); + operands.push_back(load(popExpressionStack())); } - expressionStack.push_back(std::make_pair(types[0], builtinFuncIt->second(context, builder, operands))); + expressionStack.push_back(builtinFuncIt->second(context, builder, operands)); return true; } else { return false; diff --git a/lib/CodeGen/TypeConversion.cpp b/lib/CodeGen/TypeConversion.cpp index a1a2389..3ed39c0 100644 --- a/lib/CodeGen/TypeConversion.cpp +++ b/lib/CodeGen/TypeConversion.cpp @@ -5,7 +5,7 @@ namespace shaderpulse { namespace codegen { -mlir::Type convertShaderPulseType(mlir::MLIRContext *ctx, Type *shaderPulseType, llvm::StringMap &structDeclarations) { +mlir::Type convertShaderPulseType(mlir::MLIRContext *ctx, shaderpulse::Type *shaderPulseType, llvm::StringMap &structDeclarations) { switch (shaderPulseType->getKind()) { case TypeKind::Void: return mlir::NoneType::get(ctx); @@ -64,7 +64,7 @@ mlir::Type convertShaderPulseType(mlir::MLIRContext *ctx, Type *shaderPulseType, } std::optional -getSpirvStorageClass(TypeQualifier *typeQualifier) { +getSpirvStorageClass(shaderpulse::TypeQualifier *typeQualifier) { if (!typeQualifier) { return std::nullopt; } @@ -110,6 +110,45 @@ getLocationFromTypeQualifier(mlir::MLIRContext *ctx, TypeQualifier *typeQualifie return std::nullopt; } +mlir::Type getElementType(mlir::Type type) { + if (type.isa() || type.isa() || type.isa()) { + auto compositeType = type.dyn_cast(); + return compositeType.getElementType(0); + } else { + return type; + } +} + +// Emulate the Like concept from shaderpulse::Type +bool isBoolLike(mlir::Type type) { + return getElementType(type).isSignlessInteger(1); +} + +bool isIntLike(mlir::Type type) { + return getElementType(type).isInteger(32); +} + +bool isSIntLike(mlir::Type type) { + return getElementType(type).isSignedInteger(32); +} + +bool isUIntLike(mlir::Type type) { + return getElementType(type).isUnsignedInteger(32); +} + +bool isFloatLike(mlir::Type type) { + auto _type = getElementType(type); + return _type.isF32() || _type.isF64(); +} + +bool isF32Like(mlir::Type type) { + return getElementType(type).isF32(); +} + +bool isF64Like(mlir::Type type) { + return getElementType(type).isF64(); +} + }; // namespace codegen }; // namespace shaderpulse diff --git a/test/CodeGen/functions.glsl b/test/CodeGen/functions.glsl index f7aaddd..3e49daf 100644 --- a/test/CodeGen/functions.glsl +++ b/test/CodeGen/functions.glsl @@ -13,5 +13,9 @@ void main() { // CHECK: %2 = spirv.Load "Function" %0 : si32 // CHECK-NEXT: %3 = spirv.Load "Function" %1 : si32 // CHECK-NEXT: %4 = spirv.FunctionCall @add(%2, %3) : (si32, si32) -> si32 - int c = add(a, b); + // CHECK-NEXT: %5 = spirv.Load "Function" %0 : si32 + // CHECK-NEXT: %6 = spirv.Load "Function" %1 : si32 + // CHECK-NEXT: %7 = spirv.FunctionCall @add(%5, %6) : (si32, si32) -> si32 + // CHECK-NEXT: %8 = spirv.IMul %4, %7 : si32 + int c = add(a, b) * add(a, b); }