Skip to content

Commit

Permalink
documentation: (marimo) move module_html helper to utils (#3714)
Browse files Browse the repository at this point in the history
Unifies how we show modules in marimo notebooks.
  • Loading branch information
superlopuh authored Jan 7, 2025
1 parent 69cec51 commit 460bd6e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 31 deletions.
21 changes: 5 additions & 16 deletions docs/marimo/linalg_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _(
func,
linalg,
mo,
module_html,
xmo,
):
a_type = MemRefType(f64, a_shape)
b_type = MemRefType(f64, b_shape)
Expand Down Expand Up @@ -186,7 +186,7 @@ def _(
Here is matrix multiplication defined in the `linalg` dialect, with the iteration space decoupled from the computation:
{module_html(linalg_module)}
{xmo.module_html(linalg_module)}
""")
return (
a,
Expand Down Expand Up @@ -671,25 +671,14 @@ def format_row(key: str, *values: str):
)


@app.cell
def _(ModuleOp):
import html as htmllib

def module_html(module: ModuleOp) -> str:
return f"""\
<div style="overflow-y: scroll; height:400px;"><small><code style="white-space: pre-wrap;">{htmllib.escape(str(module))}</code></small></div>
"""
return htmllib, module_html


@app.cell
def _():
from collections import Counter
return (Counter,)


@app.cell
def _(Counter, ModuleOp, ModulePass, PipelinePass, ctx, mo, module_html):
def _(Counter, ModuleOp, ModulePass, PipelinePass, ctx, mo, xmo):
def spec_str(p: ModulePass) -> str:
if isinstance(p, PipelinePass):
return ",".join(str(c.pipeline_pass_spec()) for c in p.passes)
Expand All @@ -711,12 +700,12 @@ def pipeline_accordion(
header = f"{spec} ({d_key_count[spec]})"
else:
header = spec
html_res = module_html(res)
html_res = xmo.module_html(res)
d.append(mo.vstack(
(
header,
text,
mo.md(html_res),
html_res,
)
))
return (res, mo.carousel(d))
Expand Down
24 changes: 9 additions & 15 deletions docs/marimo/onnx/onnx_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import marimo

__generated_with = "0.10.0"
__generated_with = "0.10.9"
app = marimo.App()


Expand Down Expand Up @@ -132,13 +132,13 @@ def _(mo, model_def):


@app.cell
def _(html, init_module, mo):
def _(init_module, mo, xmo):
mo.md(f"""
### Converting to `linalg`
Here is the xDSL representation of the function, it takes two `tensor` values of our chosen shape, passes them as operands to the `onnx.Add` operation, and returns it:
{html(init_module)}
{xmo.module_html(init_module)}
"""
)
return
Expand Down Expand Up @@ -260,15 +260,9 @@ def _():


@app.cell(hide_code=True)
def _(ModuleOp, mo):
import html as htmllib

def html(module: ModuleOp) -> mo.Html:
return f"""\
<small><code style="white-space: pre-wrap;">{htmllib.escape(str(module))}</code></small>
"""
# return mo.as_html(str(module))
return html, htmllib
def _():
import xdsl.utils.marimo as xmo
return (xmo,)


@app.cell(hide_code=True)
Expand All @@ -278,7 +272,7 @@ def _():


@app.cell(hide_code=True)
def _(Counter, ModuleOp, ModulePass, PipelinePass, ctx, html, mo):
def _(Counter, ModuleOp, ModulePass, PipelinePass, ctx, mo, xmo):
def spec_str(p: ModulePass) -> str:
if isinstance(p, PipelinePass):
return ",".join(str(c.pipeline_pass_spec()) for c in p.passes)
Expand All @@ -298,11 +292,11 @@ def pipeline_accordion(passes: tuple[tuple[mo.Html, ModulePass], ...], module: M
header = f"{spec} ({d_key_count[spec]})"
else:
header = spec
html_res = html(res)
html_res = xmo.module_html(res)
d[header] = mo.vstack((
text,
# mo.plain_text(f"Pass: {p.pipeline_pass_spec()}"),
mo.md(html_res)
html_res
))
return (res, mo.accordion(d))
return pipeline_accordion, spec_str
Expand Down
9 changes: 9 additions & 0 deletions xdsl/utils/marimo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import marimo as mo

from xdsl.dialects.builtin import ModuleOp


def asm_html(asm: str) -> mo.Html:
"""
Returns a Marimo-optimised representation of the assembly code passed in.
"""
return mo.ui.code_editor(asm, language="python", disabled=True)


def module_html(module: ModuleOp) -> mo.Html:
"""
Returns a Marimo-optimised representation of the module passed in.
"""
return mo.ui.code_editor(str(module), language="javascript", disabled=True)

0 comments on commit 460bd6e

Please sign in to comment.