Skip to content

Commit

Permalink
feat: Added support for streaming in chat as a run_manager (#9)
Browse files Browse the repository at this point in the history
* feat: Added support for streaming in chat as a run_manager
  • Loading branch information
MateuszOssGit authored Aug 8, 2024
1 parent 951b6b8 commit e9a7c79
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 75 deletions.
11 changes: 2 additions & 9 deletions libs/ibm/langchain_ibm/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,18 +669,13 @@ def _stream(

message_chunk = _convert_delta_to_message_chunk(choice, default_chunk_class)
generation_info = {}
if finish_reason := choice.get("stop_reason"):
if (finish_reason := choice.get("stop_reason")) != "not_finished":
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
chunk = ChatGenerationChunk(
message=message_chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
chunk.content, chunk=chunk, logprobs=logprobs
)
run_manager.on_llm_new_token(chunk.text, chunk=chunk)

yield chunk

Expand Down Expand Up @@ -751,8 +746,6 @@ def _create_chat_result(self, response: Union[dict]) -> ChatResult:
for res in response["results"]:
message = _convert_dict_to_message(res, call_id)
generation_info = dict(finish_reason=res.get("stop_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
if "generated_token_count" in res:
sum_of_total_generated_tokens += res["generated_token_count"]
if "input_token_count" in res:
Expand Down
7 changes: 6 additions & 1 deletion libs/ibm/langchain_ibm/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,15 @@ def _stream_response_to_generation_chunk(
"""Convert a stream response to a generation chunk."""
if not stream_response["results"]:
return GenerationChunk(text="")

finish_reason = stream_response["results"][0].get("stop_reason", None)

return GenerationChunk(
text=stream_response["results"][0]["generated_text"],
generation_info=dict(
finish_reason=stream_response["results"][0].get("stop_reason", None),
finish_reason=None
if finish_reason == "not_finished"
else finish_reason,
llm_output={
"model_id": self.model_id,
"deployment_id": self.deployment_id,
Expand Down
Loading

0 comments on commit e9a7c79

Please sign in to comment.