From 555f97becbf69f283040af67d20636d1be7e93d6 Mon Sep 17 00:00:00 2001 From: Moritz Schlager <87517800+mo374z@users.noreply.github.com> Date: Wed, 28 Aug 2024 11:02:35 +0200 Subject: [PATCH] community[patch]: fix model initialization bug for deepinfra (#25727) ### Description adds an init method to ChatDeepInfra to set the model_name attribute accordings to the argument ### Issue currently, the model_name specified by the user during initialization of the ChatDeepInfra class is never set. Therefore, it always chooses the default model (meta-llama/Llama-2-70b-chat-hf, however probably since this is deprecated it always uses meta-llama/Llama-3-70b-Instruct). We stumbled across this issue and fixed it as proposed in this pull request. Feel free to change the fix according to your coding guidelines and style, this is just a proposal and we want to draw attention to this problem. ### Dependencies no additional dependencies required Feel free to contact me or @timo282 and @finitearth if you have any questions. --------- Co-authored-by: Bagatur Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> --- .../langchain_community/chat_models/deepinfra.py | 5 +++++ .../tests/unit_tests/chat_models/test_deepinfra.py | 11 +++++++++++ 2 files changed, 16 insertions(+) create mode 100644 libs/community/tests/unit_tests/chat_models/test_deepinfra.py diff --git a/libs/community/langchain_community/chat_models/deepinfra.py b/libs/community/langchain_community/chat_models/deepinfra.py index 37fef6763b77e..3de532b0c4912 100644 --- a/libs/community/langchain_community/chat_models/deepinfra.py +++ b/libs/community/langchain_community/chat_models/deepinfra.py @@ -222,6 +222,11 @@ class ChatDeepInfra(BaseChatModel): streaming: bool = False max_retries: int = 1 + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" diff --git a/libs/community/tests/unit_tests/chat_models/test_deepinfra.py b/libs/community/tests/unit_tests/chat_models/test_deepinfra.py new file mode 100644 index 0000000000000..48a1d08a2de60 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_deepinfra.py @@ -0,0 +1,11 @@ +from langchain_community.chat_models import ChatDeepInfra + + +def test_deepinfra_model_name_param() -> None: + llm = ChatDeepInfra(model_name="foo") # type: ignore[call-arg] + assert llm.model_name == "foo" + + +def test_deepinfra_model_param() -> None: + llm = ChatDeepInfra(model="foo") + assert llm.model_name == "foo"