diff --git a/tests/filecheck/transforms/memref_streamify.mlir b/tests/filecheck/transforms/memref_streamify.mlir index 817701331c..af2df183a3 100644 --- a/tests/filecheck/transforms/memref_streamify.mlir +++ b/tests/filecheck/transforms/memref_streamify.mlir @@ -6,6 +6,38 @@ // CHECK: builtin.module { +func.func @fill_empty_shape(%scalar: memref) { + %zero_float = arith.constant 0.000000e+00 : f64 + memref_stream.generic { + bounds = [], + indexing_maps = [ + affine_map<() -> ()>, + affine_map<() -> ()> + ], + iterator_types = [] + } ins(%zero_float : f64) outs(%scalar : memref) { + ^bb0(%in: f64, %out: f64): + linalg.yield %in : f64 + } + return +} + +// CHECK-NEXT: func.func @fill_empty_shape(%scalar : memref) { +// CHECK-NEXT: %zero_float = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: memref_stream.generic { +// CHECK-NEXT: bounds = [], +// CHECK-NEXT: indexing_maps = [ +// CHECK-NEXT: affine_map<() -> ()>, +// CHECK-NEXT: affine_map<() -> ()> +// CHECK-NEXT: ], +// CHECK-NEXT: iterator_types = [] +// CHECK-NEXT: } ins(%zero_float : f64) outs(%scalar : memref) { +// CHECK-NEXT: ^0(%in : f64, %out : f64): +// CHECK-NEXT: linalg.yield %in : f64 +// CHECK-NEXT: } +// CHECK-NEXT: func.return +// CHECK-NEXT: } + func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 : memref<8x16xf64>) -> memref<8x16xf64> { memref_stream.generic { bounds = [8, 16], diff --git a/xdsl/transforms/memref_streamify.py b/xdsl/transforms/memref_streamify.py index fb6a51be68..b00694d7a3 100644 --- a/xdsl/transforms/memref_streamify.py +++ b/xdsl/transforms/memref_streamify.py @@ -41,15 +41,17 @@ def match_and_rewrite( for index, (i, arg) in enumerate( zip(op.inputs, op.body.block.args[:input_count]) ) - if isinstance(i.type, memref.MemRefType) and arg.uses + if isinstance(i_type := i.type, memref.MemRefType) and arg.uses + if i_type.get_shape() ) streamable_output_indices = tuple( (index, arg.type) for index, (o, arg) in enumerate( zip(op.outputs, op.body.block.args[input_count:]) ) - if isinstance(o.type, memref.MemRefType) + if isinstance(o_type := o.type, memref.MemRefType) if index in init_indices or not arg.uses + if o_type.get_shape() ) if not streamable_input_indices and not streamable_output_indices: # No memrefs to convert to streams