Skip to content

Commit

Permalink
add aiex dialect with aiex.runtime_sequence operation
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendumoulin committed Jan 8, 2025
1 parent 35db9e3 commit 8a4d56c
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/filecheck/dialects/aiex/ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP
// RUN: AIE_ROUNDTRIP
// RUN: AIE_GENERIC_ROUNDTRIP



aie.device(npu1_1col) {
// CHECK: aie.device
// CHECK-GENERIC: "aie.device"

aiex.runtime_sequence () { }

// CHECK-NEXT: aiex.runtime_sequence() {
// CHECK-NEXT: }

// CHECK-GENERIC-NEXT: "aiex.runtime_sequence"() ({
// CHECK-GENERIC-NEXT: }) : () -> ()

aiex.runtime_sequence(%0 : memref<16xi8>, %1 : memref<16xi8>) { }

// CHECK-NEXT: aiex.runtime_sequence(%{{.*}}: memref<16xi8>, %{{.*}}: memref<16xi8>) {
// CHECK-NEXT: }

// CHECK-GENERIC-NEXT: "aiex.runtime_sequence"() ({
// CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}}: memref<16xi8>, %{{.*}}: memref<16xi8>):
// CHECK-GENERIC-NEXT: }) : () -> ()

aiex.runtime_sequence @testabc(%2 : memref<16xi8>, %3 : memref<16xi8>) { }

// CHECK-NEXT: aiex.runtime_sequence @testabc(%{{.*}}: memref<16xi8>, %{{.*}}: memref<16xi8>) {
// CHECK-NEXT: }

// CHECK-GENERIC-NEXT: "aiex.runtime_sequence"() <{sym_name = "testabc"}> ({
// CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}}: memref<16xi8>, %{{.*}}: memref<16xi8>):
// CHECK-GENERIC-NEXT: }) : () -> ()

}
6 changes: 6 additions & 0 deletions xdsl_aie/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ def get_aie():

return AIE

def get_aiex():
from xdsl_aie.dialects.aiex import AIEX

return AIEX

def get_arith():
from xdsl.dialects.arith import Arith

Expand Down Expand Up @@ -94,6 +99,7 @@ def get_vector():
return {
"affine": get_affine,
"aie": get_aie,
"aiex": get_aiex,
"arith": get_arith,
"bufferization": get_bufferization,
"builtin": get_builtin,
Expand Down
69 changes: 69 additions & 0 deletions xdsl_aie/dialects/aiex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import Self

from xdsl.dialects.builtin import StringAttr
from xdsl.ir import Dialect, Region
from xdsl.irdl import (
IRDLOperation,
irdl_op_definition,
opt_prop_def,
region_def,
traits_def,
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.traits import HasParent, NoTerminator

from xdsl_aie.dialects.aie import DeviceOp


@irdl_op_definition
class RuntimeSequenceOp(IRDLOperation):
name = "aiex.runtime_sequence"

sym_name = opt_prop_def(StringAttr)

body = region_def()

traits = traits_def(HasParent(DeviceOp), NoTerminator())

def __init__(self, body: Region, name: StringAttr | str | None = None):
if isinstance(name, str):
name = StringAttr(name)

super().__init__(properties={"sym_name": name}, regions=[body])

def print(self, printer: Printer):
if self.sym_name:
printer.print_string(" @" + self.sym_name.data)
printer.print_string("(")
if self.body.blocks:
printer.print_list(self.body.blocks[0].args, printer.print_block_argument)
printer.print_string(") ")
printer.print_region(
self.body, print_entry_block_args=False, print_empty_block=False
)

@classmethod
def parse(cls, parser: Parser) -> Self:
name = parser.parse_optional_symbol_name()
parser.parse_characters("(")
args: list[Parser.Argument] | None = []
while True:
if arg := parser.parse_optional_argument():
args.append(arg)
if not parser.parse_optional_punctuation(","):
break
parser.parse_characters(")")
if not len(args):
args = None
region = parser.parse_region(args)
return cls(body=region, name=name)


AIEX = Dialect(
"aiex",
[
RuntimeSequenceOp,
],
[],
)

0 comments on commit 8a4d56c

Please sign in to comment.