From 4a05fc1f3a5421f19d020d997dbccbd6448c5dee Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 9 Dec 2024 12:33:35 +0100 Subject: [PATCH] Add per_step diffusion measurments (#303) --- optimum_benchmark/backends/base.py | 2 +- optimum_benchmark/backends/diffusers_utils.py | 36 +- optimum_benchmark/backends/ipex/backend.py | 1 - .../scenarios/energy_star/scenario.py | 25 +- .../scenarios/inference/scenario.py | 140 +++--- .../scenarios/training/scenario.py | 6 +- optimum_benchmark/task_utils.py | 13 +- optimum_benchmark/trackers/__init__.py | 20 +- optimum_benchmark/trackers/latency.py | 463 ++++++++++++------ tests/configs/_gguf_.yaml | 4 +- tests/configs/_inference_.yaml | 6 +- tests/test_api.py | 49 +- 12 files changed, 453 insertions(+), 312 deletions(-) diff --git a/optimum_benchmark/backends/base.py b/optimum_benchmark/backends/base.py index 8488b457..78d0bef7 100644 --- a/optimum_benchmark/backends/base.py +++ b/optimum_benchmark/backends/base.py @@ -55,8 +55,8 @@ def __init__(self, config: BackendConfigT): if self.config.library == "diffusers": self.logger.info("\t+ Benchmarking a Diffusers pipeline") self.pretrained_config = get_diffusers_pretrained_config(self.config.model, **self.config.model_kwargs) - self.model_shapes = extract_diffusers_shapes_from_model(self.config.model, **self.config.model_kwargs) self.automodel_loader = get_diffusers_auto_pipeline_class_for_task(self.config.task) + self.model_shapes = extract_diffusers_shapes_from_model() self.pretrained_processor = None self.generation_config = None diff --git a/optimum_benchmark/backends/diffusers_utils.py b/optimum_benchmark/backends/diffusers_utils.py index 345d30f0..a2d75bc8 100644 --- a/optimum_benchmark/backends/diffusers_utils.py +++ b/optimum_benchmark/backends/diffusers_utils.py @@ -1,8 +1,5 @@ -import warnings from typing import Dict -from hydra.utils import get_class - from ..import_utils import is_diffusers_available from ..task_utils import TASKS_TO_AUTO_PIPELINE_CLASS_NAMES, map_from_synonym_task @@ -34,41 +31,10 @@ def get_diffusers_pretrained_config(model: str, **kwargs) -> Dict[str, int]: return pipeline_config -def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]: +def extract_diffusers_shapes_from_model(**kwargs) -> Dict[str, int]: if not is_diffusers_available(): raise ImportError("diffusers is not available. Please, pip install diffusers.") - model_config = get_diffusers_pretrained_config(model, **kwargs) - shapes = {} - if "vae" in model_config: - vae_import_path = model_config["vae"] - vae_class = get_class(f"{vae_import_path[0]}.{vae_import_path[1]}") - vae_config = vae_class.load_config(model, subfolder="vae", **kwargs) - shapes["num_channels"] = vae_config["out_channels"] - shapes["height"] = vae_config["sample_size"] - shapes["width"] = vae_config["sample_size"] - - elif "vae_decoder" in model_config: - vae_import_path = model_config["vae_decoder"] - vae_class = get_class(f"{vae_import_path[0]}.{vae_import_path[1]}") - vae_config = vae_class.load_config(model, subfolder="vae_decoder", **kwargs) - shapes["num_channels"] = vae_config["out_channels"] - shapes["height"] = vae_config["sample_size"] - shapes["width"] = vae_config["sample_size"] - - elif "vae_encoder" in model_config: - vae_import_path = model_config["vae_encoder"] - vae_class = get_class(f"{vae_import_path[0]}.{vae_import_path[1]}") - vae_config = vae_class.load_config(model, subfolder="vae_encoder", **kwargs) - shapes["num_channels"] = vae_config["out_channels"] - shapes["height"] = vae_config["sample_size"] - shapes["width"] = vae_config["sample_size"] - - else: - warnings.warn("Could not extract shapes [num_channels, height, width] from diffusion pipeline.") - shapes["num_channels"] = -1 - shapes["height"] = -1 - shapes["width"] = -1 return shapes diff --git a/optimum_benchmark/backends/ipex/backend.py b/optimum_benchmark/backends/ipex/backend.py index b584ff6c..7e4983a9 100644 --- a/optimum_benchmark/backends/ipex/backend.py +++ b/optimum_benchmark/backends/ipex/backend.py @@ -63,7 +63,6 @@ def _load_ipexmodel_from_pretrained(self) -> None: self.pretrained_model = self.ipexmodel_class.from_pretrained( self.config.model, export=self.config.export, - device=self.config.device, **self.config.model_kwargs, **self.automodel_kwargs, ) diff --git a/optimum_benchmark/scenarios/energy_star/scenario.py b/optimum_benchmark/scenarios/energy_star/scenario.py index 4358c405..e9cbf59d 100644 --- a/optimum_benchmark/scenarios/energy_star/scenario.py +++ b/optimum_benchmark/scenarios/energy_star/scenario.py @@ -137,7 +137,6 @@ def track(self, task_name: str): if self.config.memory: context_stack.enter_context(self.memory_tracker.track()) if self.config.latency: - self.latency_tracker.reset() context_stack.enter_context(self.latency_tracker.track()) yield @@ -173,17 +172,17 @@ def run_dataset_preprocessing_tracking(self): if self.config.energy: preprocess_energy = self.energy_tracker.get_energy() - preprocess_volume = self.dataset_preprocess_volume + self.report.preprocess_dataset.energy = preprocess_energy self.report.preprocess_dataset.efficiency = Efficiency.from_energy( - preprocess_energy, preprocess_volume, unit=PREPROCESS_EFFICIENCY_UNIT + preprocess_energy, self.dataset_preprocess_volume, unit=PREPROCESS_EFFICIENCY_UNIT ) if self.config.latency: preprocess_latency = self.latency_tracker.get_latency() - preprocess_volume = self.dataset_preprocess_volume + self.report.preprocess_dataset.latency = preprocess_latency self.report.preprocess_dataset.throughput = Throughput.from_latency( - preprocess_latency, preprocess_volume, unit=PREPROCESS_THROUGHPUT_UNIT + preprocess_latency, self.dataset_preprocess_volume, unit=PREPROCESS_THROUGHPUT_UNIT ) if self.config.memory: self.report.preprocess_dataset.memory = self.memory_tracker.get_max_memory() @@ -237,17 +236,17 @@ def run_text_generation_tracking(self): if self.config.energy: prefill_energy = self.energy_tracker.get_energy() - decode_energy = self.dataset_prefill_volume + self.report.prefill.energy = prefill_energy self.report.prefill.efficiency = Efficiency.from_energy( - prefill_energy, decode_energy, unit=PREFILL_EFFICIENCY_UNIT + prefill_energy, self.dataset_prefill_volume, unit=PREFILL_EFFICIENCY_UNIT ) if self.config.latency: prefill_latency = self.latency_tracker.get_latency() - prefill_volume = self.dataset_prefill_volume + self.report.prefill.latency = prefill_latency self.report.prefill.throughput = Throughput.from_latency( - prefill_latency, prefill_volume, unit=PREFILL_THROUGHPUT_UNIT + prefill_latency, self.dataset_prefill_volume, unit=PREFILL_THROUGHPUT_UNIT ) if self.config.memory: self.report.prefill.memory = self.memory_tracker.get_max_memory() @@ -260,18 +259,18 @@ def run_text_generation_tracking(self): if self.config.energy: generate_energy = self.energy_tracker.get_energy() decode_energy = generate_energy - prefill_energy - decode_volume = self.dataset_decode_volume + self.report.decode.energy = decode_energy self.report.decode.efficiency = Efficiency.from_energy( - decode_energy, decode_volume, unit=DECODE_EFFICIENCY_UNIT + decode_energy, self.dataset_decode_volume, unit=DECODE_EFFICIENCY_UNIT ) if self.config.latency: generate_latency = self.latency_tracker.get_latency() decode_latency = generate_latency - prefill_latency - decode_volume = self.dataset_decode_volume + self.report.decode.latency = decode_latency self.report.decode.throughput = Throughput.from_latency( - decode_latency, decode_volume, unit=DECODE_THROUGHPUT_UNIT + decode_latency, self.dataset_decode_volume, unit=DECODE_THROUGHPUT_UNIT ) if self.config.memory: self.report.decode.memory = self.memory_tracker.get_max_memory() diff --git a/optimum_benchmark/scenarios/inference/scenario.py b/optimum_benchmark/scenarios/inference/scenario.py index 4a804b5f..45461714 100644 --- a/optimum_benchmark/scenarios/inference/scenario.py +++ b/optimum_benchmark/scenarios/inference/scenario.py @@ -8,7 +8,12 @@ from ...generators.input_generator import InputGenerator from ...task_utils import IMAGE_DIFFUSION_TASKS, TEXT_GENERATION_TASKS from ...trackers.energy import Efficiency, EnergyTracker -from ...trackers.latency import LatencyTracker, PerTokenLatencyLogitsProcessor, Throughput +from ...trackers.latency import ( + LatencySessionTracker, + PerStepLatencySessionTrackerPipelineCallback, + PerTokenLatencySessionTrackerLogitsProcessor, + Throughput, +) from ...trackers.memory import MemoryTracker from ..base import Scenario from .config import InferenceConfig @@ -74,19 +79,29 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: self.logger.info("\t+ Updating Image Diffusion kwargs with default values") self.config.call_kwargs = {**IMAGE_DIFFUSION_DEFAULT_KWARGS, **self.config.call_kwargs} self.logger.info("\t+ Initializing Image Diffusion report") - self.report = BenchmarkReport.from_list(targets=["load_model", "call"]) + self.report = BenchmarkReport.from_list(targets=["load_model", "call", "per_step"]) else: self.logger.info("\t+ Initializing Inference report") self.report = BenchmarkReport.from_list(targets=["load_model", "forward"]) if self.config.latency: self.logger.info("\t+ Initializing Latency tracker") - self.latency_tracker = LatencyTracker(backend=self.backend.config.name, device=self.backend.config.device) + self.latency_tracker = LatencySessionTracker( + device=self.backend.config.device, backend=self.backend.config.name + ) if self.backend.config.task in TEXT_GENERATION_TASKS and self.backend.config.name in PER_TOKEN_BACKENDS: self.logger.info("\t+ Initializing Per-Token Latency tracker") - self.per_token_latency_tracker = PerTokenLatencyLogitsProcessor( - backend=self.backend.config.name, device=self.backend.config.device + self.per_token_latency_tracker = PerTokenLatencySessionTrackerLogitsProcessor( + device=self.backend.config.device, backend=self.backend.config.name + ) + self.config.generate_kwargs["logits_processor"] = LogitsProcessorList([self.per_token_latency_tracker]) + elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: + self.logger.info("\t+ Initializing Diffusion Step Latency tracker") + self.per_step_latency_tracker = PerStepLatencySessionTrackerPipelineCallback( + device=self.backend.config.device, backend=self.backend.config.name ) + self.config.call_kwargs["callback_on_step_end"] = self.per_step_latency_tracker + if self.config.memory: self.logger.info("\t+ Initializing Memory tracker") self.memory_tracker = MemoryTracker( @@ -94,6 +109,7 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: device=self.backend.config.device, device_ids=self.backend.config.device_ids, ) + if self.config.energy: self.logger.info("\t+ Initializing Energy tracker") self.energy_tracker = EnergyTracker( @@ -161,7 +177,7 @@ def run_model_loading_tracking(self): if self.config.memory: context_stack.enter_context(self.memory_tracker.track()) if self.config.latency: - self.latency_tracker.reset() + context_stack.enter_context(self.latency_tracker.session()) context_stack.enter_context(self.latency_tracker.track()) self.backend.load() @@ -229,21 +245,17 @@ def run_inference_memory_tracking(self): def run_per_token_text_generation_latency_tracking(self): self.logger.info("\t+ Running Per-Token Text Generation latency tracking") - self.config.generate_kwargs["logits_processor"] = LogitsProcessorList([self.per_token_latency_tracker]) - - self.per_token_latency_tracker.reset() - while ( - self.per_token_latency_tracker.elapsed() < self.config.duration - or self.per_token_latency_tracker.count() < self.config.iterations - ): - with self.per_token_latency_tracker.track(): - self.backend.generate(self.inputs, self.config.generate_kwargs) + with self.per_token_latency_tracker.session(): + while ( + self.per_token_latency_tracker.elapsed() < self.config.duration + or self.per_token_latency_tracker.count() < self.config.iterations + ): + with self.per_token_latency_tracker.track(): + self.backend.generate(self.inputs, self.config.generate_kwargs) per_token_latency = self.per_token_latency_tracker.get_per_token_latency() prefill_latency = self.per_token_latency_tracker.get_prefill_latency() decode_latency = self.per_token_latency_tracker.get_decode_latency() - prefill_volume = self.atomic_prefill_volume - decode_volume = self.atomic_decode_volume self.report.per_token.latency = per_token_latency self.report.prefill.latency = prefill_latency @@ -253,84 +265,86 @@ def run_per_token_text_generation_latency_tracking(self): # it's a confusing metric and the same signal as the decode throughput self.report.prefill.throughput = Throughput.from_latency( - prefill_latency, prefill_volume, unit=PREFILL_THROUGHPUT_UNIT + prefill_latency, self.atomic_prefill_volume, unit=PREFILL_THROUGHPUT_UNIT ) self.report.decode.throughput = Throughput.from_latency( - decode_latency, decode_volume, unit=DECODE_THROUGHPUT_UNIT + decode_latency, self.atomic_decode_volume, unit=DECODE_THROUGHPUT_UNIT ) ## Text Generation latency tracking def run_text_generation_latency_tracking(self): self.logger.info("\t+ Running Text Generation latency tracking") + prefill_kwargs = {**self.config.generate_kwargs, **TEXT_GENERATION_PREFILL_OVERRIDES} - self.latency_tracker.reset() - while ( - self.latency_tracker.elapsed() < self.config.duration - or self.latency_tracker.count() < self.config.iterations - ): - with self.latency_tracker.track(): - self.backend.prefill(self.inputs, prefill_kwargs) + with self.latency_tracker.session(): + while ( + self.latency_tracker.elapsed() < self.config.duration + or self.latency_tracker.count() < self.config.iterations + ): + with self.latency_tracker.track(): + self.backend.prefill(self.inputs, prefill_kwargs) prefill_latency = self.latency_tracker.get_latency() - prefill_volume = self.atomic_prefill_volume self.report.prefill.latency = prefill_latency self.report.prefill.throughput = Throughput.from_latency( - prefill_latency, prefill_volume, unit=PREFILL_THROUGHPUT_UNIT + prefill_latency, self.atomic_prefill_volume, unit=PREFILL_THROUGHPUT_UNIT ) - self.latency_tracker.reset() - while ( - self.latency_tracker.elapsed() < self.config.duration - or self.latency_tracker.count() < self.config.iterations - ): - with self.latency_tracker.track(): - self.backend.generate(self.inputs, self.config.generate_kwargs) + with self.latency_tracker.session(): + while ( + self.latency_tracker.elapsed() < self.config.duration + or self.latency_tracker.count() < self.config.iterations + ): + with self.latency_tracker.track(): + self.backend.generate(self.inputs, self.config.generate_kwargs) generate_latency = self.latency_tracker.get_latency() decode_latency = generate_latency - prefill_latency - decode_volume = self.atomic_decode_volume self.report.decode.latency = decode_latency self.report.decode.throughput = Throughput.from_latency( - decode_latency, decode_volume, unit=DECODE_THROUGHPUT_UNIT + decode_latency, self.atomic_decode_volume, unit=DECODE_THROUGHPUT_UNIT ) def run_image_diffusion_latency_tracking(self): self.logger.info("\t+ Running Image Diffusion latency tracking") - self.latency_tracker.reset() - while ( - self.latency_tracker.elapsed() < self.config.duration - or self.latency_tracker.count() < self.config.iterations - ): - with self.latency_tracker.track(): - self.backend.call(self.inputs, self.config.call_kwargs) + with self.per_step_latency_tracker.session(): + while ( + self.per_step_latency_tracker.elapsed() < self.config.duration + or self.per_step_latency_tracker.count() < self.config.iterations + ): + with self.per_step_latency_tracker.track(): + self.backend.call(self.inputs, self.config.call_kwargs) - call_latency = self.latency_tracker.get_latency() - call_volume = self.atomic_call_volume + call_latency = self.per_step_latency_tracker.get_call_latency() + per_step_latency = self.per_step_latency_tracker.get_step_latency() self.report.call.latency = call_latency - self.report.call.throughput = Throughput.from_latency(call_latency, call_volume, unit=CALL_THROUGHPUT_UNIT) + self.report.per_step.latency = per_step_latency + + self.report.call.throughput = Throughput.from_latency( + call_latency, self.atomic_call_volume, unit=CALL_THROUGHPUT_UNIT + ) def run_inference_latency_tracking(self): self.logger.info("\t+ Running Inference latency tracking") - self.latency_tracker.reset() - while ( - self.latency_tracker.elapsed() < self.config.duration - or self.latency_tracker.count() < self.config.iterations - ): - with self.latency_tracker.track(): - self.backend.forward(self.inputs, self.config.forward_kwargs) + with self.latency_tracker.session(): + while ( + self.latency_tracker.elapsed() < self.config.duration + or self.latency_tracker.count() < self.config.iterations + ): + with self.latency_tracker.track(): + self.backend.forward(self.inputs, self.config.forward_kwargs) forward_latency = self.latency_tracker.get_latency() - forward_volume = self.atomic_forward_volume self.report.forward.latency = forward_latency self.report.forward.throughput = Throughput.from_latency( - forward_latency, forward_volume, unit=FORWARD_THROUGHPUT_UNIT + forward_latency, self.atomic_forward_volume, unit=FORWARD_THROUGHPUT_UNIT ) ## Energy tracking @@ -349,11 +363,10 @@ def run_text_generation_energy_tracking(self): count += 1 prefill_energy = self.energy_tracker.get_energy() / count - prefill_volume = self.atomic_prefill_volume self.report.prefill.energy = prefill_energy self.report.prefill.efficiency = Efficiency.from_energy( - prefill_energy, prefill_volume, unit=PREFILL_EFFICIENCY_UNIT + prefill_energy, self.atomic_prefill_volume, unit=PREFILL_EFFICIENCY_UNIT ) count = 0 @@ -368,11 +381,10 @@ def run_text_generation_energy_tracking(self): generate_energy = self.energy_tracker.get_energy() / count decode_energy = generate_energy - prefill_energy - decode_volume = self.atomic_decode_volume self.report.decode.energy = decode_energy self.report.decode.efficiency = Efficiency.from_energy( - decode_energy, decode_volume, unit=DECODE_EFFICIENCY_UNIT + decode_energy, self.atomic_decode_volume, unit=DECODE_EFFICIENCY_UNIT ) def run_image_diffusion_energy_tracking(self): @@ -389,10 +401,11 @@ def run_image_diffusion_energy_tracking(self): count += 1 call_energy = self.energy_tracker.get_energy() / count - call_volume = self.atomic_call_volume self.report.call.energy = call_energy - self.report.call.efficiency = Efficiency.from_energy(call_energy, call_volume, unit=CALL_EFFICIENCY_UNIT) + self.report.call.efficiency = Efficiency.from_energy( + call_energy, self.atomic_call_volume, unit=CALL_EFFICIENCY_UNIT + ) def run_inference_energy_tracking(self): self.logger.info("\t+ Running energy tracking") @@ -408,11 +421,10 @@ def run_inference_energy_tracking(self): count += 1 forward_energy = self.energy_tracker.get_energy() / count - forward_volume = self.atomic_forward_volume self.report.forward.energy = forward_energy self.report.forward.efficiency = Efficiency.from_energy( - forward_energy, forward_volume, unit=FORWARD_EFFICIENCY_UNIT + forward_energy, self.atomic_forward_volume, unit=FORWARD_EFFICIENCY_UNIT ) @property diff --git a/optimum_benchmark/scenarios/training/scenario.py b/optimum_benchmark/scenarios/training/scenario.py index 7ee3ff0d..63b51ad9 100644 --- a/optimum_benchmark/scenarios/training/scenario.py +++ b/optimum_benchmark/scenarios/training/scenario.py @@ -6,7 +6,7 @@ from ...benchmark.report import BenchmarkReport from ...generators.dataset_generator import DatasetGenerator from ...trackers.energy import Efficiency, EnergyTracker -from ...trackers.latency import StepLatencyTrainerCallback, Throughput +from ...trackers.latency import StepLatencyTrackerTrainerCallback, Throughput from ...trackers.memory import MemoryTracker from ..base import Scenario from .config import TrainingConfig @@ -40,7 +40,9 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: with ExitStack() as context_stack: if self.config.latency: - latency_callback = StepLatencyTrainerCallback(device=backend.config.device, backend=backend.config.name) + latency_callback = StepLatencyTrackerTrainerCallback( + device=backend.config.device, backend=backend.config.name + ) training_callbackes.append(latency_callback) if self.config.memory: memory_tracker = MemoryTracker( diff --git a/optimum_benchmark/task_utils.py b/optimum_benchmark/task_utils.py index c5b43739..51be2ee5 100644 --- a/optimum_benchmark/task_utils.py +++ b/optimum_benchmark/task_utils.py @@ -254,6 +254,9 @@ def infer_task_from_model_name_or_path( if inferred_task_name is not None: break + if inferred_task_name is None: + raise KeyError(f"Could not find the proper task name for target class name {target_class_name}.") + elif library_name == "transformers": transformers_config = get_repo_config(model_name_or_path, "config.json", token=token, revision=revision) target_class_name = transformers_config["architectures"][0] @@ -266,8 +269,8 @@ def infer_task_from_model_name_or_path( if inferred_task_name is not None: break - if inferred_task_name is None: - raise KeyError(f"Could not find the proper task name for {auto_model_class_name}.") + if inferred_task_name is None: + raise KeyError(f"Could not find the proper task name for target class name {target_class_name}.") return map_from_synonym_task(inferred_task_name) @@ -302,11 +305,11 @@ def infer_model_type_from_model_name_or_path( if inferred_model_type is not None: break + if inferred_model_type is None: + raise KeyError(f"Could not find the proper model type for target class name {target_class_name}.") + elif library_name == "transformers": transformers_config = get_repo_config(model_name_or_path, "config.json", token=token, revision=revision) inferred_model_type = transformers_config["model_type"] - if inferred_model_type is None: - raise KeyError(f"Could not find the proper model type for {model_name_or_path}.") - return inferred_model_type diff --git a/optimum_benchmark/trackers/__init__.py b/optimum_benchmark/trackers/__init__.py index b5c95acc..22fd146b 100644 --- a/optimum_benchmark/trackers/__init__.py +++ b/optimum_benchmark/trackers/__init__.py @@ -1,16 +1,26 @@ from .energy import Efficiency, Energy, EnergyTracker -from .latency import Latency, LatencyTracker, PerTokenLatencyLogitsProcessor, StepLatencyTrainerCallback, Throughput +from .latency import ( + Latency, + LatencySessionTracker, + LatencyTracker, + PerStepLatencySessionTrackerPipelineCallback, + PerTokenLatencySessionTrackerLogitsProcessor, + StepLatencyTrackerTrainerCallback, + Throughput, +) from .memory import Memory, MemoryTracker __all__ = [ + "Efficiency", "Energy", "EnergyTracker", "Latency", + "LatencySessionTracker", "LatencyTracker", + "PerStepLatencySessionTrackerPipelineCallback", + "PerTokenLatencySessionTrackerLogitsProcessor", + "StepLatencyTrackerTrainerCallback", + "Throughput", "Memory", "MemoryTracker", - "PerTokenLatencyLogitsProcessor", - "StepLatencyTrainerCallback", - "Throughput", - "Efficiency", ] diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 2850c503..f3850e1c 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -8,7 +8,7 @@ import torch from rich.console import Console from rich.markdown import Markdown -from transformers import LogitsProcessor, TrainerCallback +from transformers import TrainerCallback CONSOLE = Console() LOGGER = getLogger("latency") @@ -46,9 +46,7 @@ def __getitem__(self, index) -> float: def __sub__(self, latency: "Latency") -> "Latency": latencies = [lat - latency.mean for lat in self.values] - assert not any( - latency < 0 for latency in latencies - ), "Negative latency detected. Please increase the dimensions of your benchmark (inputs/warmup/iterations)." + assert all(latency >= 0 for latency in latencies) return Latency.from_values(values=latencies, unit=self.unit) @@ -176,78 +174,39 @@ def __init__(self, device: str, backend: str): else: LOGGER.info("\t\t+ Tracking latency using CPU performance counter") - self.start_time: Optional[float] = None - self.start_events: List[Union[float, torch.cuda.Event]] = [] - self.end_events: List[Union[float, torch.cuda.Event]] = [] - - def reset(self): - self.start_time = None - self.start_events = [] - self.end_events = [] + self.start_event: Optional[Union[float, torch.cuda.Event]] = None + self.end_event: Optional[Union[float, torch.cuda.Event]] = None @contextmanager def track(self): if self.is_pytorch_cuda: - yield from self._pytorch_cuda_latency() - else: - yield from self._cpu_latency() - - def _pytorch_cuda_latency(self): - self.start_events.append(torch.cuda.Event(enable_timing=True)) - self.start_events[-1].record() - - yield - - self.end_events.append(torch.cuda.Event(enable_timing=True)) - self.end_events[-1].record() + self.start_event = torch.cuda.Event(enable_timing=True) + self.end_event = torch.cuda.Event(enable_timing=True) - def _cpu_latency(self): - self.start_events.append(time.perf_counter()) - - yield - - self.end_events.append(time.perf_counter()) + self.start_event.record() + yield + self.end_event.record() + else: + self.start_event = time.perf_counter() + yield + self.end_event = time.perf_counter() def get_latency(self) -> Latency: - assert len(self.start_events) == len( - self.end_events - ), "Mismatched number of start and end events, get_latency() should only be called outside of track() context" + assert self.start_event is not None and self.end_event is not None if self.is_pytorch_cuda: torch.cuda.synchronize() - - latencies_list = [ - self.start_events[i].elapsed_time(self.end_events[i]) / 1e3 for i in range(len(self.start_events)) - ] + latency = self.start_event.elapsed_time(self.end_event) / 1e3 else: - latencies_list = [(self.end_events[i] - self.start_events[i]) for i in range(len(self.start_events))] - - assert not any( - latency < 0 for latency in latencies_list - ), "Negative latency detected. Please increase the dimensions of your benchmark (inputs/warmup/iterations)." - - return Latency.from_values(latencies_list, unit=LATENCY_UNIT) - - def count(self): - assert len(self.start_events) == len( - self.end_events - ), "Mismatched number of start and end events, count() should only be called outside of track() context" - - return len(self.start_events) + latency = self.end_event - self.start_event - def elapsed(self): - if self.start_time is None: - assert ( - len(self.start_events) == 0 and len(self.end_events) == 0 - ), "Number of recorded events is not zero, make sure to reset() the tracker properly" - - self.start_time = time.perf_counter() + assert latency >= 0 - return time.perf_counter() - self.start_time + return Latency.from_values([latency], unit=LATENCY_UNIT) -class StepLatencyTrainerCallback(TrainerCallback): - def __init__(self, device: str, backend: str) -> None: +class LatencySessionTracker: + def __init__(self, device: str, backend: str): self.device = device self.backend = backend @@ -261,46 +220,67 @@ def __init__(self, device: str, backend: str) -> None: self.start_events: List[Union[float, torch.cuda.Event]] = [] self.end_events: List[Union[float, torch.cuda.Event]] = [] - def reset(self): + self.start_time: Optional[float] = None + + @contextmanager + def session(self): + assert self.start_time is None + self.start_events = [] self.end_events = [] - def on_step_begin(self, *args, **kwargs): - if self.is_pytorch_cuda: - self.start_events.append(torch.cuda.Event(enable_timing=True)) - self.start_events[-1].record() - else: - self.start_events.append(time.perf_counter()) + self.start_time = time.perf_counter() + yield + self.start_time = None - def on_step_end(self, *args, **kwargs): + def count(self) -> int: + assert self.start_time is not None, "This method can only be called inside of a '.session()' context" + assert len(self.start_events) == len(self.end_events) + + return len(self.start_events) + + def elapsed(self): + assert self.start_time is not None, "This method can only be called inside of a '.session()' context" + + return time.perf_counter() - self.start_time + + @contextmanager + def track(self): if self.is_pytorch_cuda: - self.end_events.append(torch.cuda.Event(enable_timing=True)) - self.end_events[-1].record() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + yield + end_event.record() else: - self.end_events.append(time.perf_counter()) + start_event = time.perf_counter() + yield + end_event = time.perf_counter() + + self.start_events.append(start_event) + self.end_events.append(end_event) def get_latency(self) -> Latency: - assert len(self.start_events) == len( - self.end_events - ), "Mismatched number of start and end events, get_latency() should only be called outside of track() context" + assert len(self.end_events) == len(self.start_events) >= 0 if self.is_pytorch_cuda: torch.cuda.synchronize() - - latencies_list = [ - self.start_events[i].elapsed_time(self.end_events[i]) / 1e3 for i in range(len(self.start_events)) + latencies = [ + start_event.elapsed_time(end_event) / 1e3 + for start_event, end_event in zip(self.start_events, self.end_events) ] else: - latencies_list = [(self.end_events[i] - self.start_events[i]) for i in range(len(self.start_events))] + latencies = [ + (end_event - start_event) for start_event, end_event in zip(self.start_events, self.end_events) + ] - assert not any( - latency < 0 for latency in latencies_list - ), "Negative latency detected. Please increase the dimensions of your benchmark (inputs/warmup/iterations)." + assert all(latency >= 0 for latency in latencies) - return Latency.from_values(latencies_list, unit=LATENCY_UNIT) + return Latency.from_values(latencies, unit=LATENCY_UNIT) -class PerTokenLatencyLogitsProcessor(LogitsProcessor): +class PerTokenLatencySessionTrackerLogitsProcessor: def __init__(self, device: str, backend: str): self.device = device self.backend = backend @@ -312,144 +292,311 @@ def __init__(self, device: str, backend: str): else: LOGGER.info("\t\t+ Tracking latency using CPU performance counter") - self.start_time: Optional[float] = None - self.prefilled: Optional[bool] = None - - self.per_token_events: List[List[Union[float, torch.cuda.Event]]] = [] self.prefill_start_events: List[Union[float, torch.cuda.Event]] = [] self.prefill_end_events: List[Union[float, torch.cuda.Event]] = [] + self.per_token_start_events: List[Union[float, torch.cuda.Event]] = [] + self.per_token_end_events: List[Union[float, torch.cuda.Event]] = [] + self.per_token_events: List[Union[float, torch.cuda.Event]] = [] self.decode_start_events: List[Union[float, torch.cuda.Event]] = [] self.decode_end_events: List[Union[float, torch.cuda.Event]] = [] - def reset(self): - self.start_time = None - self.prefilled = None + self.start_time: Optional[float] = None + + @contextmanager + def session(self): + assert self.start_time is None - self.per_token_events = [] self.prefill_start_events = [] self.prefill_end_events = [] + self.per_token_start_events = [] + self.per_token_end_events = [] + self.per_token_events = [] self.decode_start_events = [] self.decode_end_events = [] - @contextmanager - def track(self): - self.prefilled = False - self.per_token_events.append([]) + self.start_time = time.perf_counter() + yield + self.start_time = None - if self.is_pytorch_cuda: - self.prefill_start_events.append(torch.cuda.Event(enable_timing=True)) - self.prefill_start_events[-1].record() - else: - self.prefill_start_events.append(time.perf_counter()) + def count(self) -> int: + assert self.start_time is not None, "This method can only be called inside of a '.session()' context" + assert ( + len(self.prefill_start_events) + == len(self.prefill_end_events) + == len(self.decode_start_events) + == len(self.decode_end_events) + ) - yield + return len(self.prefill_start_events) + def elapsed(self): + assert self.start_time is not None, "This method can only be called inside of a '.session()' context" + + return time.perf_counter() - self.start_time + + @contextmanager + def track(self): if self.is_pytorch_cuda: - self.decode_end_events.append(torch.cuda.Event(enable_timing=True)) - self.decode_end_events[-1].record() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + yield + end_event.record() else: - self.decode_end_events.append(time.perf_counter()) + start_event = time.perf_counter() + yield + end_event = time.perf_counter() - self.prefilled = False + self.prefill_start_events.append(start_event) + self.decode_end_events.append(end_event) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): - assert ( - self.prefilled is not None - ), "PerTokenLatencyLogitsProcessor should only be called inside of track() context" + self.per_token_start_events.extend(self.per_token_events[:-1]) + self.per_token_end_events.extend(self.per_token_events[1:]) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): if self.is_pytorch_cuda: event = torch.cuda.Event(enable_timing=True) event.record() else: event = time.perf_counter() - if not self.prefilled: + if len(self.prefill_start_events) == len(self.prefill_end_events): + # on the first call (prefill), there will be the same number of prefill/decode start/end events self.prefill_end_events.append(event) self.decode_start_events.append(event) - self.prefilled = True - self.per_token_events[-1].append(event) + self.per_token_events.append(event) return scores def get_prefill_latency(self) -> Latency: + assert len(self.prefill_start_events) == len(self.prefill_end_events) > 0 + if self.is_pytorch_cuda: torch.cuda.synchronize() - latencies_list = [ - self.prefill_start_events[i].elapsed_time(self.prefill_end_events[i]) / 1e3 - for i in range(len(self.prefill_start_events)) + latencies = [ + start_event.elapsed_time(end_event) / 1e3 + for start_event, end_event in zip(self.prefill_start_events, self.prefill_end_events) ] else: - latencies_list = [ - (self.prefill_end_events[i] - self.prefill_start_events[i]) - for i in range(len(self.prefill_start_events)) + latencies = [ + (end_event - start_event) + for start_event, end_event in zip(self.prefill_start_events, self.prefill_end_events) ] - assert not any( - latency < 0 for latency in latencies_list - ), "Negative latency detected. Please increase the dimensions of your benchmark (inputs/warmup/iterations)." + assert all(latency >= 0 for latency in latencies) - return Latency.from_values(latencies_list, unit=LATENCY_UNIT) + return Latency.from_values(latencies, unit=LATENCY_UNIT) def get_decode_latency(self) -> Latency: + assert len(self.decode_start_events) == len(self.decode_end_events) > 0 + if self.is_pytorch_cuda: torch.cuda.synchronize() - latencies_list = [ - self.decode_start_events[i].elapsed_time(self.decode_end_events[i]) / 1e3 - for i in range(len(self.decode_start_events)) + latencies = [ + start_event.elapsed_time(end_event) / 1e3 + for start_event, end_event in zip(self.decode_start_events, self.decode_end_events) ] else: - latencies_list = [ - (self.decode_end_events[i] - self.decode_start_events[i]) for i in range(len(self.decode_start_events)) + latencies = [ + (end_event - start_event) + for start_event, end_event in zip(self.decode_start_events, self.decode_end_events) ] - assert not any( - latency < 0 for latency in latencies_list - ), "Negative latency detected. Please increase the dimensions of your benchmark (inputs/warmup/iterations)." + assert all(latency >= 0 for latency in latencies) - return Latency.from_values(latencies_list, unit=LATENCY_UNIT) + return Latency.from_values(latencies, unit=LATENCY_UNIT) def get_per_token_latency(self) -> Latency: - assert ( - len(self.per_token_events) > 0 - ), "No per-token events recorded, make sure to pass the PerTokenLatencyLogitsProcessor to the generate() method" + assert len(self.per_token_start_events) == len(self.per_token_end_events) > 0 if self.is_pytorch_cuda: torch.cuda.synchronize() - latencies_list = [ - self.per_token_events[i][j].elapsed_time(self.per_token_events[i][j + 1]) / 1e3 - for i in range(len(self.per_token_events)) - for j in range(0, len(self.per_token_events[i]) - 1) + latencies = [ + start_event.elapsed_time(end_event) / 1e3 + for start_event, end_event in zip(self.per_token_start_events, self.per_token_end_events) ] else: - latencies_list = [ - (self.per_token_events[i][j + 1] - self.per_token_events[i][j]) - for i in range(len(self.per_token_events)) - for j in range(0, len(self.per_token_events[i]) - 1) + latencies = [ + (end_event - start_event) + for start_event, end_event in zip(self.per_token_start_events, self.per_token_end_events) ] - assert not any( - latency < 0 for latency in latencies_list - ), "Negative latency detected. Please increase the dimensions of your benchmark (inputs/warmup/iterations)." + assert all(latency >= 0 for latency in latencies) - return Latency.from_values(latencies_list, unit=LATENCY_UNIT) + return Latency.from_values(latencies, unit=LATENCY_UNIT) - def count(self): - assert len(self.prefill_start_events) == len( - self.prefill_end_events - ), "Mismatched number of start and end events, count() should only be called outside of track() context" - return len(self.prefill_start_events) +class PerStepLatencySessionTrackerPipelineCallback: + tensor_inputs = [] - def elapsed(self): - if self.start_time is None: - assert ( - len(self.prefill_start_events) == 0 and len(self.prefill_end_events) == 0 - ), "Number of recorded events is not zero, make sure to reset() the tracker properly" + def __init__(self, device: str, backend: str): + self.device = device + self.backend = backend - self.start_time = time.perf_counter() + self.is_pytorch_cuda = (self.backend, self.device) == ("pytorch", "cuda") + + if self.is_pytorch_cuda: + LOGGER.info("\t\t+ Tracking latency using Pytorch CUDA events") + else: + LOGGER.info("\t\t+ Tracking latency using CPU performance counter") + + self.call_start_events: List[Union[float, torch.cuda.Event]] = [] + self.call_end_events: List[Union[float, torch.cuda.Event]] = [] + self.per_step_start_events: List[Union[float, torch.cuda.Event]] = [] + self.per_step_end_events: List[Union[float, torch.cuda.Event]] = [] + self.per_step_events: List[Union[float, torch.cuda.Event]] = [] + + self.start_time: Optional[float] = None + + @contextmanager + def session(self): + assert self.start_time is None + + self.call_start_events = [] + self.call_end_events = [] + self.per_step_start_events = [] + self.per_step_end_events = [] + self.per_step_events = [] + + self.start_time = time.perf_counter() + yield + self.start_time = None + + def count(self) -> int: + assert self.start_time is not None, "This method can only be called inside of a '.session()' context" + assert len(self.call_start_events) == len(self.call_start_events) + + return len(self.call_start_events) + + def elapsed(self): + assert self.start_time is not None, "This method can only be called inside of a '.session()' context" return time.perf_counter() - self.start_time + + @contextmanager + def track(self): + if self.is_pytorch_cuda: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + yield + end_event.record() + else: + start_event = time.perf_counter() + yield + end_event = time.perf_counter() + + self.call_start_events.append(start_event) + self.call_end_events.append(end_event) + + self.per_step_start_events.extend(self.per_step_events[:-1]) + self.per_step_end_events.extend(self.per_step_events[1:]) + + def __call__(self, pipeline, step_index, timestep, callback_kwargs): + if self.is_pytorch_cuda: + event = torch.cuda.Event(enable_timing=True) + event.record() + else: + event = time.perf_counter() + + self.per_step_events.append(event) + + return callback_kwargs + + def get_step_latency(self) -> Latency: + assert len(self.per_step_start_events) == len(self.per_step_end_events) > 0 + + if self.is_pytorch_cuda: + torch.cuda.synchronize() + + latencies = [ + start_event.elapsed_time(end_event) / 1e3 + for start_event, end_event in zip(self.per_step_start_events, self.per_step_end_events) + ] + else: + latencies = [ + (end_event - start_event) + for start_event, end_event in zip(self.per_step_start_events, self.per_step_end_events) + ] + + assert all(latency >= 0 for latency in latencies) + + return Latency.from_values(latencies, unit=LATENCY_UNIT) + + def get_call_latency(self) -> Latency: + assert len(self.call_start_events) == len(self.call_end_events) > 0 + + if self.is_pytorch_cuda: + torch.cuda.synchronize() + + latencies = [ + start_event.elapsed_time(end_event) / 1e3 + for start_event, end_event in zip(self.call_start_events, self.call_end_events) + ] + else: + latencies = [ + (end_event - start_event) + for start_event, end_event in zip(self.call_start_events, self.call_end_events) + ] + + assert all(latency >= 0 for latency in latencies) + + return Latency.from_values(latencies, unit=LATENCY_UNIT) + + +class StepLatencyTrackerTrainerCallback(TrainerCallback): + def __init__(self, device: str, backend: str) -> None: + self.device = device + self.backend = backend + + self.is_pytorch_cuda = (self.backend, self.device) == ("pytorch", "cuda") + + if self.is_pytorch_cuda: + LOGGER.info("\t\t+ Tracking latency using Pytorch CUDA events") + else: + LOGGER.info("\t\t+ Tracking latency using CPU performance counter") + + self.start_events: List[Union[float, torch.cuda.Event]] = [] + self.end_events: List[Union[float, torch.cuda.Event]] = [] + + def on_step_begin(self, *args, **kwargs): + if self.is_pytorch_cuda: + event = torch.cuda.Event(enable_timing=True) + event.record() + else: + event = time.perf_counter() + + self.start_events.append(event) + + def on_step_end(self, *args, **kwargs): + if self.is_pytorch_cuda: + event = torch.cuda.Event(enable_timing=True) + event.record() + else: + event = time.perf_counter() + + self.end_events.append(event) + + def get_latency(self) -> Latency: + assert len(self.start_events) == len(self.end_events) > 0 + + if self.is_pytorch_cuda: + torch.cuda.synchronize() + latencies = [ + start_event.elapsed_time(end_event) / 1e3 + for start_event, end_event in zip(self.start_events, self.end_events) + ] + else: + latencies = [ + (end_event - start_event) for start_event, end_event in zip(self.start_events, self.end_events) + ] + + assert all(latency >= 0 for latency in latencies) + + return Latency.from_values(latencies, unit=LATENCY_UNIT) diff --git a/tests/configs/_gguf_.yaml b/tests/configs/_gguf_.yaml index 41ef8027..d4cae6ab 100644 --- a/tests/configs/_gguf_.yaml +++ b/tests/configs/_gguf_.yaml @@ -2,6 +2,6 @@ hydra: mode: MULTIRUN sweeper: params: - backend.model: ggml-org/models backend.task: text-generation,feature-extraction - backend.filename: tinyllamas/stories15M-q8_0.gguf + backend.filename: DistilGPT2-TinyStories.Q4_K_S.gguf + backend.model: mradermacher/DistilGPT2-TinyStories-GGUF diff --git a/tests/configs/_inference_.yaml b/tests/configs/_inference_.yaml index 82b2fcd6..88762ed3 100644 --- a/tests/configs/_inference_.yaml +++ b/tests/configs/_inference_.yaml @@ -14,8 +14,8 @@ scenario: sequence_length: 16 generate_kwargs: - max_new_tokens: 16 - min_new_tokens: 16 + max_new_tokens: 4 + min_new_tokens: 4 call_kwargs: - num_inference_steps: 2 + num_inference_steps: 4 diff --git a/tests/test_api.py b/tests/test_api.py index 34598a02..a0bb4754 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -23,7 +23,7 @@ from optimum_benchmark.generators.input_generator import InputGenerator from optimum_benchmark.import_utils import get_git_revision_hash from optimum_benchmark.system_utils import is_nvidia_system, is_rocm_system -from optimum_benchmark.trackers import LatencyTracker, MemoryTracker +from optimum_benchmark.trackers import LatencySessionTracker, MemoryTracker PUSH_REPO_ID = os.environ.get("PUSH_REPO_ID", "optimum-benchmark/local") @@ -55,6 +55,9 @@ @pytest.mark.parametrize("scenario", ["training", "inference"]) @pytest.mark.parametrize("library,task,model", LIBRARIES_TASKS_MODELS) def test_api_launch(device, scenario, library, task, model): + if scenario == "training" and library != "transformers": + pytest.skip("Training is only supported with transformers library models") + benchmark_name = f"{device}_{scenario}_{library}_{task}_{model}" if device == "cuda": @@ -65,24 +68,26 @@ def test_api_launch(device, scenario, library, task, model): elif is_nvidia_system(): device_isolation_action = "error" device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "0") + else: + raise RuntimeError("Using CUDA device on a machine that is neither NVIDIA nor ROCM.") else: device_isolation_action = None device_isolation = False device_ids = None - launcher_config = ProcessConfig(device_isolation=device_isolation, device_isolation_action=device_isolation_action) + launcher_config = ProcessConfig( + device_isolation=device_isolation, + device_isolation_action=device_isolation_action, + ) if scenario == "training": - if library == "transformers": - scenario_config = TrainingConfig( - memory=True, - latency=True, - energy=not is_rocm_system(), - warmup_steps=2, - max_steps=5, - ) - else: - pytest.skip("Training scenario is only available for Transformers library") + scenario_config = TrainingConfig( + memory=True, + latency=True, + energy=not is_rocm_system(), + warmup_steps=2, + max_steps=5, + ) elif scenario == "inference": scenario_config = InferenceConfig( @@ -227,12 +232,12 @@ def test_api_dataset_generator(library, task, model): @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("backend", ["pytorch", "other"]) def test_api_latency_tracker(device, backend): - tracker = LatencyTracker(device=device, backend=backend) + tracker = LatencySessionTracker(device=device, backend=backend) - tracker.reset() - while tracker.elapsed() < 2: - with tracker.track(): - time.sleep(1) + with tracker.session(): + while tracker.elapsed() < 2: + with tracker.track(): + time.sleep(1) latency = tracker.get_latency() latency.log() @@ -241,10 +246,10 @@ def test_api_latency_tracker(device, backend): assert latency.mean > 0.9 assert len(latency.values) == 2 - tracker.reset() - while tracker.count() < 2: - with tracker.track(): - time.sleep(1) + with tracker.session(): + while tracker.count() < 2: + with tracker.track(): + time.sleep(1) latency = tracker.get_latency() latency.log() @@ -273,7 +278,6 @@ def test_api_memory_tracker(device, backend): tracker = MemoryTracker(device=device, backend=backend, device_ids=device_ids) - tracker.reset() with tracker.track(): time.sleep(1) pass @@ -281,7 +285,6 @@ def test_api_memory_tracker(device, backend): initial_memory = tracker.get_max_memory() initial_memory.log() - tracker.reset() with tracker.track(): array = torch.randn((10000, 10000), dtype=torch.float64, device=device) expected_memory = array.nbytes / 1e6