Skip to content

Commit

Permalink
pyright: fix custom dialects snax, snax_stream, tsl
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 3, 2025
1 parent 502135e commit c1b2c19
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 41 deletions.
25 changes: 14 additions & 11 deletions compiler/dialects/snax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
IndexType,
IntegerAttr,
IntegerType,
MemrefLayoutAttr,
MemRefType,
NoneAttr,
UnrankedMemrefType,
Expand Down Expand Up @@ -64,8 +65,8 @@ class LayoutCast(IRDLOperation):

name = "snax.layout_cast"

source = operand_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
dest = result_def(MemRefType[Attribute] | UnrankedMemrefType[Attribute])
source = operand_def(MemRefType)
dest = result_def(MemRefType)

def __init__(
self,
Expand All @@ -77,14 +78,16 @@ def __init__(
@staticmethod
def from_type_and_target_layout(
source: SSAValue | Operation,
layout: Attribute,
layout: MemrefLayoutAttr,
) -> LayoutCast:
assert isinstance(source.type, MemRefType)
source = SSAValue.get(source)
assert isinstance(source.type, Attribute)
source_type = cast(MemRefType[Attribute], source.type)
dest = MemRefType(
source.type.get_element_type(),
shape=source.type.get_shape(),
source_type.get_element_type(),
source_type.get_shape(),
layout=layout,
memory_space=source.type.memory_space,
memory_space=source_type.memory_space,
)
return LayoutCast(source, dest)

Expand Down Expand Up @@ -117,8 +120,8 @@ class Alloc(IRDLOperation):

name = "snax.alloc"

size: Operand = operand_def(IntegerType | IndexType)
shapes: VarOperand = var_operand_def(IntegerType | IndexType)
size: Operand = operand_def(IndexType)
shapes: VarOperand = var_operand_def(IndexType)
result: OpResult = result_def(LLVMStructType)
memory_space: Attribute | None = opt_prop_def(Attribute)
alignment: AnyIntegerAttr | None = opt_prop_def(AnyIntegerAttr)
Expand All @@ -129,7 +132,7 @@ def __init__(
size: SSAValue | Operation,
shapes: list[SSAValue | Operation],
memory_space: Attribute = NoneAttr(),
alignment: AnyIntegerAttr = None,
alignment: AnyIntegerAttr | None = None,
integer_type: IntegerType = i32,
):
# output type is llvm struct memref descriptor
Expand Down Expand Up @@ -184,7 +187,7 @@ def parse_parameter(cls, parser: AttrParser) -> StreamerConfiguration:
parser.parse_punctuation("[")

# Determine streamer options
opts = []
opts: Sequence[StreamerOpts] = []
if parser.parse_optional_keyword("opts"):
parser.parse_punctuation("=")
while not parser.parse_optional_punctuation(","):
Expand Down
2 changes: 1 addition & 1 deletion compiler/dialects/snax_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
parameters: Sequence[Attribute] = []
for arg in (upper_bounds, temporal_strides, spatial_strides):
if not isinstance(arg, ArrayAttr):
arg = ArrayAttr([IntAttr(x) if isinstance(x, int) else x for x in arg])
arg = ArrayAttr([IntAttr(x) for x in arg])
parameters.append(arg)
super().__init__(parameters)

Expand Down
37 changes: 15 additions & 22 deletions compiler/dialects/tsl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from math import prod
from typing import cast

from xdsl.dialects.arith import ConstantOp, DivUIOp, MuliOp
from xdsl.dialects.builtin import (
Expand All @@ -11,7 +12,7 @@
StridedLayoutAttr,
)
from xdsl.dialects.memref import DimOp, ExtractStridedMetaDataOp
from xdsl.ir import Data, Dialect, Operation, SSAValue
from xdsl.ir import Attribute, Data, Dialect, Operation, SSAValue
from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineMap
from xdsl.irdl import (
irdl_attr_definition,
Expand All @@ -32,7 +33,7 @@ class TiledStridedLayoutAttr(MemrefLayoutAttr, Data[TiledStridedLayout]):
@classmethod
def parse_parameter(cls, parser: AttrParser) -> TiledStridedLayout:
with parser.in_angle_brackets():
tslparser = TSLParser(parser._parser_state)
tslparser = TSLParser(parser.pos)
return tslparser.parse()

def print_parameter(self, printer: Printer) -> None:
Expand All @@ -42,10 +43,6 @@ def get_affine_map(self) -> AffineMap:
if self.data.is_dynamic():
raise NotImplementedError("Dynamic case is not implemented yet!")

# TODO: the affine map should result in element offset, not byte offset
# i will probably transition the tsl definition to element offset
# as well, to make everything more convenient

result = AffineConstantExpr(0)
for dim in range(self.data.dimension()):
max_depth = self.data.tstrides[dim].depth()
Expand Down Expand Up @@ -86,15 +83,15 @@ def get_bound_ops(
"""

result: list[Operation] = []
result_mapping: dict[(int, int), Operation] = {}
result_mapping: dict[tuple[int, int], Operation] = {}

tsl = self.data

if isinstance(memref_op_or_shapes, SSAValue | Operation):
# if the argument passed is a memref, generate shape operation
# list by using the dim operation
memref = memref_op_or_shapes
shapes = []
shapes: list[Operation] = []
for dim in range(tsl.dimension()):
dim_index_op = ConstantOp.from_int_and_width(dim, IndexType())
dim_op = DimOp.from_source_and_index(memref, dim_index_op)
Expand Down Expand Up @@ -134,6 +131,7 @@ def get_bound_ops(
# inner tile depths are all static by definition of TSL
for depth in range(1, tsl.tstrides[dim].depth()):
stride = tsl.get_stride(dim, depth)
assert stride.bound is not None
bound_op = ConstantOp.from_int_and_width(stride.bound, IndexType())
result.append(bound_op)
result_mapping[(dim, depth)] = bound_op
Expand Down Expand Up @@ -169,19 +167,20 @@ def get_step_ops(
result: list[Operation] = []
result_mapping: dict[tuple[int, int], Operation] = {}
tsl = self.data
el_bytes = 1

# Handle the special case where a tsl is constructed from a stridedlayoutattr
# In this case, if there are dynamic strides, we cannot perform
# the TSL contiguity assumptions. Instead, dynamic strides are
# fetched from the extract strided metadata operation.
if (
memref_op
and isinstance(memref_op.type, MemRefType)
and isinstance(memref_op.type.layout, StridedLayoutAttr)
if memref_op and isinstance(
(memref_type := cast(MemRefType[Attribute], memref_op.type)).layout,
StridedLayoutAttr,
):
metadata_op = ExtractStridedMetaDataOp(memref_op)
assert isinstance(memref_type.element_type, FixedBitwidthType)
element_size_op = ConstantOp.from_int_and_width(
memref_op.type.element_type.width.data // 8, IndexType()
memref_type.element_type.size, IndexType()
)
result.extend([metadata_op, element_size_op])
for dim in range(tsl.dimension()):
Expand All @@ -191,15 +190,9 @@ def get_step_ops(
stride = MuliOp(metadata_op.strides[dim], element_size_op)
result_mapping[(dim, depth)] = stride

# optional bytes correction
if in_bytes:
assert memref_op
assert isinstance(memref_op.type, MemRefType)
assert isinstance(memref_op.type.element_type, FixedBitwidthType)
el_bytes = memref_op.type.element_type.size
else:
# else use 1 such that 1 element = 1 byte
el_bytes = 1
# optional bytes correction
if in_bytes:
el_bytes = memref_type.element_type.size

# to handle the dynamic case, we must first find the largest
# statically defined step, and then use that to calculate the
Expand Down
6 changes: 3 additions & 3 deletions compiler/parser/tsl_parser.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from xdsl.parser.base_parser import BaseParser, ParserState
from xdsl.parser.base_parser import BaseParser
from xdsl.utils.exceptions import ParseError
from xdsl.utils.lexer import Token
from xdsl.utils.lexer import Position, Token

from compiler.ir.tsl import Stride, TiledStride, TiledStridedLayout


class TSLParser(BaseParser):
def __init__(self, state: ParserState) -> None:
def __init__(self, state: Position) -> None:
self._resume_from(state)

def _parse_int_or_question(self, context_msg: str = "") -> int | None:
Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ typeCheckingMode = "strict"
"compiler/accelerators/snax_gemmx.py",
"compiler/accelerators/snax_hwpe_mult.py",
"compiler/dialects/accfg.py",
"compiler/dialects/snax.py",
"compiler/dialects/snax_stream.py",
"compiler/dialects/tsl.py",
"compiler/inference/dataflow.py",
"compiler/inference/helpers.py",
"compiler/inference/scoped_setups.py",
Expand Down

0 comments on commit c1b2c19

Please sign in to comment.