diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index fb6404bf6c59b..f60646ad5258d 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -25,7 +25,6 @@ ) from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.input import get_color_mapping from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( @@ -39,6 +38,7 @@ from langchain.schema.messages import BaseMessage from langchain.tools.base import BaseTool from langchain.utilities.asyncio import asyncio_timeout +from langchain.utils.input import get_color_mapping logger = logging.getLogger(__name__) diff --git a/libs/langchain/langchain/agents/agent_iterator.py b/libs/langchain/langchain/agents/agent_iterator.py index 40193387f9337..829e00c9eaefe 100644 --- a/libs/langchain/langchain/agents/agent_iterator.py +++ b/libs/langchain/langchain/agents/agent_iterator.py @@ -25,11 +25,11 @@ CallbackManagerForChainRun, Callbacks, ) -from langchain.input import get_color_mapping from langchain.load.dump import dumpd from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo from langchain.tools import BaseTool from langchain.utilities.asyncio import asyncio_timeout +from langchain.utils.input import get_color_mapping if TYPE_CHECKING: from langchain.agents.agent import AgentExecutor diff --git a/libs/langchain/langchain/callbacks/file.py b/libs/langchain/langchain/callbacks/file.py index 1c359e7a1d59c..3265fa0060b02 100644 --- a/libs/langchain/langchain/callbacks/file.py +++ b/libs/langchain/langchain/callbacks/file.py @@ -2,8 +2,8 @@ from typing import Any, Dict, Optional, TextIO, cast from langchain.callbacks.base import BaseCallbackHandler -from langchain.input import print_text from langchain.schema import AgentAction, AgentFinish +from langchain.utils.input import print_text class FileCallbackHandler(BaseCallbackHandler): diff --git a/libs/langchain/langchain/callbacks/stdout.py b/libs/langchain/langchain/callbacks/stdout.py index add431d970bfe..56e0e7d904af3 100644 --- a/libs/langchain/langchain/callbacks/stdout.py +++ b/libs/langchain/langchain/callbacks/stdout.py @@ -2,8 +2,8 @@ from typing import Any, Dict, List, Optional, Union from langchain.callbacks.base import BaseCallbackHandler -from langchain.input import print_text from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.utils.input import print_text class StdOutCallbackHandler(BaseCallbackHandler): diff --git a/libs/langchain/langchain/callbacks/tracers/stdout.py b/libs/langchain/langchain/callbacks/tracers/stdout.py index b7121336a63e8..2c082196d74fd 100644 --- a/libs/langchain/langchain/callbacks/tracers/stdout.py +++ b/libs/langchain/langchain/callbacks/tracers/stdout.py @@ -3,7 +3,7 @@ from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run -from langchain.input import get_bolded_text, get_colored_text +from langchain.utils.input import get_bolded_text, get_colored_text def try_json_stringify(obj: Any, fallback: str) -> str: diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index 40cb3d5ec7a1a..bb24607a97c27 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -14,7 +14,6 @@ Callbacks, ) from langchain.chains.base import Chain -from langchain.input import get_colored_text from langchain.load.dump import dumpd from langchain.prompts.prompt import PromptTemplate from langchain.schema import ( @@ -25,6 +24,7 @@ PromptValue, ) from langchain.schema.language_model import BaseLanguageModel +from langchain.utils.input import get_colored_text class LLMChain(Chain): diff --git a/libs/langchain/langchain/chains/openai_functions/openapi.py b/libs/langchain/langchain/chains/openai_functions/openapi.py index 5b991eff9d548..511123a601eff 100644 --- a/libs/langchain/langchain/chains/openai_functions/openapi.py +++ b/libs/langchain/langchain/chains/openai_functions/openapi.py @@ -12,13 +12,13 @@ from langchain.chains.base import Chain from langchain.chains.sequential import SequentialChain from langchain.chat_models import ChatOpenAI -from langchain.input import get_colored_text from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain.prompts import ChatPromptTemplate from langchain.schema import BasePromptTemplate from langchain.schema.language_model import BaseLanguageModel from langchain.tools import APIOperation from langchain.utilities.openapi import OpenAPISpec +from langchain.utils.input import get_colored_text def _get_description(o: Any, prefer_short: bool) -> Optional[str]: diff --git a/libs/langchain/langchain/chains/sequential.py b/libs/langchain/langchain/chains/sequential.py index 078a0d7fffa43..26cbaf7021d1b 100644 --- a/libs/langchain/langchain/chains/sequential.py +++ b/libs/langchain/langchain/chains/sequential.py @@ -8,7 +8,7 @@ CallbackManagerForChainRun, ) from langchain.chains.base import Chain -from langchain.input import get_color_mapping +from langchain.utils.input import get_color_mapping class SequentialChain(Chain): diff --git a/libs/langchain/langchain/input.py b/libs/langchain/langchain/input.py index 8d5ae6cc24fb8..7fa443ef45277 100644 --- a/libs/langchain/langchain/input.py +++ b/libs/langchain/langchain/input.py @@ -1,42 +1,14 @@ -"""Handle chained inputs.""" -from typing import Dict, List, Optional, TextIO - -_TEXT_COLOR_MAPPING = { - "blue": "36;1", - "yellow": "33;1", - "pink": "38;5;200", - "green": "32;1", - "red": "31;1", -} - - -def get_color_mapping( - items: List[str], excluded_colors: Optional[List] = None -) -> Dict[str, str]: - """Get mapping for items to a support color.""" - colors = list(_TEXT_COLOR_MAPPING.keys()) - if excluded_colors is not None: - colors = [c for c in colors if c not in excluded_colors] - color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} - return color_mapping - - -def get_colored_text(text: str, color: str) -> str: - """Get colored text.""" - color_str = _TEXT_COLOR_MAPPING[color] - return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" - - -def get_bolded_text(text: str) -> str: - """Get bolded text.""" - return f"\033[1m{text}\033[0m" - - -def print_text( - text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None -) -> None: - """Print text with highlighting and no end characters.""" - text_to_print = get_colored_text(text, color) if color else text - print(text_to_print, end=end, file=file) - if file: - file.flush() # ensure all printed content are written to file +"""DEPRECATED: Kept for backwards compatibility.""" +from langchain.utils.input import ( + get_bolded_text, + get_color_mapping, + get_colored_text, + print_text, +) + +__all__ = [ + "get_bolded_text", + "get_color_mapping", + "get_colored_text", + "print_text", +] diff --git a/libs/langchain/langchain/model_laboratory.py b/libs/langchain/langchain/model_laboratory.py index 0ba871b9bd56a..87c44a211438a 100644 --- a/libs/langchain/langchain/model_laboratory.py +++ b/libs/langchain/langchain/model_laboratory.py @@ -5,9 +5,9 @@ from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.input import get_color_mapping, print_text from langchain.llms.base import BaseLLM from langchain.prompts.prompt import PromptTemplate +from langchain.utils.input import get_color_mapping, print_text class ModelLaboratory: diff --git a/libs/langchain/langchain/utils/__init__.py b/libs/langchain/langchain/utils/__init__.py index 74b0bb87d7b15..1242144c7a6fe 100644 --- a/libs/langchain/langchain/utils/__init__.py +++ b/libs/langchain/langchain/utils/__init__.py @@ -6,6 +6,12 @@ from langchain.utils.env import get_from_dict_or_env, get_from_env from langchain.utils.formatting import StrictFormatter, formatter +from langchain.utils.input import ( + get_bolded_text, + get_color_mapping, + get_colored_text, + print_text, +) from langchain.utils.math import cosine_similarity, cosine_similarity_top_k from langchain.utils.strings import comma_list, stringify_dict, stringify_value from langchain.utils.utils import ( @@ -24,11 +30,15 @@ "cosine_similarity", "cosine_similarity_top_k", "formatter", + "get_bolded_text", + "get_color_mapping", + "get_colored_text", "get_from_dict_or_env", "get_from_env", "get_pydantic_field_names", "guard_import", "mock_now", + "print_text", "raise_for_status_with_text", "stringify_dict", "stringify_value", diff --git a/libs/langchain/langchain/utils/input.py b/libs/langchain/langchain/utils/input.py new file mode 100644 index 0000000000000..8d5ae6cc24fb8 --- /dev/null +++ b/libs/langchain/langchain/utils/input.py @@ -0,0 +1,42 @@ +"""Handle chained inputs.""" +from typing import Dict, List, Optional, TextIO + +_TEXT_COLOR_MAPPING = { + "blue": "36;1", + "yellow": "33;1", + "pink": "38;5;200", + "green": "32;1", + "red": "31;1", +} + + +def get_color_mapping( + items: List[str], excluded_colors: Optional[List] = None +) -> Dict[str, str]: + """Get mapping for items to a support color.""" + colors = list(_TEXT_COLOR_MAPPING.keys()) + if excluded_colors is not None: + colors = [c for c in colors if c not in excluded_colors] + color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)} + return color_mapping + + +def get_colored_text(text: str, color: str) -> str: + """Get colored text.""" + color_str = _TEXT_COLOR_MAPPING[color] + return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" + + +def get_bolded_text(text: str) -> str: + """Get bolded text.""" + return f"\033[1m{text}\033[0m" + + +def print_text( + text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None +) -> None: + """Print text with highlighting and no end characters.""" + text_to_print = get_colored_text(text, color) if color else text + print(text_to_print, end=end, file=file) + if file: + file.flush() # ensure all printed content are written to file