From 1e784bb4f5193087d05a457e395091d9d7102c95 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 4 Dec 2024 13:30:48 +0100 Subject: [PATCH] fix --- optimum_benchmark/scenarios/inference/scenario.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/optimum_benchmark/scenarios/inference/scenario.py b/optimum_benchmark/scenarios/inference/scenario.py index 668c29c1..d9318b57 100644 --- a/optimum_benchmark/scenarios/inference/scenario.py +++ b/optimum_benchmark/scenarios/inference/scenario.py @@ -131,11 +131,12 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: def init_trackers(self, backend: Backend[BackendConfigT]): if self.config.latency: - if backend.config.name in PER_TOKEN_BACKENDS: + if backend.config.task in TEXT_GENERATION_TASKS and backend.config.name in PER_TOKEN_BACKENDS: self.latency_tracker = PerTokenLatencyLogitsProcessor( backend=backend.config.name, device=backend.config.device, ) + self.config.generate_kwargs["logits_processor"] = LogitsProcessorList([self.latency_tracker]) else: self.latency_tracker = LatencyTracker( backend=backend.config.name, @@ -233,10 +234,6 @@ def run_inference_memory_tracking(self, backend: Backend[BackendConfigT]): ## Latency tracking def run_per_token_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Running Per-Token Text Generation latency tracking") - per_token_kwargs = { - **self.config.generate_kwargs, - "logits_processor": LogitsProcessorList([self.latency_tracker]), - } self.latency_tracker.reset() while ( @@ -244,7 +241,7 @@ def run_per_token_text_generation_latency_tracking(self, backend: Backend[Backen or self.latency_tracker.count() < self.config.iterations ): with self.latency_tracker.track(): - _ = backend.generate(self.inputs, per_token_kwargs) + _ = backend.generate(self.inputs, self.config.generate_kwargs) per_token_latency = self.latency_tracker.get_per_token_latency() prefill_latency = self.latency_tracker.get_prefill_latency()