From 12fea5b868edd12b0d576e7f8bfc922d0167eeab Mon Sep 17 00:00:00 2001 From: Jorge Piedrahita Ortiz Date: Tue, 15 Oct 2024 12:24:36 -0500 Subject: [PATCH] community: sambastudio chat model integration minor fix (#27238) **Description:** sambastudio chat model integration minor fix fix default params fix usage metadata when streaming --- .../chat_models/sambanova.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/libs/community/langchain_community/chat_models/sambanova.py b/libs/community/langchain_community/chat_models/sambanova.py index cd95f4beefa5f..62bc366486805 100644 --- a/libs/community/langchain_community/chat_models/sambanova.py +++ b/libs/community/langchain_community/chat_models/sambanova.py @@ -174,10 +174,10 @@ class ChatSambaNovaCloud(BaseChatModel): temperature: float = Field(default=0.7) """model temperature""" - top_p: Optional[float] = Field() + top_p: Optional[float] = Field(default=None) """model top p""" - top_k: Optional[int] = Field() + top_k: Optional[int] = Field(default=None) """model top k""" stream_options: dict = Field(default={"include_usage": True}) @@ -593,7 +593,7 @@ class ChatSambaStudio(BaseChatModel): streaming_url: str = Field(default="", exclude=True) """SambaStudio streaming Url""" - model: Optional[str] = Field() + model: Optional[str] = Field(default=None) """The name of the model or expert to use (for CoE endpoints)""" streaming: bool = Field(default=False) @@ -605,16 +605,16 @@ class ChatSambaStudio(BaseChatModel): temperature: Optional[float] = Field(default=0.7) """model temperature""" - top_p: Optional[float] = Field() + top_p: Optional[float] = Field(default=None) """model top p""" - top_k: Optional[int] = Field() + top_k: Optional[int] = Field(default=None) """model top k""" - do_sample: Optional[bool] = Field() + do_sample: Optional[bool] = Field(default=None) """whether to do sampling""" - process_prompt: Optional[bool] = Field() + process_prompt: Optional[bool] = Field(default=True) """whether process prompt (for CoE generic v1 and v2 endpoints)""" stream_options: dict = Field(default={"include_usage": True}) @@ -1012,6 +1012,16 @@ def _process_stream_response( "system_fingerprint": data["system_fingerprint"], "created": data["created"], } + if data.get("usage") is not None: + content = "" + id = data["id"] + metadata = { + "finish_reason": finish_reason, + "usage": data.get("usage"), + "model_name": data["model"], + "system_fingerprint": data["system_fingerprint"], + "created": data["created"], + } yield AIMessageChunk( content=content, id=id,