diff --git a/docs/marimo/autotune.py b/docs/marimo/autotune.py index dd775bdfd7..288eba793c 100644 --- a/docs/marimo/autotune.py +++ b/docs/marimo/autotune.py @@ -445,7 +445,13 @@ def _( @app.cell -def _(memref_stream_ctx, memref_stream_module, mo, xmo): +def _( + memref_stream_cost_model, + memref_stream_ctx, + memref_stream_module, + mo, + xmo, +): from xdsl.transforms import memref_stream_interleave interleaved_ctx = memref_stream_ctx.clone() @@ -455,12 +461,19 @@ def _(memref_stream_ctx, memref_stream_module, mo, xmo): interleaved_ctx, interleaved_module ) + interleaved_cost = memref_stream_cost_model.estimate_cost(interleaved_module, interleaved_ctx) + mo.md(f""" - We can use the existing interleave pass to choose the unroll-and-jam factor: + We can use the existing interleave pass to choose the unroll-and-jam factor, with cost {interleaved_cost}: {xmo.module_html(interleaved_module)} """) - return interleaved_ctx, interleaved_module, memref_stream_interleave + return ( + interleaved_cost, + interleaved_ctx, + interleaved_module, + memref_stream_interleave, + ) @app.cell @@ -510,7 +523,7 @@ def apply(p: ModulePass, ctx: MLContext, op: ModuleOp) -> ModuleOp: op = op.clone() ctx = ctx.clone() p.apply(ctx, op) - return op + return ctx, op return (apply,) @@ -524,7 +537,7 @@ def _( uaj_passes, ): scores = tuple( - memref_stream_cost_model.estimate_cost(apply(p, memref_stream_ctx, memref_stream_module), memref_stream_ctx) + memref_stream_cost_model.estimate_cost(apply(p, memref_stream_ctx, memref_stream_module)[1], memref_stream_ctx) for p in uaj_passes ) @@ -538,8 +551,161 @@ def _( @app.cell def _(): + import numpy as np + from numpy.random import Generator as RandomGenerator + return RandomGenerator, np + + +@app.cell +def _(np): + rng = np.random.default_rng() + + type(rng) + return (rng,) + + +@app.cell +def _(): + from xdsl.interpreters.onnx import to_dtype + return (to_dtype,) + + +@app.cell +def _( + Attribute, + DenseIntOrFPElementsAttr, + RandomGenerator, + f64, + np, + to_dtype, +): + from xdsl.dialects.builtin import FloatAttr, ShapedType, ContainerType, TensorType + + + def random_attr_of_type(t: Attribute, rng: RandomGenerator) -> Attribute | None: + # if isinstance(type, ShapedType): + if t == f64: + return FloatAttr(rng.random(1, to_dtype(t))[0], f64) + elif isinstance(t, ShapedType) and isinstance(t, ContainerType): + values = rng.random(t.element_count(), to_dtype(t.get_element_type())) + return DenseIntOrFPElementsAttr.from_list( + t, values + ) + + _rng = np.random.default_rng() + random_attr_of_type(f64, _rng), random_attr_of_type(TensorType(f64, (2, 3)), _rng) + return ( + ContainerType, + FloatAttr, + ShapedType, + TensorType, + random_attr_of_type, + ) + + +@app.cell +def _( + LensCostModel, + MLContext, + MemrefStreamUnrollAndJamPass, + ModuleOp, + ModulePass, + SnitchCycleCostModel, + func, + memref_stream, + memref_stream_module, + msg_factors, + np, + random_attr_of_type, + riscv_passes, + uaj_passes, +): + class AutomaticUnrollAndJamPass(ModulePass): + + name = "automatic-unroll-and-jam" + + def apply(self, ctx: MLContext, op: ModuleOp) -> None: + msg_ops = tuple(child for child in memref_stream_module.walk() if isinstance(child, memref_stream.GenericOp)) + + if not msg_ops: + return + + assert len(msg_ops) == 1 + + msg_op = msg_ops[0] + + passes = tuple( + MemrefStreamUnrollAndJamPass(0, index, factor) + for index, factor in msg_factors + ) + + if not passes: + return + + func_op = msg_op.parent_op() + + assert isinstance(func_op, func.FuncOp), func_op.name + + func_op_name = func_op.sym_name.data + + arg_types = func_op.function_type.inputs + + rng = np.random.default_rng() + attrs = tuple(random_attr_of_type(t, rng) for t in arg_types) + + cost_model = LensCostModel(SnitchCycleCostModel(func_op_name, attrs), riscv_passes.passes) + + scores = tuple( + enumerate( + cost_model.estimate_cost(apply(p, ctx, op)[1], ctx) + for p in uaj_passes + ) + ) + + best_pass_index = min(scores, key=lambda x: x[1]) + + passes[best_pass_index[0]].apply(ctx, op) + return (AutomaticUnrollAndJamPass,) + + +@app.cell +def _(k, m, mo, n): + mo.md( + f""" + Here are the sliders again: + + {m}{m.value} + + {n}{n.value} + + {k}{k.value} + """ + ) return +@app.cell +def _( + AutomaticUnrollAndJamPass, + apply, + interleaved_cost, + memref_stream_cost_model, + memref_stream_ctx, + memref_stream_module, + mo, + xmo, +): + automated_ctx, automated_module = apply(AutomaticUnrollAndJamPass(), memref_stream_ctx, memref_stream_module) + + automated_cost = memref_stream_cost_model.estimate_cost(automated_module, automated_ctx) + + mo.md(f""" + Here's the updated module with our automated approach, with cost {automated_cost} vs the heuristic's {interleaved_cost}: + + {xmo.module_html(automated_module)} + """) + return automated_cost, automated_ctx, automated_module + + if __name__ == "__main__": app.run() diff --git a/pyproject.toml b/pyproject.toml index b2e4cd1c39..e194bd3ba6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,7 +141,7 @@ max-line-length = 300 "docs/mlir_interoperation.ipynb" = ["E402"] "docs/irdl.ipynb" = ["ALL"] "docs/database_example.ipynb" = ["ALL"] -"**/{docs/marimo}/*" = ["E501", "I001"] +"**/{docs/marimo}/*" = ["E501", "I001", "F821"] "_version.py" = ["ALL"] "__init__.py" = ["F403"]