Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run simulations in parallel #347

Merged
merged 18 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions kernels/alloc/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,20 @@ config["snaxoptflags"] = ",".join(
)


module default_rules:
module snax_rules:
snakefile:
"../../util/snake/default_rules.smk"
"../../util/snake/snax.smk"
config:
config


rule all:
input:
"func.x",
shell:
"{config[vltsim]} {input[0]}"
use rule * from snax_rules as snax_*


use rule * from default_rules exclude compile_simple_main as default_*
# Rules
rule all:
input:
"func_traces.json",


rule compile_snax_binary:
Expand Down
19 changes: 9 additions & 10 deletions kernels/gemm/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,20 @@ config["snaxoptflags"] = ",".join(
)


module default_rules:
module snax_rules:
snakefile:
"../../util/snake/default_rules.smk"
"../../util/snake/snax.smk"
config:
config


use rule * from default_rules as default_*
use rule * from snax_rules as snax_*


# Rules
rule all:
input:
"gemm_traces.json",


rule compile_main:
Expand All @@ -51,13 +57,6 @@ rule compile_main:
"{config[cc]} {config[cflags]} -c {input[0]}"


rule all:
input:
"gemm.x",
shell:
"{config[vltsim]} {input[0]}"


from gendata import create_data_files


Expand Down
10 changes: 4 additions & 6 deletions kernels/rescale/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ config["snaxoptflags"] = ",".join(
)


module default_rules:
module snax_rules:
snakefile:
"../../util/snake/default_rules.smk"
"../../util/snake/snax.smk"
config:
config


use rule * from default_rules as default_*
use rule * from snax_rules as snax_*


rule link_snax_binary:
Expand All @@ -48,9 +48,7 @@ rule link_snax_binary:

rule all:
input:
"rescale.x",
shell:
"{config[vltsim]} {input[0]}"
"rescale_traces.json",


from gendata import create_data_files
Expand Down
10 changes: 4 additions & 6 deletions kernels/simple_copy/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,20 @@ config["snaxoptflags"] = ",".join(
)


module default_rules:
module snax_rules:
snakefile:
"../../util/snake/default_rules.smk"
"../../util/snake/snax.smk"
config:
config


use rule * from default_rules as default_*
use rule * from snax_rules as snax_*


# Rules
rule all:
input:
"simple_copy.x",
shell:
"{config[vltsim]} {input[0]}"
"simple_copy_traces.json",


rule compile_main:
Expand Down
15 changes: 7 additions & 8 deletions kernels/streamer_alu/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,23 @@ config["snaxoptflags"] = ",".join(
)


module default_rules:
module snax_rules:
snakefile:
"../../util/snake/default_rules.smk"
"../../util/snake/snax.smk"
config:
config


use rule * from default_rules as default_*
use rule * from snax_rules as snax_*


files = ["streamer_add", "streamer_add_stream"]


# Rules
rule all:
input:
"streamer_add.x",
"streamer_add_stream.x",
run:
for item in input:
shell("{config[vltsim]} {item}")
expand("{file}_traces.json", file=files),


from gendata import create_data_files
Expand Down
15 changes: 7 additions & 8 deletions kernels/streamer_matmul/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,23 @@ config["mlirtransformflags"] = [
]


module default_rules:
module snax_rules:
snakefile:
"../../util/snake/default_rules.smk"
"../../util/snake/snax.smk"
config:
config


use rule * from default_rules as default_*
use rule * from snax_rules as snax_*


files = ["quantized_matmul", "tiled_quantized_matmul"]


# Rules
rule all:
input:
"quantized_matmul.x",
"tiled_quantized_matmul.x",
run:
for item in input:
shell("{config[vltsim]} {item}")
expand("{file}_traces.json", file=files),


rule generate_quantized_matmul:
Expand Down
11 changes: 4 additions & 7 deletions kernels/tiled_add/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def get_snax_opt_flags(options):
)


module default_rules:
module snax_rules:
snakefile:
"../../util/snake/default_rules.smk"
"../../util/snake/snax.smk"
config:
config


use rule * from default_rules as default_*
use rule * from snax_rules as snax_*


rule size_mlir:
Expand Down Expand Up @@ -123,15 +123,12 @@ COMPILER_OPTS = ["accfgboth", "noaccfgopt"]
rule all:
input:
expand(
"{file}_{array_size}_{tile_size}_{compiler_opt}.x",
"{file}_{array_size}_{tile_size}_{compiler_opt}_traces.json",
file=FILES,
array_size=ARRAY_SIZES,
tile_size=TILE_SIZES,
compiler_opt=COMPILER_OPTS,
),
run:
for item in input:
shell("{config[vltsim]} {item}")


from gendata import generate_data
Expand Down
16 changes: 6 additions & 10 deletions kernels/transform_copy/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,25 @@ config["snaxoptflags"] = ",".join(
)


module default_rules:
module snax_rules:
snakefile:
"../../util/snake/default_rules.smk"
"../../util/snake/snax.smk"
config:
config


use rule * from default_rules as default_*
use rule * from snax_rules as snax_*


from gendata import create_files

files = ["transform_copy", "transform_from_none", "transform_from_strided"]


# Rules
rule all:
input:
"transform_copy.x",
"transform_from_none.x",
"transform_from_strided.x",
run:
shell("{config[vltsim]} {input[0]} {input[0]}")
shell("{config[vltsim]} {input[1]} {input[1]}")
shell("{config[vltsim]} {input[2]} {input[2]}")
expand("{file}_traces.json", file=files),


rule generate_data:
Expand Down
11 changes: 10 additions & 1 deletion util/snake/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
get_clang_flags,
get_default_flags,
)
from util.snake.paths import get_default_paths
from util.snake.paths import get_default_paths, get_default_snax_paths


def get_snax_mac_config() -> dict[str, Any]:
Expand All @@ -14,7 +14,10 @@ def get_snax_mac_config() -> dict[str, Any]:
snitch_sw_path = snax_utils_path + "/snax-mac"
config: dict[str, Any] = {}
config.update(get_default_paths())
config.update(get_default_snax_paths())
config.update(get_default_flags(snitch_sw_path))
config["num_chips"] = 1
config["num_harts"] = 2
config["vltsim"] = f"{snax_utils_path}/snax-mac-rtl/bin/snitch_cluster.vlt"
config["cflags"].append(
f"-I{snitch_sw_path}/target/snitch_cluster/sw/snax/mac/include"
Expand All @@ -31,7 +34,10 @@ def get_snax_gemmx_config() -> dict[str, Any]:
snitch_sw_path = snax_utils_path + "/snax-kul-cluster-mixed-narrow-wide"
config: dict[str, Any] = {}
config.update(get_default_paths())
config.update(get_default_snax_paths())
config.update(get_default_flags(snitch_sw_path))
config["num_chips"] = 1
config["num_harts"] = 2
config["vltsim"] = (
snax_utils_path
+ "/snax-kul-cluster-mixed-narrow-wide-rtl/bin/snitch_cluster.vlt"
Expand All @@ -45,7 +51,10 @@ def get_snax_alu_config() -> dict[str, Any]:
snitch_sw_path = snax_utils_path + "/snax-alu"
config: dict[str, Any] = {}
config.update(get_default_paths())
config.update(get_default_snax_paths())
config.update(get_default_flags(snitch_sw_path))
config["num_chips"] = 1
config["num_harts"] = 2
config["vltsim"] = snax_utils_path + "/snax-alu-rtl/bin/snitch_cluster.vlt"
return config

Expand Down
13 changes: 13 additions & 0 deletions util/snake/paths.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os


def get_default_paths() -> dict[str, str]:
return {
"cc": "clang",
Expand All @@ -6,3 +9,13 @@ def get_default_paths() -> dict[str, str]:
"mlir-translate": "mlir-translate",
"snax-opt": "snax-opt",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not snax-specific?

}


def get_default_snax_paths() -> dict[str, str]:
# use CONDA_PREFIX to access pixi env
gen_trace_path = f"{os.environ['CONDA_PREFIX']}/snax-utils/gen_trace.py"
return {
"python": "python",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this snax-specific?

"spike-dasm": "spike-dasm",
"gen_trace.py": gen_trace_path,
}
61 changes: 61 additions & 0 deletions util/snake/snax.smk
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from util.tracing.merge_json import merge_json


# All snax compliant configs also use the default rules
module default_rules:
snakefile:
"default_rules.smk"
config:
config


use rule * from default_rules as default_*


rule simulate:
input:
"{file}.x",
output:
temp(
expand(
"{file}_trace_chip_{num_chips:02d}_hart_{num_harts:05d}.dasm",
file=["{file}"],
num_chips=range(config["num_chips"]),
num_harts=range(config["num_harts"]),
),
),
log:
"{file}.vltlog",
shell:
"{config[vltsim]} --prefix-trace={wildcards.file}_ {wildcards.file}.x 2>&1 | tee {log}"


rule trace_dasm:
"""
Use spike-dasm and gen_trace.py to make simulation traces human-readable
and aggregate stats for a specific hart's trace.
"""
input:
"{file}.dasm",
output:
temp("{file}_perf.json"),
temp("{file}.txt"),
shell:
"{config[spike-dasm]} < {input} | {config[python]} {config[gen_trace.py]} --permissive -d {output[0]} > {output[1]}"


rule aggregate_json:
"""
Aggregate traced stats for across chips and hart traces.
"""
input:
expand(
"{file}_trace_chip_{num_chips:02d}_hart_{num_harts:05d}_perf.json",
file=["{file}"],
num_chips=range(config["num_chips"]),
num_harts=range(config["num_harts"]),
),
output:
temp("{file}_traces.json"),
run:
merge_json(input, output[0])
Loading
Loading