Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 4, 2024
1 parent 74d15d2 commit 1e784bb
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions optimum_benchmark/scenarios/inference/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport:

def init_trackers(self, backend: Backend[BackendConfigT]):
if self.config.latency:
if backend.config.name in PER_TOKEN_BACKENDS:
if backend.config.task in TEXT_GENERATION_TASKS and backend.config.name in PER_TOKEN_BACKENDS:
self.latency_tracker = PerTokenLatencyLogitsProcessor(
backend=backend.config.name,
device=backend.config.device,
)
self.config.generate_kwargs["logits_processor"] = LogitsProcessorList([self.latency_tracker])
else:
self.latency_tracker = LatencyTracker(
backend=backend.config.name,
Expand Down Expand Up @@ -233,18 +234,14 @@ def run_inference_memory_tracking(self, backend: Backend[BackendConfigT]):
## Latency tracking
def run_per_token_text_generation_latency_tracking(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Running Per-Token Text Generation latency tracking")
per_token_kwargs = {
**self.config.generate_kwargs,
"logits_processor": LogitsProcessorList([self.latency_tracker]),
}

self.latency_tracker.reset()
while (
self.latency_tracker.elapsed() < self.config.duration
or self.latency_tracker.count() < self.config.iterations
):
with self.latency_tracker.track():
_ = backend.generate(self.inputs, per_token_kwargs)
_ = backend.generate(self.inputs, self.config.generate_kwargs)

per_token_latency = self.latency_tracker.get_per_token_latency()
prefill_latency = self.latency_tracker.get_prefill_latency()
Expand Down

0 comments on commit 1e784bb

Please sign in to comment.