diff --git a/optimum_benchmark/backends/vllm/backend.py b/optimum_benchmark/backends/vllm/backend.py index e90f3e7e..eadd6c0a 100644 --- a/optimum_benchmark/backends/vllm/backend.py +++ b/optimum_benchmark/backends/vllm/backend.py @@ -117,35 +117,33 @@ def batch_offline_engine_generate(self, inputs: Dict[str, Any], kwargs: Dict[str self.pretrained_model.add_request( inputs=prompt, request_id=str(i), - params=SamplingParams( - ignore_eos=True, - detokenize=True, - seed=self.config.seed, - n=kwargs.get("num_return_sequences"), - max_tokens=kwargs.get("max_new_tokens"), - min_tokens=kwargs.get("min_new_tokens"), - use_beam_search=kwargs.get("num_beams") > 1, - logits_processors=kwargs.get("logits_processors", None), - ), + params=self.get_sampling_params(kwargs), ) while self.pretrained_model.has_unfinished_requests(): self.pretrained_model.step() + def get_sampling_params(self, kwargs: Dict[str, Any]) -> SamplingParams: + params = SamplingParams( + ignore_eos=True, + detokenize=True, + seed=self.config.seed, + n=kwargs.get("num_return_sequences"), + max_tokens=kwargs.get("max_new_tokens"), + min_tokens=kwargs.get("min_new_tokens"), + logits_processors=kwargs.get("logits_processors", None), + ) + # following huggingface transformers implementation + # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/generation/beam_search.py#L534 + if kwargs.get("num_beams") > 1: + params.logprobs = 2 * kwargs.get("num_beams") + return params + async def single_online_engine_generate(self, prompt: str, request_id: str, kwargs: Dict[str, Any]) -> Any: stream = await self.pretrained_model.add_request( inputs=prompt, request_id=request_id, - params=SamplingParams( - ignore_eos=True, - detokenize=True, - seed=self.config.seed, - n=kwargs.get("num_return_sequences"), - max_tokens=kwargs.get("max_new_tokens"), - min_tokens=kwargs.get("min_new_tokens"), - use_beam_search=kwargs.get("num_beams") > 1, - logits_processors=kwargs.get("logits_processors", None), - ), + params=self.get_sampling_params(kwargs), ) async for _ in stream: