Skip to content

Commit

Permalink
Create scope for if/else and while, add scope tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed Sep 5, 2024
1 parent d9bf73b commit c577863
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 27 deletions.
2 changes: 1 addition & 1 deletion include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class MLIRCodeGen : public ASTVisitor {
llvm::ScopedHashTable<llvm::StringRef, SymbolTableEntry>
symbolTable;
using SymbolTableScopeT =
llvm::ScopedHashTableScope<StringRef, SymbolTableEntry>;
llvm::ScopedHashTableScope<llvm::StringRef, SymbolTableEntry>;

SymbolTableScopeT globalScope;
SmallVector<Attribute, 4> interface;
Expand Down
63 changes: 37 additions & 26 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,7 @@ void MLIRCodeGen::visit(UnaryExpression *unExp) {
}
}

void MLIRCodeGen::declare(StringRef name, SymbolTableEntry entry) {
if (symbolTable.count(name)) {
return;
}

void MLIRCodeGen::declare(llvm::StringRef name, SymbolTableEntry entry) {
symbolTable.insert(name, entry);
}

Expand Down Expand Up @@ -405,6 +401,7 @@ void MLIRCodeGen::visit(SwitchStatement *switchStmt) {
void MLIRCodeGen::visit(WhileStatement *whileStmt) {
Block *restoreInsertionBlock = builder.getInsertionBlock();

SymbolTableScopeT varScope(symbolTable);
whileStmt->getCondition()->accept(this);

auto conditionOp = load(popExpressionStack().second);
Expand Down Expand Up @@ -745,29 +742,41 @@ void MLIRCodeGen::visit(DoStatement *doStmt) {

void MLIRCodeGen::visit(IfStatement *ifStmt) {
Block *restoreInsertionBlock = builder.getInsertionBlock();

auto loc = builder.getUnknownLoc();
spirv::SelectionOp selectionOp;
mlir::Value condition;
Block* selectionHeaderBlock;
Block* thenBlock;
Block* mergeBlock;

// Scope for true part
{
SymbolTableScopeT varScope(symbolTable);
ifStmt->getCondition()->accept(this);
condition = load(popExpressionStack().second);
selectionOp = builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
selectionOp.addMergeBlock();

// Merge
mergeBlock = selectionOp.getMergeBlock();

// Selection header
selectionHeaderBlock = new Block();
selectionOp.getBody().getBlocks().push_front(selectionHeaderBlock);

// True part
thenBlock = new Block();
selectionOp.getBody().getBlocks().insert(std::next(selectionOp.getBody().begin(), 1), thenBlock);
builder.setInsertionPointToStart(thenBlock);

ifStmt->getTruePart()->accept(this);
builder.create<spirv::BranchOp>(loc, mergeBlock);

// If scope destroyed here
}

ifStmt->getCondition()->accept(this);
mlir::Value condition = load(popExpressionStack().second);

auto selectionOp = builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
selectionOp.addMergeBlock();

// Merge
auto *mergeBlock = selectionOp.getMergeBlock();

// Selection header
auto *selectionHeaderBlock = new Block();
selectionOp.getBody().getBlocks().push_front(selectionHeaderBlock);

// True part
auto *thenBlock = new Block();
selectionOp.getBody().getBlocks().insert(std::next(selectionOp.getBody().begin(), 1), thenBlock);
builder.setInsertionPointToStart(thenBlock);

ifStmt->getTruePart()->accept(this);
builder.create<spirv::BranchOp>(loc, mergeBlock);
// Scope for else part
SymbolTableScopeT varScope(symbolTable);

// False part
auto *elseBlock = new Block();
Expand All @@ -788,6 +797,8 @@ void MLIRCodeGen::visit(IfStatement *ifStmt) {
loc, condition, thenBlock, ArrayRef<mlir::Value>(), elseBlock, ArrayRef<mlir::Value>());

builder.setInsertionPointToEnd(restoreInsertionBlock);

// Else scope destroyed here
}

void MLIRCodeGen::visit(AssignmentExpression *assignmentExp) {
Expand Down
79 changes: 79 additions & 0 deletions test/CodeGen/scopes.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
void main() {
// CHECK: %0 = spirv.Variable : !spirv.ptr<si32, Function>
int a;

/*
*
* test new scope for 'if' and 'else' parts
*
*/

// CHECK: %1 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %2 = spirv.IEqual %1, %cst1_si32 : si32
if (a == 1) {
// CHECK: %7 = spirv.Variable : !spirv.ptr<si32, Function>
int a;

// CHECK: %cst2_si32_2 = spirv.Constant 2 : si32
// CHECK-NEXT: spirv.Store "Function" %7, %cst2_si32_2 : si32
a = 2;
} else {
// CHECK: %8 = spirv.Variable : !spirv.ptr<si32, Function>
int a;

// CHECK: %cst3_si32 = spirv.Constant 3 : si32
// CHECK-NEXT: spirv.Store "Function" %8, %cst3_si32 : si32
a = 3;
}

// CHECK: %cst2_si32 = spirv.Constant 2 : si32
// CHECK-NEXT: spirv.Store "Function" %0, %cst2_si32 : si32
a = 2;

/*
*
* test new scope for loop body
*
*/

// CHECK: %3 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %4 = spirv.IEqual %3, %cst1_si32_0 : si32
while (a == 1) {
// CHECK: %7 = spirv.Variable : !spirv.ptr<si32, Function>
int a;

// CHECK: %cst5_si32 = spirv.Constant 5 : si32
// CHECK-NEXT: spirv.Store "Function" %7, %cst5_si32 : si32
a = 5;
}

// CHECK: %cst4_si32 = spirv.Constant 4 : si32
// CHECK-NEXT: spirv.Store "Function" %0, %cst4_si32 : si32
a = 4;

/*
*
* test nested scopes
*
*/

// CHECK: %5 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %6 = spirv.IEqual %5, %cst1_si32_1 : si32
if (a == 1) {
// CHECK: %7 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: %cst1_si32_2 = spirv.Constant 1 : si32
// CHECK-NEXT: spirv.Store "Function" %7, %cst1_si32_2 : si32
int a;
a = 1;

// CHECK: %8 = spirv.Load "Function" %7 : si32
// CHECK-NEXT: %9 = spirv.IEqual %8, %cst2_si32_3 : si32
if (a == 2) {
// CHECK: %10 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: %cst2_si32_4 = spirv.Constant 2 : si32
// CHECK-NEXT: spirv.Store "Function" %10, %cst2_si32_4 : si32
int a;
a = 2;
}
}
}

0 comments on commit c577863

Please sign in to comment.