Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 9, 2024
1 parent 87fdb56 commit 026ddcd
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions optimum_benchmark/trackers/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def get_latency(self) -> Latency:
]
else:
latencies = [
(start_event - end_event) for start_event, end_event in zip(self.start_events, self.end_events)
(end_event - start_event) for start_event, end_event in zip(self.start_events, self.end_events)
]

assert all(latency >= 0 for latency in latencies)
Expand Down Expand Up @@ -357,7 +357,7 @@ def track(self):
self.per_token_start_events.extend(self.per_token_events[:-1])
self.per_token_end_events.extend(self.per_token_events[1:])

def __call__(self, *args, scores: torch.FloatTensor, **kwargs):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
if self.is_pytorch_cuda:
event = torch.cuda.Event(enable_timing=True)
event.record()
Expand All @@ -383,7 +383,7 @@ def get_prefill_latency(self) -> Latency:
]
else:
latencies = [
(start_event - end_event)
(end_event - start_event)
for start_event, end_event in zip(self.prefill_start_events, self.prefill_end_events)
]

Expand All @@ -401,7 +401,7 @@ def get_decode_latency(self) -> Latency:
]
else:
latencies = [
(start_event - end_event)
(end_event - start_event)
for start_event, end_event in zip(self.decode_start_events, self.decode_end_events)
]

Expand All @@ -419,7 +419,7 @@ def get_per_token_latency(self) -> Latency:
]
else:
latencies = [
(start_event - end_event)
(end_event - start_event)
for start_event, end_event in zip(self.per_token_start_events, self.per_token_end_events)
]

Expand Down Expand Up @@ -516,7 +516,7 @@ def get_step_latency(self) -> Latency:
]
else:
latencies = [
(start_event - end_event)
(end_event - start_event)
for start_event, end_event in zip(self.per_step_start_events, self.per_step_end_events)
]

Expand All @@ -534,7 +534,7 @@ def get_call_latency(self) -> Latency:
]
else:
latencies = [
(start_event - end_event)
(end_event - start_event)
for start_event, end_event in zip(self.call_start_events, self.call_end_events)
]

Expand Down Expand Up @@ -583,7 +583,7 @@ def get_latency(self) -> Latency:
]
else:
latencies = [
(start_event - end_event) for start_event, end_event in zip(self.start_events, self.end_events)
(end_event - start_event) for start_event, end_event in zip(self.start_events, self.end_events)
]

assert all(latency >= 0 for latency in latencies)
Expand Down

0 comments on commit 026ddcd

Please sign in to comment.