Skip to content

Commit

Permalink
fix ObjectFifoOp (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin authored Dec 18, 2024
1 parent 6f5e680 commit 4226851
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 15 deletions.
28 changes: 28 additions & 0 deletions tests/filecheck/dialects/aie/object_fifo_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP
// RUN: AIE_ROUNDTRIP
// RUN: AIE_GENERIC_ROUNDTRIP

aie.device(npu1) {
%1 = aie.tile(0, 1)
%2 = aie.tile(0, 2)
aie.objectfifo @of1 (%1, { %2 }, 4 : i32) : !aie.objectfifo<memref<16xi32>>
}


// CHECK: module {
// CHECK-NEXT: aie.device(npu1) {
// CHECK-NEXT: %{{.*}} = aie.tile(0, 1)
// CHECK-NEXT: %{{.*}} = aie.tile(0, 2)
// CHECK-NEXT: aie.objectfifo @of1(%{{.*}}, {%{{.*}}}, 4 : i32) : !aie.objectfifo<memref<16xi32>>
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK-GENERIC: "builtin.module"() ({
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 4 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "aie.tile"() <{{{["]?}}col{{["]?}} = 0 : i32, {{["]?}}row{{["]?}} = 1 : i32}> : () -> index
// CHECK-GENERIC-NEXT: %{{.*}} = "aie.tile"() <{{{["]?}}col{{["]?}} = 0 : i32, {{["]?}}row{{["]?}} = 2 : i32}> : () -> index
// CHECK-GENERIC-NEXT: "aie.objectfifo"(%0, %1) <{{{["]?}}dimensionsFromStreamPerConsumer{{["]?}} = #aie<bd_dim_layout_array_array[[]]>, {{["]?}}dimensionsToStream{{["]?}} = #aie<bd_dim_layout_array[]>, {{["]?}}disable_synchronization{{["]?}} = false, {{["]?}}elemNumber{{["]?}} = 4 : i32, {{["]?}}elemType{{["]?}} = !aie.objectfifo<memref<16xi32>>, {{["]?}}plio{{["]?}} = false, {{["]?}}sym_name{{["]?}} = "of1", {{["]?}}via_DMA{{["]?}} = false}> : (index, index) -> ()
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
1 change: 0 additions & 1 deletion tests/filecheck/dialects/aie/tile_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@

// CHECK: %{{.*}} = aie.tile(1, 2)
// CHECK-GENERIC: %{{.*}} = "aie.tile"() <{{{["]?}}col{{["]?}} = 1 : i32, {{["]?}}row{{["]?}} = 2 : i32}> : () -> index

89 changes: 75 additions & 14 deletions xdsl_aie/dialects/aie.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
I32,
AnyIntegerAttr,
ArrayAttr,
BoolAttr,
Float32Type,
IndexType,
IntAttr,
Expand All @@ -35,14 +36,17 @@
Attribute,
AttributeInvT,
Block,
Data,
Dialect,
EnumAttribute,
OpaqueSyntaxAttribute,
Operation,
OpResult,
ParametrizedAttribute,
Region,
SSAValue,
StrEnum,
TypeAttribute,
)
from xdsl.irdl import (
IRDLOperation,
Expand All @@ -59,7 +63,7 @@
traits_def,
var_operand_def,
)
from xdsl.parser import Parser
from xdsl.parser import AttrParser, Parser
from xdsl.printer import Printer
from xdsl.traits import (
HasParent,
Expand Down Expand Up @@ -170,7 +174,7 @@ class BufferTypeAttr(EnumAttribute[BufferTypeEnum]):


@irdl_attr_definition
class ObjectFIFO(Generic[AttributeInvT], ParametrizedAttribute):
class ObjectFIFO(Generic[AttributeInvT], ParametrizedAttribute, TypeAttribute):
name = "aie.objectfifo"

buffer: ParameterDef[AttributeInvT]
Expand Down Expand Up @@ -201,6 +205,45 @@ def from_element_type_and_shape(
return ObjectFIFOSubview([builtin.MemRefType(element_type, shape)])


class BDDimLayout(tuple[int]): ...


class BDDimLayoutArray(tuple[BDDimLayout, ...]): ...


class BDDimLayoutArrayArray(tuple[BDDimLayoutArray, ...]): ...


@irdl_attr_definition
class BDDimLayoutArrayAttr(Data[BDDimLayoutArray], OpaqueSyntaxAttribute):
name = "aie.bd_dim_layout_array"

@classmethod
def parse_parameter(cls, parser: AttrParser) -> BDDimLayoutArray:
parser.parse_punctuation("[")
parser.parse_punctuation("]")
return BDDimLayoutArray(tuple())

def print_parameter(self, printer: Printer) -> None:
printer.print_string(f"{list(self.data)}")


@irdl_attr_definition
class BDDimLayoutArrayArrayAttr(Data[BDDimLayoutArrayArray], OpaqueSyntaxAttribute):
name = "aie.bd_dim_layout_array_array"

@classmethod
def parse_parameter(cls, parser: AttrParser) -> BDDimLayoutArrayArray:
parser.parse_punctuation("[")
parser.parse_punctuation("[")
parser.parse_punctuation("]")
parser.parse_punctuation("]")
return BDDimLayoutArrayArray(tuple(tuple()))

def print_parameter(self, printer: Printer) -> None:
printer.print_string("[[]]")


@irdl_op_definition
class SwitchboxOp(IRDLOperation):
name = "aie.switchbox"
Expand Down Expand Up @@ -1068,37 +1111,53 @@ def parse(cls, parser: Parser) -> ObjectFIFOSubviewAccessOp:
class ObjectFifoOp(IRDLOperation):
name = "aie.objectfifo"

elemNumber = attr_def(IntegerAttr[IntegerType])
producerTile = operand_def(IndexType())
consumerTiles = var_operand_def(IndexType())
sym_name = attr_def(StringAttr)
object_fifo = attr_def(ObjectFIFO[Attribute])

sym_name = prop_def(StringAttr)

elemNumber = prop_def(IntegerAttr[IntegerType])
elemType = prop_def(ObjectFIFO[Attribute])

dimensionsToStream = prop_def(BDDimLayoutArrayAttr)
dimensionsFromStreamPerConsumer = prop_def(BDDimLayoutArrayArrayAttr)

via_DMA = prop_def(BoolAttr)
plio = prop_def(BoolAttr)
disable_synchronization = prop_def(BoolAttr)

traits = traits_def(SymbolOpInterface(), HasParent(DeviceOp))

def __init__(
self,
elemNumber: IntegerAttr[IntegerType],
producerTile: Operation | SSAValue,
consumerTiles: list[Operation | SSAValue],
name: str,
elemNumber: IntegerAttr[IntegerType],
referenced_type: Attribute,
shape: Iterable[int | IntAttr],
name: str,
):
object_fifo = ObjectFIFO[Attribute].from_element_type_and_shape(
elemType = ObjectFIFO[Attribute].from_element_type_and_shape(
referenced_type, shape
)
super().__init__(
attributes={
properties={
"dimensionsFromStreamPerConsumer": BDDimLayoutArrayArrayAttr(
BDDimLayoutArrayArray(tuple(tuple()))
),
"dimensionsToStream": BDDimLayoutArrayAttr(BDDimLayoutArray(tuple())),
"disable_synchronization": IntegerAttr.from_int_and_width(0, 1),
"elemNumber": elemNumber,
"object_fifo": object_fifo,
"elemType": elemType,
"plio": IntegerAttr.from_int_and_width(0, 1),
"sym_name": StringAttr(name),
"via_DMA": IntegerAttr.from_int_and_width(0, 1),
},
operands=[producerTile, consumerTiles],
operands=[producerTile, *consumerTiles],
)

def print(self, printer: Printer):
printer.print(" @", self.sym_name.data, "( ", self.producerTile, ", { ")
printer.print(" @", self.sym_name.data, "(", self.producerTile, ", {")
for i in range(len(self.consumerTiles) - 1):
printer.print(self.consumerTiles[i], ", ")

Expand All @@ -1107,7 +1166,7 @@ def print(self, printer: Printer):

printer.print(
") : !aie.objectfifo<",
self.object_fifo.buffer,
self.elemType.buffer,
">",
)

Expand Down Expand Up @@ -1140,7 +1199,7 @@ def parse(cls, parser: Parser) -> ObjectFifoOp:
referenced_type = objectfifo_type.element_type

object_fifo = ObjectFifoOp(
elemNumber, producerTile, consumerTiles, referenced_type, shape, name
producerTile, consumerTiles, name, elemNumber, referenced_type, shape
)

return object_fifo
Expand Down Expand Up @@ -1535,6 +1594,8 @@ def __init__(
EndOp,
],
[
BDDimLayoutArrayAttr,
BDDimLayoutArrayArrayAttr,
WireBundleAttr,
ObjectFIFO,
ObjectFIFOSubview,
Expand Down

0 comments on commit 4226851

Please sign in to comment.