diff --git a/Makefile b/Makefile index 86176c0a..242604b7 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # List of targets that are not associated with files -.PHONY: quality style install build_docker_cpu build_docker_cuda build_docker_rocm test_cli_cpu_neural_compressor test_cli_cpu_onnxruntime test_cli_cpu_openvino test_cli_cpu_pytorch test_cli_rocm_pytorch test_cli_cuda_pytorch test_api_cpu test_api_cuda test_api_rocm test_api_misc +.PHONY: quality style install quality: ruff check . @@ -26,7 +26,7 @@ test_cli_cpu_neural_compressor: --rm \ --pid=host \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-cpu:latest -c "pip install -e .[testing,neural-compressor,diffusers,timm] && pytest tests/ -k 'cli and cpu and neural_compressor' -x" @@ -35,7 +35,7 @@ test_cli_cpu_onnxruntime: --rm \ --pid=host \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-cpu:latest -c "pip install -e .[testing,onnxruntime,diffusers,timm] && pytest tests/ -k 'cli and cpu and onnxruntime' -x" @@ -44,7 +44,7 @@ test_cli_cpu_openvino: --rm \ --pid=host \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-cpu:latest -c "pip install -e .[testing,openvino,diffusers,timm] && pytest tests/ -k 'cli and cpu and openvino' -x" @@ -53,7 +53,7 @@ test_cli_cpu_pytorch: --rm \ --pid=host \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-cpu:latest -c "pip install -e .[testing,diffusers,timm] && pytest tests/ -k 'cli and cpu and pytorch' -x" @@ -66,7 +66,7 @@ test_cli_rocm_pytorch: --device /dev/dri/renderD129 \ --group-add video \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-rocm:latest -c "pip install -e .[testing,diffusers,timm,deepspeed,peft] && pytest tests/ -k 'cli and cuda and pytorch' -x" @@ -76,16 +76,26 @@ test_cli_cuda_pytorch: --pid=host \ --gpus '"device=0,1"' \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-cuda:latest -c "pip install -e .[testing,diffusers,timm,deepspeed,peft] && pytest tests/ -k 'cli and cuda and pytorch' -x" +test_cli_tensorrt_llm: + docker run \ + --rm \ + --pid=host \ + --gpus '"device=0,1"' \ + --entrypoint /bin/bash \ + --volume $(shell pwd):/workspace \ + --workdir /workspace \ + opt-bench-tensorrt-llm:latest -c "pip install -e .[testing] && pip uninstall -y nvidia-ml-py && pytest tests/ -k 'cli and tensorrt_llm' -x" + test_api_cpu: docker run \ --rm \ --pid=host \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-cpu:latest -c "pip install -e .[testing,timm,diffusers] && pytest tests/ -k 'api and cpu' -x" @@ -95,7 +105,7 @@ test_api_cuda: --pid=host \ --gpus '"device=0,1"' \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-cuda:latest -c "pip install -e .[testing,timm,diffusers] && pytest tests/ -k 'api and cuda' -x" @@ -108,7 +118,7 @@ test_api_rocm: --device /dev/dri/renderD129 \ --group-add video \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-rocm:latest -c "pip install -e .[testing,timm,diffusers] && pytest tests/ -k 'api and cuda' -x" @@ -117,6 +127,6 @@ test_api_misc: --rm \ --pid=host \ --entrypoint /bin/bash \ - --volume $(PWD):/workspace \ + --volume $(shell pwd):/workspace \ --workdir /workspace \ opt-bench-cpu:latest -c "pip install -e .[testing,timm,diffusers] && pytest tests/ -k 'api and not (cpu or cuda or rocm or tensorrt)' -x" diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index 3beb1387..7a3b1984 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -47,13 +47,15 @@ def load_trtmodel_from_pretrained(self) -> None: def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput: return self.pretrained_model.generate( - input_ids=inputs.get("input_ids", None), attention_mask=inputs.get("attention_mask", None), max_new_tokens=1 + input_ids=inputs.get("input_ids"), + attention_mask=inputs.get("attention_mask"), + max_new_tokens=1, ) def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput: return self.pretrained_model.generate( - input_ids=inputs.get("inputs", None), # diff names - attention_mask=inputs.get("attention_mask", None), + input_ids=inputs.get("input_ids"), + attention_mask=inputs.get("attention_mask"), # important for benchmarking max_new_tokens=kwargs.get("max_new_tokens", -1), min_length=kwargs.get("min_new_tokens", -1), # why different ? diff --git a/optimum_benchmark/benchmarks/inference/benchmark.py b/optimum_benchmark/benchmarks/inference/benchmark.py index d9420b0b..69606bbb 100644 --- a/optimum_benchmark/benchmarks/inference/benchmark.py +++ b/optimum_benchmark/benchmarks/inference/benchmark.py @@ -1,12 +1,14 @@ from dataclasses import dataclass from logging import getLogger +from transformers import LogitsProcessorList + from ...backends.base import Backend, BackendConfigT from ...generators.input_generator import InputGenerator from ...import_utils import is_torch_distributed_available from ...task_utils import IMAGE_DIFFUSION_TASKS, TEXT_GENERATION_TASKS from ...trackers.energy import Efficiency, EnergyTracker -from ...trackers.latency import LatencyTracker, Throughput +from ...trackers.latency import LatencyLogitsProcessor, LatencyTracker, Throughput from ...trackers.memory import MemoryTracker from ..base import Benchmark from ..report import BenchmarkMeasurements, BenchmarkReport @@ -18,8 +20,9 @@ LOGGER = getLogger("inference") -IMAGE_DIFFUSION_KWARGS = {"num_inference_steps": 30, "num_images_per_prompt": 1} +PER_TOKEN_BACKENDS = ["pytorch", "onnxruntime", "openvino", "neural-compressor"] +IMAGE_DIFFUSION_KWARGS = {"num_inference_steps": 30, "num_images_per_prompt": 1} TEXT_GENERATION_KWARGS = { "num_return_sequences": 1, "max_new_tokens": 100, @@ -31,21 +34,21 @@ "num_beams": 1, } -EFFICIENCY_UNIT = "samples/kWh" -THROUGHPUT_UNIT = "samples/s" -PREFILL_THROUGHPUT_UNIT = "tokens/s" -DECODE_THROUGHPUT_UNIT = "tokens/s" -CALL_THROUGHPUT_UNIT = "images/s" +TEXT_GENERATION_THROUGHPUT_UNIT = "tokens/s" +IMAGE_DIFFUSION_THROUGHPUT_UNIT = "images/s" +INFERENCE_THROUGHPUT_UNIT = "samples/s" -PREFILL_EFFICIENCY_UNIT = "tokens/kWh" -DECODE_EFFICIENCY_UNIT = "tokens/kWh" -CALL_EFFICIENCY_UNIT = "images/kWh" +TEXT_GENERATION_EFFICIENCY_UNIT = "tokens/kWh" +IMAGE_DIFFUSION_EFFICIENCY_UNIT = "images/kWh" +INFERENCE_EFFICIENCY_UNIT = "samples/kWh" @dataclass -class InferenceReport(BenchmarkReport): - forward: BenchmarkMeasurements +class TextGenerationReport(BenchmarkReport): + prefill: BenchmarkMeasurements + decode: BenchmarkMeasurements + per_token: BenchmarkMeasurements @dataclass @@ -54,9 +57,8 @@ class ImageDiffusionReport(BenchmarkReport): @dataclass -class TextGenerationReport(BenchmarkReport): - prefill: BenchmarkMeasurements - decode: BenchmarkMeasurements +class InferenceReport(BenchmarkReport): + forward: BenchmarkMeasurements class InferenceBenchmark(Benchmark[InferenceConfig]): @@ -65,7 +67,7 @@ class InferenceBenchmark(Benchmark[InferenceConfig]): def __init__(self, config: InferenceConfig) -> None: super().__init__(config) - def run(self, backend: Backend[BackendConfigT]) -> None: + def run(self, backend: Backend[BackendConfigT][BackendConfigT]) -> None: if is_torch_distributed_available() and torch.distributed.is_initialized(): LOGGER.info("\t+ Distributing batch size across processes") if self.config.input_shapes["batch_size"] % torch.distributed.get_world_size() != 0: @@ -87,7 +89,9 @@ def run(self, backend: Backend[BackendConfigT]) -> None: LOGGER.info("\t+ Updating Text Generation kwargs with default values") self.config.generate_kwargs = {**TEXT_GENERATION_KWARGS, **self.config.generate_kwargs} LOGGER.info("\t+ Initializing Text Generation report") - self.report = TextGenerationReport(prefill=BenchmarkMeasurements(), decode=BenchmarkMeasurements()) + self.report = TextGenerationReport( + decode=BenchmarkMeasurements(), prefill=BenchmarkMeasurements(), per_token=BenchmarkMeasurements() + ) elif backend.config.task in IMAGE_DIFFUSION_TASKS: LOGGER.info("\t+ Generating Image Diffusion inputs") @@ -142,6 +146,8 @@ def run(self, backend: Backend[BackendConfigT]) -> None: LOGGER.info("\t+ Creating inference latency tracker") self.latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) if backend.config.task in TEXT_GENERATION_TASKS: + LOGGER.info("\t+ Creating latency logits processor tracker") + self.run_text_generation_latency_tracking(backend) elif backend.config.task in IMAGE_DIFFUSION_TASKS: self.run_image_diffusion_latency_tracking(backend) @@ -165,7 +171,7 @@ def run(self, backend: Backend[BackendConfigT]) -> None: self.report.log_efficiency() ## Memory tracking - def run_text_generation_memory_tracking(self, backend: Backend): + def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running memory tracking") self.memory_tracker.reset() with self.memory_tracker.track(): @@ -179,7 +185,7 @@ def run_text_generation_memory_tracking(self, backend: Backend): self.report.decode.memory = self.memory_tracker.get_max_memory() - def run_image_diffusion_memory_tracking(self, backend: Backend): + def run_image_diffusion_memory_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running memory tracking") self.memory_tracker.reset() with self.memory_tracker.track(): @@ -187,7 +193,7 @@ def run_image_diffusion_memory_tracking(self, backend: Backend): self.report.call.memory = self.memory_tracker.get_max_memory() - def run_inference_memory_tracking(self, backend: Backend): + def run_inference_memory_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running memory tracking") self.memory_tracker.reset() with self.memory_tracker.track(): @@ -196,33 +202,51 @@ def run_inference_memory_tracking(self, backend: Backend): self.report.forward.memory = self.memory_tracker.get_max_memory() ## Latency tracking - def run_text_generation_latency_tracking(self, backend: Backend): + def run_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running latency tracking") self.latency_tracker.reset() while self.latency_tracker.get_elapsed_time() < self.config.duration: with self.latency_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) - forward_latency = self.latency_tracker.get_latency() - forward_latency.log(prefix="forward") + self.report.prefill.latency = forward_latency - self.report.prefill.throughput = self.latency_tracker.get_throughput( - volume=self.prefill_volume, unit=PREFILL_THROUGHPUT_UNIT + self.report.prefill.throughput = Throughput.from_latency( + self.report.prefill.latency, self.text_generation_prefill_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT ) - self.latency_tracker.reset() - while self.latency_tracker.get_elapsed_time() < self.config.duration: - with self.latency_tracker.track(): - _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) + if backend.config.name in PER_TOKEN_BACKENDS: + self.logits_processor = LatencyLogitsProcessor(device=backend.config.device, backend=backend.config.name) + self.config.generate_kwargs["logits_processor"] = LogitsProcessorList([self.logits_processor]) + self.logits_processor.reset() - generate_latency = self.latency_tracker.get_latency() - generate_latency.log(prefix="generate") - self.report.decode.latency = generate_latency - self.report.prefill.latency.mean - self.report.decode.throughput = Throughput.from_latency( - self.report.decode.latency, self.decode_volume, unit=DECODE_THROUGHPUT_UNIT - ) + while self.logits_processor.get_elapsed_time() < self.config.duration: + with self.logits_processor.track(): + _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) - def run_image_diffusion_latency_tracking(self, backend: Backend): + self.report.decode.latency = self.logits_processor.get_decode_latency() + self.report.per_token.latency = self.logits_processor.get_per_token_latency() + self.report.decode.throughput = Throughput.from_latency( + self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + ) + self.report.per_token.throughput = Throughput.from_latency( + self.report.per_token.latency, + self.text_generation_per_token_volume, + unit=TEXT_GENERATION_THROUGHPUT_UNIT, + ) + else: + self.latency_tracker.reset() + while self.latency_tracker.get_elapsed_time() < self.config.duration: + with self.latency_tracker.track(): + _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) + generate_latency = self.latency_tracker.get_latency() + + self.report.decode.latency = generate_latency - forward_latency + self.report.decode.throughput = Throughput.from_latency( + self.report.decode.latency, self.text_generation_decode_volume, unit=TEXT_GENERATION_THROUGHPUT_UNIT + ) + + def run_image_diffusion_latency_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running latency tracking") self.latency_tracker.reset() while self.latency_tracker.get_elapsed_time() < self.config.duration: @@ -231,10 +255,10 @@ def run_image_diffusion_latency_tracking(self, backend: Backend): self.report.call.latency = self.latency_tracker.get_latency() self.report.call.throughput = Throughput.from_latency( - self.report.call.latency, self.call_volume, unit=CALL_THROUGHPUT_UNIT + self.report.call.latency, self.image_diffusion_volume, unit=IMAGE_DIFFUSION_THROUGHPUT_UNIT ) - def run_latency_inference_tracking(self, backend: Backend): + def run_latency_inference_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running latency tracking") self.latency_tracker.reset() while self.latency_tracker.get_elapsed_time() < self.config.duration: @@ -243,31 +267,33 @@ def run_latency_inference_tracking(self, backend: Backend): self.report.forward.latency = self.latency_tracker.get_latency() self.report.forward.throughput = Throughput.from_latency( - self.report.forward.latency, self.forward_volume, unit=THROUGHPUT_UNIT + self.report.forward.latency, self.inference_volume, unit=INFERENCE_THROUGHPUT_UNIT ) ## Energy tracking - def run_text_generation_energy_tracking(self, backend: Backend): + def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running energy tracking") self.energy_tracker.reset() with self.energy_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) + forward_energy = self.energy_tracker.get_energy() - self.report.prefill.energy = self.energy_tracker.get_energy() + self.report.prefill.energy = forward_energy self.report.prefill.efficiency = Efficiency.from_energy( - self.report.prefill.energy, self.prefill_volume, unit=PREFILL_EFFICIENCY_UNIT + self.report.prefill.energy, self.text_generation_prefill_volume, unit=TEXT_GENERATION_EFFICIENCY_UNIT ) self.energy_tracker.reset() with self.energy_tracker.track(): _ = backend.generate(self.generate_inputs, self.config.generate_kwargs) + generate_energy = self.energy_tracker.get_energy() - self.report.decode.energy = self.energy_tracker.get_energy() - self.report.prefill.energy + self.report.decode.energy = generate_energy - forward_energy self.report.decode.efficiency = Efficiency.from_energy( - self.report.decode.energy, self.decode_volume, unit=DECODE_EFFICIENCY_UNIT + self.report.decode.energy, self.text_generation_decode_volume, unit=TEXT_GENERATION_EFFICIENCY_UNIT ) - def run_image_diffusion_energy_tracking(self, backend: Backend): + def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running energy tracking") self.energy_tracker.reset() with self.energy_tracker.track(): @@ -275,10 +301,10 @@ def run_image_diffusion_energy_tracking(self, backend: Backend): self.report.call.energy = self.energy_tracker.get_energy() self.report.call.efficiency = Efficiency.from_energy( - self.report.call.energy, self.call_volume, unit=CALL_EFFICIENCY_UNIT + self.report.call.energy, self.image_diffusion_volume, unit=IMAGE_DIFFUSION_EFFICIENCY_UNIT ) - def run_inference_energy_tracking(self, backend: Backend): + def run_inference_energy_tracking(self, backend: Backend[BackendConfigT]): LOGGER.info("\t+ Running energy tracking") self.energy_tracker.reset() with self.energy_tracker.track(): @@ -286,27 +312,31 @@ def run_inference_energy_tracking(self, backend: Backend): self.report.forward.energy = self.energy_tracker.get_energy() self.report.forward.efficiency = Efficiency.from_energy( - self.report.forward.energy, self.forward_volume, unit=EFFICIENCY_UNIT + self.report.forward.energy, self.inference_volume, unit=INFERENCE_EFFICIENCY_UNIT ) @property - def forward_volume(self) -> int: # in samples + def inference_volume(self) -> int: # in samples return self.config.input_shapes["batch_size"] @property - def prefill_volume(self) -> int: # in tokens + def image_diffusion_volume(self) -> int: # in images + return self.config.input_shapes["batch_size"] * self.config.call_kwargs["num_images_per_prompt"] + + @property + def text_generation_prefill_volume(self) -> int: # in tokens return self.config.input_shapes["batch_size"] * self.config.input_shapes["sequence_length"] @property - def call_volume(self) -> int: # in images - return self.config.input_shapes["batch_size"] * self.config.call_kwargs["num_images_per_prompt"] + def text_generation_per_token_volume(self) -> int: # in tokens + return self.config.input_shapes["batch_size"] * self.config.generate_kwargs["num_return_sequences"] @property - def decode_volume(self) -> int: # in tokens + def text_generation_decode_volume(self) -> int: # in tokens return ( self.config.input_shapes["batch_size"] * self.config.generate_kwargs["num_return_sequences"] - * self.config.generate_kwargs["max_new_tokens"] + * (self.config.generate_kwargs["max_new_tokens"] - 1) # 1 token is generated during prefill ) def get_report(self) -> InferenceReport: diff --git a/optimum_benchmark/benchmarks/inference/inputs_utils.py b/optimum_benchmark/benchmarks/inference/inputs_utils.py index f4dc5bd1..62bce47a 100644 --- a/optimum_benchmark/benchmarks/inference/inputs_utils.py +++ b/optimum_benchmark/benchmarks/inference/inputs_utils.py @@ -1,16 +1,16 @@ def extract_text_generation_inputs(inputs): if "pixel_values" in inputs: # image input - text_generation_inputs = {"inputs": inputs["pixel_values"]} + text_generation_inputs = {"pixel_values": inputs["pixel_values"]} elif "input_values" in inputs: # speech input - text_generation_inputs = {"inputs": inputs["input_values"]} + text_generation_inputs = {"input_values": inputs["input_values"]} elif "input_features" in inputs: # waveform input - text_generation_inputs = {"inputs": inputs["input_features"]} + text_generation_inputs = {"input_features": inputs["input_features"]} elif "input_ids" in inputs: # text input - text_generation_inputs = {"inputs": inputs["input_ids"]} + text_generation_inputs = {"input_ids": inputs["input_ids"]} else: raise ValueError("Could not find any valid text generation inputs.") diff --git a/optimum_benchmark/launchers/torchrun/launcher.py b/optimum_benchmark/launchers/torchrun/launcher.py index c86630cb..33d34103 100644 --- a/optimum_benchmark/launchers/torchrun/launcher.py +++ b/optimum_benchmark/launchers/torchrun/launcher.py @@ -1,3 +1,4 @@ +import os from logging import getLogger from typing import Any, Callable, Dict, List @@ -80,11 +81,11 @@ def entrypoint(worker, queue, lock, log_level, *worker_args): This a pickalable function that correctly sets up the logging configuration """ - torch.distributed.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo") + rank = int(os.environ["RANK"]) + setup_logging(level=log_level, prefix=f"RANK-{rank}") if rank == 0 else None - rank = torch.distributed.get_rank() torch.cuda.set_device(rank) if torch.cuda.is_available() else None - setup_logging(level=log_level, prefix=f"RANK-{rank}") if rank == 0 else None + torch.distributed.init_process_group(backend="nccl" if torch.cuda.is_available() else "gloo") output = worker(*worker_args) diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index e7a10f6d..55d51a71 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -33,11 +33,11 @@ def __getitem__(self, index: int) -> float: else: return Latency.from_values(values=[self.values[index]], unit=self.unit) - def __sub__(self, scalar: float) -> "Latency": - if not isinstance(scalar, (int, float)): - raise ValueError(f"Cannot subtract non-scalar value from latency: {scalar}") + def __sub__(self, latency: "Latency") -> "Latency": + if not isinstance(latency, Latency): + raise ValueError(f"Cannot subtract {type(latency)} from Latency") - latencies = [lat - scalar for lat in self.values] + latencies = [lat - latency.mean for lat in self.values] return Latency.from_values(values=latencies, unit=self.unit) @staticmethod @@ -156,9 +156,6 @@ def get_latency(self) -> Latency: return Latency.from_values(latencies_list, unit=LATENCY_UNIT) - def get_throughput(self, volume: int, unit: str) -> Throughput: - return Throughput.from_latency(self.get_latency(), volume, unit) - class LatencyTrainerCallback(TrainerCallback): def __init__(self, device: str, backend: str) -> None: @@ -197,39 +194,71 @@ def get_latency(self) -> Latency: return Latency.from_values(latencies_list, unit=LATENCY_UNIT) - def get_throughput(self, volume: int, unit: str) -> Throughput: - return Throughput.from_latency(self.get_latency(), volume, unit) - class LatencyLogitsProcessor(LogitsProcessor): def __init__(self, device: str, backend: str): self.device = device self.backend = backend + self.distributed = is_torch_distributed_available() and torch.distributed.is_initialized() self.reset() def reset(self): - self.events: List[Union[float, torch.cuda.Event]] = [] + # for each generate (run) pass, we store the time of each token + self.run_events: List[List[Union[float, torch.cuda.Event]]] = [] + self.start_time: float = time.perf_counter() + + def get_elapsed_time(self) -> float: + return time.perf_counter() - self.start_time + + @contextmanager + def track(self): + if self.distributed: + torch.distributed.barrier(device_ids=[torch.cuda.current_device()] if self.device == "cuda" else None) + + self.tok_events: List[Union[float, torch.cuda.Event]] = [] + + yield # this is where generate is called, and for each token, we record an event + + self.run_events.append(self.tok_events) + + if self.distributed: + torch.distributed.barrier(device_ids=[torch.cuda.current_device()] if self.device == "cuda" else None) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): if self.device == "cuda" and self.backend == "pytorch": event = torch.cuda.Event(enable_timing=True) event.record() - self.events.append(event) + self.tok_events.append(event) else: - self.events.append(time.perf_counter()) + self.tok_events.append(time.perf_counter()) return scores - def get_latency(self) -> Latency: - if self.device == "cuda" and self.backend == "pytorch": - # synchronize the device to make sure all events have been recorded - torch.cuda.synchronize() - latencies_list = [self.events[i - 1].elapsed_time(self.events[i]) / 1e3 for i in range(1, len(self.events))] - else: - latencies_list = [(self.events[i] - self.events[i - 1]) for i in range(1, len(self.events))] + def get_per_token_latency(self) -> Latency: + latencies_list = [] + for tok_events in self.run_events: + if self.device == "cuda" and self.backend == "pytorch": + # synchronize the device to make sure all events have been recorded + torch.cuda.synchronize() + latencies_list.extend( + [tok_events[i - 1].elapsed_time(tok_events[i]) / 1e3 for i in range(1, len(tok_events))] + ) + else: + latencies_list.extend([(tok_events[i] - tok_events[i - 1]) for i in range(1, len(tok_events))]) return Latency.from_values(latencies_list, unit=LATENCY_UNIT) - def get_throughput(self, volume: int, unit: str) -> Throughput: - return Throughput.from_latency(self.get_latency(), volume, unit) + def get_decode_latency(self) -> Latency: + latencies_list = [] + for tok_events in self.run_events: + if self.device == "cuda" and self.backend == "pytorch": + # synchronize the device to make sure all events have been recorded + torch.cuda.synchronize() + latencies_list.append( + sum([tok_events[i - 1].elapsed_time(tok_events[i]) / 1e3 for i in range(1, len(tok_events))]) + ) + else: + latencies_list.append(sum([(tok_events[i] - tok_events[i - 1]) for i in range(1, len(tok_events))])) + + return Latency.from_values(latencies_list, unit=LATENCY_UNIT)