Skip to content

Commit

Permalink
Support arrays in structs (#24)
Browse files Browse the repository at this point in the history
* Parse array in struct

* No 'isLhs' checks, load ptrs when needed

* Indexing with AccessChain. TODO: thorough testing of advanced indexing.

* Cleanup, use mlir:: perfix for Value, extend structs tests
  • Loading branch information
wpmed92 authored Aug 27, 2024
1 parent 4077b86 commit ecd5e52
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 91 deletions.
1 change: 1 addition & 0 deletions include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class MLIRCodeGen : public ASTVisitor {
void declare(StringRef name, SymbolTableEntry entry);
void createVariable(shaderpulse::Type *, VariableDeclaration *);
void insertEntryPoint();
mlir::Value load(mlir::Value);

std::pair<Type*, Value> popExpressionStack();
mlir::Value currentBasePointer;
Expand Down
136 changes: 70 additions & 66 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void MLIRCodeGen::visit(TranslationUnit *unit) {
insertEntryPoint();
}

std::pair<Type*, Value> MLIRCodeGen::popExpressionStack() {
std::pair<Type*, mlir::Value> MLIRCodeGen::popExpressionStack() {
auto val = expressionStack.back();
expressionStack.pop_back();
return val;
Expand All @@ -64,16 +64,16 @@ void MLIRCodeGen::visit(BinaryExpression *binExp) {
binExp->getLhs()->accept(this);
binExp->getRhs()->accept(this);

std::pair<Type*, Value> rhsPair = popExpressionStack();
std::pair<Type*, Value> lhsPair = popExpressionStack();
std::pair<Type*, mlir::Value> rhsPair = popExpressionStack();
std::pair<Type*, mlir::Value> lhsPair = popExpressionStack();

Value rhs = rhsPair.second;
Value lhs = lhsPair.second;
mlir::Value rhs = load(rhsPair.second);
mlir::Value lhs = load(lhsPair.second);
Type* typeContext = lhsPair.first;

// TODO: implement source location
auto loc = builder.getUnknownLoc();
Value val;
mlir::Value val;

switch (binExp->getOp()) {
case BinaryOperator::Add:
Expand Down Expand Up @@ -217,11 +217,11 @@ void MLIRCodeGen::visit(ConditionalExpression *condExp) {
condExp->getTruePart()->accept(this);
condExp->getCondition()->accept(this);

std::pair<Type*, Value> condition = popExpressionStack();
std::pair<Type*, Value> truePart = popExpressionStack();
std::pair<Type*, Value> falsePart = popExpressionStack();
std::pair<Type*, mlir::Value> condition = popExpressionStack();
std::pair<Type*, mlir::Value> truePart = popExpressionStack();
std::pair<Type*, mlir::Value> falsePart = popExpressionStack();

Value res = builder.create<spirv::SelectOp>(
mlir::Value res = builder.create<spirv::SelectOp>(
builder.getUnknownLoc(),
convertShaderPulseType(&context, truePart.first, structDeclarations),
condition.second,
Expand All @@ -241,7 +241,7 @@ void MLIRCodeGen::visit(InitializerExpression *initExp) {
* is implemented.
*/
auto type = builder.getIntegerType(32, true);
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
IntegerAttr::get(type, APInt(32, 0, true)));

Expand All @@ -250,10 +250,12 @@ void MLIRCodeGen::visit(InitializerExpression *initExp) {

void MLIRCodeGen::visit(UnaryExpression *unExp) {
unExp->getExpression()->accept(this);
std::pair<Type*, Value> rhs = popExpressionStack();
std::pair<Type*, mlir::Value> rhsPair = popExpressionStack();

auto loc = builder.getUnknownLoc();
Value val;
mlir::Value rhs = load(rhsPair.second);
mlir::Value result;
Type* rhsType = rhsPair.first;

switch (unExp->getOp()) {
case UnaryOperator::Inc:
Expand All @@ -263,23 +265,24 @@ void MLIRCodeGen::visit(UnaryExpression *unExp) {
// TODO: implement post-, pre-fix decrement
break;
case UnaryOperator::Plus:
expressionStack.push_back(rhs);
expressionStack.push_back(std::make_pair(rhsPair.first, rhs));
break;
case UnaryOperator::Dash:
if (rhs.first->isFloatLike()) {
val = builder.create<spirv::FNegateOp>(loc, rhs.second);
if (rhsType->isFloatLike()) {
result = builder.create<spirv::FNegateOp>(loc, rhs);
} else {
val = builder.create<spirv::SNegateOp>(loc, rhs.second);
result = builder.create<spirv::SNegateOp>(loc, rhs);
}
expressionStack.push_back(std::make_pair(rhs.first, val));

expressionStack.push_back(std::make_pair(rhsType, result));
break;
case UnaryOperator::Bang:
val = builder.create<spirv::LogicalNotOp>(loc, rhs.second);
expressionStack.push_back(std::make_pair(rhs.first, val));
result = builder.create<spirv::LogicalNotOp>(loc, rhs);
expressionStack.push_back(std::make_pair(rhsType, result));
break;
case UnaryOperator::Tilde:
val = builder.create<spirv::NotOp>(loc, rhs.second);
expressionStack.push_back(std::make_pair(rhs.first, val));
result = builder.create<spirv::NotOp>(loc, rhs);
expressionStack.push_back(std::make_pair(rhsType, result));
break;
}
}
Expand Down Expand Up @@ -320,7 +323,7 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type,
Operation *initializerOp = nullptr;

if (expressionStack.size() > 0) {
Value val = popExpressionStack().second;
mlir::Value val = popExpressionStack().second;
initializerOp = val.getDefiningOp();
}

Expand All @@ -346,10 +349,10 @@ void MLIRCodeGen::createVariable(shaderpulse::Type *type,
varDecl->getInitialzerExpression()->accept(this);
}

Value val;
mlir::Value val;

if (expressionStack.size() > 0) {
val = popExpressionStack().second;
val = load(popExpressionStack().second);
}

spirv::PointerType ptrType = spirv::PointerType::get(
Expand Down Expand Up @@ -404,7 +407,7 @@ void MLIRCodeGen::visit(WhileStatement *whileStmt) {

Block *merge = loopOp.getMergeBlock();
builder.create<spirv::BranchConditionalOp>(
loc, conditionOp, body, ArrayRef<Value>(), merge, ArrayRef<Value>());
loc, conditionOp, body, ArrayRef<mlir::Value>(), merge, ArrayRef<mlir::Value>());

// Emit the continue/latch block.
Block *continueBlock = loopOp.getContinueBlock();
Expand Down Expand Up @@ -484,45 +487,31 @@ void MLIRCodeGen::visit(ConstructorExpression *constructorExp) {
void MLIRCodeGen::visit(ArrayAccessExpression *arrayAccess) {
auto array = arrayAccess->getArray();
array->accept(this);
std::pair<Type*, Value> mlirArray = popExpressionStack();
std::pair<Type*, mlir::Value> mlirArray = popExpressionStack();
Type* elementType = dynamic_cast<ArrayType*>(mlirArray.first)->getElementType();
std::vector<mlir::Value> indices;

for (auto &access : arrayAccess->getAccessChain()) {
access->accept(this);
auto val = popExpressionStack().second;

// If it's a variable index, load it first
if (val.getType().isa<spirv::PointerType>()) {
auto loadedIdx = builder.create<spirv::LoadOp>(builder.getUnknownLoc(), val);
indices.push_back(loadedIdx);
} else {
indices.push_back(val);
}
indices.push_back(load(val));
}

Value accessChain = builder.create<spirv::AccessChainOp>(builder.getUnknownLoc(), mlirArray.second, indices);

if (arrayAccess->isLhs()) {
expressionStack.push_back(std::make_pair(elementType, accessChain));
} else {
expressionStack.push_back(std::make_pair(elementType, builder.create<spirv::LoadOp>(builder.getUnknownLoc(), accessChain)->getResult(0)));
}
mlir::Value accessChain = builder.create<spirv::AccessChainOp>(builder.getUnknownLoc(), mlirArray.second, indices);
expressionStack.push_back(std::make_pair(elementType, accessChain));
}

void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) {
auto baseComposite = memberAccess->getBaseComposite();
baseComposite->accept(this);
Value baseCompositeValue = popExpressionStack().second;
std::vector<int> memberIndices;
mlir::Value baseCompositeValue = popExpressionStack().second;
std::vector<mlir::Value> memberIndicesAcc;
Type* memberType;

if (currentBaseComposite) {
for (auto &member : memberAccess->getMembers()) {
if (auto var = dynamic_cast<VariableExpression*>(member.get())) {
auto memberIndexPair = currentBaseComposite->getMemberWithIndex(var->getName());
memberIndices.push_back(memberIndexPair.first);
memberIndicesAcc.push_back(builder.create<spirv::ConstantOp>(builder.getUnknownLoc(), mlir::IntegerType::get(&context, 32, mlir::IntegerType::Signless), builder.getI32IntegerAttr(memberIndexPair.first)));

if (memberIndexPair.second->getType()->getKind() == TypeKind::Struct) {
Expand All @@ -534,16 +523,23 @@ void MLIRCodeGen::visit(MemberAccessExpression *memberAccess) {
}

memberType = memberIndexPair.second->getType();
// This is a duplicate of ArrayAccessExpression, idially we want to reuse that.
} else if (auto arrayAccess = dynamic_cast<ArrayAccessExpression*>(member.get())) {
auto varName = dynamic_cast<VariableExpression*>(arrayAccess->getArray())->getName();
auto memberIndexPair = currentBaseComposite->getMemberWithIndex(varName);
memberIndicesAcc.push_back(builder.create<spirv::ConstantOp>(builder.getUnknownLoc(), mlir::IntegerType::get(&context, 32, mlir::IntegerType::Signless), builder.getI32IntegerAttr(memberIndexPair.first)));
memberType = dynamic_cast<ArrayType*>(memberIndexPair.second->getType())->getElementType();

for (auto &access : arrayAccess->getAccessChain()) {
access->accept(this);
auto val = popExpressionStack().second;
memberIndicesAcc.push_back(load(val));
}
}
}

if (memberAccess->isLhs()) {
Value accessChain = builder.create<spirv::AccessChainOp>(builder.getUnknownLoc(), baseCompositeValue, memberIndicesAcc);
expressionStack.push_back(std::make_pair(memberType, accessChain));
} else {
Value compositeElement = builder.create<spirv::CompositeExtractOp>(builder.getUnknownLoc(), baseCompositeValue, memberIndices);
expressionStack.push_back(std::make_pair(memberType, compositeElement));
}
mlir::Value accessChain = builder.create<spirv::AccessChainOp>(builder.getUnknownLoc(), baseCompositeValue, memberIndicesAcc);
expressionStack.push_back(std::make_pair(memberType, accessChain));
}
}

Expand All @@ -563,7 +559,7 @@ void MLIRCodeGen::visit(IfStatement *ifStmt) {
auto loc = builder.getUnknownLoc();

ifStmt->getCondition()->accept(this);
Value condition = popExpressionStack().second;
mlir::Value condition = popExpressionStack().second;

auto selectionOp = builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
selectionOp.addMergeBlock();
Expand Down Expand Up @@ -599,7 +595,7 @@ void MLIRCodeGen::visit(IfStatement *ifStmt) {
builder.setInsertionPointToEnd(selectionHeaderBlock);

builder.create<spirv::BranchConditionalOp>(
loc, condition, thenBlock, ArrayRef<Value>(), elseBlock, ArrayRef<Value>());
loc, condition, thenBlock, ArrayRef<mlir::Value>(), elseBlock, ArrayRef<mlir::Value>());

builder.setInsertionPointToEnd(restoreInsertionBlock);
}
Expand All @@ -608,8 +604,8 @@ void MLIRCodeGen::visit(AssignmentExpression *assignmentExp) {
assignmentExp->getUnaryExpression()->accept(this);
assignmentExp->getExpression()->accept(this);

Value val = popExpressionStack().second;
Value ptr = popExpressionStack().second;
mlir::Value val = load(popExpressionStack().second);
mlir::Value ptr = popExpressionStack().second;

builder.create<spirv::StoreOp>(builder.getUnknownLoc(), ptr, val);
}
Expand All @@ -629,7 +625,7 @@ void MLIRCodeGen::visit(CallExpression *callExp) {
if (callExp->getArguments().size() > 0) {
for (auto &arg : callExp->getArguments()) {
arg->accept(this);
operands.push_back(popExpressionStack().second);
operands.push_back(load(popExpressionStack().second));
}
}

Expand All @@ -653,18 +649,18 @@ void MLIRCodeGen::visit(VariableExpression *varExp) {
if (entry.isFunctionParam) {
expressionStack.push_back(std::make_pair(entry.type, entry.value));
} else if (entry.variable) {
Value val;
mlir::Value val;

if (entry.isGlobal) {
auto addressOfGlobal = builder.create<mlir::spirv::AddressOfOp>(builder.getUnknownLoc(), entry.ptrType, varExp->getName());
val = varExp->isLhs() ? addressOfGlobal->getResult(0) : builder.create<spirv::LoadOp>(builder.getUnknownLoc(), addressOfGlobal)->getResult(0);
val = addressOfGlobal->getResult(0);

// If we're inside the entry point function, collect the used global variables
if (insideEntryPoint) {
interface.push_back(SymbolRefAttr::get(&context, varExp->getName()));
}
} else {
val = (varExp->isLhs() || (entry.variable->getType()->getKind() == TypeKind::Array)) ? entry.value : builder.create<spirv::LoadOp>(builder.getUnknownLoc(), entry.value);
val = entry.value;
}

if (entry.variable->getType()->getKind() == TypeKind::Struct) {
Expand All @@ -683,7 +679,7 @@ void MLIRCodeGen::visit(VariableExpression *varExp) {

void MLIRCodeGen::visit(IntegerConstantExpression *intConstExp) {
auto type = builder.getIntegerType(32, true);
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
IntegerAttr::get(type, APInt(32, intConstExp->getVal(), true)));

Expand All @@ -692,7 +688,7 @@ void MLIRCodeGen::visit(IntegerConstantExpression *intConstExp) {

void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) {
auto type = builder.getIntegerType(32, false);
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
IntegerAttr::get(type, APInt(32, uintConstExp->getVal(), false)));

Expand All @@ -701,7 +697,7 @@ void MLIRCodeGen::visit(UnsignedIntegerConstantExpression *uintConstExp) {

void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) {
auto type = builder.getF32Type();
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
FloatAttr::get(type, APFloat(floatConstExp->getVal())));

Expand All @@ -710,7 +706,7 @@ void MLIRCodeGen::visit(FloatConstantExpression *floatConstExp) {

void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) {
auto type = builder.getF64Type();
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
FloatAttr::get(type, APFloat(doubleConstExp->getVal())));

Expand All @@ -719,7 +715,7 @@ void MLIRCodeGen::visit(DoubleConstantExpression *doubleConstExp) {

void MLIRCodeGen::visit(BoolConstantExpression *boolConstExp) {
auto type = builder.getIntegerType(1);
Value val = builder.create<spirv::ConstantOp>(
mlir::Value val = builder.create<spirv::ConstantOp>(
builder.getUnknownLoc(), type,
IntegerAttr::get(type, APInt(1, boolConstExp->getVal())));

Expand All @@ -733,7 +729,7 @@ void MLIRCodeGen::visit(ReturnStatement *returnStmt) {
if (expressionStack.empty()) {
builder.create<spirv::ReturnOp>(builder.getUnknownLoc());
} else {
Value val = popExpressionStack().second;
mlir::Value val = popExpressionStack().second;
builder.create<spirv::ReturnValueOp>(builder.getUnknownLoc(), val);
}
}
Expand Down Expand Up @@ -806,6 +802,14 @@ void MLIRCodeGen::visit(DefaultLabel *defaultLabel) {}

void MLIRCodeGen::visit(CaseLabel *defaultLabel) {}

mlir::Value MLIRCodeGen::load(mlir::Value val) {
if (val.getType().isa<spirv::PointerType>()) {
return builder.create<spirv::LoadOp>(builder.getUnknownLoc(), val);
}

return val;
}

}; // namespace codegen

}; // namespace shaderpulse
1 change: 0 additions & 1 deletion lib/CodeGen/TypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ mlir::Type convertShaderPulseType(mlir::MLIRContext *ctx, Type *shaderPulseType,

for (auto &member : structDecl->getMembers()) {
auto varMember = dynamic_cast<ast::VariableDeclaration*>(member.get());
std::cout << "Converting member type: " << varMember->getIdentifierName() << std::endl;
memberTypes.push_back(convertShaderPulseType(ctx, varMember->getType(), structDeclarations));
}

Expand Down
8 changes: 4 additions & 4 deletions lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ std::optional<std::vector<std::unique_ptr<Expression>>> Parser::parseArrayAccess
do {
advanceToken();

if (auto access = parsePostfixExpression(/*parsingSubExpression*/ true)) {
if (auto access = parsePostfixExpression()) {
accessChain.push_back(std::move(access));
advanceToken();

Expand Down Expand Up @@ -1297,12 +1297,12 @@ std::unique_ptr<Expression> Parser::parseUnaryExpression() {

std::unique_ptr<Expression> Parser::parsePostfixExpression(bool parsingSubExpression) {
if (auto primary = parsePrimaryExpression()) {
if (parsingSubExpression) {
if (auto access = parseArrayAccess()) {
return std::make_unique<ArrayAccessExpression>(std::move(primary), std::move(*access), parsingLhsExpression);
} else if (parsingSubExpression) {
return primary;
} else if (auto members = parseMemberAccessChain()) {
return std::make_unique<MemberAccessExpression>(std::move(primary), std::move(*members), parsingLhsExpression);
} else if (auto access = parseArrayAccess()) {
return std::make_unique<ArrayAccessExpression>(std::move(primary), std::move(*access), parsingLhsExpression);
} else {
return primary;
}
Expand Down
Loading

0 comments on commit ecd5e52

Please sign in to comment.