Skip to content

Commit

Permalink
Custom config for dequantized matvec
Browse files Browse the repository at this point in the history
  • Loading branch information
Max191 authored and github-actions[bot] committed Oct 9, 2023
1 parent dff25bd commit a3093f9
Showing 1 changed file with 171 additions and 0 deletions.
171 changes: 171 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1683,6 +1683,173 @@ static LogicalResult setElementwiseGenericOpRootConfig(
tileSizes, passPipeline);
}

// Checks if the passed op is a dequantization on grouped input
// This function checks that the genericOp:
// 1. Has a body like:
// arith.extui
// arith.uitofp
// arith.subf
// arith.mulf
// arith.mulf
// arith.addf
// 2. Increases the bit width of the input
// 3. Has 3 parallel dims
// 4. Has 4 (rhs, weights, scales, zero points)
// inputs and 1 output
static bool isGroupedDequantizationMatvecOp(linalg::GenericOp genericOp) {
// Check for 1 result, and 2 (input, scales) or 3 (input, scales, zero points)
// inputs
if (genericOp.getNumDpsInits() != 1) {
LLVM_DEBUG(KD_DBGS() << "Wrong number of outputs: "
<< genericOp.getNumDpsInits() << "\n");
return false;
}
if (genericOp.getNumDpsInputs() != 4) {
LLVM_DEBUG(KD_DBGS() << "Wrong number of inputs: "
<< genericOp.getNumDpsInputs() << "\n");
return false;
}

// Check that the rank is at least 3 and all loops are parallel
unsigned numLoops = genericOp.getNumLoops();
unsigned numReductionLoops = genericOp.getNumReductionLoops();
if (numLoops != 4) {
LLVM_DEBUG(KD_DBGS() << "Wrong number of loops: " << numLoops << "\n");
return false;
}
if (numReductionLoops != 2) {
LLVM_DEBUG(KD_DBGS() << "Wrong number of reduction loops: "
<< numReductionLoops << "\n");
return false;
}
// Work back from linalg.yield and check body of genericOp.
// The genericOp should yield the result of an arith.mulf,
// preceded by an arith.subf, arith.uitofp, and arith.extui
auto yieldOp = cast<linalg::YieldOp>(genericOp.getBody()->getTerminator());
Value producerOutput;
Operation *producer;

// Producer of linalg.yield op is arith.addf
{
producerOutput = yieldOp->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0)
return false;
if (!matchPattern(producer, m_Op<arith::AddFOp>()))
return false;
}

// Producer of arith.addf op is arith.mulf
{
producerOutput = producer->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0)
return false;
if (!matchPattern(producer, m_Op<arith::MulFOp>()))
return false;
}

// Producer of arith.mulf op is arith.mulf
{
producerOutput = producer->getOperand(1);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0)
return false;
if (!matchPattern(producer, m_Op<arith::MulFOp>()))
return false;
}

// Producer of arith.mulf op is arith.subf
{
producerOutput = producer->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0)
return false;
if (!matchPattern(producer, m_Op<arith::SubFOp>()))
return false;
}

// Producer of arith.subf op is arith.uitofp
{
producerOutput = producer->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer || producer->getNumOperands() == 0)
return false;
if (!matchPattern(producer, m_Op<arith::UIToFPOp>()))
return false;
}

// Producer of arith.uitofp op is arith.extui
{
producerOutput = producer->getOperand(0);
producer = producerOutput.getDefiningOp();
if (!producer)
return false;
if (!matchPattern(producer, m_Op<arith::ExtUIOp>()))
return false;
}

// Ensure that the dequantization increases the
// bitwidth from the input to the output
auto elementTypeOut =
llvm::cast<ShapedType>(genericOp.getOutputs()[0].getType())
.getElementType();
if (!elementTypeOut.isIntOrFloat())
return false;
unsigned bitWidthOut = elementTypeOut.getIntOrFloatBitWidth();
auto elementTypeIn =
llvm::cast<ShapedType>(genericOp.getInputs()[1].getType())
.getElementType();
if (!elementTypeIn.isIntOrFloat())
return false;
unsigned bitWidthIn = elementTypeIn.getIntOrFloatBitWidth();
if (bitWidthIn >= bitWidthOut)
return false;

return true;
}

/// Sets linalg.generic ops that represent rematerialized dequantized matvec
/// ContractionOpInterface RootConfig
static LogicalResult setDequantizationMatvecOpRootConfig(
func::FuncOp entryPointFn, linalg::GenericOp genericOp,
const LinalgOpInfo &linalgOpInfo,
const TargetMLTransformInfo &targetMLTransInfo) {
assert(!getLoweringConfig(genericOp) &&
"expected lowering_config is not set");
unsigned numLoops = genericOp.getNumLoops();
if (!isGroupedDequantizationMatvecOp(genericOp)) {
LLVM_DEBUG(KD_DBGS() << "Failed matching for dequantized matvec\n");
return failure();
}

SmallVector<int64_t> distTileSizes = {32, 32, 0, 0};
SmallVector<int64_t> parallelTileSizes = {1, 1, 0, 0};
SmallVector<int64_t> reductionTileSizes = {0, 0, 1, 64};

SmallVector<unsigned> reductionDims;
genericOp.getReductionDims(reductionDims);
SmallVector<int64_t, 4> bounds = genericOp.getStaticLoopRanges();

TileSizesListType tileSizes;
tileSizes.push_back(distTileSizes);
tileSizes.push_back(parallelTileSizes);
tileSizes.push_back(reductionTileSizes);
tileSizes.emplace_back(numLoops, 0);

LLVM_DEBUG(KD_DBGS() << "Setting dequantized matmul config\n");
LLVM_DEBUG(KD_DBGS() << "Distribution tile sizes: " << distTileSizes << "\n");
LLVM_DEBUG(KD_DBGS() << "Parallel tile sizes: " << parallelTileSizes << "\n");
LLVM_DEBUG(KD_DBGS() << "Reduction tile size: " << reductionTileSizes
<< "\n");

DispatchLoweringPassPipeline passPipeline =
DispatchLoweringPassPipeline::CPUDoubleTilingExpert;

return setOpConfigAndEntryPointFnTranslation(entryPointFn, genericOp,
tileSizes, passPipeline);
}

/// Sets the lowering configuration for a generic op to use
/// CPUDoubleTilingExpert pipeline.
static LogicalResult
Expand All @@ -1705,6 +1872,10 @@ setRootConfig(func::FuncOp entryPointFn, linalg::GenericOp genericOp,
entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo))) {
return success();
}
if (succeeded(setDequantizationMatvecOpRootConfig(
entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo))) {
return success();
}
if (succeeded(setDefaultGenericOpRootConfig(
entryPointFn, genericOp, linalgOpInfo, targetMLTransInfo))) {
return success();
Expand Down

0 comments on commit a3093f9

Please sign in to comment.