From 7f5d48690ca4e72398f773487e70a5e7bfdf8cfd Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 26 Nov 2024 03:22:36 -0800 Subject: [PATCH 1/2] Add torchao to optimum as a pytorch backend configuration (#297) --- docker/cpu/Dockerfile | 6 +++--- docker/cuda-ort/Dockerfile | 6 +++--- docker/cuda/Dockerfile | 6 +++--- examples/pytorch_llama.py | 5 +++++ optimum_benchmark/backends/pytorch/backend.py | 13 +++++++++++++ optimum_benchmark/backends/pytorch/config.py | 2 +- 6 files changed, 28 insertions(+), 10 deletions(-) diff --git a/docker/cpu/Dockerfile b/docker/cpu/Dockerfile index 9f61e0d4..92ed6163 100644 --- a/docker/cpu/Dockerfile +++ b/docker/cpu/Dockerfile @@ -33,11 +33,11 @@ ARG TORCH_VERSION="" ARG TORCH_RELEASE_TYPE=stable RUN if [ -n "${TORCH_VERSION}" ]; then \ - pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ; \ + pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/cpu ; \ elif [ "${TORCH_RELEASE_TYPE}" = "stable" ]; then \ - pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu ; \ + pip install --no-cache-dir torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/cpu ; \ elif [ "${TORCH_RELEASE_TYPE}" = "nightly" ]; then \ - pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu ; \ + pip install --no-cache-dir --pre torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/nightly/cpu ; \ else \ echo "Error: Invalid TORCH_RELEASE_TYPE. Must be 'stable', 'nightly', or specify a TORCH_VERSION." && exit 1 ; \ fi diff --git a/docker/cuda-ort/Dockerfile b/docker/cuda-ort/Dockerfile index 88ba302e..6ed39063 100644 --- a/docker/cuda-ort/Dockerfile +++ b/docker/cuda-ort/Dockerfile @@ -32,11 +32,11 @@ ARG TORCH_CUDA=cu118 ARG TORCH_VERSION=stable RUN if [ "${TORCH_VERSION}" = "stable" ]; then \ - pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \ + pip install --no-cache-dir torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \ elif [ "${TORCH_VERSION}" = "nightly" ]; then \ - pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \ + pip install --no-cache-dir --pre torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \ else \ - pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \ + pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \ fi # Install torch-ort and onnxruntime-training diff --git a/docker/cuda/Dockerfile b/docker/cuda/Dockerfile index 64372600..462a6211 100644 --- a/docker/cuda/Dockerfile +++ b/docker/cuda/Dockerfile @@ -32,11 +32,11 @@ ARG TORCH_CUDA=cu124 ARG TORCH_RELEASE_TYPE=stable RUN if [ -n "${TORCH_VERSION}" ]; then \ - pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \ + pip install --no-cache-dir torch==${TORCH_VERSION} torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \ elif [ "${TORCH_RELEASE_TYPE}" = "stable" ]; then \ - pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \ + pip install --no-cache-dir torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/${TORCH_CUDA} ; \ elif [ "${TORCH_RELEASE_TYPE}" = "nightly" ]; then \ - pip install --no-cache-dir --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \ + pip install --no-cache-dir --pre torch torchvision torchaudio torchao --index-url https://download.pytorch.org/whl/nightly/${TORCH_CUDA} ; \ else \ echo "Error: Invalid TORCH_RELEASE_TYPE. Must be 'stable', 'nightly', or specify a TORCH_VERSION." && exit 1 ; \ fi diff --git a/examples/pytorch_llama.py b/examples/pytorch_llama.py index 90c09931..bcaaedcd 100644 --- a/examples/pytorch_llama.py +++ b/examples/pytorch_llama.py @@ -29,6 +29,11 @@ "quantization_scheme": "gptq", "quantization_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256}, }, + "torchao-int4wo-128": { + "torch_dtype": "bfloat16", + "quantization_scheme": "torchao", + "quantization_config": {"quant_type": "int4_weight_only", "group_size": 128}, + } } diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index fcf522b5..c7ddb14f 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -16,6 +16,7 @@ TrainerState, TrainingArguments, ) +from transformers import TorchAoConfig from ...import_utils import is_deepspeed_available, is_torch_distributed_available, is_zentorch_available from ..base import Backend @@ -323,6 +324,11 @@ def process_quantization_config(self) -> None: self.quantization_config = BitsAndBytesConfig( **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) ) + elif self.is_torchao_quantized: + self.logger.info("\t+ Processing TorchAO config") + self.quantization_config = TorchAoConfig( + **dict(getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config) + ) else: raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized") @@ -366,6 +372,13 @@ def is_awq_quantized(self) -> bool: and self.pretrained_config.quantization_config.get("quant_method", None) == "awq" ) + @property + def is_torchao_quantized(self) -> bool: + return self.config.quantization_scheme == "torchao" or ( + hasattr(self.pretrained_config, "quantization_config") + and self.pretrained_config.quantization_config.get("quant_method", None) == "torchao" + ) + @property def is_exllamav2(self) -> bool: return ( diff --git a/optimum_benchmark/backends/pytorch/config.py b/optimum_benchmark/backends/pytorch/config.py index 225718e5..ec48f639 100644 --- a/optimum_benchmark/backends/pytorch/config.py +++ b/optimum_benchmark/backends/pytorch/config.py @@ -9,7 +9,7 @@ AMP_DTYPES = ["bfloat16", "float16"] TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"] -QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}, "gptq": {}, "awq": {}} +QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}, "gptq": {}, "awq": {}, "torchao": {}} @dataclass From 9104793fa9ba932870d9ceb029cd628a1388b11e Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Tue, 26 Nov 2024 16:12:50 +0100 Subject: [PATCH 2/2] fix llamacpp and windows libuv (#298) --- optimum_benchmark/backends/base.py | 7 ++----- optimum_benchmark/backends/llama_cpp/backend.py | 9 ++------- optimum_benchmark/backends/pytorch/backend.py | 2 +- optimum_benchmark/launchers/torchrun/launcher.py | 5 ----- tests/configs/_gguf_.yaml | 4 ++-- tests/test_cli.py | 6 ++++++ 6 files changed, 13 insertions(+), 20 deletions(-) diff --git a/optimum_benchmark/backends/base.py b/optimum_benchmark/backends/base.py index 6726f91f..1c039163 100644 --- a/optimum_benchmark/backends/base.py +++ b/optimum_benchmark/backends/base.py @@ -70,14 +70,11 @@ def __init__(self, config: BackendConfigT): elif self.config.library == "llama_cpp": self.logger.info("\t+ Benchmarking a LlamaCpp model") - # TOD: need a custom method to extract shapes from gguf - self.model_shapes = extract_transformers_shapes_from_artifacts( - self.pretrained_config, self.pretrained_processor - ) self.pretrained_processor = None - self.generation_config = None self.pretrained_config = None + self.generation_config = None self.automodel_loader = None + self.model_shapes = {} else: self.logger.info("\t+ Benchmarking a Transformers model") diff --git a/optimum_benchmark/backends/llama_cpp/backend.py b/optimum_benchmark/backends/llama_cpp/backend.py index 06215cbf..c9d6bbf8 100644 --- a/optimum_benchmark/backends/llama_cpp/backend.py +++ b/optimum_benchmark/backends/llama_cpp/backend.py @@ -41,15 +41,10 @@ def llama_cpp_kwargs(self) -> Dict[str, Any]: "echo": False, } - def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]: - if self.config.task == "text-generation": - if input_shapes["batch_size"] != 1: - raise ValueError("Batch size must be 1 for LlamaCpp text generation") - - return input_shapes - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.config.task == "text-generation": + if inputs["input_ids"].shape[0] != 1: + raise ValueError("Batch size must be 1 for LlamaCpp text generation") return {"tokens": inputs["input_ids"].squeeze(0).tolist()} elif self.config.task == "feature-extraction": diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index c7ddb14f..ba76f10a 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -25,7 +25,7 @@ from .config import PyTorchConfig if is_deepspeed_available(): - import deepspeed + import deepspeed # type: ignore if is_torch_distributed_available(): import torch.distributed diff --git a/optimum_benchmark/launchers/torchrun/launcher.py b/optimum_benchmark/launchers/torchrun/launcher.py index 98c076ee..10b45d4d 100644 --- a/optimum_benchmark/launchers/torchrun/launcher.py +++ b/optimum_benchmark/launchers/torchrun/launcher.py @@ -1,5 +1,4 @@ import os -import sys import traceback from contextlib import ExitStack from logging import Logger @@ -155,10 +154,6 @@ def entrypoint(worker: Callable[..., BenchmarkReport], worker_args: List[Any], l else: setup_logging(level="ERROR", to_file=log_to_file, prefix=f"RANK-PROCESS-{rank}") - if sys.platform == "win32": - logger.info("\t+ Disabline libuv on Windows") - os.environ["USE_LIBUV"] = "0" - if torch.cuda.is_available(): logger.info(f"\t+ Setting torch.distributed cuda device to {rank}") device = torch.device("cuda", rank) diff --git a/tests/configs/_gguf_.yaml b/tests/configs/_gguf_.yaml index 007a03e7..41ef8027 100644 --- a/tests/configs/_gguf_.yaml +++ b/tests/configs/_gguf_.yaml @@ -2,6 +2,6 @@ hydra: mode: MULTIRUN sweeper: params: + backend.model: ggml-org/models backend.task: text-generation,feature-extraction - backend.model: QuantFactory/gpt2-GGUF - backend.filename: gpt2.Q4_0.gguf + backend.filename: tinyllamas/stories15M-q8_0.gguf diff --git a/tests/test_cli.py b/tests/test_cli.py index c18b26fb..3a510806 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -53,6 +53,9 @@ def test_cli_configs(config_name): @pytest.mark.parametrize("launcher", ["inline", "process", "torchrun"]) def test_cli_exit_code_0(launcher): + if launcher == "torchrun" and sys.platform == "win32": + pytest.skip("torchrun is not supported on Windows") + args_0 = [ "optimum-benchmark", "--config-dir", @@ -73,6 +76,9 @@ def test_cli_exit_code_0(launcher): @pytest.mark.parametrize("launcher", ["inline", "process", "torchrun"]) def test_cli_exit_code_1(launcher): + if launcher == "torchrun" and sys.platform == "win32": + pytest.skip("torchrun is not supported on Windows") + args_1 = [ "optimum-benchmark", "--config-dir",