Skip to content

Commit

Permalink
Implement break and continue (#33)
Browse files Browse the repository at this point in the history
* WIP

* Break/continue tests

* Fix loop continue/break gate block ordering
  • Loading branch information
wpmed92 authored Sep 29, 2024
1 parent cec70c0 commit 57339ce
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 60 deletions.
5 changes: 5 additions & 0 deletions include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class MLIRCodeGen : public ASTVisitor {
std::vector<mlir::Value> expressionStack;
StructDeclaration* currentBaseComposite = nullptr;
mlir::Operation *execModeOp = nullptr;
std::vector<mlir::spirv::VariableOp> breakStack;
std::vector<mlir::spirv::VariableOp> continueStack;
bool breakDetected = false;
bool continueDetected = false;

llvm::ScopedHashTable<llvm::StringRef, SymbolTableEntry>
symbolTable;
Expand All @@ -107,6 +111,7 @@ class MLIRCodeGen : public ASTVisitor {
bool callBuiltIn(CallExpression* exp);
void createBuiltinComputeVar(const std::string &varName, const std::string &mlirName);
void generateLoop(Statement* initStmt, Expression* conditionExpr, Expression* inductionExpr, Statement* bodyStmt);
void setBoolVar(mlir::spirv::VariableOp var, bool val);
mlir::Value load(mlir::Value);
mlir::Value popExpressionStack();
mlir::Value currentBasePointer;
Expand Down
92 changes: 85 additions & 7 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1119,11 +1119,25 @@ void MLIRCodeGen::visit(ReturnStatement *returnStmt) {
}
}

void MLIRCodeGen::visit(BreakStatement *breakStmt) {}
void MLIRCodeGen::visit(BreakStatement *breakStmt) {
setBoolVar(breakStack.back(), true);
breakDetected = true;
}

void MLIRCodeGen::visit(ContinueStatement *continueStmt) {
setBoolVar(continueStack.back(), true);
continueDetected = true;
}

void MLIRCodeGen::visit(ContinueStatement *continueStmt) {}
void MLIRCodeGen::setBoolVar(mlir::spirv::VariableOp var, bool val) {
auto type = builder.getIntegerType(1);
mlir::Value constTrue = builder.create<spirv::ConstantOp>(builder.getUnknownLoc(), type, IntegerAttr::get(type, APInt(1, val)));
builder.create<spirv::StoreOp>(builder.getUnknownLoc(), var, constTrue);
}

void MLIRCodeGen::visit(DiscardStatement *discardStmt) {}
void MLIRCodeGen::visit(DiscardStatement *discardStmt) {

}

void MLIRCodeGen::visit(FunctionDeclaration *funcDecl) {
insideEntryPoint = funcDecl->getName() == "main";
Expand Down Expand Up @@ -1200,6 +1214,17 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E
Block *restoreInsertionBlock = builder.getInsertionBlock();
SymbolTableScopeT varScope(symbolTable);

mlir::Type boolType = mlir::IntegerType::get(&context, 1, mlir::IntegerType::Signless);
spirv::PointerType ptrType = spirv::PointerType::get(boolType, mlir::spirv::StorageClass::Function);
breakStack.push_back(
builder.create<spirv::VariableOp>(
builder.getUnknownLoc(), ptrType, spirv::StorageClass::Function, nullptr)
);
continueStack.push_back(
builder.create<spirv::VariableOp>(
builder.getUnknownLoc(), ptrType, spirv::StorageClass::Function, nullptr)
);

if (initStmt) {
initStmt->accept(this);
}
Expand All @@ -1216,6 +1241,10 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E
Block *body = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), 2), body);

// Insert the continue block.
Block *continueBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), 3), continueBlock);

// Emit the entry code.
Block *entry = loopOp.getEntryBlock();
builder.setInsertionPointToEnd(entry);
Expand All @@ -1229,18 +1258,67 @@ void MLIRCodeGen::generateLoop(Statement* initStmt, Expression* conditionExpr, E
auto conditionOp = load(popExpressionStack());
builder.create<spirv::BranchConditionalOp>(loc, conditionOp, body, ArrayRef<mlir::Value>(), merge, ArrayRef<mlir::Value>());

// Emit the continue/latch block.
builder.setInsertionPointToStart(body);
bodyStmt->accept(this);

// Detect break/continue flag
int postGateBlockInsertionPoint = 2;

if (auto body = dynamic_cast<StatementList*>(bodyStmt)) {
for (auto &stmt : body->getStatements()) {
stmt->accept(this);

if (breakDetected || continueDetected) {
if (breakDetected && continueDetected) {
auto continueGate = continueStack.back();
auto breakGate = breakStack.back();
Block *breakCheckBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), breakCheckBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef<mlir::Value>(), breakCheckBlock, ArrayRef<mlir::Value>());
Block *postGateBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock);
builder.setInsertionPointToStart(breakCheckBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(breakGate), merge, ArrayRef<mlir::Value>(), postGateBlock, ArrayRef<mlir::Value>());
builder.setInsertionPointToStart(postGateBlock);
} else if (continueDetected) {
auto continueGate = continueStack.back();
Block *postGateBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(continueGate), loopOp.getContinueBlock(), ArrayRef<mlir::Value>(), postGateBlock, ArrayRef<mlir::Value>());
builder.setInsertionPointToStart(postGateBlock);
} else if (breakDetected) {
auto breakGate = breakStack.back();
Block *postGateBlock = new Block();
loopOp.getBody().getBlocks().insert(std::next(loopOp.getBody().begin(), ++postGateBlockInsertionPoint), postGateBlock);
builder.create<spirv::BranchConditionalOp>(loc, load(breakGate), merge, ArrayRef<mlir::Value>(), postGateBlock, ArrayRef<mlir::Value>());
builder.setInsertionPointToStart(postGateBlock);
}

if (breakDetected) {
setBoolVar(breakStack.back(), false);
}

if (continueDetected) {
setBoolVar(continueStack.back(), false);
}
breakDetected = false;
continueDetected = false;
}
}
} else {
bodyStmt->accept(this);
}

builder.create<spirv::BranchOp>(loc, loopOp.getContinueBlock());
builder.setInsertionPointToEnd(loopOp.getContinueBlock());

if (inductionExpr) {
inductionExpr->accept(this);
}

Block *continueBlock = loopOp.getContinueBlock();
builder.setInsertionPointToEnd(continueBlock);
builder.create<spirv::BranchOp>(loc, header);
builder.setInsertionPointToEnd(restoreInsertionBlock);
breakStack.pop_back();
continueStack.pop_back();
}

mlir::Value MLIRCodeGen::load(mlir::Value val) {
Expand Down
47 changes: 20 additions & 27 deletions test/CodeGen/cf_loops_for.glsl
Original file line number Diff line number Diff line change
@@ -1,36 +1,29 @@
void main() {
// CHECK: %cst0_si32 = spirv.Constant 0 : si32
// CHECK-NEXT: %0 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %0, %cst0_si32 : si32
// CHECK-NEXT: spirv.mlir.loop {
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb2
// CHECK: %2 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK: spirv.mlir.loop {
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb3
// CHECK-NEXT: %cst10_si32 = spirv.Constant 10 : si32
// CHECK-NEXT: %2 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %3 = spirv.SLessThan %2, %cst10_si32 : si32
// CHECK-NEXT: spirv.BranchConditional %3, ^bb2, ^bb3
// CHECK-NEXT: ^bb2: // pred: ^bb1
// CHECK-NEXT: %3 = spirv.Load "Function" %2 : si32
// CHECK-NEXT: %4 = spirv.SLessThan %3, %cst10_si32 : si32
// CHECK-NEXT: spirv.BranchConditional %4, ^bb2, ^bb4
// CHECK-NEXT:^bb2: // pred: ^bb1
// CHECK-NEXT: %cst1_si32 = spirv.Constant 1 : si32
// CHECK-NEXT: %4 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %5 = spirv.IAdd %4, %cst1_si32 : si32
// CHECK-NEXT: %6 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %6, %5 : si32
// CHECK-NEXT: %7 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %cst1_si32_1 = spirv.Constant 1 : si32
// CHECK-NEXT: %8 = spirv.IAdd %7, %cst1_si32_1 : si32
// CHECK-NEXT: spirv.Store "Function" %0, %8 : si32
// CHECK-NEXT: %5 = spirv.Load "Function" %2 : si32
// CHECK-NEXT: %6 = spirv.IAdd %5, %cst1_si32 : si32
// CHECK-NEXT: %7 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %7, %6 : si32
// CHECK-NEXT: spirv.Branch ^bb3
// CHECK-NEXT:^bb3: // pred: ^bb2
// CHECK-NEXT: %8 = spirv.Load "Function" %2 : si32
// CHECK-NEXT: %cst1_si32_0 = spirv.Constant 1 : si32
// CHECK-NEXT: %9 = spirv.IAdd %8, %cst1_si32_0 : si32
// CHECK-NEXT: spirv.Store "Function" %2, %9 : si32
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb3: // pred: ^bb1
// CHECK-NEXT:^bb4: // pred: ^bb1
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
// CHECK-NEXT:}
for (int i = 0; i < 10; ++i) {
int a = i + 1;
}

// TODO: file check embedded loops
for (int i = 0; i < 10; ++i) {
for (int j = 0; j < 20; ++j) {
int a = i + j;
}
}
}
16 changes: 9 additions & 7 deletions test/CodeGen/cf_loops_while.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@ void main() {

// CHECK: spirv.mlir.loop {
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb2
// CHECK-NEXT: %2 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %2, ^bb2, ^bb3
// CHECK-NEXT: ^bb1: // 2 preds: ^bb0, ^bb3
// CHECK-NEXT: %4 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %4, ^bb2, ^bb4
// CHECK-NEXT: ^bb2: // pred: ^bb1
while (a) {
int c = 2;
int d = 3;

// CHECK: spirv.Store "Function" %1, %7 : si32
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK: spirv.Store "Function" %1, %9 : si32
// CHECK-NEXT: spirv.Branch ^bb3
b = c + d;
}

// CHECK: ^bb3: // pred: ^bb1
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: spirv.Branch ^bb1
// CHECK-NEXT: ^bb4: // pred: ^bb1
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
}
38 changes: 38 additions & 0 deletions test/CodeGen/cf_loops_while_break.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
void main() {
// Hidden break/continue control vars

// CHECK: %0 = spirv.Variable : !spirv.ptr<i1, Function>
// CHECK-NEXT: %1 = spirv.Variable : !spirv.ptr<i1, Function>
while (true) {
// CHECK: %cst1_si32 = spirv.Constant 1 : si32
// CHECK-NEXT: %2 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %2, %cst1_si32 : si32
int someVarBefore = 1;

// CHECK: ^bb1: // pred: ^bb0
// CHECK-NEXT: %true_2 = spirv.Constant true
// CHECK-NEXT: spirv.Store "Function" %0, %true_2 : i1
if (true) {
break;
}

// CHECK: spirv.mlir.merge
// CHECK-NEXT: }
// CHECK-NEXT: %3 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb3

// CHECK: ^bb3: // pred: ^bb2
// CHECK-NEXT: %false = spirv.Constant false

// Reset break control var

// CHECK-NEXT: spirv.Store "Function" %0, %false : i1
// CHECK-NEXT: %cst1_si32_1 = spirv.Constant 1 : si32
// CHECK-NEXT: %4 = spirv.Variable : !spirv.ptr<si32, Function>
int someVarAfter = 1;
}

// CHECK: ^bb5: // 2 preds: ^bb1, ^bb2
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
}
44 changes: 44 additions & 0 deletions test/CodeGen/cf_loops_while_continue.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
void main() {
// Hidden break/continue control vars

// CHECK: %0 = spirv.Variable : !spirv.ptr<i1, Function>
// CHECK-NEXT: %1 = spirv.Variable : !spirv.ptr<i1, Function>
while (true) {
// CHECK: %cst1_si32 = spirv.Constant 1 : si32
// CHECK-NEXT: %2 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: spirv.Store "Function" %2, %cst1_si32 : si32
int someVarBefore = 1;

// CHECK: ^bb1: // pred: ^bb0
// CHECK-NEXT: %true_3 = spirv.Constant true
// CHECK-NEXT: spirv.Store "Function" %1, %true_3 : i1
if (true) {
continue;
// CHECK: ^bb2: // pred: ^bb0
// CHECK-NEXT: %true_4 = spirv.Constant true
// CHECK-NEXT: spirv.Store "Function" %0, %true_4 : i1
} else {
break;
}

// CHECK: spirv.mlir.merge
// CHECK-NEXT: }
// CHECK-NEXT: %3 = spirv.Load "Function" %1 : i1
// CHECK-NEXT: spirv.BranchConditional %3, ^bb5, ^bb3
// CHECK-NEXT: ^bb3: // pred: ^bb2
// CHECK-NEXT: %4 = spirv.Load "Function" %0 : i1
// CHECK-NEXT: spirv.BranchConditional %4, ^bb6, ^bb4

// Reset continue/break control vars
// CHECK: ^bb4: // pred: ^bb3
// CHECK-NEXT: %false = spirv.Constant false
// CHECK-NEXT: spirv.Store "Function" %0, %false : i1
// CHECK-NEXT: %false_1 = spirv.Constant false
// CHECK-NEXT: spirv.Store "Function" %1, %false_1 : i1
int someVarAfter = 1;
}

// CHECK: ^bb6: // 2 preds: ^bb1, ^bb3
// CHECK-NEXT: spirv.mlir.merge
// CHECK-NEXT: }
}
38 changes: 19 additions & 19 deletions test/CodeGen/scopes.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ void main() {
// CHECK: %1 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %2 = spirv.IEqual %1, %cst1_si32 : si32
if (a == 1) {
// CHECK: %5 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK: %7 = spirv.Variable : !spirv.ptr<si32, Function>
int a;

// CHECK: %cst2_si32_1 = spirv.Constant 2 : si32
// CHECK-NEXT: spirv.Store "Function" %5, %cst2_si32_1 : si32
// CHECK-NEXT: spirv.Store "Function" %7, %cst2_si32_1 : si32
a = 2;
} else {
// CHECK: %6 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK: %8 = spirv.Variable : !spirv.ptr<si32, Function>
int a;

// CHECK: %cst3_si32 = spirv.Constant 3 : si32
// CHECK-NEXT: spirv.Store "Function" %6, %cst3_si32 : si32
// CHECK-NEXT: spirv.Store "Function" %8, %cst3_si32 : si32
a = 3;
}

// CHECK: %cst2_si32 = spirv.Constant 2 : si32
// CHECK-NEXT: spirv.Store "Function" %0, %cst2_si32 : si32
a = 2;
Expand All @@ -36,14 +36,15 @@ void main() {
*
*/

// CHECK: %5 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %6 = spirv.IEqual %5, %cst1_si32_1 : si32
// CHECK: %7 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %8 = spirv.IEqual %7, %cst1_si32_1 : si32
while (a == 1) {
// CHECK: %7 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK: %9 = spirv.Variable : !spirv.ptr<si32, Function>
int a;


// CHECK: %cst5_si32 = spirv.Constant 5 : si32
// CHECK-NEXT: spirv.Store "Function" %7, %cst5_si32 : si32
// CHECK-NEXT: spirv.Store "Function" %9, %cst5_si32 : si32
a = 5;
}

Expand All @@ -57,22 +58,21 @@ void main() {
*
*/

// CHECK: %3 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %4 = spirv.IEqual %3, %cst1_si32_0 : si32
// CHECK: %5 = spirv.Load "Function" %0 : si32
// CHECK-NEXT: %6 = spirv.IEqual %5, %cst1_si32_0 : si32
if (a == 1) {
// CHECK: %5 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK: %7 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: %cst1_si32_1 = spirv.Constant 1 : si32
// CHECK-NEXT: spirv.Store "Function" %5, %cst1_si32_1 : si32
// CHECK-NEXT: spirv.Store "Function" %7, %cst1_si32_1 : si32
int a;
a = 1;


// CHECK: %6 = spirv.Load "Function" %5 : si32
// CHECK-NEXT: %7 = spirv.IEqual %6, %cst2_si32_2 : si32
// CHECK: %8 = spirv.Load "Function" %7 : si32
// CHECK-NEXT: %9 = spirv.IEqual %8, %cst2_si32_2 : si32
if (a == 2) {
// CHECK: %8 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK: %10 = spirv.Variable : !spirv.ptr<si32, Function>
// CHECK-NEXT: %cst2_si32_3 = spirv.Constant 2 : si32
// CHECK-NEXT: spirv.Store "Function" %8, %cst2_si32_3 : si32
// CHECK-NEXT: spirv.Store "Function" %10, %cst2_si32_3 : si32
int a;
a = 2;
}
Expand Down

0 comments on commit 57339ce

Please sign in to comment.