Skip to content

Commit

Permalink
community[patch]: add missing chunk parameter for _stream/_astream (#…
Browse files Browse the repository at this point in the history
…17807)

- Description: Add missing chunk parameter for _stream/_astream for some
chat models, make all chat models in a consistent behaviour.
- Issue: N/A
- Dependencies: N/A
  • Loading branch information
mackong authored Feb 21, 2024
1 parent 1b0802b commit 3189109
Show file tree
Hide file tree
Showing 15 changed files with 71 additions and 42 deletions.
5 changes: 3 additions & 2 deletions libs/community/langchain_community/chat_models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,10 @@ def _stream(
m.get("delta"), default_chunk_class
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)

def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
parameters = {**self._default_params, **kwargs}
Expand Down
10 changes: 6 additions & 4 deletions libs/community/langchain_community/chat_models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,10 @@ def _stream(
for data in stream:
if data.event_type == "text-generation":
delta = data.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta)
run_manager.on_llm_new_token(delta, chunk=chunk)

async def _astream(
self,
Expand All @@ -164,9 +165,10 @@ async def _astream(
async for data in stream:
if data.event_type == "text-generation":
delta = data.text
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
await run_manager.on_llm_new_token(delta)
await run_manager.on_llm_new_token(delta, chunk=chunk)

def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""
Expand Down
12 changes: 8 additions & 4 deletions libs/community/langchain_community/chat_models/deepinfra.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,10 @@ def _stream(
for line in _parse_stream(response.iter_lines()):
chunk = _handle_sse_line(line)
if chunk:
yield ChatGenerationChunk(message=chunk, generation_info=None)
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(str(chunk.content))
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)

async def _astream(
self,
Expand All @@ -350,9 +351,12 @@ async def _astream(
async for line in _parse_stream_async(response.content):
chunk = _handle_sse_line(line)
if chunk:
yield ChatGenerationChunk(message=chunk, generation_info=None)
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(str(chunk.content))
await run_manager.on_llm_new_token(
str(chunk.content), chunk=cg_chunk
)

async def _agenerate(
self,
Expand Down
10 changes: 6 additions & 4 deletions libs/community/langchain_community/chat_models/gigachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,10 @@ def _stream(
for chunk in self._client.stream(payload):
if chunk.choices:
content = chunk.choices[0].delta.content
yield ChatGenerationChunk(message=AIMessageChunk(content=content))
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(content)
run_manager.on_llm_new_token(content, chunk=cg_chunk)

async def _astream(
self,
Expand All @@ -170,9 +171,10 @@ async def _astream(
async for chunk in self._client.astream(payload):
if chunk.choices:
content = chunk.choices[0].delta.content
yield ChatGenerationChunk(message=AIMessageChunk(content=content))
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(content)
await run_manager.on_llm_new_token(content, chunk=cg_chunk)

def get_num_tokens(self, text: str) -> int:
"""Count approximate number of tokens"""
Expand Down
5 changes: 3 additions & 2 deletions libs/community/langchain_community/chat_models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,10 @@ def _stream(
choice["delta"], default_chunk_class
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)

def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.hunyuan_secret_key is None:
Expand Down
10 changes: 6 additions & 4 deletions libs/community/langchain_community/chat_models/jinachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,10 @@ def _stream(
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)

def _generate(
self,
Expand Down Expand Up @@ -371,9 +372,10 @@ async def _astream(
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)

async def _agenerate(
self,
Expand Down
10 changes: 6 additions & 4 deletions libs/community/langchain_community/chat_models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,10 @@ def _stream(
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)

async def _astream(
self,
Expand All @@ -378,9 +379,10 @@ async def _astream(
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)

async def _agenerate(
self,
Expand Down
12 changes: 8 additions & 4 deletions libs/community/langchain_community/chat_models/litellm_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ def _stream(
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, **params)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk, **params)

async def _astream(
self,
Expand All @@ -148,9 +149,12 @@ async def _astream(
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content, **params)
await run_manager.on_llm_new_token(
chunk.content, chunk=cg_chunk, **params
)

async def _agenerate(
self,
Expand Down
2 changes: 2 additions & 0 deletions libs/community/langchain_community/chat_models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def _chat_stream_with_aggregation(
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:
Expand All @@ -221,6 +222,7 @@ async def _achat_stream_with_aggregation(
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=verbose,
)
if final_chunk is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,12 @@ async def _astream(

# yield text, if any
if text:
cg_chunk = ChatGenerationChunk(message=content)
if run_manager:
await run_manager.on_llm_new_token(cast(str, content.content))
yield ChatGenerationChunk(message=content)
await run_manager.on_llm_new_token(
cast(str, content.content), chunk=cg_chunk
)
yield cg_chunk

# break if stop sequence found
if stop_seq_found:
Expand Down
5 changes: 3 additions & 2 deletions libs/community/langchain_community/chat_models/sparkllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ def _stream(
continue
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
yield ChatGenerationChunk(message=chunk)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(str(chunk.content))
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)

def _generate(
self,
Expand Down
5 changes: 3 additions & 2 deletions libs/community/langchain_community/chat_models/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,10 @@ def _stream(
chat = self._start_chat(history, **params)
responses = chat.send_message_streaming(question.content, **params)
for response in responses:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=response.text))
if run_manager:
run_manager.on_llm_new_token(response.text)
yield ChatGenerationChunk(message=AIMessageChunk(content=response.text))
run_manager.on_llm_new_token(response.text, chunk=chunk)
yield chunk

def _start_chat(
self, history: _ChatHistory, **kwargs: Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ def _stream(
for res in self.client.stream_chat(params):
if res:
msg = convert_dict_to_message(res)
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
yield chunk
if run_manager:
run_manager.on_llm_new_token(cast(str, msg.content))
run_manager.on_llm_new_token(cast(str, msg.content), chunk=chunk)

def _generate(
self,
Expand Down
10 changes: 6 additions & 4 deletions libs/community/langchain_community/chat_models/yuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,12 +269,13 @@ def _stream(
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(
cg_chunk = ChatGenerationChunk(
message=chunk,
generation_info=generation_info,
)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content)
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)

def _generate(
self,
Expand Down Expand Up @@ -351,12 +352,13 @@ async def _astream(
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(
cg_chunk = ChatGenerationChunk(
message=chunk,
generation_info=generation_info,
)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)

async def _agenerate(
self,
Expand Down
5 changes: 3 additions & 2 deletions libs/community/langchain_community/chat_models/zhipuai.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,10 @@ def _stream( # type: ignore[override]
for r in response.events():
if r.event == "add":
delta = r.data
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta)
run_manager.on_llm_new_token(delta, chunk=chunk)

elif r.event == "error":
raise ValueError(f"Error from ZhipuAI API response: {r.data}")

0 comments on commit 3189109

Please sign in to comment.