-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add template helpers to the op_reg facility. (#10)
* This support was ported and generalized from sharktank, where the technique originated. * Support ops defined in terms of either a str-format or jinja2 based ASM template. * Helpers for inlining and calling. * Example ops used for testing. * Adds dependency on Jinja2. Signed-off-by: Stella Laurenzo <[email protected]>
- Loading branch information
1 parent
63a2411
commit 31d4378
Showing
14 changed files
with
380 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Build/test requirements. | ||
Jinja2==3.1.3 | ||
numpy==1.26.3 | ||
pytest==8.0.0 | ||
pytest-xdist==3.5.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from ..support.ir_imports import ( | ||
RankedTensorType, | ||
) | ||
|
||
from ..runtime.op_reg import ( | ||
CustomOp, | ||
KernelBuilder, | ||
KernelSelection, | ||
def_library, | ||
impl_helper, | ||
) | ||
|
||
__all__ = [ | ||
"trace", | ||
] | ||
|
||
LIBRARY = def_library("_turbine_jinja_test") | ||
_templates = impl_helper.JinjaTemplateLoader(__name__) | ||
|
||
|
||
@CustomOp.register(library=LIBRARY) | ||
class test_add(CustomOp): | ||
signature = "test_add(Tensor t1, Tensor t2) -> (Tensor)" | ||
|
||
def select(self, ksel: KernelSelection): | ||
t1_desc = ksel.arg_tensor(0) | ||
t1_desc.specialize_all_dims() | ||
t2_desc = ksel.arg_tensor(1) | ||
t2_desc.specialize_all_dims() | ||
result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype) | ||
result_desc.specialize_all_dims() | ||
|
||
def generate(self, ksel: KernelSelection, kb: KernelBuilder): | ||
result_type = kb.arg_bindings[0].type # type: ignore | ||
rtt = RankedTensorType(result_type) | ||
function_name = f"turbine_test_add_jinja_{rtt.rank}d_{str(rtt.element_type)}" | ||
func_op = _templates.inline_template_function( | ||
kb, | ||
"test_add_jinja", | ||
function_name, | ||
rank=rtt.rank, | ||
element_type=str(rtt.element_type), | ||
tensor_type=str(rtt), | ||
) | ||
kb.yield_results(*impl_helper.call_function(func_op, *kb.arg_bindings)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from ..support.ir_imports import ( | ||
RankedTensorType, | ||
) | ||
|
||
from ..runtime.op_reg import ( | ||
CustomOp, | ||
KernelBuilder, | ||
KernelSelection, | ||
def_library, | ||
impl_helper, | ||
) | ||
|
||
__all__ = [ | ||
"trace", | ||
] | ||
|
||
LIBRARY = def_library("_turbine_str_format_test") | ||
_templates = impl_helper.StrFormatTemplateLoader(__name__) | ||
|
||
|
||
@CustomOp.register(library=LIBRARY) | ||
class test_add(CustomOp): | ||
signature = "test_add(Tensor t1, Tensor t2) -> (Tensor)" | ||
|
||
def select(self, ksel: KernelSelection): | ||
t1_desc = ksel.arg_tensor(0) | ||
t1_desc.specialize_all_dims() | ||
t2_desc = ksel.arg_tensor(1) | ||
t2_desc.specialize_all_dims() | ||
result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype) | ||
result_desc.specialize_all_dims() | ||
|
||
def generate(self, ksel: KernelSelection, kb: KernelBuilder): | ||
result_type = kb.arg_bindings[0].type # type: ignore | ||
rtt = RankedTensorType(result_type) | ||
function_name = ( | ||
f"turbine_test_add_strformat_{rtt.rank}d_{str(rtt.element_type)}" | ||
) | ||
func_op = _templates.inline_template_function( | ||
kb, | ||
"test_add_strformat", | ||
function_name, | ||
rank=rtt.rank, | ||
element_type=str(rtt.element_type), | ||
tensor_type=str(rtt), | ||
) | ||
kb.yield_results(*impl_helper.call_function(func_op, *kb.arg_bindings)) | ||
|
||
|
||
@CustomOp.register(library=LIBRARY) | ||
class syntax_error(CustomOp): | ||
signature = "syntax_error(Tensor t1) -> (Tensor)" | ||
|
||
def select(self, ksel: KernelSelection): | ||
t1_desc = ksel.arg_tensor(0) | ||
t1_desc.specialize_all_dims() | ||
result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype) | ||
result_desc.specialize_all_dims() | ||
|
||
def generate(self, ksel: KernelSelection, kb: KernelBuilder): | ||
function_name = "syntax_error" | ||
func_op = _templates.inline_template_function( | ||
kb, | ||
"test_syntax_error", | ||
function_name, | ||
) | ||
kb.yield_results(*impl_helper.call_function(func_op, *kb.arg_bindings)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
!tensor_type = {{tensor_type}} | ||
|
||
module { | ||
|
||
util.func private @turbine_test_add_jinja_{{rank}}d_{{element_type}}( | ||
%a: !tensor_type, %b: !tensor_type | ||
) -> !tensor_type { | ||
%out = tensor.empty() : !tensor_type | ||
%0 = linalg.add ins(%a, %b : !tensor_type, !tensor_type) outs(%out : !tensor_type) -> !tensor_type | ||
util.return %0 : !tensor_type | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
!tensor_type = {tensor_type} | ||
|
||
module {{ | ||
|
||
util.func private @turbine_test_add_strformat_{rank}d_{element_type}( | ||
%a: !tensor_type, %b: !tensor_type | ||
) -> !tensor_type {{ | ||
%out = tensor.empty() : !tensor_type | ||
%0 = linalg.add ins(%a, %b : !tensor_type, !tensor_type) outs(%out : !tensor_type) -> !tensor_type | ||
util.return %0 : !tensor_type | ||
}} | ||
}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
THIS IS A SYNTAX ERROR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ | |
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .base import * | ||
from . import impl_helper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
"""Helpers for implementing ops. | ||
Typical usage: | ||
``` | ||
_templates = JinjaTemplateLoader(__name__) | ||
def generate(kb: KernelBuilder): | ||
func_op = _templates.inline_template_function( | ||
kb, "my_template", "function_name", **kwargs) | ||
return call_function(func_op, *values) | ||
``` | ||
""" | ||
|
||
from typing import Sequence | ||
|
||
from abc import ABC, abstractmethod | ||
import logging | ||
import textwrap | ||
|
||
from ...support.logging import runtime_logger as logger | ||
|
||
from ...support.ir_imports import ( | ||
FlatSymbolRefAttr, | ||
FunctionType, | ||
MLIRError, | ||
Operation, | ||
StringAttr, | ||
TypeAttr, | ||
Value, | ||
) | ||
|
||
from ...transforms.merger import Merger | ||
|
||
from .base import ( | ||
KernelBuilder, | ||
) | ||
|
||
|
||
__all__ = [ | ||
"TemplateLoader", | ||
"StrFormatTemplateLoader", | ||
"call_function", | ||
] | ||
|
||
|
||
class TemplateLoader(ABC): | ||
"""Base class for templates that can be loaded by name.""" | ||
|
||
@abstractmethod | ||
def load_template(self, kb: KernelBuilder, name: str, **kwargs) -> Operation: | ||
"""Loads a template by name and kwargs, returning the module operation.""" | ||
... | ||
|
||
def _parse_module_asm(self, kb: KernelBuilder, asm: str) -> Operation: | ||
try: | ||
module_op = Operation.parse(asm, context=kb.context) | ||
except MLIRError as e: | ||
lines = asm.splitlines() | ||
lines_numbered = "\n".join( | ||
[f" {str(i+1):>5}: {l}" for i, l in enumerate(lines)] | ||
) | ||
raise RuntimeError( | ||
f"Error parsing generated op template:" | ||
f"\n{textwrap.indent(str(e), ' ')}" | ||
f"\n{lines_numbered}" | ||
) | ||
return module_op.operation | ||
|
||
def inline_template_function( | ||
self, | ||
kb: KernelBuilder, | ||
template_file: str, | ||
function_name: str, | ||
**kwargs, | ||
) -> Operation: | ||
"""Inlines a template module by first expanding its ASM via **kwargs. | ||
Returns the inlined symbol `function_name`, which is expected to have been | ||
in the template. | ||
""" | ||
try: | ||
return kb.symbol_table[function_name] | ||
except KeyError: | ||
pass | ||
source_module_op = self.load_template(kb, template_file, **kwargs) | ||
if logger.isEnabledFor(logging.DEBUG): | ||
logger.debug( | ||
"Generated kernel IR %s:\n%s", function_name, str(source_module_op) | ||
) | ||
merger = Merger( | ||
source_module_op, kb.module_body.owner, target_symbol_table=kb.symbol_table | ||
) | ||
merger.merge() | ||
return kb.symbol_table[function_name] | ||
|
||
|
||
class StrFormatTemplateLoader(TemplateLoader): | ||
"""Template loader that uses str.format. | ||
Usage: | ||
_templates = StrFromatTemplateLoader(__name__) | ||
By default, this will resolve a template like "foo" from templates/foo.mlir | ||
in the package directory. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
package_name: str, | ||
package_path: str = "templates", | ||
*, | ||
suffix: str = ".mlir", | ||
): | ||
self.parent_package_name = ".".join(package_name.split(".")[0:-1]) | ||
self.package_path = package_path | ||
self.suffix = suffix | ||
|
||
def load_template(self, kb: KernelBuilder, name: str, **kwargs) -> Operation: | ||
from importlib import resources | ||
|
||
res = ( | ||
resources.files(self.parent_package_name) | ||
/ self.package_path | ||
/ f"{name}{self.suffix}" | ||
) | ||
contents = res.read_text().format(**kwargs) | ||
return self._parse_module_asm(kb, contents) | ||
|
||
|
||
class JinjaTemplateLoader(TemplateLoader): | ||
"""Template loader based on jinja templates. | ||
Usage: | ||
_templates = JinjaTemplateLoader(__name__) | ||
By default, this will resolve a template like "foo" from templates/foo.mlir | ||
in the package directory. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
package_name: str, | ||
package_path: str = "templates", | ||
*, | ||
suffix: str = ".mlir", | ||
): | ||
try: | ||
from jinja2 import Environment, PackageLoader, select_autoescape | ||
except ModuleNotFoundError as e: | ||
raise ModuleNotFoundError( | ||
"Cannot use JinjaTemplateLoader if jinja2 is not installed" | ||
) from e | ||
self.env = Environment(loader=PackageLoader(package_name, package_path)) | ||
self.suffix = suffix | ||
|
||
def load_template(self, kb: KernelBuilder, name: str, **kwargs) -> Operation: | ||
template_file = f"{name}{self.suffix}" | ||
contents = self.env.get_template(template_file).render(**kwargs) | ||
return self._parse_module_asm(kb, contents) | ||
|
||
|
||
def call_function(target_function: Operation, *operands: Value) -> Sequence[Value]: | ||
"""Emits a util.call for a util.func target function operation.""" | ||
target_symbol = FlatSymbolRefAttr.get( | ||
StringAttr(target_function.attributes["sym_name"]).value_bytes | ||
) | ||
ftype = FunctionType(TypeAttr(target_function.attributes["function_type"]).value) | ||
return Operation.create( | ||
"util.call", | ||
results=ftype.results, | ||
operands=operands, | ||
attributes={ | ||
"callee": target_symbol, | ||
}, | ||
).results |
Oops, something went wrong.