diff --git a/examples/pytorch_vlm.yaml b/examples/pytorch_vlm.yaml new file mode 100644 index 00000000..a39f8c8a --- /dev/null +++ b/examples/pytorch_vlm.yaml @@ -0,0 +1,42 @@ +defaults: + - benchmark + - scenario: inference + - launcher: process + - backend: pytorch + - _base_ + - _self_ + +name: pytorch_vlm + +launcher: + device_isolation: true + device_isolation_action: warn + +backend: + device: cuda + device_ids: 0 + no_weights: true + torch_dtype: float16 + model: Qwen/Qwen2-VL-7B-Instruct + +scenario: + memory: true + latency: true + + warmup_runs: 10 + iterations: 10 + duration: 10 + + input_shapes: + # text + batch_size: 1 + sequence_length: 256 + # image + num_images: 2 + num_channels: 3 + height: 224 + width: 224 + + generate_kwargs: + max_new_tokens: 32 + min_new_tokens: 32 diff --git a/optimum_benchmark/backends/base.py b/optimum_benchmark/backends/base.py index 73ef52cd..6726f91f 100644 --- a/optimum_benchmark/backends/base.py +++ b/optimum_benchmark/backends/base.py @@ -36,7 +36,6 @@ class Backend(Generic[BackendConfigT], ABC): NAME: ClassVar[str] - model_type: str model_shapes: Dict[str, int] pretrained_model: PreTrainedModel diff --git a/optimum_benchmark/backends/timm_utils.py b/optimum_benchmark/backends/timm_utils.py index 941e0991..dbaf36fd 100644 --- a/optimum_benchmark/backends/timm_utils.py +++ b/optimum_benchmark/backends/timm_utils.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Dict from transformers import PretrainedConfig @@ -35,15 +36,17 @@ def extract_timm_shapes_from_config(config: PretrainedConfig) -> Dict[str, Any]: shapes = {} # image input - shapes["num_channels"] = artifacts_dict.get("num_channels", None) - if shapes["num_channels"] is None: - # processors have different names for the number of channels + if "num_channels" in artifacts_dict: + shapes["num_channels"] = artifacts_dict.get("num_channels", None) + elif "channels" in artifacts_dict: shapes["num_channels"] = artifacts_dict.get("channels", None) - image_size = artifacts_dict.get("image_size", None) - if image_size is None: - # processors have different names for the image size - image_size = artifacts_dict.get("size", None) + if "image_size" in artifacts_dict: + image_size = artifacts_dict["image_size"] + elif "size" in artifacts_dict: + image_size = artifacts_dict["size"] + else: + image_size = None if isinstance(image_size, (int, float)): shapes["height"] = image_size @@ -57,24 +60,15 @@ def extract_timm_shapes_from_config(config: PretrainedConfig) -> Dict[str, Any]: elif isinstance(image_size, dict) and len(image_size) == 1: shapes["height"] = list(image_size.values())[0] shapes["width"] = list(image_size.values())[0] - else: - shapes["height"] = None - shapes["width"] = None - input_size = artifacts_dict.get("input_size", None) - if input_size is not None: + if "input_size" in artifacts_dict: + input_size = artifacts_dict.get("input_size", None) shapes["num_channels"] = input_size[0] shapes["height"] = input_size[1] shapes["width"] = input_size[2] - # classification labels - id2label = artifacts_dict.get("id2label", None) - if id2label is not None: - shapes["num_labels"] = len(id2label) - - num_classes = artifacts_dict.get("num_classes", None) - if num_classes is not None: - shapes["num_labels"] = num_classes + if "num_classes" not in artifacts_dict: + warnings.warn("Could not extract shapes [num_channels, height, width] from timm model config.") return shapes diff --git a/optimum_benchmark/backends/transformers_utils.py b/optimum_benchmark/backends/transformers_utils.py index 71e8de63..1ef86dd3 100644 --- a/optimum_benchmark/backends/transformers_utils.py +++ b/optimum_benchmark/backends/transformers_utils.py @@ -1,4 +1,3 @@ -import warnings from contextlib import contextmanager from typing import Any, Dict, Optional, Union @@ -7,6 +6,7 @@ from transformers import ( AutoConfig, AutoFeatureExtractor, + AutoImageProcessor, AutoProcessor, AutoTokenizer, FeatureExtractionMixin, @@ -47,6 +47,7 @@ "image-to-text": "AutoModelForVision2Seq", "text-generation": "AutoModelForCausalLM", "text2text-generation": "AutoModelForSeq2SeqLM", + "image-text-to-text": "AutoModelForImageTextToText", "visual-question-answering": "AutoModelForVisualQuestionAnswering", "automatic-speech-recognition": ("AutoModelForSpeechSeq2Seq", "AutoModelForCTC"), } @@ -64,8 +65,11 @@ model_loaders = (model_loaders,) for model_loader_name in model_loaders: - model_loader_class = getattr(transformers, model_loader_name) - TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name].update(model_loader_class._model_mapping._model_mapping) + model_loader_class = getattr(transformers, model_loader_name, None) + if model_loader_class is not None: + TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES[task_name].update( + model_loader_class._model_mapping._model_mapping + ) else: TASKS_TO_MODEL_TYPES_TO_MODEL_CLASSES = {} @@ -107,56 +111,83 @@ def get_transformers_pretrained_processor(model: str, **kwargs) -> Optional["Pre return AutoFeatureExtractor.from_pretrained(model, **kwargs) except Exception: try: - return AutoTokenizer.from_pretrained(model, **kwargs) + return AutoImageProcessor.from_pretrained(model, **kwargs) except Exception: - return None + try: + return AutoTokenizer.from_pretrained(model, **kwargs) + except Exception: + return None + + +def get_flat_dict(d: Dict[str, Any]) -> Dict[str, Any]: + flat_dict = {} + for k, v in d.items(): + if isinstance(v, dict): + flat_dict.update(get_flat_dict(v)) + else: + flat_dict[k] = v + return flat_dict + + +def get_flat_artifact_dict(artifact: Union[PretrainedConfig, PretrainedProcessor]) -> Dict[str, Any]: + artifact_dict = {} + + if isinstance(artifact, ProcessorMixin): + artifact_dict.update( + {k: v for k, v in artifact.__dict__.items() if isinstance(v, (int, str, float, bool, list, tuple, dict))} + ) + for attribute in artifact.attributes: + artifact_dict.update(get_flat_artifact_dict(getattr(artifact, attribute))) + elif hasattr(artifact, "to_dict"): + artifact_dict.update( + {k: v for k, v in artifact.to_dict().items() if isinstance(v, (int, str, float, bool, list, tuple, dict))} + ) + else: + artifact_dict.update( + {k: v for k, v in artifact.__dict__.items() if isinstance(v, (int, str, float, bool, list, tuple, dict))} + ) + + artifact_dict = get_flat_dict(artifact_dict) + + return artifact_dict def extract_transformers_shapes_from_artifacts( - config: Optional["PretrainedConfig"] = None, processor: Optional["PretrainedProcessor"] = None + config: Optional["PretrainedConfig"] = None, + processor: Optional["PretrainedProcessor"] = None, ) -> Dict[str, Any]: - artifacts_dict = {} + flat_artifacts_dict = {} - if config is not None and hasattr(config, "to_dict"): - config_dict = {k: v for k, v in config.to_dict().items() if v is not None} - artifacts_dict.update(config_dict) - elif config is not None: - try: - config_dict = {k: getattr(config, k) for k in dir(config) if isinstance(getattr(config, k), int)} - artifacts_dict.update(config_dict) - except Exception: - warnings.warn(f"Could not extract shapes from config {config}") + if config is not None: + flat_artifacts_dict.update(get_flat_artifact_dict(config)) - if processor is not None and hasattr(processor, "to_dict"): - processor_dict = {k: v for k, v in processor.to_dict().items() if v is not None} - artifacts_dict.update(processor_dict) - elif processor is not None: - try: - processor_dict = { - k: getattr(processor, k) for k in dir(processor) if isinstance(getattr(processor, k), int) - } - except Exception: - warnings.warn(f"Could not extract shapes from processor {processor}") + if processor is not None: + flat_artifacts_dict.update(get_flat_artifact_dict(processor)) shapes = {} # text input - shapes["vocab_size"] = artifacts_dict.get("vocab_size", None) - shapes["type_vocab_size"] = artifacts_dict.get("type_vocab_size", None) - shapes["max_position_embeddings"] = artifacts_dict.get("max_position_embeddings", None) - if shapes["max_position_embeddings"] is None: - shapes["max_position_embeddings"] = artifacts_dict.get("n_positions", None) + if "vocab_size" in flat_artifacts_dict: + shapes["vocab_size"] = flat_artifacts_dict["vocab_size"] + + if "type_vocab_size" in flat_artifacts_dict: + shapes["type_vocab_size"] = flat_artifacts_dict["type_vocab_size"] + + if "max_position_embeddings" in flat_artifacts_dict: + shapes["max_position_embeddings"] = flat_artifacts_dict["max_position_embeddings"] + elif "n_positions" in flat_artifacts_dict: + shapes["max_position_embeddings"] = flat_artifacts_dict["n_positions"] # image input - shapes["num_channels"] = artifacts_dict.get("num_channels", None) - if shapes["num_channels"] is None: - # processors have different names for the number of channels - shapes["num_channels"] = artifacts_dict.get("channels", None) + if "num_channels" in flat_artifacts_dict: + shapes["num_channels"] = flat_artifacts_dict["num_channels"] - image_size = artifacts_dict.get("image_size", None) - if image_size is None: - # processors have different names for the image size - image_size = artifacts_dict.get("size", None) + if "image_size" in flat_artifacts_dict: + image_size = flat_artifacts_dict["image_size"] + elif "size" in flat_artifacts_dict: + image_size = flat_artifacts_dict["size"] + else: + image_size = None if isinstance(image_size, (int, float)): shapes["height"] = image_size @@ -170,29 +201,41 @@ def extract_transformers_shapes_from_artifacts( elif isinstance(image_size, dict) and len(image_size) == 1: shapes["height"] = list(image_size.values())[0] shapes["width"] = list(image_size.values())[0] - else: - shapes["height"] = None - shapes["width"] = None - input_size = artifacts_dict.get("input_size", None) - if input_size is not None: + if "input_size" in flat_artifacts_dict: + input_size = flat_artifacts_dict["input_size"] shapes["num_channels"] = input_size[0] shapes["height"] = input_size[1] shapes["width"] = input_size[2] # classification labels - id2label = artifacts_dict.get("id2label", None) - if id2label is not None: + if "id2label" in flat_artifacts_dict: + id2label = flat_artifacts_dict["id2label"] shapes["num_labels"] = len(id2label) - - num_classes = artifacts_dict.get("num_classes", None) - if num_classes is not None: - shapes["num_labels"] = num_classes + elif "num_classes" in flat_artifacts_dict: + shapes["num_labels"] = flat_artifacts_dict["num_classes"] # object detection labels - shapes["num_queries"] = artifacts_dict.get("num_queries", None) - if shapes["num_queries"] == 0: - shapes["num_queries"] = 2 + if "num_queries" in flat_artifacts_dict: + shapes["num_queries"] = flat_artifacts_dict["num_queries"] + + # image-text input + + if "patch_size" in flat_artifacts_dict: + shapes["patch_size"] = flat_artifacts_dict["patch_size"] + if "in_chans" in flat_artifacts_dict: + shapes["num_channels"] = flat_artifacts_dict["in_chans"] + if "image_seq_len" in flat_artifacts_dict: + shapes["image_seq_len"] = flat_artifacts_dict["image_seq_len"] + if "image_token_id" in flat_artifacts_dict: + shapes["image_token_id"] = flat_artifacts_dict["image_token_id"] + if "spatial_merge_size" in flat_artifacts_dict: + shapes["spatial_merge_size"] = flat_artifacts_dict["spatial_merge_size"] + if "do_image_splitting" in flat_artifacts_dict: + shapes["do_image_splitting"] = flat_artifacts_dict["do_image_splitting"] + + if "temporal_patch_size" in flat_artifacts_dict: + shapes["temporal_patch_size"] = flat_artifacts_dict["temporal_patch_size"] return shapes diff --git a/optimum_benchmark/generators/base.py b/optimum_benchmark/generators/base.py new file mode 100644 index 00000000..e4d779b9 --- /dev/null +++ b/optimum_benchmark/generators/base.py @@ -0,0 +1,52 @@ +import logging +import random +import string +from abc import ABC +from typing import Dict, List, Tuple + +import torch + +LOGGER = logging.getLogger("generators") + + +class BaseGenerator(ABC): + def __init__(self, shapes: Dict[str, int], with_labels: bool): + self.shapes = shapes + self.with_labels = with_labels + + def assert_not_missing_shapes(self, required_shapes: List[str]): + for shape in required_shapes: + assert self.shapes.get(shape, None) is not None, ( + f"{shape} either couldn't be inferred automatically from model artifacts or should be provided by the user. " + f"Please provide it under `scenario.input_shapes.{shape}` or open an issue/PR in optimum-benchmark repository. " + ) + + @staticmethod + def generate_constant_integers(value: int, shape: Tuple[int]): + return torch.full(shape, value, dtype=torch.int64) + + @staticmethod + def generate_constant_floats(value: float, shape: Tuple[int]): + return torch.full(shape, value, dtype=torch.float32) + + @staticmethod + def generate_random_integers(min_value: int, max_value: int, shape: Tuple[int]): + return torch.randint(min_value, max_value, shape) + + @staticmethod + def generate_random_floats(min_value: float, max_value: float, shape: Tuple[int]): + return torch.rand(shape) * (max_value - min_value) + min_value + + @staticmethod + def generate_ranges(start: int, stop: int, shape: Tuple[int]): + return torch.arange(start, stop).repeat(shape[0], 1) + + @staticmethod + def generate_random_strings(num_seq: int) -> List[str]: + return [ + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(random.randint(10, 100))) + for _ in range(num_seq) + ] + + def __call__(self): + raise NotImplementedError("Generator must implement __call__ method") diff --git a/optimum_benchmark/generators/dataset_generator.py b/optimum_benchmark/generators/dataset_generator.py index bbaa87f0..efc8a029 100644 --- a/optimum_benchmark/generators/dataset_generator.py +++ b/optimum_benchmark/generators/dataset_generator.py @@ -1,29 +1,41 @@ -from typing import Dict +from typing import Dict, Optional from datasets import Dataset -from .task_generator import TASKS_TO_GENERATORS, TaskGenerator +from .base import BaseGenerator +from .model_generator import MODEL_TYPE_TO_GENERATORS +from .task_generator import TASKS_TO_GENERATORS class DatasetGenerator: - task_generator: TaskGenerator + generator: BaseGenerator - def __init__(self, task: str, dataset_shapes: Dict[str, int], model_shapes: Dict[str, int]) -> None: - dataset_shapes["batch_size"] = dataset_shapes["dataset_size"] + def __init__( + self, + task: str, + dataset_shapes: Dict[str, int], + model_shapes: Dict[str, int], + model_type: Optional[str] = None, + ) -> None: + # dataset_shapes take precedence over model_shapes + all_shapes = {**model_shapes, **dataset_shapes} + all_shapes["batch_size"] = all_shapes.pop("dataset_size", None) - if task in TASKS_TO_GENERATORS: - shapes = {**dataset_shapes, **model_shapes} - self.task_generator = TASKS_TO_GENERATORS[task](shapes=shapes, with_labels=True) + if model_type in MODEL_TYPE_TO_GENERATORS: + self.generator = MODEL_TYPE_TO_GENERATORS[model_type](shapes=all_shapes, with_labels=True) + elif task in TASKS_TO_GENERATORS: + self.generator = TASKS_TO_GENERATORS[task](shapes=all_shapes, with_labels=True) else: raise NotImplementedError( - f"Task {task} is supported. \n" - f"Available tasks: {list(TASKS_TO_GENERATORS.keys())}. \n" - "If you want to add support for this task, " - "please submit a PR or a feature request to optimum-benchmark. \n" + f"Task {task} is not supported for dataset generation. " + f"Available tasks: {list(TASKS_TO_GENERATORS.keys())}. " + f"Available model types: {list(MODEL_TYPE_TO_GENERATORS.keys())}. " + "If you want to add support for this task or model type, " + "please submit a PR or a feature request to optimum-benchmark." ) def __call__(self) -> Dataset: - task_dataset = self.task_generator() + task_dataset = self.generator() task_dataset = Dataset.from_dict(task_dataset) task_dataset.set_format(type="torch", columns=list(task_dataset.features.keys())) return task_dataset diff --git a/optimum_benchmark/generators/input_generator.py b/optimum_benchmark/generators/input_generator.py index 1dd5501a..2f05dc62 100644 --- a/optimum_benchmark/generators/input_generator.py +++ b/optimum_benchmark/generators/input_generator.py @@ -1,23 +1,36 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional -from .task_generator import TASKS_TO_GENERATORS, TaskGenerator +from .base import BaseGenerator +from .model_generator import MODEL_TYPE_TO_GENERATORS +from .task_generator import TASKS_TO_GENERATORS class InputGenerator: - task_generator: TaskGenerator + generator: BaseGenerator - def __init__(self, task: str, input_shapes: Dict[str, int], model_shapes: Dict[str, int]) -> None: - if task in TASKS_TO_GENERATORS: - shapes = {**input_shapes, **model_shapes} - self.task_generator = TASKS_TO_GENERATORS[task](shapes=shapes, with_labels=False) + def __init__( + self, + task: str, + input_shapes: Dict[str, int], + model_shapes: Dict[str, int], + model_type: Optional[str] = None, + ) -> None: + # input_shapes take precedence over model_shapes + all_shapes = {**model_shapes, **input_shapes} + + if model_type in MODEL_TYPE_TO_GENERATORS: + self.generator = MODEL_TYPE_TO_GENERATORS[model_type](shapes=all_shapes, with_labels=False) + elif task in TASKS_TO_GENERATORS: + self.generator = TASKS_TO_GENERATORS[task](shapes=all_shapes, with_labels=False) else: raise NotImplementedError( - f"Task {task} is not supported. " + f"Task {task} is not supported for input generation. " f"Available tasks: {list(TASKS_TO_GENERATORS.keys())}. " - "If you want to add support for this task, " - "please submit a PR or a feature request to optimum-benchmark. " + f"Available model types: {list(MODEL_TYPE_TO_GENERATORS.keys())}. " + "If you want to add support for this task or model type, " + "please submit a PR or a feature request to optimum-benchmark." ) def __call__(self) -> Dict[str, Any]: - task_input = self.task_generator() + task_input = self.generator() return task_input diff --git a/optimum_benchmark/generators/model_generator.py b/optimum_benchmark/generators/model_generator.py new file mode 100644 index 00000000..e709398a --- /dev/null +++ b/optimum_benchmark/generators/model_generator.py @@ -0,0 +1,259 @@ +import logging + +import torch + +from .base import BaseGenerator + +LOGGER = logging.getLogger("generators") + +DEFAULT_VOCAB_SIZE = 2 + + +class IdeficsGenerator(BaseGenerator): + def input_ids(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images", "image_token_id"]) + + text_tokens = self.generate_random_integers( + min_value=0, + max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE), + shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), + ) + + image_tokens = self.generate_constant_integers( + value=self.shapes["image_token_id"], + shape=(self.shapes["batch_size"], self.shapes["num_images"]), + ) + + return torch.cat((text_tokens, image_tokens), dim=1) + + def attention_mask(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images"]) + + return self.generate_constant_integers( + value=1, # no sparsity + shape=( + self.shapes["batch_size"], + self.shapes["sequence_length"] + self.shapes["num_images"], + ), + ) + + def pixel_values(self): + self.assert_not_missing_shapes(["batch_size", "num_images", "num_channels", "height", "width"]) + + return self.generate_random_floats( + min_value=0, + max_value=1, + shape=( + self.shapes["batch_size"], + self.shapes["num_images"], + self.shapes["num_channels"], + self.shapes["height"], + self.shapes["width"], + ), + ) + + def image_attention_mask(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images"]) + + return self.generate_constant_integers( + value=1, # no sparsity + shape=( + self.shapes["batch_size"], + self.shapes["sequence_length"] + self.shapes["num_images"], + self.shapes["num_images"], + ), + ) + + def __call__(self): + dummy = {} + + dummy["input_ids"] = self.input_ids() + dummy["pixel_values"] = self.pixel_values() + dummy["attention_mask"] = self.attention_mask() + dummy["image_attention_mask"] = self.image_attention_mask() + + if self.with_labels: + dummy["labels"] = self.input_ids() + + return dummy + + +class Idefics2Generator(BaseGenerator): + def input_ids(self): + self.assert_not_missing_shapes( + ["batch_size", "sequence_length", "num_images", "image_seq_len", "image_token_id", "do_image_splitting"] + ) + + text_tokens = self.generate_random_integers( + min_value=0, + max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE), + shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), + ) + + image_tokens = self.generate_constant_integers( + value=self.shapes["image_token_id"], + shape=( + self.shapes["batch_size"], + self.shapes["num_images"] + * self.shapes["image_seq_len"] + * (5 if self.shapes["do_image_splitting"] else 1), + ), + ) + + return torch.cat((text_tokens, image_tokens), dim=1) + + def attention_mask(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images", "do_image_splitting"]) + + return self.generate_constant_integers( + value=1, # no sparsity + shape=( + self.shapes["batch_size"], + self.shapes["sequence_length"] + + self.shapes["num_images"] + * self.shapes["image_seq_len"] + * (5 if self.shapes["do_image_splitting"] else 1), + ), + ) + + def pixel_values(self): + self.assert_not_missing_shapes( + ["batch_size", "num_images", "num_channels", "height", "width", "do_image_splitting"] + ) + + return self.generate_random_floats( + min_value=0, + max_value=1, + shape=( + self.shapes["batch_size"], + self.shapes["num_images"] * (5 if self.shapes["do_image_splitting"] else 1), + self.shapes["num_channels"], + self.shapes["height"], + self.shapes["width"], + ), + ) + + def pixel_attention_mask(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length", "num_images", "do_image_splitting"]) + + return self.generate_constant_integers( + value=1, # no sparsity + shape=( + self.shapes["batch_size"], + self.shapes["num_images"] * (5 if self.shapes["do_image_splitting"] else 1), + self.shapes["height"], + self.shapes["width"], + ), + ) + + def __call__(self): + dummy = {} + + dummy["input_ids"] = self.input_ids() + dummy["pixel_values"] = self.pixel_values() + dummy["attention_mask"] = self.attention_mask() + dummy["pixel_attention_mask"] = self.pixel_attention_mask() + + print("input_ids", dummy["input_ids"].shape) + print("pixel_values", dummy["pixel_values"].shape) + print("attention_mask", dummy["attention_mask"].shape) + print("pixel_attention_mask", dummy["pixel_attention_mask"].shape) + + if self.with_labels: + dummy["labels"] = self.input_ids() + + return dummy + + +class Qwen2VLGenerator(BaseGenerator): + def input_ids(self): + self.assert_not_missing_shapes( + [ + "batch_size", + "sequence_length", + "num_images", + "num_channels", + "height", + "width", + "patch_size", + "temporal_patch_size", + "spatial_merge_size", + "image_token_id", + ] + ) + + text_tokens = self.generate_random_integers( + min_value=0, + max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE), + shape=( + self.shapes["batch_size"], + self.shapes["sequence_length"], + ), + ) + image_tokens = self.generate_constant_integers( + value=self.shapes["image_token_id"], + shape=( + self.shapes["batch_size"], + int( + self.shapes["num_images"] + * self.shapes["height"] + * self.shapes["width"] + / self.shapes["temporal_patch_size"] + / self.shapes["spatial_merge_size"] + / self.shapes["patch_size"] ** 2 + ), + ), + ) + + return torch.cat((text_tokens, image_tokens), dim=1) + + def pixel_values(self): + self.assert_not_missing_shapes( + ["num_images", "num_channels", "height", "width", "patch_size", "temporal_patch_size"] + ) + + return self.generate_random_floats( + min_value=0, + max_value=1, + shape=( + self.shapes["num_images"] + * int(self.shapes["height"] / self.shapes["patch_size"]) + * int(self.shapes["width"] / self.shapes["patch_size"]), + self.shapes["num_channels"] + * self.shapes["patch_size"] + * self.shapes["patch_size"] + * self.shapes["temporal_patch_size"], + ), + ) + + def image_grid_thw(self): + self.assert_not_missing_shapes(["num_images", "height", "width", "patch_size"]) + + return torch.tensor( + [ + [ + self.shapes["num_images"], + int(self.shapes["height"] / self.shapes["patch_size"]), + int(self.shapes["width"] / self.shapes["patch_size"]), + ] + ] + ) + + def __call__(self): + dummy = {} + + dummy["input_ids"] = self.input_ids() + dummy["pixel_values"] = self.pixel_values() + dummy["image_grid_thw"] = self.image_grid_thw() + + if self.with_labels: + dummy["labels"] = self.input_ids() + + return dummy + + +MODEL_TYPE_TO_GENERATORS = { + "idefics": IdeficsGenerator, + "idefics2": Idefics2Generator, + "qwen2_vl": Qwen2VLGenerator, +} diff --git a/optimum_benchmark/generators/task_generator.py b/optimum_benchmark/generators/task_generator.py index 76131578..f11d21eb 100644 --- a/optimum_benchmark/generators/task_generator.py +++ b/optimum_benchmark/generators/task_generator.py @@ -1,11 +1,6 @@ import logging -import random -import string -from abc import ABC -from typing import List, Tuple -# TODO: drop torch dependency and use numpy instead -import torch +from .base import BaseGenerator LOGGER = logging.getLogger("generators") @@ -14,57 +9,36 @@ DEFAULT_TYPE_VOCAB_SIZE = 2 -class TaskGenerator(ABC): - def __init__(self, shapes, with_labels: bool): - self.shapes = shapes - self.with_labels = with_labels - - @staticmethod - def generate_random_integers(min_value: int, max_value: int, shape: Tuple[int]): - return torch.randint(min_value, max_value, shape) - - @staticmethod - def generate_random_floats(min_value: float, max_value: float, shape: Tuple[int]): - return torch.rand(shape) * (max_value - min_value) + min_value - - @staticmethod - def generate_ranges(start: int, stop: int, shape: Tuple[int]): - return torch.arange(start, stop).repeat(shape[0], 1) - - @staticmethod - def generate_random_strings(num_seq: int) -> List[str]: - return [ - "".join(random.choice(string.ascii_letters + string.digits) for _ in range(random.randint(10, 100))) - for _ in range(num_seq) - ] - - def __call__(self): - raise NotImplementedError("Generator must implement __call__ method") - - -class TextGenerator(TaskGenerator): +class TextGenerator(BaseGenerator): def input_ids(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length"]) + return self.generate_random_integers( min_value=0, - max_value=self.shapes["vocab_size"] or DEFAULT_VOCAB_SIZE, + max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE), shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), ) def attention_mask(self): - return self.generate_random_integers( - min_value=1, # avoid sparse attention - max_value=2, + self.assert_not_missing_shapes(["batch_size", "sequence_length"]) + + return self.generate_constant_integers( + value=1, # no sparsity shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), ) def token_type_ids(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length"]) + return self.generate_random_integers( min_value=0, - max_value=self.shapes["type_vocab_size"] or DEFAULT_TYPE_VOCAB_SIZE, + max_value=self.shapes.get("type_vocab_size", DEFAULT_TYPE_VOCAB_SIZE), shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), ) def position_ids(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length"]) + return self.generate_ranges( start=0, stop=self.shapes["sequence_length"], @@ -72,39 +46,65 @@ def position_ids(self): ) def requires_token_type_ids(self): - return self.shapes["type_vocab_size"] is not None and self.shapes["type_vocab_size"] > 1 + return self.shapes.get("type_vocab_size", None) is not None and self.shapes["type_vocab_size"] > 1 def requires_position_ids(self): - return self.shapes["max_position_embeddings"] is not None + return ( + self.shapes.get("max_position_embeddings", None) is not None and self.shapes["max_position_embeddings"] > 1 + ) -class ImageGenerator(TaskGenerator): +class ImageGenerator(BaseGenerator): def pixel_values(self): + self.assert_not_missing_shapes(["batch_size", "num_channels", "height", "width"]) + return self.generate_random_floats( min_value=0, max_value=1, - shape=(self.shapes["batch_size"], self.shapes["num_channels"], self.shapes["height"], self.shapes["width"]), + shape=( + self.shapes["batch_size"], + self.shapes["num_channels"], + self.shapes["height"], + self.shapes["width"], + ), ) -class AudioGenerator(TaskGenerator): +class AudioGenerator(BaseGenerator): def input_values(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length"]) + return self.generate_random_floats( - min_value=-1, max_value=1, shape=(self.shapes["batch_size"], self.shapes["sequence_length"]) + min_value=-1, + max_value=1, + shape=( + self.shapes["batch_size"], + self.shapes["sequence_length"], + ), ) def input_features(self): + self.assert_not_missing_shapes(["batch_size", "feature_size", "nb_max_frames"]) + return self.generate_random_floats( min_value=-1, max_value=1, - shape=(self.shapes["batch_size"], self.shapes["feature_size"], self.shapes["nb_max_frames"]), + shape=( + self.shapes["batch_size"], + self.shapes["feature_size"], + self.shapes["nb_max_frames"], + ), ) class TextClassificationGenerator(TextGenerator): def labels(self): + self.assert_not_missing_shapes(["batch_size"]) + return self.generate_random_integers( - min_value=0, max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"],) + min_value=0, + max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS), + shape=(self.shapes["batch_size"],), ) def __call__(self): @@ -127,9 +127,11 @@ def __call__(self): class TokenClassificationGenerator(TextGenerator): def labels(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length"]) + return self.generate_random_integers( min_value=0, - max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, + max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS), shape=(self.shapes["batch_size"], self.shapes["sequence_length"]), ) @@ -177,13 +179,21 @@ def __call__(self): class QuestionAnsweringGenerator(TextGenerator): def start_positions(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length"]) + return self.generate_random_integers( - min_value=0, max_value=self.shapes["sequence_length"], shape=(self.shapes["batch_size"],) + min_value=0, + max_value=self.shapes["sequence_length"], + shape=(self.shapes["batch_size"],), ) def end_positions(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length"]) + return self.generate_random_integers( - min_value=0, max_value=self.shapes["sequence_length"], shape=(self.shapes["batch_size"],) + min_value=0, + max_value=self.shapes["sequence_length"], + shape=(self.shapes["batch_size"],), ) def __call__(self): @@ -220,7 +230,35 @@ def __call__(self): class MultipleChoiceGenerator(TextGenerator): + def input_ids(self): + self.assert_not_missing_shapes(["batch_size", "num_choices", "sequence_length"]) + + return self.generate_random_integers( + min_value=0, + max_value=self.shapes.get("vocab_size", DEFAULT_VOCAB_SIZE), + shape=(self.shapes["batch_size"], self.shapes["num_choices"], self.shapes["sequence_length"]), + ) + + def attention_mask(self): + self.assert_not_missing_shapes(["batch_size", "num_choices", "sequence_length"]) + + return self.generate_constant_integers( + value=1, # no sparsity + shape=(self.shapes["batch_size"], self.shapes["num_choices"], self.shapes["sequence_length"]), + ) + + def token_type_ids(self): + self.assert_not_missing_shapes(["batch_size", "num_choices", "sequence_length"]) + + return self.generate_random_integers( + min_value=0, + max_value=self.shapes.get("type_vocab_size", DEFAULT_TYPE_VOCAB_SIZE), + shape=(self.shapes["batch_size"], self.shapes["num_choices"], self.shapes["sequence_length"]), + ) + def labels(self): + self.assert_not_missing_shapes(["batch_size", "num_choices"]) + return self.generate_random_integers( min_value=0, max_value=self.shapes["num_choices"], shape=(self.shapes["batch_size"],) ) @@ -228,24 +266,11 @@ def labels(self): def __call__(self): dummy = {} - dummy["input_ids"] = ( - self.input_ids() - .reshape(self.shapes["batch_size"], 1, self.shapes["sequence_length"]) - .repeat(1, self.shapes["num_choices"], 1) - ) - - dummy["attention_mask"] = ( - self.attention_mask() - .reshape(self.shapes["batch_size"], 1, self.shapes["sequence_length"]) - .repeat(1, self.shapes["num_choices"], 1) - ) + dummy["input_ids"] = self.input_ids() + dummy["attention_mask"] = self.attention_mask() if self.requires_token_type_ids(): - dummy["token_type_ids"] = ( - self.token_type_ids() - .reshape(self.shapes["batch_size"], 1, self.shapes["sequence_length"]) - .repeat(1, self.shapes["num_choices"], 1) - ) + dummy["token_type_ids"] = self.token_type_ids() if self.with_labels: dummy["label"] = self.labels() @@ -255,8 +280,12 @@ def __call__(self): class ImageClassificationGenerator(ImageGenerator): def labels(self): + self.assert_not_missing_shapes(["batch_size"]) + return self.generate_random_integers( - min_value=0, max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"],) + min_value=0, + max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS), + shape=(self.shapes["batch_size"],), ) def __call__(self): @@ -271,11 +300,13 @@ def __call__(self): class ObjectDetectionGenerator(ImageGenerator): def labels(self): + self.assert_not_missing_shapes(["batch_size", "num_queries"]) + return [ { "class_labels": self.generate_random_integers( min_value=0, - max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, + max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS), shape=(self.shapes["num_queries"],), ), "boxes": self.generate_random_floats(min_value=-1, max_value=1, shape=(self.shapes["num_queries"], 4)), @@ -295,9 +326,11 @@ def __call__(self): class SemanticSegmentationGenerator(ImageGenerator): def labels(self): + self.assert_not_missing_shapes(["batch_size", "height", "width"]) + return self.generate_random_integers( min_value=0, - max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, + max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS), shape=(self.shapes["batch_size"], self.shapes["height"], self.shapes["width"]), ) @@ -313,8 +346,10 @@ def __call__(self): class AudioClassificationGenerator(AudioGenerator): def labels(self): + self.assert_not_missing_shapes(["batch_size"]) + return self.generate_random_integers( - min_value=0, max_value=self.shapes["num_labels"] or DEFAULT_NUM_LABELS, shape=(self.shapes["batch_size"],) + min_value=0, max_value=self.shapes.get("num_labels", DEFAULT_NUM_LABELS), shape=(self.shapes["batch_size"],) ) def __call__(self): @@ -329,6 +364,8 @@ def __call__(self): class AutomaticSpeechRecognitionGenerator(AudioGenerator): def labels(self): + self.assert_not_missing_shapes(["batch_size", "sequence_length"]) + return self.generate_random_integers( min_value=0, max_value=self.shapes["vocab_size"] or DEFAULT_TYPE_VOCAB_SIZE, @@ -345,8 +382,10 @@ def __call__(self): return dummy -class PromptGenerator(TaskGenerator): +class PromptGenerator(BaseGenerator): def prompt(self): + self.assert_not_missing_shapes(["batch_size"]) + return self.generate_random_strings(num_seq=self.shapes["batch_size"]) def __call__(self): @@ -360,9 +399,7 @@ class FeatureExtractionGenerator(TextGenerator, ImageGenerator): def __call__(self): dummy = {} - if self.shapes.get("num_channels", None) is not None and self.shapes.get("height", None) is not None: - dummy["pixel_values"] = self.pixel_values() - else: + if self.shapes.get("sequence_length", None) is not None: dummy["input_ids"] = self.input_ids() dummy["attention_mask"] = self.attention_mask() @@ -372,6 +409,23 @@ def __call__(self): if self.requires_position_ids(): dummy["position_ids"] = self.position_ids() + if self.shapes.get("height", None) is not None: + dummy["pixel_values"] = self.pixel_values() + + return dummy + + +class ImageTextToTextGenerator(TextGenerator, ImageGenerator): + def __call__(self): + dummy = {} + + dummy["input_ids"] = self.input_ids() + dummy["attention_mask"] = self.attention_mask() + dummy["pixel_values"] = self.pixel_values() + + if self.with_labels: + dummy["labels"] = self.input_ids() + return dummy @@ -388,6 +442,7 @@ def __call__(self): "image-classification": ImageClassificationGenerator, "object-detection": ObjectDetectionGenerator, "semantic-segmentation": SemanticSegmentationGenerator, + "image-text-to-text": ImageTextToTextGenerator, # diffusers pipelines tasks "text-to-image": PromptGenerator, "stable-diffusion": PromptGenerator, diff --git a/optimum_benchmark/scenarios/inference/config.py b/optimum_benchmark/scenarios/inference/config.py index 2c05d97f..57d482ab 100644 --- a/optimum_benchmark/scenarios/inference/config.py +++ b/optimum_benchmark/scenarios/inference/config.py @@ -7,7 +7,10 @@ LOGGER = getLogger("inference") -INPUT_SHAPES = {"batch_size": 2, "num_choices": 2, "sequence_length": 16} +INPUT_SHAPES = { + "batch_size": 2, + "sequence_length": 16, +} @dataclass diff --git a/optimum_benchmark/scenarios/inference/scenario.py b/optimum_benchmark/scenarios/inference/scenario.py index f2f18e0b..512f269d 100644 --- a/optimum_benchmark/scenarios/inference/scenario.py +++ b/optimum_benchmark/scenarios/inference/scenario.py @@ -21,8 +21,6 @@ "min_new_tokens": 100, "do_sample": False, "use_cache": True, - "pad_token_id": 0, - "eos_token_id": 0, "num_beams": 1, } TEXT_GENERATION_PREFILL_OVERRIDES = { @@ -60,7 +58,10 @@ def __init__(self, config: InferenceConfig) -> None: def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: self.logger.info("\t+ Creating input generator") self.input_generator = InputGenerator( - task=backend.config.task, model_shapes=backend.model_shapes, input_shapes=self.config.input_shapes + task=backend.config.task, + input_shapes=self.config.input_shapes, + model_shapes=backend.model_shapes, + model_type=backend.config.model_type, ) if backend.config.task in TEXT_GENERATION_TASKS: @@ -414,8 +415,8 @@ def atomic_call_volume(self) -> int: # in images @property def atomic_prefill_volume(self) -> int: # in tokens if {"input_ids", "prompt", "prompts"} & set(self.inputs.keys()): - # text conditioned generation (1 bos token or sequence_length tokens) - return self.config.input_shapes["batch_size"] * max(self.config.input_shapes["sequence_length"], 1) + # text conditioned generation (sequence_length tokens) + return self.config.input_shapes["batch_size"] * self.config.input_shapes["sequence_length"] else: # image/audio conditioned generation (1 bos token) return self.config.input_shapes["batch_size"] diff --git a/optimum_benchmark/task_utils.py b/optimum_benchmark/task_utils.py index 337e835e..0a2a98c2 100644 --- a/optimum_benchmark/task_utils.py +++ b/optimum_benchmark/task_utils.py @@ -47,6 +47,7 @@ "image-to-text", "conversational", "text-generation", + "image-text-to-text", "text2text-generation", "automatic-speech-recognition", ] diff --git a/tests/configs/_image_text_to_text_.yaml b/tests/configs/_image_text_to_text_.yaml new file mode 100644 index 00000000..20043a67 --- /dev/null +++ b/tests/configs/_image_text_to_text_.yaml @@ -0,0 +1,10 @@ +hydra: + mode: MULTIRUN + sweeper: + params: + backend.task: image-text-to-text + backend.model: hf-internal-testing/tiny-random-GitForCausalLM, + hf-internal-testing/tiny-random-BlipForConditionalGeneration, + hf-internal-testing/tiny-random-Blip2ForConditionalGeneration, + hf-internal-testing/tiny-random-IdeficsForVisionText2Text + +scenario.input_shapes.num_images: 2 diff --git a/tests/configs/cpu_inference_pytorch_image_text_to_text.yaml b/tests/configs/cpu_inference_pytorch_image_text_to_text.yaml new file mode 100644 index 00000000..df125a3a --- /dev/null +++ b/tests/configs/cpu_inference_pytorch_image_text_to_text.yaml @@ -0,0 +1,11 @@ +defaults: + # order of inheritance, last one overrides previous ones + - _base_ # inherits from base config + - _cpu_ # inherits from cpu config + - _inference_ # inherits from inference config + - _image_text_to_text_ # inherits from image text to text config + - _no_weights_ # inherits from no weights config + - _self_ # hydra 1.1 compatibility + - override backend: pytorch + +name: cpu_inference_pytorch_image_text_to_text diff --git a/tests/test_api.py b/tests/test_api.py index 66ee16f9..fd6e2dac 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -47,6 +47,9 @@ def test_api_launch(device, scenario, library, task, model): benchmark_name = f"{device}_{scenario}_{library}_{task}_{model}" + if task == "multiple-choice": + INPUT_SHAPES["num_choices"] = 2 + if device == "cuda": device_isolation = True if is_rocm_system(): @@ -82,7 +85,7 @@ def test_api_launch(device, scenario, library, task, model): duration=1, iterations=1, warmup_runs=1, - input_shapes={"batch_size": 1, "sequence_length": 2}, + input_shapes=INPUT_SHAPES, generate_kwargs={"max_new_tokens": 2, "min_new_tokens": 2}, call_kwargs={"num_inference_steps": 2}, ) @@ -170,7 +173,14 @@ def test_api_input_generator(library, task, model): else: raise ValueError(f"Unknown library {library}") - input_generator = InputGenerator(task=task, input_shapes=INPUT_SHAPES, model_shapes=model_shapes) + if task == "multiple-choice": + INPUT_SHAPES["num_choices"] = 2 + + input_generator = InputGenerator( + task=task, + input_shapes=INPUT_SHAPES, + model_shapes=model_shapes, + ) generated_inputs = input_generator() assert len(generated_inputs) > 0, "No inputs were generated" @@ -193,6 +203,9 @@ def test_api_dataset_generator(library, task, model): else: raise ValueError(f"Unknown library {library}") + if task == "multiple-choice": + DATASET_SHAPES["num_choices"] = 2 + generator = DatasetGenerator(task=task, dataset_shapes=DATASET_SHAPES, model_shapes=model_shapes) generated_dataset = generator()