diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 01a3c7a1..0555578f 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -275,7 +275,7 @@ def get_latency(self) -> Latency: ] else: latencies = [ - (start_event - end_event) for start_event, end_event in zip(self.start_events, self.end_events) + (end_event - start_event) for start_event, end_event in zip(self.start_events, self.end_events) ] assert all(latency >= 0 for latency in latencies) @@ -357,7 +357,7 @@ def track(self): self.per_token_start_events.extend(self.per_token_events[:-1]) self.per_token_end_events.extend(self.per_token_events[1:]) - def __call__(self, *args, scores: torch.FloatTensor, **kwargs): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): if self.is_pytorch_cuda: event = torch.cuda.Event(enable_timing=True) event.record() @@ -383,7 +383,7 @@ def get_prefill_latency(self) -> Latency: ] else: latencies = [ - (start_event - end_event) + (end_event - start_event) for start_event, end_event in zip(self.prefill_start_events, self.prefill_end_events) ] @@ -401,7 +401,7 @@ def get_decode_latency(self) -> Latency: ] else: latencies = [ - (start_event - end_event) + (end_event - start_event) for start_event, end_event in zip(self.decode_start_events, self.decode_end_events) ] @@ -419,7 +419,7 @@ def get_per_token_latency(self) -> Latency: ] else: latencies = [ - (start_event - end_event) + (end_event - start_event) for start_event, end_event in zip(self.per_token_start_events, self.per_token_end_events) ] @@ -516,7 +516,7 @@ def get_step_latency(self) -> Latency: ] else: latencies = [ - (start_event - end_event) + (end_event - start_event) for start_event, end_event in zip(self.per_step_start_events, self.per_step_end_events) ] @@ -534,7 +534,7 @@ def get_call_latency(self) -> Latency: ] else: latencies = [ - (start_event - end_event) + (end_event - start_event) for start_event, end_event in zip(self.call_start_events, self.call_end_events) ] @@ -583,7 +583,7 @@ def get_latency(self) -> Latency: ] else: latencies = [ - (start_event - end_event) for start_event, end_event in zip(self.start_events, self.end_events) + (end_event - start_event) for start_event, end_event in zip(self.start_events, self.end_events) ] assert all(latency >= 0 for latency in latencies)