diff --git a/libs/community/langchain_community/callbacks/panel_callback.py b/libs/community/langchain_community/callbacks/panel_callback.py index acf2a3afe3b3f..1c3de3df64d9a 100644 --- a/libs/community/langchain_community/callbacks/panel_callback.py +++ b/libs/community/langchain_community/callbacks/panel_callback.py @@ -1,17 +1,17 @@ -"""The langchain module integrates Langchain support with Panel.""" - from __future__ import annotations -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Optional, Union +from uuid import UUID from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema import AgentAction, AgentFinish, LLMResult, Document from langchain_core.utils import guard_import if TYPE_CHECKING: import panel from panel.chat.feed import ChatFeed from panel.chat.interface import ChatInterface + from panel.widgets import ChatMessage def import_panel() -> panel: @@ -20,38 +20,13 @@ def import_panel() -> panel: class PanelCallbackHandler(BaseCallbackHandler): - """ - The Langchain `PanelCallbackHandler` itself is not a widget or pane, - but is useful for rendering and streaming the *chain of thought* from - Langchain Tools, Agents, and Chains as `ChatMessage` objects. - - :Example: - - >>> chat_interface = pn.widgets.ChatInterface( - callback=callback, callback_user="Langchain" - ) - >>> callback_handler = pn.widgets.langchain.PanelCallbackHandler( - instance=chat_interface - ) - >>> llm = ChatOpenAI(streaming=True, callbacks=[callback_handler]) - >>> chain = ConversationChain(llm=llm) - - Args: - instance (ChatFeed | ChatInterface): The ChatFeed or ChatInterface - instance to stream messages to. - user (str, optional): The user to display in the chat feed. - Defaults to "LangChain". - avatar (str, optional): The avatar to display in the chat feed. - Defaults to DEFAULT_AVATARS["langchain"]. - """ - def __init__( self, - instance: "ChatFeed" | "ChatInterface", + instance: ChatFeed | ChatInterface, user: str = "LangChain", avatar: str | None = None, step_width: int = 500, - ): + ) -> None: if BaseCallbackHandler is object: raise ImportError( "LangChainCallbackHandler requires `langchain` to be installed." @@ -62,22 +37,18 @@ def __init__( avatar = avatar or self._default_avatars["langchain"] self.instance = instance - self._step = None - self._message = None + self._step: Any = None # Can be None or panel.widgets.ChatMessage + self._message: Optional[ChatMessage] = None self._active_user = user self._active_avatar = avatar self._disabled_state = self.instance.disabled - self._is_streaming = None + self._is_streaming: Optional[bool] = None self._step_width = step_width - self._input_user = user # original user + self._input_user = user self._input_avatar = avatar - def _update_active(self, avatar: str, label: str): - """ - Prevent duplicate labels from being appended to the same user. - """ - # not a typo; Langchain passes a string :/ + def _update_active(self, avatar: str, label: str) -> None: if label == "None": return @@ -85,12 +56,12 @@ def _update_active(self, avatar: str, label: str): if f"- {label}" not in self._active_user: self._active_user = f"{self._active_user} - {label}" - def _reset_active(self): + def _reset_active(self) -> None: self._active_user = self._input_user self._active_avatar = self._input_avatar self._message = None - def _on_start(self, serialized, kwargs): + def _on_start(self, serialized: dict[str, Any], kwargs: dict[str, Any]) -> None: model = kwargs.get("invocation_params", {}).get("model_name", "") self._is_streaming = serialized.get("kwargs", {}).get("streaming") messages = self.instance.objects @@ -99,7 +70,7 @@ def _on_start(self, serialized, kwargs): if self._active_user and model not in self._active_user: self._active_user = f"{self._active_user} ({model})" - def _stream(self, message: str): + def _stream(self, message: str) -> Optional[ChatMessage]: if message: return self.instance.stream( message, @@ -109,34 +80,51 @@ def _stream(self, message: str): ) return self._message - def on_llm_start(self, serialized: dict[str, Any], *args, **kwargs): + def on_llm_start( + self, + serialized: dict[str, Any], + *args: Any, + **kwargs: Any, + ) -> Any: self._on_start(serialized, kwargs) return super().on_llm_start(serialized, *args, **kwargs) - def on_llm_new_token(self, token: str, **kwargs) -> None: + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: self._message = self._stream(token) return super().on_llm_new_token(token, **kwargs) - def on_llm_end(self, response: LLMResult, *args, **kwargs): + def on_llm_end(self, response: LLMResult, *args: Any, **kwargs: Any) -> Any: if not self._is_streaming: - # on_llm_new_token does not get called if not streaming self._stream(response.generations[0][0].text) self._reset_active() return super().on_llm_end(response, *args, **kwargs) - def on_llm_error(self, error: Union[Exception, KeyboardInterrupt], *args, **kwargs): - return super().on_llm_error(error, *args, **kwargs) + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + return super().on_llm_error( + error, run_id=run_id, parent_run_id=parent_run_id, **kwargs + ) - def on_agent_action(self, action: AgentAction, *args, **kwargs: Any) -> Any: + def on_agent_action(self, action: AgentAction, *args: Any, **kwargs: Any) -> Any: return super().on_agent_action(action, *args, **kwargs) - def on_agent_finish(self, finish: AgentFinish, *args, **kwargs: Any) -> Any: + def on_agent_finish(self, finish: AgentFinish, *args: Any, **kwargs: Any) -> Any: return super().on_agent_finish(finish, *args, **kwargs) def on_tool_start( - self, serialized: dict[str, Any], input_str: str, *args, **kwargs - ): + self, + serialized: dict[str, Any], + input_str: str, + *args: Any, + **kwargs: Any, + ) -> Any: self._update_active(self._default_avatars["tool"], serialized["name"]) self._step = self.instance.add_step( title=f"Tool input: {input_str}", @@ -148,37 +136,56 @@ def on_tool_start( ) return super().on_tool_start(serialized, input_str, *args, **kwargs) - def on_tool_end(self, output: str, *args, **kwargs): - self._step.stream(output) - self._step.status = "success" + def on_tool_end(self, output: str, *args: Any, **kwargs: Any) -> Any: + if self._step is not None: + self._step.stream(output) + self._step.status = "success" self._reset_active() self._step = None return super().on_tool_end(output, *args, **kwargs) def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], *args, **kwargs - ): - return super().on_tool_error(error, *args, **kwargs) + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + return super().on_tool_error( + error, run_id=run_id, parent_run_id=parent_run_id, **kwargs + ) def on_chain_start( - self, serialized: dict[str, Any], inputs: dict[str, Any], *args, **kwargs - ): + self, + serialized: dict[str, Any], + inputs: dict[str, Any], + *args: Any, + **kwargs: Any, + ) -> Any: self._disabled_state = self.instance.disabled self.instance.disabled = True return super().on_chain_start(serialized, inputs, *args, **kwargs) - def on_chain_end(self, outputs: dict[str, Any], *args, **kwargs): + def on_chain_end( + self, outputs: dict[str, Any], *args: Any, **kwargs: Any + ) -> Any: self.instance.disabled = self._disabled_state return super().on_chain_end(outputs, *args, **kwargs) def on_retriever_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, ) -> Any: - """Run when Retriever errors.""" - return super().on_retriever_error(error, **kwargs) + return super().on_retriever_error( + error, run_id=run_id, parent_run_id=parent_run_id, **kwargs + ) - def on_retriever_end(self, documents, **kwargs: Any) -> Any: - """Run when Retriever ends running.""" + def on_retriever_end(self, documents: list[Document], **kwargs: Any) -> Any: objects = [ (f"Document {index}", document.page_content) for index, document in enumerate(documents) @@ -194,15 +201,13 @@ def on_retriever_end(self, documents, **kwargs: Any) -> Any: ) return super().on_retriever_end(documents=documents, **kwargs) - def on_text(self, text: str, **kwargs: Any): - """Run when text is received.""" - return super().on_text(text, **kwargs) + def on_text(self, text: str, **kwargs: Any) -> None: + super().on_text(text, **kwargs) def on_chat_model_start( - self, serialized: dict[str, Any], messages: list, **kwargs: Any + self, + serialized: dict[str, Any], + messages: list[Any], + **kwargs: Any, ) -> None: - """ - To prevent the inherited class from raising - NotImplementedError, will not call super() here. - """ - self._on_start(serialized, kwargs) + self._on_start(serialized, kwargs) \ No newline at end of file