Skip to content

Commit b794409

Browse files
authored
[Primitive] Extensible primitives and transformations in Python (cornell-zhang#504)
* [Primitive] Refactor * Decouple primitives * Remove redundant code * Fix pylint * Move context outside * Add user-defined primitives * Fix typo * Add create_buffer * Fix format * Fix linting * Update transform
1 parent 0e690d1 commit b794409

24 files changed

+1067
-429
lines changed

heterocl/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# SPDX-License-Identifier: Apache-2.0
33
# pylint: disable=redefined-builtin
44

5-
from .schedule import Schedule, customize, create_schedule, Partition
5+
from .schedule import Schedule, customize, create_schedule
6+
from .primitives.base import Primitive, register_primitive
7+
from .primitives.partition import Partition
68
from .scheme import Scheme, create_scheme, create_schedule_from_scheme
79
from .build_module import lower, build
810
from .operation import *

heterocl/ast/ir_builder.py

+1-175
Original file line numberDiff line numberDiff line change
@@ -362,40 +362,10 @@ def build_visitor(self, op, ip):
362362
self.build_op_handle(op, ip)
363363
elif isinstance(op, ast.LoopHandle):
364364
self.build_loop_handle(op, ip)
365-
elif isinstance(op, ast.ReuseAtOp):
366-
self.build_reuse_at_op(op, ip)
367-
elif isinstance(op, ast.PartitionOp):
368-
self.build_partition_op(op, ip)
369-
elif isinstance(op, ast.ReplaceOp):
370-
self.build_replace_op(op, ip)
371-
elif isinstance(op, ast.ReshapeOp):
372-
self.build_reshape_op(op, ip)
373-
elif isinstance(op, ast.ReformOp):
374-
self.build_reform_op(op, ip)
375-
elif isinstance(op, ast.BufferAtOp):
376-
self.build_buffer_at_op(op, ip)
377365
elif isinstance(op, ast.InterKernelToOp):
378366
self.build_inter_kernel_to_op(op, ip)
379367
elif isinstance(op, ast.OutlineOp):
380368
self.build_outline_op(op, ip)
381-
elif isinstance(op, ast.ReorderOp):
382-
self.build_reorder_op(op, ip)
383-
elif isinstance(op, ast.SplitOp):
384-
self.build_split_op(op, ip)
385-
elif isinstance(op, ast.TileOp):
386-
self.build_tile_op(op, ip)
387-
elif isinstance(op, ast.PipelineOp):
388-
self.build_pipeline_op(op, ip)
389-
elif isinstance(op, ast.UnrollOp):
390-
self.build_unroll_op(op, ip)
391-
elif isinstance(op, ast.ParallelOp):
392-
self.build_parallel_op(op, ip)
393-
elif isinstance(op, ast.FuseOp):
394-
self.build_fuse_op(op, ip)
395-
elif isinstance(op, ast.ComputeAtOp):
396-
self.build_compute_at_op(op, ip)
397-
elif isinstance(op, ast.SystolicOp):
398-
self.build_systolic_op(op, ip)
399369
else:
400370
raise HCLNotImplementedError(
401371
f"{type(op)}'s build visitor is not implemented yet."
@@ -469,6 +439,7 @@ def build_func_op(self, op: ast.FuncOp, ip):
469439
# as the same argument object may be refered in multiple functions
470440
# we need to make sure that the result is not reused
471441
for arg in op.args:
442+
arg.prev_result = arg.result
472443
arg.result = None
473444

474445
def build_call_op(self, op: ast.CallOp, ip):
@@ -1610,75 +1581,6 @@ def build_loop_handle(self, op: ast.LoopHandle, ip):
16101581
op.ir_op = hdl_op
16111582
op.result = hdl_op.result
16121583

1613-
def build_partition_op(self, op: ast.PartitionOp, ip):
1614-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1615-
i32 = IntegerType.get_signless(32)
1616-
ui32 = IntegerType.get_unsigned(32)
1617-
partition_type = IntegerAttr.get(i32, op.kind)
1618-
dim = IntegerAttr.get(ui32, op.dim)
1619-
factor = IntegerAttr.get(ui32, op.factor)
1620-
self.build_visitor(op.tensor, ip)
1621-
partition_op = hcl_d.PartitionOp(
1622-
op.tensor.result,
1623-
partition_kind=partition_type,
1624-
dim=dim,
1625-
factor=factor,
1626-
ip=ip,
1627-
loc=loc,
1628-
)
1629-
op.ir_op = partition_op
1630-
1631-
def build_replace_op(self, op: ast.ReplaceOp, ip):
1632-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1633-
self.build_visitor(op.target, ip)
1634-
self.build_visitor(op.src, ip)
1635-
replace_op = hcl_d.ReplaceOp(op.target.result, op.src.result, ip=ip, loc=loc)
1636-
op.ir_op = replace_op
1637-
1638-
def build_reshape_op(self, op: ast.ReshapeOp, ip):
1639-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1640-
self.build_visitor(op.tensor, ip)
1641-
eletype = hcl_dtype_to_mlir(op.tensor.dtype)
1642-
memref_type = MemRefType.get(op.shape, eletype, loc=loc)
1643-
reshape_op = hcl_d.ReshapeOp(memref_type, op.tensor.result, ip=ip, loc=loc)
1644-
op.ir_op = reshape_op
1645-
1646-
def build_reform_op(self, op: ast.ReformOp, ip):
1647-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1648-
self.build_visitor(op.target, ip)
1649-
if op.layout == "nhwc":
1650-
attr = AffineMap.get_permutation([0, 2, 3, 1])
1651-
else:
1652-
raise RuntimeError("Not supported layout")
1653-
memref_type = MemRefType.get(op.target.shape, op.target.ir_op.dtype)
1654-
reform_op = hcl_d.ReformOp(memref_type, op.target.result, ip=ip, loc=loc)
1655-
reform_op.attributes["layout"] = AffineMapAttr.get(attr)
1656-
op.ir_op = reform_op
1657-
1658-
def build_reuse_at_op(self, op: ast.ReuseAtOp, ip):
1659-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1660-
self.build_visitor(op.target, ip)
1661-
self.build_visitor(op.axis, ip)
1662-
f32 = F32Type.get()
1663-
memref_type = MemRefType.get((1,), f32, loc=loc)
1664-
reuse_at_op = hcl_d.ReuseAtOp(
1665-
memref_type, op.target.result, op.axis.result, ip=ip, loc=loc
1666-
)
1667-
op.ir_op = reuse_at_op
1668-
op.result = reuse_at_op.result
1669-
1670-
def build_buffer_at_op(self, op: ast.BufferAtOp, ip):
1671-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1672-
self.build_visitor(op.target, ip)
1673-
self.build_visitor(op.axis, ip)
1674-
f32 = F32Type.get()
1675-
memref_type = MemRefType.get((1,), f32, loc=loc)
1676-
buffer_at_op = hcl_d.BufferAtOp(
1677-
memref_type, op.target.result, op.axis.result, ip=ip, loc=loc
1678-
)
1679-
op.ir_op = buffer_at_op
1680-
op.result = buffer_at_op.result
1681-
16821584
def build_inter_kernel_to_op(self, op: ast.InterKernelToOp, ip):
16831585
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
16841586
self.build_visitor(op.tensor, ip)
@@ -1704,79 +1606,3 @@ def build_outline_op(self, op: ast.OutlineOp, ip):
17041606
if op.axis is not None:
17051607
outline_op.attributes["axis"] = StringAttr.get(op.axis)
17061608
op.ir_op = outline_op
1707-
1708-
def build_reorder_op(self, op: ast.ReorderOp, ip):
1709-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1710-
for arg in op.args:
1711-
self.build_visitor(arg, ip)
1712-
arg_results = [arg.result for arg in op.args]
1713-
reorder_op = hcl_d.ReorderOp(arg_results, ip=ip, loc=loc)
1714-
op.ir_op = reorder_op
1715-
1716-
def build_split_op(self, op: ast.SplitOp, ip):
1717-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1718-
self.build_visitor(op.parent, ip)
1719-
i32 = IntegerType.get_unsigned(32)
1720-
factor = IntegerAttr.get(i32, op.factor)
1721-
split_op = hcl_d.SplitOp(op.parent.result, factor, ip=ip, loc=loc)
1722-
op.ir_op = split_op
1723-
for result_loop_hdl, hdl_result in zip(op.results, split_op.results):
1724-
result_loop_hdl.result = hdl_result
1725-
1726-
def build_tile_op(self, op: ast.TileOp, ip):
1727-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1728-
i32 = IntegerType.get_unsigned(32)
1729-
x_factor = IntegerAttr.get(i32, op.x_factor)
1730-
y_factor = IntegerAttr.get(i32, op.y_factor)
1731-
self.build_visitor(op.x_parent, ip)
1732-
self.build_visitor(op.y_parent, ip)
1733-
tile_op = hcl_d.TileOp(
1734-
op.x_parent.result, op.y_parent.result, x_factor, y_factor, ip=ip, loc=loc
1735-
)
1736-
op.ir_op = tile_op
1737-
for result_loop_hdl, hdl_result in zip(op.results, tile_op.results):
1738-
result_loop_hdl.result = hdl_result
1739-
1740-
def build_pipeline_op(self, op: ast.PipelineOp, ip):
1741-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1742-
self.build_visitor(op.target, ip)
1743-
i32 = IntegerType.get_unsigned(32)
1744-
ii = IntegerAttr.get(i32, op.ii)
1745-
pipeline_op = hcl_d.PipelineOp(op.target.result, ii=ii, ip=ip, loc=loc)
1746-
op.ir_op = pipeline_op
1747-
1748-
def build_unroll_op(self, op: ast.UnrollOp, ip):
1749-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1750-
self.build_visitor(op.target, ip)
1751-
i32 = IntegerType.get_unsigned(32)
1752-
factor = IntegerAttr.get(i32, op.factor)
1753-
unroll_op = hcl_d.UnrollOp(op.target.result, factor=factor, ip=ip, loc=loc)
1754-
op.ir_op = unroll_op
1755-
1756-
def build_parallel_op(self, op: ast.ParallelOp, ip):
1757-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1758-
self.build_visitor(op.target, ip)
1759-
parallel_op = hcl_d.ParallelOp(op.target.result, ip=ip, loc=loc)
1760-
op.ir_op = parallel_op
1761-
1762-
def build_fuse_op(self, op: ast.FuseOp, ip):
1763-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1764-
for arg in op.arg_list:
1765-
self.build_visitor(arg, ip)
1766-
arg_results = [arg.result for arg in op.arg_list]
1767-
fuse_op = hcl_d.FuseOp(arg_results, ip=ip, loc=loc)
1768-
op.ir_op = fuse_op
1769-
op.result = fuse_op.result
1770-
1771-
def build_compute_at_op(self, op: ast.ComputeAtOp, ip):
1772-
loc = Location.file(op.loc.filename, op.loc.lineno, 0)
1773-
self.build_visitor(op.stage, ip)
1774-
self.build_visitor(op.parent, ip)
1775-
self.build_visitor(op.axis, ip)
1776-
compute_at_op = hcl_d.ComputeAtOp(
1777-
op.stage.result, op.parent.result, op.axis.result, ip=ip, loc=loc
1778-
)
1779-
op.ir_op = compute_at_op
1780-
1781-
def build_systolic_op(self, op: ast.SystolicOp, ip):
1782-
op.target.ir_op.attributes["systolic"] = UnitAttr.get()

heterocl/build_module.py

+1-23
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
from .runtime import copy_build_files
2525
from .schedule import Schedule
2626
from .utils import hcl_dtype_to_mlir
27-
from .passes.pass_manager import PassManager as ast_pass_manager
28-
from .passes.nest_if import NestElseIf
29-
from .passes.promote_func import PromoteFunc
3027
from .ast.ir_builder import IRBuilder
3128
from .ast.build_cleaner import ASTCleaner
3229
from .ast import ast
@@ -51,26 +48,7 @@ def lower(
5148
"""Lowering step before build into target
5249
by applying optimization pass
5350
"""
54-
if schedule.is_lowered():
55-
raise APIError(
56-
"The module has been lowered. Please apply schedule primitives before the lowering process."
57-
)
58-
# HeteroCL Transformation Pipeline
59-
ast_pm = ast_pass_manager()
60-
ast_pm.add_pass(NestElseIf)
61-
ast_pm.add_pass(PromoteFunc)
62-
device_agnostic_ast = ast_pm.run(schedule.ast)
63-
schedule._ast = device_agnostic_ast
64-
65-
# Build MLIR IR
66-
set_context()
67-
agnostic_ir_builder = IRBuilder(device_agnostic_ast)
68-
agnostic_ir_builder.build()
69-
agnostic_module = agnostic_ir_builder.module
70-
schedule._module = _mlir_lower_pipeline(agnostic_module)
71-
schedule._top_func = agnostic_ir_builder.top_func
72-
exit_context()
73-
51+
schedule._module = _mlir_lower_pipeline(schedule._module)
7452
schedule.set_lowered()
7553
return schedule.module
7654

heterocl/ir/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright HeteroCL authors. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0

heterocl/ir/transform.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright HeteroCL authors. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
# pylint: disable=no-name-in-module, no-value-for-parameter
4+
5+
import hcl_mlir
6+
from hcl_mlir import UnitAttr, StringAttr, InsertionPoint, MemRefType
7+
from hcl_mlir.dialects import memref as memref_d
8+
9+
10+
def get_affine_loop_nests(func):
11+
loops = hcl_mlir.get_affine_loop_nests(func)
12+
res = []
13+
for loop in loops:
14+
res.append([(item["name"], item["body"]) for item in loop])
15+
return res
16+
17+
18+
def annotate(op, name):
19+
op.attributes[name] = UnitAttr.get()
20+
21+
22+
def build_for_loops(grid, ip, name="loop"):
23+
for_loops = []
24+
if isinstance(name, str):
25+
names = [name + f"_l_{i}" for i in range(len(grid))]
26+
stage_name = "S_" + name
27+
else: # list
28+
names = name
29+
stage_name = "S_" + "_".join(names)
30+
assert len(grid) >= 1
31+
32+
def recursive_for(for_handle, idx):
33+
if idx == len(grid):
34+
return
35+
with InsertionPoint(for_handle.body.operations[0]):
36+
new_for = hcl_mlir.make_for(0, grid[idx], name=names[idx])
37+
for_loops.append(new_for)
38+
recursive_for(new_for, idx + 1)
39+
40+
if not isinstance(ip, InsertionPoint):
41+
ip = InsertionPoint(ip)
42+
with ip:
43+
for_handle = hcl_mlir.make_for(0, grid[0], name=names[0], stage=stage_name)
44+
for_loops.append(for_handle)
45+
recursive_for(for_handle, 1)
46+
return for_loops
47+
48+
49+
def create_buffer(tensor, name, ip):
50+
with InsertionPoint(ip):
51+
alloc_op = memref_d.AllocOp(tensor.type, [], [])
52+
alloc_op.attributes["name"] = StringAttr.get(name)
53+
shape = MemRefType(tensor.type).shape
54+
for_loops = build_for_loops(shape, ip, name)
55+
induction_vars = [for_loop.induction_variable for for_loop in for_loops]
56+
with InsertionPoint(for_loops[-1].body.operations[0]):
57+
load = memref_d.LoadOp(tensor, induction_vars)
58+
memref_d.StoreOp(
59+
load.result,
60+
alloc_op.result,
61+
induction_vars,
62+
)
63+
# TODO: Upgrade LLVM version and use the following code
64+
# tensor.replace_all_uses_with(alloc_op.result)

heterocl/primitives/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright HeteroCL authors. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Schedule primitives."""
4+
5+
import os
6+
from os.path import abspath, dirname, join, isfile
7+
from inspect import getsourcefile, getmembers
8+
from importlib import import_module
9+
10+
from .base import register_primitive, PRIMITIVES, Primitive
11+
12+
path = dirname(abspath(getsourcefile(lambda: 0)))
13+
files = [
14+
f
15+
for f in os.listdir(path)
16+
if isfile(join(path, f)) and f not in {"__init__.py", "base.py"}
17+
]
18+
for file in files:
19+
mod = import_module(f".{file.split('.')[0]}", package="heterocl.primitives")
20+
# register the schedule primitive using the decorator
21+
getmembers(mod)

heterocl/primitives/base.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright HeteroCL authors. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Schedule primitive base."""
4+
# pylint: disable=no-method-argument
5+
6+
from __future__ import annotations
7+
from abc import ABCMeta, abstractmethod
8+
9+
PRIMITIVES = {}
10+
STAGE_PRIMITIVES = {}
11+
12+
13+
def register_primitive():
14+
"""Register a primitive to the schedule."""
15+
16+
def dectorator(cls):
17+
if cls.name in PRIMITIVES:
18+
raise ValueError(f"Primitive {cls.name} already registered")
19+
if not issubclass(cls, Primitive):
20+
raise ValueError(f"Class {cls} is not a subclass of Primitive")
21+
if hasattr(cls, "is_stage_primitive") and cls.is_stage_primitive:
22+
STAGE_PRIMITIVES[cls.name] = cls
23+
else:
24+
PRIMITIVES[cls.name] = cls
25+
return cls
26+
27+
return dectorator
28+
29+
30+
class Primitive(metaclass=ABCMeta):
31+
"""A base class of schedule primitives."""
32+
33+
@property
34+
@abstractmethod
35+
def name():
36+
"""The name of the primitive."""
37+
raise NotImplementedError
38+
39+
@staticmethod
40+
@abstractmethod
41+
def apply(sch, *args, **kwargs):
42+
"""Apply the primitive to the schedule."""
43+
raise NotImplementedError

0 commit comments

Comments
 (0)