Skip to content

Commit

Permalink
[LinalgExt] Remove LinalgExt::Softmax and use upstream linalg::softmax (
Browse files Browse the repository at this point in the history
iree-org#15021)

### Improve Codegen health - Use linalg::Softmax instead of
LinalgExt::Softmax

-- This commit gets rid of LinalgExt::Softmax and uses upstream
   linalg::Softmax.
-- This therefore entails using linalg::Softmax's decomposition within
   `--iree-linalg-ext-decompose-softmax` pass.
   
 Signed-off-by: Abhishek Varma <[email protected]>

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma authored Oct 12, 2023
1 parent b3cd60a commit 9c424c4
Show file tree
Hide file tree
Showing 20 changed files with 84 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,6 @@ void registerBufferizationInterfaces(DialectRegistry &registry) {
LinalgExtOpInterface<IREE::LinalgExt::WinogradInputTransformOp>>(*ctx);
IREE::LinalgExt::WinogradOutputTransformOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::WinogradOutputTransformOp>>(*ctx);
IREE::LinalgExt::SoftmaxOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::SoftmaxOp>>(*ctx);
IREE::LinalgExt::AttentionOp::attachInterface<
LinalgExtOpInterface<IREE::LinalgExt::AttentionOp>>(*ctx);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,6 @@ void registerPartitionableLoopsInterfaceModels(DialectRegistry &registry) {
IREE::LinalgExt::WinogradOutputTransformOp::attachInterface<
AllParallelAsPartitionableLoops<
IREE::LinalgExt::WinogradOutputTransformOp>>(*ctx);
IREE::LinalgExt::SoftmaxOp::attachInterface<
AllParallelAsPartitionableLoops<IREE::LinalgExt::SoftmaxOp>>(*ctx);
IREE::LinalgExt::AttentionOp::attachInterface<
AllParallelAsPartitionableLoops<IREE::LinalgExt::AttentionOp>>(*ctx);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> {
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<12x128x40960xf32>> -> tensor<12x128x40960xf32>
%3 = tensor.empty() : tensor<12x128x40960xf32>
%4 = iree_linalg_ext.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32>
%4 = linalg.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : tensor<12x128x40960xf32> -> !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ hal.executable.variant @rocm, target = <"rocm", "rocm-hsaco-fb", {target_arch =
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<12x128x40960xf32>> -> tensor<12x128x40960xf32>
%3 = tensor.empty() : tensor<12x128x40960xf32>
%4 = iree_linalg_ext.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32>
%4 = linalg.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : tensor<12x128x40960xf32> -> !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ hal.executable.variant public @vulkan_spirv_fb, target = #executable_target_vulk
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<12x128x40960xf32>> -> tensor<12x128x40960xf32>
%3 = tensor.empty() : tensor<12x128x40960xf32>
%4 = iree_linalg_ext.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32>
%4 = linalg.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32>
flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : tensor<12x128x40960xf32> -> !flow.dispatch.tensor<writeonly:tensor<12x128x40960xf32>>
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ static int64_t estimateLinalgExtOpCost(Operation *op) {
return cost;
}

// Estimates the evaluation cost of a Linalg::Softmax op using a heuristic cost
// model similar to LinalgExt ops.
static int64_t estimateLinalgSoftmaxOpCost(Operation *op) {
return estimateLinalgExtOpCost(op);
}

// Returns a string like "512xDx128" representing loop ranges.
static std::string loopRangesToString(ArrayRef<int64_t> loopRanges) {
std::string outputString;
Expand Down Expand Up @@ -167,7 +173,9 @@ static std::string summarizeLinalgOp(linalg::LinalgOp op) {

static std::string summarizeLinalgExtOp(Operation *op) {
auto opName = op->getName().getStringRef();
if (!opName.consume_front("iree_linalg_ext."))
// Currently, this utility is also invoked by Linalg::SoftmaxOp.
if (!(opName.consume_front("iree_linalg_ext.") ||
opName.consume_front("linalg.")))
return "";
std::string suffix = "";
if (TensorType mainTensor = getMainTensorForLinalgExtOp(op)) {
Expand Down Expand Up @@ -203,6 +211,15 @@ summarizeDispatchWorkgroupsOp(DispatchWorkgroupsOp regionOp) {
int64_t bestEstimatedCost = kMinEstimatedCost;
regionOp.getWorkgroupBody().walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case<linalg::SoftmaxOp>([&](auto op) {
int64_t estimatedCost = estimateLinalgSoftmaxOpCost(op);
if (estimatedCost < bestEstimatedCost)
return;
bestEstimatedCost = estimatedCost;
bestOp = op;
LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName()
<< "', cost: " << bestEstimatedCost << "\n");
})
.Case<linalg::LinalgOp>([&](auto op) {
int64_t estimatedCost = estimateLinalgOpCost(op);
if (estimatedCost < bestEstimatedCost)
Expand Down Expand Up @@ -259,6 +276,8 @@ summarizeDispatchWorkgroupsOp(DispatchWorkgroupsOp regionOp) {

std::string bestSummary = "";
TypeSwitch<Operation *>(bestOp)
.Case<linalg::SoftmaxOp>(
[&](auto op) { bestSummary = summarizeLinalgExtOp(op); })
.Case<linalg::LinalgOp>(
[&](auto op) { bestSummary = summarizeLinalgOp(op); })
.Case<IREE::LinalgExt::SetEncodingOp>([&](auto op) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,9 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase<RaiseSpecialOpsPass> {
linalg::LinalgOp op = softmax.first;
Value src = softmax.second;
rewriter.setInsertionPoint(softmax.first);
rewriter.replaceOpWithNewOp<IREE::LinalgExt::SoftmaxOp>(
op, src, op.getDpsInitOperand(0)->get(), op.getNumLoops() - 1);
rewriter.replaceOpWithNewOp<linalg::SoftmaxOp>(
op, op->getResultTypes(), src, op.getDpsInitOperand(0)->get(),
op.getNumLoops() - 1);
}

for (std::pair<linalg::MatmulOp, Value> aTransposeBMatmul :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func.func @main(%arg0: tensor<7xf32>) -> tensor<7xf32> {
(%arg1: !flow.dispatch.tensor<readonly:tensor<7xf32>>, %arg2: !flow.dispatch.tensor<writeonly:tensor<7xf32>>) {
%1 = flow.dispatch.tensor.load %arg1, offsets = [0], sizes = [7], strides = [1] : !flow.dispatch.tensor<readonly:tensor<7xf32>> -> tensor<7xf32>
%2 = tensor.empty() : tensor<7xf32>
%3 = iree_linalg_ext.softmax dimension(0) ins(%1 : tensor<7xf32>) outs(%2 : tensor<7xf32>) -> tensor<7xf32>
%3 = linalg.softmax dimension(0) ins(%1 : tensor<7xf32>) outs(%2 : tensor<7xf32>) -> tensor<7xf32>
flow.dispatch.tensor.store %3, %arg2, offsets = [0], sizes = [7], strides = [1] : tensor<7xf32> -> !flow.dispatch.tensor<writeonly:tensor<7xf32>>
flow.return
} count(%arg1: index) -> (index, index, index) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// CHECK-LABEL: @softmax
// CHECK-SAME: %[[ARG:.+]]: tensor<?x?x?xf32>
// CHECK: %[[E:.+]] = tensor.empty(%{{.*}}, %{{.*}}, %{{.*}}) : tensor<?x?x?xf32>
// CHECK: %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<?x?x?xf32>) outs(%[[E]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: %[[S:.+]] = linalg.softmax dimension(2) ins(%[[ARG]] : tensor<?x?x?xf32>) outs(%[[E]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK: return %[[S]] : tensor<?x?x?xf32>

func.func @softmax(%src : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
Expand Down Expand Up @@ -56,7 +56,7 @@ func.func @softmax(%src : tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
// CHECK-LABEL: @softmax_no_rcp
// CHECK-SAME: %[[ARG:.+]]: tensor<10x4096x4096xf16>
// CHECK: %[[E:.+]] = tensor.empty() : tensor<10x4096x4096xf16>
// CHECK: %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<10x4096x4096xf16>) outs(%[[E]] : tensor<10x4096x4096xf16>) -> tensor<10x4096x4096xf16>
// CHECK: %[[S:.+]] = linalg.softmax dimension(2) ins(%[[ARG]] : tensor<10x4096x4096xf16>) outs(%[[E]] : tensor<10x4096x4096xf16>) -> tensor<10x4096x4096xf16>
// CHECK: return %[[S]] : tensor<10x4096x4096xf16>
func.func @softmax_no_rcp(%src : tensor<10x4096x4096xf16>) -> (tensor<10x4096x4096xf16>) {
%cst_158 = arith.constant -6.550400e+04 : f16
Expand Down Expand Up @@ -113,7 +113,7 @@ func.func @softmax_no_rcp(%src : tensor<10x4096x4096xf16>) -> (tensor<10x4096x40
// CHECK-LABEL: @softmax_broadcast
// CHECK-SAME: %[[ARG:.+]]: tensor<12x128x128xf32>
// CHECK: %[[E:.+]] = tensor.empty() : tensor<12x128x128xf32>
// CHECK: %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<12x128x128xf32>) outs(%[[E]] : tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
// CHECK: %[[S:.+]] = linalg.softmax dimension(2) ins(%[[ARG]] : tensor<12x128x128xf32>) outs(%[[E]] : tensor<12x128x128xf32>) -> tensor<12x128x128xf32>
// CHECK: return %[[S]] : tensor<12x128x128xf32>
func.func @softmax_broadcast(%93 : tensor<12x128x128xf32>) -> (tensor<12x128x128xf32>) {
%cst_16 = arith.constant 0xFF800000 : f32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ void registerUtilExternalModels(DialectRegistry &registry) {
LinalgOpTiedOpInterface<LinalgExt::WinogradInputTransformOp>>(*ctx);
LinalgExt::WinogradOutputTransformOp::attachInterface<
LinalgOpTiedOpInterface<LinalgExt::WinogradOutputTransformOp>>(*ctx);
LinalgExt::SoftmaxOp::attachInterface<
LinalgOpTiedOpInterface<LinalgExt::SoftmaxOp>>(*ctx);
LinalgExt::AttentionOp::attachInterface<
LinalgOpTiedOpInterface<LinalgExt::AttentionOp>>(*ctx);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,75 +538,6 @@ def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[
}];
}

//===----------------------------------------------------------------------===//
// Softmax
//===----------------------------------------------------------------------===//

def IREELinalgExt_SoftmaxOp : IREELinalgExt_Op<"softmax",
[PredOpTrait<"only one input and one output", CheckNumOperands<2>>,
PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Softmax operator";
let description = [{
This op computes a numerically stable version of softmax for a given tensor.
For a given input tensor x and specified dimension d,
we first compute the max along that dimension (m). We then compute
f(x) = exp(x - m). Then, we sum f(x) along dimension d to get l(x). Finally,
we compute the softmax as f(x) / l(x).
}];

let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs,
I64Attr:$dimension
);

let builders = [
OpBuilder<(ins "Value":$inputs, "Value":$outputs,
CArg<"int64_t", "0">:$dimension)>
];

let results = (outs Variadic<AnyRankedTensor>:$result);
let hasFolder = 1;
let assemblyFormat = [{
attr-dict
`dimension` `(` $dimension `)`
`ins` `(` $inputs `:` type($inputs) `)`
`outs` `(` $outputs `:` type($outputs) `)`
(`->` type($result)^)?
}];

let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
Value input() {
return getDpsInputOperand(0)->get();
}
Value output() {
return getDpsInitOperand(0)->get();
}
ShapedType getInputOperandType() {
return input().getType().cast<ShapedType>();
}
ShapedType getOutputOperandType() {
return output().getType().cast<ShapedType>();
}
int64_t getInputOperandRank() {
return getInputOperandType().getRank();
}
int64_t getOutputOperandRank() {
return getOutputOperandType().getRank();
}
// Method to implement for specifying output range for
// DestinationStyleOpInterface
MutableOperandRange getDpsInitsMutable() {
return getOutputsMutable();
}
}];
}

//===----------------------------------------------------------------------===//
// Attention
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2413,101 +2413,6 @@ LogicalResult WinogradOutputTransformOp::reifyResultShapes(
.reifyResultShapes(b, reifiedReturnShapes);
}

//===----------------------------------------------------------------------===//
// SoftmaxOp
//===----------------------------------------------------------------------===//

LogicalResult SoftmaxOp::verify() {
Operation *op = getOperation();
auto inputType = input().getType().cast<ShapedType>();
auto outputType = output().getType().cast<ShapedType>();
ArrayRef<int64_t> inputShape = inputType.getShape();
ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(inputShape, outputShape))) {
return op->emitOpError("incompatible output shape");
}
int64_t inputRank = getInputOperandRank();
int64_t dimension = getDimension();
if ((dimension < 0) || (dimension >= inputRank)) {
return op->emitOpError("incorrect dimension specified");
}
return success();
}

SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
int64_t operandRank = getInputOperandRank();
SmallVector<Range> loopBounds(operandRank);
Location loc = getLoc();
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
Value source = input();
for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
loopBounds[dim].offset = zero;
loopBounds[dim].size = getDimValue(builder, loc, source, dim);
loopBounds[dim].stride = one;
}
return loopBounds;
}

SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
utils::IteratorType::parallel);
iteratorTypes[getDimension()] = utils::IteratorType::reduction;
return iteratorTypes;
}

FailureOr<TilingResult>
SoftmaxOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
int64_t rank = getInputOperandRank();
auto oneAttr = builder.getI64IntegerAttr(1);
SmallVector<OpFoldResult> strides(rank, oneAttr);
SmallVector<Value> tiledOperands;
tiledOperands.emplace_back(
getSlice(builder, getLoc(), input(), offsets, sizes, strides));
tiledOperands.emplace_back(
getSlice(builder, getLoc(), getOutputs()[0], offsets, sizes, strides));

SmallVector<Type, 4> resultTypes;
if (hasTensorSemantics()) {
resultTypes.push_back(tiledOperands[1].getType());
}
Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);

return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}

LogicalResult SoftmaxOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
}
return failure();
}

LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}

LogicalResult
SoftmaxOp::reifyResultShapes(OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
return cast<LinalgExtOp>(getOperation())
.reifyResultShapes(b, reifiedReturnShapes);
}

void SoftmaxOp::build(OpBuilder &builder, OperationState &state, Value source,
Value output, int64_t dimension) {
build(builder, state, TypeRange({output.getType()}), ValueRange(source),
ValueRange(output), dimension);
}

//===----------------------------------------------------------------------===//
// AttentionOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2650,7 +2555,6 @@ DEFINE_OP_GET_EFFECTS(PackOp)
DEFINE_OP_GET_EFFECTS(UnPackOp)
DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp)
DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp)
DEFINE_OP_GET_EFFECTS(SoftmaxOp)
DEFINE_OP_GET_EFFECTS(AttentionOp)

//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 9c424c4

Please sign in to comment.