Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyright: fix tsl-related stuff #327

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions compiler/ir/tsl/tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass

import numpy as np
from numpy._typing import NDArray

from compiler.ir.tsl.stride import Stride
from compiler.ir.tsl.tiled_stride import TiledStride
Expand Down Expand Up @@ -76,14 +77,14 @@ def get_stride(self, dim: int, depth: int) -> Stride:
the Tiled Strided Layout"""
return self.tstrides[dim].strides[depth]

def all_values(self) -> np.ndarray:
def all_values(self) -> NDArray[np.int_]:
"""
Returns a numpy array containing all the elements in the iteration space.
"""
result = np.array([0])

for _, _, stride in self:
next_stride = np.array(stride.all_values())
next_stride = np.array(stride.all_values(), dtype=np.int_)
# for every stride, add a dimension and broadcast sum
result = np.squeeze(
np.expand_dims(result, -1) + np.expand_dims(next_stride, 0)
Expand Down
14 changes: 8 additions & 6 deletions compiler/parser/tsl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@ def _parse_int_or_question(self, context_msg: str = "") -> int | None:
return v
self.raise_error("Expected an integer literal or `?`" + context_msg)

def _parse_step(self) -> list[int]:
def _parse_step(self) -> list[int | None]:
"""
steps ::== `(` steps (`,` steps)* `)`
"""
self._parse_token(Token.Kind.L_PAREN, "Expected opening bracket")
steps: list[int] = []
steps: list[int | None] = []
while not self._parse_optional_token(Token.Kind.R_PAREN):
steps.append(self._parse_int_or_question())
self._parse_optional_token(Token.Kind.COMMA)
return steps

def _parse_bound(self) -> list[int]:
def _parse_bound(self) -> list[int | None]:
"""
bounds ::== `[` bound (`,` bound)* `]`
"""
self._parse_token(Token.Kind.L_SQUARE, "Expected opening bracket")
bounds: list[int] = []
bounds: list[int | None] = []
while not self._parse_optional_token(Token.Kind.R_SQUARE):
bounds.append(self._parse_int_or_question())
self._parse_optional_token(Token.Kind.COMMA)
Expand All @@ -49,15 +49,17 @@ def _parse_tiled_stride(self) -> TiledStride:
self._parse_token(Token.Kind.ARROW, "Expected arrow")
steps = self._parse_step()
if len(steps) != len(bounds):
raise ParseError("Expected same number of steps and bounds")
raise ParseError(
self._current_token.span, "Expected same number of steps and bounds"
)
# construct the tiledstrides
return TiledStride([Stride(step, bound) for step, bound in zip(steps, bounds)])

def parse(self) -> TiledStridedLayout:
"""
tsl ::= tiled-stride (`,` tiled-stride)*` (, offset: ` offset)?
"""
tstrides = []
tstrides: list[TiledStride] = []
offset = 0
while True:
if self._current_token.kind == Token.Kind.GREATER:
Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 0 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ typeCheckingMode = "strict"
"compiler/inference/helpers.py",
"compiler/inference/scoped_setups.py",
"compiler/inference/trace_acc_state.py",
"compiler/ir/tsl/tiled_strided_layout.py",
"compiler/parser/tsl_parser.py",
"compiler/transforms/accfg_dedup.py",
"compiler/transforms/clear_memory_space.py",
"compiler/transforms/convert_linalg_to_accfg.py",
Expand All @@ -114,11 +112,7 @@ typeCheckingMode = "strict"
"compiler/util/memref_descriptor.py",
"tests/benchmark/test_snax_benchmark.py",
"tests/dialects/test_snax.py",
"tests/dialects/test_tsl.py",
"tests/inference/test_accfg_state_tracing.py",
"tests/ir/tsl/test_stride.py",
"tests/ir/tsl/test_tiled_stride.py",
"tests/ir/tsl/test_tiled_strided_layout.py",
"tests/util/",
]

Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_tsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def example_tsl_attr():
return tsl_attr


def test_tsl_attr_constructor(example_tsl_attr):
def test_tsl_attr_constructor(example_tsl_attr: TiledStridedLayoutAttr):
tsl = example_tsl_attr
assert isinstance(tsl, TiledStridedLayoutAttr)
assert isinstance(tsl.data, TiledStridedLayout)


def test_tsl_attr_get_affine(example_tsl_attr):
def test_tsl_attr_get_affine(example_tsl_attr: TiledStridedLayoutAttr):
tsl = example_tsl_attr
map = canonicalize_map(tsl.get_affine_map())
assert map == canonicalize_map(
Expand Down
6 changes: 3 additions & 3 deletions tests/ir/tsl/test_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def example_strides():
return (Stride(1, 4), Stride(4, 6), Stride(24, 2), Stride(None, None))


def test_stride_constructor(example_strides):
def test_stride_constructor(example_strides: tuple[Stride, ...]):
stride1, stride2, _, dynamic_stride = example_strides
assert stride1.step == 1
assert stride1.bound == 4
Expand All @@ -18,15 +18,15 @@ def test_stride_constructor(example_strides):
assert dynamic_stride.bound is None


def test_stride_all_values(example_strides):
def test_stride_all_values(example_strides: tuple[Stride, ...]):
stride1, stride2, _, dynamic_stride = example_strides
assert stride1.all_values() == [0, 1, 2, 3]
assert stride2.all_values() == [0, 4, 8, 12, 16, 20]
with pytest.raises(ValueError):
dynamic_stride.all_values()


def test_stride_str(example_strides):
def test_stride_str(example_strides: tuple[Stride, ...]):
stride1, stride2, stride3, dynamic_stride = example_strides
assert str(stride1) == "4 -> 1"
assert str(stride2) == "6 -> 4"
Expand Down
16 changes: 10 additions & 6 deletions tests/ir/tsl/test_tiled_stride.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ def example_strides():


@pytest.fixture()
def example_tiled_strides(example_strides):
def example_tiled_strides(example_strides: tuple[Stride, ...]):
stride1, stride2, stride3, dynamic_stride = example_strides
tiledStride1 = TiledStride([stride2, stride1])
tiledStride2 = TiledStride([stride3, stride2, stride1])
tiledStride3 = TiledStride([dynamic_stride, stride1])
return tiledStride1, tiledStride2, tiledStride3


def test_tiled_stride_constructor(example_strides, example_tiled_strides):
def test_tiled_stride_constructor(
example_strides: tuple[Stride, ...], example_tiled_strides: tuple[TiledStride, ...]
):
stride1, stride2, stride3, _ = example_strides
tiledStride1, tiledStride2, _ = example_tiled_strides
assert tiledStride1.strides[0] == stride2
Expand All @@ -38,21 +40,23 @@ def test_tiled_stride_from_stride():
assert tiledStride2.strides[2] == Stride(24, 4)


def test_tiled_stride_depth(example_tiled_strides):
def test_tiled_stride_depth(example_tiled_strides: tuple[TiledStride, ...]):
tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides
assert tiledStride1.depth() == 2
assert tiledStride2.depth() == 3
assert tiledStride3.depth() == 2


def test_tiled_stride_str(example_tiled_strides):
def test_tiled_stride_str(example_tiled_strides: tuple[TiledStride, ...]):
tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides
assert str(tiledStride1) == "[6, 4] -> (4, 1)"
assert str(tiledStride2) == "[2, 6, 4] -> (24, 4, 1)"
assert str(tiledStride3) == "[?, 4] -> (?, 1)"


def test_tiled_stride_iter(example_strides, example_tiled_strides):
def test_tiled_stride_iter(
example_strides: tuple[Stride, ...], example_tiled_strides: tuple[TiledStride, ...]
):
stride1, stride2, stride3, _ = example_strides
strides = [stride3, stride2, stride1]

Expand All @@ -64,7 +68,7 @@ def test_tiled_stride_iter(example_strides, example_tiled_strides):
assert stride == strides[depth]


def test_tiled_stride_tile_bounds(example_tiled_strides):
def test_tiled_stride_tile_bounds(example_tiled_strides: tuple[TiledStride, ...]):
tiledStride1, tiledStride2, tiledStride3 = example_tiled_strides
assert tiledStride1.tile_bounds() == [6, 4]
assert tiledStride2.tile_bounds() == [2, 6, 4]
Expand Down
18 changes: 9 additions & 9 deletions tests/ir/tsl/test_tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ def example_tsl():
return tsl, tsl2


def test_tsl_constructor(example_tsl):
def test_tsl_constructor(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
assert isinstance(tsl.tstrides[0], TiledStride)
assert isinstance(tsl.tstrides[1], TiledStride)


def test_tsl_from_strides():
strides = [None, 1]
tile_bounds = [[16, 4], [16, 4]]
tile_bounds: list[list[int | None]] = [[16, 4], [16, 4]]
tsl_constructor = TiledStridedLayout(
[
TiledStride([Stride(None, 16), Stride(None, 4)]),
Expand All @@ -49,13 +49,13 @@ def test_tsl_from_strides():
assert tsl_constructor == tsl_from_strides


def test_tsl_str(example_tsl):
def test_tsl_str(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, tsl2 = example_tsl
assert str(tsl) == "[2, 4] -> (32, 4), [2, 4] -> (16, 1), offset: 5"
assert str(tsl2) == "[2, 4] -> (32, 4), [?, 4] -> (?, 1), offset: 7"


def test_tsl_iter(example_tsl):
def test_tsl_iter(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
count = 0
for dim, depth, stride in tsl:
Expand All @@ -67,19 +67,19 @@ def test_tsl_iter(example_tsl):
assert count == tsl.dimension() * tsl.tstrides[0].depth()


def test_tsl_all_values(example_tsl):
def test_tsl_all_values(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, tsl2 = example_tsl
assert set(tsl.all_values()) == set(range(64))
with pytest.raises(ValueError):
tsl2.all_values()


def test_tsl_tile_bounds(example_tsl):
def test_tsl_tile_bounds(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
assert tsl.tile_bounds() == [[2, 4], [2, 4]]


def test_tsl_self_overlaps(example_tsl):
def test_tsl_self_overlaps(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
assert not tsl.self_overlaps()

Expand All @@ -100,7 +100,7 @@ def test_tsl_self_overlaps(example_tsl):
assert tsl2.self_overlaps()


def test_tsl_is_dense(example_tsl):
def test_tsl_is_dense(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, _ = example_tsl
assert tsl.is_dense()

Expand All @@ -121,7 +121,7 @@ def test_tsl_is_dense(example_tsl):
assert not tsl2.is_dense()


def test_tsl_equal_tile_bounds(example_tsl):
def test_tsl_equal_tile_bounds(example_tsl: tuple[TiledStridedLayout, ...]):
tsl, tsl2 = example_tsl
assert tsl.equal_tile_bounds(tsl)
assert not tsl.equal_tile_bounds(tsl2)
Expand Down
Loading