From b08d152138e82f355c12bc5909b181adde6d4b0e Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Fri, 17 Jan 2025 11:15:02 -0800 Subject: [PATCH] [HAL] Use util.assume.int for memref alignments (#19691) When bufferizing, use util.assume.int to construct memref.assume_alignment, since we can use the divisibility on those assumptions constrain the subspan offset. --- .../Common/test/iree_comprehensive_bufferize.mlir | 13 +++++++++---- .../src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp | 3 +++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir index 953d361e3361..79d383779e1f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/test/iree_comprehensive_bufferize.mlir @@ -76,7 +76,7 @@ func.func @matmul() { // ----- -#pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, #hal.pipeline.binding @@ -84,15 +84,17 @@ func.func @matmul() { func.func @matmul_fill() { %cst = arith.constant 0.0 : f32 %c0 = arith.constant 0 : index - %c1024 = arith.constant 1024 : index %m = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index %n = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index %k = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index %base_offset_i32 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) alignment(8) : i32 %base_offset = arith.index_castui %base_offset_i32 : i32 to index + %res_offset_i32 = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : i32 + %res_offset_index = arith.index_castui %res_offset_i32 : i32 to index + %res_offset = util.assume.int %res_offset_index[, , ] : index %lhs = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(32) : !flow.dispatch.tensor>{%m, %k} %rhs = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%base_offset) : !flow.dispatch.tensor>{%k, %n} - %result = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c1024) : !flow.dispatch.tensor>{%m, %n} + %result = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%res_offset) : !flow.dispatch.tensor>{%m, %n} %wg_id_y = hal.interface.workgroup.id[1] : index %wg_count_y = hal.interface.workgroup.count[1] : index %wg_size_y = hal.interface.workgroup.size[1] : index @@ -127,11 +129,14 @@ func.func @matmul_fill() { // CHECK-DAG: %[[K:.+]] = hal.interface.constant.load layout({{.+}}) ordinal(2) // CHECK-DAG: %[[BASE_OFFSET_I32:.+]] = hal.interface.constant.load layout({{.+}}) ordinal(3) // CHECK-DAG: %[[BASE_OFFSET:.+]] = arith.index_castui %[[BASE_OFFSET_I32]] +// CHECK-DAG: %[[RES_OFFSET_I32:.+]] = hal.interface.constant.load layout({{.+}}) ordinal(4) +// CHECK-DAG: %[[RES_OFFSET_INDEX:.+]] = arith.index_castui %[[RES_OFFSET_I32]] +// CHECK-DAG: %[[RES_OFFSET:.+]] = util.assume.int %[[RES_OFFSET_INDEX]] // CHECK-DAG: %[[LHS:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(32) // CHECK-DAG: memref.assume_alignment %[[LHS]], 32 // CHECK-DAG: %[[RHS:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[BASE_OFFSET]]) // CHECK-DAG: memref.assume_alignment %[[RHS]], 8 -// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%c1024) +// CHECK-DAG: %[[RESULT:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[RES_OFFSET]]) // CHECK-DAG: memref.assume_alignment %[[RESULT]], 64 // CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1] // CHECK-DAG: %[[WG_COUNT_Y:.+]] = hal.interface.workgroup.count[1] diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp index 9e31af790e8b..ea460093e212 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALTypes.cpp @@ -55,6 +55,9 @@ std::optional lookupOffsetOrAlignment(Value value) { } } else if (auto castOp = dyn_cast(op)) { return lookupOffsetOrAlignment(castOp.getOperand()); + } else if (auto assumeOp = dyn_cast(op)) { + return assumeOp.getUnionedUnsignedDivisor( + cast(value).getResultNumber()); } // TODO(benvanik): more searching using util.align and other ops.