Skip to content

Commit

Permalink
Add template helpers to the op_reg facility. (#10)
Browse files Browse the repository at this point in the history
* This support was ported and generalized from sharktank, where the
technique originated.
* Support ops defined in terms of either a str-format or jinja2 based
ASM template.
* Helpers for inlining and calling.
* Example ops used for testing.
* Adds dependency on Jinja2.

Signed-off-by: Stella Laurenzo <[email protected]>
  • Loading branch information
stellaraccident authored May 2, 2024
1 parent 63a2411 commit 31d4378
Show file tree
Hide file tree
Showing 14 changed files with 380 additions and 30 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Build/test requirements.
Jinja2==3.1.3
numpy==1.26.3
pytest==8.0.0
pytest-xdist==3.5.0
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def initialize_options(self):
f"iree-compiler{get_version_spec('iree-compiler')}",
f"iree-runtime{get_version_spec('iree-runtime')}",
"torch>=2.3.0",
f"Jinja2{get_version_spec('Jinja2')}",
],
extras_require={
"testing": [
Expand Down
1 change: 1 addition & 0 deletions shark_turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
actual loads/stores/computes to local vectors using PyTorch tensor
level operations executed as threads over a grid.
"""

from typing import Any, Callable, Type, Optional, Sequence, Union, List
import types

Expand Down
51 changes: 51 additions & 0 deletions shark_turbine/ops/_jinja_test_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..support.ir_imports import (
RankedTensorType,
)

from ..runtime.op_reg import (
CustomOp,
KernelBuilder,
KernelSelection,
def_library,
impl_helper,
)

__all__ = [
"trace",
]

LIBRARY = def_library("_turbine_jinja_test")
_templates = impl_helper.JinjaTemplateLoader(__name__)


@CustomOp.register(library=LIBRARY)
class test_add(CustomOp):
signature = "test_add(Tensor t1, Tensor t2) -> (Tensor)"

def select(self, ksel: KernelSelection):
t1_desc = ksel.arg_tensor(0)
t1_desc.specialize_all_dims()
t2_desc = ksel.arg_tensor(1)
t2_desc.specialize_all_dims()
result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype)
result_desc.specialize_all_dims()

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
result_type = kb.arg_bindings[0].type # type: ignore
rtt = RankedTensorType(result_type)
function_name = f"turbine_test_add_jinja_{rtt.rank}d_{str(rtt.element_type)}"
func_op = _templates.inline_template_function(
kb,
"test_add_jinja",
function_name,
rank=rtt.rank,
element_type=str(rtt.element_type),
tensor_type=str(rtt),
)
kb.yield_results(*impl_helper.call_function(func_op, *kb.arg_bindings))
73 changes: 73 additions & 0 deletions shark_turbine/ops/_str_format_test_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ..support.ir_imports import (
RankedTensorType,
)

from ..runtime.op_reg import (
CustomOp,
KernelBuilder,
KernelSelection,
def_library,
impl_helper,
)

__all__ = [
"trace",
]

LIBRARY = def_library("_turbine_str_format_test")
_templates = impl_helper.StrFormatTemplateLoader(__name__)


@CustomOp.register(library=LIBRARY)
class test_add(CustomOp):
signature = "test_add(Tensor t1, Tensor t2) -> (Tensor)"

def select(self, ksel: KernelSelection):
t1_desc = ksel.arg_tensor(0)
t1_desc.specialize_all_dims()
t2_desc = ksel.arg_tensor(1)
t2_desc.specialize_all_dims()
result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype)
result_desc.specialize_all_dims()

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
result_type = kb.arg_bindings[0].type # type: ignore
rtt = RankedTensorType(result_type)
function_name = (
f"turbine_test_add_strformat_{rtt.rank}d_{str(rtt.element_type)}"
)
func_op = _templates.inline_template_function(
kb,
"test_add_strformat",
function_name,
rank=rtt.rank,
element_type=str(rtt.element_type),
tensor_type=str(rtt),
)
kb.yield_results(*impl_helper.call_function(func_op, *kb.arg_bindings))


@CustomOp.register(library=LIBRARY)
class syntax_error(CustomOp):
signature = "syntax_error(Tensor t1) -> (Tensor)"

def select(self, ksel: KernelSelection):
t1_desc = ksel.arg_tensor(0)
t1_desc.specialize_all_dims()
result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype)
result_desc.specialize_all_dims()

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
function_name = "syntax_error"
func_op = _templates.inline_template_function(
kb,
"test_syntax_error",
function_name,
)
kb.yield_results(*impl_helper.call_function(func_op, *kb.arg_bindings))
22 changes: 0 additions & 22 deletions shark_turbine/ops/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import cast

from ..support.ir_imports import (
Operation,
RankedTensorType,
StringAttr,
Value,
Expand Down Expand Up @@ -61,24 +60,3 @@ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
key = cast(AttrArg, ksel.arg_descs[0])
_emit_tensor_trace(kb, cast(str, key.v), [kb.arg_bindings[1]])
kb.yield_results(kb.arg_bindings[1])


@CustomOp.register(library=IREE_LIBRARY)
class _test_add(CustomOp):
signature = "_test_add(Tensor t1, Tensor t2) -> (Tensor)"

def select(self, ksel: KernelSelection):
t1_desc = ksel.arg_tensor(0)
t1_desc.specialize_all_dims()
t2_desc = ksel.arg_tensor(1)
t2_desc.specialize_all_dims()
result_desc = ksel.return_new_tensor(list(t1_desc.t.shape), t1_desc.t.dtype)
result_desc.specialize_all_dims()

def generate(self, ksel: KernelSelection, kb: KernelBuilder):
t1, t2 = kb.arg_bindings
result_type = t1.type # type: ignore
result = Operation.create(
"tosa.add", results=[result_type], operands=[t1, t2]
).result
kb.yield_results(result)
12 changes: 12 additions & 0 deletions shark_turbine/ops/templates/test_add_jinja.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
!tensor_type = {{tensor_type}}

module {

util.func private @turbine_test_add_jinja_{{rank}}d_{{element_type}}(
%a: !tensor_type, %b: !tensor_type
) -> !tensor_type {
%out = tensor.empty() : !tensor_type
%0 = linalg.add ins(%a, %b : !tensor_type, !tensor_type) outs(%out : !tensor_type) -> !tensor_type
util.return %0 : !tensor_type
}
}
12 changes: 12 additions & 0 deletions shark_turbine/ops/templates/test_add_strformat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
!tensor_type = {tensor_type}

module {{

util.func private @turbine_test_add_strformat_{rank}d_{element_type}(
%a: !tensor_type, %b: !tensor_type
) -> !tensor_type {{
%out = tensor.empty() : !tensor_type
%0 = linalg.add ins(%a, %b : !tensor_type, !tensor_type) outs(%out : !tensor_type) -> !tensor_type
util.return %0 : !tensor_type
}}
}}
1 change: 1 addition & 0 deletions shark_turbine/ops/templates/test_syntax_error.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
THIS IS A SYNTAX ERROR
1 change: 1 addition & 0 deletions shark_turbine/runtime/op_reg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .base import *
from . import impl_helper
182 changes: 182 additions & 0 deletions shark_turbine/runtime/op_reg/impl_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Helpers for implementing ops.
Typical usage:
```
_templates = JinjaTemplateLoader(__name__)
def generate(kb: KernelBuilder):
func_op = _templates.inline_template_function(
kb, "my_template", "function_name", **kwargs)
return call_function(func_op, *values)
```
"""

from typing import Sequence

from abc import ABC, abstractmethod
import logging
import textwrap

from ...support.logging import runtime_logger as logger

from ...support.ir_imports import (
FlatSymbolRefAttr,
FunctionType,
MLIRError,
Operation,
StringAttr,
TypeAttr,
Value,
)

from ...transforms.merger import Merger

from .base import (
KernelBuilder,
)


__all__ = [
"TemplateLoader",
"StrFormatTemplateLoader",
"call_function",
]


class TemplateLoader(ABC):
"""Base class for templates that can be loaded by name."""

@abstractmethod
def load_template(self, kb: KernelBuilder, name: str, **kwargs) -> Operation:
"""Loads a template by name and kwargs, returning the module operation."""
...

def _parse_module_asm(self, kb: KernelBuilder, asm: str) -> Operation:
try:
module_op = Operation.parse(asm, context=kb.context)
except MLIRError as e:
lines = asm.splitlines()
lines_numbered = "\n".join(
[f" {str(i+1):>5}: {l}" for i, l in enumerate(lines)]
)
raise RuntimeError(
f"Error parsing generated op template:"
f"\n{textwrap.indent(str(e), ' ')}"
f"\n{lines_numbered}"
)
return module_op.operation

def inline_template_function(
self,
kb: KernelBuilder,
template_file: str,
function_name: str,
**kwargs,
) -> Operation:
"""Inlines a template module by first expanding its ASM via **kwargs.
Returns the inlined symbol `function_name`, which is expected to have been
in the template.
"""
try:
return kb.symbol_table[function_name]
except KeyError:
pass
source_module_op = self.load_template(kb, template_file, **kwargs)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Generated kernel IR %s:\n%s", function_name, str(source_module_op)
)
merger = Merger(
source_module_op, kb.module_body.owner, target_symbol_table=kb.symbol_table
)
merger.merge()
return kb.symbol_table[function_name]


class StrFormatTemplateLoader(TemplateLoader):
"""Template loader that uses str.format.
Usage:
_templates = StrFromatTemplateLoader(__name__)
By default, this will resolve a template like "foo" from templates/foo.mlir
in the package directory.
"""

def __init__(
self,
package_name: str,
package_path: str = "templates",
*,
suffix: str = ".mlir",
):
self.parent_package_name = ".".join(package_name.split(".")[0:-1])
self.package_path = package_path
self.suffix = suffix

def load_template(self, kb: KernelBuilder, name: str, **kwargs) -> Operation:
from importlib import resources

res = (
resources.files(self.parent_package_name)
/ self.package_path
/ f"{name}{self.suffix}"
)
contents = res.read_text().format(**kwargs)
return self._parse_module_asm(kb, contents)


class JinjaTemplateLoader(TemplateLoader):
"""Template loader based on jinja templates.
Usage:
_templates = JinjaTemplateLoader(__name__)
By default, this will resolve a template like "foo" from templates/foo.mlir
in the package directory.
"""

def __init__(
self,
package_name: str,
package_path: str = "templates",
*,
suffix: str = ".mlir",
):
try:
from jinja2 import Environment, PackageLoader, select_autoescape
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"Cannot use JinjaTemplateLoader if jinja2 is not installed"
) from e
self.env = Environment(loader=PackageLoader(package_name, package_path))
self.suffix = suffix

def load_template(self, kb: KernelBuilder, name: str, **kwargs) -> Operation:
template_file = f"{name}{self.suffix}"
contents = self.env.get_template(template_file).render(**kwargs)
return self._parse_module_asm(kb, contents)


def call_function(target_function: Operation, *operands: Value) -> Sequence[Value]:
"""Emits a util.call for a util.func target function operation."""
target_symbol = FlatSymbolRefAttr.get(
StringAttr(target_function.attributes["sym_name"]).value_bytes
)
ftype = FunctionType(TypeAttr(target_function.attributes["function_type"]).value)
return Operation.create(
"util.call",
results=ftype.results,
operands=operands,
attributes={
"callee": target_symbol,
},
).results
Loading

0 comments on commit 31d4378

Please sign in to comment.