Skip to content

Commit

Permalink
Refactored input (langchain-ai#8202)
Browse files Browse the repository at this point in the history
Refactored `input.py`. The same as
langchain-ai#7961 langchain-ai#8098 langchain-ai#8099
input.py is in the root code folder. This creates the `langchain.input:
Input` group on the API Reference navigation ToC, on the same level as
Chains and Agents which is incorrect.

Refactoring:

- copied input.py file into utils/input.py
- I added the backwards compatibility ref in the original input.py. 
- changed several imports to a new ref

@hwchase17, @baskaryan
  • Loading branch information
leo-gan authored Jul 24, 2023
1 parent 72eb4fa commit 7cbe28b
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 51 deletions.
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
Expand All @@ -39,6 +38,7 @@
from langchain.schema.messages import BaseMessage
from langchain.tools.base import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
from langchain.utils.input import get_color_mapping

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/agents/agent_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
CallbackManagerForChainRun,
Callbacks,
)
from langchain.input import get_color_mapping
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo
from langchain.tools import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
from langchain.utils.input import get_color_mapping

if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/callbacks/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Any, Dict, Optional, TextIO, cast

from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish
from langchain.utils.input import print_text


class FileCallbackHandler(BaseCallbackHandler):
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/callbacks/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Any, Dict, List, Optional, Union

from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.utils.input import print_text


class StdOutCallbackHandler(BaseCallbackHandler):
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/callbacks/tracers/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.input import get_bolded_text, get_colored_text
from langchain.utils.input import get_bolded_text, get_colored_text


def try_json_stringify(obj: Any, fallback: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Callbacks,
)
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
Expand All @@ -25,6 +24,7 @@
PromptValue,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.utils.input import get_colored_text


class LLMChain(Chain):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain
from langchain.chat_models import ChatOpenAI
from langchain.input import get_colored_text
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.prompts import ChatPromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import APIOperation
from langchain.utilities.openapi import OpenAPISpec
from langchain.utils.input import get_colored_text


def _get_description(o: Any, prefer_short: bool) -> Optional[str]:
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/langchain/chains/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.input import get_color_mapping
from langchain.utils.input import get_color_mapping


class SequentialChain(Chain):
Expand Down
56 changes: 14 additions & 42 deletions libs/langchain/langchain/input.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,14 @@
"""Handle chained inputs."""
from typing import Dict, List, Optional, TextIO

_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}


def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping


def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"


def get_bolded_text(text: str) -> str:
"""Get bolded text."""
return f"\033[1m{text}\033[0m"


def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file
"""DEPRECATED: Kept for backwards compatibility."""
from langchain.utils.input import (
get_bolded_text,
get_color_mapping,
get_colored_text,
print_text,
)

__all__ = [
"get_bolded_text",
"get_color_mapping",
"get_colored_text",
"print_text",
]
2 changes: 1 addition & 1 deletion libs/langchain/langchain/model_laboratory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.utils.input import get_color_mapping, print_text


class ModelLaboratory:
Expand Down
10 changes: 10 additions & 0 deletions libs/langchain/langchain/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

from langchain.utils.env import get_from_dict_or_env, get_from_env
from langchain.utils.formatting import StrictFormatter, formatter
from langchain.utils.input import (
get_bolded_text,
get_color_mapping,
get_colored_text,
print_text,
)
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
from langchain.utils.utils import (
Expand All @@ -24,11 +30,15 @@
"cosine_similarity",
"cosine_similarity_top_k",
"formatter",
"get_bolded_text",
"get_color_mapping",
"get_colored_text",
"get_from_dict_or_env",
"get_from_env",
"get_pydantic_field_names",
"guard_import",
"mock_now",
"print_text",
"raise_for_status_with_text",
"stringify_dict",
"stringify_value",
Expand Down
42 changes: 42 additions & 0 deletions libs/langchain/langchain/utils/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Handle chained inputs."""
from typing import Dict, List, Optional, TextIO

_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}


def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping


def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"


def get_bolded_text(text: str) -> str:
"""Get bolded text."""
return f"\033[1m{text}\033[0m"


def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file

0 comments on commit 7cbe28b

Please sign in to comment.