Skip to content

Commit

Permalink
Refactor tiled_matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
JosseVanDelm committed Nov 29, 2024
1 parent df3a9c1 commit 16babcd
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 55 deletions.
134 changes: 84 additions & 50 deletions benchmarks/tiled_matmul/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,51 +24,63 @@ LDFLAGS.append(
f"{SNITCH_SW_PATH}/target/snitch_cluster/sw/snax/streamer-gemm/build/snax-streamer-gemm-lib.o"
)

SNAXOPTFLAGS = [
"insert-accfg-op{accelerator=snax_gemm}",
"convert-linalg-to-kernel",
"dispatch-kernels",
"set-memory-space",
"set-memory-layout",
"realize-memref-casts",
"test-remove-memref-copy",
"insert-sync-barrier",
"reuse-memref-allocs",
"test-add-mcycle-around-loop",
"snax-lower-mcycle",
"dispatch-regions",
"convert-linalg-to-stream",
"convert-stream-to-snax-stream",
"convert-linalg-to-accfg",
"snax-copy-to-dma",
"memref-to-snax",
"snax-to-func",
"clear-memory-space",
"function-constant-pinning",
"mlir-opt{"
+ "\ ".join(
MLIRPREPROCFLAGS, MLIRPREPROC2FLAGS, MLIRPREPROC3FLAGS = get_mlir_preproc_flags()
MLIROPTFLAGS = get_mlir_postproc_flags()


def get_snax_opt_flags(options):
flags = []
match options:
case "deduponly":
flags = ["accfg-dedup"]
case "overlaponly":
flags = ["accfg-config-overlap"]
case "accfgboth":
flags = ["accfg-dedup", "accfg-config-overlap"]
return ",".join(
[
"executable=mlir-opt",
"generic=true",
"arguments='"
+ ",".join(
"insert-accfg-op{accelerator=snax_gemm}",
"convert-linalg-to-kernel",
"dispatch-kernels",
"set-memory-space",
"set-memory-layout",
"realize-memref-casts",
"test-remove-memref-copy",
"insert-sync-barrier",
"reuse-memref-allocs",
"test-add-mcycle-around-loop",
"snax-lower-mcycle",
"dispatch-regions",
"convert-linalg-to-stream",
"convert-stream-to-snax-stream",
"convert-linalg-to-accfg",
"snax-copy-to-dma",
"memref-to-snax",
"snax-to-func",
"clear-memory-space",
"function-constant-pinning",
"mlir-opt{"
+ "\ ".join(
[
"-cse",
"-canonicalize",
"-allow-unregistered-dialect",
"-mlir-print-op-generic",
"executable=mlir-opt",
"generic=true",
"arguments='"
+ ",".join(
[
"-cse",
"-canonicalize",
"-allow-unregistered-dialect",
"-mlir-print-op-generic",
]
),
]
),
)
+ "'}",
*flags,
"convert-accfg-to-csr",
]
)
+ "'}",
"accfg-dedup",
"accfg-config-overlap",
"convert-accfg-to-csr",
]
SNAXOPTFLAGS = ",".join(SNAXOPTFLAGS)
MLIRPREPROCFLAGS, MLIRPREPROC2FLAGS, MLIRPREPROC3FLAGS = get_mlir_preproc_flags()
MLIROPTFLAGS = get_mlir_postproc_flags()


sizes = [
[16, 16, 16],
Expand All @@ -79,30 +91,50 @@ sizes = [
[512, 512, 512],
]

options = [
"noaccfgopt",
"deduponly",
"overlaponly",
"accfgboth",
]

tiling_factors = [[8, 8]]

from genbenchmark import create_tiled_matrix_multiply, write_module_to_file
from gendata import create_test_data

# Rules


# Rules
rule run_benchmarks:
input:
expand("generated_{size[0]}_{size[1]}_{size[2]}.generated.x", size=sizes),
expand(
"generated_{size[0]}_{size[1]}_{size[2]}_{options}.x",
size=sizes,
options=options,
),


rule generate_mlir:
output:
"generated_{m}_{n}_{k}.generated.mlir",
params:
k=lambda wildcards: int(wildcards.k),
m=lambda wildcards: int(wildcards.m),
n=lambda wildcards: int(wildcards.n),
# hardcoded to [8,8] for now
tiling_factors=lambda _: tiling_factors[0],
run:
write_module_to_file(
create_tiled_matrix_multiply(wildcards.k, wildcards.m, wildcards.n),
create_tiled_matrix_multiply(
params.k, params.m, params.n, params.tiling_factors
),
output[0],
)


rule preprocess_mlir:
input:
"{file}.mlir",
"{file}.generated.mlir",
output:
temp("{file}.preproc1.mlir"),
temp("{file}.preproc2.mlir"),
Expand All @@ -119,11 +151,13 @@ rule preprocess_mlir:

rule snax_opt_mlir:
input:
"{file}.preprocfinal.mlir",
"generated_{m}_{n}_{k}.preprocfinal.mlir",
output:
temp("{file}.snax-opt.mlir"),
temp("generated_{m}_{n}_{k}_{options}.snax-opt.mlir"),
params:
snax_flags=lambda wildcards: get_snax_opt_flags(wildcards.options),
shell:
"{SNAXOPT} -p {SNAXOPTFLAGS} -o {output} {input}"
"{SNAXOPT} -p {params.snax_flags} -o {output} {input}"


rule postprocess_mlir:
Expand Down Expand Up @@ -182,11 +216,11 @@ rule compile_main:

rule compile_snax_binary:
input:
"generated_{m}_{n}_{k}.generated.o",
"generated_{m}_{n}_{k}_{options}.o",
"main_{m}_{n}_{k}.o",
"data_{m}_{n}_{k}.o",
output:
"generated_{m}_{n}_{k}.generated.x",
"generated_{m}_{n}_{k}_{options}.x",
shell:
"{LD} {LDFLAGS} {input} -o {output}"

Expand Down
6 changes: 1 addition & 5 deletions benchmarks/tiled_matmul/genbenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from util.snax_benchmark import SNAXBenchmark


def create_tiled_matrix_multiply(k, m, n):
tiling_factors = [8, 8]
k = int(k)
m = int(m)
n = int(n)
def create_tiled_matrix_multiply(k, m, n, tiling_factors):
"""
Generate IR in the form of:
```
Expand Down

0 comments on commit 16babcd

Please sign in to comment.