diff --git a/tests/dialects/test_dmp.py b/tests/dialects/test_dmp.py index 0859ed9769..0b95e5195e 100644 --- a/tests/dialects/test_dmp.py +++ b/tests/dialects/test_dmp.py @@ -1,5 +1,5 @@ +from xdsl.dialects.experimental import dmp from xdsl.dialects.experimental.dmp import ExchangeDeclarationAttr, ShapeAttr -from xdsl.transforms.experimental.dmp import decompositions def flat_face_exchanges( @@ -9,9 +9,7 @@ def flat_face_exchanges( # since this is a private function, and pyright will yell whenever it's accessed, # we have this wrapper function here that takes care of making the private publicly # accessible in the context of this test. - func = ( - decompositions._flat_face_exchanges_for_dim # pyright: ignore[reportPrivateUsage] - ) + func = dmp._flat_face_exchanges_for_dim # pyright: ignore[reportPrivateUsage] return func(shape, dim) diff --git a/tests/filecheck/dialects/dmp/canonicalize.mlir b/tests/filecheck/dialects/dmp/canonicalize.mlir index 36a6fa3e7a..200c83ba4c 100644 --- a/tests/filecheck/dialects/dmp/canonicalize.mlir +++ b/tests/filecheck/dialects/dmp/canonicalize.mlir @@ -4,17 +4,17 @@ builtin.module { %ref = "test.op"() : () -> (memref<1024x1024xf32>) "dmp.swap"(%ref) { - topo = #dmp.topo<2x2>, + strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, swaps = [ #dmp.exchange, #dmp.exchange ] } : (memref<1024x1024xf32>) -> () - // CHECK: "dmp.swap"(%ref) {"topo" = #dmp.topo<2x2>, "swaps" = [#dmp.exchange]} : (memref<1024x1024xf32>) -> () + // CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange]} : (memref<1024x1024xf32>) -> () "dmp.swap"(%ref) { - topo = #dmp.topo<2x2>, + strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, swaps = [ #dmp.exchange, #dmp.exchange diff --git a/tests/filecheck/dialects/dmp/ops.mlir b/tests/filecheck/dialects/dmp/ops.mlir index 30dbfc370e..31f1de33fb 100644 --- a/tests/filecheck/dialects/dmp/ops.mlir +++ b/tests/filecheck/dialects/dmp/ops.mlir @@ -4,12 +4,12 @@ builtin.module { %ref = "test.op"() : () -> (memref<1024x1024xf32>) "dmp.swap"(%ref) { - topo = #dmp.topo<2x2>, + strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, swaps = [ #dmp.exchange, #dmp.exchange ] } : (memref<1024x1024xf32>) -> () - // CHECK: "dmp.swap"(%ref) {"topo" = #dmp.topo<2x2>, "swaps" = [#dmp.exchange, #dmp.exchange]} : (memref<1024x1024xf32>) -> () + // CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange, #dmp.exchange]} : (memref<1024x1024xf32>) -> () } diff --git a/tests/filecheck/transforms/distribute-stencil.mlir b/tests/filecheck/transforms/distribute-stencil.mlir index 467b97447d..80c97c7f0d 100644 --- a/tests/filecheck/transforms/distribute-stencil.mlir +++ b/tests/filecheck/transforms/distribute-stencil.mlir @@ -26,7 +26,7 @@ builtin.module { // CHECK: func.func @offsets(%0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) { // CHECK-NEXT: %3 = stencil.load %0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -> !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> -// CHECK-NEXT: "dmp.swap"(%3) {"topo" = #dmp.topo<2x2x2>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> () +// CHECK-NEXT: "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> () // CHECK-NEXT: %4, %5 = stencil.apply(%6 = %3 : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> (!stencil.temp<[0,32]x[0,32]x[0,32]xf64>, !stencil.temp<[0,32]x[0,32]x[0,32]xf64>) { // CHECK-NEXT: %7 = stencil.access %6[-1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> // CHECK-NEXT: %8 = stencil.access %6[1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64> diff --git a/tests/filecheck/transforms/stencil-to-csl-stencil.mlir b/tests/filecheck/transforms/stencil-to-csl-stencil.mlir index b438945507..7868c2c279 100644 --- a/tests/filecheck/transforms/stencil-to-csl-stencil.mlir +++ b/tests/filecheck/transforms/stencil-to-csl-stencil.mlir @@ -3,7 +3,7 @@ builtin.module { func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> - "dmp.swap"(%0) {"topo" = #dmp.topo<1022x510>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> () + "dmp.swap"(%0) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange, #dmp.exchange, #dmp.exchange, #dmp.exchange]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> () %1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) { %3 = arith.constant 1.666600e-01 : f32 %4 = stencil.access %2[1, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> diff --git a/xdsl/dialects/experimental/dmp.py b/xdsl/dialects/experimental/dmp.py index e32c87147d..615e35ce81 100644 --- a/xdsl/dialects/experimental/dmp.py +++ b/xdsl/dialects/experimental/dmp.py @@ -10,7 +10,8 @@ from __future__ import annotations -from collections.abc import Sequence +from abc import ABC +from collections.abc import Iterable, Sequence from math import prod from typing import Literal @@ -20,11 +21,11 @@ IRDLOperation, Operand, ParameterDef, + attr_def, base, irdl_attr_definition, irdl_op_definition, operand_def, - opt_attr_def, ) from xdsl.parser import AttrParser from xdsl.printer import Printer @@ -392,7 +393,7 @@ class RankTopoAttr(ParametrizedAttribute): shape: ParameterDef[builtin.DenseArrayBase] - def __init__(self, shape: Sequence[int]): + def __init__(self, shape: Sequence[int] | Sequence[builtin.IntAttr]): if len(shape) < 1: raise ValueError("dmp.grid must have at least one dimension!") super().__init__([builtin.DenseArrayBase.from_list(builtin.i64, shape)]) @@ -426,6 +427,160 @@ def print_parameters(self, printer: Printer) -> None: printer.print_string(">") +class DomainDecompositionStrategy(ParametrizedAttribute, ABC): + def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]: + raise NotImplementedError("SlicingStrategy must implement calc_resize!") + + def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]: + raise NotImplementedError("SlicingStrategy must implement halo_exchange_defs!") + + def comm_layout(self) -> RankTopoAttr: + raise NotImplementedError("SlicingStrategy must implement comm_count!") + + +@irdl_attr_definition +class GridSlice2dAttr(DomainDecompositionStrategy): + """ + Takes a grid with two or more dimensions, slices it along the first two into equally + sized segments. + """ + + name = "dmp.grid_slice_2d" + + topology: ParameterDef[RankTopoAttr] + + diagonals: ParameterDef[builtin.BoolAttr] + + def __init__(self, topo: tuple[int, ...]): + super().__init__( + [RankTopoAttr(topo), builtin.BoolAttr.from_int_and_width(0, 1)] + ) + + def _verify(self): + assert ( + len(self.topology.as_tuple()) >= 2 + ), "GridSlice2d requires at least two dimensions" + + def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]: + assert len(shape) >= 2, "GridSlice2d requires at least two dimensions" + for size, node_count in zip(shape, self.topology.as_tuple()): + assert ( + size % node_count == 0 + ), "GridSlice2d requires domain be neatly divisible by shape" + return ( + *( + size // node_count + for size, node_count in zip(shape, self.topology.as_tuple()) + ), + *(size for size in shape[2:]), + ) + + def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]: + yield from _flat_face_exchanges_for_dim(shape, 0) + + yield from _flat_face_exchanges_for_dim(shape, 1) + + if self.diagonals.value.data: + raise NotImplementedError("Diagonals support not implemented yet") + + def comm_layout(self) -> RankTopoAttr: + return RankTopoAttr(self.topology.as_tuple()) + + +@irdl_attr_definition +class GridSlice3dAttr(DomainDecompositionStrategy): + """ + Takes a grid with two or more dimensions, slices it along the first three. + """ + + name = "dmp.grid_slice_3d" + + topology: ParameterDef[RankTopoAttr] + + diagonals: ParameterDef[builtin.BoolAttr] + + def __init__(self, topo: tuple[int, ...]): + super().__init__( + [RankTopoAttr(topo), builtin.BoolAttr.from_int_and_width(0, 1)] + ) + + def _verify(self): + assert ( + len(self.topology.as_tuple()) >= 3 + ), "GridSlice3d requires at least three dimensions" + + def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]: + assert len(shape) >= 3, "GridSlice3d requires at least two dimensions" + for size, node_count in zip(shape, self.topology.as_tuple()): + assert ( + size % node_count == 0 + ), "GridSlice3d requires domain be neatly divisible by shape" + return ( + *( + size // node_count + for size, node_count in zip(shape, self.topology.as_tuple()) + ), + *(size for size in shape[3:]), + ) + + def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]: + yield from _flat_face_exchanges_for_dim(shape, 0) + + yield from _flat_face_exchanges_for_dim(shape, 1) + + yield from _flat_face_exchanges_for_dim(shape, 2) + + if self.diagonals.value.data: + raise NotImplementedError("Diagonals support not implemented yet") + + def comm_layout(self) -> RankTopoAttr: + return RankTopoAttr(self.topology.as_tuple()) + + +def _flat_face_exchanges_for_dim( + shape: ShapeAttr, axis: int +) -> tuple[ExchangeDeclarationAttr, ExchangeDeclarationAttr]: + """ + Generate the two exchange delcarations to exchange the faces on the + axis "axis". + """ + dimensions = shape.dims + assert axis <= dimensions + + def coords(where: Literal["start", "end"]): + for d in range(dimensions): + # for the dim we want to exchange, return either start or end halo region + if d == axis: + if where == "start": + # "start" halo goes from buffer start to core start + yield shape.buffer_start(d), shape.core_start(d) + else: + # "end" halo goes from core end to buffer end + yield shape.core_end(d), shape.buffer_end(d) + else: + # for the sliced regions, "extrude" from core + # this way we don't exchange edges + yield shape.core_start(d), shape.core_end(d) + + ex1_coords = tuple(coords("end")) + ex2_coords = tuple(coords("start")) + + return ( + # towards positive dim: + ExchangeDeclarationAttr.from_points( + ex1_coords, + axis, + dir_sign=1, + ), + # towards negative dim: + ExchangeDeclarationAttr.from_points( + ex2_coords, + axis, + dir_sign=-1, + ), + ) + + @irdl_op_definition class SwapOp(IRDLOperation): """ @@ -438,15 +593,13 @@ class SwapOp(IRDLOperation): base(stencil.AnyTempType) | base(builtin.AnyMemRefType) ) - swaps: builtin.ArrayAttr[ExchangeDeclarationAttr] | None = opt_attr_def( - builtin.ArrayAttr[ExchangeDeclarationAttr] - ) + swaps = attr_def(builtin.ArrayAttr[ExchangeDeclarationAttr]) - topo: RankTopoAttr | None = opt_attr_def(RankTopoAttr) + strategy = attr_def(DomainDecompositionStrategy) @staticmethod - def get(input_stencil: SSAValue | Operation): - return SwapOp.build(operands=[input_stencil]) + def get(input_stencil: SSAValue | Operation, strategy: DomainDecompositionStrategy): + return SwapOp.build(operands=[input_stencil], attributes={"strategy": strategy}) DMP = Dialect( @@ -458,5 +611,7 @@ def get(input_stencil: SSAValue | Operation): ExchangeDeclarationAttr, ShapeAttr, RankTopoAttr, + GridSlice2dAttr, + GridSlice3dAttr, ], ) diff --git a/xdsl/transforms/canonicalize_dmp.py b/xdsl/transforms/canonicalize_dmp.py index 016df1b3b1..01b9a5b74d 100644 --- a/xdsl/transforms/canonicalize_dmp.py +++ b/xdsl/transforms/canonicalize_dmp.py @@ -13,8 +13,6 @@ class CanonicalizeDmpSwap(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): keeps: list[dmp.ExchangeDeclarationAttr] = [] - if op.swaps is None: - return for swap in op.swaps: if swap.elem_count > 0: keeps.append(swap) diff --git a/xdsl/transforms/experimental/dmp/decompositions.py b/xdsl/transforms/experimental/dmp/decompositions.py deleted file mode 100644 index c639d48141..0000000000 --- a/xdsl/transforms/experimental/dmp/decompositions.py +++ /dev/null @@ -1,149 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Iterable -from dataclasses import dataclass -from typing import Literal - -from xdsl.dialects.experimental import dmp - - -@dataclass -class DomainDecompositionStrategy(ABC): - def __init__(self, _: tuple[int, ...]): - pass - - @abstractmethod - def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]: - raise NotImplementedError("SlicingStrategy must implement calc_resize!") - - @abstractmethod - def halo_exchange_defs( - self, shape: dmp.ShapeAttr - ) -> Iterable[dmp.ExchangeDeclarationAttr]: - raise NotImplementedError("SlicingStrategy must implement halo_exchange_defs!") - - @abstractmethod - def comm_layout(self) -> dmp.RankTopoAttr: - raise NotImplementedError("SlicingStrategy must implement comm_count!") - - -@dataclass -class GridSlice2d(DomainDecompositionStrategy): - """ - Takes a grid with two or more dimensions, slices it along the first two into equally - sized segments. - """ - - topology: tuple[int, int] - - diagonals: bool = False - - def __post_init__(self): - assert len(self.topology) >= 2, "GridSlice2d requires at least two dimensions" - - def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]: - assert len(shape) >= 2, "GridSlice2d requires at least two dimensions" - for size, node_count in zip(shape, self.topology): - assert ( - size % node_count == 0 - ), "GridSlice2d requires domain be neatly divisible by shape" - return ( - *(size // node_count for size, node_count in zip(shape, self.topology)), - *(size for size in shape[2:]), - ) - - def halo_exchange_defs( - self, shape: dmp.ShapeAttr - ) -> Iterable[dmp.ExchangeDeclarationAttr]: - yield from _flat_face_exchanges_for_dim(shape, 0) - - yield from _flat_face_exchanges_for_dim(shape, 1) - - # TOOD: add diagonals - assert not self.diagonals - - def comm_layout(self) -> dmp.RankTopoAttr: - return dmp.RankTopoAttr(self.topology) - - -@dataclass -class GridSlice3d(DomainDecompositionStrategy): - """ - Takes a grid with two or more dimensions, slices it along the first three. - """ - - topology: tuple[int, int, int] - - diagonals: bool = False - - def __post_init__(self): - assert len(self.topology) >= 3, "GridSlice3d requires at least three dimensions" - - def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]: - assert len(shape) >= 3, "GridSlice3d requires at least two dimensions" - for size, node_count in zip(shape, self.topology): - assert ( - size % node_count == 0 - ), "GridSlice3d requires domain be neatly divisible by shape" - return ( - *(size // node_count for size, node_count in zip(shape, self.topology)), - *(size for size in shape[3:]), - ) - - def halo_exchange_defs( - self, shape: dmp.ShapeAttr - ) -> Iterable[dmp.ExchangeDeclarationAttr]: - yield from _flat_face_exchanges_for_dim(shape, 0) - - yield from _flat_face_exchanges_for_dim(shape, 1) - - yield from _flat_face_exchanges_for_dim(shape, 2) - - # TOOD: add diagonals - assert not self.diagonals - - def comm_layout(self) -> dmp.RankTopoAttr: - return dmp.RankTopoAttr(self.topology) - - -def _flat_face_exchanges_for_dim( - shape: dmp.ShapeAttr, axis: int -) -> tuple[dmp.ExchangeDeclarationAttr, dmp.ExchangeDeclarationAttr]: - """ - Generate the two exchange delcarations to exchange the faces on the - axis "axis". - """ - dimensions = shape.dims - assert axis <= dimensions - - def coords(where: Literal["start", "end"]): - for d in range(dimensions): - # for the dim we want to exchange, return either start or end halo region - if d == axis: - if where == "start": - # "start" halo goes from buffer start to core start - yield shape.buffer_start(d), shape.core_start(d) - else: - # "end" halo goes from core end to buffer end - yield shape.core_end(d), shape.buffer_end(d) - else: - # for the sliced regions, "extrude" from core - # this way we don't exchange edges - yield shape.core_start(d), shape.core_end(d) - - ex1_coords = tuple(coords("end")) - ex2_coords = tuple(coords("start")) - - return ( - # towards positive dim: - dmp.ExchangeDeclarationAttr.from_points( - ex1_coords, - axis, - dir_sign=1, - ), - # towards negative dim: - dmp.ExchangeDeclarationAttr.from_points( - ex2_coords, - axis, - dir_sign=-1, - ), - ) diff --git a/xdsl/transforms/experimental/dmp/stencil_global_to_local.py b/xdsl/transforms/experimental/dmp/stencil_global_to_local.py index fa7d8f882a..357ab130e7 100644 --- a/xdsl/transforms/experimental/dmp/stencil_global_to_local.py +++ b/xdsl/transforms/experimental/dmp/stencil_global_to_local.py @@ -1,6 +1,6 @@ from abc import ABC from collections.abc import Callable, Iterable -from dataclasses import dataclass, field +from dataclasses import dataclass from math import prod from typing import ClassVar, TypeVar, cast @@ -18,11 +18,6 @@ op_type_rewrite_pattern, ) from xdsl.rewriter import InsertPoint, Rewriter -from xdsl.transforms.experimental.dmp.decompositions import ( - DomainDecompositionStrategy, - GridSlice2d, - GridSlice3d, -) from xdsl.transforms.shape_inference import ShapeInferencePass from xdsl.utils.hints import isa @@ -33,7 +28,7 @@ @dataclass class ChangeStoreOpSizes(RewritePattern): - strategy: DomainDecompositionStrategy + strategy: dmp.DomainDecompositionStrategy @op_type_rewrite_pattern def match_and_rewrite(self, op: stencil.StoreOp, rewriter: PatternRewriter, /): @@ -58,12 +53,11 @@ class AddHaloExchangeOps(RewritePattern): This rewrite adds a `stencil.halo_exchange` after each `stencil.load` op """ - strategy: DomainDecompositionStrategy + strategy: dmp.DomainDecompositionStrategy @op_type_rewrite_pattern def match_and_rewrite(self, op: stencil.LoadOp, rewriter: PatternRewriter, /): - swap_op = dmp.SwapOp.get(op.res) - swap_op.topo = self.strategy.comm_layout() + swap_op = dmp.SwapOp.get(op.res, self.strategy) rewriter.insert_op_after_matched_op(swap_op) @@ -74,8 +68,6 @@ class LowerHaloExchangeToMpi(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): - assert op.swaps is not None - assert op.topo is not None exchanges = list(op.swaps) input_type = cast(ContainerType[Attribute], op.input_stencil.type) @@ -86,7 +78,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): op.input_stencil, exchanges, input_type.get_element_type(), - op.topo, + op.strategy.comm_layout(), emit_init=self.init, emit_debug=self.debug_prints, ) @@ -566,19 +558,15 @@ def collect_args_recursive(op: Operation) -> Iterable[Operation]: @dataclass -class DmpSwapShapeInference: +class DmpSwapShapeInference(RewritePattern): """ - Not a rewrite pattern, as it's a bit more involved. - This is applied after stencil shape inference has run. It will find the HaloSwapOps again, and use the results of the shape inference pass to attach the swap declarations. """ - strategy: DomainDecompositionStrategy - rewriter: Rewriter = field(default_factory=Rewriter) - - def match_and_rewrite(self, op: dmp.SwapOp): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: dmp.SwapOp, rewrite: PatternRewriter): core_lb: stencil.IndexAttr | None = None core_ub: stencil.IndexAttr | None = None @@ -609,7 +597,7 @@ def match_and_rewrite(self, op: dmp.SwapOp): # drop 0 element exchanges op.swaps = builtin.ArrayAttr( exchange - for exchange in self.strategy.halo_exchange_defs( + for exchange in op.strategy.halo_exchange_defs( dmp.ShapeAttr.from_index_attrs( buff_lb=buff_lb, core_lb=core_lb, @@ -620,11 +608,6 @@ def match_and_rewrite(self, op: dmp.SwapOp): if exchange.elem_count > 0 ) - def apply(self, module: builtin.ModuleOp): - for op in module.walk(): - if isinstance(op, dmp.SwapOp): - self.match_and_rewrite(op) - @dataclass(frozen=True) class DmpDecompositionPass(ModulePass, ABC): @@ -643,9 +626,9 @@ class DistributeStencilPass(DmpDecompositionPass): name = "distribute-stencil" - STRATEGIES: ClassVar[dict[str, type[DomainDecompositionStrategy]]] = { - "2d-grid": GridSlice2d, - "3d-grid": GridSlice3d, + STRATEGIES: ClassVar[dict[str, type[dmp.GridSlice2dAttr | dmp.GridSlice3dAttr]]] = { + "2d-grid": dmp.GridSlice2dAttr, + "3d-grid": dmp.GridSlice3dAttr, } slices: tuple[int, ...] @@ -684,7 +667,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: # run the shape inference pass ShapeInferencePass().apply(ctx, op) - DmpSwapShapeInference(strategy).apply(op) + PatternRewriteWalker(DmpSwapShapeInference()).rewrite_module(op) @dataclass(frozen=True) diff --git a/xdsl/transforms/stencil_to_csl_stencil.py b/xdsl/transforms/stencil_to_csl_stencil.py index 2c123e7fda..afa8db29f0 100644 --- a/xdsl/transforms/stencil_to_csl_stencil.py +++ b/xdsl/transforms/stencil_to_csl_stencil.py @@ -186,7 +186,7 @@ class ConvertSwapToPrefetchPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): # remove op if it contains no swaps - if op.swaps is None or len(op.swaps) == 0: + if len(op.swaps) == 0: rewriter.erase_matched_op(False) return @@ -211,12 +211,14 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): assert isa( t_type := op.input_stencil.type.get_element_type(), TensorType[Attribute] ) - assert op.topo is not None, f"topology on {type(op)} is not given" + assert ( + op.strategy.comm_layout() is not None + ), f"topology on {type(op)} is not given" # when translating swaps, remove third dimension prefetch_op = csl_stencil.PrefetchOp( input_stencil=op.input_stencil, - topo=op.topo, + topo=op.strategy.comm_layout(), swaps=[ csl_stencil.ExchangeDeclarationAttr(swap.neighbor[:2]) for swap in op.swaps