From 7f06674bc387f571aed05636a2b5aaebb9a5818a Mon Sep 17 00:00:00 2001 From: enesbol Date: Tue, 15 Oct 2024 20:39:35 +0200 Subject: [PATCH] Condense param filtering and move vllm import to local scope --- libs/community/langchain_community/llms/vllm.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/libs/community/langchain_community/llms/vllm.py b/libs/community/langchain_community/llms/vllm.py index ccaf629ec45e5..dc8a7a76d24ed 100644 --- a/libs/community/langchain_community/llms/vllm.py +++ b/libs/community/langchain_community/llms/vllm.py @@ -5,7 +5,6 @@ from langchain_core.outputs import Generation, LLMResult from langchain_core.utils import pre_init from pydantic import Field -from vllm import SamplingParams from langchain_community.llms.openai import BaseOpenAI from langchain_community.utils.openai import is_openai_v1 @@ -124,21 +123,19 @@ def _generate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" + from vllm import SamplingParams # build sampling parameters params = {**self._default_params, **kwargs, "stop": stop} - # Filter params for SamplingParams - sampling_param_keys = SamplingParams.__annotations__.keys() - sampling_params_dict = { - k: v for k, v in params.items() if k in sampling_param_keys - } - - # Create SamplingParams instance for sampling - sampling_params = SamplingParams(**sampling_params_dict) + # filter params for SamplingParams + known_keys = SamplingParams.__annotations__.keys() + sample_params = SamplingParams( + **{k: v for k, v in params.items() if k in known_keys} + ) # call the model - outputs = self.client.generate(prompts, sampling_params) + outputs = self.client.generate(prompts, sample_params) generations = [] for output in outputs: