Skip to content

Commit

Permalink
community: callback before yield for _stream/_astream
Browse files Browse the repository at this point in the history
  • Loading branch information
mackong committed Feb 22, 2024
1 parent 919b8a3 commit b340e25
Show file tree
Hide file tree
Showing 22 changed files with 66 additions and 55 deletions.
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def _stream(
for data in stream_resp:
delta = data.completion
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

async def _astream(
self,
Expand All @@ -163,9 +163,9 @@ async def _astream(
async for data in stream_resp:
delta = data.completion
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

def _generate(
self,
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def _stream(
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk

def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
parameters = {**self._default_params, **kwargs}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,9 @@ def _stream(
),
generation_info=msg.additional_kwargs,
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

async def _astream(
self,
Expand All @@ -372,6 +372,6 @@ async def _astream(
),
generation_info=msg.additional_kwargs,
)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ def _stream(
if data.event_type == "text-generation":
delta = data.text
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

async def _astream(
self,
Expand All @@ -166,9 +166,9 @@ async def _astream(
if data.event_type == "text-generation":
delta = data.text
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
yield chunk
if run_manager:
await run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk

def _get_generation_info(self, response: Any) -> Dict[str, Any]:
"""Get the generation info from cohere API response."""
Expand Down
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/deepinfra.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,9 @@ def _stream(
chunk = _handle_sse_line(line)
if chunk:
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk

async def _astream(
self,
Expand All @@ -352,11 +352,11 @@ async def _astream(
chunk = _handle_sse_line(line)
if chunk:
cg_chunk = ChatGenerationChunk(message=chunk, generation_info=None)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(
str(chunk.content), chunk=cg_chunk
)
yield cg_chunk

async def _agenerate(
self,
Expand Down
14 changes: 6 additions & 8 deletions libs/community/langchain_community/chat_models/edenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,10 @@ def _stream(
for chunk_response in response.iter_lines():
chunk = json.loads(chunk_response.decode())
token = chunk["text"]
chat_generatio_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token)
)
yield chat_generatio_chunk
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
if run_manager:
run_manager.on_llm_new_token(token, chunk=chat_generatio_chunk)
run_manager.on_llm_new_token(token, chunk=cg_chunk)
yield cg_chunk

async def _astream(
self,
Expand Down Expand Up @@ -246,14 +244,14 @@ async def _astream(
async for chunk_response in response.content:
chunk = json.loads(chunk_response.decode())
token = chunk["text"]
chat_generation_chunk = ChatGenerationChunk(
cg_chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token)
)
yield chat_generation_chunk
if run_manager:
await run_manager.on_llm_new_token(
token=chunk["text"], chunk=chat_generation_chunk
token=chunk["text"], chunk=cg_chunk
)
yield cg_chunk

def _generate(
self,
Expand Down
16 changes: 10 additions & 6 deletions libs/community/langchain_community/chat_models/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,12 @@ def _stream(
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk

async def _astream(
self,
Expand Down Expand Up @@ -250,10 +252,12 @@ async def _astream(
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
await run_manager.on_llm_new_token(token=chunk.text, chunk=cg_chunk)
yield cg_chunk


def conditional_decorator(
Expand Down
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/gigachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def _stream(
if chunk.choices:
content = chunk.choices[0].delta.content
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(content, chunk=cg_chunk)
yield cg_chunk

async def _astream(
self,
Expand All @@ -172,9 +172,9 @@ async def _astream(
if chunk.choices:
content = chunk.choices[0].delta.content
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(content, chunk=cg_chunk)
yield cg_chunk

def get_num_tokens(self, text: str) -> int:
"""Count approximate number of tokens"""
Expand Down
8 changes: 4 additions & 4 deletions libs/community/langchain_community/chat_models/gpt_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,13 @@ def _stream(
chunk.data, default_chunk_class
)

yield chunk

if run_manager:
run_manager.on_llm_new_token(
token=chunk.message.content, chunk=chunk.message
)

yield chunk

async def _astream(
self,
messages: List[BaseMessage],
Expand All @@ -358,13 +358,13 @@ async def _astream(
chunk.data, default_chunk_class
)

yield chunk

if run_manager:
await run_manager.on_llm_new_token(
token=chunk.message.content, chunk=chunk.message
)

yield chunk

def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,9 @@ def _stream(
)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk

def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.hunyuan_secret_key is None:
Expand Down
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/jinachat.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,9 +313,9 @@ def _stream(
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk

def _generate(
self,
Expand Down Expand Up @@ -373,9 +373,9 @@ async def _astream(
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk

async def _agenerate(
self,
Expand Down
8 changes: 5 additions & 3 deletions libs/community/langchain_community/chat_models/konko.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,12 @@ def _stream(
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk

def _generate(
self,
Expand Down
4 changes: 2 additions & 2 deletions libs/community/langchain_community/chat_models/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,9 @@ def _stream(
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk

async def _astream(
self,
Expand All @@ -380,9 +380,9 @@ async def _astream(
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
yield cg_chunk

async def _agenerate(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def _stream(
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk, **params)
yield cg_chunk

async def _astream(
self,
Expand All @@ -150,11 +150,11 @@ async def _astream(
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.content, chunk=cg_chunk, **params
)
yield cg_chunk

async def _agenerate(
self,
Expand Down
6 changes: 3 additions & 3 deletions libs/community/langchain_community/chat_models/llama_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,12 @@ def _stream(
else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk

def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> requests.Response:
if self.service_url is None:
Expand Down
3 changes: 3 additions & 0 deletions libs/community/langchain_community/chat_models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def _stream(
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
yield chunk
Expand All @@ -337,6 +338,7 @@ async def _astream(
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
yield chunk
Expand All @@ -356,6 +358,7 @@ def _legacy_stream(
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
chunk=chunk,
verbose=self.verbose,
)
yield chunk
16 changes: 10 additions & 6 deletions libs/community/langchain_community/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,12 @@ def _stream(
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=cg_chunk)
yield cg_chunk

def _generate(
self,
Expand Down Expand Up @@ -501,10 +503,12 @@ async def _astream(
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
yield chunk
cg_chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info
)
if run_manager:
await run_manager.on_llm_new_token(token=chunk.text, chunk=chunk)
await run_manager.on_llm_new_token(token=chunk.text, chunk=cg_chunk)
yield cg_chunk

async def _agenerate(
self,
Expand Down
2 changes: 1 addition & 1 deletion libs/community/langchain_community/chat_models/sparkllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ def _stream(
delta = content["data"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
cg_chunk = ChatGenerationChunk(message=chunk)
yield cg_chunk
if run_manager:
run_manager.on_llm_new_token(str(chunk.content), chunk=cg_chunk)
yield cg_chunk

def _generate(
self,
Expand Down
Loading

0 comments on commit b340e25

Please sign in to comment.