diff --git a/tests/filecheck/dialects/csl/csl-stencil-ops.mlir b/tests/filecheck/dialects/csl/csl-stencil-ops.mlir index cd34692b93..986c640bab 100644 --- a/tests/filecheck/dialects/csl/csl-stencil-ops.mlir +++ b/tests/filecheck/dialects/csl/csl-stencil-ops.mlir @@ -81,9 +81,12 @@ builtin.module { // CHECK-GENERIC-NEXT: %16 = "arith.addf"(%15, %6) <{"fastmath" = #arith.fastmath}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32> // CHECK-GENERIC-NEXT: %17 = "arith.addf"(%16, %5) <{"fastmath" = #arith.fastmath}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32> // CHECK-GENERIC-NEXT: %18 = "tensor.empty"() : () -> tensor<510xf32> -// CHECK-GENERIC-NEXT: %19 = "linalg.fill"(%4, %18) <{"operandSegmentSizes" = array}> : (f32, tensor<510xf32>) -> tensor<510xf32> -// CHECK-GENERIC-NEXT: %20 = "arith.mulf"(%17, %19) <{"fastmath" = #arith.fastmath}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32> -// CHECK-GENERIC-NEXT: "stencil.return"(%20) : (tensor<510xf32>) -> () +// CHECK-GENERIC-NEXT: %19 = "linalg.fill"(%4, %18) <{"operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^2(%20 : f32, %21 : f32): +// CHECK-GENERIC-NEXT: "linalg.yield"(%20) : (f32) -> () +// CHECK-GENERIC-NEXT: }) : (f32, tensor<510xf32>) -> tensor<510xf32> +// CHECK-GENERIC-NEXT: %22 = "arith.mulf"(%17, %19) <{"fastmath" = #arith.fastmath}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32> +// CHECK-GENERIC-NEXT: "stencil.return"(%22) : (tensor<510xf32>) -> () // CHECK-GENERIC-NEXT: }) : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, tensor<4x510xf32>) -> !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> // CHECK-GENERIC-NEXT: "stencil.store"(%1, %b) {"bounds" = #stencil.bounds<[0, 0], [1, 1]>} : (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> () // CHECK-GENERIC-NEXT: "func.return"() : () -> () @@ -195,9 +198,12 @@ builtin.module { // CHECK-GENERIC-NEXT: %17 = "arith.addf"(%16, %15) <{"fastmath" = #arith.fastmath}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32> // CHECK-GENERIC-NEXT: %18 = "arith.constant"() <{"value" = 1.666600e-01 : f32}> : () -> f32 // CHECK-GENERIC-NEXT: %19 = "tensor.empty"() : () -> tensor<510xf32> -// CHECK-GENERIC-NEXT: %20 = "linalg.fill"(%18, %19) <{"operandSegmentSizes" = array}> : (f32, tensor<510xf32>) -> tensor<510xf32> -// CHECK-GENERIC-NEXT: %21 = "arith.mulf"(%17, %20) <{"fastmath" = #arith.fastmath}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32> -// CHECK-GENERIC-NEXT: "csl_stencil.yield"(%21) : (tensor<510xf32>) -> () +// CHECK-GENERIC-NEXT: %20 = "linalg.fill"(%18, %19) <{"operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^3(%21 : f32, %22 : f32): +// CHECK-GENERIC-NEXT: "linalg.yield"(%21) : (f32) -> () +// CHECK-GENERIC-NEXT: }) : (f32, tensor<510xf32>) -> tensor<510xf32> +// CHECK-GENERIC-NEXT: %23 = "arith.mulf"(%17, %20) <{"fastmath" = #arith.fastmath}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32> +// CHECK-GENERIC-NEXT: "csl_stencil.yield"(%23) : (tensor<510xf32>) -> () // CHECK-GENERIC-NEXT: }) : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, tensor<510xf32>) -> !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> // CHECK-GENERIC-NEXT: "stencil.store"(%2, %b) {"bounds" = #stencil.bounds<[0, 0], [1, 1]>} : (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> () // CHECK-GENERIC-NEXT: "func.return"() : () -> () diff --git a/tests/filecheck/dialects/linalg/linalg_ops.mlir b/tests/filecheck/dialects/linalg/linalg_ops.mlir index 313e8fd8f8..96734abf13 100644 --- a/tests/filecheck/dialects/linalg/linalg_ops.mlir +++ b/tests/filecheck/dialects/linalg/linalg_ops.mlir @@ -28,69 +28,97 @@ linalg.add ins(%m1, %m2 : memref<4x16xf32>, memref<4x16xf32>) outs(%m3 : memref< %mul = linalg.mul ins(%t1, %t2 : tensor<4x16xf32>, tensor<4x16xf32>) outs(%t3 : tensor<4x16xf32>) -> tensor<4x16xf32> linalg.mul ins(%m1, %m2 : memref<4x16xf32>, memref<4x16xf32>) outs(%m3 : memref<4x16xf32>) - %2, %3 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) %4 = "test.op"() : () -> (memref<64x4096xf32>) linalg.matmul {id} ins(%2, %3 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%4 : memref<64x4096xf32>) +%fill = linalg.fill ins(%0 : f32) outs(%t3 : tensor<4x16xf32>) -> tensor<4x16xf32> +linalg.fill ins(%0 : f32) outs(%m3 : memref<4x16xf32>) + // CHECK: module { -// CHECK-NEXT: %0, %1 = "test.op"() : () -> (f32, memref<1x256xf32>) -// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : f32) outs(%1 : memref<1x256xf32>) { -// CHECK-NEXT: ^0(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: %{{.*}} %{{.*}} = "test.op"() : () -> (f32, memref<1x256xf32>) +// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%{{.*}} : f32) outs(%{{.*}} : memref<1x256xf32>) { +// CHECK-NEXT: ^0(%{{.*}} f32, %{{.*}} f32): // CHECK-NEXT: linalg.yield %{{.*}} : f32 // CHECK-NEXT: } -// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], doc = "a_docstring", library_call = "a_library_call"} ins(%0 : f32) outs(%1 : memref<1x256xf32>) { -// CHECK-NEXT: ^1(%arg3_1 : f32, %arg4_1 : f32): -// CHECK-NEXT: linalg.yield %arg3_1 : f32 +// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], doc = "a_docstring", library_call = "a_library_call"} ins(%{{.*}} : f32) outs(%{{.*}} : memref<1x256xf32>) { +// CHECK-NEXT: ^1(%{{.*}} : f32, %{{.*}} : f32): +// CHECK-NEXT: linalg.yield %{{.*}} : f32 // CHECK-NEXT: } -// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : f32) outs(%1 : memref<1x256xf32>) attrs = {"hello" = "world"} { -// CHECK-NEXT: ^{{.*}}(%{{.*}}: f32, %{{.*}}: f32): +// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%{{.*}} : f32) outs(%{{.*}} : memref<1x256xf32>) attrs = {"hello" = "world"} { +// CHECK-NEXT: ^{{.*}}(%{{.*}} f32, %{{.*}} f32): // CHECK-NEXT: linalg.yield %{{.*}} : f32 // CHECK-NEXT: } -// CHECK-NEXT: %t1, %t2, %t3 = "test.op"() : () -> (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>) -// CHECK-NEXT: %m1, %m2, %m3 = "test.op"() : () -> (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>) -// CHECK-NEXT: %sum = linalg.add ins(%t1, %t2 : tensor<4x16xf32>, tensor<4x16xf32>) outs(%t3 : tensor<4x16xf32>) -> tensor<4x16xf32> -// CHECK-NEXT: linalg.add ins(%m1, %m2 : memref<4x16xf32>, memref<4x16xf32>) outs(%m3 : memref<4x16xf32>) -// CHECK-NEXT: %mul = linalg.mul ins(%t1, %t2 : tensor<4x16xf32>, tensor<4x16xf32>) outs(%t3 : tensor<4x16xf32>) -> tensor<4x16xf32> -// CHECK-NEXT: linalg.mul ins(%m1, %m2 : memref<4x16xf32>, memref<4x16xf32>) outs(%m3 : memref<4x16xf32>) -// CHECK-NEXT: %2, %3 = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) -// CHECK-NEXT: %4 = "test.op"() : () -> memref<64x4096xf32> -// CHECK-NEXT: linalg.matmul {"id"} ins(%2, %3 : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%4 : memref<64x4096xf32>) +// CHECK-NEXT: %{{.*}} %{{.*}} %{{.*}} = "test.op"() : () -> (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>) +// CHECK-NEXT: %{{.*}} %{{.*}} %{{.*}} = "test.op"() : () -> (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>) +// CHECK-NEXT: %{{.*}} = linalg.add ins(%{{.*}} %{{.*}} : tensor<4x16xf32>, tensor<4x16xf32>) outs(%{{.*}} : tensor<4x16xf32>) -> tensor<4x16xf32> +// CHECK-NEXT: linalg.add ins(%{{.*}} %{{.*}} : memref<4x16xf32>, memref<4x16xf32>) outs(%{{.*}} : memref<4x16xf32>) +// CHECK-NEXT: %{{.*}} = linalg.mul ins(%{{.*}} %{{.*}} : tensor<4x16xf32>, tensor<4x16xf32>) outs(%{{.*}} : tensor<4x16xf32>) -> tensor<4x16xf32> +// CHECK-NEXT: linalg.mul ins(%{{.*}} %{{.*}} : memref<4x16xf32>, memref<4x16xf32>) outs(%{{.*}} : memref<4x16xf32>) +// CHECK-NEXT: %{{.*}} %{{.*}} = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) +// CHECK-NEXT: %{{.*}} = "test.op"() : () -> memref<64x4096xf32> +// CHECK-NEXT: linalg.matmul {"id"} ins(%{{.*}} %{{.*}} : memref<64x9216xf32>, memref<9216x4096xf32>) outs(%{{.*}} : memref<64x4096xf32>) +// CHECK-NEXT: %{{.*}} = linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<4x16xf32>) -> tensor<4x16xf32> +// CHECK-NEXT: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : memref<4x16xf32>) // CHECK-NEXT: } -// CHECK-GENERIC: "linalg.generic"(%0, %1) <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type, #linalg.iterator_type], "operandSegmentSizes" = array}> ({ -// CHECK-GENERIC-NEXT: ^0(%{{.*}}: f32, %{{.*}}: f32): -// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f32) -> () +// CHECK-GENERIC: "linalg.generic"(%{{.*}} %{{.*}} <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type, #linalg.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^0(%{{.*}} f32, %{{.*}} f32): +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () // CHECK-GENERIC-NEXT: }) : (f32, memref<1x256xf32>) -> () -// CHECK-GENERIC-NEXT: "linalg.generic"(%0, %1) <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type, #linalg.iterator_type], "doc" = "a_docstring", "library_call" = "a_library_call", "operandSegmentSizes" = array}> ({ -// CHECK-GENERIC-NEXT: ^1(%arg3_1 : f32, %arg4_1 : f32): -// CHECK-GENERIC-NEXT: "linalg.yield"(%arg3_1) : (f32) -> () +// CHECK-GENERIC-NEXT: "linalg.generic"(%{{.*}} %{{.*}} <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type, #linalg.iterator_type], "doc" = "a_docstring", "library_call" = "a_library_call", "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^1(%{{.*}} : f32, %{{.*}} : f32): +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () // CHECK-GENERIC-NEXT: }) : (f32, memref<1x256xf32>) -> () -// CHECK-GENERIC: "linalg.generic"(%0, %1) <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type, #linalg.iterator_type], "operandSegmentSizes" = array}> ({ -// CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}}: f32, %{{.*}}: f32): -// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f32) -> () +// CHECK-GENERIC: "linalg.generic"(%{{.*}} %{{.*}} <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type, #linalg.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}} f32, %{{.*}} f32): +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () // CHECK-GENERIC-NEXT: }) {"hello" = "world"} : (f32, memref<1x256xf32>) -> () -// CHECK-GENERIC-NEXT: %t1, %t2, %t3 = "test.op"() : () -> (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>) -// CHECK-GENERIC-NEXT: %m1, %m2, %m3 = "test.op"() : () -> (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>) -// CHECK-GENERIC-NEXT: %sum = "linalg.add"(%t1, %t2, %t3) <{"operandSegmentSizes" = array}> ({ -// CHECK-GENERIC-NEXT: ^3(%2 : f32, %3 : f32, %4 : f32): -// CHECK-GENERIC-NEXT: %5 = "arith.addf"(%2, %3) : (f32, f32) -> f32 -// CHECK-GENERIC-NEXT: "linalg.yield"(%5) : (f32) -> () +// CHECK-GENERIC-NEXT: %{{.*}} %{{.*}} %{{.*}} = "test.op"() : () -> (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>) +// CHECK-GENERIC-NEXT: %{{.*}} %{{.*}} %{{.*}} = "test.op"() : () -> (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>) + +// CHECK-GENERIC-NEXT: %{{.*}} = "linalg.add"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^3(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32): +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}} %{{.*}} : (f32, f32) -> f32 +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () // CHECK-GENERIC-NEXT: }) : (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> -// CHECK-GENERIC-NEXT: "linalg.add"(%m1, %m2, %m3) <{"operandSegmentSizes" = array}> ({ -// CHECK-GENERIC-NEXT: ^4(%6 : f32, %7 : f32, %8 : f32): -// CHECK-GENERIC-NEXT: %9 = "arith.addf"(%6, %7) : (f32, f32) -> f32 -// CHECK-GENERIC-NEXT: "linalg.yield"(%9) : (f32) -> () + +// CHECK-GENERIC-NEXT: "linalg.add"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^4(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32): +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}} %{{.*}} : (f32, f32) -> f32 +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () // CHECK-GENERIC-NEXT: }) : (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>) -> () -// CHECK-GENERIC-NEXT: %mul = "linalg.mul"(%t1, %t2, %t3) <{"operandSegmentSizes" = array}> ({ -// CHECK-GENERIC-NEXT: ^5(%10 : f32, %11 : f32, %12 : f32): -// CHECK-GENERIC-NEXT: %13 = "arith.mulf"(%10, %11) : (f32, f32) -> f32 -// CHECK-GENERIC-NEXT: "linalg.yield"(%13) : (f32) -> () + +// CHECK-GENERIC-NEXT: %{{.*}} = "linalg.mul"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^5(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32): +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}} %{{.*}} : (f32, f32) -> f32 +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () // CHECK-GENERIC-NEXT: }) : (tensor<4x16xf32>, tensor<4x16xf32>, tensor<4x16xf32>) -> tensor<4x16xf32> -// CHECK-GENERIC-NEXT: "linalg.mul"(%m1, %m2, %m3) <{"operandSegmentSizes" = array}> ({ -// CHECK-GENERIC-NEXT: ^6(%14 : f32, %15 : f32, %16 : f32): -// CHECK-GENERIC-NEXT: %17 = "arith.mulf"(%14, %15) : (f32, f32) -> f32 -// CHECK-GENERIC-NEXT: "linalg.yield"(%17) : (f32) -> () + +// CHECK-GENERIC-NEXT: "linalg.mul"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^6(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32): +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}} %{{.*}} : (f32, f32) -> f32 +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () // CHECK-GENERIC-NEXT: }) : (memref<4x16xf32>, memref<4x16xf32>, memref<4x16xf32>) -> () + +// CHECK-GENERIC-NEXT: %{{.*}} %{{.*}} = "test.op"() : () -> (memref<64x9216xf32>, memref<9216x4096xf32>) +// CHECK-GENERIC-NEXT: %{{.*}} = "test.op"() : () -> memref<64x4096xf32> + +// CHECK-GENERIC-NEXT: "linalg.matmul"(%{{.*}} %{{.*}} %{{.*}} <{"operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^7(%{{.*}} : f32, %{{.*}} : f32, %{{.*}} : f32): +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}}, %{{.*}} : (f32, f32) -> f32 +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}} : (f32, f32) -> f32 +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () +// CHECK-GENERIC-NEXT: }) {"id", "linalg.memoized_indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>]} : (memref<64x9216xf32>, memref<9216x4096xf32>, memref<64x4096xf32>) -> () + +// CHECK-GENERIC-NEXT: %{{.*}} = "linalg.fill"(%{{.*}}, %{{.*}} <{"operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^8(%{{.*}} : f32, %{{.*}} : f32): +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () +// CHECK-GENERIC-NEXT: }) : (f32, tensor<4x16xf32>) -> tensor<4x16xf32> + +// CHECK-GENERIC-NEXT: "linalg.fill"(%{{.*}}, %{{.*}} <{"operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^9(%{{.*}} : f32, %{{.*}} : f32): +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}} : (f32) -> () +// CHECK-GENERIC-NEXT: }) : (f32, memref<4x16xf32>) -> () diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir index f6899827e0..8efdb8b3cb 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/linalg/invalid_ops.mlir @@ -5,7 +5,7 @@ builtin.module { %0, %1 = "test.op"() : () -> (tensor<2x3xf32>, tensor<2x3xf32>) // CHECK: Input type is tensor<2x3xf32> but must be an instance of AnyFloat or IntegerType. - %res_fill = "linalg.fill"(%0, %1) <{"operandSegmentSizes" = array}> : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %res_fill = linalg.fill ins (%0: tensor<2x3xf32>) outs (%1: tensor<2x3xf32>) -> tensor<2x3xf32> } diff --git a/tests/interpreters/test_linalg_interpreter.py b/tests/interpreters/test_linalg_interpreter.py index a35d12f4c3..ebdeda1089 100644 --- a/tests/interpreters/test_linalg_interpreter.py +++ b/tests/interpreters/test_linalg_interpreter.py @@ -1,5 +1,3 @@ -from typing import cast - import pytest from xdsl.builder import ImplicitBuilder @@ -22,7 +20,7 @@ from xdsl.interpreters.linalg import LinalgFunctions from xdsl.interpreters.ptr import TypedPtr from xdsl.interpreters.shaped_array import ShapedArray -from xdsl.ir import Attribute, Block, Region +from xdsl.ir import Block, Region from xdsl.ir.affine import AffineExpr, AffineMap from xdsl.utils.test_value import TestSSAValue @@ -199,9 +197,8 @@ def test_fill_op(): interpreter.register_implementations(ArithFunctions()) interpreter.register_implementations(LinalgFunctions()) constant = arith.Constant(FloatAttr(1.0, f32)) - constant = cast(Attribute, constant) op = linalg.FillOp( - (TestSSAValue(constant),), + (TestSSAValue(constant.result.type),), (TestSSAValue(TensorType(f32, [2, 3])),), (TensorType(f32, [2, 3]),), ) diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 0b9f75302f..5daff4f980 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -3,10 +3,11 @@ from abc import ABC from collections.abc import Mapping, Sequence from enum import auto -from typing import cast +from typing import ClassVar, cast from typing_extensions import Self +from xdsl.builder import Builder from xdsl.dialects import arith from xdsl.dialects.builtin import ( AffineMapAttr, @@ -29,10 +30,9 @@ ) from xdsl.ir import ( Attribute, - Block, + BlockArgument, Dialect, EnumAttribute, - Operation, Region, SSAValue, ) @@ -391,9 +391,9 @@ class YieldOp(AbstractYieldOperation[Attribute]): traits = frozenset([IsTerminator()]) -class NamedOpBase(IRDLOperation): +class NamedOpBase(IRDLOperation, ABC): """ - Base class for named ops with hidden region. + Abstract base class for named ops with hidden region. """ inputs = var_operand_def() @@ -405,6 +405,8 @@ class NamedOpBase(IRDLOperation): irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()] + PRINT_ATTRS_IN_FRONT: ClassVar[bool] = False + def __init__( self, ins: Sequence[SSAValue], @@ -413,24 +415,7 @@ def __init__( properties: Mapping[str, Attribute | None] | None = None, attributes: Mapping[str, Attribute | None] | None = None, hidden_region: Region | None = None, - arith_op: type[Operation] | None = None, ): - if (hidden_region is None) == (arith_op is None): - raise ValueError("Specify either hidden_region or arith_op but not both") - - if hidden_region is None: - assert arith_op is not None - hidden_region = Region(Block(arg_types=[t.type for t in ins])) - hidden_region.block.add_ops( - ( - op := arith_op( - operands=[arg for arg in ins], - result_types=[t.type for t in outs], - ), - YieldOp(*op.results), - ) - ) - super().__init__( operands=[ins, outs], result_types=( @@ -446,6 +431,10 @@ def __init__( @classmethod def parse(cls, parser: Parser): pos = parser.pos + if cls.PRINT_ATTRS_IN_FRONT: + attrs = parser.parse_optional_attr_dict() + else: + attrs = {} if parser.parse_optional_characters("ins"): parser.parse_punctuation("(") unresolved_ins = parser.parse_comma_separated_list( @@ -475,13 +464,14 @@ def parse(cls, parser: Parser): else: outs = () - if parser.parse_optional_keyword("attrs"): - parser.parse_punctuation("=") - attrs = parser.expect( - parser.parse_optional_attr_dict, "expect extra attributes" - ) - else: - attrs = {} + if not cls.PRINT_ATTRS_IN_FRONT: + if parser.parse_optional_keyword("attrs"): + parser.parse_punctuation("=") + attrs = parser.expect( + parser.parse_optional_attr_dict, "expect extra attributes" + ) + else: + attrs = {} if parser.parse_optional_punctuation("->"): res_types = parser.parse_optional_comma_separated_list( @@ -495,6 +485,21 @@ def parse(cls, parser: Parser): return cls(ins, outs, res_types, attrs) def print(self, printer: Printer): + + extra_attrs = self.attributes.copy() + if "indexing_maps" in extra_attrs: + del extra_attrs["indexing_maps"] + if "linalg.memoized_indexing_maps" in extra_attrs: + del extra_attrs["linalg.memoized_indexing_maps"] + if "iterator_types" in extra_attrs: + del extra_attrs["iterator_types"] + if "doc" in extra_attrs: + del extra_attrs["doc"] + if "library_call" in extra_attrs: + del extra_attrs["library_call"] + + if extra_attrs and self.PRINT_ATTRS_IN_FRONT: + printer.print_op_attributes(extra_attrs) if self.inputs: printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) @@ -509,17 +514,7 @@ def print(self, printer: Printer): printer.print_list((o.type for o in self.outputs), printer.print_attribute) printer.print_string(")") - extra_attrs = self.attributes.copy() - if "indexing_maps" in extra_attrs: - del extra_attrs["indexing_maps"] - if "iterator_types" in extra_attrs: - del extra_attrs["iterator_types"] - if "doc" in extra_attrs: - del extra_attrs["doc"] - if "library_call" in extra_attrs: - del extra_attrs["library_call"] - - if extra_attrs: + if extra_attrs and not self.PRINT_ATTRS_IN_FRONT: printer.print(" attrs = ") printer.print_op_attributes(extra_attrs) @@ -534,6 +529,29 @@ def print(self, printer: Printer): ) printer.print(")") + @staticmethod + def body_arg_types( + operands: Sequence[SSAValue], + ) -> Sequence[AnyFloat | IntegerType]: + """ + Return the element types of the arguments of the body of this operation + """ + + result: Sequence[AnyFloat | IntegerType] = [] + + for op in operands: + op_type = op.type + if isa(op_type, MemRefType[Attribute]): + element_type = op_type.get_element_type() + elif isa(op_type, TensorType[Attribute]): + element_type = op_type.get_element_type() + else: # int or float + element_type = op_type + assert isinstance(element_type, AnyFloat | IntegerType) + result.append(element_type) + + return result + @irdl_op_definition class AddOp(NamedOpBase): @@ -557,20 +575,13 @@ def __init__( else: result_types = res - assert len(outputs) == 1 - assert isa(outputs[0].type, TensorType[Attribute] | MemRefType[Attribute]) - element_t = outputs[0].type.get_element_type() - hidden_region = Region(Block(arg_types=(element_t, element_t, element_t))) - hidden_region.block.add_ops( - ( - op := arith.Addf( - hidden_region.block.args[0], - hidden_region.block.args[1], - result_type=element_t, - ), - YieldOp(*op.results), - ) - ) + arg_types = self.body_arg_types((*inputs, *outputs)) + add = arith.Addf if isinstance(arg_types[-1], AnyFloat) else arith.Addi + + @Builder.implicit_region(arg_types) + def hidden_region(args: tuple[BlockArgument, ...]) -> None: + result = add(args[0], args[1]) + YieldOp(result) super().__init__( ins=inputs, @@ -603,20 +614,13 @@ def __init__( else: result_types = res - assert len(outputs) == 1 - assert isa(outputs[0].type, TensorType[Attribute] | MemRefType[Attribute]) - element_t = outputs[0].type.get_element_type() - hidden_region = Region(Block(arg_types=(element_t, element_t, element_t))) - hidden_region.block.add_ops( - ( - op := arith.Subf( - hidden_region.block.args[0], - hidden_region.block.args[1], - result_type=element_t, - ), - YieldOp(*op.results), - ) - ) + arg_types = self.body_arg_types((*inputs, *outputs)) + sub = arith.Subf if isinstance(arg_types[-1], AnyFloat) else arith.Subi + + @Builder.implicit_region(arg_types) + def hidden_region(args: tuple[BlockArgument, ...]) -> None: + result = sub(args[0], args[1]) + YieldOp(result) super().__init__( ins=inputs, @@ -628,7 +632,7 @@ def __init__( @irdl_op_definition -class FillOp(IRDLOperation): +class FillOp(NamedOpBase): """ Fills the output tensor with the given value. @@ -641,23 +645,12 @@ class FillOp(IRDLOperation): name = "linalg.fill" - inputs = var_operand_def() - outputs = var_operand_def(AnyShapedType()) - - res = var_result_def(AnyTensorType) - - assembly_format = ( - "`ins` `(` $inputs `:` type($inputs) `)` ` ` " - "`outs` `(` $outputs `:` type($outputs) `)` (`->` type($res)^)? attr-dict" - ) - - irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()] - def __init__( self, - inputs: Sequence[SSAValue | Operation], - outputs: Sequence[SSAValue | Operation] = (), + inputs: Sequence[SSAValue], + outputs: Sequence[SSAValue] = (), res: Sequence[Attribute] | None = None, + attributes: dict[str, Attribute] | None = None, ): if res is None: assert isa(outputs, Sequence[SSAValue]), "cannot infer result_types" @@ -665,9 +658,18 @@ def __init__( else: result_types = res + arg_types = self.body_arg_types((*inputs, *outputs)) + + @Builder.implicit_region(arg_types) + def hidden_region(args: tuple[BlockArgument, ...]) -> None: + YieldOp(args[0]) + super().__init__( - operands=(inputs, outputs), + ins=inputs, + outs=outputs, result_types=result_types, + attributes=attributes, + hidden_region=hidden_region, ) def verify_(self) -> None: @@ -701,20 +703,13 @@ def __init__( else: result_types = res - assert len(outputs) == 1 - assert isa(outputs[0].type, TensorType[Attribute] | MemRefType[Attribute]) - element_t = outputs[0].type.get_element_type() - hidden_region = Region(Block(arg_types=(element_t, element_t, element_t))) - hidden_region.block.add_ops( - ( - op := arith.Mulf( - hidden_region.block.args[0], - hidden_region.block.args[1], - result_type=element_t, - ), - YieldOp(*op.results), - ) - ) + arg_types = self.body_arg_types((*inputs, *outputs)) + mul = arith.Mulf if isinstance(arg_types[-1], AnyFloat) else arith.Muli + + @Builder.implicit_region(arg_types) + def hidden_region(args: tuple[BlockArgument, ...]) -> None: + result = mul(args[0], args[1]) + YieldOp(result) super().__init__( ins=inputs, @@ -831,7 +826,7 @@ def parse(cls, parser: Parser) -> Self: @irdl_op_definition -class MatmulOp(IRDLOperation): +class MatmulOp(NamedOpBase): """ Performs a matrix multiplication of two 2D inputs. @@ -841,23 +836,14 @@ class MatmulOp(IRDLOperation): name = "linalg.matmul" - inputs = var_operand_def() - outputs = var_operand_def(AnyShapedType()) - - res = var_result_def(AnyTensorType) - - assembly_format = ( - "attr-dict `ins` `(` $inputs `:` type($inputs) `)` ` ` " - "`outs` `(` $outputs `:` type($outputs) `)` (`->` type($res)^)?" - ) - - irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()] + PRINT_ATTRS_IN_FRONT: ClassVar[bool] = True def __init__( self, inputs: Sequence[SSAValue], outputs: Sequence[SSAValue] = (), res: Sequence[Attribute] | None = None, + attributes: dict[str, Attribute] | None = None, ): if res is None: result_types = tuple( @@ -867,9 +853,38 @@ def __init__( ) else: result_types = res + + arg_types = self.body_arg_types((*inputs, *outputs)) + add, mul = ( + (arith.Addf, arith.Mulf) + if isinstance(arg_types[-1], AnyFloat) + else (arith.Addi, arith.Mulf) + ) + + @Builder.implicit_region(arg_types) + def hidden_region(args: tuple[BlockArgument, ...]) -> None: + result = mul(args[0], args[1]) + mac = add(result, args[2]) + YieldOp(mac) + + # add linalg.memoized_indexing_maps attribute + if not attributes: + attributes = {} + if "linalg.memoized_indexing_maps" not in attributes: + attributes["linalg.memoized_indexing_maps"] = ArrayAttr( + [ + AffineMapAttr(AffineMap.from_callable(lambda i, _, k: (i, k))), + AffineMapAttr(AffineMap.from_callable(lambda _, j, k: (k, j))), + AffineMapAttr(AffineMap.from_callable(lambda i, j, _: (i, j))), + ] + ) + super().__init__( - operands=(inputs, outputs), - result_types=(result_types,), + ins=inputs, + outs=outputs, + result_types=result_types, + attributes=attributes, + hidden_region=hidden_region, ) diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 3f8ae6286b..9882bd3c8a 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -149,7 +149,7 @@ def arithBinaryOpTensorize( ) elif is_tensor(op.lhs.type) and is_scalar(op.rhs.type): emptyop = EmptyOp((), op.lhs.type) - fillop = FillOp((op.rhs,), (emptyop,), (op.lhs.type,)) + fillop = FillOp((op.rhs,), (emptyop.results[0],), (op.lhs.type,)) rewriter.insert_op(emptyop, InsertPoint.before(op)) rewriter.insert_op(fillop, InsertPoint.before(op)) rewriter.replace_matched_op( @@ -157,7 +157,7 @@ def arithBinaryOpTensorize( ) elif is_scalar(op.lhs.type) and is_tensor(op.rhs.type): emptyop = EmptyOp((), op.rhs.type) - fillop = FillOp((op.lhs,), (emptyop,), (op.rhs.type,)) + fillop = FillOp((op.lhs,), (emptyop.results[0],), (op.rhs.type,)) rewriter.insert_op(emptyop, InsertPoint.before(op)) rewriter.insert_op(fillop, InsertPoint.before(op)) rewriter.replace_matched_op(