Skip to content

Commit

Permalink
CodeGen: chose binary operator based on variable type
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed Aug 14, 2024
1 parent f99685d commit 5a0cb1b
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 5 deletions.
24 changes: 24 additions & 0 deletions include/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ class Type {
return other.kind == kind;
}

virtual bool isIntLike() {
return kind == TypeKind::Integer || kind == TypeKind::UnsignedInteger;
}

virtual bool isUintLike() {
return kind == TypeKind::UnsignedInteger;
}

virtual bool isFloatLike() {
return kind == TypeKind::Float || kind == TypeKind::Double;
}

virtual std::string toString() {
switch (kind) {
case TypeKind::Integer:
Expand Down Expand Up @@ -217,6 +229,18 @@ class VectorType : public Type {
return false;
}

bool isIntLike() override {
return elementType->getKind() == TypeKind::Integer || elementType->getKind() == TypeKind::UnsignedInteger;
}

bool isUintLike() override {
return elementType->getKind() == TypeKind::UnsignedInteger;
}

bool isFloatLike() override {
return elementType->getKind() == TypeKind::Float || elementType->getKind() == TypeKind::Double;
}

std::string toString() override {
std::string prefix = "";

Expand Down
1 change: 1 addition & 0 deletions include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class MLIRCodeGen : public ASTVisitor {

mlir::Value popExpressionStack();
mlir::Value currentBasePointer;
Type* typeContext;
};

}; // namespace codegen
Expand Down
37 changes: 32 additions & 5 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,49 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) {

switch (binExp->getOp()) {
case BinaryOperator::Add:
val = builder.create<spirv::FAddOp>(loc, lhs, rhs);
if (typeContext->isIntLike()) {
val = builder.create<spirv::IAddOp>(loc, lhs, rhs);
} else {
val = builder.create<spirv::FAddOp>(loc, lhs, rhs);
}
expressionStack.push_back(val);
break;
case BinaryOperator::Sub:
val = builder.create<spirv::FSubOp>(loc, lhs, rhs);
if (typeContext->isIntLike()) {
val = builder.create<spirv::ISubOp>(loc, lhs, rhs);
} else {
val = builder.create<spirv::FSubOp>(loc, lhs, rhs);
}

expressionStack.push_back(val);
break;
case BinaryOperator::Mul:
val = builder.create<spirv::FMulOp>(loc, lhs, rhs);
if (typeContext->isIntLike()) {
val = builder.create<spirv::IMulOp>(loc, lhs, rhs);
} else {
val = builder.create<spirv::FMulOp>(loc, lhs, rhs);
}

expressionStack.push_back(val);
break;
case BinaryOperator::Div:
val = builder.create<spirv::FDivOp>(loc, lhs, rhs);
if (typeContext->isUintLike()) {
val = builder.create<spirv::UDivOp>(loc, lhs, rhs);
} else if (typeContext->isIntLike()) {
val = builder.create<spirv::SDivOp>(loc, lhs, rhs);
} else {
val = builder.create<spirv::FDivOp>(loc, lhs, rhs);
}

expressionStack.push_back(val);
break;
case BinaryOperator::Mod:
val = builder.create<spirv::FRemOp>(loc, lhs, rhs);
if (typeContext->isIntLike()) {
val = builder.create<spirv::SRemOp>(loc, lhs, rhs);
} else {
val = builder.create<spirv::FRemOp>(loc, lhs, rhs);
}

expressionStack.push_back(val);
break;
case BinaryOperator::ShiftLeft:
Expand Down Expand Up @@ -240,6 +266,7 @@ void MLIRCodeGen::visit(VariableDeclarationList *varDeclList) {
void MLIRCodeGen::createVariable(shaderpulse::Type *type,
VariableDeclaration *varDecl) {
shaderpulse::Type *varType = (type) ? type : varDecl->getType();
typeContext = varType;

if (inGlobalScope) {
std::cout << "In global scope" << std::endl;
Expand Down

0 comments on commit 5a0cb1b

Please sign in to comment.