diff --git a/include/AST/Types.h b/include/AST/Types.h index 1e81b48..f17c3e9 100644 --- a/include/AST/Types.h +++ b/include/AST/Types.h @@ -128,6 +128,18 @@ class Type { return other.kind == kind; } + virtual bool isIntLike() { + return kind == TypeKind::Integer || kind == TypeKind::UnsignedInteger; + } + + virtual bool isUintLike() { + return kind == TypeKind::UnsignedInteger; + } + + virtual bool isFloatLike() { + return kind == TypeKind::Float || kind == TypeKind::Double; + } + virtual std::string toString() { switch (kind) { case TypeKind::Integer: @@ -217,6 +229,18 @@ class VectorType : public Type { return false; } + bool isIntLike() override { + return elementType->getKind() == TypeKind::Integer || elementType->getKind() == TypeKind::UnsignedInteger; + } + + bool isUintLike() override { + return elementType->getKind() == TypeKind::UnsignedInteger; + } + + bool isFloatLike() override { + return elementType->getKind() == TypeKind::Float || elementType->getKind() == TypeKind::Double; + } + std::string toString() override { std::string prefix = ""; diff --git a/include/CodeGen/MLIRCodeGen.h b/include/CodeGen/MLIRCodeGen.h index cbbffc9..b6c6e50 100644 --- a/include/CodeGen/MLIRCodeGen.h +++ b/include/CodeGen/MLIRCodeGen.h @@ -97,6 +97,7 @@ class MLIRCodeGen : public ASTVisitor { mlir::Value popExpressionStack(); mlir::Value currentBasePointer; + Type* typeContext; }; }; // namespace codegen diff --git a/lib/CodeGen/MLIRCodeGen.cpp b/lib/CodeGen/MLIRCodeGen.cpp index 0c91743..987cf8a 100644 --- a/lib/CodeGen/MLIRCodeGen.cpp +++ b/lib/CodeGen/MLIRCodeGen.cpp @@ -77,23 +77,49 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) { switch (binExp->getOp()) { case BinaryOperator::Add: - val = builder.create(loc, lhs, rhs); + if (typeContext->isIntLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } expressionStack.push_back(val); break; case BinaryOperator::Sub: - val = builder.create(loc, lhs, rhs); + if (typeContext->isIntLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(val); break; case BinaryOperator::Mul: - val = builder.create(loc, lhs, rhs); + if (typeContext->isIntLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(val); break; case BinaryOperator::Div: - val = builder.create(loc, lhs, rhs); + if (typeContext->isUintLike()) { + val = builder.create(loc, lhs, rhs); + } else if (typeContext->isIntLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(val); break; case BinaryOperator::Mod: - val = builder.create(loc, lhs, rhs); + if (typeContext->isIntLike()) { + val = builder.create(loc, lhs, rhs); + } else { + val = builder.create(loc, lhs, rhs); + } + expressionStack.push_back(val); break; case BinaryOperator::ShiftLeft: @@ -240,6 +266,7 @@ void MLIRCodeGen::visit(VariableDeclarationList *varDeclList) { void MLIRCodeGen::createVariable(shaderpulse::Type *type, VariableDeclaration *varDecl) { shaderpulse::Type *varType = (type) ? type : varDecl->getType(); + typeContext = varType; if (inGlobalScope) { std::cout << "In global scope" << std::endl;