Skip to content

Commit

Permalink
Condense param filtering and move vllm import to local scope
Browse files Browse the repository at this point in the history
  • Loading branch information
enesbol committed Oct 15, 2024
1 parent f189cf1 commit 7f06674
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions libs/community/langchain_community/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from langchain_core.outputs import Generation, LLMResult
from langchain_core.utils import pre_init
from pydantic import Field
from vllm import SamplingParams

from langchain_community.llms.openai import BaseOpenAI
from langchain_community.utils.openai import is_openai_v1
Expand Down Expand Up @@ -124,21 +123,19 @@ def _generate(
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
from vllm import SamplingParams

# build sampling parameters
params = {**self._default_params, **kwargs, "stop": stop}

# Filter params for SamplingParams
sampling_param_keys = SamplingParams.__annotations__.keys()
sampling_params_dict = {
k: v for k, v in params.items() if k in sampling_param_keys
}

# Create SamplingParams instance for sampling
sampling_params = SamplingParams(**sampling_params_dict)
# filter params for SamplingParams
known_keys = SamplingParams.__annotations__.keys()
sample_params = SamplingParams(
**{k: v for k, v in params.items() if k in known_keys}
)

# call the model
outputs = self.client.generate(prompts, sampling_params)
outputs = self.client.generate(prompts, sample_params)

generations = []
for output in outputs:
Expand Down

0 comments on commit 7f06674

Please sign in to comment.