Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add simplify method to tiled strided layouts #360

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions compiler/ir/tsl/tiled_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Iterator
from dataclasses import dataclass
from typing import Self

from compiler.ir.tsl.stride import Stride

Expand Down Expand Up @@ -73,6 +74,33 @@ def __iter__(self) -> Iterator[tuple[int, Stride]]:
"""
return enumerate(self.strides)

def canonicalize(self) -> Self:
strides: list[Stride] = []
for stride in reversed(self.strides):
if len(strides) == 0:
# always keep the innermost one
strides.insert(0, stride)
continue

if stride.bound == 1:
# strides with a bound of 0 are useless
continue

prev_stride = strides[0]
if (
prev_stride.step
and prev_stride.bound
and prev_stride.step * prev_stride.bound == stride.step
and stride.bound
):
# we can squash this stride with the previous one
strides[0] = Stride(prev_stride.step, prev_stride.bound * stride.bound)

else:
strides.insert(0, stride)

return type(self)(strides)

def is_dynamic(self) -> bool:
"""Check if the Tiled Stride is dynamic"""
return any(stride.is_dynamic() for stride in self.strides)
Expand Down
6 changes: 6 additions & 0 deletions compiler/ir/tsl/tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections.abc import Iterator
from dataclasses import dataclass
from typing import Self

import numpy as np
from numpy._typing import NDArray
Expand Down Expand Up @@ -77,6 +78,11 @@ def get_stride(self, dim: int, depth: int) -> Stride:
the Tiled Strided Layout"""
return self.tstrides[dim].strides[depth]

def canonicalize(self) -> Self:
return type(self)(
[tstride.canonicalize() for tstride in self.tstrides], self.offset
)

def all_values(self) -> NDArray[np.int_]:
"""
Returns a numpy array containing all the elements in the iteration space.
Expand Down
26 changes: 26 additions & 0 deletions tests/ir/tsl/test_tiled_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,29 @@ def test_tiled_stride_tile_bounds(example_tiled_strides: tuple[TiledStride, ...]
assert tiledStride1.tile_bounds() == [6, 4]
assert tiledStride2.tile_bounds() == [2, 6, 4]
assert tiledStride3.tile_bounds() == [None, 4]


def test_tiled_stride_canonicalize():
# normal tstrides are unaffected
tstride_normal = TiledStride([Stride(16, 8), Stride(1, 8)])
assert tstride_normal.canonicalize() == tstride_normal

# strides with bound 1 are evicted
tstride_bound1 = TiledStride([Stride(16, 1), Stride(1, 8)])
assert tstride_bound1.canonicalize() == TiledStride([Stride(1, 8)])

# squashable strides are squashed
tstride_squashme = TiledStride([Stride(8, 8), Stride(1, 8)])
assert tstride_squashme.canonicalize() == TiledStride([Stride(1, 64)])

# bound 1 + squash
tstride = TiledStride([Stride(8, 8), Stride(16, 1), Stride(1, 8)])
assert tstride.canonicalize() == TiledStride([Stride(1, 64)])

# squash + bound 1
tstride = TiledStride([Stride(64, 1), Stride(8, 8), Stride(1, 8)])
assert tstride.canonicalize() == TiledStride([Stride(1, 64)])

# normal remains the same
tstride = TiledStride([Stride(8, 2), Stride(16, 8), Stride(1, 8)])
assert tstride.canonicalize() == tstride