Skip to content

Commit

Permalink
[lmi][vllm] do not require do_sample to enable sampling (#2676)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk authored Jan 23, 2025
1 parent c8da32c commit 69bbe22
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,28 +154,27 @@ def translate_lmi_dist_params(self, parameters: dict):
:return: The same parameters dict, but with lmi-dist style parameter names.
"""
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
# If `do_sample` is not provided, force temperature=0.0, i.e. greedy
# else set to user-provided value or default to 1.0
if not parameters.pop('do_sample', False):
parameters['temperature'] = 0.0
else:
parameters['temperature'] = parameters.get('temperature', 1.0)
do_sample = parameters.pop("do_sample", None)
if do_sample is not None and do_sample is False:
parameters["temperature"] = 0.0
if do_sample is None and parameters.get("temperature") is None:
parameters["temperature"] = 0.0
if "seed" in parameters.keys():
parameters["seed"] = int(parameters["seed"])
if "stop_sequences" in parameters.keys():
if "stop_sequences" in parameters:
parameters["stop"] = parameters.pop("stop_sequences")
if "ignore_eos_token" in parameters.keys():
if "ignore_eos_token" in parameters:
parameters["ignore_eos"] = parameters.pop("ignore_eos_token")
if "num_beams" in parameters.keys():
if "num_beams" in parameters:
parameters["best_of"] = parameters.pop("num_beams")
parameters["use_beam_search"] = True
if parameters.pop("decoder_input_details", False):
parameters["prompt_logprobs"] = 1
if "best_of" in parameters.keys():
if "best_of" in parameters:
# if n is not explicitly set, we return `best_of` values sequences.
if "n" not in "best_of":
parameters["n"] = parameters["best_of"]
if "top_n_tokens" in parameters.keys():
if "top_n_tokens" in parameters:
parameters["logprobs"] = parameters.pop("top_n_tokens")
else:
parameters["logprobs"] = parameters.get("logprobs", 1)
Expand Down
25 changes: 12 additions & 13 deletions engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from collections import OrderedDict, defaultdict

from vllm import LLMEngine, SamplingParams
Expand Down Expand Up @@ -85,31 +86,29 @@ def translate_vllm_params(self, parameters: dict) -> dict:
:return: The same parameters dict, but with VLLM style parameter names.
"""
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
if "seed" in parameters.keys():
do_sample = parameters.pop("do_sample", None)
if do_sample is not None and do_sample is False:
parameters["temperature"] = 0.0
if do_sample is None and parameters.get("temperature") is None:
parameters["temperature"] = 0.0
if "seed" in parameters:
parameters["seed"] = int(parameters["seed"])

# If `do_sample` is not provided, force temperature=0.0, i.e. greedy
# else set to user-provided value or default to 1.0
if not parameters.pop('do_sample', False):
parameters['temperature'] = 0.0
else:
parameters['temperature'] = parameters.get('temperature', 1.0)
if "stop_sequences" in parameters.keys():
if "stop_sequences" in parameters:
parameters["stop"] = parameters.pop("stop_sequences")
if "ignore_eos_token" in parameters.keys():
if "ignore_eos_token" in parameters:
parameters["ignore_eos"] = parameters.pop("ignore_eos_token")
if "num_beams" in parameters.keys():
if "num_beams" in parameters:
parameters["best_of"] = parameters.pop("num_beams")
parameters["use_beam_search"] = True
if parameters.pop("decoder_input_details", False):
parameters["prompt_logprobs"] = 1

# if n is not explicitly set when best_of is set, we return `best_of` values sequences for tgi compatibility.
if "best_of" in parameters.keys():
if "best_of" in parameters:
if "n" not in "best_of":
parameters["n"] = parameters["best_of"]

if "top_n_tokens" in parameters.keys():
if "top_n_tokens" in parameters:
parameters["logprobs"] = parameters.pop("top_n_tokens")
else:
parameters["logprobs"] = parameters.get("logprobs", 1)
Expand Down
5 changes: 5 additions & 0 deletions serving/docs/lmi/user_guides/lmi_input_output_schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ If you are not specifying a specific engine or rolling batch implementation, we

If you are deploying with a specific backend, additional parameters are available that are unique to the specific backend.

**Note:**
To enable sampling in LMI <= 0.31.0, you must specify `do_sample: true` in addition to any sampling parameters you set.
This behavior will change starting LMI 0.32.0 where you will no longer be required to set `do_sample`,
it will be inferred from the other sampling parameters.

#### Additional LMI Dist Generation parameters

```
Expand Down

0 comments on commit 69bbe22

Please sign in to comment.