Skip to content

Commit

Permalink
fix DeviceOp
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Dec 18, 2024
1 parent f30189c commit ca53e95
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 40 deletions.
6 changes: 0 additions & 6 deletions tests/filecheck/dialects/aie.mlir

This file was deleted.

111 changes: 111 additions & 0 deletions tests/filecheck/dialects/aie/device_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP
// RUN: AIE_ROUNDTRIP
// RUN: AIE_GENERIC_ROUNDTRIP

// operation: aie.device

aie.device(xcvc1902) {
arith.constant 1 : i32
}

aie.device(xcve2302) {
arith.constant 1 : i32
}

aie.device(xcve2802) {
arith.constant 1 : i32
}

aie.device(npu1) {
arith.constant 1 : i32
}

aie.device(npu1_1col) {
arith.constant 1 : i32
}

aie.device(npu1_2col) {
arith.constant 1 : i32
}

aie.device(npu1_3col) {
arith.constant 1 : i32
}

aie.device(npu1_4col) {
arith.constant 1 : i32
}

aie.device(npu2) {
arith.constant 1 : i32
}

// CHECK: module {
// CHECK-NEXT: aie.device(xcvc1902) {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32
// CHECK-NEXT: }
// CHECK-NEXT: aie.device(xcve2302) {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32
// CHECK-NEXT: }
// CHECK-NEXT: aie.device(xcve2802) {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32
// CHECK-NEXT: }
// CHECK-NEXT: aie.device(npu1) {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32
// CHECK-NEXT: }
// CHECK-NEXT: aie.device(npu1_1col) {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32
// CHECK-NEXT: }
// CHECK-NEXT: aie.device(npu1_2col) {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32
// CHECK-NEXT: }
// CHECK-NEXT: aie.device(npu1_3col) {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32
// CHECK-NEXT: }
// CHECK-NEXT: aie.device(npu1_4col) {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32
// CHECK-NEXT: }
// CHECK-NEXT: aie.device(npu2) {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : i32
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK-GENERIC: "builtin.module"() ({
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 1 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{{{["]?}}value{{["]?}} = 1 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 2 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{{{["]?}}value{{["]?}} = 1 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 3 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{{{["]?}}value{{["]?}} = 1 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 4 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{{{["]?}}value{{["]?}} = 1 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 5 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{{{["]?}}value{{["]?}} = 1 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 6 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{{{["]?}}value{{["]?}} = 1 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 7 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{{{["]?}}value{{["]?}} = 1 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 8 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{{{["]?}}value{{["]?}} = 1 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: "aie.device"() <{{{["]?}}device{{["]?}} = 9 : i32}> ({
// CHECK-GENERIC-NEXT: %{{.*}} = "arith.constant"() <{{{["]?}}value{{["]?}} = 1 : i32}> : () -> i32
// CHECK-GENERIC-NEXT: "aie.end"() : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
// CHECK-GENERIC-NEXT: }) : () -> ()
87 changes: 53 additions & 34 deletions xdsl_aie/dialects/aie.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

from __future__ import annotations

from collections.abc import Iterable
from enum import auto
from typing import Generic
from collections.abc import Iterable, Sequence
from typing import Generic, Self

from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
Expand Down Expand Up @@ -52,7 +51,7 @@
irdl_op_definition,
operand_def,
opt_attr_def,
opt_region_def,
prop_def,
region_def,
result_def,
successor_def,
Expand All @@ -64,24 +63,40 @@
from xdsl.traits import (
HasParent,
IsTerminator,
NoTerminator,
SingleBlockImplicitTerminator,
SymbolOpInterface,
SymbolTable,
ensure_terminator,
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa

CASCADE_SIZE = 384


class AIEDeviceEnum(StrEnum):
xcvc1902 = auto()
class IntStrEnum(StrEnum):
@classmethod
def get_sequence(cls) -> Sequence[Self]:
return tuple(x for x in cls)

@classmethod
def from_int(cls, value: int) -> Self:
return cls.get_sequence()[value - 1]

@irdl_attr_definition
class AIEDeviceAttr(EnumAttribute[AIEDeviceEnum]):
name = "aie.device_attr"
def get_int(self) -> int:
return self.get_sequence().index(self) + 1


class AIEDeviceEnum(IntStrEnum):
xcvc1902 = "xcvc1902"
xcve2302 = "xcve2302"
xcve2802 = "xcve2802"
npu1 = "npu1"
npu1_1col = "npu1_1col"
npu1_2col = "npu1_2col"
npu1_3col = "npu1_3col"
npu1_4col = "npu1_4col"
npu2 = "npu2"


class ObjectFifoPortEnum(StrEnum):
Expand Down Expand Up @@ -615,35 +630,51 @@ def __init__(self, arg: Operation | SSAValue):
super().__init__(operands=[arg])


@irdl_op_definition
class EndOp(IRDLOperation):
name = "aie.end"

def __init__(self):
super().__init__()

traits = traits_def(IsTerminator())

assembly_format = "attr-dict"


@irdl_op_definition
class DeviceOp(IRDLOperation):
name = "aie.device"

region = opt_region_def()
region = region_def("single_block")

device = attr_def(AIEDeviceAttr)
traits = traits_def(SymbolTable(), NoTerminator(), HasParent(ModuleOp))
device = prop_def(IntegerAttr[IntegerType])
traits = traits_def(
SymbolTable(), SingleBlockImplicitTerminator(EndOp), HasParent(ModuleOp)
)

def __init__(self, device: AIEDeviceAttr, region: Region):
super().__init__(attributes={"device": device}, regions=[region])
def __init__(self, device: IntegerAttr[IntegerType], region: Region):
super().__init__(properties={"device": device}, regions=[region])

def print(self, printer: Printer):
printer.print("(")
device_str = "xcvc1902" if self.device.data == AIEDeviceEnum.xcvc1902 else ""
printer.print(device_str)
printer.print(AIEDeviceEnum.from_int(self.device.value.data).value)
printer.print(") ")
if self.region is not None:
printer.print_region(self.region)
printer.print_region(self.region, print_block_terminators=False)

@classmethod
def parse(cls, parser: Parser) -> DeviceOp:
parser.parse_characters("(")

device = AIEDeviceAttr(AIEDeviceAttr.parse_parameter(parser))
device = parser.parse_str_enum(AIEDeviceEnum)
parser.parse_characters(")")
region = parser.parse_region()

return DeviceOp(device, region)
device_op = cls(IntegerAttr(device.get_int(), 32), region)

for trait in device_op.get_traits_of_type(SingleBlockImplicitTerminator):
ensure_terminator(device_op, trait)

return device_op


@irdl_op_definition
Expand Down Expand Up @@ -1164,18 +1195,6 @@ def __init__(self, col: IntegerAttr[IntegerType]):
super().__init__(attributes={"col": col}, result_types=[IndexType()])


@irdl_op_definition
class EndOp(IRDLOperation):
name = "aie.end"

def __init__(self):
super().__init__()

traits = traits_def(IsTerminator())

assembly_format = "attr-dict"


@irdl_op_definition
class PacketFlowOp(IRDLOperation):
name = "aie.packet_flow"
Expand Down

0 comments on commit ca53e95

Please sign in to comment.