diff --git a/.github/workflows/ci-llama.yaml b/.github/workflows/ci-llama.yaml index e1d79694a..999c3fbcc 100644 --- a/.github/workflows/ci-llama.yaml +++ b/.github/workflows/ci-llama.yaml @@ -8,6 +8,7 @@ name: Llama Benchmarking Tests on: workflow_dispatch: + pull_request: schedule: # Weekdays at 5:00 AM UTC = 10:00 PM PST. - cron: "0 5 * * 1-5" @@ -75,7 +76,7 @@ jobs: "numpy<2.0" - name: Run llama test - run: pytest sharktank/tests/models/llama/benchmark_amdgpu_tests.py -v -s --longrun + run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --longrun --iree-hip-target=gfx942 - name: Upload llama executable files uses: actions/upload-artifact@v4 diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml index b86c7dc1e..17dc5abec 100644 --- a/.github/workflows/ci-sdxl.yaml +++ b/.github/workflows/ci-sdxl.yaml @@ -99,4 +99,4 @@ jobs: working-directory: ${{ env.LIBSHORTFIN_DIR }} run: | ctest --timeout 30 --output-on-failure --test-dir build - pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu + HIP_VISIBLE_DEVICES=0 pytest tests/apps/sd/e2e_test.py -v -s --system=amdgpu diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 326dd3dad..237225b2a 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -38,6 +38,9 @@ jobs: build-and-test: name: Build and test runs-on: ubuntu-24.04 + strategy: + matrix: + python-version: ["3.11", "3.12"] steps: - name: Install dependencies @@ -67,10 +70,10 @@ jobs: git submodule update --init --depth 1 -- third_party/googletest git submodule update --init --depth 1 -- third_party/hip-build-deps/ - - name: Setup Python + - name: Setup Python ${{ matrix.python-version }} uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 with: - python-version: "3.12" + python-version: ${{ matrix.python-version }} cache: "pip" - name: Install Python packages # TODO: Switch to `pip install -r requirements.txt -e shortfin/`. diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 2076c39eb..afbc93e46 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -201,13 +201,6 @@ def caching(request: FixtureRequest) -> Optional[bool]: return set_fixture_from_cli_option(request, "caching") -@pytest.fixture(scope="class") -def iree_hip_target_type(request: FixtureRequest) -> Optional[str]: - return set_fixture_from_cli_option( - request, "iree_hip_target", "iree_hip_target_type" - ) - - @pytest.fixture(scope="class") def tensor_parallelism_size(request: FixtureRequest) -> Optional[str]: return set_fixture_from_cli_option( diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 0436c0008..40cfea94f 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -57,7 +57,7 @@ def main(): "--attention-kernel", type=str, default="decomposed", - choices=["decomposed", "torch_sdpa"], + choices=["decomposed", "torch"], ) args = cli.parse(parser) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index ef3c4800d..656b4432b 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -269,6 +269,7 @@ def decode( for block_idx, block in enumerate(self.attn_blocks): if block_idx == 0: self.trace_tensor(f"llama.attn_block.{block_idx}.input", h) + block.attn.attention_kernel = "decomposed" h = block( h, start_positions=start_positions, diff --git a/sharktank/sharktank/models/llama/tools/shard_llama.py b/sharktank/sharktank/models/llama/tools/shard_llama.py index 9ac877d6f..bd1ffa696 100644 --- a/sharktank/sharktank/models/llama/tools/shard_llama.py +++ b/sharktank/sharktank/models/llama/tools/shard_llama.py @@ -33,12 +33,10 @@ def main(): dataset = cli.get_input_dataset(args) if args.output_file is None: - print(f"Need file destination for IRPA file") - return + raise RuntimeError(f"Need file destination for IRPA file") if args.shard_count < 2: - print(f"Expect sharding greater than 1 found {args.shard_count}") - return + raise RuntimeError(f"Expect sharding greater than 1 found {args.shard_count}") hp = configs.LlamaHParams.from_gguf_props(dataset.properties) llama_config = LlamaModelConfig(hp) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index b7e7bb2d4..84c206d7a 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -5,11 +5,13 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import os +import sys import subprocess import logging import time from pathlib import Path from datetime import timedelta +from typing import List import iree.compiler as ireec @@ -25,6 +27,7 @@ class ExportArtifacts: def __init__( self, + *, irpa_path: str, batch_size: int, iree_hip_target: str, @@ -59,9 +62,44 @@ def wrapper(*args, **kwargs): return wrapper + @timeit + def shard_irpa_file( + self, + *, + output_file: str, + ): + shard_irpa_args = [ + "python3", + "-m", + "sharktank.models.llama.tools.shard_llama", + "--irpa-file", + self.irpa_path, + "--output-file", + output_file, + "--shard_count", + str(self.tensor_parallelism_size), + ] + + cwd = self.sharktank_dir + cmd = subprocess.list2cmdline(shard_irpa_args) + + logger.info(f"Sharding irpa file:\n" f"cd {cwd} && {cmd}") + + proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) + if proc.returncode != 0: + logger.error( + f"Error sharding irpa file with shard_llama.py\n" + f"{proc.stdout+proc.stderr}" + ) + else: + logger.info(f"Sharded irpa file successfully:\n" f"{proc.stdout}") + + return proc.returncode + @timeit def export_to_mlir( self, + *, mlir_path: str, json_path: str, ): @@ -78,18 +116,16 @@ def export_to_mlir( "--bs", str(self.batch_size), ] - if self.attention_kernel == "decomposed": + if self.attention_kernel in ["decomposed", "torch"]: export_args.append("--attention-kernel") export_args.append(self.attention_kernel) - elif self.attention_kernel == "torch_sdpa": - raise NotImplementedError("attention_kernel torch_sdpa not implemented yet") cwd = self.sharktank_dir cmd = subprocess.list2cmdline(export_args) logger.info(f"Exporting mlir:\n" f"cd {cwd} && {cmd}") - proc = subprocess.run(export_args, capture_output=True, cwd=cwd, text=True) + proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) if proc.returncode != 0: logger.error( f"Error exporting mlir with export_paged_llm_v1.py\n" @@ -103,12 +139,14 @@ def export_to_mlir( @timeit def compile_to_vmfb( self, + *, mlir_path, vmfb_path, + hal_dump_path, ): # TODO: Control flag to enable multiple backends compile_flags = ["--iree-hip-target=" + self.iree_hip_target] - + compile_flags += [f"--iree-hal-dump-executable-files-to={hal_dump_path}/files"] try: ireec.compile_file( input_file=mlir_path, @@ -121,7 +159,43 @@ def compile_to_vmfb( else: logger.info(f"Compiled to vmfb successfully:\n" f"{vmfb_path}") - def create_file(self, suffix, prefix): + def iree_benchmark_vmfb( + self, + *, + hip_device_id: str, + vmfb_name: str, + irpa_path: str, + args: List[str], + cwd: str | Path, + ): + """Runs a compiled program with the given args using `iree-benchmark-module`. + This assumes that the `iree-benchmark-module` command is available (usually via PATH). + Args: + vmfb_name: Name of the .vmfb file (relative to `cwd`). + args: List of arguments to pass to `iree-benchmark-module`. + cwd: Working directory to run the command within. (either string or Path works) + compile_cmd: Command used to compile the program, for inclusion in error messages. + Raises Exception if running fails for some reason. + """ + benchmark_args = [ + f"ROCR_VISIBLE_DEVICES={hip_device_id}", + "iree-benchmark-module", + f"--device=hip://{hip_device_id}", + "--hip_use_streams=true", + "--hip_allow_inline_execution=true", + "--device_allocator=caching", + f"--module={vmfb_name}", + f"--parameters=model={irpa_path}", + ] + benchmark_args += args + cmd = subprocess.list2cmdline(benchmark_args) + logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") + proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd) + return_code = proc.returncode + if return_code != 0: + raise RuntimeError(f"Error running benchmark {cmd} in cwd {cwd}") + + def create_file(self, *, suffix, prefix): file_path = Path(prefix).with_suffix(suffix) f = open(file_path, "w") return file_path diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_tests.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py similarity index 73% rename from sharktank/tests/models/llama/benchmark_amdgpu_tests.py rename to sharktank/tests/models/llama/benchmark_amdgpu_test.py index 174fcbe87..c99bbc7e1 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_tests.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -13,76 +13,13 @@ import subprocess from pathlib import Path from typing import List +from sharktank.utils.export_artifacts import ExportArtifacts longrun = pytest.mark.skipif("not config.getoption('longrun')") is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") -class ExportMlirException(Exception): - """SHARK-Platform export MLIR exception that preserves the command line and error output.""" - - def __init__(self, process: subprocess.CompletedProcess, cwd: str): - try: - errs = process.stderr.decode("utf-8") - except: - errs = str(process.stderr) - - super().__init__( - f"Error invoking export_paged_llama_v1.py\n" - f"Error code: {process.returncode}\n" - f"Stderr diagnostics:\n{errs}\n\n" - f"Invoked with:\n" - f" cd {cwd} && {process.args}\n\n" - ) - - -class IreeCompileException(Exception): - """Compiler exception that preserves the command line and error output.""" - - def __init__(self, process: subprocess.CompletedProcess, cwd: str): - try: - errs = process.stderr.decode("utf-8") - except: - errs = str(process.stderr) - - super().__init__( - f"Error invoking iree-compile\n" - f"Error code: {process.returncode}\n" - f"Stderr diagnostics:\n{errs}\n\n" - f"Invoked with:\n" - f" cd {cwd} && {process.args}\n\n" - ) - - -class IreeBenchmarkException(Exception): - """Runtime exception that preserves the command line and error output.""" - - def __init__( - self, process: subprocess.CompletedProcess, cwd: str, compile_cmd: str - ): - # iree-run-module sends output to both stdout and stderr - try: - errs = process.stderr.decode("utf-8") - except: - errs = str(process.stderr) - try: - outs = process.stdout.decode("utf-8") - except: - outs = str(process.stdout) - - super().__init__( - f"Error invoking iree-benchmark-module\n" - f"Error code: {process.returncode}\n" - f"Stderr diagnostics:\n{errs}\n" - f"Stdout diagnostics:\n{outs}\n" - f"Compiled with:\n" - f" cd {cwd} && {compile_cmd}\n\n" - f"Run with:\n" - f" cd {cwd} && {process.args}\n\n" - ) - - -@pytest.mark.usefixtures("iree_hip_target_type") +@pytest.mark.usefixtures("get_iree_flags") class BaseBenchmarkTest(unittest.TestCase): directory_created = False current_date = datetime.now() @@ -105,167 +42,34 @@ def setUpClass(cls): def setUp(self): self.hip_device_id = os.getenv("HIP_DEVICE_ID", default="0") - def create_file(self, *, suffix, prefix): - file_path = Path(prefix).with_suffix(suffix) - f = open(file_path, "w") - return file_path - - def get_export_cmd( - self, - *, - attention_kernel: str, - tensor_parallelism_size: int, - irpa_path: str, - output_mlir_path: str, - output_json_path: str, - ): - export_args = [ - "python3", - "-m", - "sharktank.examples.export_paged_llm_v1", - "--irpa-file", - irpa_path, - "--output-mlir", - output_mlir_path, - "--output-config", - output_json_path, - ] - if attention_kernel == "decomposed": - export_args.append("--attention-kernel") - export_args.append(attention_kernel) - elif attention_kernel == "torch_sdpa": - raise NotImplementedError( - "attention_kernel torch_sdpa not yet plumbed through" - ) - if tensor_parallelism_size: - export_args.append("--tensor-parallelism-size") - export_args.append(str(tensor_parallelism_size)) - - cmd = subprocess.list2cmdline(export_args) - return cmd - - def get_compile_cmd( - self, *, output_mlir_path: str, output_vmfb_path: str, args: [str] - ): - compile_args = ["iree-compile", output_mlir_path] - compile_args += args - compile_args += ["-o", output_vmfb_path] - cmd = subprocess.list2cmdline(compile_args) - return cmd - - def export_mlir( - self, - *, - attention_kernel: str, - tensor_parallelism_size: int, - irpa_path: str, - output_mlir_path: str, - output_json_path: str, - cwd: str | Path, - ): - """Runs export_paged_llm_v1.py and exports an MLIR file. - Args: - irpa_path: Path to the model irpa file. - output_mlir_path: Path to the file to save the exported file. - output_json_path: Path to the file to save the config json file. - """ - cmd = self.get_export_cmd( - attention_kernel=attention_kernel, - tensor_parallelism_size=tensor_parallelism_size, - irpa_path=irpa_path, - output_mlir_path=output_mlir_path, - output_json_path=output_json_path, - ) - logging.getLogger().info(f"Launching export command:\n" f"cd {cwd} && {cmd}") - proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) - return_code = proc.returncode - if return_code != 0: - raise ExportMlirException(proc, cwd) - - def iree_compile( - self, - *, - mlir_path: str, - output_vmfb_path: str, - args: List[str], - cwd: str | Path, - ): - """Compiles an input MLIR file to an output .vmfb file. - This assumes that the `iree-compile` command is available (usually via PATH). - Args: - mlir_path: Path to the input MLIR file. - output_vmfb_path: Path for the output .vmfb file. The directory must already exist. - args: List of arguments to pass to `iree-compile`. - cwd: current working directory - Raises Exception if compilation fails for some reason. - """ - cmd = self.get_compile_cmd( - output_mlir_path=mlir_path, - output_vmfb_path=output_vmfb_path, - args=args, - ) - logging.getLogger().info(f"Launching compile command:\n" f"cd {cwd} && {cmd}") - proc = subprocess.run(cmd, shell=True, capture_output=True, cwd=cwd) - return_code = proc.returncode - if return_code != 0: - raise IreeCompileException(proc, cwd) - - def iree_benchmark_module( - self, - *, - hip_device_id: str, - vmfb_name: str, - irpa_path: str, - args: List[str], - cwd: str | Path, - ): - """Runs a compiled program with the given args using `iree-benchmark-module`. - This assumes that the `iree-benchmark-module` command is available (usually via PATH). - Args: - vmfb_name: Name of the .vmfb file (relative to `cwd`). - args: List of arguments to pass to `iree-benchmark-module`. - cwd: Working directory to run the command within. (either string or Path works) - compile_cmd: Command used to compile the program, for inclusion in error messages. - Raises Exception if running fails for some reason. - """ - benchmark_args = [ - f"ROCR_VISIBLE_DEVICES={hip_device_id}", - "iree-benchmark-module", - f"--device=hip://{hip_device_id}", - "--hip_use_streams=true", - "--hip_allow_inline_execution=true", - "--device_allocator=caching", - f"--module={vmfb_name}", - f"--parameters=model={irpa_path}", - ] - benchmark_args += args - cmd = subprocess.list2cmdline(benchmark_args) - logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") - proc = subprocess.run(cmd, shell=True, stdout=sys.stdout, cwd=cwd) - return_code = proc.returncode - if return_code != 0: - raise IreeBenchmarkException(proc, cwd, cmd) - class BenchmarkLlama3_1_8B(BaseBenchmarkTest): def setUp(self): super().setUp() # TODO: add numpy files to Azure and download from it - artifacts_dir = Path("/data/llama-3.1/8b") - self.irpa_path = artifacts_dir / "llama8b_f16.irpa" - self.irpa_path_fp8 = artifacts_dir / "llama8b_fp8.irpa" + self.artifacts_dir = Path("/data/llama-3.1/8b") + self.irpa_path = self.artifacts_dir / "llama8b_f16.irpa" + self.irpa_path_fp8 = self.artifacts_dir / "llama8b_fp8.irpa" self.tensor_parallelism_size = 1 self.dir_path_8b = self.dir_path / "llama-8b" self.temp_dir_8b = Path(self.dir_path_8b) self.temp_dir_8b.mkdir(parents=True, exist_ok=True) + self.llama8b_f16_artifacts = ExportArtifacts( + irpa_path=str(self.irpa_path), + batch_size=4, + iree_hip_target="gfx942", + iree_hal_target_backends="rocm", + attention_kernel="decomposed", + tensor_parallelism_size=self.tensor_parallelism_size, + ) self.iree_compile_args = [ "--iree-hal-target-backends=rocm", - f"--iree-hip-target={self.iree_hip_target_type}", + f"--iree-hip-target={self.iree_hip_target}", ] - self.prefill_args_f16 = artifacts_dir / "prefill_args" - self.decode_args_f16 = artifacts_dir / "decode_args" - self.prefill_args_fp8 = artifacts_dir / "prefill_args_fp8" - self.decode_args_fp8 = artifacts_dir / "decode_args_fp8" + self.prefill_args_f16 = self.artifacts_dir / "prefill_args" + self.decode_args_f16 = self.artifacts_dir / "decode_args" + self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8" + self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8" self.iree_run_prefill_args = [ "--function=prefill_bs4", f"--input=@{self.prefill_args_f16}/tokens.npy", @@ -305,28 +109,36 @@ def setUp(self): @is_mi300x def testBenchmark8B_f16_Decomposed(self): output_file_name = self.dir_path_8b / "f16_decomposed" - output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) - output_json = self.create_file(suffix=".json", prefix=output_file_name) - output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) - self.export_mlir( - attention_kernel="decomposed", - tensor_parallelism_size=self.tensor_parallelism_size, - irpa_path=self.irpa_path, - output_mlir_path=output_mlir, - output_json_path=output_json, - cwd=self.repo_root, + output_mlir = self.llama8b_f16_artifacts.create_file( + suffix=".mlir", prefix=output_file_name ) - iree_compile_args = self.iree_compile_args + [ - f"--iree-hal-dump-executable-files-to={output_file_name}/files" - ] - self.iree_compile( + output_json = self.llama8b_f16_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_f16_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + output_shard_file_name = str( + self.artifacts_dir + / f"llama3.1_8b_fp16_tp{self.tensor_parallelism_size}_parameters.irpa" + ) + # shard_irpa file + shard_return_code = self.llama8b_f16_artifacts.shard_irpa_file( + output_file=output_shard_file_name + ) + if shard_return_code == 0: + self.irpa_path = output_shard_file_name + export_return_code = self.llama8b_f16_artifacts.export_to_mlir( mlir_path=output_mlir, - output_vmfb_path=output_vmfb, - args=iree_compile_args, - cwd=self.repo_root, + json_path=output_json, + ) + self.llama8b_f16_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, ) # benchmark prefill - self.iree_benchmark_module( + self.llama8b_f16_artifacts.iree_benchmark_vmfb( hip_device_id=self.hip_device_id, vmfb_name=output_vmfb, irpa_path=self.irpa_path, @@ -334,7 +146,7 @@ def testBenchmark8B_f16_Decomposed(self): cwd=self.repo_root, ) # benchmark decode - self.iree_benchmark_module( + self.llama8b_f16_artifacts.iree_benchmark_vmfb( hip_device_id=self.hip_device_id, vmfb_name=output_vmfb, irpa_path=self.irpa_path, @@ -346,29 +158,38 @@ def testBenchmark8B_f16_Decomposed(self): @is_mi300x @pytest.mark.xfail(reason="torch_sdpa not yet plumbed through", strict=True) def testBenchmark8B_f16_Non_Decomposed(self): - output_file_name = self.dir_path_8b / "f16_torch_sdpa" - output_mlir = self.create_file(suffix=".mlir", prefix=output_file_name) - output_json = self.create_file(suffix=".json", prefix=output_file_name) - output_vmfb = self.create_file(suffix=".vmfb", prefix=output_file_name) - self.export_mlir( - attention_kernel="torch_sdpa", - tensor_parallelism_size=self.tensor_parallelism_size, - irpa_path=self.irpa_path, - output_mlir_path=output_mlir, - output_json_path=output_json, - cwd=self.repo_root, + output_file_name = self.dir_path_8b / "f16_torch" + output_mlir = self.llama8b_f16_artifacts.create_file( + suffix=".mlir", prefix=output_file_name ) - iree_compile_args = self.iree_compile_args + [ - f"--iree-hal-dump-executable-files-to={output_file_name}/files" - ] - self.iree_compile( + output_json = self.llama8b_f16_artifacts.create_file( + suffix=".json", prefix=output_file_name + ) + output_vmfb = self.llama8b_f16_artifacts.create_file( + suffix=".vmfb", prefix=output_file_name + ) + self.llama8b_f16_artifacts.attention_kernel = "torch" + output_shard_file_name = str( + self.artifacts_dir + / f"llama3.1_8b_fp16_tp{self.tensor_parallelism_size}_parameters_torch_sdpa.irpa" + ) + # shard_irpa file + shard_return_code = self.llama8b_f16_artifacts.shard_irpa_file( + output_file=output_shard_file_name + ) + if shard_return_code == 0: + self.irpa_path = output_shard_file_name + export_return_code = self.llama8b_f16_artifacts.export_to_mlir( mlir_path=output_mlir, - output_vmfb_path=output_vmfb, - args=iree_compile_args, - cwd=self.repo_root, + json_path=output_json, + ) + self.llama8b_f16_artifacts.compile_to_vmfb( + mlir_path=str(output_mlir), + vmfb_path=output_vmfb, + hal_dump_path=output_file_name, ) # benchmark prefill - self.iree_benchmark_module( + self.llama8b_f16_artifacts.iree_benchmark_vmfb( hip_device_id=self.hip_device_id, vmfb_name=output_vmfb, irpa_path=self.irpa_path, @@ -376,7 +197,7 @@ def testBenchmark8B_f16_Non_Decomposed(self): cwd=self.repo_root, ) # benchmark decode - self.iree_benchmark_module( + self.llama8b_f16_artifacts.iree_benchmark_vmfb( hip_device_id=self.hip_device_id, vmfb_name=output_vmfb, irpa_path=self.irpa_path, @@ -482,7 +303,7 @@ def setUp(self): self.temp_dir_70b.mkdir(parents=True, exist_ok=True) self.iree_compile_args = [ "--iree-hal-target-backends=rocm", - f"--iree-hip-target={self.iree_hip_target_type}", + f"--iree-hip-target={self.iree_hip_target}", ] self.prefill_args_f16 = artifacts_dir / "prefill_args" self.decode_args_f16 = artifacts_dir / "decode_args" @@ -705,7 +526,7 @@ def setUp(self): self.temp_dir_405b.mkdir(parents=True, exist_ok=True) self.iree_compile_args = [ "--iree-hal-target-backends=rocm", - f"--iree-hip-target={self.iree_hip_target_type}", + f"--iree-hip-target={self.iree_hip_target}", ] self.prefill_args_f16 = artifacts_dir / "prefill_args" self.decode_args_f16 = artifacts_dir / "decode_args" diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index d6cf71e48..2deec49c0 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -24,6 +24,12 @@ logger = logging.getLogger(__name__) +prog_isolations = { + "none": sf.ProgramIsolation.NONE, + "per_fiber": sf.ProgramIsolation.PER_FIBER, + "per_call": sf.ProgramIsolation.PER_CALL, +} + class GenerateService: """Top level service interface for image generation.""" @@ -39,6 +45,9 @@ def __init__( sysman: SystemManager, tokenizers: list[Tokenizer], model_params: ModelParams, + fibers_per_device: int, + prog_isolation: str = "per_fiber", + show_progress: bool = False, ): self.name = name @@ -50,17 +59,20 @@ def __init__( self.inference_modules: dict[str, sf.ProgramModule] = {} self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {} self.inference_programs: dict[str, sf.Program] = {} - self.procs_per_device = 1 + self.trace_execution = False + self.show_progress = show_progress + self.fibers_per_device = fibers_per_device + self.prog_isolation = prog_isolations[prog_isolation] self.workers = [] self.fibers = [] - self.locks = [] + self.fiber_status = [] for idx, device in enumerate(self.sysman.ls.devices): - for i in range(self.procs_per_device): + for i in range(self.fibers_per_device): worker = sysman.ls.create_worker(f"{name}-inference-{device.name}-{i}") fiber = sysman.ls.create_fiber(worker, devices=[device]) self.workers.append(worker) self.fibers.append(fiber) - self.locks.append(asyncio.Lock()) + self.fiber_status.append(0) # Scope dependent objects. self.batcher = BatcherProcess(self) @@ -99,7 +111,8 @@ def start(self): self.inference_programs[component] = sf.Program( modules=component_modules, devices=fiber.raw_devices, - trace_execution=False, + isolation=self.prog_isolation, + trace_execution=self.trace_execution, ) # TODO: export vmfbs with multiple batch size entrypoints @@ -169,6 +182,7 @@ def __init__(self, service: GenerateService): self.strobe_enabled = True self.strobes: int = 0 self.ideal_batch_size: int = max(service.model_params.max_batch_size) + self.num_fibers = len(service.fibers) def shutdown(self): self.batcher_infeed.close() @@ -199,6 +213,7 @@ async def run(self): logger.error("Illegal message received by batcher: %r", item) self.board_flights() + self.strobe_enabled = True await strober_task @@ -210,28 +225,40 @@ def board_flights(self): logger.info("Waiting a bit longer to fill flight") return self.strobes = 0 + batches = self.sort_batches() + for idx, batch in batches.items(): + for fidx, status in enumerate(self.service.fiber_status): + if ( + status == 0 + or self.service.prog_isolation == sf.ProgramIsolation.PER_CALL + ): + self.board(batch["reqs"], index=fidx) + break - batches = self.sort_pending() - for idx in batches.keys(): - self.board(batches[idx]["reqs"], index=idx) - - def sort_pending(self): - """Returns pending requests as sorted batches suitable for program invocations.""" + def sort_batches(self): + """Files pending requests into sorted batches suitable for program invocations.""" + reqs = self.pending_requests + next_key = 0 batches = {} - for req in self.pending_requests: + for req in reqs: is_sorted = False req_metas = [req.phases[phase]["metadata"] for phase in req.phases.keys()] - next_key = 0 + for idx_key, data in batches.items(): if not isinstance(data, dict): logger.error( "Expected to find a dictionary containing a list of requests and their shared metadatas." ) - if data["meta"] == req_metas: - batches[idx_key]["reqs"].append(req) + if len(batches[idx_key]["reqs"]) >= self.ideal_batch_size: + # Batch is full + next_key = idx_key + 1 + continue + elif data["meta"] == req_metas: + batches[idx_key]["reqs"].extend([req]) is_sorted = True break - next_key = idx_key + 1 + else: + next_key = idx_key + 1 if not is_sorted: batches[next_key] = { "reqs": [req], @@ -251,7 +278,8 @@ def board(self, request_bundle, index): if exec_process.exec_requests: for flighted_request in exec_process.exec_requests: self.pending_requests.remove(flighted_request) - print(f"launching exec process for {exec_process.exec_requests}") + if self.service.prog_isolation != sf.ProgramIsolation.PER_CALL: + self.service.fiber_status[index] = 1 exec_process.launch() @@ -284,22 +312,22 @@ async def run(self): phases = self.exec_requests[0].phases req_count = len(self.exec_requests) - async with self.service.locks[self.worker_index]: - device0 = self.fiber.device(0) - if phases[InferencePhase.PREPARE]["required"]: - await self._prepare(device=device0, requests=self.exec_requests) - if phases[InferencePhase.ENCODE]["required"]: - await self._encode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DENOISE]["required"]: - await self._denoise(device=device0, requests=self.exec_requests) - if phases[InferencePhase.DECODE]["required"]: - await self._decode(device=device0, requests=self.exec_requests) - if phases[InferencePhase.POSTPROCESS]["required"]: - await self._postprocess(device=device0, requests=self.exec_requests) + device0 = self.service.fibers[self.worker_index].device(0) + if phases[InferencePhase.PREPARE]["required"]: + await self._prepare(device=device0, requests=self.exec_requests) + if phases[InferencePhase.ENCODE]["required"]: + await self._encode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DENOISE]["required"]: + await self._denoise(device=device0, requests=self.exec_requests) + if phases[InferencePhase.DECODE]["required"]: + await self._decode(device=device0, requests=self.exec_requests) + if phases[InferencePhase.POSTPROCESS]["required"]: + await self._postprocess(device=device0, requests=self.exec_requests) for i in range(req_count): req = self.exec_requests[i] req.done.set_success() + self.service.fiber_status[self.worker_index] = 0 except Exception: logger.exception("Fatal error in image generation") @@ -345,7 +373,6 @@ async def _prepare(self, device, requests): sfnp.fill_randn(sample_host, generator=generator) request.sample.copy_from(sample_host) - await device return async def _encode(self, device, requests): @@ -385,15 +412,13 @@ async def _encode(self, device, requests): clip_inputs[idx].copy_from(host_arrs[idx]) # Encode tokenized inputs. - logger.info( + logger.debug( "INVOKE %r: %s", fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) - await device pe, te = await fn(*clip_inputs, fiber=self.fiber) - await device for i in range(req_bs): cfg_mult = 2 requests[i].prompt_embeds = pe.view(slice(i * cfg_mult, (i + 1) * cfg_mult)) @@ -477,20 +502,23 @@ async def _denoise(self, device, requests): ns_host.items = [step_count] num_steps.copy_from(ns_host) - await device + init_inputs = [ + denoise_inputs["sample"], + num_steps, + ] + # Initialize scheduler. - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["init"], - "".join([f"\n 0: {latents_shape}"]), ) (latents, time_ids, timesteps, sigmas) = await fns["init"]( - denoise_inputs["sample"], num_steps, fiber=self.fiber + *init_inputs, fiber=self.fiber ) - - await device for i, t in tqdm( enumerate(range(step_count)), + disable=(not self.service.show_progress), + desc=f"Worker #{self.worker_index} DENOISE (bs{req_bs})", ): step = sfnp.device_array.for_device(device, [1], sfnp.sint64) s_host = step.for_transfer() @@ -498,14 +526,10 @@ async def _denoise(self, device, requests): s_host.items = [i] step.copy_from(s_host) scale_inputs = [latents, step, timesteps, sigmas] - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["scale"], - "".join( - [f"\n {i}: {ary.shape}" for i, ary in enumerate(scale_inputs)] - ), ) - await device latent_model_input, t, sigma, next_sigma = await fns["scale"]( *scale_inputs, fiber=self.fiber ) @@ -519,32 +543,25 @@ async def _denoise(self, device, requests): time_ids, denoise_inputs["guidance_scale"], ] - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["unet"], - "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(unet_inputs)]), ) - await device (noise_pred,) = await fns["unet"](*unet_inputs, fiber=self.fiber) - await device step_inputs = [noise_pred, latents, sigma, next_sigma] - logger.info( - "INVOKE %r: %s", + logger.debug( + "INVOKE %r", fns["step"], - "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(step_inputs)]), ) - await device (latent_model_output,) = await fns["step"](*step_inputs, fiber=self.fiber) latents.copy_from(latent_model_output) - await device for idx, req in enumerate(requests): req.denoised_latents = sfnp.device_array.for_device( device, latents_shape, self.service.model_params.vae_dtype ) req.denoised_latents.copy_from(latents.view(idx)) - await device return async def _decode(self, device, requests): @@ -569,6 +586,11 @@ async def _decode(self, device, requests): await device # Decode the denoised latents. + logger.debug( + "INVOKE %r: %s", + fn, + "".join([f"\n 0: {latents.shape}"]), + ) (image,) = await fn(latents, fiber=self.fiber) await device diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 5e7abd1fc..0327b0a9f 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -31,7 +31,9 @@ from .components.tokenizer import Tokenizer -logger = logging.getLogger(__name__) +from shortfin.support.logging_setup import configure_main_logger + +logger = configure_main_logger("server") @asynccontextmanager @@ -87,7 +89,13 @@ def configure(args) -> SystemManager: model_params = ModelParams.load_json(args.model_config) sm = GenerateService( - name="sd", sysman=sysman, tokenizers=tokenizers, model_params=model_params + name="sd", + sysman=sysman, + tokenizers=tokenizers, + model_params=model_params, + fibers_per_device=args.fibers_per_device, + prog_isolation=args.isolation, + show_progress=args.show_progress, ) sm.load_inference_module(args.clip_vmfb, component="clip") sm.load_inference_module(args.unet_vmfb, component="unet") @@ -188,10 +196,40 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): nargs="*", help="Parameter archives to load", ) + parser.add_argument( + "--fibers_per_device", + type=int, + default=1, + help="Concurrency control -- how many fibers are created per device to run inference.", + ) + parser.add_argument( + "--isolation", + type=str, + default="per_fiber", + choices=["per_fiber", "per_call", "none"], + help="Concurrency control -- How to isolate programs.", + ) + parser.add_argument( + "--log_level", type=str, default="error", choices=["info", "debug", "error"] + ) + parser.add_argument( + "--show_progress", + action="store_true", + help="enable tqdm progress for unet iterations.", + ) + log_levels = { + "info": logging.INFO, + "debug": logging.DEBUG, + "error": logging.ERROR, + } + args = parser.parse_args(argv) + + log_level = log_levels[args.log_level] + logger.setLevel(log_level) + global sysman sysman = configure(args) - uvicorn.run( app, host=args.host, @@ -202,9 +240,6 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): if __name__ == "__main__": - from shortfin.support.logging_setup import configure_main_logger - - logger = configure_main_logger("server") main( sys.argv[1:], # Make logging defer to the default shortfin logging config. diff --git a/shortfin/setup.py b/shortfin/setup.py index f4df31237..9ba5b6140 100644 --- a/shortfin/setup.py +++ b/shortfin/setup.py @@ -111,12 +111,12 @@ def is_cpp_prebuilt(): # Due to a quirk of setuptools, that package_dir map must only contain # paths relative to the directory containing setup.py. Why? No one knows. -REL_SOURCE_DIR = SOURCE_DIR.relative_to(SETUPPY_DIR, walk_up=True) -REL_BINARY_DIR = BINARY_DIR.relative_to(SETUPPY_DIR, walk_up=True) -REL_CMAKE_DEFAULT_BUILD_DIR = CMAKE_DEFAULT_BUILD_DIR.relative_to( - SETUPPY_DIR, walk_up=True +REL_SOURCE_DIR = Path(os.path.relpath(SOURCE_DIR, SETUPPY_DIR)) +REL_BINARY_DIR = Path(os.path.relpath(BINARY_DIR, SETUPPY_DIR)) +REL_CMAKE_DEFAULT_BUILD_DIR = Path( + os.path.relpath(CMAKE_DEFAULT_BUILD_DIR, SETUPPY_DIR) ) -REL_CMAKE_TRACY_BUILD_DIR = CMAKE_TRACY_BUILD_DIR.relative_to(SETUPPY_DIR, walk_up=True) +REL_CMAKE_TRACY_BUILD_DIR = Path(os.path.relpath(CMAKE_TRACY_BUILD_DIR, SETUPPY_DIR)) class CMakeExtension(Extension): diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py index b8331946b..05b9ef69b 100644 --- a/shortfin/tests/apps/sd/e2e_test.py +++ b/shortfin/tests/apps/sd/e2e_test.py @@ -1,9 +1,3 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - import json import requests import time @@ -13,6 +7,7 @@ import os import socket import sys +import copy from contextlib import closing from datetime import datetime as dt @@ -27,7 +22,7 @@ "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], "height": [1024], "width": [1024], - "steps": [20], + "steps": [5], "guidance_scale": [7.5], "seed": [0], "output_type": ["base64"], @@ -51,12 +46,7 @@ def sd_artifacts(target: str = "gfx942"): cache = os.path.abspath("./tmp/sharktank/sd/") -@pytest.fixture(scope="module") -def sd_server(): - # Create necessary directories - - os.makedirs(cache, exist_ok=True) - +def start_server(fibers_per_device=1, isolation="per_fiber"): # Download model if it doesn't exist vmfbs_bucket = "https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/vmfbs/" weights_bucket = ( @@ -88,9 +78,67 @@ def sd_server(): for arg in sd_artifacts().keys(): artifact_arg = f"--{arg}={cache}/{sd_artifacts()[arg]}" srv_args.extend([artifact_arg]) + srv_args.extend( + [ + f"--fibers_per_device={fibers_per_device}", + f"--isolation={isolation}", + ] + ) runner = ServerRunner(srv_args) # Wait for server to start - time.sleep(5) + time.sleep(3) + return runner + + +@pytest.fixture(scope="module") +def sd_server_fpd1(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=1) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd1_per_call(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=1, isolation="per_call") + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd2(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=2) + + yield runner + + # Teardown: kill the server + del runner + + +@pytest.fixture(scope="module") +def sd_server_fpd8(): + # Create necessary directories + + os.makedirs(cache, exist_ok=True) + + runner = start_server(fibers_per_device=8) yield runner @@ -99,19 +147,46 @@ def sd_server(): @pytest.mark.system("amdgpu") -def test_sd_server(sd_server): - imgs, status_code = send_json_file(sd_server.url) +def test_sd_server(sd_server_fpd1): + imgs, status_code = send_json_file(sd_server_fpd1.url) assert len(imgs) == 1 assert status_code == 200 +@pytest.mark.system("amdgpu") +def test_sd_server_bs4_dense(sd_server_fpd1): + imgs, status_code = send_json_file(sd_server_fpd1.url, num_copies=4) + assert len(imgs) == 4 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs8_percall(sd_server_fpd1_per_call): + imgs, status_code = send_json_file(sd_server_fpd1_per_call.url, num_copies=8) + assert len(imgs) == 8 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs4_dense_fpd2(sd_server_fpd2): + imgs, status_code = send_json_file(sd_server_fpd2.url, num_copies=4) + assert len(imgs) == 4 + assert status_code == 200 + + +@pytest.mark.system("amdgpu") +def test_sd_server_bs8_dense_fpd8(sd_server_fpd8): + imgs, status_code = send_json_file(sd_server_fpd8.url, num_copies=8) + assert len(imgs) == 8 + assert status_code == 200 + + class ServerRunner: def __init__(self, args): port = str(find_free_port()) self.url = "http://0.0.0.0:" + port env = os.environ.copy() env["PYTHONUNBUFFERED"] = "1" - env["HIP_VISIBLE_DEVICES"] = "0" self.process = subprocess.Popen( [ *args, @@ -158,27 +233,24 @@ def bytes_to_img(bytes, idx=0, width=1024, height=1024): return image -def send_json_file(url="http://0.0.0.0:8000"): +def send_json_file(url="http://0.0.0.0:8000", num_copies=1): # Read the JSON file - data = sample_request + data = copy.deepcopy(sample_request) imgs = [] # Send the data to the /generate endpoint + data["prompt"] = ( + [data["prompt"]] + if isinstance(data["prompt"], str) + else data["prompt"] * num_copies + ) try: response = requests.post(url + "/generate", json=data) response.raise_for_status() # Raise an error for bad responses request = json.loads(response.request.body.decode("utf-8")) for idx, item in enumerate(response.json()["images"]): - width = ( - request["width"][idx] - if isinstance(request["height"], list) - else request["height"] - ) - height = ( - request["height"][idx] - if isinstance(request["height"], list) - else request["height"] - ) + width = getbatched(request, idx, "width") + height = getbatched(request, idx, "height") img = bytes_to_img(item.encode("utf-8"), idx, width, height) imgs.append(img) @@ -188,6 +260,16 @@ def send_json_file(url="http://0.0.0.0:8000"): return imgs, response.status_code +def getbatched(req, idx, key): + if isinstance(req[key], list): + if len(req[key]) == 1: + return req[key][0] + elif len(req[key]) > idx: + return req[key][idx] + else: + return req[key] + + def find_free_port(): """This tries to find a free port to run a server on for the test.