Skip to content

Commit

Permalink
add automated pass and cost
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Jan 7, 2025
1 parent a28c1d8 commit c449387
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 6 deletions.
176 changes: 171 additions & 5 deletions docs/marimo/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,13 @@ def _(


@app.cell
def _(memref_stream_ctx, memref_stream_module, mo, xmo):
def _(
memref_stream_cost_model,
memref_stream_ctx,
memref_stream_module,
mo,
xmo,
):
from xdsl.transforms import memref_stream_interleave

interleaved_ctx = memref_stream_ctx.clone()
Expand All @@ -455,12 +461,19 @@ def _(memref_stream_ctx, memref_stream_module, mo, xmo):
interleaved_ctx, interleaved_module
)

interleaved_cost = memref_stream_cost_model.estimate_cost(interleaved_module, interleaved_ctx)

mo.md(f"""
We can use the existing interleave pass to choose the unroll-and-jam factor:
We can use the existing interleave pass to choose the unroll-and-jam factor, with cost {interleaved_cost}:
{xmo.module_html(interleaved_module)}
""")
return interleaved_ctx, interleaved_module, memref_stream_interleave
return (
interleaved_cost,
interleaved_ctx,
interleaved_module,
memref_stream_interleave,
)


@app.cell
Expand Down Expand Up @@ -510,7 +523,7 @@ def apply(p: ModulePass, ctx: MLContext, op: ModuleOp) -> ModuleOp:
op = op.clone()
ctx = ctx.clone()
p.apply(ctx, op)
return op
return ctx, op
return (apply,)


Expand All @@ -524,7 +537,7 @@ def _(
uaj_passes,
):
scores = tuple(
memref_stream_cost_model.estimate_cost(apply(p, memref_stream_ctx, memref_stream_module), memref_stream_ctx)
memref_stream_cost_model.estimate_cost(apply(p, memref_stream_ctx, memref_stream_module)[1], memref_stream_ctx)
for p in uaj_passes
)

Expand All @@ -538,8 +551,161 @@ def _(

@app.cell
def _():
import numpy as np
from numpy.random import Generator as RandomGenerator
return RandomGenerator, np


@app.cell
def _(np):
rng = np.random.default_rng()

type(rng)
return (rng,)


@app.cell
def _():
from xdsl.interpreters.onnx import to_dtype
return (to_dtype,)


@app.cell
def _(
Attribute,
DenseIntOrFPElementsAttr,
RandomGenerator,
f64,
np,
to_dtype,
):
from xdsl.dialects.builtin import FloatAttr, ShapedType, ContainerType, TensorType


def random_attr_of_type(t: Attribute, rng: RandomGenerator) -> Attribute | None:
# if isinstance(type, ShapedType):
if t == f64:
return FloatAttr(rng.random(1, to_dtype(t))[0], f64)
elif isinstance(t, ShapedType) and isinstance(t, ContainerType):
values = rng.random(t.element_count(), to_dtype(t.get_element_type()))
return DenseIntOrFPElementsAttr.from_list(
t, values
)

_rng = np.random.default_rng()
random_attr_of_type(f64, _rng), random_attr_of_type(TensorType(f64, (2, 3)), _rng)
return (
ContainerType,
FloatAttr,
ShapedType,
TensorType,
random_attr_of_type,
)


@app.cell
def _(
LensCostModel,
MLContext,
MemrefStreamUnrollAndJamPass,
ModuleOp,
ModulePass,
SnitchCycleCostModel,
func,
memref_stream,
memref_stream_module,
msg_factors,
np,
random_attr_of_type,
riscv_passes,
uaj_passes,
):
class AutomaticUnrollAndJamPass(ModulePass):

name = "automatic-unroll-and-jam"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
msg_ops = tuple(child for child in memref_stream_module.walk() if isinstance(child, memref_stream.GenericOp))

if not msg_ops:
return

assert len(msg_ops) == 1

msg_op = msg_ops[0]

passes = tuple(
MemrefStreamUnrollAndJamPass(0, index, factor)
for index, factor in msg_factors
)

if not passes:
return

func_op = msg_op.parent_op()

assert isinstance(func_op, func.FuncOp), func_op.name

func_op_name = func_op.sym_name.data

arg_types = func_op.function_type.inputs

rng = np.random.default_rng()
attrs = tuple(random_attr_of_type(t, rng) for t in arg_types)

cost_model = LensCostModel(SnitchCycleCostModel(func_op_name, attrs), riscv_passes.passes)

scores = tuple(
enumerate(
cost_model.estimate_cost(apply(p, ctx, op)[1], ctx)
for p in uaj_passes
)
)

best_pass_index = min(scores, key=lambda x: x[1])

passes[best_pass_index[0]].apply(ctx, op)
return (AutomaticUnrollAndJamPass,)


@app.cell
def _(k, m, mo, n):
mo.md(
f"""
Here are the sliders again:
{m}{m.value}
{n}{n.value}
{k}{k.value}
"""
)
return


@app.cell
def _(
AutomaticUnrollAndJamPass,
apply,
interleaved_cost,
memref_stream_cost_model,
memref_stream_ctx,
memref_stream_module,
mo,
xmo,
):
automated_ctx, automated_module = apply(AutomaticUnrollAndJamPass(), memref_stream_ctx, memref_stream_module)

automated_cost = memref_stream_cost_model.estimate_cost(automated_module, automated_ctx)

mo.md(f"""
Here's the updated module with our automated approach, with cost {automated_cost} vs the heuristic's {interleaved_cost}:
{xmo.module_html(automated_module)}
""")
return automated_cost, automated_ctx, automated_module


if __name__ == "__main__":
app.run()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ max-line-length = 300
"docs/mlir_interoperation.ipynb" = ["E402"]
"docs/irdl.ipynb" = ["ALL"]
"docs/database_example.ipynb" = ["ALL"]
"**/{docs/marimo}/*" = ["E501", "I001"]
"**/{docs/marimo}/*" = ["E501", "I001", "F821"]
"_version.py" = ["ALL"]
"__init__.py" = ["F403"]

Expand Down

0 comments on commit c449387

Please sign in to comment.