diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp index 89c675aae093..1140125124f9 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp @@ -609,8 +609,6 @@ void registerBufferizationInterfaces(DialectRegistry ®istry) { LinalgExtOpInterface>(*ctx); IREE::LinalgExt::WinogradOutputTransformOp::attachInterface< LinalgExtOpInterface>(*ctx); - IREE::LinalgExt::SoftmaxOp::attachInterface< - LinalgExtOpInterface>(*ctx); IREE::LinalgExt::AttentionOp::attachInterface< LinalgExtOpInterface>(*ctx); }); diff --git a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp index 934fe226e6ae..77724bf32bfd 100644 --- a/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp +++ b/compiler/src/iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.cpp @@ -260,8 +260,6 @@ void registerPartitionableLoopsInterfaceModels(DialectRegistry ®istry) { IREE::LinalgExt::WinogradOutputTransformOp::attachInterface< AllParallelAsPartitionableLoops< IREE::LinalgExt::WinogradOutputTransformOp>>(*ctx); - IREE::LinalgExt::SoftmaxOp::attachInterface< - AllParallelAsPartitionableLoops>(*ctx); IREE::LinalgExt::AttentionOp::attachInterface< AllParallelAsPartitionableLoops>(*ctx); }); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir index 56bf6ad04bf1..099bf20e264a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir @@ -215,7 +215,7 @@ hal.executable.variant @cuda, target = <"cuda", "cuda-nvptx-fb"> { %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<12x128x40960xf32> %3 = tensor.empty() : tensor<12x128x40960xf32> - %4 = iree_linalg_ext.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32> + %4 = linalg.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32> flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : tensor<12x128x40960xf32> -> !flow.dispatch.tensor> return } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir index 565ccc5a78e2..f2af39934ae2 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir @@ -23,7 +23,7 @@ hal.executable.variant @rocm, target = <"rocm", "rocm-hsaco-fb", {target_arch = %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<12x128x40960xf32> %3 = tensor.empty() : tensor<12x128x40960xf32> - %4 = iree_linalg_ext.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32> + %4 = linalg.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32> flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : tensor<12x128x40960xf32> -> !flow.dispatch.tensor> return } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir index cbbac9c84bf2..ec0c5209da33 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/lowering_reduction.mlir @@ -217,7 +217,7 @@ hal.executable.variant public @vulkan_spirv_fb, target = #executable_target_vulk %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor> %2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : !flow.dispatch.tensor> -> tensor<12x128x40960xf32> %3 = tensor.empty() : tensor<12x128x40960xf32> - %4 = iree_linalg_ext.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32> + %4 = linalg.softmax dimension(2) ins(%2 : tensor<12x128x40960xf32>) outs(%3 : tensor<12x128x40960xf32>) -> tensor<12x128x40960xf32> flow.dispatch.tensor.store %4, %1, offsets = [0, 0, 0], sizes = [12, 128, 40960], strides = [1, 1, 1] : tensor<12x128x40960xf32> -> !flow.dispatch.tensor> return } diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp index 6518456d5c61..c472fd9ccf69 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/OutlineDispatchRegions.cpp @@ -88,6 +88,12 @@ static int64_t estimateLinalgExtOpCost(Operation *op) { return cost; } +// Estimates the evaluation cost of a Linalg::Softmax op using a heuristic cost +// model similar to LinalgExt ops. +static int64_t estimateLinalgSoftmaxOpCost(Operation *op) { + return estimateLinalgExtOpCost(op); +} + // Returns a string like "512xDx128" representing loop ranges. static std::string loopRangesToString(ArrayRef loopRanges) { std::string outputString; @@ -167,7 +173,9 @@ static std::string summarizeLinalgOp(linalg::LinalgOp op) { static std::string summarizeLinalgExtOp(Operation *op) { auto opName = op->getName().getStringRef(); - if (!opName.consume_front("iree_linalg_ext.")) + // Currently, this utility is also invoked by Linalg::SoftmaxOp. + if (!(opName.consume_front("iree_linalg_ext.") || + opName.consume_front("linalg."))) return ""; std::string suffix = ""; if (TensorType mainTensor = getMainTensorForLinalgExtOp(op)) { @@ -203,6 +211,15 @@ summarizeDispatchWorkgroupsOp(DispatchWorkgroupsOp regionOp) { int64_t bestEstimatedCost = kMinEstimatedCost; regionOp.getWorkgroupBody().walk([&](Operation *op) { TypeSwitch(op) + .Case([&](auto op) { + int64_t estimatedCost = estimateLinalgSoftmaxOpCost(op); + if (estimatedCost < bestEstimatedCost) + return; + bestEstimatedCost = estimatedCost; + bestOp = op; + LLVM_DEBUG(llvm::dbgs() << "// new best op: '" << bestOp->getName() + << "', cost: " << bestEstimatedCost << "\n"); + }) .Case([&](auto op) { int64_t estimatedCost = estimateLinalgOpCost(op); if (estimatedCost < bestEstimatedCost) @@ -259,6 +276,8 @@ summarizeDispatchWorkgroupsOp(DispatchWorkgroupsOp regionOp) { std::string bestSummary = ""; TypeSwitch(bestOp) + .Case( + [&](auto op) { bestSummary = summarizeLinalgExtOp(op); }) .Case( [&](auto op) { bestSummary = summarizeLinalgOp(op); }) .Case([&](auto op) { diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp index 06af02d9c501..1036321205b7 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/RaiseSpecialOps.cpp @@ -674,8 +674,9 @@ struct RaiseSpecialOpsPass : public RaiseSpecialOpsBase { linalg::LinalgOp op = softmax.first; Value src = softmax.second; rewriter.setInsertionPoint(softmax.first); - rewriter.replaceOpWithNewOp( - op, src, op.getDpsInitOperand(0)->get(), op.getNumLoops() - 1); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), src, op.getDpsInitOperand(0)->get(), + op.getNumLoops() - 1); } for (std::pair aTransposeBMatmul : diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir index 509a136f6d78..e9bb837b9d88 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/outline_dispatch_regions.mlir @@ -339,7 +339,7 @@ func.func @main(%arg0: tensor<7xf32>) -> tensor<7xf32> { (%arg1: !flow.dispatch.tensor>, %arg2: !flow.dispatch.tensor>) { %1 = flow.dispatch.tensor.load %arg1, offsets = [0], sizes = [7], strides = [1] : !flow.dispatch.tensor> -> tensor<7xf32> %2 = tensor.empty() : tensor<7xf32> - %3 = iree_linalg_ext.softmax dimension(0) ins(%1 : tensor<7xf32>) outs(%2 : tensor<7xf32>) -> tensor<7xf32> + %3 = linalg.softmax dimension(0) ins(%1 : tensor<7xf32>) outs(%2 : tensor<7xf32>) -> tensor<7xf32> flow.dispatch.tensor.store %3, %arg2, offsets = [0], sizes = [7], strides = [1] : tensor<7xf32> -> !flow.dispatch.tensor> flow.return } count(%arg1: index) -> (index, index, index) { diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir index bb712a253889..f4a3b702567a 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/raise_special_ops.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: @softmax // CHECK-SAME: %[[ARG:.+]]: tensor // CHECK: %[[E:.+]] = tensor.empty(%{{.*}}, %{{.*}}, %{{.*}}) : tensor -// CHECK: %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor) outs(%[[E]] : tensor) -> tensor +// CHECK: %[[S:.+]] = linalg.softmax dimension(2) ins(%[[ARG]] : tensor) outs(%[[E]] : tensor) -> tensor // CHECK: return %[[S]] : tensor func.func @softmax(%src : tensor) -> (tensor) { @@ -56,7 +56,7 @@ func.func @softmax(%src : tensor) -> (tensor) { // CHECK-LABEL: @softmax_no_rcp // CHECK-SAME: %[[ARG:.+]]: tensor<10x4096x4096xf16> // CHECK: %[[E:.+]] = tensor.empty() : tensor<10x4096x4096xf16> -// CHECK: %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<10x4096x4096xf16>) outs(%[[E]] : tensor<10x4096x4096xf16>) -> tensor<10x4096x4096xf16> +// CHECK: %[[S:.+]] = linalg.softmax dimension(2) ins(%[[ARG]] : tensor<10x4096x4096xf16>) outs(%[[E]] : tensor<10x4096x4096xf16>) -> tensor<10x4096x4096xf16> // CHECK: return %[[S]] : tensor<10x4096x4096xf16> func.func @softmax_no_rcp(%src : tensor<10x4096x4096xf16>) -> (tensor<10x4096x4096xf16>) { %cst_158 = arith.constant -6.550400e+04 : f16 @@ -113,7 +113,7 @@ func.func @softmax_no_rcp(%src : tensor<10x4096x4096xf16>) -> (tensor<10x4096x40 // CHECK-LABEL: @softmax_broadcast // CHECK-SAME: %[[ARG:.+]]: tensor<12x128x128xf32> // CHECK: %[[E:.+]] = tensor.empty() : tensor<12x128x128xf32> -// CHECK: %[[S:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG]] : tensor<12x128x128xf32>) outs(%[[E]] : tensor<12x128x128xf32>) -> tensor<12x128x128xf32> +// CHECK: %[[S:.+]] = linalg.softmax dimension(2) ins(%[[ARG]] : tensor<12x128x128xf32>) outs(%[[E]] : tensor<12x128x128xf32>) -> tensor<12x128x128xf32> // CHECK: return %[[S]] : tensor<12x128x128xf32> func.func @softmax_broadcast(%93 : tensor<12x128x128xf32>) -> (tensor<12x128x128xf32>) { %cst_16 = arith.constant 0xFF800000 : f32 diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilExternalModels.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilExternalModels.cpp index 851d42f95873..7891490c1998 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilExternalModels.cpp @@ -168,8 +168,6 @@ void registerUtilExternalModels(DialectRegistry ®istry) { LinalgOpTiedOpInterface>(*ctx); LinalgExt::WinogradOutputTransformOp::attachInterface< LinalgOpTiedOpInterface>(*ctx); - LinalgExt::SoftmaxOp::attachInterface< - LinalgOpTiedOpInterface>(*ctx); LinalgExt::AttentionOp::attachInterface< LinalgOpTiedOpInterface>(*ctx); }); diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td index e67f5e3c3f60..35c8ef9351fc 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -538,75 +538,6 @@ def IREELinalgExt_TopkOp : IREELinalgExt_Op<"topk",[ }]; } -//===----------------------------------------------------------------------===// -// Softmax -//===----------------------------------------------------------------------===// - -def IREELinalgExt_SoftmaxOp : IREELinalgExt_Op<"softmax", - [PredOpTrait<"only one input and one output", CheckNumOperands<2>>, - PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Softmax operator"; - let description = [{ - This op computes a numerically stable version of softmax for a given tensor. - For a given input tensor x and specified dimension d, - we first compute the max along that dimension (m). We then compute - f(x) = exp(x - m). Then, we sum f(x) along dimension d to get l(x). Finally, - we compute the softmax as f(x) / l(x). - }]; - - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - I64Attr:$dimension - ); - - let builders = [ - OpBuilder<(ins "Value":$inputs, "Value":$outputs, - CArg<"int64_t", "0">:$dimension)> - ]; - - let results = (outs Variadic:$result); - let hasFolder = 1; - let assemblyFormat = [{ - attr-dict - `dimension` `(` $dimension `)` - `ins` `(` $inputs `:` type($inputs) `)` - `outs` `(` $outputs `:` type($outputs) `)` - (`->` type($result)^)? - }]; - - let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ - Value input() { - return getDpsInputOperand(0)->get(); - } - Value output() { - return getDpsInitOperand(0)->get(); - } - ShapedType getInputOperandType() { - return input().getType().cast(); - } - ShapedType getOutputOperandType() { - return output().getType().cast(); - } - int64_t getInputOperandRank() { - return getInputOperandType().getRank(); - } - int64_t getOutputOperandRank() { - return getOutputOperandType().getRank(); - } - // Method to implement for specifying output range for - // DestinationStyleOpInterface - MutableOperandRange getDpsInitsMutable() { - return getOutputsMutable(); - } - }]; -} - //===----------------------------------------------------------------------===// // Attention //===----------------------------------------------------------------------===// diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp index 57d638ec359d..635b325fad51 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -2413,101 +2413,6 @@ LogicalResult WinogradOutputTransformOp::reifyResultShapes( .reifyResultShapes(b, reifiedReturnShapes); } -//===----------------------------------------------------------------------===// -// SoftmaxOp -//===----------------------------------------------------------------------===// - -LogicalResult SoftmaxOp::verify() { - Operation *op = getOperation(); - auto inputType = input().getType().cast(); - auto outputType = output().getType().cast(); - ArrayRef inputShape = inputType.getShape(); - ArrayRef outputShape = outputType.getShape(); - if (failed(verifyCompatibleShape(inputShape, outputShape))) { - return op->emitOpError("incompatible output shape"); - } - int64_t inputRank = getInputOperandRank(); - int64_t dimension = getDimension(); - if ((dimension < 0) || (dimension >= inputRank)) { - return op->emitOpError("incorrect dimension specified"); - } - return success(); -} - -SmallVector SoftmaxOp::getIterationDomain(OpBuilder &builder) { - int64_t operandRank = getInputOperandRank(); - SmallVector loopBounds(operandRank); - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - Value source = input(); - for (auto dim : llvm::seq(0, operandRank)) { - loopBounds[dim].offset = zero; - loopBounds[dim].size = getDimValue(builder, loc, source, dim); - loopBounds[dim].stride = one; - } - return loopBounds; -} - -SmallVector SoftmaxOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getInputOperandRank(), - utils::IteratorType::parallel); - iteratorTypes[getDimension()] = utils::IteratorType::reduction; - return iteratorTypes; -} - -FailureOr -SoftmaxOp::getTiledImplementation(OpBuilder &builder, - ArrayRef offsets, - ArrayRef sizes) { - int64_t rank = getInputOperandRank(); - auto oneAttr = builder.getI64IntegerAttr(1); - SmallVector strides(rank, oneAttr); - SmallVector tiledOperands; - tiledOperands.emplace_back( - getSlice(builder, getLoc(), input(), offsets, sizes, strides)); - tiledOperands.emplace_back( - getSlice(builder, getLoc(), getOutputs()[0], offsets, sizes, strides)); - - SmallVector resultTypes; - if (hasTensorSemantics()) { - resultTypes.push_back(tiledOperands[1].getType()); - } - Operation *tiledOp = - mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; -} - -LogicalResult SoftmaxOp::getResultTilePosition( - OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, - ArrayRef sizes, SmallVector &resultOffsets, - SmallVector &resultSizes) { - if (resultNumber == 0) { - resultOffsets.assign(offsets.begin(), offsets.end()); - resultSizes.assign(sizes.begin(), sizes.end()); - return success(); - } - return failure(); -} - -LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl &) { - return memref::foldMemRefCast(*this); -} - -LogicalResult -SoftmaxOp::reifyResultShapes(OpBuilder &b, - ReifiedRankedShapedTypeDims &reifiedReturnShapes) { - return cast(getOperation()) - .reifyResultShapes(b, reifiedReturnShapes); -} - -void SoftmaxOp::build(OpBuilder &builder, OperationState &state, Value source, - Value output, int64_t dimension) { - build(builder, state, TypeRange({output.getType()}), ValueRange(source), - ValueRange(output), dimension); -} - //===----------------------------------------------------------------------===// // AttentionOp //===----------------------------------------------------------------------===// @@ -2650,7 +2555,6 @@ DEFINE_OP_GET_EFFECTS(PackOp) DEFINE_OP_GET_EFFECTS(UnPackOp) DEFINE_OP_GET_EFFECTS(WinogradInputTransformOp) DEFINE_OP_GET_EFFECTS(WinogradOutputTransformOp) -DEFINE_OP_GET_EFFECTS(SoftmaxOp) DEFINE_OP_GET_EFFECTS(AttentionOp) //===----------------------------------------------------------------------===// diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/DecomposeSoftmax.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/DecomposeSoftmax.cpp index 2ec81f44901b..2f829e3e0439 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/DecomposeSoftmax.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/DecomposeSoftmax.cpp @@ -119,44 +119,35 @@ static Value computeSoftmax(Value numerator, Value denominator, Value output, LogicalResult convertSoftmaxToGenerics(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); SmallVector toDelete; - funcOp.walk([&](IREE::LinalgExt::SoftmaxOp softmaxOp) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(softmaxOp); - Location loc = softmaxOp.getLoc(); - Value input = softmaxOp.input(); - ShapedType inputType = input.getType().cast(); - Type elementType = inputType.getElementType(); - int64_t reductionDim = softmaxOp.getDimension(); - SmallVector dims = - tensor::getMixedSizes(rewriter, loc, input); - Value outputNd = rewriter.create(loc, dims, elementType); - dims.erase(dims.begin() + reductionDim); - // Compute max along dim - Value output = rewriter.create(loc, dims, elementType); - Value largeNegative = rewriter.create( - loc, rewriter.getFloatAttr(elementType, -1.0e30)); - Value negativeInit = - rewriter.create(loc, Value{largeNegative}, output) - .result(); - Value max = reduce(input, negativeInit, reductionDim, - loc, rewriter); - // Subtract max from input and exponentiate - linalg::GenericOp numeratorOp = - subtractAndExp(input, max, outputNd, reductionDim, loc, rewriter); - Value numerator = numeratorOp->getResult(0); - // Compute sum along dim - Value zero = rewriter.create( - loc, rewriter.getZeroAttr(elementType)); - Value zeroInit = - rewriter.create(loc, Value{zero}, output).result(); - Value denominator = - reduce(numerator, zeroInit, reductionDim, loc, rewriter); - // Compute softmax - Value result = computeSoftmax(numerator, denominator, outputNd, - reductionDim, loc, rewriter); - softmaxOp.getResult()[0].replaceAllUsesWith(result); - // Delete the op after the walk. - toDelete.push_back(softmaxOp.getOperation()); + SmallVector softmaxOpsToDecompose; + funcOp.walk([&](linalg::SoftmaxOp softmaxOp) { + softmaxOpsToDecompose.push_back(softmaxOp); + }); + + OpBuilder::InsertionGuard guard(rewriter); + for (Operation *softmaxOp : softmaxOpsToDecompose) { + // Cast linalg::softmax to AggregatedOpInterface since this where + // `decomposeOperation` is implemented. + auto decomposableSoftmaxOp = cast(softmaxOp); + + // Decompose linalg::softmax. + FailureOr> result = + decomposableSoftmaxOp.decomposeOperation(rewriter); + if (failed(result)) { + failed(rewriter.notifyMatchFailure( + softmaxOp, "linalg::SoftmaxOp could not be decomposed")); + return failure(); + } + + // Replace the result of linalg::softmax with the `result` generated via + // the decomposition above. + rewriter.replaceOp(decomposableSoftmaxOp, *result); + + // Fusion later depends on couple of Ops/Values - we try to obtain the same + // by backtracking through the generated value's def-chain. + Operation *resultOp = (*result)[0].getDefiningOp(); + Value numerator = resultOp->getOperand(0); + Operation *numeratorOp = numerator.getDefiningOp(); // Rematerialize operands that are marked for this. SmallVector uses = llvm::to_vector(llvm::map_range( @@ -176,9 +167,7 @@ LogicalResult convertSoftmaxToGenerics(func::FuncOp funcOp) { } } toDelete.push_back(numeratorOp); - - return WalkResult::advance(); - }); + } for (Operation *op : toDelete) { rewriter.eraseOp(op); } diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/decompose_softmax.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/decompose_softmax.mlir index 55ce77641063..ea0717ce473e 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/decompose_softmax.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/decompose_softmax.mlir @@ -2,7 +2,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { %0 = tensor.empty() : tensor<2x16x32xf32> - %1 = iree_linalg_ext.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> + %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> return %1 : tensor<2x16x32xf32> } @@ -11,7 +11,7 @@ func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32> // CHECK: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32> -// CHECK: %[[CST:.+]] = arith.constant -1.000000e+30 : f32 +// CHECK: %[[CST:.+]] = arith.constant -1.401300e-45 : f32 // CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> // CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[MAP]], #[[MAP1]]], iterator_types = ["parallel", // CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) { diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir index 51b9944f32d9..54e34ad936eb 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/invalid.mlir @@ -726,7 +726,7 @@ func.func @illegal_winograd_output_image_dimensions(%arg0: tensor<8x8x1x2x2x32xf func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x16xf32> { %0 = tensor.empty() : tensor<2x16xf32> // expected-error @+1 {{incompatible output shape}} - %1 = iree_linalg_ext.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16xf32>) -> tensor<2x16xf32> + %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16xf32>) -> tensor<2x16xf32> return %1 : tensor<2x16xf32> } diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir index b899ea09b7b6..c93249d40e4a 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/roundtrip.mlir @@ -1083,12 +1083,12 @@ func.func @winograd_output_transform_nchw(%arg0: tensor<8x8x1x2x2x1280xf32>) -> func.func @softmax(%arg0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { %0 = tensor.empty() : tensor<2x16x32xf32> - %1 = iree_linalg_ext.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> + %1 = linalg.softmax dimension(2) ins(%arg0 : tensor<2x16x32xf32>) outs(%0: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> return %1 : tensor<2x16x32xf32> } // CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32> -// CHECK: %[[D1:.+]] = iree_linalg_ext.softmax dimension(2) ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D0]] : +// CHECK: %[[D1:.+]] = linalg.softmax dimension(2) ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D0]] : // CHECK-SAME: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> // CHECK: return %[[D1]] : tensor<2x16x32xf32> // CHECK: } diff --git a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir index b41d22e06d10..1aa726999071 100644 --- a/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir +++ b/llvm-external-projects/iree-dialects/test/Dialect/iree_linalg_ext/tiling.mlir @@ -1156,7 +1156,7 @@ func.func @winograd_output_transform_nchw(%arg0: tensor<8x8x1x2x2x32xf32>) -> te func.func @softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { %0 = tensor.empty() : tensor<16x64x256xf32> - %1 = iree_linalg_ext.softmax {__internal_linalg_transform__ = "distribute_input"} + %1 = linalg.softmax {__internal_linalg_transform__ = "distribute_input"} dimension(1) ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32> return %1 : tensor<16x64x256xf32> } @@ -1165,11 +1165,10 @@ func.func @softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { // CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)> // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> // CHECK: func.func @softmax(%[[ARG0:[a-zA-Z0-9_]+]]: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index // CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index // CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index // CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<16x64x256xf32> // CHECK-DAG: %[[D1:.+]] = iree_input.dispatch.workgroup.id[0] : index // CHECK-DAG: %[[D2:.+]] = iree_input.dispatch.workgroup.count[0] : index @@ -1186,14 +1185,14 @@ func.func @softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { // CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[ARG2]]) -> (tensor<16x64x256xf32>) { // CHECK-DAG: %[[D12:.+]] = affine.min #[[MAP3]](%[[ARG3]])[%[[C30]], %[[C256]]] // CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG1]], 0, %[[ARG3]]] [%[[D8]], -// CHECK-SAME: %[[C64]], %[[D12]]] [1, 1, 1] : tensor<16x64x256xf32> to tensor +// CHECK-SAME: 64, %[[D12]]] [1, 1, 1] : tensor<16x64x256xf32> to tensor // CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[D0]][%[[ARG1]], 0, %[[ARG3]]] [%[[D8]], -// CHECK-SAME: %[[C64]], %[[D12]]] [1, 1, 1] : tensor<16x64x256xf32> to tensor -// CHECK: %[[D13:.+]] = iree_linalg_ext.softmax {__internal_linalg_transform__ = "distribute_output"} -// CHECK-SAME: dimension(1) ins(%[[EXTRACTED_SLICE]] : tensor) outs(%[[EXTRACTED_SLICE_0]] : -// CHECK-SAME: tensor) -> tensor +// CHECK-SAME: 64, %[[D12]]] [1, 1, 1] : tensor<16x64x256xf32> to tensor +// CHECK: %[[D13:.+]] = linalg.softmax {__internal_linalg_transform__ = "distribute_output"} +// CHECK-SAME: dimension(1) ins(%[[EXTRACTED_SLICE]] : tensor) outs(%[[EXTRACTED_SLICE_0]] : +// CHECK-SAME: tensor) -> tensor // CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[D13]] into %[[ARG4]][%[[ARG1]], 0, %[[ARG3]]] -// CHECK-SAME: [%[[D8]], %[[C64]], %[[D12]]] [1, 1, 1] : tensor into tensor<16x64x256xf32> +// CHECK-SAME: [%[[D8]], 64, %[[D12]]] [1, 1, 1] : tensor into tensor<16x64x256xf32> // CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<16x64x256xf32> // CHECK: } // CHECK: scf.yield %[[D11]] : tensor<16x64x256xf32> @@ -1204,7 +1203,7 @@ func.func @softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> { // ----- func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256xf32>) { - iree_linalg_ext.softmax {__internal_linalg_transform__ = "distribute_input"} + linalg.softmax {__internal_linalg_transform__ = "distribute_input"} dimension(1) ins(%arg0 : memref<16x64x256xf32>) outs(%arg1 : memref<16x64x256xf32>) return } @@ -1214,11 +1213,10 @@ func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256x // CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> // CHECK: func.func @softmax_memref(%[[ARG0:[a-zA-Z0-9_]+]]: memref<16x64x256xf32>, %[[ARG1:[a-zA-Z0-9_]+]]: // CHECK-SAME: memref<16x64x256xf32>) { -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index // CHECK-DAG: %[[C256:.+]] = arith.constant 256 : index // CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index +// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index // CHECK: %[[D0:.+]] = iree_input.dispatch.workgroup.id[0] : index // CHECK: %[[D1:.+]] = iree_input.dispatch.workgroup.count[0] : index // CHECK: %[[D2:.+]] = iree_input.dispatch.workgroup.id[1] : index @@ -1231,13 +1229,13 @@ func.func @softmax_memref(%arg0: memref<16x64x256xf32>, %arg1: memref<16x64x256x // CHECK-DAG: %[[D8:.+]] = affine.apply #[[MAP2]]()[%[[D1]]] // CHECK: scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[D7]] to %[[C256]] step %[[D8]] { // CHECK-DAG: %[[D9:.+]] = affine.min #[[MAP3]](%[[ARG3]])[%[[C30]], %[[C256]]] -// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][%[[ARG2]], 0, %[[ARG3]]] [%[[D6]], %[[C64]], %[[D9]]] -// CHECK-SAME: [1, 1, 1] : memref<16x64x256xf32> to memref> -// CHECK: %[[SUBVIEW_0:.+]] = memref.subview %[[ARG1]][%[[ARG2]], 0, %[[ARG3]]] [%[[D6]], %[[C64]], %[[D9]]] -// CHECK-SAME: [1, 1, 1] : memref<16x64x256xf32> to memref> -// CHECK: iree_linalg_ext.softmax {__internal_linalg_transform__ = "distribute_output"} dimension(1) -// CHECK-SAME: ins(%[[SUBVIEW]] : memref>) outs(%[[SUBVIEW_0]] : -// CHECK-SAME: memref>) +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG0]][%[[ARG2]], 0, %[[ARG3]]] [%[[D6]], 64, %[[D9]]] +// CHECK-SAME: [1, 1, 1] : memref<16x64x256xf32> to memref> +// CHECK: %[[SUBVIEW_0:.+]] = memref.subview %[[ARG1]][%[[ARG2]], 0, %[[ARG3]]] [%[[D6]], 64, %[[D9]]] +// CHECK-SAME: [1, 1, 1] : memref<16x64x256xf32> to memref> +// CHECK: linalg.softmax {__internal_linalg_transform__ = "distribute_output"} dimension(1) +// CHECK-SAME: ins(%[[SUBVIEW]] : memref>) outs(%[[SUBVIEW_0]] : +// CHECK-SAME: memref>) // CHECK: } // CHECK: } // CHECK: return diff --git a/tests/e2e/linalg_ext_ops/BUILD.bazel b/tests/e2e/linalg_ext_ops/BUILD.bazel index 117fe9e3e64d..8f2aae086706 100644 --- a/tests/e2e/linalg_ext_ops/BUILD.bazel +++ b/tests/e2e/linalg_ext_ops/BUILD.bazel @@ -21,7 +21,6 @@ iree_check_single_backend_test_suite( "reverse.mlir", "scan.mlir", "scatter.mlir", - "softmax.mlir", "sort.mlir", "top-k.mlir", ], @@ -95,7 +94,6 @@ iree_check_single_backend_test_suite( "reverse.mlir", "scan.mlir", "scatter.mlir", - "softmax.mlir", "sort.mlir", "top-k.mlir", "winograd_input.mlir", @@ -127,7 +125,6 @@ iree_check_single_backend_test_suite( ], include = ["*.mlir"], exclude = [ - "softmax.mlir", "winograd_input.mlir", "winograd_output.mlir", ], @@ -155,7 +152,6 @@ iree_check_single_backend_test_suite( # Re-enable this once we have new devices with up-to-date drivers. "top-k.mlir", "scan.mlir", - "softmax.mlir", ], ), driver = "vulkan", diff --git a/tests/e2e/linalg_ext_ops/CMakeLists.txt b/tests/e2e/linalg_ext_ops/CMakeLists.txt index ca56c68914c1..675ad5e5bcc5 100644 --- a/tests/e2e/linalg_ext_ops/CMakeLists.txt +++ b/tests/e2e/linalg_ext_ops/CMakeLists.txt @@ -17,7 +17,6 @@ iree_check_single_backend_test_suite( "reverse.mlir" "scan.mlir" "scatter.mlir" - "softmax.mlir" "sort.mlir" "top-k.mlir" TARGET_BACKEND @@ -80,7 +79,6 @@ iree_check_single_backend_test_suite( "reverse.mlir" "scan.mlir" "scatter.mlir" - "softmax.mlir" "sort.mlir" "top-k.mlir" "winograd_input.mlir" diff --git a/tests/e2e/linalg_ext_ops/softmax.mlir b/tests/e2e/linalg_ext_ops/softmax.mlir deleted file mode 100644 index d53fff716566..000000000000 --- a/tests/e2e/linalg_ext_ops/softmax.mlir +++ /dev/null @@ -1,13 +0,0 @@ -func.func @softmax() { - %input = util.unfoldable_constant dense<1.0> : tensor<2x8x4xf32> - - %init = tensor.empty() : tensor<2x8x4xf32> - %1 = iree_linalg_ext.softmax dimension(2) - ins(%input : tensor<2x8x4xf32>) - outs(%init : tensor<2x8x4xf32>) -> tensor<2x8x4xf32> - check.expect_almost_eq_const( - %1, - dense<0.25> : tensor<2x8x4xf32> - ) : tensor<2x8x4xf32> - return -}