Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move vertex shader out paramters to return type for Metal #6464

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions source/slang/slang-ir-legalize-varying-params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
#include "slang-ir-lower-out-parameters.h"
#include "slang-ir-util.h"
#include "slang-parameter-binding.h"

Expand Down Expand Up @@ -3925,11 +3926,36 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext
const UnownedStringSlice userSemanticName = toSlice("user_semantic");
};

void legalizeVertexShaderOutputParamsForMetal(DiagnosticSink* sink, EntryPointInfo& entryPoint)
{
const auto oldFunc = entryPoint.entryPointFunc;
entryPoint.entryPointFunc = lowerOutParameters(oldFunc, sink);

// Since this will no longer be the entry point function, remove those decorations
List<IRDecoration*> ds;
for (auto decor : oldFunc->getDecorations())
{
if (as<IRKeepAliveDecoration>(decor) || as<IREntryPointDecoration>(decor))
{
ds.add(decor);
}
}

for (auto decor : ds)
{
decor->removeFromParent();
}
}

void legalizeEntryPointVaryingParamsForMetal(
IRModule* module,
DiagnosticSink* sink,
List<EntryPointInfo>& entryPoints)
{
for (auto& e : entryPoints)
{
legalizeVertexShaderOutputParamsForMetal(sink, e);
}
LegalizeMetalEntryPointContext context(module, sink);
context.legalizeEntryPoints(entryPoints);
}
Expand Down
193 changes: 193 additions & 0 deletions source/slang/slang-ir-lower-out-parameters.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#include "slang-ir-lower-out-parameters.h"

#include "slang-ir-clone.h"
#include "slang-ir-inline.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
#include "slang-ir.h"

namespace Slang
{
IRFunc* lowerOutParameters(IRFunc* func, DiagnosticSink*)
{
IRBuilder builder(func->getModule());
IRCloneEnv cloneEnv;

// Collect types for the new function
List<IRType*> returnTypes;
List<IRType*> paramTypes;

struct VarInfo
{
IRVar* var;
IRParam* origParam;
bool isInOut;
};
List<VarInfo> outVars;

// If original function returns non-void, add it to tuple types
if (!as<IRVoidType>(func->getResultType()))
returnTypes.add(func->getResultType());

// Process parameters
for (auto param : func->getParams())
{
if (auto outType = as<IROutTypeBase>(param->getDataType()))
{
returnTypes.add(outType->getValueType());

if (outType->getOp() == kIROp_InOutType)
{
paramTypes.add(outType->getValueType());
}

outVars.add(VarInfo{nullptr, param, outType->getOp() == kIROp_InOutType});
}
else
{
paramTypes.add(param->getDataType());
}
}

// Create new function
auto newFunc = builder.createFunc();

// Copy all decorations except name hint
for (auto decor : func->getDecorations())
{
cloneDecoration(&cloneEnv, decor, newFunc, builder.getModule());
}

// Determine result type
IRType* resultType;
if (returnTypes.getCount() > 1)
{
resultType = builder.getTupleType(returnTypes);
}
else if (returnTypes.getCount() == 1)
{
resultType = returnTypes[0];
}
else
{
resultType = builder.getVoidType();
}

auto funcType = builder.getFuncType(paramTypes, resultType);
newFunc->setFullType(funcType);

auto firstBlock = builder.createBlock();
newFunc->addBlock(firstBlock);
builder.setInsertInto(firstBlock);

// Create parameters and track them
List<IRParam*> newParams;
for (auto param : func->getParams())
{
if (auto outType = as<IROutTypeBase>(param->getDataType()))
{
if (outType->getOp() == kIROp_InOutType)
{
auto newParam = builder.emitParam(outType->getValueType());
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
builder.addNameHintDecoration(newParam, nameHint->getName());
newParams.add(newParam);
}
}
else
{
auto newParam = builder.emitParam(param->getDataType());
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
builder.addNameHintDecoration(newParam, nameHint->getName());
newParams.add(newParam);
}
}

// Create vars for out/inout parameters
for (auto& varInfo : outVars)
{
auto outType = as<IROutTypeBase>(varInfo.origParam->getDataType());
varInfo.var = builder.emitVar(outType->getValueType());

if (varInfo.isInOut)
{
for (auto newParam : newParams)
{
if (auto nameHint = varInfo.origParam->findDecoration<IRNameHintDecoration>())
{
if (auto newNameHint = newParam->findDecoration<IRNameHintDecoration>())
{
if (nameHint->getName() == newNameHint->getName())
{
builder.emitStore(varInfo.var, newParam);
break;
}
}
}
}
}
}

// Build call to original function
List<IRInst*> args;
int newParamIndex = 0;
int outVarIndex = 0;

for (auto param : func->getParams())
{
if (as<IROutTypeBase>(param->getDataType()))
{
args.add(outVars[outVarIndex++].var);
}
else
{
args.add(newParams[newParamIndex++]);
}
}

IRCall* callResult = builder.emitCallInst(func->getResultType(), func, args);

// If original function has only one use, inline it
int useCount = 0;
for (auto use = func->firstUse; use; use = use->nextUse)
{
useCount++;
}
if (useCount == 1)
{
inlineCall(callResult);
}

// Construct return tuple
List<IRInst*> tupleValues;

if (!as<IRVoidType>(func->getResultType()))
{
tupleValues.add(callResult);
}

for (auto& varInfo : outVars)
{
tupleValues.add(builder.emitLoad(varInfo.var));
}

IRInst* returnValue;
if (tupleValues.getCount() > 1)
{
returnValue = builder.emitMakeTuple(tupleValues);
}
else if (tupleValues.getCount() == 1)
{
returnValue = tupleValues[0];
}
else
{
returnValue = nullptr;
}

builder.emitReturn(returnValue);

return newFunc;
}

} // namespace Slang
12 changes: 12 additions & 0 deletions source/slang/slang-ir-lower-out-parameters.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "slang-ir.h"

namespace Slang
{
struct IRModule;
class DiagnosticSink;

IRFunc* lowerOutParameters(IRFunc* func, DiagnosticSink* sink);

} // namespace Slang
3 changes: 3 additions & 0 deletions source/slang/slang-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3970,7 +3970,10 @@ static TypeCastStyle _getTypeStyleId(IRType* type)
{
return _getTypeStyleId(matrixType->getElementType());
}
// Try to simplify style if we can, otherwise just handle it unsimplified
auto style = getTypeStyle(type->getOp());
if (style == kIROp_Invalid)
style = type->getOp();
switch (style)
{
case kIROp_IntType:
Expand Down
16 changes: 16 additions & 0 deletions tests/metal/simple-vertex-position.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//TEST:SIMPLE(filecheck=METAL): -target metal -stage vertex -entry vertexMain
//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage vertex -entry vertexMain
//TEST:SIMPLE(filecheck=WGSL): -target wgsl -stage vertex -entry vertexMain
//TEST:SIMPLE(filecheck=WGSLSPIRV): -target wgsl-spirv-asm -stage vertex -entry vertexMain

//METAL: position
//METALLIB: @vertexMain

//WGSL: @builtin(position)
//WGSLSPIRV: %vertexMain = OpFunction

// Vertex Shader which writes to position
void vertexMain(out float4 position : SV_Position)
{
position = float4(0.6, 0.1, 0.6, 0.33);
}
Loading