Skip to content

Commit

Permalink
token latency
Browse files Browse the repository at this point in the history
  • Loading branch information
mengfei25 committed Dec 24, 2024
1 parent 53fad64 commit 325a63b
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,6 +1803,7 @@ def generate(
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model)
self.token_latency = self.config.token_latency if hasattr(self.config, "token_latency") else None

# 2. Set generation parameters if not already defined
if synced_gpus is None:
Expand Down Expand Up @@ -2992,11 +2993,13 @@ def _sample(
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
latency_list = []

while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
):
# prepare model inputs
tic = time.time()
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
Expand Down Expand Up @@ -3064,14 +3067,15 @@ def _sample(

# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
latency_list.append(time.time() - tic)
del outputs

if streamer is not None:
streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
output_result = GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
Expand All @@ -3083,7 +3087,7 @@ def _sample(
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateDecoderOnlyOutput(
output_result = GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
Expand All @@ -3092,7 +3096,11 @@ def _sample(
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
output_result = input_ids
if self.token_latency is not None:
return (output_result, latency_list)
else:
return output_result

def _temporary_reorder_cache(self, past_key_values, beam_idx):
"""
Expand Down

0 comments on commit 325a63b

Please sign in to comment.