From 0d37b4c27d1410ad813e2151eae082a163d61617 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 13 Oct 2023 17:36:44 -0400 Subject: [PATCH] Add python,pandas,xorbits,spark agents to experimental (#11774) See for contex https://github.com/langchain-ai/langchain/discussions/11680 --- .../langchain_experimental/agents/__init__.py | 0 .../agents/agent_toolkits/__init__.py | 0 .../agents/agent_toolkits/pandas/__init__.py | 1 + .../agents/agent_toolkits/pandas/base.py | 341 ++++++++++++++++++ .../agents/agent_toolkits/pandas/prompt.py | 44 +++ .../agents/agent_toolkits/python/__init__.py | 0 .../agents/agent_toolkits/python/base.py | 59 +++ .../agents/agent_toolkits/python/prompt.py | 9 + .../agents/agent_toolkits/spark/__init__.py | 1 + .../agents/agent_toolkits/spark/base.py | 81 +++++ .../agents/agent_toolkits/spark/prompt.py | 13 + .../agents/agent_toolkits/xorbits/__init__.py | 1 + .../agents/agent_toolkits/xorbits/base.py | 91 +++++ .../agents/agent_toolkits/xorbits/prompt.py | 33 ++ .../tools/python/__init__.py | 0 .../tools/python/tool.py | 150 ++++++++ .../utilities/python.py | 71 ++++ .../tests/unit_tests/python/__init__.py | 0 .../tests/unit_tests/python/test_python_1.py | 112 ++++++ .../tests/unit_tests/python/test_python_2.py | 164 +++++++++ 20 files changed, 1171 insertions(+) create mode 100644 libs/experimental/langchain_experimental/agents/__init__.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/__init__.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/__init__.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/base.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/prompt.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/python/__init__.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/python/base.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/python/prompt.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/spark/__init__.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/spark/base.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/spark/prompt.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/__init__.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/base.py create mode 100644 libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/prompt.py create mode 100644 libs/experimental/langchain_experimental/tools/python/__init__.py create mode 100644 libs/experimental/langchain_experimental/tools/python/tool.py create mode 100644 libs/experimental/langchain_experimental/utilities/python.py create mode 100644 libs/experimental/tests/unit_tests/python/__init__.py create mode 100644 libs/experimental/tests/unit_tests/python/test_python_1.py create mode 100644 libs/experimental/tests/unit_tests/python/test_python_2.py diff --git a/libs/experimental/langchain_experimental/agents/__init__.py b/libs/experimental/langchain_experimental/agents/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/__init__.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/__init__.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/__init__.py new file mode 100644 index 0000000000000..a6dc608d470e7 --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/__init__.py @@ -0,0 +1 @@ +"""Pandas toolkit.""" diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/base.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/base.py new file mode 100644 index 0000000000000..ef5e1eae8a566 --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/base.py @@ -0,0 +1,341 @@ +"""Agent for working with pandas objects.""" +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent +from langchain.agents.types import AgentType +from langchain.callbacks.base import BaseCallbackManager +from langchain.chains.llm import LLMChain +from langchain.schema import BasePromptTemplate +from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.messages import SystemMessage +from langchain.tools import BaseTool + +from langchain_experimental.agents.agent_toolkits.pandas.prompt import ( + FUNCTIONS_WITH_DF, + FUNCTIONS_WITH_MULTI_DF, + MULTI_DF_PREFIX, + MULTI_DF_PREFIX_FUNCTIONS, + PREFIX, + PREFIX_FUNCTIONS, + SUFFIX_NO_DF, + SUFFIX_WITH_DF, + SUFFIX_WITH_MULTI_DF, +) +from langchain_experimental.tools.python.tool import PythonAstREPLTool + + +def _get_multi_prompt( + dfs: List[Any], + prefix: Optional[str] = None, + suffix: Optional[str] = None, + input_variables: Optional[List[str]] = None, + include_df_in_prompt: Optional[bool] = True, + number_of_head_rows: int = 5, +) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + num_dfs = len(dfs) + if suffix is not None: + suffix_to_use = suffix + include_dfs_head = True + elif include_df_in_prompt: + suffix_to_use = SUFFIX_WITH_MULTI_DF + include_dfs_head = True + else: + suffix_to_use = SUFFIX_NO_DF + include_dfs_head = False + if input_variables is None: + input_variables = ["input", "agent_scratchpad", "num_dfs"] + if include_dfs_head: + input_variables += ["dfs_head"] + + if prefix is None: + prefix = MULTI_DF_PREFIX + + df_locals = {} + for i, dataframe in enumerate(dfs): + df_locals[f"df{i + 1}"] = dataframe + tools = [PythonAstREPLTool(locals=df_locals)] + + prompt = ZeroShotAgent.create_prompt( + tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables + ) + + partial_prompt = prompt.partial() + if "dfs_head" in input_variables: + dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs]) + partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs), dfs_head=dfs_head) + if "num_dfs" in input_variables: + partial_prompt = partial_prompt.partial(num_dfs=str(num_dfs)) + return partial_prompt, tools + + +def _get_single_prompt( + df: Any, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + input_variables: Optional[List[str]] = None, + include_df_in_prompt: Optional[bool] = True, + number_of_head_rows: int = 5, +) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + if suffix is not None: + suffix_to_use = suffix + include_df_head = True + elif include_df_in_prompt: + suffix_to_use = SUFFIX_WITH_DF + include_df_head = True + else: + suffix_to_use = SUFFIX_NO_DF + include_df_head = False + + if input_variables is None: + input_variables = ["input", "agent_scratchpad"] + if include_df_head: + input_variables += ["df_head"] + + if prefix is None: + prefix = PREFIX + + tools = [PythonAstREPLTool(locals={"df": df})] + + prompt = ZeroShotAgent.create_prompt( + tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables + ) + + partial_prompt = prompt.partial() + if "df_head" in input_variables: + partial_prompt = partial_prompt.partial( + df_head=str(df.head(number_of_head_rows).to_markdown()) + ) + return partial_prompt, tools + + +def _get_prompt_and_tools( + df: Any, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + input_variables: Optional[List[str]] = None, + include_df_in_prompt: Optional[bool] = True, + number_of_head_rows: int = 5, +) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + try: + import pandas as pd + + pd.set_option("display.max_columns", None) + except ImportError: + raise ImportError( + "pandas package not found, please install with `pip install pandas`" + ) + + if include_df_in_prompt is not None and suffix is not None: + raise ValueError("If suffix is specified, include_df_in_prompt should not be.") + + if isinstance(df, list): + for item in df: + if not isinstance(item, pd.DataFrame): + raise ValueError(f"Expected pandas object, got {type(df)}") + return _get_multi_prompt( + df, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + include_df_in_prompt=include_df_in_prompt, + number_of_head_rows=number_of_head_rows, + ) + else: + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Expected pandas object, got {type(df)}") + return _get_single_prompt( + df, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + include_df_in_prompt=include_df_in_prompt, + number_of_head_rows=number_of_head_rows, + ) + + +def _get_functions_single_prompt( + df: Any, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + include_df_in_prompt: Optional[bool] = True, + number_of_head_rows: int = 5, +) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + if suffix is not None: + suffix_to_use = suffix + if include_df_in_prompt: + suffix_to_use = suffix_to_use.format( + df_head=str(df.head(number_of_head_rows).to_markdown()) + ) + elif include_df_in_prompt: + suffix_to_use = FUNCTIONS_WITH_DF.format( + df_head=str(df.head(number_of_head_rows).to_markdown()) + ) + else: + suffix_to_use = "" + + if prefix is None: + prefix = PREFIX_FUNCTIONS + + tools = [PythonAstREPLTool(locals={"df": df})] + system_message = SystemMessage(content=prefix + suffix_to_use) + prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message) + return prompt, tools + + +def _get_functions_multi_prompt( + dfs: Any, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + include_df_in_prompt: Optional[bool] = True, + number_of_head_rows: int = 5, +) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + if suffix is not None: + suffix_to_use = suffix + if include_df_in_prompt: + dfs_head = "\n\n".join( + [d.head(number_of_head_rows).to_markdown() for d in dfs] + ) + suffix_to_use = suffix_to_use.format( + dfs_head=dfs_head, + ) + elif include_df_in_prompt: + dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs]) + suffix_to_use = FUNCTIONS_WITH_MULTI_DF.format( + dfs_head=dfs_head, + ) + else: + suffix_to_use = "" + + if prefix is None: + prefix = MULTI_DF_PREFIX_FUNCTIONS + prefix = prefix.format(num_dfs=str(len(dfs))) + + df_locals = {} + for i, dataframe in enumerate(dfs): + df_locals[f"df{i + 1}"] = dataframe + tools = [PythonAstREPLTool(locals=df_locals)] + system_message = SystemMessage(content=prefix + suffix_to_use) + prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message) + return prompt, tools + + +def _get_functions_prompt_and_tools( + df: Any, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + input_variables: Optional[List[str]] = None, + include_df_in_prompt: Optional[bool] = True, + number_of_head_rows: int = 5, +) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + try: + import pandas as pd + + pd.set_option("display.max_columns", None) + except ImportError: + raise ImportError( + "pandas package not found, please install with `pip install pandas`" + ) + if input_variables is not None: + raise ValueError("`input_variables` is not supported at the moment.") + + if include_df_in_prompt is not None and suffix is not None: + raise ValueError("If suffix is specified, include_df_in_prompt should not be.") + + if isinstance(df, list): + for item in df: + if not isinstance(item, pd.DataFrame): + raise ValueError(f"Expected pandas object, got {type(df)}") + return _get_functions_multi_prompt( + df, + prefix=prefix, + suffix=suffix, + include_df_in_prompt=include_df_in_prompt, + number_of_head_rows=number_of_head_rows, + ) + else: + if not isinstance(df, pd.DataFrame): + raise ValueError(f"Expected pandas object, got {type(df)}") + return _get_functions_single_prompt( + df, + prefix=prefix, + suffix=suffix, + include_df_in_prompt=include_df_in_prompt, + number_of_head_rows=number_of_head_rows, + ) + + +def create_pandas_dataframe_agent( + llm: BaseLanguageModel, + df: Any, + agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: Optional[str] = None, + suffix: Optional[str] = None, + input_variables: Optional[List[str]] = None, + verbose: bool = False, + return_intermediate_steps: bool = False, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + include_df_in_prompt: Optional[bool] = True, + number_of_head_rows: int = 5, + extra_tools: Sequence[BaseTool] = (), + **kwargs: Dict[str, Any], +) -> AgentExecutor: + """Construct a pandas agent from an LLM and dataframe.""" + agent: BaseSingleActionAgent + if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: + prompt, base_tools = _get_prompt_and_tools( + df, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + include_df_in_prompt=include_df_in_prompt, + number_of_head_rows=number_of_head_rows, + ) + tools = base_tools + list(extra_tools) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent( + llm_chain=llm_chain, + allowed_tools=tool_names, + callback_manager=callback_manager, + **kwargs, + ) + elif agent_type == AgentType.OPENAI_FUNCTIONS: + _prompt, base_tools = _get_functions_prompt_and_tools( + df, + prefix=prefix, + suffix=suffix, + input_variables=input_variables, + include_df_in_prompt=include_df_in_prompt, + number_of_head_rows=number_of_head_rows, + ) + tools = base_tools + list(extra_tools) + agent = OpenAIFunctionsAgent( + llm=llm, + prompt=_prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) + else: + raise ValueError(f"Agent type {agent_type} not supported at the moment.") + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/prompt.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/prompt.py new file mode 100644 index 0000000000000..72b2bc8b20bc4 --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/prompt.py @@ -0,0 +1,44 @@ +# flake8: noqa + +PREFIX = """ +You are working with a pandas dataframe in Python. The name of the dataframe is `df`. +You should use the tools below to answer the question posed of you:""" + +MULTI_DF_PREFIX = """ +You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc. You +should use the tools below to answer the question posed of you:""" + +SUFFIX_NO_DF = """ +Begin! +Question: {input} +{agent_scratchpad}""" + +SUFFIX_WITH_DF = """ +This is the result of `print(df.head())`: +{df_head} + +Begin! +Question: {input} +{agent_scratchpad}""" + +SUFFIX_WITH_MULTI_DF = """ +This is the result of `print(df.head())` for each dataframe: +{dfs_head} + +Begin! +Question: {input} +{agent_scratchpad}""" + +PREFIX_FUNCTIONS = """ +You are working with a pandas dataframe in Python. The name of the dataframe is `df`.""" + +MULTI_DF_PREFIX_FUNCTIONS = """ +You are working with {num_dfs} pandas dataframes in Python named df1, df2, etc.""" + +FUNCTIONS_WITH_DF = """ +This is the result of `print(df.head())`: +{df_head}""" + +FUNCTIONS_WITH_MULTI_DF = """ +This is the result of `print(df.head())` for each dataframe: +{dfs_head}""" diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/python/__init__.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/python/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/python/base.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/python/base.py new file mode 100644 index 0000000000000..10d9632357d84 --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/python/base.py @@ -0,0 +1,59 @@ +"""Python agent.""" + +from typing import Any, Dict, Optional + +from langchain.agents.agent import AgentExecutor, BaseSingleActionAgent +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent +from langchain.agents.types import AgentType +from langchain.callbacks.base import BaseCallbackManager +from langchain.chains.llm import LLMChain +from langchain.schema.language_model import BaseLanguageModel +from langchain.schema.messages import SystemMessage + +from langchain_experimental.agents.agent_toolkits.python.prompt import PREFIX +from langchain_experimental.tools.python.tool import PythonREPLTool + + +def create_python_agent( + llm: BaseLanguageModel, + tool: PythonREPLTool, + agent_type: AgentType = AgentType.ZERO_SHOT_REACT_DESCRIPTION, + callback_manager: Optional[BaseCallbackManager] = None, + verbose: bool = False, + prefix: str = PREFIX, + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Dict[str, Any], +) -> AgentExecutor: + """Construct a python agent from an LLM and tool.""" + tools = [tool] + agent: BaseSingleActionAgent + + if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: + prompt = ZeroShotAgent.create_prompt(tools, prefix=prefix) + llm_chain = LLMChain( + llm=llm, + prompt=prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + elif agent_type == AgentType.OPENAI_FUNCTIONS: + system_message = SystemMessage(content=prefix) + _prompt = OpenAIFunctionsAgent.create_prompt(system_message=system_message) + agent = OpenAIFunctionsAgent( + llm=llm, + prompt=_prompt, + tools=tools, + callback_manager=callback_manager, + **kwargs, + ) + else: + raise ValueError(f"Agent type {agent_type} not supported at the moment.") + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/python/prompt.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/python/prompt.py new file mode 100644 index 0000000000000..fc97e7916eb47 --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/python/prompt.py @@ -0,0 +1,9 @@ +# flake8: noqa + +PREFIX = """You are an agent designed to write and execute python code to answer questions. +You have access to a python REPL, which you can use to execute python code. +If you get an error, debug your code and try again. +Only use the output of your code to answer the question. +You might know the answer without running any code, but you should still run the code to get the answer. +If it does not seem like you can write code to answer the question, just return "I don't know" as the answer. +""" diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/spark/__init__.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/spark/__init__.py new file mode 100644 index 0000000000000..ded6eb03a420a --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/spark/__init__.py @@ -0,0 +1 @@ +"""spark toolkit""" diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/spark/base.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/spark/base.py new file mode 100644 index 0000000000000..d2ed41d65eec4 --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/spark/base.py @@ -0,0 +1,81 @@ +"""Agent for working with pandas objects.""" +from typing import Any, Dict, List, Optional + +from langchain.agents.agent import AgentExecutor +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.callbacks.base import BaseCallbackManager +from langchain.chains.llm import LLMChain +from langchain.llms.base import BaseLLM + +from langchain_experimental.agents.agent_toolkits.spark.prompt import PREFIX, SUFFIX +from langchain_experimental.tools.python.tool import PythonAstREPLTool + + +def _validate_spark_df(df: Any) -> bool: + try: + from pyspark.sql import DataFrame as SparkLocalDataFrame + + return isinstance(df, SparkLocalDataFrame) + except ImportError: + return False + + +def _validate_spark_connect_df(df: Any) -> bool: + try: + from pyspark.sql.connect.dataframe import DataFrame as SparkConnectDataFrame + + return isinstance(df, SparkConnectDataFrame) + except ImportError: + return False + + +def create_spark_dataframe_agent( + llm: BaseLLM, + df: Any, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = PREFIX, + suffix: str = SUFFIX, + input_variables: Optional[List[str]] = None, + verbose: bool = False, + return_intermediate_steps: bool = False, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Dict[str, Any], +) -> AgentExecutor: + """Construct a Spark agent from an LLM and dataframe.""" + + if not _validate_spark_df(df) and not _validate_spark_connect_df(df): + raise ImportError("Spark is not installed. run `pip install pyspark`.") + + if input_variables is None: + input_variables = ["df", "input", "agent_scratchpad"] + tools = [PythonAstREPLTool(locals={"df": df})] + prompt = ZeroShotAgent.create_prompt( + tools, prefix=prefix, suffix=suffix, input_variables=input_variables + ) + partial_prompt = prompt.partial(df=str(df.first())) + llm_chain = LLMChain( + llm=llm, + prompt=partial_prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent( + llm_chain=llm_chain, + allowed_tools=tool_names, + callback_manager=callback_manager, + **kwargs, + ) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/spark/prompt.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/spark/prompt.py new file mode 100644 index 0000000000000..32ce2c3423540 --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/spark/prompt.py @@ -0,0 +1,13 @@ +# flake8: noqa + +PREFIX = """ +You are working with a spark dataframe in Python. The name of the dataframe is `df`. +You should use the tools below to answer the question posed of you:""" + +SUFFIX = """ +This is the result of `print(df.first())`: +{df} + +Begin! +Question: {input} +{agent_scratchpad}""" diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/__init__.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/__init__.py new file mode 100644 index 0000000000000..71bf7b70ea2c9 --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/__init__.py @@ -0,0 +1 @@ +"""Xorbits toolkit.""" diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/base.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/base.py new file mode 100644 index 0000000000000..70f8118869213 --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/base.py @@ -0,0 +1,91 @@ +"""Agent for working with xorbits objects.""" +from typing import Any, Dict, List, Optional + +from langchain.agents.agent import AgentExecutor +from langchain.agents.mrkl.base import ZeroShotAgent +from langchain.callbacks.base import BaseCallbackManager +from langchain.chains.llm import LLMChain +from langchain.llms.base import BaseLLM + +from langchain_experimental.agents.agent_toolkits.xorbits.prompt import ( + NP_PREFIX, + NP_SUFFIX, + PD_PREFIX, + PD_SUFFIX, +) +from langchain_experimental.tools.python.tool import PythonAstREPLTool + + +def create_xorbits_agent( + llm: BaseLLM, + data: Any, + callback_manager: Optional[BaseCallbackManager] = None, + prefix: str = "", + suffix: str = "", + input_variables: Optional[List[str]] = None, + verbose: bool = False, + return_intermediate_steps: bool = False, + max_iterations: Optional[int] = 15, + max_execution_time: Optional[float] = None, + early_stopping_method: str = "force", + agent_executor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Dict[str, Any], +) -> AgentExecutor: + """Construct a xorbits agent from an LLM and dataframe.""" + try: + from xorbits import numpy as np + from xorbits import pandas as pd + except ImportError: + raise ImportError( + "Xorbits package not installed, please install with `pip install xorbits`" + ) + + if not isinstance(data, (pd.DataFrame, np.ndarray)): + raise ValueError( + f"Expected Xorbits DataFrame or ndarray object, got {type(data)}" + ) + if input_variables is None: + input_variables = ["data", "input", "agent_scratchpad"] + tools = [PythonAstREPLTool(locals={"data": data})] + prompt, partial_input = None, None + + if isinstance(data, pd.DataFrame): + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=PD_PREFIX if prefix == "" else prefix, + suffix=PD_SUFFIX if suffix == "" else suffix, + input_variables=input_variables, + ) + partial_input = str(data.head()) + else: + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=NP_PREFIX if prefix == "" else prefix, + suffix=NP_SUFFIX if suffix == "" else suffix, + input_variables=input_variables, + ) + partial_input = str(data[: len(data) // 2]) + partial_prompt = prompt.partial(data=partial_input) + llm_chain = LLMChain( + llm=llm, + prompt=partial_prompt, + callback_manager=callback_manager, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent( + llm_chain=llm_chain, + allowed_tools=tool_names, + callback_manager=callback_manager, + **kwargs, + ) + return AgentExecutor.from_agent_and_tools( + agent=agent, + tools=tools, + callback_manager=callback_manager, + verbose=verbose, + return_intermediate_steps=return_intermediate_steps, + max_iterations=max_iterations, + max_execution_time=max_execution_time, + early_stopping_method=early_stopping_method, + **(agent_executor_kwargs or {}), + ) diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/prompt.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/prompt.py new file mode 100644 index 0000000000000..6db3a41a79fbe --- /dev/null +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/xorbits/prompt.py @@ -0,0 +1,33 @@ +PD_PREFIX = """ +You are working with Xorbits dataframe object in Python. +Before importing Numpy or Pandas in the current script, +remember to import the xorbits version of the library instead. +To import the xorbits version of Numpy, replace the original import statement +`import pandas as pd` with `import xorbits.pandas as pd`. +The name of the input is `data`. +You should use the tools below to answer the question posed of you:""" + +PD_SUFFIX = """ +This is the result of `print(data)`: +{data} + +Begin! +Question: {input} +{agent_scratchpad}""" + +NP_PREFIX = """ +You are working with Xorbits ndarray object in Python. +Before importing Numpy in the current script, +remember to import the xorbits version of the library instead. +To import the xorbits version of Numpy, replace the original import statement +`import numpy as np` with `import xorbits.numpy as np`. +The name of the input is `data`. +You should use the tools below to answer the question posed of you:""" + +NP_SUFFIX = """ +This is the result of `print(data)`: +{data} + +Begin! +Question: {input} +{agent_scratchpad}""" diff --git a/libs/experimental/langchain_experimental/tools/python/__init__.py b/libs/experimental/langchain_experimental/tools/python/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/experimental/langchain_experimental/tools/python/tool.py b/libs/experimental/langchain_experimental/tools/python/tool.py new file mode 100644 index 0000000000000..eb9b1c08a3637 --- /dev/null +++ b/libs/experimental/langchain_experimental/tools/python/tool.py @@ -0,0 +1,150 @@ +"""A tool for running python code in a REPL.""" + +import ast +import asyncio +import re +import sys +from contextlib import redirect_stdout +from io import StringIO +from typing import Any, Dict, Optional, Type + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForToolRun, + CallbackManagerForToolRun, +) +from langchain.pydantic_v1 import BaseModel, Field, root_validator +from langchain.tools.base import BaseTool + +from langchain_experimental.utilities.python import PythonREPL + + +def _get_default_python_repl() -> PythonREPL: + return PythonREPL(_globals=globals(), _locals=None) + + +def sanitize_input(query: str) -> str: + """Sanitize input to the python REPL. + Remove whitespace, backtick & python (if llm mistakes python console as terminal) + + Args: + query: The query to sanitize + + Returns: + str: The sanitized query + """ + + # Removes `, whitespace & python from start + query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query) + # Removes whitespace & ` from end + query = re.sub(r"(\s|`)*$", "", query) + return query + + +class PythonREPLTool(BaseTool): + """A tool for running python code in a REPL.""" + + name: str = "Python_REPL" + description: str = ( + "A Python shell. Use this to execute python commands. " + "Input should be a valid python command. " + "If you want to see the output of a value, you should print it out " + "with `print(...)`." + ) + python_repl: PythonREPL = Field(default_factory=_get_default_python_repl) + sanitize_input: bool = True + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> Any: + """Use the tool.""" + if self.sanitize_input: + query = sanitize_input(query) + return self.python_repl.run(query) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> Any: + """Use the tool asynchronously.""" + if self.sanitize_input: + query = sanitize_input(query) + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, self.run, query) + + return result + + +class PythonInputs(BaseModel): + query: str = Field(description="code snippet to run") + + +class PythonAstREPLTool(BaseTool): + """A tool for running python code in a REPL.""" + + name: str = "python_repl_ast" + description: str = ( + "A Python shell. Use this to execute python commands. " + "Input should be a valid python command. " + "When using this tool, sometimes output is abbreviated - " + "make sure it does not look abbreviated before using it in your answer." + ) + globals: Optional[Dict] = Field(default_factory=dict) + locals: Optional[Dict] = Field(default_factory=dict) + sanitize_input: bool = True + args_schema: Type[BaseModel] = PythonInputs + + @root_validator(pre=True) + def validate_python_version(cls, values: Dict) -> Dict: + """Validate valid python version.""" + if sys.version_info < (3, 9): + raise ValueError( + "This tool relies on Python 3.9 or higher " + "(as it uses new functionality in the `ast` module, " + f"you have Python version: {sys.version}" + ) + return values + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the tool.""" + try: + if self.sanitize_input: + query = sanitize_input(query) + tree = ast.parse(query) + module = ast.Module(tree.body[:-1], type_ignores=[]) + exec(ast.unparse(module), self.globals, self.locals) # type: ignore + module_end = ast.Module(tree.body[-1:], type_ignores=[]) + module_end_str = ast.unparse(module_end) # type: ignore + io_buffer = StringIO() + try: + with redirect_stdout(io_buffer): + ret = eval(module_end_str, self.globals, self.locals) + if ret is None: + return io_buffer.getvalue() + else: + return ret + except Exception: + with redirect_stdout(io_buffer): + exec(module_end_str, self.globals, self.locals) + return io_buffer.getvalue() + except Exception as e: + return "{}: {}".format(type(e).__name__, str(e)) + + async def _arun( + self, + query: str, + run_manager: Optional[AsyncCallbackManagerForToolRun] = None, + ) -> Any: + """Use the tool asynchronously.""" + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, self._run, query) + + return result diff --git a/libs/experimental/langchain_experimental/utilities/python.py b/libs/experimental/langchain_experimental/utilities/python.py new file mode 100644 index 0000000000000..d2f5d2fb72a44 --- /dev/null +++ b/libs/experimental/langchain_experimental/utilities/python.py @@ -0,0 +1,71 @@ +import functools +import logging +import multiprocessing +import sys +from io import StringIO +from typing import Dict, Optional + +from langchain.pydantic_v1 import BaseModel, Field + +logger = logging.getLogger(__name__) + + +@functools.lru_cache(maxsize=None) +def warn_once() -> None: + """Warn once about the dangers of PythonREPL.""" + logger.warning("Python REPL can execute arbitrary code. Use with caution.") + + +class PythonREPL(BaseModel): + """Simulates a standalone Python REPL.""" + + globals: Optional[Dict] = Field(default_factory=dict, alias="_globals") + locals: Optional[Dict] = Field(default_factory=dict, alias="_locals") + + @classmethod + def worker( + cls, + command: str, + globals: Optional[Dict], + locals: Optional[Dict], + queue: multiprocessing.Queue, + ) -> None: + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + try: + exec(command, globals, locals) + sys.stdout = old_stdout + queue.put(mystdout.getvalue()) + except Exception as e: + sys.stdout = old_stdout + queue.put(repr(e)) + + def run(self, command: str, timeout: Optional[int] = None) -> str: + """Run command with own globals/locals and returns anything printed. + Timeout after the specified number of seconds.""" + + # Warn against dangers of PythonREPL + warn_once() + + queue: multiprocessing.Queue = multiprocessing.Queue() + + # Only use multiprocessing if we are enforcing a timeout + if timeout is not None: + # create a Process + p = multiprocessing.Process( + target=self.worker, args=(command, self.globals, self.locals, queue) + ) + + # start it + p.start() + + # wait for the process to finish or kill it after timeout seconds + p.join(timeout) + + if p.is_alive(): + p.terminate() + return "Execution timed out" + else: + self.worker(command, self.globals, self.locals, queue) + # get the result from the worker function + return queue.get() diff --git a/libs/experimental/tests/unit_tests/python/__init__.py b/libs/experimental/tests/unit_tests/python/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/experimental/tests/unit_tests/python/test_python_1.py b/libs/experimental/tests/unit_tests/python/test_python_1.py new file mode 100644 index 0000000000000..46d92c29e6ea5 --- /dev/null +++ b/libs/experimental/tests/unit_tests/python/test_python_1.py @@ -0,0 +1,112 @@ +"""Test functionality of Python REPL.""" +import sys + +import pytest + +from langchain_experimental.tools.python.tool import PythonAstREPLTool, PythonREPLTool +from langchain_experimental.utilities.python import PythonREPL + +_SAMPLE_CODE = """ +``` +def multiply(): + print(5*6) +multiply() +``` +""" + +_AST_SAMPLE_CODE = """ +``` +def multiply(): + return(5*6) +multiply() +``` +""" + +_AST_SAMPLE_CODE_EXECUTE = """ +``` +def multiply(a, b): + return(5*6) +a = 5 +b = 6 + +multiply(a, b) +``` +""" + + +def test_python_repl() -> None: + """Test functionality when globals/locals are not provided.""" + repl = PythonREPL() + + # Run a simple initial command. + repl.run("foo = 1") + assert repl.locals is not None + assert repl.locals["foo"] == 1 + + # Now run a command that accesses `foo` to make sure it still has it. + repl.run("bar = foo * 2") + assert repl.locals is not None + assert repl.locals["bar"] == 2 + + +def test_python_repl_no_previous_variables() -> None: + """Test that it does not have access to variables created outside the scope.""" + foo = 3 # noqa: F841 + repl = PythonREPL() + output = repl.run("print(foo)") + assert output == """NameError("name 'foo' is not defined")""" + + +def test_python_repl_pass_in_locals() -> None: + """Test functionality when passing in locals.""" + _locals = {"foo": 4} + repl = PythonREPL(_locals=_locals) + repl.run("bar = foo * 2") + assert repl.locals is not None + assert repl.locals["bar"] == 8 + + +def test_functionality() -> None: + """Test correct functionality.""" + chain = PythonREPL() + code = "print(1 + 1)" + output = chain.run(code) + assert output == "2\n" + + +def test_functionality_multiline() -> None: + """Test correct functionality for ChatGPT multiline commands.""" + chain = PythonREPL() + tool = PythonREPLTool(python_repl=chain) + output = tool.run(_SAMPLE_CODE) + assert output == "30\n" + + +def test_python_ast_repl_multiline() -> None: + """Test correct functionality for ChatGPT multiline commands.""" + if sys.version_info < (3, 9): + pytest.skip("Python 3.9+ is required for this test") + tool = PythonAstREPLTool() + output = tool.run(_AST_SAMPLE_CODE) + assert output == 30 + + +def test_python_ast_repl_multi_statement() -> None: + """Test correct functionality for ChatGPT multi statement commands.""" + if sys.version_info < (3, 9): + pytest.skip("Python 3.9+ is required for this test") + tool = PythonAstREPLTool() + output = tool.run(_AST_SAMPLE_CODE_EXECUTE) + assert output == 30 + + +def test_function() -> None: + """Test correct functionality.""" + chain = PythonREPL() + code = "def add(a, b): " " return a + b" + output = chain.run(code) + assert output == "" + + code = "print(add(1, 2))" + output = chain.run(code) + assert output == "3\n" diff --git a/libs/experimental/tests/unit_tests/python/test_python_2.py b/libs/experimental/tests/unit_tests/python/test_python_2.py new file mode 100644 index 0000000000000..9a7f2c8d04efb --- /dev/null +++ b/libs/experimental/tests/unit_tests/python/test_python_2.py @@ -0,0 +1,164 @@ +"""Test Python REPL Tools.""" +import sys + +import numpy as np +import pytest + +from langchain_experimental.tools.python.tool import ( + PythonAstREPLTool, + PythonREPLTool, + sanitize_input, +) + + +def test_python_repl_tool_single_input() -> None: + """Test that the python REPL tool works with a single input.""" + tool = PythonREPLTool() + assert tool.is_single_input + assert int(tool.run("print(1 + 1)").strip()) == 2 + + +def test_python_repl_print() -> None: + program = """ +import numpy as np +v1 = np.array([1, 2, 3]) +v2 = np.array([4, 5, 6]) +dot_product = np.dot(v1, v2) +print("The dot product is {:d}.".format(dot_product)) + """ + tool = PythonREPLTool() + assert tool.run(program) == "The dot product is 32.\n" + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_tool_single_input() -> None: + """Test that the python REPL tool works with a single input.""" + tool = PythonAstREPLTool() + assert tool.is_single_input + assert tool.run("1 + 1") == 2 + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_return() -> None: + program = """ +``` +import numpy as np +v1 = np.array([1, 2, 3]) +v2 = np.array([4, 5, 6]) +dot_product = np.dot(v1, v2) +int(dot_product) +``` + """ + tool = PythonAstREPLTool() + assert tool.run(program) == 32 + + program = """ +```python +import numpy as np +v1 = np.array([1, 2, 3]) +v2 = np.array([4, 5, 6]) +dot_product = np.dot(v1, v2) +int(dot_product) +``` + """ + assert tool.run(program) == 32 + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_print() -> None: + program = """python +string = "racecar" +if string == string[::-1]: + print(string, "is a palindrome") +else: + print(string, "is not a palindrome")""" + tool = PythonAstREPLTool() + assert tool.run(program) == "racecar is a palindrome\n" + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_repl_print_python_backticks() -> None: + program = "`print('`python` is a great language.')`" + tool = PythonAstREPLTool() + assert tool.run(program) == "`python` is a great language.\n" + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_raise_exception() -> None: + data = {"Name": ["John", "Alice"], "Age": [30, 25]} + program = """ +import pandas as pd +df = pd.DataFrame(data) +df['Gender'] + """ + tool = PythonAstREPLTool(locals={"data": data}) + expected_outputs = ( + "KeyError: 'Gender'", + "ModuleNotFoundError: No module named 'pandas'", + ) + assert tool.run(program) in expected_outputs + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_one_line_print() -> None: + program = 'print("The square of {} is {:.2f}".format(3, 3**2))' + tool = PythonAstREPLTool() + assert tool.run(program) == "The square of 3 is 9.00\n" + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_one_line_return() -> None: + arr = np.array([1, 2, 3, 4, 5]) + tool = PythonAstREPLTool(locals={"arr": arr}) + program = "`(arr**2).sum() # Returns sum of squares`" + assert tool.run(program) == 55 + + +@pytest.mark.skipif( + sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." +) +def test_python_ast_repl_one_line_exception() -> None: + program = "[1, 2, 3][4]" + tool = PythonAstREPLTool() + assert tool.run(program) == "IndexError: list index out of range" + + +def test_sanitize_input() -> None: + query = """ + ``` + p = 5 + ``` + """ + expected = "p = 5" + actual = sanitize_input(query) + assert expected == actual + + query = """ + ```python + p = 5 + ``` + """ + expected = "p = 5" + actual = sanitize_input(query) + assert expected == actual + + query = """ + p = 5 + """ + expected = "p = 5" + actual = sanitize_input(query) + assert expected == actual