Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snakemake: Add gemm kernel #313

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build-run-kernel-snake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
working-directory: kernels/${{ matrix.kernel }}
strategy:
matrix:
kernel: [alloc, simple_copy, transform_copy]
kernel: [alloc, simple_copy, transform_copy, gemm]
2 changes: 1 addition & 1 deletion .github/workflows/build-run-kernel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
working-directory: kernels/${{ matrix.kernel }}
strategy:
matrix:
kernel: [streamer_alu, tiled_add, streamer_matmul, gemmini, rescale, gemm]
kernel: [streamer_alu, tiled_add, streamer_matmul, gemmini, rescale]
43 changes: 0 additions & 43 deletions kernels/gemm/Makefile

This file was deleted.

80 changes: 80 additions & 0 deletions kernels/gemm/Snakefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from util.snake.configs import get_snax_gemmx_config

config = get_snax_gemmx_config()
config["mlirpreprocflags"] = [
"--linalg-generalize-named-ops",
"--mlir-print-op-generic",
"--mlir-print-local-scope",
]
config["snaxoptflags"] = ",".join(
[
"convert-linalg-to-kernel",
"insert-accfg-op{accelerator=snax_gemmx}",
"dispatch-kernels",
"convert-linalg-to-stream",
"fuse-streaming-regions",
"snax-bufferize",
"alloc-to-global",
"set-memory-space",
"set-memory-layout",
"realize-memref-casts",
"insert-sync-barrier",
"dispatch-regions{nb_cores=2}",
"convert-stream-to-snax-stream",
"convert-linalg-to-accfg",
"convert-accfg-to-csr",
"snax-copy-to-dma",
"memref-to-snax",
"snax-to-func",
"clear-memory-space",
]
)


module default_rules:
snakefile:
"../../util/snake/default_rules.smk"
config:
config


use rule * from default_rules as default_*


rule compile_main:
input:
"main.c",
"data.h",
output:
"main.o",
shell:
"{config[cc]} {config[cflags]} -c {input[0]}"


rule all:
input:
"gemm.x",
shell:
"{config[vltsim]} {input[0]}"


from gendata import create_data_files


rule generate_data:
output:
"data.c",
"data.h",
run:
create_data_files()


rule link_snax_binary:
input:
"gemm.o",
"main.o",
"data.o",
output:
"gemm.x",
shell:
"{config[ld]} {config[ldflags]} {input} -o {output}"
6 changes: 2 additions & 4 deletions kernels/gemm/gendata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# simple script to generate inputs and expected outputs for simple_matmult

import numpy as np

from util.gendata import create_data, create_header

if __name__ == "__main__":

def create_data_files():
# Reset random seed for reproducible behavior

np.random.seed(0)
Expand All @@ -15,7 +14,6 @@

A_size = [m, k]
B_size = [k, n]
O_size = [m, n]

# D = A.B + C
low_bound = -128
Expand Down
14 changes: 14 additions & 0 deletions util/snake/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,17 @@ def get_snax_mac_config():
f"-I{snitch_sw_path}/target/snitch_cluster/sw/snax/mac/build/mac.o"
)
return config


def get_snax_gemmx_config():
# use CONDA_PREFIX to access pixi env
snax_utils_path = os.environ["CONDA_PREFIX"] + "/snax-utils"
snitch_sw_path = snax_utils_path + "/snax-kul-cluster-mixed-narrow-wide"
config = {}
config.update(get_default_paths())
config.update(get_default_flags(snitch_sw_path))
config["vltsim"] = (
snax_utils_path
+ "/snax-kul-cluster-mixed-narrow-wide-rtl/bin/snitch_cluster.vlt"
)
return config
Loading