Skip to content

Commit

Permalink
dialects: dmp: make strategies attributes and carry them in IR (#3050)
Browse files Browse the repository at this point in the history
As of now, the decomposition strategies are dataclasses and the
information to construct them is expected to match in multiple bits of
the pipeline (actually fused in `distribute-stencil`, which happens to
do shape inference too, but I'm trying to decouple things here)

So instead, define them as attributes, and make the initial distribution
pass pop that attribute on the swaps. The swap shape inference then do
not need extra information to carry on.
  • Loading branch information
PapyChacal authored Aug 19, 2024
1 parent f591b3b commit f9df366
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 204 deletions.
6 changes: 2 additions & 4 deletions tests/dialects/test_dmp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from xdsl.dialects.experimental import dmp
from xdsl.dialects.experimental.dmp import ExchangeDeclarationAttr, ShapeAttr
from xdsl.transforms.experimental.dmp import decompositions


def flat_face_exchanges(
Expand All @@ -9,9 +9,7 @@ def flat_face_exchanges(
# since this is a private function, and pyright will yell whenever it's accessed,
# we have this wrapper function here that takes care of making the private publicly
# accessible in the context of this test.
func = (
decompositions._flat_face_exchanges_for_dim # pyright: ignore[reportPrivateUsage]
)
func = dmp._flat_face_exchanges_for_dim # pyright: ignore[reportPrivateUsage]
return func(shape, dim)


Expand Down
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/dmp/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ builtin.module {
%ref = "test.op"() : () -> (memref<1024x1024xf32>)

"dmp.swap"(%ref) {
topo = #dmp.topo<2x2>,
strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>,
swaps = [
#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>,
#dmp.exchange<at [4, 104] size [100, 0] source offset [0, -4] to [1, 0]>
]
} : (memref<1024x1024xf32>) -> ()

// CHECK: "dmp.swap"(%ref) {"topo" = #dmp.topo<2x2>, "swaps" = [#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>]} : (memref<1024x1024xf32>) -> ()
// CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>]} : (memref<1024x1024xf32>) -> ()

"dmp.swap"(%ref) {
topo = #dmp.topo<2x2>,
strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>,
swaps = [
#dmp.exchange<at [4, 0] size [100, 0] source offset [0, 4] to [-1, 0]>,
#dmp.exchange<at [4, 104] size [100, 0] source offset [0, -4] to [1, 0]>
Expand Down
4 changes: 2 additions & 2 deletions tests/filecheck/dialects/dmp/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ builtin.module {
%ref = "test.op"() : () -> (memref<1024x1024xf32>)

"dmp.swap"(%ref) {
topo = #dmp.topo<2x2>,
strategy = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>,
swaps = [
#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>,
#dmp.exchange<at [4, 104] size [100, 4] source offset [0, -4] to [1, 0]>
]
} : (memref<1024x1024xf32>) -> ()

// CHECK: "dmp.swap"(%ref) {"topo" = #dmp.topo<2x2>, "swaps" = [#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>, #dmp.exchange<at [4, 104] size [100, 4] source offset [0, -4] to [1, 0]>]} : (memref<1024x1024xf32>) -> ()
// CHECK: "dmp.swap"(%ref) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<2x2>, false>, "swaps" = [#dmp.exchange<at [4, 0] size [100, 4] source offset [0, 4] to [-1, 0]>, #dmp.exchange<at [4, 104] size [100, 4] source offset [0, -4] to [1, 0]>]} : (memref<1024x1024xf32>) -> ()
}
2 changes: 1 addition & 1 deletion tests/filecheck/transforms/distribute-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ builtin.module {

// CHECK: func.func @offsets(%0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %1 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>, %2 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64>) {
// CHECK-NEXT: %3 = stencil.load %0 : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> -> !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// CHECK-NEXT: "dmp.swap"(%3) {"topo" = #dmp.topo<2x2x2>, "swaps" = [#dmp.exchange<at [32, 0, 0] size [1, 32, 32] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 32, 32] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 32, 0] size [32, 1, 32] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [32, 1, 32] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> ()
// CHECK-NEXT: "dmp.swap"(%3) {"strategy" = #dmp.grid_slice_3d<#dmp.topo<2x2x2>, false>, "swaps" = [#dmp.exchange<at [32, 0, 0] size [1, 32, 32] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 32, 32] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 32, 0] size [32, 1, 32] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [32, 1, 32] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> ()
// CHECK-NEXT: %4, %5 = stencil.apply(%6 = %3 : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>) -> (!stencil.temp<[0,32]x[0,32]x[0,32]xf64>, !stencil.temp<[0,32]x[0,32]x[0,32]xf64>) {
// CHECK-NEXT: %7 = stencil.access %6[-1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
// CHECK-NEXT: %8 = stencil.access %6[1, 0, 0] : !stencil.temp<[-1,33]x[-1,33]x[0,32]xf64>
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/transforms/stencil-to-csl-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
builtin.module {
func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
"dmp.swap"(%0) {"topo" = #dmp.topo<1022x510>, "swaps" = [#dmp.exchange<at [1, 0, 0] size [1, 1, 510] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 1, 510] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 1, 0] size [1, 1, 510] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [1, 1, 510] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> ()
"dmp.swap"(%0) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<1022x510>, false>, "swaps" = [#dmp.exchange<at [1, 0, 0] size [1, 1, 510] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 1, 510] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 1, 0] size [1, 1, 510] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [1, 1, 510] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> ()
%1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
%3 = arith.constant 1.666600e-01 : f32
%4 = stencil.access %2[1, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
Expand Down
173 changes: 164 additions & 9 deletions xdsl/dialects/experimental/dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from __future__ import annotations

from collections.abc import Sequence
from abc import ABC
from collections.abc import Iterable, Sequence
from math import prod
from typing import Literal

Expand All @@ -20,11 +21,11 @@
IRDLOperation,
Operand,
ParameterDef,
attr_def,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
opt_attr_def,
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer
Expand Down Expand Up @@ -392,7 +393,7 @@ class RankTopoAttr(ParametrizedAttribute):

shape: ParameterDef[builtin.DenseArrayBase]

def __init__(self, shape: Sequence[int]):
def __init__(self, shape: Sequence[int] | Sequence[builtin.IntAttr]):
if len(shape) < 1:
raise ValueError("dmp.grid must have at least one dimension!")
super().__init__([builtin.DenseArrayBase.from_list(builtin.i64, shape)])
Expand Down Expand Up @@ -426,6 +427,160 @@ def print_parameters(self, printer: Printer) -> None:
printer.print_string(">")


class DomainDecompositionStrategy(ParametrizedAttribute, ABC):
def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]:
raise NotImplementedError("SlicingStrategy must implement calc_resize!")

def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]:
raise NotImplementedError("SlicingStrategy must implement halo_exchange_defs!")

def comm_layout(self) -> RankTopoAttr:
raise NotImplementedError("SlicingStrategy must implement comm_count!")


@irdl_attr_definition
class GridSlice2dAttr(DomainDecompositionStrategy):
"""
Takes a grid with two or more dimensions, slices it along the first two into equally
sized segments.
"""

name = "dmp.grid_slice_2d"

topology: ParameterDef[RankTopoAttr]

diagonals: ParameterDef[builtin.BoolAttr]

def __init__(self, topo: tuple[int, ...]):
super().__init__(
[RankTopoAttr(topo), builtin.BoolAttr.from_int_and_width(0, 1)]
)

def _verify(self):
assert (
len(self.topology.as_tuple()) >= 2
), "GridSlice2d requires at least two dimensions"

def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]:
assert len(shape) >= 2, "GridSlice2d requires at least two dimensions"
for size, node_count in zip(shape, self.topology.as_tuple()):
assert (
size % node_count == 0
), "GridSlice2d requires domain be neatly divisible by shape"
return (
*(
size // node_count
for size, node_count in zip(shape, self.topology.as_tuple())
),
*(size for size in shape[2:]),
)

def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]:
yield from _flat_face_exchanges_for_dim(shape, 0)

yield from _flat_face_exchanges_for_dim(shape, 1)

if self.diagonals.value.data:
raise NotImplementedError("Diagonals support not implemented yet")

def comm_layout(self) -> RankTopoAttr:
return RankTopoAttr(self.topology.as_tuple())


@irdl_attr_definition
class GridSlice3dAttr(DomainDecompositionStrategy):
"""
Takes a grid with two or more dimensions, slices it along the first three.
"""

name = "dmp.grid_slice_3d"

topology: ParameterDef[RankTopoAttr]

diagonals: ParameterDef[builtin.BoolAttr]

def __init__(self, topo: tuple[int, ...]):
super().__init__(
[RankTopoAttr(topo), builtin.BoolAttr.from_int_and_width(0, 1)]
)

def _verify(self):
assert (
len(self.topology.as_tuple()) >= 3
), "GridSlice3d requires at least three dimensions"

def calc_resize(self, shape: tuple[int, ...]) -> tuple[int, ...]:
assert len(shape) >= 3, "GridSlice3d requires at least two dimensions"
for size, node_count in zip(shape, self.topology.as_tuple()):
assert (
size % node_count == 0
), "GridSlice3d requires domain be neatly divisible by shape"
return (
*(
size // node_count
for size, node_count in zip(shape, self.topology.as_tuple())
),
*(size for size in shape[3:]),
)

def halo_exchange_defs(self, shape: ShapeAttr) -> Iterable[ExchangeDeclarationAttr]:
yield from _flat_face_exchanges_for_dim(shape, 0)

yield from _flat_face_exchanges_for_dim(shape, 1)

yield from _flat_face_exchanges_for_dim(shape, 2)

if self.diagonals.value.data:
raise NotImplementedError("Diagonals support not implemented yet")

def comm_layout(self) -> RankTopoAttr:
return RankTopoAttr(self.topology.as_tuple())


def _flat_face_exchanges_for_dim(
shape: ShapeAttr, axis: int
) -> tuple[ExchangeDeclarationAttr, ExchangeDeclarationAttr]:
"""
Generate the two exchange delcarations to exchange the faces on the
axis "axis".
"""
dimensions = shape.dims
assert axis <= dimensions

def coords(where: Literal["start", "end"]):
for d in range(dimensions):
# for the dim we want to exchange, return either start or end halo region
if d == axis:
if where == "start":
# "start" halo goes from buffer start to core start
yield shape.buffer_start(d), shape.core_start(d)
else:
# "end" halo goes from core end to buffer end
yield shape.core_end(d), shape.buffer_end(d)
else:
# for the sliced regions, "extrude" from core
# this way we don't exchange edges
yield shape.core_start(d), shape.core_end(d)

ex1_coords = tuple(coords("end"))
ex2_coords = tuple(coords("start"))

return (
# towards positive dim:
ExchangeDeclarationAttr.from_points(
ex1_coords,
axis,
dir_sign=1,
),
# towards negative dim:
ExchangeDeclarationAttr.from_points(
ex2_coords,
axis,
dir_sign=-1,
),
)


@irdl_op_definition
class SwapOp(IRDLOperation):
"""
Expand All @@ -438,15 +593,13 @@ class SwapOp(IRDLOperation):
base(stencil.AnyTempType) | base(builtin.AnyMemRefType)
)

swaps: builtin.ArrayAttr[ExchangeDeclarationAttr] | None = opt_attr_def(
builtin.ArrayAttr[ExchangeDeclarationAttr]
)
swaps = attr_def(builtin.ArrayAttr[ExchangeDeclarationAttr])

topo: RankTopoAttr | None = opt_attr_def(RankTopoAttr)
strategy = attr_def(DomainDecompositionStrategy)

@staticmethod
def get(input_stencil: SSAValue | Operation):
return SwapOp.build(operands=[input_stencil])
def get(input_stencil: SSAValue | Operation, strategy: DomainDecompositionStrategy):
return SwapOp.build(operands=[input_stencil], attributes={"strategy": strategy})


DMP = Dialect(
Expand All @@ -458,5 +611,7 @@ def get(input_stencil: SSAValue | Operation):
ExchangeDeclarationAttr,
ShapeAttr,
RankTopoAttr,
GridSlice2dAttr,
GridSlice3dAttr,
],
)
2 changes: 0 additions & 2 deletions xdsl/transforms/canonicalize_dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ class CanonicalizeDmpSwap(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
keeps: list[dmp.ExchangeDeclarationAttr] = []
if op.swaps is None:
return
for swap in op.swaps:
if swap.elem_count > 0:
keeps.append(swap)
Expand Down
Loading

0 comments on commit f9df366

Please sign in to comment.