Skip to content

Commit

Permalink
add offsets
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 8, 2024
1 parent 226d65b commit e649cb7
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
6 changes: 4 additions & 2 deletions compiler/ir/tsl/tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ class TiledStridedLayout:
"""

tstrides: list[TiledStride]
offset: int

def __init__(self, tstrides: list[TiledStride]):
def __init__(self, tstrides: list[TiledStride], offset: int = 0):
self.tstrides = tstrides
self.offset = offset

def __str__(self) -> str:
return "(" + ", ".join(map(str, self.tstrides)) + ")"
return "(" + ", ".join(map(str, self.tstrides)) + f", offset: {self.offset})"

def __iter__(self) -> Iterator[tuple[int, int, Stride]]:
"""Returns an iterator of (dim, depth, stride) over all the
Expand Down
11 changes: 8 additions & 3 deletions compiler/parser/tsl_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,16 @@ def _parse_tiled_stride(self) -> TiledStride:

def parse(self) -> TiledStridedLayout:
"""
tsl ::= `(` tiled-stride (`,` tiled-stride)* `)`
tsl ::= `(` tiled-stride (`,` tiled-stride)*`, offset: ` offset `)`
"""
self._parse_token(Token.Kind.L_PAREN, "Expected opening bracket")
tstrides = []
while not self._parse_optional_token(Token.Kind.R_PAREN):
self.parse_optional_characters("offset:")
# while not self._parse_optional_token(Token.Kind.R_PAREN):
while not self.parse_optional_characters("offset"):
tstrides.append(self._parse_tiled_stride())
self._parse_optional_token(Token.Kind.COMMA)
return TiledStridedLayout(tstrides)
self._parse_token(Token.Kind.COLON, "Expected colon")
offset = self.parse_integer()
self._parse_token(Token.Kind.R_PAREN, "Expected closing bracket")
return TiledStridedLayout(tstrides, offset=offset)
4 changes: 2 additions & 2 deletions tests/filecheck/dialects/tsl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

//CHECK: module
"builtin.module"() ({
//CHECK: #tsl.tsl<([2, 4] * [4, 2], [16, 32] * [4, 2])>
%0 = "memref.alloc"() <{"operandSegmentSizes" = array<i32: 0, 0>}> : () -> memref<64x64xindex, #tsl.tsl<([2, 4] * [4, 2], [16, 32] * [4, 2])>, 2 : i32>
//CHECK: #tsl.tsl<([2, 4] * [4, 2], [16, 32] * [4, 2], offset: 5)>
%0 = "memref.alloc"() <{"operandSegmentSizes" = array<i32: 0, 0>}> : () -> memref<64x64xindex, #tsl.tsl<([2, 4] * [4, 2], [16, 32] * [4, 2], offset: 5)>, 2 : i32>
}) : () -> ()
4 changes: 2 additions & 2 deletions tests/ir/tsl/test_tiled_strided_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def example_tsl():
Stride(16, 2),
]
)
tsl = TiledStridedLayout([tiledStride1, tiledStride2])
tsl = TiledStridedLayout([tiledStride1, tiledStride2], offset=5)
return tsl


Expand All @@ -31,7 +31,7 @@ def test_tsl_constructor(example_tsl):

def test_tsl_str(example_tsl):
tsl = example_tsl
assert str(tsl) == "([4, 32] * [4, 2], [1, 16] * [4, 2])"
assert str(tsl) == "([4, 32] * [4, 2], [1, 16] * [4, 2], offset: 5)"


def test_tsl_iter(example_tsl):
Expand Down

0 comments on commit e649cb7

Please sign in to comment.