From 7b935559c4a22fbec7055caad0ebd71296bf99f2 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Thu, 30 Jan 2025 12:37:31 -0800 Subject: [PATCH] [python] Update rolling batch params to output delta (#2636) --- .../rolling_batch/lmi_dist_rolling_batch.py | 3 +- .../rolling_batch/rolling_batch_vllm_utils.py | 41 ++----- .../rolling_batch/vllm_rolling_batch.py | 2 + .../djl_python/tests/test_rb_vllm_utils.py | 108 ++---------------- 4 files changed, 21 insertions(+), 133 deletions(-) diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 0c2dd666a..773b3e3eb 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -19,7 +19,7 @@ from lmi_dist.arg_utils import VllmEngineArgs from lmi_dist.init_engine import engine_from_args from lmi_dist.seq2seq_engine import Seq2SeqPreprocessor -from vllm import SamplingParams +from vllm.sampling_params import RequestOutputKind from vllm.utils import AtomicCounter from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params @@ -153,6 +153,7 @@ def translate_lmi_dist_params(self, parameters: dict): :return: The same parameters dict, but with lmi-dist style parameter names. """ + parameters["output_kind"] = RequestOutputKind.DELTA parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) do_sample = parameters.pop("do_sample", None) if do_sample is not None and do_sample is False: diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index 7b3c3ce8d..688c00484 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -74,7 +74,7 @@ def update_request_cache_with_output(request_cache: OrderedDict, request_output.prompt_tokens_details.append(prompt_token) # sets the details of all sequences - update_multiple_sequences(cache, request_output, vllm_request_output) + update_multiple_sequences(request_output, vllm_request_output) # remove finished requests from cache if vllm_request_output.finished: @@ -89,49 +89,28 @@ def update_request_cache_with_output(request_cache: OrderedDict, return request_cache -def update_multiple_sequences(cache, request_output, vllm_request_output): +def update_multiple_sequences(request_output, vllm_request_output): for completion_output in vllm_request_output.outputs: - sequence_index = completion_output.index - if f"sequence_index_{sequence_index}" not in cache: - cache[f"sequence_index_{sequence_index}"] = { - "curr_length": 0, - "num_generated_tokens": 0 - } if sequence_index not in request_output.sequences: request_output.sequences[sequence_index] = Sequence() - # set token of the sequence - # previous length of token ids generated - prev_len = cache[f"sequence_index_{sequence_index}"][ - 'num_generated_tokens'] - # curr length of the token ids generated so far - cur_len = len(completion_output.token_ids) - cache[f"sequence_index_{sequence_index}"][ - "num_generated_tokens"] = cur_len - # get the newly generated token_ids - new_token_ids = completion_output.token_ids[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.token_ids + new_token_ids = completion_output.token_ids # get the newly generated token texts for speculative decoding output_token_texts = [] if hasattr(completion_output, "output_token_texts"): - output_token_texts = completion_output.output_token_texts[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.output_token_texts + output_token_texts = completion_output.output_token_texts top_tokens = [] token_texts = [] # calculate log probs and token_texts if completion_output.logprobs: - new_logprobs_list = completion_output.logprobs[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.logprobs new_logprobs = [] - for token_id, logprobs in zip(new_token_ids, new_logprobs_list): + for token_id, logprobs in zip(new_token_ids, + completion_output.logprobs): new_logprobs.append(logprobs[token_id].logprob) decoded_token = logprobs[token_id].decoded_token if logprobs[ token_id].decoded_token else "" @@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): Token(id=token_id_key, text=logprob.decoded_token, log_prob=logprob.logprob)) - elif new_token_ids: # TODO: Test and remove this. logprobs is always set 1. This case should never happen. new_logprobs = [None] * len(new_token_ids) - curr_length = cache[f"sequence_index_{sequence_index}"][ - "curr_length"] - token_texts.append(completion_output.text[curr_length:]) + token_texts.append(completion_output.text) if not output_token_texts: if len(token_texts) != len(new_token_ids): @@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): request_output.sequences[sequence_index].set_next_top_tokens( top_tokens) - cache[f"sequence_index_{sequence_index}"]["curr_length"] = len( - completion_output.text) - def get_speculative_decoding_metrics_record( completion_output: CompletionOutput, diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 45262af82..33022176f 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -14,6 +14,7 @@ from collections import OrderedDict, defaultdict from vllm import LLMEngine, SamplingParams +from vllm.sampling_params import RequestOutputKind from vllm.utils import random_uuid, AtomicCounter from djl_python.request import Request @@ -85,6 +86,7 @@ def translate_vllm_params(self, parameters: dict) -> dict: :return: The same parameters dict, but with VLLM style parameter names. """ + parameters["output_kind"] = RequestOutputKind.DELTA parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) do_sample = parameters.pop("do_sample", None) if do_sample is not None and do_sample is False: diff --git a/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py b/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py index 41486fb20..c9aa52716 100644 --- a/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py +++ b/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py @@ -1,6 +1,5 @@ import sys import unittest -import uuid from dataclasses import dataclass from typing import List, Optional, Dict, Union from collections import OrderedDict @@ -12,8 +11,8 @@ import djl_python from djl_python.output_formatter import _json_output_formatter from djl_python.request import Request -from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token, RequestInput -'''These Mock classes are in compliance with vllm RequestOutput version 0.5.3.post1''' +from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token +'''These Mock classes are in compliance with vllm RequestOutput version 0.6.3.post1''' @dataclass @@ -148,23 +147,10 @@ def __init__( ], outputs=[ MockCompletionOutput(index=1, - text=' member of', - token_ids=[4292, 302], + text=' of', + token_ids=[302], cumulative_logprob=-4.3041129764169455, logprobs=[{ - 4292: - MockLogprob(logprob=-4.2740092277526855, - rank=4, - decoded_token=' member'), - 2032: - MockLogprob(logprob=-3.0240092277526855, - rank=1, - decoded_token=' big'), - 888: - MockLogprob(logprob=-4.4099884033203125, - rank=3, - decoded_token=' new'), - }, { 302: MockLogprob(logprob=-0.03010374866425991, rank=1, @@ -181,27 +167,10 @@ def __init__( finish_reason=None, stop_reason=None), MockCompletionOutput(index=0, - text=' consolidated', - token_ids=[22968, 601], + text='ated', + token_ids=[601], cumulative_logprob=-13.402491569519043, logprobs=[{ - 22968: - MockLogprob(logprob=-12.117759704589844, - rank=5308, - decoded_token=' consolid'), - 2032: - MockLogprob(logprob=-3.0240092277526855, - rank=1, - decoded_token=' big'), - 17372: - MockLogprob(logprob=-13.409988403320312, - rank=10489, - decoded_token=' crown'), - 888: - MockLogprob(logprob=-4.4099884033203125, - rank=3, - decoded_token=' new'), - }, { 601: MockLogprob(logprob=-1.2847318649291992, rank=2, @@ -235,37 +204,10 @@ def __init__( ], outputs=[ MockCompletionOutput(index=1, - text=' member of the', - token_ids=[4292, 302, - 272], + text=' the', + token_ids=[272], cumulative_logprob=-4.815703457221389, logprobs=[{ - 4292: - MockLogprob(logprob=-4.2740092277526855, - rank=4, - decoded_token=' member'), - 2032: - MockLogprob(logprob=-3.0240092277526855, - rank=1, - decoded_token=' big'), - 888: - MockLogprob(logprob=-4.4099884033203125, - rank=3, - decoded_token=' new'), - }, { - 302: - MockLogprob(logprob=-0.03010374866425991, - rank=1, - decoded_token=' of'), - 235290: - MockLogprob(logprob=-2.2026185989379883, - rank=1, - decoded_token='-'), - 578: - MockLogprob(logprob=-2.2026185989379883, - rank=2, - decoded_token=' and') - }, { 272: MockLogprob(logprob=-0.5115904808044434, rank=1, @@ -282,40 +224,10 @@ def __init__( finish_reason='length', stop_reason=None), MockCompletionOutput(index=0, - text=' consolidated or', - token_ids=[22968, 601, 442], + text=' or', + token_ids=[442], cumulative_logprob=-20.4010648727417, logprobs=[{ - 22968: - MockLogprob(logprob=-12.117759704589844, - rank=5308, - decoded_token=' consolid'), - 2032: - MockLogprob(logprob=-3.0240092277526855, - rank=1, - decoded_token=' big'), - 17372: - MockLogprob(logprob=-13.409988403320312, - rank=10489, - decoded_token=' crown'), - 888: - MockLogprob(logprob=-4.4099884033203125, - rank=3, - decoded_token=' new'), - }, { - 601: - MockLogprob(logprob=-1.2847318649291992, - rank=2, - decoded_token='ated'), - 1028: - MockLogprob(logprob=-0.909731924533844, - rank=1, - decoded_token='ator'), - 1162: - MockLogprob(logprob=-0.8929234743118286, - rank=2, - decoded_token=' year') - }, { 442: MockLogprob(logprob=-6.998573303222656, rank=188,