Skip to content

Commit

Permalink
more linting
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 committed Nov 11, 2024
1 parent da06c1a commit 50bd414
Showing 1 changed file with 81 additions and 76 deletions.
157 changes: 81 additions & 76 deletions libs/community/langchain_community/callbacks/panel_callback.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""The langchain module integrates Langchain support with Panel."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any, Optional, Union

Check failure on line 3 in libs/community/langchain_community/callbacks/panel_callback.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (F401)

langchain_community/callbacks/panel_callback.py:3:50: F401 `typing.Union` imported but unused

Check failure on line 3 in libs/community/langchain_community/callbacks/panel_callback.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (F401)

langchain_community/callbacks/panel_callback.py:3:50: F401 `typing.Union` imported but unused
from uuid import UUID

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import AgentAction, AgentFinish, LLMResult, Document
from langchain_core.utils import guard_import

if TYPE_CHECKING:

Check failure on line 10 in libs/community/langchain_community/callbacks/panel_callback.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (I001)

langchain_community/callbacks/panel_callback.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 10 in libs/community/langchain_community/callbacks/panel_callback.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (I001)

langchain_community/callbacks/panel_callback.py:1:1: I001 Import block is un-sorted or un-formatted
import panel
from panel.chat.feed import ChatFeed
from panel.chat.interface import ChatInterface
from panel.widgets import ChatMessage


def import_panel() -> panel:
Expand All @@ -20,38 +20,13 @@ def import_panel() -> panel:


class PanelCallbackHandler(BaseCallbackHandler):
"""
The Langchain `PanelCallbackHandler` itself is not a widget or pane,
but is useful for rendering and streaming the *chain of thought* from
Langchain Tools, Agents, and Chains as `ChatMessage` objects.
:Example:
>>> chat_interface = pn.widgets.ChatInterface(
callback=callback, callback_user="Langchain"
)
>>> callback_handler = pn.widgets.langchain.PanelCallbackHandler(
instance=chat_interface
)
>>> llm = ChatOpenAI(streaming=True, callbacks=[callback_handler])
>>> chain = ConversationChain(llm=llm)
Args:
instance (ChatFeed | ChatInterface): The ChatFeed or ChatInterface
instance to stream messages to.
user (str, optional): The user to display in the chat feed.
Defaults to "LangChain".
avatar (str, optional): The avatar to display in the chat feed.
Defaults to DEFAULT_AVATARS["langchain"].
"""

def __init__(
self,
instance: "ChatFeed" | "ChatInterface",
instance: ChatFeed | ChatInterface,
user: str = "LangChain",
avatar: str | None = None,
step_width: int = 500,
):
) -> None:
if BaseCallbackHandler is object:
raise ImportError(
"LangChainCallbackHandler requires `langchain` to be installed."
Expand All @@ -62,35 +37,31 @@ def __init__(
avatar = avatar or self._default_avatars["langchain"]

self.instance = instance
self._step = None
self._message = None
self._step: Any = None # Can be None or panel.widgets.ChatMessage
self._message: Optional[ChatMessage] = None
self._active_user = user
self._active_avatar = avatar
self._disabled_state = self.instance.disabled
self._is_streaming = None
self._is_streaming: Optional[bool] = None
self._step_width = step_width

self._input_user = user # original user
self._input_user = user
self._input_avatar = avatar

def _update_active(self, avatar: str, label: str):
"""
Prevent duplicate labels from being appended to the same user.
"""
# not a typo; Langchain passes a string :/
def _update_active(self, avatar: str, label: str) -> None:
if label == "None":
return

self._active_avatar = avatar
if f"- {label}" not in self._active_user:
self._active_user = f"{self._active_user} - {label}"

def _reset_active(self):
def _reset_active(self) -> None:
self._active_user = self._input_user
self._active_avatar = self._input_avatar
self._message = None

def _on_start(self, serialized, kwargs):
def _on_start(self, serialized: dict[str, Any], kwargs: dict[str, Any]) -> None:
model = kwargs.get("invocation_params", {}).get("model_name", "")
self._is_streaming = serialized.get("kwargs", {}).get("streaming")
messages = self.instance.objects
Expand All @@ -99,7 +70,7 @@ def _on_start(self, serialized, kwargs):
if self._active_user and model not in self._active_user:
self._active_user = f"{self._active_user} ({model})"

def _stream(self, message: str):
def _stream(self, message: str) -> Optional[ChatMessage]:
if message:
return self.instance.stream(
message,
Expand All @@ -109,34 +80,51 @@ def _stream(self, message: str):
)
return self._message

def on_llm_start(self, serialized: dict[str, Any], *args, **kwargs):
def on_llm_start(
self,
serialized: dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Any:
self._on_start(serialized, kwargs)
return super().on_llm_start(serialized, *args, **kwargs)

def on_llm_new_token(self, token: str, **kwargs) -> None:
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self._message = self._stream(token)
return super().on_llm_new_token(token, **kwargs)

def on_llm_end(self, response: LLMResult, *args, **kwargs):
def on_llm_end(self, response: LLMResult, *args: Any, **kwargs: Any) -> Any:
if not self._is_streaming:
# on_llm_new_token does not get called if not streaming
self._stream(response.generations[0][0].text)

self._reset_active()
return super().on_llm_end(response, *args, **kwargs)

def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], *args, **kwargs):
return super().on_llm_error(error, *args, **kwargs)
def on_llm_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
return super().on_llm_error(
error, run_id=run_id, parent_run_id=parent_run_id, **kwargs
)

def on_agent_action(self, action: AgentAction, *args, **kwargs: Any) -> Any:
def on_agent_action(self, action: AgentAction, *args: Any, **kwargs: Any) -> Any:
return super().on_agent_action(action, *args, **kwargs)

def on_agent_finish(self, finish: AgentFinish, *args, **kwargs: Any) -> Any:
def on_agent_finish(self, finish: AgentFinish, *args: Any, **kwargs: Any) -> Any:
return super().on_agent_finish(finish, *args, **kwargs)

def on_tool_start(
self, serialized: dict[str, Any], input_str: str, *args, **kwargs
):
self,
serialized: dict[str, Any],
input_str: str,
*args: Any,
**kwargs: Any,
) -> Any:
self._update_active(self._default_avatars["tool"], serialized["name"])
self._step = self.instance.add_step(
title=f"Tool input: {input_str}",
Expand All @@ -148,37 +136,56 @@ def on_tool_start(
)
return super().on_tool_start(serialized, input_str, *args, **kwargs)

def on_tool_end(self, output: str, *args, **kwargs):
self._step.stream(output)
self._step.status = "success"
def on_tool_end(self, output: str, *args: Any, **kwargs: Any) -> Any:
if self._step is not None:
self._step.stream(output)
self._step.status = "success"
self._reset_active()
self._step = None
return super().on_tool_end(output, *args, **kwargs)

def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], *args, **kwargs
):
return super().on_tool_error(error, *args, **kwargs)
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
return super().on_tool_error(
error, run_id=run_id, parent_run_id=parent_run_id, **kwargs
)

def on_chain_start(
self, serialized: dict[str, Any], inputs: dict[str, Any], *args, **kwargs
):
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*args: Any,
**kwargs: Any,
) -> Any:
self._disabled_state = self.instance.disabled
self.instance.disabled = True
return super().on_chain_start(serialized, inputs, *args, **kwargs)

def on_chain_end(self, outputs: dict[str, Any], *args, **kwargs):
def on_chain_end(
self, outputs: dict[str, Any], *args: Any, **kwargs: Any
) -> Any:
self.instance.disabled = self._disabled_state
return super().on_chain_end(outputs, *args, **kwargs)

def on_retriever_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever errors."""
return super().on_retriever_error(error, **kwargs)
return super().on_retriever_error(
error, run_id=run_id, parent_run_id=parent_run_id, **kwargs
)

def on_retriever_end(self, documents, **kwargs: Any) -> Any:
"""Run when Retriever ends running."""
def on_retriever_end(self, documents: list[Document], **kwargs: Any) -> Any:
objects = [
(f"Document {index}", document.page_content)
for index, document in enumerate(documents)
Expand All @@ -194,15 +201,13 @@ def on_retriever_end(self, documents, **kwargs: Any) -> Any:
)
return super().on_retriever_end(documents=documents, **kwargs)

def on_text(self, text: str, **kwargs: Any):
"""Run when text is received."""
return super().on_text(text, **kwargs)
def on_text(self, text: str, **kwargs: Any) -> None:
super().on_text(text, **kwargs)

def on_chat_model_start(
self, serialized: dict[str, Any], messages: list, **kwargs: Any
self,
serialized: dict[str, Any],
messages: list[Any],
**kwargs: Any,
) -> None:
"""
To prevent the inherited class from raising
NotImplementedError, will not call super() here.
"""
self._on_start(serialized, kwargs)
self._on_start(serialized, kwargs)

0 comments on commit 50bd414

Please sign in to comment.