Skip to content

Commit

Permalink
Add CLI args for ast, codegen, use type->toString() in printer, fix v…
Browse files Browse the repository at this point in the history
…oid toString()
  • Loading branch information
wpmed92 committed Aug 7, 2024
1 parent b15ce4b commit 722fe13
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
4 changes: 3 additions & 1 deletion include/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class Type {
return "mat";
case TypeKind::Struct:
return "struct";
case TypeKind::Void:
return "void";
default:
return "opaque";
}
Expand Down Expand Up @@ -351,7 +353,7 @@ class StructType : public Type {
}

std::string toString() override {
return structName;
return "struct '" + structName + "'";
}

private:
Expand Down
9 changes: 4 additions & 5 deletions lib/AST/PrinterASTVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include "AST/Util.h"
#include "AST/PrinterASTVisitor.h"
#include <iostream>
#include "../../utils/include/magic_enum.hpp"

namespace shaderpulse {

Expand Down Expand Up @@ -42,7 +41,7 @@ void PrinterASTVisitor::visit(TranslationUnit *unit) {


void PrinterASTVisitor::visit(VariableDeclarationList *varDeclList) {
print("|-VariableDeclarationList: type=" + std::string(magic_enum::enum_name(varDeclList->getType()->getKind())));
print("|-VariableDeclarationList: type=" + varDeclList->getType()->toString());
indent();

for (auto &var : varDeclList->getDeclarations()) {
Expand All @@ -53,7 +52,7 @@ void PrinterASTVisitor::visit(VariableDeclarationList *varDeclList) {
}

void PrinterASTVisitor::visit(VariableDeclaration *varDecl) {
auto typeName = varDecl->getType() != nullptr ? std::string(magic_enum::enum_name(varDecl->getType()->getKind())) : "";
auto typeName = varDecl->getType() != nullptr ? varDecl->getType()->toString() : "";
print("|-VariableDeclaration: name=" + varDecl->getIdentifierName() + ((typeName != "") ? (", type=" + typeName) : ""));

indent();
Expand Down Expand Up @@ -269,14 +268,14 @@ void PrinterASTVisitor::visit(DiscardStatement *discardStmt) {
}

void PrinterASTVisitor::visit(FunctionDeclaration *funcDecl) {
print("|-FunctionDeclaration: name=" + funcDecl->getName() + ", return type=" + std::string(magic_enum::enum_name(funcDecl->getReturnType()->getKind())));
print("|-FunctionDeclaration: name=" + funcDecl->getName() + ", return type=" + funcDecl->getReturnType()->toString());
indent();

print("|-Args:");
indent();

for (auto &arg : funcDecl->getParams()) {
print("name=" + arg->getName() + ", type=" + std::string(magic_enum::enum_name(arg->getType()->getKind())));
print("name=" + arg->getName() + ", type=" + arg->getType()->toString());
}

resetIndent();
Expand Down
41 changes: 32 additions & 9 deletions standalone/shaderpulse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ int main(int argc, char** argv) {
return -1;
}

bool printAST = false;
bool codeGen = true;

for (size_t i = 2; i < argc; i++) {
std::string arg = argv[i];

if (arg == "--print-ast") {
printAST = true;
} else if (arg == "--no-codegen") {
codeGen = false;
} else {
std::cout << "Unrecognized argument: '" << arg << "'." << std::endl;
return -1;
}
}


std::ifstream glslIn(argv[1]);
std::stringstream shaderCodeBuffer;
shaderCodeBuffer << glslIn.rdbuf();
Expand All @@ -59,17 +76,23 @@ int main(int argc, char** argv) {
auto &tokens = (*resp).get();
auto parser = Parser(tokens);
auto translationUnit = parser.parseTranslationUnit();
auto mlirCodeGen = std::make_unique<codegen::MLIRCodeGen>();

translationUnit->accept(mlirCodeGen.get());
mlirCodeGen->dump();

if (mlirCodeGen->verify()) {
std::cout << "SPIR-V module verified" << std::endl;

if (printAST) {
auto printer = PrinterASTVisitor();
translationUnit->accept(&printer);
}

auto checker = std::make_unique<SemanticAnalyzer>();
translationUnit->accept(checker.get());
if (codeGen) {
auto analyzer = SemanticAnalyzer();
translationUnit->accept(&analyzer);
auto mlirCodeGen = codegen::MLIRCodeGen();
translationUnit->accept(&mlirCodeGen);
mlirCodeGen.dump();

if (mlirCodeGen.verify()) {
std::cout << "SPIR-V module verified" << std::endl;
}
}

return 0;
}

0 comments on commit 722fe13

Please sign in to comment.