From 69bbe229a88807c5708c84f13bc259d8cad1e5c4 Mon Sep 17 00:00:00 2001 From: Siddharth Venkatesan Date: Thu, 23 Jan 2025 15:50:48 -0800 Subject: [PATCH] [lmi][vllm] do not require do_sample to enable sampling (#2676) --- .../rolling_batch/lmi_dist_rolling_batch.py | 21 ++++++++-------- .../rolling_batch/vllm_rolling_batch.py | 25 +++++++++---------- .../user_guides/lmi_input_output_schema.md | 5 ++++ 3 files changed, 27 insertions(+), 24 deletions(-) diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index 5f78489d7..0c2dd666a 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -154,28 +154,27 @@ def translate_lmi_dist_params(self, parameters: dict): :return: The same parameters dict, but with lmi-dist style parameter names. """ parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) - # If `do_sample` is not provided, force temperature=0.0, i.e. greedy - # else set to user-provided value or default to 1.0 - if not parameters.pop('do_sample', False): - parameters['temperature'] = 0.0 - else: - parameters['temperature'] = parameters.get('temperature', 1.0) + do_sample = parameters.pop("do_sample", None) + if do_sample is not None and do_sample is False: + parameters["temperature"] = 0.0 + if do_sample is None and parameters.get("temperature") is None: + parameters["temperature"] = 0.0 if "seed" in parameters.keys(): parameters["seed"] = int(parameters["seed"]) - if "stop_sequences" in parameters.keys(): + if "stop_sequences" in parameters: parameters["stop"] = parameters.pop("stop_sequences") - if "ignore_eos_token" in parameters.keys(): + if "ignore_eos_token" in parameters: parameters["ignore_eos"] = parameters.pop("ignore_eos_token") - if "num_beams" in parameters.keys(): + if "num_beams" in parameters: parameters["best_of"] = parameters.pop("num_beams") parameters["use_beam_search"] = True if parameters.pop("decoder_input_details", False): parameters["prompt_logprobs"] = 1 - if "best_of" in parameters.keys(): + if "best_of" in parameters: # if n is not explicitly set, we return `best_of` values sequences. if "n" not in "best_of": parameters["n"] = parameters["best_of"] - if "top_n_tokens" in parameters.keys(): + if "top_n_tokens" in parameters: parameters["logprobs"] = parameters.pop("top_n_tokens") else: parameters["logprobs"] = parameters.get("logprobs", 1) diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 6745ced4b..45262af82 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -10,6 +10,7 @@ # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" # BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for # the specific language governing permissions and limitations under the License. +import logging from collections import OrderedDict, defaultdict from vllm import LLMEngine, SamplingParams @@ -85,31 +86,29 @@ def translate_vllm_params(self, parameters: dict) -> dict: :return: The same parameters dict, but with VLLM style parameter names. """ parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) - if "seed" in parameters.keys(): + do_sample = parameters.pop("do_sample", None) + if do_sample is not None and do_sample is False: + parameters["temperature"] = 0.0 + if do_sample is None and parameters.get("temperature") is None: + parameters["temperature"] = 0.0 + if "seed" in parameters: parameters["seed"] = int(parameters["seed"]) - - # If `do_sample` is not provided, force temperature=0.0, i.e. greedy - # else set to user-provided value or default to 1.0 - if not parameters.pop('do_sample', False): - parameters['temperature'] = 0.0 - else: - parameters['temperature'] = parameters.get('temperature', 1.0) - if "stop_sequences" in parameters.keys(): + if "stop_sequences" in parameters: parameters["stop"] = parameters.pop("stop_sequences") - if "ignore_eos_token" in parameters.keys(): + if "ignore_eos_token" in parameters: parameters["ignore_eos"] = parameters.pop("ignore_eos_token") - if "num_beams" in parameters.keys(): + if "num_beams" in parameters: parameters["best_of"] = parameters.pop("num_beams") parameters["use_beam_search"] = True if parameters.pop("decoder_input_details", False): parameters["prompt_logprobs"] = 1 # if n is not explicitly set when best_of is set, we return `best_of` values sequences for tgi compatibility. - if "best_of" in parameters.keys(): + if "best_of" in parameters: if "n" not in "best_of": parameters["n"] = parameters["best_of"] - if "top_n_tokens" in parameters.keys(): + if "top_n_tokens" in parameters: parameters["logprobs"] = parameters.pop("top_n_tokens") else: parameters["logprobs"] = parameters.get("logprobs", 1) diff --git a/serving/docs/lmi/user_guides/lmi_input_output_schema.md b/serving/docs/lmi/user_guides/lmi_input_output_schema.md index 11747fd38..be93ceb50 100644 --- a/serving/docs/lmi/user_guides/lmi_input_output_schema.md +++ b/serving/docs/lmi/user_guides/lmi_input_output_schema.md @@ -289,6 +289,11 @@ If you are not specifying a specific engine or rolling batch implementation, we If you are deploying with a specific backend, additional parameters are available that are unique to the specific backend. +**Note:** +To enable sampling in LMI <= 0.31.0, you must specify `do_sample: true` in addition to any sampling parameters you set. +This behavior will change starting LMI 0.32.0 where you will no longer be required to set `do_sample`, +it will be inferred from the other sampling parameters. + #### Additional LMI Dist Generation parameters ```