Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

testing: raise ddot bottom_up test to linalg.generic #3678

Merged
merged 4 commits into from
Dec 31, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions tests/filecheck/projects/riscv-backend-paper/bottom_up_f64.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -127,34 +127,35 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
// CHECK-NEXT: ret

func.func @ddot(
%X : memref<128xf64>,
%Y : memref<128xf64>,
%G : memref<f64>
%X: memref<128xf64> {llvm.noalias},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this attribute here but not other tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it's always needed it, and we make some incorrect assumptions of non-aliasing in our lowering. In the riscv-paper-experiments repo, we have them on all the kernels, as LLVM doesn't do the same optimisations without these annotations.

%Y: memref<128xf64> {llvm.noalias},
%Z: memref<f64> {llvm.noalias}
) {
memref_stream.streaming_region {
patterns = [
#memref_stream.stride_pattern<ub = [128], index_map = (d0) -> (d0)>,
#memref_stream.stride_pattern<ub = [128], index_map = (d0) -> (d0)>
]
} ins(%X, %Y : memref<128xf64>, memref<128xf64>) {
^0(%x_stream : !memref_stream.readable<f64>, %y_stream : !memref_stream.readable<f64>):
%zero_float = arith.constant 0.0 : f64

%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c128 = arith.constant 128 : i32
%g = scf.for %i = %c0 to %c128 step %c1 iter_args(%acc = %zero_float) -> (f64) : i32 {
%x = memref_stream.read from %x_stream : f64
%y = memref_stream.read from %y_stream : f64
%prod = arith.mulf %x, %y fastmath<fast> : f64
%res = arith.addf %prod, %acc fastmath<fast> : f64
scf.yield %res : f64
}

memref.store %g, %G[] : memref<f64>
%zero_float = arith.constant 0.000000e+00 : f64
linalg.generic {
indexing_maps = [
affine_map<() -> ()>,
affine_map<() -> ()>
],
iterator_types = []
} ins(%zero_float : f64) outs(%Z : memref<f64>) {
^bb0(%in: f64, %out: f64):
linalg.yield %in : f64
}

func.return
linalg.generic {
indexing_maps = [
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> ()>
],
iterator_types = ["reduction"]
} ins(%X, %Y : memref<128xf64>, memref<128xf64>) outs(%Z : memref<f64>) {
^bb0(%x: f64, %y: f64, %acc: f64):
%prod = arith.mulf %x, %y fastmath<fast> : f64
%acc_new = arith.addf %prod, %acc fastmath<fast> : f64
linalg.yield %acc_new : f64
}
return
}


Expand All @@ -166,6 +167,7 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
// CHECK-NEXT: mv t2, a0
// CHECK-NEXT: mv t1, a1
// CHECK-NEXT: mv t0, a2
// CHECK-NEXT: fcvt.d.w ft3, zero
// CHECK-NEXT: li t3, 127
// CHECK-NEXT: scfgwi t3, 95 # dm 31 dim 0 bound
// CHECK-NEXT: li t3, 8
Expand All @@ -174,7 +176,6 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
// CHECK-NEXT: scfgwi t2, 768 # dm 0 dim 0 source
// CHECK-NEXT: scfgwi t1, 769 # dm 1 dim 0 source
// CHECK-NEXT: csrrsi zero, 1984, 1 # SSR enable
// CHECK-NEXT: fcvt.d.w ft3, zero
// CHECK-NEXT: li t1, 127
// CHECK-NEXT: frep.o t1, 1, 0, 0
// CHECK-NEXT: fmadd.d ft3, ft0, ft1, ft3
Expand Down
Loading