diff --git a/.github/workflows/ci-llama.yaml b/.github/workflows/ci-llama.yaml index e1d79694a..999c3fbcc 100644 --- a/.github/workflows/ci-llama.yaml +++ b/.github/workflows/ci-llama.yaml @@ -8,6 +8,7 @@ name: Llama Benchmarking Tests on: workflow_dispatch: + pull_request: schedule: # Weekdays at 5:00 AM UTC = 10:00 PM PST. - cron: "0 5 * * 1-5" @@ -75,7 +76,7 @@ jobs: "numpy<2.0" - name: Run llama test - run: pytest sharktank/tests/models/llama/benchmark_amdgpu_tests.py -v -s --longrun + run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --longrun --iree-hip-target=gfx942 - name: Upload llama executable files uses: actions/upload-artifact@v4 diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 2076c39eb..afbc93e46 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -201,13 +201,6 @@ def caching(request: FixtureRequest) -> Optional[bool]: return set_fixture_from_cli_option(request, "caching") -@pytest.fixture(scope="class") -def iree_hip_target_type(request: FixtureRequest) -> Optional[str]: - return set_fixture_from_cli_option( - request, "iree_hip_target", "iree_hip_target_type" - ) - - @pytest.fixture(scope="class") def tensor_parallelism_size(request: FixtureRequest) -> Optional[str]: return set_fixture_from_cli_option( diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 0436c0008..40cfea94f 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -57,7 +57,7 @@ def main(): "--attention-kernel", type=str, default="decomposed", - choices=["decomposed", "torch_sdpa"], + choices=["decomposed", "torch"], ) args = cli.parse(parser) diff --git a/sharktank/sharktank/models/llama/tools/shard_llama.py b/sharktank/sharktank/models/llama/tools/shard_llama.py index 9ac877d6f..bd1ffa696 100644 --- a/sharktank/sharktank/models/llama/tools/shard_llama.py +++ b/sharktank/sharktank/models/llama/tools/shard_llama.py @@ -33,12 +33,10 @@ def main(): dataset = cli.get_input_dataset(args) if args.output_file is None: - print(f"Need file destination for IRPA file") - return + raise RuntimeError(f"Need file destination for IRPA file") if args.shard_count < 2: - print(f"Expect sharding greater than 1 found {args.shard_count}") - return + raise RuntimeError(f"Expect sharding greater than 1 found {args.shard_count}") hp = configs.LlamaHParams.from_gguf_props(dataset.properties) llama_config = LlamaModelConfig(hp) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index b7e7bb2d4..84c206d7a 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -5,11 +5,13 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import os +import sys import subprocess import logging import time from pathlib import Path from datetime import timedelta +from typing import List import iree.compiler as ireec @@ -25,6 +27,7 @@ class ExportArtifacts: def __init__( self, + *, irpa_path: str, batch_size: int, iree_hip_target: str, @@ -59,9 +62,44 @@ def wrapper(*args, **kwargs): return wrapper + @timeit + def shard_irpa_file( + self, + *, + output_file: str, + ): + shard_irpa_args = [ + "python3", + "-m", + "sharktank.models.llama.tools.shard_llama", + "--irpa-file", + self.irpa_path, + "--output-file", + output_file, + "--shard_count", + str(self.tensor_parallelism_size), + ] + + cwd = self.sharktank_dir + cmd = subprocess.list2cmdline(shard_irpa_args) + + logger.info(f"Sharding irpa file:\n" f"cd {cwd} && {cmd}") + + proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) + if proc.returncode != 0: + logger.error( + f"Error sharding irpa file with shard_llama.py\n" + f"{proc.stdout+proc.stderr}" + ) + else: + logger.info(f"Sharded irpa file successfully:\n" f"{proc.stdout}") + + return proc.returncode + @timeit def export_to_mlir( self, + *, mlir_path: str, json_path: str, ): @@ -78,18 +116,16 @@ def export_to_mlir( "--bs", str(self.batch_size), ] - if self.attention_kernel == "decomposed": + if self.attention_kernel in ["decomposed", "torch"]: export_args.append("--attention-kernel") export_args.append(self.attention_kernel) - elif self.attention_kernel == "torch_sdpa": - raise NotImplementedError("attention_kernel torch_sdpa not implemented yet") cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args) logger.info(f"Exporting mlir:\n" f"cd {cwd} && {cmd}") - proc = subprocess.run(export_args, capture_output=True, cwd=cwd, text=True) + proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) if proc.returncode != 0: logger.error( f"Error exporting mlir with export_paged_llm_v1.py\n" @@ -103,12 +139,14 @@ def export_to_mlir( @timeit def compile_to_vmfb( self, + *, mlir_path, vmfb_path, + hal_dump_path, ): # TODO: Control flag to enable multiple backends compile_flags = ["--iree-hip-target=" + self.iree_hip_target] - + compile_flags += [f"--iree-hal-dump-executable-files-to={hal_dump_path}/files"] try: ireec.compile_file( input_file=mlir_path, @@ -121,7 +159,43 @@ def compile_to_vmfb( else: logger.info(f"Compiled to vmfb successfully:\n" f"{vmfb_path}") - def create_file(self, suffix, prefix): + def iree_benchmark_vmfb( + self, + *, + hip_device_id: str, + vmfb_name: str, + irpa_path: str, + args: List[str], + cwd: str | Path, + ): + """Runs a compiled program with the given args using `iree-benchmark-module`. + This assumes that the `iree-benchmark-module` command is available (usually via PATH). + Args: + vmfb_name: Name of the .vmfb file (relative to `cwd`). + args: List of arguments to pass to `iree-benchmark-module`. + cwd: Working directory to run the command within. (either string or Path works) + compile_cmd: Command used to compile the program, for inclusion in error messages. + Raises Exception if running fails for some reason. + """ + benchmark_args = [ + f"ROCR_VISIBLE_DEVICES={hip_device_id}", + "iree-benchmark-module", + f"--device=hip://{hip_device_id}", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--device_allocator=caching", + f"--module={vmfb_name}", + f"--parameters=model={irpa_path}", + ] + benchmark_args += args + cmd = subprocess.list2cmdline(benchmark_args) + logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") + proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd) + return_code = proc.returncode + if return_code != 0: + raise RuntimeError(f"Error running benchmark {cmd} in cwd {cwd}") + + def create_file(self, *, suffix, prefix): file_path = Path(prefix).with_suffix(suffix) f = open(file_path, "w") return file_path diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_tests.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py similarity index 73% rename from sharktank/tests/models/llama/benchmark_amdgpu_tests.py rename to sharktank/tests/models/llama/benchmark_amdgpu_test.py index 174fcbe87..c99bbc7e1 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_tests.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -13,76 +13,13 @@ import subprocess from pathlib import Path from typing import List +from sharktank.utils.export_artifacts import ExportArtifacts longrun = pytest.mark.skipif("not config.getoption('longrun')") is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") -class ExportMlirException(Exception): - """SHARK-Platform export MLIR exception that preserves the command line and error output.""" - - def __init__(self, process: subprocess.CompletedProcess, cwd: str): - try: - errs = process.stderr.decode("utf-8") - except: - errs = str(process.stderr) - - super().__init__( - f"Error invoking export_paged_llama_v1.py\n" - f"Error code: {process.returncode}\n" - f"Stderr diagnostics:\n{errs}\n\n" - f"Invoked with:\n" - f" cd {cwd} && {process.args}\n\n" - ) - - -class IreeCompileException(Exception): - """Compiler exception that preserves the command line and error output.""" - - def __init__(self, process: subprocess.CompletedProcess, cwd: str): - try: - errs = process.stderr.decode("utf-8") - except: - errs = str(process.stderr) - - super().__init__( - f"Error invoking iree-compile\n" - f"Error code: {process.returncode}\n" - f"Stderr diagnostics:\n{errs}\n\n" - f"Invoked with:\n" - f" cd {cwd} && {process.args}\n\n" - ) - - -class IreeBenchmarkException(Exception): - """Runtime exception that preserves the command line and error output.""" - - def __init__( - self, process: subprocess.CompletedProcess, cwd: str, compile_cmd: str - ): - # iree-run-module sends output to both stdout and stderr - try: - errs = process.stderr.decode("utf-8") - except: - errs = str(process.stderr) - try: - outs = process.stdout.decode("utf-8") - except: - outs = str(process.stdout) - - super().__init__( - f"Error invoking iree-benchmark-module\n" - f"Error code: {process.returncode}\n" - f"Stderr diagnostics:\n{errs}\n" - f"Stdout diagnostics:\n{outs}\n" - f"Compiled with:\n" - f" cd {cwd} && {compile_cmd}\n\n" - f"Run with:\n" - f" cd {cwd} && {process.args}\n\n" - ) - - -@pytest.mark.usefixtures("iree_hip_target_type") +@pytest.mark.usefixtures("get_iree_flags") class BaseBenchmarkTest(unittest.TestCase): directory_created = False current_date = datetime.now() @@ -105,167 +42,34 @@ def setUpClass(cls): def setUp(self): self.hip_device_id = os.getenv("HIP_DEVICE_ID", default="0") - def create_file(self, *, suffix, prefix): - file_path = Path(prefix).with_suffix(suffix) - f = open(file_path, "w") - return file_path - - def get_export_cmd( - self, - *, - attention_kernel: str, - tensor_parallelism_size: int, - irpa_path: str, - output_mlir_path: str, - output_json_path: str, - ): - export_args = [ - "python3", - "-m", - "sharktank.examples.export_paged_llm_v1", - "--irpa-file", - irpa_path, - "--output-mlir", - output_mlir_path, - "--output-config", - output_json_path, - ] - if attention_kernel == "decomposed": - export_args.append("--attention-kernel") - export_args.append(attention_kernel) - elif attention_kernel == "torch_sdpa": - raise NotImplementedError( - "attention_kernel torch_sdpa not yet plumbed through" - ) - if tensor_parallelism_size: - export_args.append("--tensor-parallelism-size") - export_args.append(str(tensor_parallelism_size)) - - cmd = subprocess.list2cmdline(export_args) - return cmd - - def get_compile_cmd( - self, *, output_mlir_path: str, output_vmfb_path: str, args: [str] - ): - compile_args = ["iree-compile", output_mlir_path] - compile_args += args - compile_args += ["-o", output_vmfb_path] - cmd = subprocess.list2cmdline(compile_args) - return cmd - - def export_mlir( - self, - *, - attention_kernel: str, - tensor_parallelism_size: int, - irpa_path: str, - output_mlir_path: str, - output_json_path: str, - cwd: str | Path, - ): - """Runs export_paged_llm_v1.py and exports an MLIR file. - Args: - irpa_path: Path to the model irpa file. - output_mlir_path: Path to the file to save the exported file. - output_json_path: Path to the file to save the config json file. - """ - cmd = self.get_export_cmd( - attention_kernel=attention_kernel, - tensor_parallelism_size=tensor_parallelism_size, - irpa_path=irpa_path, - output_mlir_path=output_mlir_path, - output_json_path=output_json_path, - ) - logging.getLogger().info(f"Launching export command:\n" f"cd {cwd} && {cmd}") - proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) - return_code = proc.returncode - if return_code != 0: - raise ExportMlirException(proc, cwd) - - def iree_compile( - self, - *, - mlir_path: str, - output_vmfb_path: str, - args: List[str], - cwd: str | Path, - ): - """Compiles an input MLIR file to an output .vmfb file. - This assumes that the `iree-compile` command is available (usually via PATH). - Args: - mlir_path: Path to the input MLIR file. - output_vmfb_path: Path for the output .vmfb file. The directory must already exist. - args: List of arguments to pass to `iree-compile`. - cwd: current working directory - Raises Exception if compilation fails for some reason. - """ - cmd = self.get_compile_cmd( - output_mlir_path=mlir_path, - output_vmfb_path=output_vmfb_path, - args=args, - ) - logging.getLogger().info(f"Launching compile command:\n" f"cd {cwd} && {cmd}") - proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) - return_code = proc.returncode - if return_code != 0: - raise IreeCompileException(proc, cwd) - - def iree_benchmark_module( - self, - *, - hip_device_id: str, - vmfb_name: str, - irpa_path: str, - args: List[str], - cwd: str | Path, - ): - """Runs a compiled program with the given args using `iree-benchmark-module`. - This assumes that the `iree-benchmark-module` command is available (usually via PATH). - Args: - vmfb_name: Name of the .vmfb file (relative to `cwd`). - args: List of arguments to pass to `iree-benchmark-module`. - cwd: Working directory to run the command within. (either string or Path works) - compile_cmd: Command used to compile the program, for inclusion in error messages. - Raises Exception if running fails for some reason. - """ - benchmark_args = [ - f"ROCR_VISIBLE_DEVICES={hip_device_id}", - "iree-benchmark-module", - f"--device=hip://{hip_device_id}", - "--hip_use_streams=true", - "--hip_allow_inline_execution=true", - "--device_allocator=caching", - f"--module={vmfb_name}", - f"--parameters=model={irpa_path}", - ] - benchmark_args += args - cmd = subprocess.list2cmdline(benchmark_args) - logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") - proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd) - return_code = proc.returncode - if return_code != 0: - raise IreeBenchmarkException(proc, cwd, cmd) - class BenchmarkLlama3_1_8B(BaseBenchmarkTest): def setUp(self): super().setUp() # TODO: add numpy files to Azure and download from it - artifacts_dir = Path("/data/llama-3.1/8b") - self.irpa_path = artifacts_dir / "llama8b_f16.irpa" - self.irpa_path_fp8 = artifacts_dir / "llama8b_fp8.irpa" + self.artifacts_dir = Path("/data/llama-3.1/8b") + self.irpa_path = self.artifacts_dir / "llama8b_f16.irpa" + self.irpa_path_fp8 = self.artifacts_dir / "llama8b_fp8.irpa" self.tensor_parallelism_size = 1 self.dir_path_8b = self.dir_path / "llama-8b" self.temp_dir_8b = Path(self.dir_path_8b) self.temp_dir_8b.mkdir(parents=True, exist_ok=True) + self.llama8b_f16_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + ) self.iree_compile_args = [ "--iree-hal-target-backends=rocm", - f"--iree-hip-target={self.iree_hip_target_type}", + f"--iree-hip-target={self.iree_hip_target}", ] - self.prefill_args_f16 = artifacts_dir / "prefill_args" - self.decode_args_f16 = artifacts_dir / "decode_args" - self.prefill_args_fp8 = artifacts_dir / "prefill_args_fp8" - self.decode_args_fp8 = artifacts_dir / "decode_args_fp8" + self.prefill_args_f16 = self.artifacts_dir / "prefill_args" + self.decode_args_f16 = self.artifacts_dir / "decode_args" + self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8" + self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8" self.iree_run_prefill_args = [ "--function=prefill_bs4", f"--input=@{self.prefill_args_f16}/tokens.npy", @@ -305,28 +109,36 @@ def setUp(self): @is_mi300x def testBenchmark8B_f16_Decomposed(self): output_file_name = self.dir_path_8b / "f16_decomposed" - output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) - output_json = self.create_file(suffix=".json", prefix=output_file_name) - output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) - self.export_mlir( - attention_kernel="decomposed", - tensor_parallelism_size=self.tensor_parallelism_size, - irpa_path=self.irpa_path, - output_mlir_path=output_mlir, - output_json_path=output_json, - cwd=self.repo_root, + output_mlir = self.llama8b_f16_artifacts.create_file( + suffix=".mlir", prefix=output_file_name ) - iree_compile_args = self.iree_compile_args + [ - f"--iree-hal-dump-executable-files-to={output_file_name}/files" - ] - self.iree_compile( + output_json = self.llama8b_f16_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_f16_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + output_shard_file_name = str( + self.artifacts_dir + / f"llama3.1_8b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + # shard_irpa file + shard_return_code = self.llama8b_f16_artifacts.shard_irpa_file( + output_file=output_shard_file_name + ) + if shard_return_code == 0: + self.irpa_path = output_shard_file_name + export_return_code = self.llama8b_f16_artifacts.export_to_mlir( mlir_path=output_mlir, - output_vmfb_path=output_vmfb, - args=iree_compile_args, - cwd=self.repo_root, + json_path=output_json, + ) + self.llama8b_f16_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, ) # benchmark prefill - self.iree_benchmark_module( + self.llama8b_f16_artifacts.iree_benchmark_vmfb( hip_device_id=self.hip_device_id, vmfb_name=output_vmfb, irpa_path=self.irpa_path, @@ -334,7 +146,7 @@ def testBenchmark8B_f16_Decomposed(self): cwd=self.repo_root, ) # benchmark decode - self.iree_benchmark_module( + self.llama8b_f16_artifacts.iree_benchmark_vmfb( hip_device_id=self.hip_device_id, vmfb_name=output_vmfb, irpa_path=self.irpa_path, @@ -346,29 +158,38 @@ def testBenchmark8B_f16_Decomposed(self): @is_mi300x @pytest.mark.xfail(reason="torch_sdpa not yet plumbed through", strict=True) def testBenchmark8B_f16_Non_Decomposed(self): - output_file_name = self.dir_path_8b / "f16_torch_sdpa" - output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) - output_json = self.create_file(suffix=".json", prefix=output_file_name) - output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) - self.export_mlir( - attention_kernel="torch_sdpa", - tensor_parallelism_size=self.tensor_parallelism_size, - irpa_path=self.irpa_path, - output_mlir_path=output_mlir, - output_json_path=output_json, - cwd=self.repo_root, + output_file_name = self.dir_path_8b / "f16_torch" + output_mlir = self.llama8b_f16_artifacts.create_file( + suffix=".mlir", prefix=output_file_name ) - iree_compile_args = self.iree_compile_args + [ - f"--iree-hal-dump-executable-files-to={output_file_name}/files" - ] - self.iree_compile( + output_json = self.llama8b_f16_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_f16_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + self.llama8b_f16_artifacts.attention_kernel = "torch" + output_shard_file_name = str( + self.artifacts_dir + / f"llama3.1_8b_fp16_tp{self.tensor_parallelism_size}_parameters_torch_sdpa.irpa" + ) + # shard_irpa file + shard_return_code = self.llama8b_f16_artifacts.shard_irpa_file( + output_file=output_shard_file_name + ) + if shard_return_code == 0: + self.irpa_path = output_shard_file_name + export_return_code = self.llama8b_f16_artifacts.export_to_mlir( mlir_path=output_mlir, - output_vmfb_path=output_vmfb, - args=iree_compile_args, - cwd=self.repo_root, + json_path=output_json, + ) + self.llama8b_f16_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, ) # benchmark prefill - self.iree_benchmark_module( + self.llama8b_f16_artifacts.iree_benchmark_vmfb( hip_device_id=self.hip_device_id, vmfb_name=output_vmfb, irpa_path=self.irpa_path, @@ -376,7 +197,7 @@ def testBenchmark8B_f16_Non_Decomposed(self): cwd=self.repo_root, ) # benchmark decode - self.iree_benchmark_module( + self.llama8b_f16_artifacts.iree_benchmark_vmfb( hip_device_id=self.hip_device_id, vmfb_name=output_vmfb, irpa_path=self.irpa_path, @@ -482,7 +303,7 @@ def setUp(self): self.temp_dir_70b.mkdir(parents=True, exist_ok=True) self.iree_compile_args = [ "--iree-hal-target-backends=rocm", - f"--iree-hip-target={self.iree_hip_target_type}", + f"--iree-hip-target={self.iree_hip_target}", ] self.prefill_args_f16 = artifacts_dir / "prefill_args" self.decode_args_f16 = artifacts_dir / "decode_args" @@ -705,7 +526,7 @@ def setUp(self): self.temp_dir_405b.mkdir(parents=True, exist_ok=True) self.iree_compile_args = [ "--iree-hal-target-backends=rocm", - f"--iree-hip-target={self.iree_hip_target_type}", + f"--iree-hip-target={self.iree_hip_target}", ] self.prefill_args_f16 = artifacts_dir / "prefill_args" self.decode_args_f16 = artifacts_dir / "decode_args"