Skip to content

Commit

Permalink
Move vertex shader out parameters to return type for Metal
Browse files Browse the repository at this point in the history
Closes #6025
  • Loading branch information
expipiplus1 committed Feb 26, 2025
2 parents f7b9745 + 651fc8b commit 5a879ee
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 0 deletions.
26 changes: 26 additions & 0 deletions source/slang/slang-ir-legalize-varying-params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
#include "slang-ir-lower-out-parameters.h"
#include "slang-ir-util.h"
#include "slang-parameter-binding.h"

Expand Down Expand Up @@ -3925,11 +3926,36 @@ class LegalizeWGSLEntryPointContext : public LegalizeShaderEntryPointContext
const UnownedStringSlice userSemanticName = toSlice("user_semantic");
};

void legalizeVertexShaderOutputParamsForMetal(DiagnosticSink* sink, EntryPointInfo& entryPoint)
{
const auto oldFunc = entryPoint.entryPointFunc;
entryPoint.entryPointFunc = lowerOutParameters(oldFunc, sink);

// Since this will no longer be the entry point function, remove those decorations
List<IRDecoration*> ds;
for (auto decor : oldFunc->getDecorations())
{
if (as<IRKeepAliveDecoration>(decor) || as<IREntryPointDecoration>(decor))
{
ds.add(decor);
}
}

for (auto decor : ds)
{
decor->removeFromParent();
}
}

void legalizeEntryPointVaryingParamsForMetal(
IRModule* module,
DiagnosticSink* sink,
List<EntryPointInfo>& entryPoints)
{
for (auto& e : entryPoints)
{
legalizeVertexShaderOutputParamsForMetal(sink, e);
}
LegalizeMetalEntryPointContext context(module, sink);
context.legalizeEntryPoints(entryPoints);
}
Expand Down
196 changes: 196 additions & 0 deletions source/slang/slang-ir-lower-out-parameters.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
#include "slang-ir-lower-out-parameters.h"

#include "slang-ir-clone.h"
#include "slang-ir-inline.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
#include "slang-ir.h"

namespace Slang
{
IRFunc* lowerOutParameters(IRFunc* func, DiagnosticSink*)
{
IRBuilder builder(func->getModule());
IRCloneEnv cloneEnv;

// Collect types for the new function
List<IRType*> returnTypes;
List<IRType*> paramTypes;

struct VarInfo
{
IRVar* var;
IRParam* origParam;
bool isInOut;
};
List<VarInfo> outVars;

// If original function returns non-void, add it to tuple types
if (!as<IRVoidType>(func->getResultType()))
returnTypes.add(func->getResultType());

// Process parameters
for (auto param : func->getParams())
{
if (auto outType = as<IROutTypeBase>(param->getDataType()))
{
returnTypes.add(outType->getValueType());

if (outType->getOp() == kIROp_InOutType)
{
paramTypes.add(outType->getValueType());
}

outVars.add(VarInfo{nullptr, param, outType->getOp() == kIROp_InOutType});
}
else
{
paramTypes.add(param->getDataType());
}
}

// Create new function
auto newFunc = builder.createFunc();

// Copy all decorations except name hint
for (auto decor : func->getDecorations())
{
cloneDecoration(&cloneEnv, decor, newFunc, builder.getModule());
}

// Copy modifiers
// newFunc->setModifiers(func->getModifiers());

// Determine result type
IRType* resultType;
if (returnTypes.getCount() > 1)
{
resultType = builder.getTupleType(returnTypes);
}
else if (returnTypes.getCount() == 1)
{
resultType = returnTypes[0];
}
else
{
resultType = builder.getVoidType();
}

auto funcType = builder.getFuncType(paramTypes, resultType);
newFunc->setFullType(funcType);

auto firstBlock = builder.createBlock();
newFunc->addBlock(firstBlock);
builder.setInsertInto(firstBlock);

// Create parameters and track them
List<IRParam*> newParams;
for (auto param : func->getParams())
{
if (auto outType = as<IROutTypeBase>(param->getDataType()))
{
if (outType->getOp() == kIROp_InOutType)
{
auto newParam = builder.emitParam(outType->getValueType());
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
builder.addNameHintDecoration(newParam, nameHint->getName());
newParams.add(newParam);
}
}
else
{
auto newParam = builder.emitParam(param->getDataType());
if (auto nameHint = param->findDecoration<IRNameHintDecoration>())
builder.addNameHintDecoration(newParam, nameHint->getName());
newParams.add(newParam);
}
}

// Create vars for out/inout parameters
for (auto& varInfo : outVars)
{
auto outType = as<IROutTypeBase>(varInfo.origParam->getDataType());
varInfo.var = builder.emitVar(outType->getValueType());

if (varInfo.isInOut)
{
for (auto newParam : newParams)
{
if (auto nameHint = varInfo.origParam->findDecoration<IRNameHintDecoration>())
{
if (auto newNameHint = newParam->findDecoration<IRNameHintDecoration>())
{
if (nameHint->getName() == newNameHint->getName())
{
builder.emitStore(varInfo.var, newParam);
break;
}
}
}
}
}
}

// Build call to original function
List<IRInst*> args;
int newParamIndex = 0;
int outVarIndex = 0;

for (auto param : func->getParams())
{
if (auto outType = as<IROutTypeBase>(param->getDataType()))
{
args.add(outVars[outVarIndex++].var);
}
else
{
args.add(newParams[newParamIndex++]);
}
}

IRCall* callResult = builder.emitCallInst(func->getResultType(), func, args);

// If original function has only one use, inline it
int useCount = 0;
for (auto use = func->firstUse; use; use = use->nextUse)
{
useCount++;
}
if (useCount == 1)
{
inlineCall(callResult);
}

// Construct return tuple
List<IRInst*> tupleValues;

if (!as<IRVoidType>(func->getResultType()))
{
tupleValues.add(callResult);
}

for (auto& varInfo : outVars)
{
tupleValues.add(builder.emitLoad(varInfo.var));
}

IRInst* returnValue;
if (tupleValues.getCount() > 1)
{
returnValue = builder.emitMakeTuple(tupleValues);
}
else if (tupleValues.getCount() == 1)
{
returnValue = tupleValues[0];
}
else
{
returnValue = nullptr;
}

builder.emitReturn(returnValue);

return newFunc;
}

} // namespace Slang
12 changes: 12 additions & 0 deletions source/slang/slang-ir-lower-out-parameters.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "slang-ir.h"

namespace Slang
{
struct IRModule;
class DiagnosticSink;

IRFunc* lowerOutParameters(IRFunc* func, DiagnosticSink* sink);

} // namespace Slang
3 changes: 3 additions & 0 deletions source/slang/slang-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3970,7 +3970,10 @@ static TypeCastStyle _getTypeStyleId(IRType* type)
{
return _getTypeStyleId(matrixType->getElementType());
}
// Try to simplify style if we can, otherwise just handle it unsimplified
auto style = getTypeStyle(type->getOp());
if (style == kIROp_Invalid)
style = type->getOp();
switch (style)
{
case kIROp_IntType:
Expand Down
16 changes: 16 additions & 0 deletions tests/metal/simple-vertex-position.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//TEST:SIMPLE(filecheck=METAL): -target metal -stage vertex -entry vertexMain
//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage vertex -entry vertexMain
//TEST:SIMPLE(filecheck=WGSL): -target wgsl -stage vertex -entry vertexMain
//TEST:SIMPLE(filecheck=WGSLSPIRV): -target wgsl-spirv-asm -stage vertex -entry vertexMain

//METAL: position
//METALLIB: @vertexMain

//WGSL: @builtin(position)
//WGSLSPIRV: %vertexMain = OpFunction

// Vertex Shader which writes to position
void vertexMain(out float4 position : SV_Position)
{
position = float4(0.6, 0.1, 0.6, 0.33);
}

0 comments on commit 5a879ee

Please sign in to comment.