diff --git a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp index 930c9f50d133..4f3e3c9fbc5b 100644 --- a/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/BlockDynamicDimensions.cpp @@ -102,7 +102,11 @@ getTensorDivisibilityInfo(const TensorDynamicDimAnalysis &dynamicDimAnalysis, /// inverses of each other. The `util.optimization.barrier` avoid these from /// getting folded away during reshape propagation. Return the result of the /// `tensor.collapse_shape generated. -static std::optional +struct ReshapeOps { + tensor::ExpandShapeOp expandShapeOp; + tensor::CollapseShapeOp collapseShapeOp; +}; +static std::optional blockDynamicDimensionsOfValue(RewriterBase &rewriter, const TensorDivisibilityInfo &divisibilityInfo, Value v) { @@ -154,18 +158,23 @@ blockDynamicDimensionsOfValue(RewriterBase &rewriter, auto outputType = RankedTensorType::get( staticOutputShape, tensorType.getElementType(), tensorType.getEncoding()); - Value expandShape = rewriter.create( + auto expandShapeOp = rewriter.create( loc, outputType, v, reassociation, outputShape); - Value barrier = - rewriter.create(loc, expandShape) - .getResult(0); - Value collapseShape = rewriter.create( + Value barrier = rewriter + .create( + loc, expandShapeOp.getResult()) + .getResult(0); + auto collapseShapeOp = rewriter.create( loc, tensorType, barrier, reassociation); - return collapseShape; + return ReshapeOps{expandShapeOp, collapseShapeOp}; } +//===---------------------------------------------------------------------===// +// Methods for blocking operands of operations +//===---------------------------------------------------------------------===// + /// For an operation, replace the operands at indices specified in -/// `limitToOperandIndices` with the result of +/// `limitToOperandNumbers` with the result of /// `tensor.expand_shape`/`tensor.collapse_shape` pair to materialize the /// information about dynamic dimensions that are known to be a multiple of a /// compile-time static value. For example, @@ -186,11 +195,10 @@ blockDynamicDimensionsOfValue(RewriterBase &rewriter, /// ``` static LogicalResult blockDynamicDimensions( RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis, - Operation *operation, llvm::SmallDenseSet limitToOperandIndices) { - OpBuilder::InsertionGuard g(rewriter); - + Operation *operation, llvm::SmallDenseSet limitToOperandNumbers, + llvm::SmallDenseSet limitToResultNumbers) { for (OpOperand &operand : operation->getOpOperands()) { - if (!limitToOperandIndices.contains(operand.getOperandNumber())) + if (!limitToOperandNumbers.contains(operand.getOperandNumber())) continue; if (operand.get().getDefiningOp()) continue; @@ -198,56 +206,93 @@ static LogicalResult blockDynamicDimensions( getTensorDivisibilityInfo(dynamicDimAnalysis, operand.get()); if (operandDivisibilityInfo.empty()) continue; - std::optional newOperand = blockDynamicDimensionsOfValue( + std::optional reshapes = blockDynamicDimensionsOfValue( rewriter, operandDivisibilityInfo, operand.get()); - if (newOperand) { - rewriter.modifyOpInPlace(operation, - [&]() { operand.set(newOperand.value()); }); + if (reshapes) { + rewriter.modifyOpInPlace( + operation, [&]() { operand.set(reshapes->collapseShapeOp); }); + } + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(operation); + for (OpResult result : operation->getResults()) { + if (!limitToResultNumbers.contains(result.getResultNumber())) + continue; + TensorDivisibilityInfo resultDivisibilityInfo = + getTensorDivisibilityInfo(dynamicDimAnalysis, result); + if (resultDivisibilityInfo.empty()) + continue; + std::optional reshapes = + blockDynamicDimensionsOfValue(rewriter, resultDivisibilityInfo, result); + if (reshapes) { + llvm::SmallPtrSet ignoreUses; + ignoreUses.insert(reshapes->expandShapeOp); + rewriter.replaceAllUsesExcept( + result, reshapes->collapseShapeOp.getResult(), ignoreUses); } } return success(); } -/// Insert `tensor.expand_shape` operations to materialize in IR information -/// about dynamic dimensions that are known to be a multiple of a compile-time -/// know value, for the operands of `iree_linalg_ext.attention` operation. +/// Generic method for blocking all operands of an operation. +static LogicalResult blockDynamicDimensionsOfAllTensorOperandsAndResults( + RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis, + Operation *op) { + llvm::SmallDenseSet tensorOperandsList, tensorResultsList; + for (OpOperand &opOperand : op->getOpOperands()) { + if (isa(opOperand.get().getType())) { + tensorOperandsList.insert(opOperand.getOperandNumber()); + } + } + for (OpResult result : op->getResults()) { + if (isa(result.getType())) { + tensorResultsList.insert(result.getResultNumber()); + } + } + return blockDynamicDimensions(rewriter, dynamicDimAnalysis, op, + tensorOperandsList, tensorResultsList); +} + +/// Block dynamic dimensions in operands of `LinalgOp`. +static LogicalResult +blockDynamicDimensions(RewriterBase &rewriter, + const TensorDynamicDimAnalysis &dynamicDimAnalysis, + linalg::LinalgOp linalgOp) { + if (linalg::isaContractionOpInterface(linalgOp)) { + return blockDynamicDimensionsOfAllTensorOperandsAndResults( + rewriter, dynamicDimAnalysis, linalgOp); + } + return success(); +} + +/// Block dynamic dimensions in operands of `AttentionOp`. static LogicalResult blockDynamicDimensions(RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis, IREE::LinalgExt::AttentionOp attentionOp) { // Only block the q and k values. - llvm::SmallDenseSet prunedOperandsList; + llvm::SmallDenseSet prunedOperandsList, prunedResultsList; prunedOperandsList.insert(attentionOp.getQueryMutable().getOperandNumber()); prunedOperandsList.insert(attentionOp.getKeyMutable().getOperandNumber()); return blockDynamicDimensions(rewriter, dynamicDimAnalysis, attentionOp, - prunedOperandsList); + prunedOperandsList, prunedResultsList); } -/// Generic method to block dynamic dimensions for all tensor operands. -/// Only used for testing for now +/// Dispatch to methods that block dynamic dimensions of operations. static LogicalResult blockDynamicDimensions(RewriterBase &rewriter, const TensorDynamicDimAnalysis &dynamicDimAnalysis, - Operation *operation, bool test) { + Operation *operation) { return TypeSwitch(operation) .Case([&](auto attentionOp) { return blockDynamicDimensions(rewriter, dynamicDimAnalysis, attentionOp); }) - .Default([&](Operation *op) { - if (!test) { - return success(); - } - // The default path here is for now only for testing. - llvm::SmallDenseSet tensorOperandsList; - for (OpOperand &opOperand : operation->getOpOperands()) { - if (isa(opOperand.get().getType())) { - tensorOperandsList.insert(opOperand.getOperandNumber()); - } - } - return blockDynamicDimensions(rewriter, dynamicDimAnalysis, operation, - tensorOperandsList); - }); + .Case([&](auto linalgOp) { + return blockDynamicDimensions(rewriter, dynamicDimAnalysis, linalgOp); + }) + .Default([&](Operation *op) { return success(); }); } void BlockDynamicDimensionsPass::runOnOperation() { @@ -261,7 +306,7 @@ void BlockDynamicDimensionsPass::runOnOperation() { IRRewriter rewriter(context); auto walkResult = operation->walk([&](Operation *op) -> WalkResult { rewriter.setInsertionPoint(op); - return blockDynamicDimensions(rewriter, dynamicDimAnalysis, op, test); + return blockDynamicDimensions(rewriter, dynamicDimAnalysis, op); }); if (walkResult.wasInterrupted()) { return signalPassFailure(); @@ -278,7 +323,11 @@ void BlockDynamicDimensionsPass::runOnOperation() { // Add patterns to "push down" the `tensor.collapse_shape` patterns (which // are the dual of the patterns to "bubble up" `tensor.expand_shape` // patterns) - linalg::ControlFusionFn controlFn = [](OpOperand *) { return true; }; + linalg::ControlFusionFn controlFn = [](OpOperand *opOperand) { + // Avoid fusion with fills/empty using the propagation patterns. + return !isa_and_nonnull( + opOperand->get().getDefiningOp()); + }; linalg::populateFoldReshapeOpsByExpansionPatterns(bubbleExpandShapePatterns, controlFn); IREE::LinalgExt::populateFoldReshapeOpsByExpansionPatterns( @@ -288,6 +337,8 @@ void BlockDynamicDimensionsPass::runOnOperation() { // bindings or `tensor.empty` operations. populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns); tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); + linalg::FillOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, + context); // Add some additional patterns that can simplify the IR and remove dead // operations. memref::populateResolveRankedShapedTypeResultDimsPatterns( @@ -315,6 +366,11 @@ void BlockDynamicDimensionsPass::runOnOperation() { context); tensor::CollapseShapeOp::getCanonicalizationPatterns( removeBarrierOpsPatterns, context); + tensor::populateFoldTensorEmptyPatterns(removeBarrierOpsPatterns); + linalg::FillOp::getCanonicalizationPatterns(removeBarrierOpsPatterns, + context); + memref::populateResolveRankedShapedTypeResultDimsPatterns( + removeBarrierOpsPatterns); if (failed(applyPatternsAndFoldGreedily( operation, std::move(removeBarrierOpsPatterns)))) { operation->emitOpError("failed in cleanup patterns"); diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index a649f3703ae8..493afa843f81 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -23,9 +23,6 @@ def BlockDynamicDimensionsPass : Pass<"iree-codegen-block-dynamic-dimensions"> { let summary = "Expand dynamic dimensions that are known to be multiples of " "statically known values."; - let options = [ - Option<"test", "test", "bool", /*default=*/"false", "Enable test mode"> - ]; } def BubbleUpOrdinalOpsPass : Pass<"iree-codegen-bubble-up-ordinal-ops", ""> { diff --git a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir index 4dab261953cd..7f7ba90493de 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/block_dynamic_dims.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions{test}, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-block-dynamic-dimensions, cse))" --split-input-file --mlir-print-local-scope %s | FileCheck %s #pipeline_layout = #hal.pipeline.layout, @@ -102,33 +102,129 @@ func.func @block_attention_dims() { // ----- -func.func @basic_blocking_test(%arg0 : index) -> tensor { +func.func @basic_blocking_test(%arg0 : index) -> tensor { %0 = util.assume.int %arg0 : index - %1 = tensor.empty(%0) : tensor - return %1 : tensor + %lhs = tensor.empty(%0) : tensor + %rhs = tensor.empty() : tensor<2048x4096xf32> + %init = tensor.empty(%0) : tensor + %matmul = linalg.matmul ins(%lhs, %rhs : tensor, tensor<2048x4096xf32>) + outs(%init : tensor) -> tensor + return %matmul : tensor } // CHECK-LABEL: func @basic_blocking_test( -// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EMPTY]] +// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK: %[[MATMUL:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LHS]], +// CHECK-SAME: outs(%[[INIT]] : +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[MATMUL]] // CHECK: return %[[COLLAPSE]] // ----- -func.func @no_blocking(%arg0 : index) -> tensor { - %1 = tensor.empty(%arg0) : tensor - return %1 : tensor +func.func @no_blocking(%arg0 : index) -> tensor { + %lhs = tensor.empty(%arg0) : tensor + %rhs = tensor.empty() : tensor<2048x4096xf32> + %init = tensor.empty(%arg0) : tensor + %matmul = linalg.matmul ins(%lhs, %rhs : tensor, tensor<2048x4096xf32>) + outs(%init : tensor) -> tensor + return %matmul : tensor } // CHECK-LABEL: func @no_blocking( -// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor -// CHECK: return %[[EMPTY]] +// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK: %[[MATMUL:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS]], +// CHECK-SAME: outs(%[[INIT]] : +// CHECK: return %[[MATMUL]] // ----- -func.func @no_unit_blocking(%arg0 : index) -> tensor { +func.func @no_unit_blocking(%arg0 : index) -> tensor { %0 = util.assume.int %arg0 : index - %1 = tensor.empty(%0) : tensor - return %1 : tensor + %lhs = tensor.empty(%0) : tensor + %rhs = tensor.empty() : tensor<2048x4096xf32> + %init = tensor.empty(%0) : tensor + %matmul = linalg.matmul ins(%lhs, %rhs : tensor, tensor<2048x4096xf32>) + outs(%init : tensor) -> tensor + return %matmul : tensor } // CHECK-LABEL: func @no_unit_blocking( -// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor -// CHECK: return %[[EMPTY]] +// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK: %[[MATMUL:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS]], +// CHECK-SAME: outs(%[[INIT]] : +// CHECK: return %[[MATMUL]] + + +// ----- + +func.func @contract_op_interface_op(%rhs : tensor<2048x4096xf16>, %m : index) + -> tensor { + %0 = util.assume.int %m : index + %lhs = tensor.empty(%0) : tensor + %init = tensor.empty(%0) : tensor + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor<2048x4096xf16>) + outs(%init : tensor) { + ^bb0(%in: f16, %in_0: f16, %out: f32): + %17 = arith.extf %in : f16 to f32 + %18 = arith.extf %in_0 : f16 to f32 + %19 = arith.mulf %17, %18 : f32 + %20 = arith.addf %out, %19 : f32 + linalg.yield %20 : f32 + } -> tensor + return %1 : tensor +} +// CHECK-LABEL: func @contract_op_interface_op( +// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LHS]], +// CHECK-SAME: outs(%[[INIT]] : +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[GENERIC]] {{\[}}[0, 1], [2]{{\]}} +// CHECK: return %[[COLLAPSED]] + +// ----- + +func.func @reshape_propagation_test(%rhs : tensor<2048x4096xf16>, %m : index) + -> tensor { + %cst = arith.constant 0.0 : f32 + %0 = util.assume.int %m : index + %lhs = tensor.empty(%0) : tensor + %init = tensor.empty(%0) : tensor + %init2 = tensor.empty(%0) : tensor + %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor + %1 = linalg.matmul_transpose_b + ins(%lhs, %rhs : tensor, tensor<2048x4096xf16>) + outs(%fill : tensor) -> tensor + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%1 : tensor) outs(%init2 : tensor) { + ^bb0(%b0 : f32, %b1 : f16): + %3 = arith.truncf %b0 : f32 to f16 + linalg.yield %3 : f16 + } -> tensor + return %2 : tensor +} +// CHECK-LABEL: func @reshape_propagation_test( +// CHECK-DAG: %[[LHS:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK: %[[FILL:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT]] : +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LHS]], +// CHECK-SAME: outs(%[[FILL]] : +// CHECK: %[[EMPTY:.+]] = tensor.empty(%{{.+}}) : tensor +// CHECK: %[[TRUNC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC]] : +// CHECK-SAME: outs(%[[EMPTY]] : +// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[TRUNC]] +// CHECK: return %[[COLLAPSED]] diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir index 0c80b4830c53..520fbe9a378d 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir @@ -48,12 +48,55 @@ util.func @missing_udiv_skipped(%arg0 : index) -> index { // ----- -util.func @muli_divisibility(%arg0 : index) -> index { - %cst = arith.constant 16 : index - %0 = arith.muli %arg0, %cst : index - %1 = arith.remui %0, %cst : index - util.return %1 : index +util.func @muli_divisibility(%arg0 : index) -> (index, index) { + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %0 = arith.muli %arg0, %c16 : index + %1 = arith.remui %0, %c16 : index + %2 = arith.remui %0, %c32 : index + util.return %1, %2 : index, index } // CHECK-LABEL: @muli_divisibility -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: return %[[C0]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK: %[[V:.+]] = arith.muli +// CHECK: %[[REM:.+]] = arith.remui %[[V]], %[[C32]] +// CHECK: return %[[C0]], %[[REM]] + +// ----- + +util.func @muli_compounded_divisibility(%arg0 : index) -> (index, index) { + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %0 = util.assume.int %arg0 : index + %1 = arith.muli %0, %c16 : index + %2 = arith.remui %1, %c64 : index + %3 = arith.remui %1, %c128 : index + util.return %2, %3 : index, index +} +// CHECK-LABEL: @muli_compounded_divisibility +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index +// CHECK: %[[V:.+]] = arith.muli +// CHECK: %[[REM:.+]] = arith.remui %[[V]], %[[C128]] +// CHECK: return %[[C0]], %[[REM]] + +// ----- + +util.func @divui_divisibility(%arg0 : index) -> (index, index) { + %c4 = arith.constant 4 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %0 = util.assume.int %arg0 : index + %1 = arith.divui %0, %c4 : index + %2 = arith.remui %1, %c16 : index + %3 = arith.remui %1, %c32 : index + util.return %2, %3 : index, index +} +// CHECK-LABEL: @divui_divisibility +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index +// CHECK: %[[V:.+]] = arith.divui +// CHECK: %[[REM:.+]] = arith.remui %[[V]], %[[C32]] +// CHECK: return %[[C0]], %[[REM]] diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index 1ce8e18119ab..db5fa4007269 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -27,6 +27,21 @@ namespace { // InferIntDivisibilityOpInterface //===----------------------------------------------------------------------===// +static IREE::Util::ConstantIntDivisibility +getDivisibilityOfOperand(Value v, + IREE::Util::IntegerDivisibility divisibility) { + if (!divisibility.isUninitialized()) { + return divisibility.getValue(); + } + APInt intVal; + if (matchPattern(v, m_ConstantInt(&intVal))) { + uint64_t udiv = intVal.getZExtValue(); + uint64_t sdiv = std::abs(intVal.getSExtValue()); + return IREE::Util::ConstantIntDivisibility(udiv, sdiv); + } + return IREE::Util::ConstantIntDivisibility(1, 1); +} + struct ArithConstantInferIntDivisibilityOpInterface : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< ArithConstantInferIntDivisibilityOpInterface, arith::ConstantOp> { @@ -54,17 +69,38 @@ struct ArithMulIInferIntDivisibilityOpInterface Operation *op, ArrayRef argDivs, IREE::Util::SetIntDivisibilityFn setResultDivs) const { auto mulOp = cast(op); + + auto lhsDivisibility = getDivisibilityOfOperand(mulOp.getLhs(), argDivs[0]); + auto rhsDivisibility = getDivisibilityOfOperand(mulOp.getRhs(), argDivs[1]); + + uint64_t mulUDiv = lhsDivisibility.udiv() * rhsDivisibility.udiv(); + uint64_t mulSDiv = lhsDivisibility.sdiv() * rhsDivisibility.sdiv(); + + setResultDivs(mulOp.getResult(), + IREE::Util::ConstantIntDivisibility(mulUDiv, mulSDiv)); + } +}; + +struct ArithDivUIInferIntDivisibilityOpInterface + : public IREE::Util::InferIntDivisibilityOpInterface::ExternalModel< + ArithDivUIInferIntDivisibilityOpInterface, arith::DivUIOp> { + + void inferResultDivisibility( + Operation *op, ArrayRef argDivs, + IREE::Util::SetIntDivisibilityFn setResultDivs) const { + auto divOp = cast(op); + APInt intVal; - if (!matchPattern(mulOp.getLhs(), m_ConstantInt(&intVal))) { - if (!matchPattern(mulOp.getRhs(), m_ConstantInt(&intVal))) { - return; - } + if (!matchPattern(divOp.getRhs(), m_ConstantInt(&intVal))) { + return; } - uint64_t udiv = intVal.getZExtValue(); - uint64_t sdiv = std::abs(intVal.getSExtValue()); - setResultDivs(mulOp.getResult(), - IREE::Util::ConstantIntDivisibility(udiv, sdiv)); + auto lhsDivisibility = getDivisibilityOfOperand(divOp.getLhs(), argDivs[0]); + + uint64_t divUDiv = lhsDivisibility.udiv() / intVal.getZExtValue(); + uint64_t divSDiv = lhsDivisibility.sdiv() / std::abs(intVal.getSExtValue()); + + setResultDivs(divOp, IREE::Util::ConstantIntDivisibility(divUDiv, divSDiv)); } }; @@ -353,6 +389,8 @@ void registerUtilExternalModels(DialectRegistry ®istry) { ArithConstantInferIntDivisibilityOpInterface>(*context); arith::MulIOp::attachInterface( *context); + arith::DivUIOp::attachInterface( + *context); }); registry.addExtension(