Skip to content

Commit

Permalink
dialects: Update dmp.swap (#3056)
Browse files Browse the repository at this point in the history
Give it a proper signature for value-semantics and reference-semantics.
It now needs to return values if taking values in, and need only an
in-place operand if reference-semantics.

Update the distribution pass, to replace the loaded values by the
swapped values, building a consistent def-use chain.

---------

Co-authored-by: n-io <[email protected]>
  • Loading branch information
PapyChacal and n-io authored Aug 19, 2024
1 parent 5c4bcf4 commit 4704025
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 53 deletions.
16 changes: 9 additions & 7 deletions tests/filecheck/dialects/dmp/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
// RUN: xdsl-opt -p canonicalize-dmp %s | filecheck %s

builtin.module {
%ref = "test.op"() : () -> (memref<1024x1024xf32>)
%ref = "test.op"() : () -> (!stencil.field<[0,1024]x[0,1024]xf32>)
%val = "test.op"() : () -> (!stencil.temp<[0,1024]x[0,1024]xf32>)

"dmp.swap"(%ref) {
strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>,
swaps = [
#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>,
#dmp.exchange<at [4, 104] size [100, 0] source offset [0, -4] to [1, 0]>
]
} : (memref<1024x1024xf32>) -> ()
} : (!stencil.field<[0,1024]x[0,1024]xf32>) -> ()

// CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>]} : (memref<1024x1024xf32>) -> ()
// CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>]} : (!stencil.field<[0,1024]x[0,1024]xf32>) -> ()

"dmp.swap"(%ref) {
%swap_val = "dmp.swap"(%val) {
strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>,
swaps = [
#dmp.exchange<at [4, 0] size [100, 0] source offset [0, 4] to [-1, 0]>,
#dmp.exchange<at [4, 104] size [100, 0] source offset [0, -4] to [1, 0]>
]
} : (memref<1024x1024xf32>) -> ()
} : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> (!stencil.temp<[0,1024]x[0,1024]xf32>)

"test.op"() : () -> ()
"test.op"(%swap_val) : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> ()

// this op should be completely removed since both exchanges are empty, so we expect the next op to be a test.op
// CHECK-NEXT: "test.op"() : () -> ()
// and its operand replaced by the unswapped input
// CHECK-NEXT: "test.op"(%val) : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> ()
}
17 changes: 14 additions & 3 deletions tests/filecheck/dialects/dmp/ops.mlir
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
// RUN: XDSL_ROUNDTRIP

builtin.module {
%ref = "test.op"() : () -> (memref<1024x1024xf32>)
%ref = "test.op"() : () -> (!stencil.field<[0,1024]x[0,1024]xf32>)
%val = "test.op"() : () -> (!stencil.temp<[0,1024]x[0,1024]xf32>)

"dmp.swap"(%ref) {
strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>,
swaps = [
#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>,
#dmp.exchange<at [4, 104] size [100, 4] source offset [0, -4] to [1, 0]>
]
} : (memref<1024x1024xf32>) -> ()
} : (!stencil.field<[0,1024]x[0,1024]xf32>) -> ()

// CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>, #dmp.exchange<at [4, 104] size [100, 4] source offset [0, -4] to [1, 0]>]} : (memref<1024x1024xf32>) -> ()

%swap_val = "dmp.swap"(%val) {
strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>,
swaps = [
#dmp.exchange<at [4, 0] size [100, 2] source offset [0, 4] to [-1, 0]>,
#dmp.exchange<at [4, 104] size [100, 2] source offset [0, -4] to [1, 0]>
]
} : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> (!stencil.temp<[0,1024]x[0,1024]xf32>)

// CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>, #dmp.exchange<at [4, 104] size [100, 4] source offset [0, -4] to [1, 0]>]} : (!stencil.field<[0,1024]x[0,1024]xf32>) -> ()
// CHECK-NEXT: %swap_val = "dmp.swap"(%val) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange<at [4, 0] size [100, 2] source offset [0, 4] to [-1, 0]>, #dmp.exchange<at [4, 104] size [100, 2] source offset [0, -4] to [1, 0]>]} : (!stencil.temp<[0,1024]x[0,1024]xf32>) -> !stencil.temp<[0,1024]x[0,1024]xf32>
}
60 changes: 30 additions & 30 deletions tests/filecheck/transforms/distribute-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,45 @@

// CHECK: func.func @offsets(%0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) {
// CHECK-NEXT: %3 = stencil.load %0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -> !stencil.temp<?x?x?xf64>
// CHECK-NEXT: "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = []} : (!stencil.temp<?x?x?xf64>) -> ()
// CHECK-NEXT: %4, %5 = stencil.apply(%6 = %3 : !stencil.temp<?x?x?xf64>) -> (!stencil.temp<?x?x?xf64>, !stencil.temp<?x?x?xf64>) {
// CHECK-NEXT: %7 = stencil.access %6[-1, 0, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %8 = stencil.access %6[1, 0, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %9 = stencil.access %6[0, 1, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %10 = stencil.access %6[0, -1, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %11 = stencil.access %6[0, 0, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %12 = arith.addf %7, %8 : f64
// CHECK-NEXT: %13 = arith.addf %9, %10 : f64
// CHECK-NEXT: %14 = arith.addf %12, %13 : f64
// CHECK-NEXT: %4 = "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = []} : (!stencil.temp<?x?x?xf64>) -> !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %5, %6 = stencil.apply(%7 = %4 : !stencil.temp<?x?x?xf64>) -> (!stencil.temp<?x?x?xf64>, !stencil.temp<?x?x?xf64>) {
// CHECK-NEXT: %8 = stencil.access %7[-1, 0, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %9 = stencil.access %7[1, 0, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %10 = stencil.access %7[0, 1, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %11 = stencil.access %7[0, -1, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %12 = stencil.access %7[0, 0, 0] : !stencil.temp<?x?x?xf64>
// CHECK-NEXT: %13 = arith.addf %8, %9 : f64
// CHECK-NEXT: %14 = arith.addf %10, %11 : f64
// CHECK-NEXT: %15 = arith.addf %13, %14 : f64
// CHECK-NEXT: %cst = arith.constant -4.000000e+00 : f64
// CHECK-NEXT: %15 = arith.mulf %11, %cst : f64
// CHECK-NEXT: %16 = arith.addf %15, %14 : f64
// CHECK-NEXT: stencil.return %16, %15 : f64, f64
// CHECK-NEXT: %16 = arith.mulf %12, %cst : f64
// CHECK-NEXT: %17 = arith.addf %16, %15 : f64
// CHECK-NEXT: stencil.return %17, %16 : f64, f64
// CHECK-NEXT: }
// CHECK-NEXT: stencil.store %4 to %1(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<?x?x?xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// CHECK-NEXT: stencil.store %5 to %2(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<?x?x?xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// CHECK-NEXT: stencil.store %5 to %1(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<?x?x?xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// CHECK-NEXT: stencil.store %6 to %2(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<?x?x?xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// CHECK-NEXT: func.return
// CHECK-NEXT: }

// SHAPE: func.func @offsets(%0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) {
// SHAPE-NEXT: %3 = stencil.load %0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -> !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = [#dmp.exchange<at [32, 0, 0] size [1, 32, 32] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 32, 32] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 32, 0] size [32, 1, 32] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [32, 1, 32] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> ()
// SHAPE-NEXT: %4, %5 = stencil.apply(%6 = %3 : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> (!stencil.temp<[0,32]x[0,32]x[0,32]xf64>, !stencil.temp<[0,32]x[0,32]x[0,32]xf64>) {
// SHAPE-NEXT: %7 = stencil.access %6[-1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %8 = stencil.access %6[1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %9 = stencil.access %6[0, 1, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %10 = stencil.access %6[0, -1, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %11 = stencil.access %6[0, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %12 = arith.addf %7, %8 : f64
// SHAPE-NEXT: %13 = arith.addf %9, %10 : f64
// SHAPE-NEXT: %14 = arith.addf %12, %13 : f64
// SHAPE-NEXT: %4 = "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = [#dmp.exchange<at [32, 0, 0] size [1, 32, 32] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 32, 32] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 32, 0] size [32, 1, 32] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [32, 1, 32] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %5, %6 = stencil.apply(%7 = %4 : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> (!stencil.temp<[0,32]x[0,32]x[0,32]xf64>, !stencil.temp<[0,32]x[0,32]x[0,32]xf64>) {
// SHAPE-NEXT: %8 = stencil.access %7[-1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %9 = stencil.access %7[1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %10 = stencil.access %7[0, 1, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %11 = stencil.access %7[0, -1, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %12 = stencil.access %7[0, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// SHAPE-NEXT: %13 = arith.addf %8, %9 : f64
// SHAPE-NEXT: %14 = arith.addf %10, %11 : f64
// SHAPE-NEXT: %15 = arith.addf %13, %14 : f64
// SHAPE-NEXT: %cst = arith.constant -4.000000e+00 : f64
// SHAPE-NEXT: %15 = arith.mulf %11, %cst : f64
// SHAPE-NEXT: %16 = arith.addf %15, %14 : f64
// SHAPE-NEXT: stencil.return %16, %15 : f64, f64
// SHAPE-NEXT: %16 = arith.mulf %12, %cst : f64
// SHAPE-NEXT: %17 = arith.addf %16, %15 : f64
// SHAPE-NEXT: stencil.return %17, %16 : f64, f64
// SHAPE-NEXT: }
// SHAPE-NEXT: stencil.store %4 to %1(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// SHAPE-NEXT: stencil.store %5 to %2(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// SHAPE-NEXT: stencil.store %5 to %1(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// SHAPE-NEXT: stencil.store %6 to %2(<[0, 0, 0], [32, 32, 32]>) : !stencil.temp<[0,32]x[0,32]x[0,32]xf64> to !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>
// SHAPE-NEXT: func.return
// SHAPE-NEXT: }

Expand Down
5 changes: 3 additions & 2 deletions tests/filecheck/transforms/stencil-to-csl-stencil.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// RUN: xdsl-opt %s -p "stencil-to-csl-stencil{num_chunks=2}" | filecheck %s

builtin.module {

func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
"dmp.swap"(%0) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange<at [1, 0, 0] size [1, 1, 510] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 1, 510] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 1, 0] size [1, 1, 510] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [1, 1, 510] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> ()
%1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
%24 = "dmp.swap"(%0) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange<at [1, 0, 0] size [1, 1, 510] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 1, 510] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 1, 0] size [1, 1, 510] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [1, 1, 510] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>)
%1 = stencil.apply(%2 = %24 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
%3 = arith.constant 1.666600e-01 : f32
%4 = stencil.access %2[1, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%5 = "tensor.extract_slice"(%4) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class PrefetchOp(IRDLOperation):
name = "csl_stencil.prefetch"

input_stencil = operand_def(
base(stencil.TempType[Attribute])
base(stencil.StencilType[Attribute])
| base(memref.MemRefType[Attribute])
| base(TensorType[Attribute])
)
Expand Down
34 changes: 28 additions & 6 deletions xdsl/dialects/experimental/dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
Operand,
ParameterDef,
attr_def,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
opt_result_def,
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer
from xdsl.traits import HasShapeInferencePatternsTrait
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa

# helpers for named dimensions:
Expand Down Expand Up @@ -585,9 +586,12 @@ def coords(where: Literal["start", "end"]):
class SwapOpHasShapeInferencePatterns(HasShapeInferencePatternsTrait):
@classmethod
def get_shape_inference_patterns(cls):
from xdsl.transforms.shape_inference_patterns.dmp import DmpSwapShapeInference
from xdsl.transforms.shape_inference_patterns.dmp import (
DmpSwapShapeInference,
DmpSwapSwapsInference,
)

return (DmpSwapShapeInference(),)
return (DmpSwapShapeInference(), DmpSwapSwapsInference())


@irdl_op_definition
Expand All @@ -598,20 +602,38 @@ class SwapOp(IRDLOperation):

name = "dmp.swap"

input_stencil: Operand = operand_def(
base(stencil.AnyTempType) | base(builtin.AnyMemRefType)
)
input_stencil: Operand = operand_def(stencil.StencilType[Attribute])
swapped_values = opt_result_def(stencil.TempType[Attribute])

swaps = attr_def(builtin.ArrayAttr[ExchangeDeclarationAttr])

strategy = attr_def(DomainDecompositionStrategy)

traits = frozenset([SwapOpHasShapeInferencePatterns()])

def verify_(self) -> None:
if self.swapped_values:
if isinstance(self.input_stencil.type, stencil.FieldType):
raise VerifyException(
"dmp.swap_op cannot have a result if input is a field"
)
else:
if isinstance(self.input_stencil.type, stencil.TempType):
raise VerifyException(
"dmp.swap_op must have a result if input is a temporary"
)

@staticmethod
def get(input_stencil: SSAValue | Operation, strategy: DomainDecompositionStrategy):
input_type = SSAValue.get(input_stencil).type

result_types = (
input_type if isa(input_type, stencil.TempType[Attribute]) else None
)

return SwapOp.build(
operands=[input_stencil],
result_types=[result_types],
attributes={
"strategy": strategy,
"swaps": builtin.ArrayAttr[ExchangeDeclarationAttr](()),
Expand Down
9 changes: 7 additions & 2 deletions xdsl/transforms/canonicalize_dmp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from xdsl.dialects import builtin
from xdsl.dialects import builtin, stencil
from xdsl.dialects.experimental import dmp
from xdsl.passes import MLContext, ModulePass
from xdsl.pattern_rewriter import (
Expand All @@ -17,7 +17,12 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
if swap.elem_count > 0:
keeps.append(swap)
if len(keeps) == 0:
rewriter.erase_matched_op()
new_result = (
op.input_stencil
if isinstance(op.input_stencil.type, stencil.TempType)
else None
)
rewriter.replace_matched_op([], [new_result])
else:
op.swaps = builtin.ArrayAttr(keeps)

Expand Down
6 changes: 6 additions & 0 deletions xdsl/transforms/experimental/dmp/stencil_global_to_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ class AddHaloExchangeOps(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.LoadOp, rewriter: PatternRewriter, /):
swap_op = dmp.SwapOp.get(op.res, self.strategy)
assert swap_op.swapped_values
rewriter.insert_op_after_matched_op(swap_op)
for use in op.res.uses.copy():
if use.operation is swap_op:
continue
use.operation.operands[use.index] = swap_op.swapped_values
rewriter.handle_operation_modification(use.operation)


@dataclass
Expand Down
22 changes: 21 additions & 1 deletion xdsl/transforms/shape_inference_patterns/dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@


class DmpSwapShapeInference(RewritePattern):
"""
Infer the shape of the `dmp.swap` operation.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter):
if not op.swapped_values:
return
swap_t = op.swapped_values.type
if not isinstance(swap_t, stencil.TempType):
return
if op.input_stencil.type != swap_t:
op.input_stencil.type = swap_t
rewrite.handle_operation_modification(op)


class DmpSwapSwapsInference(RewritePattern):
"""
Infer the exact exchanges this `dmp.swap` needs to perform.
"""
Expand All @@ -19,7 +36,10 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter):
core_lb: stencil.IndexAttr | None = None
core_ub: stencil.IndexAttr | None = None

for use in op.input_stencil.uses:
if not op.swapped_values:
return

for use in op.swapped_values.uses:
if not isinstance(use.operation, stencil.ApplyOp):
continue
assert use.operation.res
Expand Down
2 changes: 1 addition & 1 deletion xdsl/transforms/stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
)

# a little hack to get around a check that prevents replacing a no-results op with an n-results op
rewriter.replace_matched_op(prefetch_op, new_results=[])
rewriter.replace_matched_op(prefetch_op, new_results=[op.input_stencil])

# uses have to be retrieved *before* the loop because of the rewriting happening inside the loop
uses = list(op.input_stencil.uses)
Expand Down

0 comments on commit 4704025

Please sign in to comment.