Skip to content

Commit

Permalink
return new ctx as well as module
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Jan 7, 2025
1 parent 3036e06 commit 789585b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 31 deletions.
52 changes: 27 additions & 25 deletions docs/marimo/linalg_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,11 @@ def _(mo):

@app.cell
def _(MLContext, get_all_dialects):
ctx = MLContext()
linalg_ctx = MLContext()

for dialect_name, dialect_factory in get_all_dialects().items():
ctx.register_dialect(dialect_name, dialect_factory)
return ctx, dialect_factory, dialect_name
linalg_ctx.register_dialect(dialect_name, dialect_factory)
return dialect_factory, dialect_name, linalg_ctx


@app.cell
Expand All @@ -278,7 +278,7 @@ def _(
convert_linalg_to_loops,
convert_memref_to_riscv,
convert_scf_to_riscv_scf,
ctx,
linalg_ctx,
linalg_module,
reconcile_unrealized_casts,
xmo,
Expand All @@ -294,12 +294,12 @@ def _(
]
)

riscv_module, riscv_html = xmo.pipeline_html(
ctx, tuple(("", p) for p in lower_to_riscv.passes), linalg_module
riscv_ctx, riscv_module, riscv_html = xmo.pipeline_html(
linalg_ctx, linalg_module, tuple(("", p) for p in lower_to_riscv.passes)
)

riscv_html
return lower_to_riscv, riscv_html, riscv_module
return lower_to_riscv, riscv_ctx, riscv_html, riscv_module


@app.cell
Expand All @@ -319,7 +319,7 @@ def _(
CanonicalizePass,
PipelinePass,
RISCVRegisterAllocation,
ctx,
riscv_ctx,
riscv_module,
xmo,
):
Expand All @@ -330,20 +330,20 @@ def _(
]
)

regalloc_module, regalloc_html = xmo.pipeline_html(
ctx, tuple(("", p) for p in allocate_registers.passes), riscv_module
regalloc_ctx, regalloc_module, regalloc_html = xmo.pipeline_html(
riscv_ctx, riscv_module, tuple(("", p) for p in allocate_registers.passes),
)

regalloc_html
return allocate_registers, regalloc_html, regalloc_module
return allocate_registers, regalloc_ctx, regalloc_html, regalloc_module


@app.cell
def _(
CanonicalizePass,
ConvertRiscvScfToRiscvCfPass,
PipelinePass,
ctx,
regalloc_ctx,
regalloc_module,
xmo,
):
Expand All @@ -354,12 +354,12 @@ def _(
]
)

riscv_asm_module, assembly_html = xmo.pipeline_html(
ctx, (("", lower_to_asm),), regalloc_module
riscv_asm_ctx, riscv_asm_module, assembly_html = xmo.pipeline_html(
regalloc_ctx, regalloc_module, (("", lower_to_asm),),
)

assembly_html
return assembly_html, lower_to_asm, riscv_asm_module
return assembly_html, lower_to_asm, riscv_asm_ctx, riscv_asm_module


@app.cell
Expand Down Expand Up @@ -399,7 +399,7 @@ def _(
arith_add_fastmath,
convert_linalg_to_memref_stream,
convert_riscv_scf_for_to_frep,
ctx,
linalg_ctx,
linalg_module,
xmo,
):
Expand All @@ -415,15 +415,16 @@ def _(
]
)

snitch_stream_module, snitch_stream_html = xmo.pipeline_html(
ctx, tuple(("", p) for p in convert_linalg_to_snitch.passes), linalg_module
snitch_stream_ctx, snitch_stream_module, snitch_stream_html = xmo.pipeline_html(
linalg_ctx, linalg_module, tuple(("", p) for p in convert_linalg_to_snitch.passes),
)

snitch_stream_html
return (
LOWER_MEMREF_STREAM_TO_SNITCH_STREAM_PASSES,
OPTIMISE_MEMREF_STREAM_PASSES,
convert_linalg_to_snitch,
snitch_stream_ctx,
snitch_stream_html,
snitch_stream_module,
)
Expand All @@ -436,16 +437,17 @@ def _(mo):


@app.cell
def _(ctx, snitch_stream_module, xmo):
def _(snitch_stream_ctx, snitch_stream_module, xmo):
from xdsl.transforms.test_lower_linalg_to_snitch import LOWER_SNITCH_STREAM_TO_ASM_PASSES

snitch_asm_module, snitch_asm_html = xmo.pipeline_html(
ctx, tuple(("", p) for p in LOWER_SNITCH_STREAM_TO_ASM_PASSES), snitch_stream_module
snitch_asm_ctx, snitch_asm_module, snitch_asm_html = xmo.pipeline_html(
snitch_stream_ctx, snitch_stream_module, tuple(("", p) for p in LOWER_SNITCH_STREAM_TO_ASM_PASSES)
)

snitch_asm_html
return (
LOWER_SNITCH_STREAM_TO_ASM_PASSES,
snitch_asm_ctx,
snitch_asm_html,
snitch_asm_module,
)
Expand Down Expand Up @@ -493,7 +495,7 @@ def _(mo):


@app.cell
def _(TypedPtr, a_shape, b_shape, c_shape, ctx, mo, riscv_module):
def _(TypedPtr, a_shape, b_shape, c_shape, mo, riscv_ctx, riscv_module):
from math import prod

from xdsl.interpreter import Interpreter, OpCounter
Expand All @@ -511,7 +513,7 @@ def _(TypedPtr, a_shape, b_shape, c_shape, ctx, mo, riscv_module):
riscv_op_counter = OpCounter()
riscv_interpreter = Interpreter(riscv_module, listeners=(riscv_op_counter,))

register_implementations(riscv_interpreter, ctx, include_wgpu=False, include_onnx=False)
register_implementations(riscv_interpreter, riscv_ctx, include_wgpu=False, include_onnx=False)

riscv_interpreter.call_op("matmul", (a_shaped.data_ptr.raw, b_shaped.data_ptr.raw, riscv_c_shaped.data_ptr.raw))

Expand Down Expand Up @@ -551,10 +553,10 @@ def _(
b_shaped,
c_len,
c_shape,
ctx,
mo,
register_implementations,
riscv_c_shaped,
snitch_stream_ctx,
snitch_stream_module,
):
snitch_op_counter = OpCounter()
Expand All @@ -564,7 +566,7 @@ def _(

snitch_c_shaped = ShapedArray(TypedPtr.new_float64([0.0] * c_len), c_shape)

register_implementations(snitch_interpreter, ctx, include_wgpu=False, include_onnx=False)
register_implementations(snitch_interpreter, snitch_stream_ctx, include_wgpu=False, include_onnx=False)

snitch_interpreter.call_op(
"matmul",
Expand Down
6 changes: 3 additions & 3 deletions docs/marimo/onnx/onnx_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ def _(
mo,
xmo,
):
bufferized_module, linalg_html = xmo.pipeline_html(
bufferized_ctx, bufferized_module, linalg_html = xmo.pipeline_html(
ctx,
init_module,
(
(
mo.md(
Expand Down Expand Up @@ -213,8 +214,7 @@ def _(
]
)
)
),
init_module
)
)

linalg_html
Expand Down
6 changes: 3 additions & 3 deletions xdsl/utils/marimo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def _spec_str(p: ModulePass) -> str:


def pipeline_html(
ctx: MLContext, passes: Sequence[tuple[mo.Html, ModulePass]], module: ModuleOp
) -> tuple[ModuleOp, mo.Html]:
ctx: MLContext, module: ModuleOp, passes: Sequence[tuple[mo.Html, ModulePass]]
) -> tuple[MLContext, ModuleOp, mo.Html]:
"""
Returns a tuple of the resulting module after applying the passes, and the
Marimo-optimised representation of the modules throughout compilation.
Expand Down Expand Up @@ -64,4 +64,4 @@ def pipeline_html(
)
)
)
return (res, mo.carousel(d))
return (ctx, res, mo.carousel(d))

0 comments on commit 789585b

Please sign in to comment.