Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support arrays in structs #24

Merged
merged 4 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class MLIRCodeGen : public ASTVisitor {
void declare(StringRef name, SymbolTableEntry entry);
void createVariable(shaderpulse::Type *, VariableDeclaration *);
void insertEntryPoint();
mlir::Value load(mlir::Value);

std::pair<Type*, Value> popExpressionStack();
mlir::Value currentBasePointer;
Expand Down
136 changes: 70 additions & 66 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void MLIRCodeGen::visit(TranslationUnit *unit) {
insertEntryPoint();
}

std::pair<Type*, Value> MLIRCodeGen::popExpressionStack() {
std::pair<Type*, mlir::Value> MLIRCodeGen::popExpressionStack() {
auto val = expressionStack.back();
expressionStack.pop_back();
return val;
Expand All @@ -64,16 +64,16 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) {
binExp->getLhs()->accept(this);
binExp->getRhs()->accept(this);

std::pair<Type*, Value> rhsPair = popExpressionStack();
std::pair<Type*, Value> lhsPair = popExpressionStack();
std::pair<Type*, mlir::Value> rhsPair = popExpressionStack();
std::pair<Type*, mlir::Value> lhsPair = popExpressionStack();

Value rhs = rhsPair.second;
Value lhs = lhsPair.second;
mlir::Value rhs = load(rhsPair.second);
mlir::Value lhs = load(lhsPair.second);
Type* typeContext = lhsPair.first;

// TODO: implement source location
auto loc = builder.getUnknownLoc();
Value val;
mlir::Value val;

switch (binExp->getOp()) {
case BinaryOperator::Add:
Expand Down Expand Up @@ -217,11 +217,11 @@ void MLIRCodeGen::visit(ConditionalExpression *condExp) {
condExp->getTruePart()->accept(this);
condExp->getCondition()->accept(this);

std::pair<Type*, Value> condition = popExpressionStack();
std::pair<Type*, Value> truePart = popExpressionStack();
std::pair<Type*, Value> falsePart = popExpressionStack();
std::pair<Type*, mlir::Value> condition = popExpressionStack();
std::pair<Type*, mlir::Value> truePart = popExpressionStack();
std::pair<Type*, mlir::Value> falsePart = popExpressionStack();

Value res = builder.create<spirv::SelectOp>(
mlir::Value res = builder.create<spirv::SelectOp>(
builder.getUnknownLoc(),
convertShaderPulseType(&context, truePart.first, structDeclarations),
condition.second,
Expand All @@ -241,7 +241,7 @@ void MLIRCodeGen::visit(InitializerExpression *initExp) {
* is implemented.
*/
auto type = builder.getIntegerType(32, true);
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
IntegerAttr::get(type, APInt(32, 0, true)));

Expand All @@ -250,10 +250,12 @@ void MLIRCodeGen::visit(InitializerExpression *initExp) {

void MLIRCodeGen::visit(UnaryExpression *unExp) {
unExp->getExpression()->accept(this);
std::pair<Type*, Value> rhs = popExpressionStack();
std::pair<Type*, mlir::Value> rhsPair = popExpressionStack();

auto loc = builder.getUnknownLoc();
Value val;
mlir::Value rhs = load(rhsPair.second);
mlir::Value result;
Type* rhsType = rhsPair.first;

switch (unExp->getOp()) {
case UnaryOperator::Inc:
Expand All @@ -263,23 +265,24 @@ void MLIRCodeGen::visit(UnaryExpression *unExp) {
// TODO: implement post-, pre-fix decrement
break;
case UnaryOperator::Plus:
expressionStack.push_back(rhs);
expressionStack.push_back(std::make_pair(rhsPair.first, rhs));
break;
case UnaryOperator::Dash:
if (rhs.first->isFloatLike()) {
val = builder.create<spirv::FNegateOp>(loc, rhs.second);
if (rhsType->isFloatLike()) {
result = builder.create<spirv::FNegateOp>(loc, rhs);
} else {
val = builder.create<spirv::SNegateOp>(loc, rhs.second);
result = builder.create<spirv::SNegateOp>(loc, rhs);
}
expressionStack.push_back(std::make_pair(rhs.first, val));

expressionStack.push_back(std::make_pair(rhsType, result));
break;
case UnaryOperator::Bang:
val = builder.create<spirv::LogicalNotOp>(loc, rhs.second);
expressionStack.push_back(std::make_pair(rhs.first, val));
result = builder.create<spirv::LogicalNotOp>(loc, rhs);
expressionStack.push_back(std::make_pair(rhsType, result));
break;
case UnaryOperator::Tilde:
val = builder.create<spirv::NotOp>(loc, rhs.second);
expressionStack.push_back(std::make_pair(rhs.first, val));
result = builder.create<spirv::NotOp>(loc, rhs);
expressionStack.push_back(std::make_pair(rhsType, result));
break;
}
}
Expand Down Expand Up @@ -320,7 +323,7 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type,
Operation *initializerOp = nullptr;

if (expressionStack.size() > 0) {
Value val = popExpressionStack().second;
mlir::Value val = popExpressionStack().second;
initializerOp = val.getDefiningOp();
}

Expand All @@ -346,10 +349,10 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type,
varDecl->getInitialzerExpression()->accept(this);
}

Value val;
mlir::Value val;

if (expressionStack.size() > 0) {
val = popExpressionStack().second;
val = load(popExpressionStack().second);
}

spirv::PointerType ptrType = spirv::PointerType::get(
Expand Down Expand Up @@ -404,7 +407,7 @@ void MLIRCodeGen::visit(WhileStatement *whileStmt) {

Block *merge = loopOp.getMergeBlock();
builder.create<spirv::BranchConditionalOp>(
loc, conditionOp, body, ArrayRef<Value>(), merge, ArrayRef<Value>());
loc, conditionOp, body, ArrayRef<mlir::Value>(), merge, ArrayRef<mlir::Value>());

// Emit the continue/latch block.
Block *continueBlock = loopOp.getContinueBlock();
Expand Down Expand Up @@ -484,45 +487,31 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) {
void MLIRCodeGen::visit(ArrayAccessExpression *arrayAccess) {
auto array = arrayAccess->getArray();
array->accept(this);
std::pair<Type*, Value> mlirArray = popExpressionStack();
std::pair<Type*, mlir::Value> mlirArray = popExpressionStack();
Type* elementType = dynamic_cast<ArrayType*>(mlirArray.first)->getElementType();
std::vector<mlir::Value> indices;

for (auto &access : arrayAccess->getAccessChain()) {
access->accept(this);
auto val = popExpressionStack().second;

// If it's a variable index, load it first
if (val.getType().isa<spirv::PointerType>()) {
auto loadedIdx = builder.create<spirv::LoadOp>(builder.getUnknownLoc(), val);
indices.push_back(loadedIdx);
} else {
indices.push_back(val);
}
indices.push_back(load(val));
}

Value accessChain = builder.create<spirv::AccessChainOp>(builder.getUnknownLoc(), mlirArray.second, indices);

if (arrayAccess->isLhs()) {
expressionStack.push_back(std::make_pair(elementType, accessChain));
} else {
expressionStack.push_back(std::make_pair(elementType, builder.create<spirv::LoadOp>(builder.getUnknownLoc(), accessChain)->getResult(0)));
}
mlir::Value accessChain = builder.create<spirv::AccessChainOp>(builder.getUnknownLoc(), mlirArray.second, indices);
expressionStack.push_back(std::make_pair(elementType, accessChain));
}

void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) {
auto baseComposite = memberAccess->getBaseComposite();
baseComposite->accept(this);
Value baseCompositeValue = popExpressionStack().second;
std::vector<int> memberIndices;
mlir::Value baseCompositeValue = popExpressionStack().second;
std::vector<mlir::Value> memberIndicesAcc;
Type* memberType;

if (currentBaseComposite) {
for (auto &member : memberAccess->getMembers()) {
if (auto var = dynamic_cast<VariableExpression*>(member.get())) {
auto memberIndexPair = currentBaseComposite->getMemberWithIndex(var->getName());
memberIndices.push_back(memberIndexPair.first);
memberIndicesAcc.push_back(builder.create<spirv::ConstantOp>(builder.getUnknownLoc(), mlir::IntegerType::get(&context, 32, mlir::IntegerType::Signless), builder.getI32IntegerAttr(memberIndexPair.first)));

if (memberIndexPair.second->getType()->getKind() == TypeKind::Struct) {
Expand All @@ -534,16 +523,23 @@ void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) {
}

memberType = memberIndexPair.second->getType();
// This is a duplicate of ArrayAccessExpression, idially we want to reuse that.
} else if (auto arrayAccess = dynamic_cast<ArrayAccessExpression*>(member.get())) {
auto varName = dynamic_cast<VariableExpression*>(arrayAccess->getArray())->getName();
auto memberIndexPair = currentBaseComposite->getMemberWithIndex(varName);
memberIndicesAcc.push_back(builder.create<spirv::ConstantOp>(builder.getUnknownLoc(), mlir::IntegerType::get(&context, 32, mlir::IntegerType::Signless), builder.getI32IntegerAttr(memberIndexPair.first)));
memberType = dynamic_cast<ArrayType*>(memberIndexPair.second->getType())->getElementType();

for (auto &access : arrayAccess->getAccessChain()) {
access->accept(this);
auto val = popExpressionStack().second;
memberIndicesAcc.push_back(load(val));
}
}
}

if (memberAccess->isLhs()) {
Value accessChain = builder.create<spirv::AccessChainOp>(builder.getUnknownLoc(), baseCompositeValue, memberIndicesAcc);
expressionStack.push_back(std::make_pair(memberType, accessChain));
} else {
Value compositeElement = builder.create<spirv::CompositeExtractOp>(builder.getUnknownLoc(), baseCompositeValue, memberIndices);
expressionStack.push_back(std::make_pair(memberType, compositeElement));
}
mlir::Value accessChain = builder.create<spirv::AccessChainOp>(builder.getUnknownLoc(), baseCompositeValue, memberIndicesAcc);
expressionStack.push_back(std::make_pair(memberType, accessChain));
}
}

Expand All @@ -563,7 +559,7 @@ void MLIRCodeGen::visit(IfStatement *ifStmt) {
auto loc = builder.getUnknownLoc();

ifStmt->getCondition()->accept(this);
Value condition = popExpressionStack().second;
mlir::Value condition = popExpressionStack().second;

auto selectionOp = builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
selectionOp.addMergeBlock();
Expand Down Expand Up @@ -599,7 +595,7 @@ void MLIRCodeGen::visit(IfStatement *ifStmt) {
builder.setInsertionPointToEnd(selectionHeaderBlock);

builder.create<spirv::BranchConditionalOp>(
loc, condition, thenBlock, ArrayRef<Value>(), elseBlock, ArrayRef<Value>());
loc, condition, thenBlock, ArrayRef<mlir::Value>(), elseBlock, ArrayRef<mlir::Value>());

builder.setInsertionPointToEnd(restoreInsertionBlock);
}
Expand All @@ -608,8 +604,8 @@ void MLIRCodeGen::visit(AssignmentExpression *assignmentExp) {
assignmentExp->getUnaryExpression()->accept(this);
assignmentExp->getExpression()->accept(this);

Value val = popExpressionStack().second;
Value ptr = popExpressionStack().second;
mlir::Value val = load(popExpressionStack().second);
mlir::Value ptr = popExpressionStack().second;

builder.create<spirv::StoreOp>(builder.getUnknownLoc(), ptr, val);
}
Expand All @@ -629,7 +625,7 @@ void MLIRCodeGen::visit(CallExpression *callExp) {
if (callExp->getArguments().size() > 0) {
for (auto &arg : callExp->getArguments()) {
arg->accept(this);
operands.push_back(popExpressionStack().second);
operands.push_back(load(popExpressionStack().second));
}
}

Expand All @@ -653,18 +649,18 @@ void MLIRCodeGen::visit(VariableExpression *varExp) {
if (entry.isFunctionParam) {
expressionStack.push_back(std::make_pair(entry.type, entry.value));
} else if (entry.variable) {
Value val;
mlir::Value val;

if (entry.isGlobal) {
auto addressOfGlobal = builder.create<mlir::spirv::AddressOfOp>(builder.getUnknownLoc(), entry.ptrType, varExp->getName());
val = varExp->isLhs() ? addressOfGlobal->getResult(0) : builder.create<spirv::LoadOp>(builder.getUnknownLoc(), addressOfGlobal)->getResult(0);
val = addressOfGlobal->getResult(0);

// If we're inside the entry point function, collect the used global variables
if (insideEntryPoint) {
interface.push_back(SymbolRefAttr::get(&context, varExp->getName()));
}
} else {
val = (varExp->isLhs() || (entry.variable->getType()->getKind() == TypeKind::Array)) ? entry.value : builder.create<spirv::LoadOp>(builder.getUnknownLoc(), entry.value);
val = entry.value;
}

if (entry.variable->getType()->getKind() == TypeKind::Struct) {
Expand All @@ -683,7 +679,7 @@ void MLIRCodeGen::visit(VariableExpression *varExp) {

void MLIRCodeGen::visit(IntegerConstantExpression *intConstExp) {
auto type = builder.getIntegerType(32, true);
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
IntegerAttr::get(type, APInt(32, intConstExp->getVal(), true)));

Expand All @@ -692,7 +688,7 @@ void MLIRCodeGen::visit(IntegerConstantExpression *intConstExp) {

void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) {
auto type = builder.getIntegerType(32, false);
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
IntegerAttr::get(type, APInt(32, uintConstExp->getVal(), false)));

Expand All @@ -701,7 +697,7 @@ void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) {

void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) {
auto type = builder.getF32Type();
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
FloatAttr::get(type, APFloat(floatConstExp->getVal())));

Expand All @@ -710,7 +706,7 @@ void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) {

void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) {
auto type = builder.getF64Type();
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
FloatAttr::get(type, APFloat(doubleConstExp->getVal())));

Expand All @@ -719,7 +715,7 @@ void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) {

void MLIRCodeGen::visit(BoolConstantExpression *boolConstExp) {
auto type = builder.getIntegerType(1);
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
IntegerAttr::get(type, APInt(1, boolConstExp->getVal())));

Expand All @@ -733,7 +729,7 @@ void MLIRCodeGen::visit(ReturnStatement *returnStmt) {
if (expressionStack.empty()) {
builder.create<spirv::ReturnOp>(builder.getUnknownLoc());
} else {
Value val = popExpressionStack().second;
mlir::Value val = popExpressionStack().second;
builder.create<spirv::ReturnValueOp>(builder.getUnknownLoc(), val);
}
}
Expand Down Expand Up @@ -806,6 +802,14 @@ void MLIRCodeGen::visit(DefaultLabel *defaultLabel) {}

void MLIRCodeGen::visit(CaseLabel *defaultLabel) {}

mlir::Value MLIRCodeGen::load(mlir::Value val) {
if (val.getType().isa<spirv::PointerType>()) {
return builder.create<spirv::LoadOp>(builder.getUnknownLoc(), val);
}

return val;
}

}; // namespace codegen

}; // namespace shaderpulse
1 change: 0 additions & 1 deletion lib/CodeGen/TypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ mlir::Type convertShaderPulseType(mlir::MLIRContext *ctx, Type *shaderPulseType,

for (auto &member : structDecl->getMembers()) {
auto varMember = dynamic_cast<ast::VariableDeclaration*>(member.get());
std::cout << "Converting member type: " << varMember->getIdentifierName() << std::endl;
memberTypes.push_back(convertShaderPulseType(ctx, varMember->getType(), structDeclarations));
}

Expand Down
8 changes: 4 additions & 4 deletions lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ std::optional<std::vector<std::unique_ptr<Expression>>> Parser::parseArrayAccess
do {
advanceToken();

if (auto access = parsePostfixExpression(/*parsingSubExpression*/ true)) {
if (auto access = parsePostfixExpression()) {
accessChain.push_back(std::move(access));
advanceToken();

Expand Down Expand Up @@ -1297,12 +1297,12 @@ std::unique_ptr<Expression> Parser::parseUnaryExpression() {

std::unique_ptr<Expression> Parser::parsePostfixExpression(bool parsingSubExpression) {
if (auto primary = parsePrimaryExpression()) {
if (parsingSubExpression) {
if (auto access = parseArrayAccess()) {
return std::make_unique<ArrayAccessExpression>(std::move(primary), std::move(*access), parsingLhsExpression);
} else if (parsingSubExpression) {
return primary;
} else if (auto members = parseMemberAccessChain()) {
return std::make_unique<MemberAccessExpression>(std::move(primary), std::move(*members), parsingLhsExpression);
} else if (auto access = parseArrayAccess()) {
return std::make_unique<ArrayAccessExpression>(std::move(primary), std::move(*access), parsingLhsExpression);
} else {
return primary;
}
Expand Down
Loading
Loading