Skip to content

Commit

Permalink
[GPUDistributionPatterns] Propagate predicate attribute for cmpf op (i…
Browse files Browse the repository at this point in the history
…ree-org#17664)

We aren't propagating the predicate attribute for cmpf if it is there,
this fixes the missing predicate attribute for cmpf.

---------

Signed-off-by: aviator19941 <[email protected]>
  • Loading branch information
aviator19941 authored Jun 18, 2024
1 parent 1f954b2 commit 2b3c46c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,7 @@ struct DistributeElementwise final
}

// Replace the original op with the distributed op.
Operation *distributedOp = rewriter.create(
op->getLoc(), op->getName().getIdentifier(), operands, resultTypes);

// Propagate known attributes.
StringRef fastmathAttrName = arith::FastMathFlagsAttr::getMnemonic();
if (Attribute attr = op->getAttr(fastmathAttrName)) {
distributedOp->setAttr(fastmathAttrName, attr);
}
Operation *distributedOp = mlir::clone(rewriter, op, resultTypes, operands);

DistributionPattern::replaceOpWithDistributedValues(
rewriter, op, distributedOp->getResults());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,24 @@
#layout = #iree_vector_ext.layout<<[VECTORY, LANEY], [4, 4]>, <[VECTORX, LANEX], [4, 4]>>

// CHECK-LABEL: @distribute_elementwise_f16
func.func @distribute_elementwise_f16(%a: vector<16x16xf16>, %b: vector<16x16xf16>, %denom: vector<16x16xf16>) -> vector<16x16xf16> {
func.func @distribute_elementwise_f16(%a: vector<16x16xf16>, %b: vector<16x16xf16>, %denom: vector<16x16xf16>) -> vector<16x16xi1> {
%c0 = arith.constant 0 : index
%cst_0 = arith.constant 0.0 : f16
// CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16>
%root = arith.constant {"__vector_layout_test_anchor_result_0" = #layout} dense<0.0> : vector<16x16xf16>
// CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<16xf16>
// CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector<16xf16>
// CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] {{.*}} : vector<16xf16>
%c = arith.mulf %root, %b : vector<16x16xf16>
// CHECK-DAG: %[[DENOM:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<16xf16>
// CHECK-DAG: %[[DIVD:.*]] = arith.divf %[[C]], %[[DENOM]] : vector<16xf16>
// CHECK-DAG: %[[DIVD:.*]] = arith.divf %[[C]], %[[DENOM]] {{.*}} : vector<16xf16>
%divd = arith.divf %c, %denom : vector<16x16xf16>
// CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xf16> -> vector<16xf16>
// CHECK-DAG: %[[D:.*]] = arith.addf %[[DIVD]], %[[A]] fastmath<reassoc,nnan> : vector<16xf16>
// CHECK-DAG: %[[D:.*]] = arith.addf %[[DIVD]], %[[A]] fastmath<reassoc,nnan> {{.*}} : vector<16xf16>
%d = arith.addf %divd, %a fastmath<reassoc,nnan> : vector<16x16xf16>
// CHECK: iree_vector_ext.to_simd %[[D]] : vector<16xf16> -> vector<16x16xf16>
return %d : vector<16x16xf16>
// CHECK-DAG: %[[R:.*]] = arith.cmpf ult, %[[D]], %[[ROOT]] {{.*}} : vector<16xf16>
%r = arith.cmpf ult, %d, %root : vector<16x16xf16>
// CHECK: iree_vector_ext.to_simd %[[R]] : vector<16xi1> -> vector<16x16xi1>
return %r : vector<16x16xi1>
}

// CHECK-LABEL: @distribute_elementwise_i32
Expand All @@ -28,10 +30,10 @@ func.func @distribute_elementwise_i32(%a: vector<16x16xi32>, %b: vector<16x16xi3
// CHECK: %[[ROOT:.*]] = arith.constant dense<2> : vector<16xi32>
%root = arith.constant {"__vector_layout_test_anchor_result_0" = #layout} dense<2> : vector<16x16xi32>
// CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<16xi32>
// CHECK-DAG: %[[C:.*]] = arith.muli %[[B]], %[[ROOT]] : vector<16xi32>
// CHECK-DAG: %[[C:.*]] = arith.muli %[[B]], %[[ROOT]] {{.*}} : vector<16xi32>
%c = arith.muli %root, %b : vector<16x16xi32>
// CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<16xi32>
// CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] : vector<16xi32>
// CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] {{.*}} : vector<16xi32>
%d = arith.addi %c, %a : vector<16x16xi32>
// CHECK: iree_vector_ext.to_simd %[[D]] : vector<16xi32> -> vector<16x16xi32>
return %d : vector<16x16xi32>
Expand All @@ -58,10 +60,10 @@ func.func @distribute_elementwise_nested_layout_f16(%a: vector<128x128x128xf16>,
// CHECK: %[[ROOT:.*]] = arith.constant dense<0.000000e+00> : vector<8x2x4x1x4x4x1x8x2xf16>
%root = arith.constant {"__vector_layout_test_anchor_result_0" = #nested} dense<0.0> : vector<128x128x128xf16>
// CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x8x2xf16>
// CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] : vector<8x2x4x1x4x4x1x8x2xf16>
// CHECK-DAG: %[[C:.*]] = arith.mulf %[[B]], %[[ROOT]] {{.*}} : vector<8x2x4x1x4x4x1x8x2xf16>
%c = arith.mulf %root, %b : vector<128x128x128xf16>
// CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<128x128x128xf16> -> vector<8x2x4x1x4x4x1x8x2xf16>
// CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> : vector<8x2x4x1x4x4x1x8x2xf16>
// CHECK-DAG: %[[D:.*]] = arith.addf %[[C]], %[[A]] fastmath<reassoc,nnan> {{.*}} : vector<8x2x4x1x4x4x1x8x2xf16>
%d = arith.addf %c, %a fastmath<reassoc,nnan> : vector<128x128x128xf16>
// CHECK: iree_vector_ext.to_simd %[[D]] : vector<8x2x4x1x4x4x1x8x2xf16> -> vector<128x128x128xf16>
return %d : vector<128x128x128xf16>
Expand All @@ -81,10 +83,10 @@ func.func @distribute_scf_for(%a: vector<16x16xi32>, %b: vector<16x16xi32>) -> v
// Canonicalization currently breaks other tests. If canonicalization
// is ever ran, this should be updated.
// CHECK-DAG: %[[B:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<16xi32>
// CHECK-DAG: %[[C:.*]] = arith.muli %[[ARG0]], %[[B]] : vector<16xi32>
// CHECK-DAG: %[[C:.*]] = arith.muli %[[ARG0]], %[[B]] {{.*}} : vector<16xi32>
%c = arith.muli %arg0, %b : vector<16x16xi32>
// CHECK-DAG: %[[A:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<16x16xi32> -> vector<16xi32>
// CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] : vector<16xi32>
// CHECK-DAG: %[[D:.*]] = arith.addi %[[C]], %[[A]] {{.*}} : vector<16xi32>
%d = arith.addi %c, %a : vector<16x16xi32>
// CHECK: scf.yield %[[D]] : vector<16xi32>
scf.yield %d : vector<16x16xi32>
Expand Down Expand Up @@ -616,7 +618,7 @@ func.func @resolved_layout_conflict(%a : memref<32x16xf16>, %b : memref<32x16xf1
// CHECK: %[[R1:.*]] = vector.insert_strided_slice %[[R0]], %[[CST0]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x1x4xf16> into vector<2x1x4xf16>
// CHECK: %[[R2:.*]] = vector.extract_strided_slice %[[V0]] {offsets = [0, 0, 4], sizes = [1, 1, 4], strides = [1, 1, 1]} : vector<1x1x8xf16> to vector<1x1x4xf16>
// CHECK: %[[R3:.*]] = vector.insert_strided_slice %[[R2]], %[[R1]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x1x4xf16> into vector<2x1x4xf16>
// CHECK: %[[R4:.*]] = arith.addf %[[R3]], %[[R3]] : vector<2x1x4xf16>
// CHECK: %[[R4:.*]] = arith.addf %[[R3]], %[[R3]] {{.*}} : vector<2x1x4xf16>
%vec2 = arith.addf %vec, %vec : vector<32x16xf16>
// CHECK-COUNT-8: vector.store {{.*}}, vector<1xf16>
vector.transfer_write %vec2, %b[%c0, %c0] {in_bounds = [true, true],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
// CHECK-LABEL: func.func @matmul_256x256x256_f16_f16()
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: scf.for {{.*}} = %c0 to %c256 step %c128 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<2x2x1x1x4x1xf16>)
// CHECK: arith.extf %[[ARG]] : vector<2x2x1x1x4x1xf16> to vector<2x2x1x1x4x1xf32>
// CHECK: arith.extf %[[ARG]] {{.*}} : vector<2x2x1x1x4x1xf16> to vector<2x2x1x1x4x1xf32>
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<2x2x1x1x4x1xf32> to vector<2x2x1x1x4x1xf16>
// CHECK: scf.yield %[[TRUNC]] : vector<2x2x1x1x4x1xf16>
Expand Down Expand Up @@ -161,7 +161,7 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {
// This has more than 2 iteartions. So we have prefetching enabled for this case. Due to
// prefetching, we have one iteration peeled of so upper bound is 2048 - 128 = 1920.
// CHECK: scf.for {{.*}} = %c0 to %c15 step %c1 iter_args(%[[ARG:.+]] = {{.*}}) -> (vector<4x1x1x1x4x1xf16>)
// CHECK: arith.extf %[[ARG]] : vector<4x1x1x1x4x1xf16> to vector<4x1x1x1x4x1xf32>
// CHECK: arith.extf %[[ARG]] {{.*}} : vector<4x1x1x1x4x1xf16> to vector<4x1x1x1x4x1xf32>
// CHECK-COUNT-32: amdgpu.mfma {{.*}} {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: %[[TRUNC:.+]] = arith.truncf %{{.*}} : vector<4x1x1x1x4x1xf32> to vector<4x1x1x1x4x1xf16>
// CHECK: scf.yield %[[TRUNC]] : vector<4x1x1x1x4x1xf16>
Expand Down

0 comments on commit 2b3c46c

Please sign in to comment.