Skip to content

Commit

Permalink
Add dynamic stride and bound support
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 9, 2024
1 parent e649cb7 commit dab0bf8
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 28 deletions.
20 changes: 15 additions & 5 deletions compiler/ir/tsl/stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,29 @@ class Stride:
used to generate all values within the bound.
Args:
stride (int): The stride of the Stride
bound (int): The bound of the Stride
stride (int | None): The stride of the Stride
None represents a dynamic stride
bound (int | None): The bound of the Stride
None represents a dynamic bound
"""

stride: int
bound: int
stride: int | None
bound: int | None

def is_dynamic(self) -> bool:
"""Check if the Stride is dynamic"""
return self.stride is None or self.bound is None

def all_values(self) -> list[int]:
"""Get all values within the bound of the Stride"""
if self.is_dynamic():
raise ValueError("Cannot get all values of a dynamic stride")
return list(range(0, self.stride * self.bound, self.stride))

def __str__(self) -> str:
return f"{self.stride} x {self.bound}"
stride = "?" if self.stride is None else str(self.stride)
bound = "?" if self.bound is None else str(self.bound)
return f"{stride} x {bound}"

def __eq__(self, other: object) -> bool:
if not isinstance(other, Stride):
Expand Down
12 changes: 10 additions & 2 deletions compiler/ir/tsl/tiled_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ def __init__(self, strides: list[Stride]):
self.strides = list(strides)

def __str__(self) -> str:
strides = ", ".join(str(stride.stride) for stride in self.strides)
bounds = ", ".join(str(stride.bound) for stride in self.strides)
strides = ", ".join(
str(stride.stride) if stride.stride else "?" for stride in self.strides
)
bounds = ", ".join(
str(stride.bound) if stride.bound else "?" for stride in self.strides
)
return f"[{strides}] * [{bounds}]"

def __iter__(self) -> Iterator[list[int, Stride]]:
Expand All @@ -43,6 +47,10 @@ def __iter__(self) -> Iterator[list[int, Stride]]:
"""
yield from zip(range(self.depth()), self.strides)

def is_dynamic(self) -> bool:
"""Check if the Tiled Stride is dynamic"""
return any(stride.is_dynamic() for _, stride in self.strides)

def depth(self) -> int:
"""Get the number of strides in the Tiled Stride
Expand Down
4 changes: 4 additions & 0 deletions compiler/ir/tsl/tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __iter__(self) -> Iterator[tuple[int, int, Stride]]:
# return result
return iter(result)

def is_dynamic(self) -> bool:
"""Check if the Tiled Strided Layout is dynamic"""
return any(stride.is_dynamic() for _, _, stride in self)

def dimension(self) -> int:
"""Get the number of dimensions in the Tiled Strided Layout"""
return len(self.tstrides)
Expand Down
12 changes: 10 additions & 2 deletions compiler/parser/tsl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@ class TSLParser(BaseParser):
def __init__(self, state: ParserState) -> None:
self._resume_from(state)

def _parse_int_or_question(self, context_msg: str = "") -> int | None:
"""Parse either an integer literal, or a '?'."""
if self._parse_optional_token(Token.Kind.QUESTION) is not None:
return None
if (v := self.parse_optional_integer(allow_boolean=False)) is not None:
return v
self.raise_error("Expected an integer literal or `?`" + context_msg)

def _parse_stride(self) -> list[int]:
"""
strides ::== `[` stride (`,` stride)* `]`
"""
self._parse_token(Token.Kind.L_SQUARE, "Expected opening bracket")
strides: list[int] = []
while not self._parse_optional_token(Token.Kind.R_SQUARE):
strides.append(self.parse_integer())
strides.append(self._parse_int_or_question())
self._parse_optional_token(Token.Kind.COMMA)
return strides

Expand All @@ -29,7 +37,7 @@ def _parse_bound(self) -> list[int]:
self._parse_token(Token.Kind.L_SQUARE, "Expected opening bracket")
bounds: list[int] = []
while not self._parse_optional_token(Token.Kind.R_SQUARE):
bounds.append(self.parse_integer())
bounds.append(self._parse_int_or_question())
self._parse_optional_token(Token.Kind.COMMA)
return bounds

Expand Down
2 changes: 2 additions & 0 deletions tests/filecheck/dialects/tsl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@
"builtin.module"() ({
//CHECK: #tsl.tsl<([2, 4] * [4, 2], [16, 32] * [4, 2], offset: 5)>
%0 = "memref.alloc"() <{"operandSegmentSizes" = array<i32: 0, 0>}> : () -> memref<64x64xindex, #tsl.tsl<([2, 4] * [4, 2], [16, 32] * [4, 2], offset: 5)>, 2 : i32>
//CHECK: #tsl.tsl<([2, 4] * [4, ?], [16, ?] * [4, ?], offset: 7)>
%1 = "memref.alloc"() <{"operandSegmentSizes" = array<i32: 0, 0>}> : () -> memref<64x64xindex, #tsl.tsl<([2, 4] * [4, ?], [16, ?] * [4, ?], offset: 7)>, 2 : i32>
}) : () -> ()
13 changes: 9 additions & 4 deletions tests/ir/tsl/test_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,30 @@

@pytest.fixture()
def example_strides():
return (Stride(1, 4), Stride(4, 6), Stride(24, 2))
return (Stride(1, 4), Stride(4, 6), Stride(24, 2), Stride(None, None))


def test_stride_constructor(example_strides):
stride1, stride2, _ = example_strides
stride1, stride2, _, dynamic_stride = example_strides
assert stride1.stride == 1
assert stride1.bound == 4
assert stride2.stride == 4
assert stride2.bound == 6
assert dynamic_stride.stride is None
assert dynamic_stride.bound is None


def test_stride_all_values(example_strides):
stride1, stride2, _ = example_strides
stride1, stride2, _, dynamic_stride = example_strides
assert stride1.all_values() == [0, 1, 2, 3]
assert stride2.all_values() == [0, 4, 8, 12, 16, 20]
with pytest.raises(ValueError):
dynamic_stride.all_values()


def test_stride_str(example_strides):
stride1, stride2, stride3 = example_strides
stride1, stride2, stride3, dynamic_stride = example_strides
assert str(stride1) == "1 x 4"
assert str(stride2) == "4 x 6"
assert str(stride3) == "24 x 2"
assert str(dynamic_stride) == "? x ?"
19 changes: 11 additions & 8 deletions tests/ir/tsl/test_tiled_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@

@pytest.fixture()
def example_strides():
return (Stride(1, 4), Stride(4, 6), Stride(24, 2))
return (Stride(1, 4), Stride(4, 6), Stride(24, 2), Stride(None, None))


@pytest.fixture()
def example_tiled_strides(example_strides):
stride1, stride2, stride3 = example_strides
stride1, stride2, stride3, dynamic_stride = example_strides
tiledStride1 = TiledStride([stride1, stride2])
tiledStride2 = TiledStride([stride1, stride2, stride3])
return tiledStride1, tiledStride2
tiledStride3 = TiledStride([stride1, dynamic_stride])
return tiledStride1, tiledStride2, tiledStride3


def test_tiled_stride_constructor(example_strides, example_tiled_strides):
stride1, stride2, stride3 = example_strides
tiledStride1, tiledStride2 = example_tiled_strides
stride1, stride2, stride3, _ = example_strides
tiledStride1, tiledStride2, _ = example_tiled_strides
assert tiledStride1.strides[0] == stride1
assert tiledStride1.strides[1] == stride2
assert tiledStride2.strides[0] == stride1
Expand All @@ -28,20 +29,22 @@ def test_tiled_stride_constructor(example_strides, example_tiled_strides):


def test_tiled_stride_depth(example_tiled_strides):
tiledStride1, tiledStride2 = example_tiled_strides
tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides
assert tiledStride1.depth() == 2
assert tiledStride2.depth() == 3
assert tiledStride3.depth() == 2


def test_tiled_stride_str(example_tiled_strides):
tiledStride1, tiledStride2 = example_tiled_strides
tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides
assert str(tiledStride1) == "[1, 4] * [4, 6]"
assert str(tiledStride2) == "[1, 4, 24] * [4, 6, 2]"
assert str(tiledStride3) == "[1, ?] * [4, ?]"


def test_tiled_stride_iter(example_strides, example_tiled_strides):
strides = example_strides
_, tiledStride2 = example_tiled_strides
_, tiledStride2, _ = example_tiled_strides

for depth, stride in tiledStride2:
assert isinstance(depth, int)
Expand Down
24 changes: 17 additions & 7 deletions tests/ir/tsl/test_tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,31 @@ def example_tsl():
Stride(16, 2),
]
)
tiledStride3 = TiledStride(
[
Stride(1, 4),
Stride(None, None),
]
)
tsl = TiledStridedLayout([tiledStride1, tiledStride2], offset=5)
return tsl
tsl2 = TiledStridedLayout([tiledStride1, tiledStride3], offset=7)
return tsl, tsl2


def test_tsl_constructor(example_tsl):
tsl = example_tsl
tsl, _ = example_tsl
assert isinstance(tsl.tstrides[0], TiledStride)
assert isinstance(tsl.tstrides[1], TiledStride)


def test_tsl_str(example_tsl):
tsl = example_tsl
tsl, tsl2 = example_tsl
assert str(tsl) == "([4, 32] * [4, 2], [1, 16] * [4, 2], offset: 5)"
assert str(tsl2) == "([4, 32] * [4, 2], [1, ?] * [4, ?], offset: 7)"


def test_tsl_iter(example_tsl):
tsl = example_tsl
tsl, _ = example_tsl
count = 0
for dim, depth, stride in tsl:
count += 1
Expand All @@ -47,12 +55,14 @@ def test_tsl_iter(example_tsl):


def test_tsl_all_values(example_tsl):
tsl = example_tsl
tsl, tsl2 = example_tsl
assert set(tsl.all_values()) == set(range(64))
with pytest.raises(ValueError):
tsl2.all_values()


def test_tsl_self_overlaps(example_tsl):
tsl = example_tsl
tsl, _ = example_tsl
assert not tsl.self_overlaps()

tiledStride1 = TiledStride(
Expand All @@ -73,7 +83,7 @@ def test_tsl_self_overlaps(example_tsl):


def test_tsl_is_dense(example_tsl):
tsl = example_tsl
tsl, _ = example_tsl
assert tsl.is_dense()

tiledStride1 = TiledStride(
Expand Down

0 comments on commit dab0bf8

Please sign in to comment.