diff --git a/optimum_benchmark/backends/ipex/backend.py b/optimum_benchmark/backends/ipex/backend.py index 8939fdb0..b584ff6c 100644 --- a/optimum_benchmark/backends/ipex/backend.py +++ b/optimum_benchmark/backends/ipex/backend.py @@ -84,31 +84,14 @@ def automodel_kwargs(self) -> Dict[str, Any]: if self.config.torch_dtype is not None: kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) - print(kwargs) - return kwargs @property - def is_dp_distributed(self) -> bool: + def split_between_processes(self) -> bool: return is_torch_distributed_available() and torch.distributed.is_initialized() - def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]: - if self.is_dp_distributed: - if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0: - raise ValueError( - f"Batch size {input_shapes['batch_size']} must be divisible by " - f"data parallel world size {torch.distributed.get_world_size()}" - ) - # distributing batch size across processes - input_shapes["batch_size"] //= torch.distributed.get_world_size() - - # registering input shapes for usage during model reshaping - self.input_shapes = input_shapes - - return input_shapes - def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - if self.is_dp_distributed: + if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index 223da6dc..2fffcc36 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -280,20 +280,12 @@ def quantize_onnx_files(self) -> None: if self.pretrained_config is not None: self.pretrained_config.save_pretrained(self.quantized_model) - def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]: - if self.is_dp_distributed: - if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0: - raise ValueError( - f"Batch size {input_shapes['batch_size']} must be divisible by " - f"data parallel world size {torch.distributed.get_world_size()}" - ) - # distributing batch size across processes - input_shapes["batch_size"] //= torch.distributed.get_world_size() - - return input_shapes + @property + def split_between_processes(self) -> bool: + return is_torch_distributed_available() and torch.distributed.is_initialized() def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - if self.is_dp_distributed: + if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs diff --git a/optimum_benchmark/backends/openvino/backend.py b/optimum_benchmark/backends/openvino/backend.py index 9db49fb2..f0aa1925 100644 --- a/optimum_benchmark/backends/openvino/backend.py +++ b/optimum_benchmark/backends/openvino/backend.py @@ -82,7 +82,7 @@ def load(self) -> None: if self.config.reshape: static_shapes = { key: value - for key, value in {**self.input_shapes, **self.model_shapes}.items() + for key, value in self.model_shapes.items() if key in inspect.getfullargspec(self.pretrained_model.reshape).args } if ("sequence_length" in static_shapes) and ("height" in static_shapes) and ("width" in static_shapes): @@ -135,20 +135,6 @@ def _load_ovmodel_with_no_weights(self) -> None: self.config.export = original_export self.config.model = original_model - @property - def is_dp_distributed(self) -> bool: - return is_torch_distributed_available() and torch.distributed.is_initialized() - - @property - def ovmodel_kwargs(self) -> Dict[str, Any]: - kwargs = {} - - if self.config.task in TEXT_GENERATION_TASKS: - kwargs["use_cache"] = self.config.use_cache - kwargs["use_merged"] = self.config.use_merged - - return kwargs - def quantize_automodel(self) -> None: self.logger.info("\t+ Attempting quantization") self.quantized_model = f"{self.tmpdir.name}/quantized_model" @@ -181,23 +167,22 @@ def quantize_automodel(self) -> None: batch_size=1, ) - def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]: - if self.is_dp_distributed: - if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0: - raise ValueError( - f"Batch size {input_shapes['batch_size']} must be divisible by " - f"data parallel world size {torch.distributed.get_world_size()}" - ) - # distributing batch size across processes - input_shapes["batch_size"] //= torch.distributed.get_world_size() + @property + def ovmodel_kwargs(self) -> Dict[str, Any]: + kwargs = {} - # registering input shapes for usage during model reshaping - self.input_shapes = input_shapes + if self.config.task in TEXT_GENERATION_TASKS: + kwargs["use_cache"] = self.config.use_cache + kwargs["use_merged"] = self.config.use_merged - return input_shapes + return kwargs + + @property + def split_between_processes(self) -> bool: + return is_torch_distributed_available() and torch.distributed.is_initialized() def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - if self.is_dp_distributed: + if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs @@ -205,6 +190,14 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: inputs.pop(key) + if "input_ids" in inputs: + self.model_shapes.update(dict(zip(["batch_size", "sequence_length"], inputs["input_ids"].shape))) + + if "pixel_values" in inputs: + self.model_shapes.update( + dict(zip(["batch_size", "num_channels", "height", "width"], inputs["pixel_values"].shape)) + ) + return inputs def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: diff --git a/optimum_benchmark/backends/pytorch/backend.py b/optimum_benchmark/backends/pytorch/backend.py index bfbe9745..5052e148 100644 --- a/optimum_benchmark/backends/pytorch/backend.py +++ b/optimum_benchmark/backends/pytorch/backend.py @@ -27,7 +27,7 @@ import deepspeed # type: ignore if is_torch_distributed_available(): - import torch.distributed + import torch.distributed # type: ignore if is_zentorch_available(): import zentorch # type: ignore # noqa: F401 @@ -326,18 +326,6 @@ def process_quantization_config(self) -> None: else: raise ValueError(f"Quantization scheme {self.config.quantization_scheme} not recognized") - @property - def is_distributed(self) -> bool: - return is_torch_distributed_available() and torch.distributed.is_initialized() - - @property - def is_tp_distributed(self) -> bool: - return self.is_distributed and self.config.deepspeed_inference - - @property - def is_dp_distributed(self) -> bool: - return self.is_distributed and not self.config.deepspeed_inference - @property def is_quantized(self) -> bool: return self.config.quantization_scheme is not None or ( @@ -407,35 +395,26 @@ def automodel_kwargs(self) -> Dict[str, Any]: return kwargs - def prepare_input_shapes(self, input_shapes: Dict[str, Any]) -> Dict[str, Any]: - if self.is_dp_distributed: - if input_shapes["batch_size"] % torch.distributed.get_world_size() != 0: - raise ValueError( - f"Batch size {input_shapes['batch_size']} must be divisible by " - f"data parallel world size {torch.distributed.get_world_size()}" - ) - # distributing batch size across processes - input_shapes["batch_size"] //= torch.distributed.get_world_size() - - if self.is_tp_distributed: - if torch.distributed.get_rank() != 0: - # zeroing throughput on other ranks - input_shapes["batch_size"] = 0 - - return input_shapes + @property + def split_between_processes(self) -> bool: + return ( + is_torch_distributed_available() + and torch.distributed.is_initialized() + and not self.config.deepspeed_inference + ) def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - if self.is_dp_distributed: + if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs - if self.config.library == "timm": - inputs = {"x": inputs["pixel_values"]} - for key, value in inputs.items(): if isinstance(value, torch.Tensor): inputs[key] = value.to(self.config.device) + if self.config.library == "timm": + inputs = {"x": inputs["pixel_values"]} + return inputs @torch.inference_mode() diff --git a/optimum_benchmark/backends/transformers_utils.py b/optimum_benchmark/backends/transformers_utils.py index 3b38bc2c..efd2b8af 100644 --- a/optimum_benchmark/backends/transformers_utils.py +++ b/optimum_benchmark/backends/transformers_utils.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Type, Union import torch import transformers @@ -7,6 +7,7 @@ AutoConfig, AutoFeatureExtractor, AutoImageProcessor, + AutoModel, AutoProcessor, AutoTokenizer, FeatureExtractionMixin, @@ -17,9 +18,7 @@ SpecialTokensMixin, ) -from ..import_utils import is_torch_available - -TASKS_TO_MODEL_LOADERS = { +TASKS_TO_AUTOMODEL_CLASS_NAMES = { # text processing "feature-extraction": "AutoModel", "fill-mask": "AutoModelForMaskedLM", @@ -57,34 +56,26 @@ "sentence-similarity": "feature-extraction", } -if is_torch_available(): - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {} - for task_name, model_loaders in TASKS_TO_MODEL_LOADERS.items(): - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name] = {} - - if isinstance(model_loaders, str): - model_loaders = (model_loaders,) - - for model_loader_name in model_loaders: - model_loader_class = getattr(transformers, model_loader_name, None) - if model_loader_class is not None: - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name].update( - model_loader_class._model_mapping._model_mapping - ) -else: - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {} - -def get_transformers_automodel_loader_for_task(task: str, model_type: Optional[str] = None): +def get_transformers_automodel_class_for_task(task: str, model_type: Optional[str] = None) -> Type["AutoModel"]: if task in SYNONYM_TASKS: task = SYNONYM_TASKS[task] - if model_type is not None: - model_loader_name = TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task][model_type] + if task not in TASKS_TO_AUTOMODEL_CLASS_NAMES: + raise ValueError(f"Task {task} not supported") + + if isinstance(TASKS_TO_AUTOMODEL_CLASS_NAMES[task], str): + return getattr(transformers, TASKS_TO_AUTOMODEL_CLASS_NAMES[task]) else: - model_loader_name = TASKS_TO_MODEL_LOADERS[task] + if model_type is None: + raise ValueError(f"Task {task} requires a model_type to be specified") + + for automodel_class_name in TASKS_TO_AUTOMODEL_CLASS_NAMES[task]: + automodel_class = getattr(transformers, automodel_class_name) + if model_type in automodel_class._model_mapping._model_mapping: + return automodel_class - return getattr(transformers, model_loader_name) + raise ValueError(f"Task {task} not supported for model type {model_type}") PretrainedProcessor = Union["FeatureExtractionMixin", "ImageProcessingMixin", "SpecialTokensMixin", "ProcessorMixin"] diff --git a/optimum_benchmark/benchmark/report.py b/optimum_benchmark/benchmark/report.py index c4b0602d..b9edd960 100644 --- a/optimum_benchmark/benchmark/report.py +++ b/optimum_benchmark/benchmark/report.py @@ -35,16 +35,26 @@ def __post_init__(self): self.efficiency = Efficiency(**self.efficiency) @staticmethod - def aggregate(measurements: List["TargetMeasurements"]) -> "TargetMeasurements": + def aggregate_across_processes(measurements: List["TargetMeasurements"]) -> "TargetMeasurements": assert len(measurements) > 0, "No measurements to aggregate" m0 = measurements[0] - memory = Memory.aggregate([m.memory for m in measurements]) if m0.memory is not None else None - latency = Latency.aggregate([m.latency for m in measurements]) if m0.latency is not None else None - throughput = Throughput.aggregate([m.throughput for m in measurements]) if m0.throughput is not None else None - energy = Energy.aggregate([m.energy for m in measurements]) if m0.energy is not None else None - efficiency = Efficiency.aggregate([m.efficiency for m in measurements]) if m0.efficiency is not None else None + memory = Memory.aggregate_across_processes([m.memory for m in measurements]) if m0.memory is not None else None + latency = ( + Latency.aggregate_across_processes([m.latency for m in measurements]) if m0.latency is not None else None + ) + throughput = ( + Throughput.aggregate_across_processes([m.throughput for m in measurements]) + if m0.throughput is not None + else None + ) + energy = Energy.aggregate_across_processes([m.energy for m in measurements]) if m0.energy is not None else None + efficiency = ( + Efficiency.aggregate_across_processes([m.efficiency for m in measurements]) + if m0.efficiency is not None + else None + ) return TargetMeasurements( memory=memory, latency=latency, throughput=throughput, energy=energy, efficiency=efficiency @@ -99,11 +109,11 @@ def __post_init__(self): setattr(self, target, TargetMeasurements(**getattr(self, target))) @classmethod - def aggregate(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport": + def aggregate_across_processes(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport": aggregated_measurements = {} for target in reports[0].to_dict().keys(): measurements = [getattr(report, target) for report in reports] - aggregated_measurements[target] = TargetMeasurements.aggregate(measurements) + aggregated_measurements[target] = TargetMeasurements.aggregate_across_processes(measurements) return cls.from_dict(aggregated_measurements) diff --git a/optimum_benchmark/generators/task_generator.py b/optimum_benchmark/generators/task_generator.py index f11d21eb..96dbb2e5 100644 --- a/optimum_benchmark/generators/task_generator.py +++ b/optimum_benchmark/generators/task_generator.py @@ -445,6 +445,4 @@ def __call__(self): "image-text-to-text": ImageTextToTextGenerator, # diffusers pipelines tasks "text-to-image": PromptGenerator, - "stable-diffusion": PromptGenerator, - "stable-diffusion-xl": PromptGenerator, } diff --git a/optimum_benchmark/launchers/torchrun/launcher.py b/optimum_benchmark/launchers/torchrun/launcher.py index 10b45d4d..99d5ba12 100644 --- a/optimum_benchmark/launchers/torchrun/launcher.py +++ b/optimum_benchmark/launchers/torchrun/launcher.py @@ -100,7 +100,7 @@ def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any]) raise RuntimeError(f"Received an unexpected response from isolated process: {output}") self.logger.info("\t+ Aggregating reports from all rank processes") - report = BenchmarkReport.aggregate(reports) + report = BenchmarkReport.aggregate_across_processes(reports) return report diff --git a/optimum_benchmark/scenarios/energy_star/scenario.py b/optimum_benchmark/scenarios/energy_star/scenario.py index 39b12a04..8345cae0 100644 --- a/optimum_benchmark/scenarios/energy_star/scenario.py +++ b/optimum_benchmark/scenarios/energy_star/scenario.py @@ -38,7 +38,7 @@ PREPROCESS_EFFICIENCY_UNIT = "samples/kWh" FORWARD_EFFICIENCY_UNIT = "samples/kWh" -PREFILL_EFFICIENCY_UNIT = "tokens/kWh" +PREFILL_EFFICIENCY_UNIT = "samples/kWh" DECODE_EFFICIENCY_UNIT = "tokens/kWh" CALL_EFFICIENCY_UNIT = "images/kWh" @@ -50,9 +50,9 @@ def __init__(self, config: EnergyStarConfig) -> None: super().__init__(config) def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: - self.task = backend.config.task + self.backend = backend - if self.task in TEXT_GENERATION_TASKS: + if self.backend.config.task in TEXT_GENERATION_TASKS: self.logger.info("\t+ Updating Text Generation kwargs with default values") self.config.generate_kwargs = {**TEXT_GENERATION_DEFAULT_KWARGS, **self.config.generate_kwargs} self.prefill_kwargs = {**self.config.generate_kwargs, **TEXT_GENERATION_PREFILL_OVERRIDES} @@ -60,7 +60,7 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: self.report = BenchmarkReport.from_list( targets=["load_dataset", "preprocess_dataset", "load_model", "prefill", "decode"] ) - elif self.task in IMAGE_DIFFUSION_TASKS: + elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: self.logger.info("\t+ Updating Image Diffusion kwargs with default values") self.config.call_kwargs = {**IMAGE_DIFFUSION_DEFAULT_KWARGS, **self.config.call_kwargs} self.logger.info("\t+ Initializing Image Diffusion report") @@ -80,17 +80,18 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: ) self.run_dataset_loading_energy_tracking() - self.run_model_loading_energy_tracking(backend) self.run_dataset_preprocessing_energy_tracking(backend) self.logger.info("\t+ Preparing sample inputs for model warmup") - self.raw_sample_inputs = self.dataset[: self.config.input_shapes["batch_size"]] - self.prepared_sample_inputs = backend.prepare_inputs(self.raw_sample_inputs) + self.sample_inputs = self.dataset[: self.config.input_shapes["batch_size"]] + self.sample_inputs = backend.prepare_inputs(self.sample_inputs) + + self.run_model_loading_energy_tracking(backend) - if self.task in TEXT_GENERATION_TASKS: + if self.backend.config.task in TEXT_GENERATION_TASKS: self.warmup_text_generation(backend) self.run_text_generation_energy_tracking(backend) - elif self.task in IMAGE_DIFFUSION_TASKS: + elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: self.warmup_image_diffusion(backend) self.run_image_diffusion_energy_tracking(backend) else: @@ -115,7 +116,7 @@ def run_dataset_preprocessing_energy_tracking(self, backend: Backend[BackendConf self.logger.info("\t+ Running dataset preprocessing energy tracking") with self.energy_tracker.track(file_prefix="preprocess_dataset"): - self.dataset = TASKS_TO_PREPROCESSORS[self.task]( + self.dataset = TASKS_TO_PREPROCESSORS[self.backend.config.task]( dataset=self.dataset, scenario_config=self.config, pretrained_config=backend.pretrained_config, @@ -144,24 +145,22 @@ def run_model_loading_energy_tracking(self, backend: Backend[BackendConfigT]): # Text Generation warmup def warmup_text_generation(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Warming up backend for Text Generation") - backend.generate(self.prepared_sample_inputs, self.config.generate_kwargs) + backend.generate(self.sample_inputs, self.config.generate_kwargs) for _ in range(self.config.warmup_runs): - backend.generate( - self.prepared_sample_inputs, {**self.config.generate_kwargs, **TEXT_GENERATION_WARMUP_OVERRIDES} - ) + backend.generate(self.sample_inputs, {**self.config.generate_kwargs, **TEXT_GENERATION_WARMUP_OVERRIDES}) # Image Diffusion warmup def warmup_image_diffusion(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Warming up backend for Image Diffusion") - backend.call(self.prepared_sample_inputs, self.config.call_kwargs) + backend.call(self.sample_inputs, self.config.call_kwargs) for _ in range(self.config.warmup_runs): - backend.call(self.prepared_sample_inputs, {**self.config.call_kwargs, **IMAGE_DIFFUSION_WARMUP_OVERRIDES}) + backend.call(self.sample_inputs, {**self.config.call_kwargs, **IMAGE_DIFFUSION_WARMUP_OVERRIDES}) # Inference warmup def warmup_inference(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Warming up backend for Inference") for _ in range(self.config.warmup_runs): - backend.forward(self.prepared_sample_inputs, self.config.forward_kwargs) + backend.forward(self.sample_inputs, self.config.forward_kwargs) # Text Generation energy tracking def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]): @@ -243,25 +242,8 @@ def dataset_forward_volume(self) -> int: # in samples return self.config.num_samples @property - def dataset_prefill_volume(self) -> int: # in tokens - prefill_volume = 0 - - for sample in self.dataset: - if "input_ids" in sample.keys(): - # text/image-text/video-image-text conditioned generation - prefill_volume += self.raw_sample_inputs["input_ids"].numel() - else: - # image/audio/other conditioned generation (1 bos token) - prefill_volume += 1 - - return prefill_volume - - @property - def dataset_per_token_volume(self) -> int: # in tokens - return ( - self.config.num_samples - * self.config.generate_kwargs["num_beams"] # at each beam stage there are num_beams tokens generated - ) + def dataset_prefill_volume(self) -> int: # in samples + return self.config.num_samples @property def dataset_decode_volume(self) -> int: # in tokens @@ -273,7 +255,7 @@ def dataset_decode_volume(self) -> int: # in tokens @property def dataset_call_volume(self) -> int: # in images - if self.task == "text-to-image": + if self.backend.config.task == "text-to-image": return self.config.num_samples * self.config.call_kwargs["num_images_per_prompt"] else: return self.config.num_samples diff --git a/optimum_benchmark/scenarios/inference/config.py b/optimum_benchmark/scenarios/inference/config.py index 57d482ab..d86962eb 100644 --- a/optimum_benchmark/scenarios/inference/config.py +++ b/optimum_benchmark/scenarios/inference/config.py @@ -9,7 +9,6 @@ INPUT_SHAPES = { "batch_size": 2, - "sequence_length": 16, } diff --git a/optimum_benchmark/scenarios/inference/scenario.py b/optimum_benchmark/scenarios/inference/scenario.py index 2f0ac8e7..e05cb7b9 100644 --- a/optimum_benchmark/scenarios/inference/scenario.py +++ b/optimum_benchmark/scenarios/inference/scenario.py @@ -40,13 +40,17 @@ "num_inference_steps": 2, } -TEXT_GENERATION_THROUGHPUT_UNIT = "tokens/s" -IMAGE_DIFFUSION_THROUGHPUT_UNIT = "images/s" -INFERENCE_THROUGHPUT_UNIT = "samples/s" -TEXT_GENERATION_EFFICIENCY_UNIT = "tokens/kWh" -IMAGE_DIFFUSION_EFFICIENCY_UNIT = "images/kWh" -INFERENCE_EFFICIENCY_UNIT = "samples/kWh" +FORWARD_THROUGHPUT_UNIT = "samples/s" +PREFILL_THROUGHPUT_UNIT = "samples/s" +DECODE_THROUGHPUT_UNIT = "tokens/s" +CALL_THROUGHPUT_UNIT = "images/s" + + +FORWARD_EFFICIENCY_UNIT = "samples/kWh" +PREFILL_EFFICIENCY_UNIT = "samples/kWh" +DECODE_EFFICIENCY_UNIT = "tokens/kWh" +CALL_EFFICIENCY_UNIT = "images/kWh" class InferenceScenario(Scenario[InferenceConfig]): @@ -56,77 +60,71 @@ def __init__(self, config: InferenceConfig) -> None: super().__init__(config) def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: - self.task = backend.config.task + self.backend = backend - self.logger.info("\t+ Creating input generator") - self.input_generator = InputGenerator( - task=self.task, - input_shapes=self.config.input_shapes, - model_shapes=backend.model_shapes, - model_type=backend.config.model_type, - ) - - if self.task in TEXT_GENERATION_TASKS: - self.logger.info("\t+ Generating Text Generation inputs") - self.inputs = self.input_generator() + if self.backend.config.task in TEXT_GENERATION_TASKS: self.logger.info("\t+ Updating Text Generation kwargs with default values") self.config.generate_kwargs = {**TEXT_GENERATION_DEFAULT_KWARGS, **self.config.generate_kwargs} self.logger.info("\t+ Initializing Text Generation report") self.report = BenchmarkReport.from_list(targets=["load", "prefill", "decode", "per_token"]) - elif self.task in IMAGE_DIFFUSION_TASKS: - self.logger.info("\t+ Generating Image Diffusion inputs") - self.inputs = self.input_generator() + elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: self.logger.info("\t+ Updating Image Diffusion kwargs with default values") self.config.call_kwargs = {**IMAGE_DIFFUSION_DEFAULT_KWARGS, **self.config.call_kwargs} self.logger.info("\t+ Initializing Image Diffusion report") self.report = BenchmarkReport.from_list(targets=["load", "call"]) else: - self.logger.info("\t+ Generating Inference inputs") - self.inputs = self.input_generator() self.logger.info("\t+ Initializing Inference report") self.report = BenchmarkReport.from_list(targets=["load", "forward"]) - self.logger.info("\t+ Preparing input shapes for Inference") - self.config.input_shapes = backend.prepare_input_shapes(input_shapes=self.config.input_shapes) + self.logger.info("\t+ Creating input generator") + self.input_generator = InputGenerator( + task=self.backend.config.task, + model_shapes=backend.model_shapes, + input_shapes=self.config.input_shapes, + model_type=backend.config.model_type, + ) - self.run_model_loading_tracking(backend) + self.logger.info("\t+ Generating inputs") + self.inputs = self.input_generator() self.logger.info("\t+ Preparing inputs for Inference") self.inputs = backend.prepare_inputs(inputs=self.inputs) + self.run_model_loading_tracking(backend) + if self.config.latency or self.config.energy: # latency and energy are metrics that require some warmup if self.config.warmup_runs > 0: - if self.task in TEXT_GENERATION_TASKS: + if self.backend.config.task in TEXT_GENERATION_TASKS: self.warmup_text_generation(backend) - elif self.task in IMAGE_DIFFUSION_TASKS: + elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: self.warmup_image_diffusion(backend) else: self.warmup_inference(backend) if self.config.latency: - if self.task in TEXT_GENERATION_TASKS: + if self.backend.config.task in TEXT_GENERATION_TASKS: if backend.config.name in PER_TOKEN_BACKENDS: self.run_per_token_text_generation_latency_tracking(backend) else: self.run_text_generation_latency_tracking(backend) - elif self.task in IMAGE_DIFFUSION_TASKS: + elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: self.run_image_diffusion_latency_tracking(backend) else: self.run_latency_inference_tracking(backend) if self.config.memory: - if self.task in TEXT_GENERATION_TASKS: + if self.backend.config.task in TEXT_GENERATION_TASKS: self.run_text_generation_memory_tracking(backend) - elif self.task in IMAGE_DIFFUSION_TASKS: + elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: self.run_image_diffusion_memory_tracking(backend) else: self.run_inference_memory_tracking(backend) if self.config.energy: - if self.task in TEXT_GENERATION_TASKS: + if self.backend.config.task in TEXT_GENERATION_TASKS: self.run_text_generation_energy_tracking(backend) - elif self.task in IMAGE_DIFFUSION_TASKS: + elif self.backend.config.task in IMAGE_DIFFUSION_TASKS: self.run_image_diffusion_energy_tracking(backend) else: self.run_inference_energy_tracking(backend) @@ -178,42 +176,42 @@ def run_model_loading_tracking(self, backend: Backend[BackendConfigT]): ## Memory tracking def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Running Text Generation memory tracking") - self.memory_tracker = MemoryTracker( + memory_tracker = MemoryTracker( backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids ) prefill_kwargs = {**self.config.generate_kwargs, **TEXT_GENERATION_PREFILL_OVERRIDES} - with self.memory_tracker.track(): + with memory_tracker.track(): _ = backend.prefill(self.inputs, prefill_kwargs) - self.report.prefill.memory = self.memory_tracker.get_max_memory() + self.report.prefill.memory = memory_tracker.get_max_memory() - with self.memory_tracker.track(): + with memory_tracker.track(): _ = backend.generate(self.inputs, self.config.generate_kwargs) - self.report.decode.memory = self.memory_tracker.get_max_memory() + self.report.decode.memory = memory_tracker.get_max_memory() def run_image_diffusion_memory_tracking(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Running Image Diffusion memory tracking") - self.memory_tracker = MemoryTracker( + memory_tracker = MemoryTracker( backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids ) - with self.memory_tracker.track(): + with memory_tracker.track(): _ = backend.call(self.inputs, self.config.call_kwargs) - self.report.call.memory = self.memory_tracker.get_max_memory() + self.report.call.memory = memory_tracker.get_max_memory() def run_inference_memory_tracking(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Running Inference memory tracking") - self.memory_tracker = MemoryTracker( + memory_tracker = MemoryTracker( backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids ) - with self.memory_tracker.track(): + with memory_tracker.track(): _ = backend.forward(self.inputs, self.config.forward_kwargs) - self.report.forward.memory = self.memory_tracker.get_max_memory() + self.report.forward.memory = memory_tracker.get_max_memory() ## Latency tracking def run_per_token_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]): @@ -229,7 +227,6 @@ def run_per_token_text_generation_latency_tracking(self, backend: Backend[Backen prefill_latency = latency_tracker.get_prefill_latency() decode_latency = latency_tracker.get_decode_latency() - per_token_volume = self.atomic_per_token_volume prefill_volume = self.atomic_prefill_volume decode_volume = self.atomic_decode_volume @@ -237,14 +234,12 @@ def run_per_token_text_generation_latency_tracking(self, backend: Backend[Backen self.report.prefill.latency = prefill_latency self.report.decode.latency = decode_latency - self.report.per_token.throughput = Throughput.from_latency( - per_token_latency, per_token_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT - ) + # we don't register a per-token throughput, as it's a confusing metric and the same as the decode throughput self.report.prefill.throughput = Throughput.from_latency( - prefill_latency, prefill_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + prefill_latency, prefill_volume, unit=PREFILL_THROUGHPUT_UNIT ) self.report.decode.throughput = Throughput.from_latency( - decode_latency, decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + decode_latency, decode_volume, unit=DECODE_THROUGHPUT_UNIT ) def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]): @@ -261,7 +256,7 @@ def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]) self.report.prefill.latency = prefill_latency self.report.prefill.throughput = Throughput.from_latency( - prefill_latency, prefill_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + prefill_latency, prefill_volume, unit=PREFILL_THROUGHPUT_UNIT ) latency_tracker.reset() @@ -275,7 +270,7 @@ def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]) self.report.decode.latency = decode_latency self.report.decode.throughput = Throughput.from_latency( - decode_latency, decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + decode_latency, decode_volume, unit=DECODE_THROUGHPUT_UNIT ) def run_image_diffusion_latency_tracking(self, backend: Backend[BackendConfigT]): @@ -290,9 +285,7 @@ def run_image_diffusion_latency_tracking(self, backend: Backend[BackendConfigT]) call_volume = self.atomic_call_volume self.report.call.latency = call_latency - self.report.call.throughput = Throughput.from_latency( - call_latency, call_volume, unit=IMAGE_DIFFUSION_THROUGHPUT_UNIT - ) + self.report.call.throughput = Throughput.from_latency(call_latency, call_volume, unit=CALL_THROUGHPUT_UNIT) def run_latency_inference_tracking(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Running Inference latency tracking") @@ -307,7 +300,7 @@ def run_latency_inference_tracking(self, backend: Backend[BackendConfigT]): self.report.forward.latency = forward_latency self.report.forward.throughput = Throughput.from_latency( - forward_latency, forward_volume, unit=INFERENCE_THROUGHPUT_UNIT + forward_latency, forward_volume, unit=FORWARD_THROUGHPUT_UNIT ) ## Energy tracking @@ -333,7 +326,7 @@ def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]): self.report.prefill.energy = prefill_energy self.report.prefill.efficiency = Efficiency.from_energy( - prefill_energy, prefill_volume, unit=TEXT_GENERATION_EFFICIENCY_UNIT + prefill_energy, prefill_volume, unit=PREFILL_EFFICIENCY_UNIT ) count = 0 @@ -352,7 +345,7 @@ def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]): self.report.decode.energy = decode_energy self.report.decode.efficiency = Efficiency.from_energy( - decode_energy, decode_volume, unit=TEXT_GENERATION_EFFICIENCY_UNIT + decode_energy, decode_volume, unit=DECODE_EFFICIENCY_UNIT ) def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]): @@ -375,9 +368,7 @@ def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]): call_volume = self.atomic_call_volume self.report.call.energy = call_energy - self.report.call.efficiency = Efficiency.from_energy( - call_energy, call_volume, unit=IMAGE_DIFFUSION_EFFICIENCY_UNIT - ) + self.report.call.efficiency = Efficiency.from_energy(call_energy, call_volume, unit=CALL_EFFICIENCY_UNIT) def run_inference_energy_tracking(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Running energy tracking") @@ -400,31 +391,19 @@ def run_inference_energy_tracking(self, backend: Backend[BackendConfigT]): self.report.forward.energy = forward_energy self.report.forward.efficiency = Efficiency.from_energy( - forward_energy, forward_volume, unit=INFERENCE_EFFICIENCY_UNIT + forward_energy, forward_volume, unit=FORWARD_EFFICIENCY_UNIT ) @property - def atomic_forward_volume(self) -> int: # in samples + def atomic_forward_volume(self) -> int: # in terms of processed samples return self.config.input_shapes["batch_size"] @property - def atomic_prefill_volume(self) -> int: # in tokens - if {"input_ids", "prompt", "prompts"} & set(self.inputs.keys()): - # text conditioned generation (sequence_length tokens) - return self.config.input_shapes["batch_size"] * self.config.input_shapes["sequence_length"] - else: - # image/audio conditioned generation (1 bos token) - return self.config.input_shapes["batch_size"] - - @property - def atomic_per_token_volume(self) -> int: # in tokens - return ( - self.config.input_shapes["batch_size"] - * self.config.generate_kwargs["num_beams"] # at each beam stage there are num_beams tokens generated - ) + def atomic_prefill_volume(self) -> int: # in terms of processed samples + return self.config.input_shapes["batch_size"] @property - def atomic_decode_volume(self) -> int: # in tokens + def atomic_decode_volume(self) -> int: # in terms of output/generated tokens return ( self.config.input_shapes["batch_size"] * self.config.generate_kwargs["num_beams"] # at each beam stage there are num_beams tokens generated @@ -432,8 +411,8 @@ def atomic_decode_volume(self) -> int: # in tokens ) @property - def atomic_call_volume(self) -> int: # in images - if self.task == "text-to-image": + def atomic_call_volume(self) -> int: # in terms of output images + if self.backend.config.task == "text-to-image": return self.config.input_shapes["batch_size"] * self.config.call_kwargs["num_images_per_prompt"] else: return self.config.input_shapes["batch_size"] diff --git a/optimum_benchmark/trackers/energy.py b/optimum_benchmark/trackers/energy.py index 3586809f..427c4d40 100644 --- a/optimum_benchmark/trackers/energy.py +++ b/optimum_benchmark/trackers/energy.py @@ -61,19 +61,20 @@ def __truediv__(self, scalar: float) -> "Energy": ) @staticmethod - def aggregate(energies: List["Energy"]) -> "Energy": - if len(energies) == 0 or all(energy is None for energy in energies): - return None + def aggregate_across_processes(energies: List[Optional["Energy"]]) -> Optional["Energy"]: + if len(energies) == 0: + raise ValueError("No energy measurements to aggregate") elif any(energy is None for energy in energies): raise ValueError("Some energy measurements are missing") # since measurements are machine-level, we just take the average + total = sum(energy.total for energy in energies) / len(energies) cpu = sum(energy.cpu for energy in energies) / len(energies) gpu = sum(energy.gpu for energy in energies) / len(energies) ram = sum(energy.ram for energy in energies) / len(energies) - total = sum(energy.total for energy in energies) / len(energies) + unit = energies[0].unit - return Energy(cpu=cpu, gpu=gpu, ram=ram, total=total, unit=ENERGY_UNIT) + return Energy(cpu=cpu, gpu=gpu, ram=ram, total=total, unit=unit) def to_plain_text(self) -> str: plain_text = "" @@ -109,14 +110,15 @@ class Efficiency: value: float @staticmethod - def aggregate(efficiencies: List["Efficiency"]) -> "Efficiency": + def aggregate_across_processes(efficiencies: List[Optional["Efficiency"]]) -> Optional["Efficiency"]: if len(efficiencies) == 0: raise ValueError("No efficiency measurements to aggregate") elif any(efficiency is None for efficiency in efficiencies): raise ValueError("Some efficiency measurements are None") - unit = efficiencies[0].unit + # since measurements are machine-level, we just take the average value = sum(efficiency.value for efficiency in efficiencies) / len(efficiencies) + unit = efficiencies[0].unit return Efficiency(value=value, unit=unit) diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 908108cb..de4ab341 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -53,14 +53,17 @@ def __sub__(self, latency: "Latency") -> "Latency": return Latency.from_values(values=latencies, unit=self.unit) @staticmethod - def aggregate(latencies: List["Latency"]) -> "Latency": - if len(latencies) == 0 or all(latency is None for latency in latencies): - return None + def aggregate_across_processes(latencies: List["Latency"]) -> "Latency": + if len(latencies) == 0: + raise ValueError("No latency measurements to aggregate") elif any(latency is None for latency in latencies): raise ValueError("Some latency measurements are missing") - unit = latencies[0].unit + # we combine the lists of latencies and statistics are then computed on this list values = sum((lat.values for lat in latencies), []) + + unit = latencies[0].unit + return Latency.from_values(values=values, unit=unit) @staticmethod @@ -123,14 +126,15 @@ class Throughput: value: float @staticmethod - def aggregate(throughputs: List["Throughput"]) -> "Throughput": + def aggregate_across_processes(throughputs: List[Optional["Throughput"]]) -> Optional["Throughput"]: if len(throughputs) == 0: raise ValueError("No throughput measurements to aggregate") elif any(throughput is None for throughput in throughputs): raise ValueError("Some throughput measurements are missing") + # we compute throughputs on the whole input level so we just take the average + value = sum(throughput.value for throughput in throughputs) / len(throughputs) unit = throughputs[0].unit - value = sum(throughput.value for throughput in throughputs) return Throughput(value=value, unit=unit) diff --git a/optimum_benchmark/trackers/memory.py b/optimum_benchmark/trackers/memory.py index 5e9359b1..47edf71e 100644 --- a/optimum_benchmark/trackers/memory.py +++ b/optimum_benchmark/trackers/memory.py @@ -52,16 +52,14 @@ class Memory: max_allocated: Optional[float] = None @staticmethod - def aggregate(memories: List["Memory"]) -> "Memory": + def aggregate_across_processes(memories: List["Memory"]) -> "Memory": if len(memories) == 0: raise ValueError("No memory measurements to aggregate") elif any(memory is None for memory in memories): raise ValueError("Some memory measurements are missing") - unit = memories[0].unit - - # process specific measurements - max_ram = sum(memory.max_ram for memory in memories) + # ram, reserved, allocated, and process_vram measurements are process-specific so they are summed + max_ram = sum(memory.max_ram for memory in memories) if memories[0].max_ram is not None else None max_reserved = sum(memory.max_reserved for memory in memories) if memories[0].max_reserved is not None else None max_allocated = ( sum(memory.max_allocated for memory in memories) if memories[0].max_allocated is not None else None @@ -69,10 +67,13 @@ def aggregate(memories: List["Memory"]) -> "Memory": max_process_vram = ( sum(memory.max_process_vram for memory in memories) if memories[0].max_process_vram is not None else None ) - # machine level measurements + # global_vram is not process-specific so we take the average max_global_vram = ( - max(memory.max_global_vram for memory in memories) if memories[0].max_global_vram is not None else None + sum(memory.max_global_vram for memory in memories) / len(memories) + if memories[0].max_global_vram is not None + else None ) + unit = memories[0].unit return Memory( unit=unit,