From 081deb8beb6989912d8107b196ee6ad69ca78f5a Mon Sep 17 00:00:00 2001 From: Baber Abbasi <92168766+baberabb@users.noreply.github.com> Date: Wed, 24 Jan 2024 02:44:27 +0500 Subject: [PATCH] manage default (greedy) gen_kwargs in vllm (#1341) * manage default (greedy) gen_kwargs in vllm better * mirror HF `do_sample` * just need to set temp=0 for greedy --- lm_eval/models/vllm_causallms.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index 6912428e..5c208d90 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -170,14 +170,8 @@ def _model_generate( stop: Optional[List[str]] = None, **kwargs, ): - if "do_sample" in kwargs.keys(): - kwargs.pop("do_sample") if generate: - # hf defaults - kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False) - kwargs["spaces_between_special_tokens"] = kwargs.get( - "spaces_between_special_tokens", False - ) + kwargs = self.modify_gen_kwargs(kwargs) sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) else: sampling_params = SamplingParams( @@ -438,3 +432,16 @@ def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]: break return continuation_logprobs, is_greedy + + @staticmethod + def modify_gen_kwargs(kwargs: dict) -> dict: + # sampling_params + do_sample = kwargs.pop("do_sample", False) + if do_sample is not True: + kwargs["temperature"] = 0.0 + # hf defaults + kwargs["skip_special_tokens"] = kwargs.get("skip_special_tokens", False) + kwargs["spaces_between_special_tokens"] = kwargs.get( + "spaces_between_special_tokens", False + ) + return kwargs