Skip to content

Commit

Permalink
community[patch]: fix model initialization bug for deepinfra (#25727)
Browse files Browse the repository at this point in the history
### Description
adds an init method to ChatDeepInfra to set the model_name attribute
accordings to the argument
### Issue
currently, the model_name specified by the user during initialization of
the ChatDeepInfra class is never set. Therefore, it always chooses the
default model (meta-llama/Llama-2-70b-chat-hf, however probably since
this is deprecated it always uses meta-llama/Llama-3-70b-Instruct). We
stumbled across this issue and fixed it as proposed in this pull
request. Feel free to change the fix according to your coding guidelines
and style, this is just a proposal and we want to draw attention to this
problem.
### Dependencies
no additional dependencies required

Feel free to contact me or @timo282 and @finitearth if you have any
questions.

---------

Co-authored-by: Bagatur <[email protected]>
Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
3 people authored Aug 28, 2024
1 parent a052173 commit 555f97b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
5 changes: 5 additions & 0 deletions libs/community/langchain_community/chat_models/deepinfra.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ class ChatDeepInfra(BaseChatModel):
streaming: bool = False
max_retries: int = 1

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
Expand Down
11 changes: 11 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_deepinfra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from langchain_community.chat_models import ChatDeepInfra


def test_deepinfra_model_name_param() -> None:
llm = ChatDeepInfra(model_name="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"


def test_deepinfra_model_param() -> None:
llm = ChatDeepInfra(model="foo")
assert llm.model_name == "foo"

0 comments on commit 555f97b

Please sign in to comment.