Skip to content

Commit

Permalink
callback refactor (langchain-ai#13372)
Browse files Browse the repository at this point in the history
Co-authored-by: Nuno Campos <[email protected]>
  • Loading branch information
2 people authored and xieqihui committed Nov 21, 2023
1 parent 141b376 commit 658a724
Show file tree
Hide file tree
Showing 31 changed files with 4,848 additions and 4,642 deletions.
615 changes: 22 additions & 593 deletions libs/langchain/langchain/callbacks/base.py

Large diffs are not rendered by default.

2,086 changes: 58 additions & 2,028 deletions libs/langchain/langchain/callbacks/manager.py

Large diffs are not rendered by default.

98 changes: 2 additions & 96 deletions libs/langchain/langchain/callbacks/stdout.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,3 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional
from langchain.schema.callbacks.stdout import StdOutCallbackHandler

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.utils.input import print_text


class StdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""

def __init__(self, color: Optional[str] = None) -> None:
"""Initialize callback handler."""
self.color = color

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
pass

def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass

def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass

def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
class_name = serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")

def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m")

def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass

def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
pass

def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
print_text(action.log, color=color or self.color)

def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
if observation_prefix is not None:
print_text(f"\n{observation_prefix}")
print_text(output, color=color or self.color)
if llm_prefix is not None:
print_text(f"\n{llm_prefix}")

def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass

def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Any,
) -> None:
"""Run when agent ends."""
print_text(text, color=color or self.color, end=end)

def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text(finish.log, color=color or self.color, end="\n")
__all__ = ["StdOutCallbackHandler"]
8 changes: 4 additions & 4 deletions libs/langchain/langchain/callbacks/tracers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Tracers that record execution of LangChain runs."""

from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1
from langchain.callbacks.tracers.stdout import (
from langchain.callbacks.tracers.wandb import WandbTracer
from langchain.schema.callbacks.tracers.langchain import LangChainTracer
from langchain.schema.callbacks.tracers.langchain_v1 import LangChainTracerV1
from langchain.schema.callbacks.tracers.stdout import (
ConsoleCallbackHandler,
FunctionCallbackHandler,
)
from langchain.callbacks.tracers.wandb import WandbTracer

__all__ = [
"LangChainTracer",
Expand Down
Loading

0 comments on commit 658a724

Please sign in to comment.