Skip to content

Commit

Permalink
Add builtin GL functions (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 authored Sep 10, 2024
1 parent 00ffcd4 commit 8c15b9a
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 2 deletions.
6 changes: 6 additions & 0 deletions include/CodeGen/MLIRCodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include <vector>
#include <map>
#include <utility>
#include <functional>
#include <unordered_map>

using namespace mlir;

Expand Down Expand Up @@ -90,13 +92,17 @@ class MLIRCodeGen : public ASTVisitor {
symbolTable;
using SymbolTableScopeT =
llvm::ScopedHashTableScope<llvm::StringRef, SymbolTableEntry>;
using BuiltInFunc = std::function<mlir::Value(mlir::OpBuilder &, mlir::ValueRange)>;

std::unordered_map<std::string, BuiltInFunc> builtInFuncMap;
SymbolTableScopeT globalScope;
SmallVector<Attribute, 4> interface;

void declare(StringRef name, SymbolTableEntry entry);
void createVariable(shaderpulse::Type *, VariableDeclaration *);
void insertEntryPoint();
void initBuiltinFuncMap();
bool callBuiltIn(CallExpression* exp);
mlir::Value load(mlir::Value);

std::pair<shaderpulse::Type*, Value> popExpressionStack();
Expand Down
138 changes: 136 additions & 2 deletions lib/CodeGen/MLIRCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,121 @@ namespace codegen {
MLIRCodeGen::MLIRCodeGen() : builder(&context), globalScope(symbolTable) {
context.getOrLoadDialect<spirv::SPIRVDialect>();
initModuleOp();
initBuiltinFuncMap();
}

void MLIRCodeGen::initBuiltinFuncMap() {
builtInFuncMap = {
{"acos", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLAcosOp>(builder.getUnknownLoc(), operands[0]);
}},
{"asin", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLAsinOp>(builder.getUnknownLoc(), operands[0]);
}},
{"atan", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLAtanOp>(builder.getUnknownLoc(), operands[0]);
}},
{"ceil", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLCeilOp>(builder.getUnknownLoc(), operands[0]);
}},
{"cos", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLCosOp>(builder.getUnknownLoc(), operands[0]);
}},
{"cosh", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLCoshOp>(builder.getUnknownLoc(), operands[0]);
}},
{"exp", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLExpOp>(builder.getUnknownLoc(), operands[0]);
}},
{"fabs", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLFAbsOp>(builder.getUnknownLoc(), operands[0]);
}},
{"fclamp", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLFClampOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1], operands[2]);
}},
{"fmax", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLFMaxOp>(builder.getUnknownLoc(),operands[0].getType(), operands[0], operands[1]);
}},
{"fmin", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLFMinOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1]);
}},
{"fmix", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLFMixOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1], operands[2]);
}},
{"fsign", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLFSignOp>(builder.getUnknownLoc(), operands[0]);
}},
{"findumsb", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLFindUMsbOp>(builder.getUnknownLoc(), operands[0]);
}},
{"floor", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLFloorOp>(builder.getUnknownLoc(), operands[0]);
}},
{"fma", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLFmaOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1], operands[2]);
}},
{"frexpstruct", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
// TODO: implement me
return mlir::Value();
}},
{"inversesqrt", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLInverseSqrtOp>(builder.getUnknownLoc(), operands[0]);
}},
{"ldexp", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLLdexpOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1]);
}},
{"log", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLLogOp>(builder.getUnknownLoc(), operands[0]);
}},
{"pow", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLPowOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1]);
}},
{"roundeven", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLRoundEvenOp>(builder.getUnknownLoc(), operands[0]);
}},
{"round", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLRoundOp>(builder.getUnknownLoc(), operands[0]);
}},
{"sabs", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLSAbsOp>(builder.getUnknownLoc(), operands[0]);
}},
{"sclamp", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLSClampOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1], operands[2]);
}},
{"smax", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLSMaxOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1]);
}},
{"smin", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLSMinOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1]);
}},
{"ssign", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLSSignOp>(builder.getUnknownLoc(), operands[0]);
}},
{"sin", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLSinOp>(builder.getUnknownLoc(), operands[0]);
}},
{"sinh", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLSinhOp>(builder.getUnknownLoc(), operands[0]);
}},
{"sqrt", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLSqrtOp>(builder.getUnknownLoc(), operands[0]);
}},
{"tan", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLTanOp>(builder.getUnknownLoc(), operands[0]);
}},
{"tanh", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLTanhOp>(builder.getUnknownLoc(), operands[0]);
}},
{"uclamp", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLUClampOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1], operands[2]);
}},
{"umax", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLUMaxOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1]);
}},
{"umin", [](mlir::OpBuilder &builder, mlir::ValueRange operands) {
return builder.create<spirv::GLUMinOp>(builder.getUnknownLoc(), operands[0].getType(), operands[0], operands[1]);
}}
};
}

void MLIRCodeGen::initModuleOp() {
Expand Down Expand Up @@ -839,8 +954,7 @@ void MLIRCodeGen::visit(CallExpression *callExp) {
// TODO: get return type of callee
expressionStack.push_back(std::make_pair(nullptr, funcCall.getResult(0)));
} else {
std::cout << "Function not found." << callExp->getFunctionName()
<< std::endl;
assert(callBuiltIn(callExp) && "Function not found");
}
}

Expand Down Expand Up @@ -1011,6 +1125,26 @@ mlir::Value MLIRCodeGen::load(mlir::Value val) {
return val;
}

bool MLIRCodeGen::callBuiltIn(CallExpression* exp) {
auto builtinFuncIt = builtInFuncMap.find(exp->getFunctionName());

if (builtinFuncIt != builtInFuncMap.end()) {
std::vector<mlir::Value> operands;
std::vector<shaderpulse::Type*> types;

for (auto &arg : exp->getArguments()) {
arg->accept(this);
auto typeValPair = popExpressionStack();
types.push_back(typeValPair.first);
operands.push_back(load(typeValPair.second));
}
expressionStack.push_back(std::make_pair(types[0], builtinFuncIt->second(builder, operands)));
return true;
} else {
return false;
}
}

}; // namespace codegen

}; // namespace shaderpulse
20 changes: 20 additions & 0 deletions test/CodeGen/functions_builtin.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// CHECK: spirv.func @main() "None" {
void main() {
float a = sqrt(2.0);

// CHECK: %cst_f32_0 = spirv.Constant 2.000000e+00 : f32
// CHECK-NEXT: %2 = spirv.GL.InverseSqrt %cst_f32_0 : f32
a = inversesqrt(2.0);

// CHECK: %cst_f32_1 = spirv.Constant 1.500000e+00 : f32
// CHECK-NEXT: %3 = spirv.GL.Sin %cst_f32_1 : f32
a = sin(1.5);

// CHECK: %cst_f32_2 = spirv.Constant 3.140000e+00 : f32
// CHECK-NEXT: %4 = spirv.GL.Cos %cst_f32_2 : f32
a = cos(3.14);

// CHECK: %cst_f32_3 = spirv.Constant 1.000000e+00 : f32
// CHECK-NEXT: %5 = spirv.GL.Tan %cst_f32_3 : f32
a = tan(1.0);
}

0 comments on commit 8c15b9a

Please sign in to comment.