Skip to content

Commit

Permalink
community: sambastudio chat model integration minor fix (#27238)
Browse files Browse the repository at this point in the history
**Description:** sambastudio chat model integration minor fix
 fix default params
 fix usage metadata when streaming
  • Loading branch information
jhpiedrahitao authored Oct 15, 2024
1 parent fead474 commit 12fea5b
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions libs/community/langchain_community/chat_models/sambanova.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ class ChatSambaNovaCloud(BaseChatModel):
temperature: float = Field(default=0.7)
"""model temperature"""

top_p: Optional[float] = Field()
top_p: Optional[float] = Field(default=None)
"""model top p"""

top_k: Optional[int] = Field()
top_k: Optional[int] = Field(default=None)
"""model top k"""

stream_options: dict = Field(default={"include_usage": True})
Expand Down Expand Up @@ -593,7 +593,7 @@ class ChatSambaStudio(BaseChatModel):
streaming_url: str = Field(default="", exclude=True)
"""SambaStudio streaming Url"""

model: Optional[str] = Field()
model: Optional[str] = Field(default=None)
"""The name of the model or expert to use (for CoE endpoints)"""

streaming: bool = Field(default=False)
Expand All @@ -605,16 +605,16 @@ class ChatSambaStudio(BaseChatModel):
temperature: Optional[float] = Field(default=0.7)
"""model temperature"""

top_p: Optional[float] = Field()
top_p: Optional[float] = Field(default=None)
"""model top p"""

top_k: Optional[int] = Field()
top_k: Optional[int] = Field(default=None)
"""model top k"""

do_sample: Optional[bool] = Field()
do_sample: Optional[bool] = Field(default=None)
"""whether to do sampling"""

process_prompt: Optional[bool] = Field()
process_prompt: Optional[bool] = Field(default=True)
"""whether process prompt (for CoE generic v1 and v2 endpoints)"""

stream_options: dict = Field(default={"include_usage": True})
Expand Down Expand Up @@ -1012,6 +1012,16 @@ def _process_stream_response(
"system_fingerprint": data["system_fingerprint"],
"created": data["created"],
}
if data.get("usage") is not None:
content = ""
id = data["id"]
metadata = {
"finish_reason": finish_reason,
"usage": data.get("usage"),
"model_name": data["model"],
"system_fingerprint": data["system_fingerprint"],
"created": data["created"],
}
yield AIMessageChunk(
content=content,
id=id,
Expand Down

0 comments on commit 12fea5b

Please sign in to comment.