diff --git a/.scripts/community_split/script_integrations.sh b/.scripts/community_split/script_integrations.sh index 7fd38cc50990e..941262e119989 100755 --- a/.scripts/community_split/script_integrations.sh +++ b/.scripts/community_split/script_integrations.sh @@ -130,7 +130,6 @@ git grep -l 'from langchain.tools.base' | xargs sed -i '' 's/from langchain.tool git grep -l 'from langchain_community.llms.openai' | xargs sed -i '' 's/from langchain_community.llms.openai/from langchain_openai.llm/g' git grep -l 'from langchain_community.chat_models.openai' | xargs sed -i '' 's/from langchain_community.chat_models.openai/from langchain_openai.chat_model/g' git grep -l 'from langchain_community.embeddings.openai' | xargs sed -i '' 's/from langchain_community.embeddings.openai/from langchain_openai.embedding/g' -git grep -l 'from langchain.utils.json_schema' | xargs sed -i '' 's/from langchain.utils.json_schema/from langchain_core.utils.json_schema/g' cd .. @@ -152,11 +151,10 @@ mv community/langchain_community/embeddings/azure_openai.py partners/openai/lang cp langchain/langchain/utils/openai.py partners/openai/langchain_openai/utils.py cp langchain/langchain/utils/openai_functions.py partners/openai/langchain_openai/functions.py +git add partners core git grep -l 'from langchain.utils.json_schema' | xargs sed -i '' 's/from langchain.utils.json_schema/from langchain_core.utils.json_schema/g' -git add partners core - rm community/langchain_community/chat_models/base.py rm community/langchain_community/llms/base.py rm community/langchain_community/tools/base.py diff --git a/libs/community/langchain_community/adapters/openai.py b/libs/community/langchain_community/adapters/openai.py index 8607468b81d23..0af759ebf5b08 100644 --- a/libs/community/langchain_community/adapters/openai.py +++ b/libs/community/langchain_community/adapters/openai.py @@ -25,6 +25,7 @@ SystemMessage, ToolMessage, ) +from langchain_core.pydantic_v1 import BaseModel from typing_extensions import Literal @@ -38,6 +39,29 @@ async def aenumerate( i += 1 +class IndexableBaseModel(BaseModel): + """Allows a BaseModel to return its fields by string variable indexing""" + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) + + +class Choice(IndexableBaseModel): + message: dict + + +class ChatCompletions(IndexableBaseModel): + choices: List[Choice] + + +class ChoiceChunk(IndexableBaseModel): + delta: dict + + +class ChatCompletionChunk(IndexableBaseModel): + choices: List[ChoiceChunk] + + def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: """Convert a dictionary to a LangChain message. @@ -129,7 +153,7 @@ def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMess return [convert_dict_to_message(m) for m in messages] -def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]: +def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict: _dict: Dict[str, Any] = {} if isinstance(chunk, AIMessageChunk): if i == 0: @@ -148,6 +172,11 @@ def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str # This only happens at the end of streams, and OpenAI returns as empty dict if _dict == {"content": ""}: _dict = {} + return _dict + + +def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]: + _dict = _convert_message_chunk(chunk, i) return {"choices": [{"delta": _dict}]} @@ -262,3 +291,109 @@ def convert_messages_for_finetuning( for session in sessions if _has_assistant_message(session) ] + + +class Completions: + """Completion.""" + + @overload + @staticmethod + def create( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[False] = False, + **kwargs: Any, + ) -> ChatCompletions: + ... + + @overload + @staticmethod + def create( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[True], + **kwargs: Any, + ) -> Iterable: + ... + + @staticmethod + def create( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: bool = False, + **kwargs: Any, + ) -> Union[ChatCompletions, Iterable]: + models = importlib.import_module("langchain.chat_models") + model_cls = getattr(models, provider) + model_config = model_cls(**kwargs) + converted_messages = convert_openai_messages(messages) + if not stream: + result = model_config.invoke(converted_messages) + return ChatCompletions( + choices=[Choice(message=convert_message_to_dict(result))] + ) + else: + return ( + ChatCompletionChunk( + choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))] + ) + for i, c in enumerate(model_config.stream(converted_messages)) + ) + + @overload + @staticmethod + async def acreate( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[False] = False, + **kwargs: Any, + ) -> ChatCompletions: + ... + + @overload + @staticmethod + async def acreate( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: Literal[True], + **kwargs: Any, + ) -> AsyncIterator: + ... + + @staticmethod + async def acreate( + messages: Sequence[Dict[str, Any]], + *, + provider: str = "ChatOpenAI", + stream: bool = False, + **kwargs: Any, + ) -> Union[ChatCompletions, AsyncIterator]: + models = importlib.import_module("langchain.chat_models") + model_cls = getattr(models, provider) + model_config = model_cls(**kwargs) + converted_messages = convert_openai_messages(messages) + if not stream: + result = await model_config.ainvoke(converted_messages) + return ChatCompletions( + choices=[Choice(message=convert_message_to_dict(result))] + ) + else: + return ( + ChatCompletionChunk( + choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))] + ) + async for i, c in aenumerate(model_config.astream(converted_messages)) + ) + + +class Chat: + def __init__(self) -> None: + self.completions = Completions() + + +chat = Chat() diff --git a/libs/community/langchain_community/agent_toolkits/__init__.py b/libs/community/langchain_community/agent_toolkits/__init__.py index 258631e07ed2e..9ab2aef75cf12 100644 --- a/libs/community/langchain_community/agent_toolkits/__init__.py +++ b/libs/community/langchain_community/agent_toolkits/__init__.py @@ -34,6 +34,7 @@ from langchain_community.agent_toolkits.json.base import create_json_agent from langchain_community.agent_toolkits.json.toolkit import JsonToolkit from langchain_community.agent_toolkits.multion.toolkit import MultionToolkit +from langchain_community.agent_toolkits.nasa.toolkit import NasaToolkit from langchain_community.agent_toolkits.nla.toolkit import NLAToolkit from langchain_community.agent_toolkits.office365.toolkit import O365Toolkit from langchain_community.agent_toolkits.openapi.base import create_openapi_agent @@ -49,6 +50,7 @@ from langchain_community.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit from langchain_community.agent_toolkits.sql.base import create_sql_agent from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain_community.agent_toolkits.steam.toolkit import SteamToolkit from langchain_community.agent_toolkits.vectorstore.base import ( create_vectorstore_agent, create_vectorstore_router_agent, @@ -94,12 +96,14 @@ def __getattr__(name: str) -> Any: "JiraToolkit", "JsonToolkit", "MultionToolkit", + "NasaToolkit", "NLAToolkit", "O365Toolkit", "OpenAPIToolkit", "PlayWrightBrowserToolkit", "PowerBIToolkit", "SlackToolkit", + "SteamToolkit", "SQLDatabaseToolkit", "SparkSQLToolkit", "VectorStoreInfo", diff --git a/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py b/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py index 97acfd615d6d7..9d05b5f770750 100644 --- a/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py +++ b/libs/community/langchain_community/agent_toolkits/conversational_retrieval/openai_functions.py @@ -5,7 +5,6 @@ from langchain_core.messages import SystemMessage from langchain_core.prompts.chat import MessagesPlaceholder from langchain_core.tools import BaseTool -from langchain_openai.chat_model import ChatOpenAI from langchain_community.agents.agent import AgentExecutor from langchain_community.agents.openai_functions_agent.agent_token_buffer_memory import ( @@ -57,8 +56,6 @@ def create_conversational_retrieval_agent( An agent executor initialized appropriately """ - if not isinstance(llm, ChatOpenAI): - raise ValueError("Only supported with ChatOpenAI models.") if remember_intermediate_steps: memory: BaseMemory = AgentTokenBufferMemory( memory_key=memory_key, llm=llm, max_token_limit=max_token_limit diff --git a/libs/community/langchain_community/agent_toolkits/github/toolkit.py b/libs/community/langchain_community/agent_toolkits/github/toolkit.py index ee055b145e65c..7c85504e1c62e 100644 --- a/libs/community/langchain_community/agent_toolkits/github/toolkit.py +++ b/libs/community/langchain_community/agent_toolkits/github/toolkit.py @@ -1,22 +1,128 @@ """GitHub Toolkit.""" from typing import Dict, List +from langchain_core.pydantic_v1 import BaseModel, Field + from langchain_community.agent_toolkits.base import BaseToolkit from langchain_community.tools import BaseTool from langchain_community.tools.github.prompt import ( COMMENT_ON_ISSUE_PROMPT, + CREATE_BRANCH_PROMPT, CREATE_FILE_PROMPT, CREATE_PULL_REQUEST_PROMPT, + CREATE_REVIEW_REQUEST_PROMPT, DELETE_FILE_PROMPT, + GET_FILES_FROM_DIRECTORY_PROMPT, GET_ISSUE_PROMPT, GET_ISSUES_PROMPT, + GET_PR_PROMPT, + LIST_BRANCHES_IN_REPO_PROMPT, + LIST_PRS_PROMPT, + LIST_PULL_REQUEST_FILES, + OVERVIEW_EXISTING_FILES_BOT_BRANCH, + OVERVIEW_EXISTING_FILES_IN_MAIN, READ_FILE_PROMPT, + SEARCH_CODE_PROMPT, + SEARCH_ISSUES_AND_PRS_PROMPT, + SET_ACTIVE_BRANCH_PROMPT, UPDATE_FILE_PROMPT, ) from langchain_community.tools.github.tool import GitHubAction from langchain_community.utilities.github import GitHubAPIWrapper +class NoInput(BaseModel): + no_input: str = Field("", description="No input required, e.g. `` (empty string).") + + +class GetIssue(BaseModel): + issue_number: int = Field(0, description="Issue number as an integer, e.g. `42`") + + +class CommentOnIssue(BaseModel): + input: str = Field(..., description="Follow the required formatting.") + + +class GetPR(BaseModel): + pr_number: int = Field(0, description="The PR number as an integer, e.g. `12`") + + +class CreatePR(BaseModel): + formatted_pr: str = Field(..., description="Follow the required formatting.") + + +class CreateFile(BaseModel): + formatted_file: str = Field(..., description="Follow the required formatting.") + + +class ReadFile(BaseModel): + formatted_filepath: str = Field( + ..., + description=( + "The full file path of the file you would like to read where the " + "path must NOT start with a slash, e.g. `some_dir/my_file.py`." + ), + ) + + +class UpdateFile(BaseModel): + formatted_file_update: str = Field( + ..., description="Strictly follow the provided rules." + ) + + +class DeleteFile(BaseModel): + formatted_filepath: str = Field( + ..., + description=( + "The full file path of the file you would like to delete" + " where the path must NOT start with a slash, e.g." + " `some_dir/my_file.py`. Only input a string," + " not the param name." + ), + ) + + +class DirectoryPath(BaseModel): + input: str = Field( + "", + description=( + "The path of the directory, e.g. `some_dir/inner_dir`." + " Only input a string, do not include the parameter name." + ), + ) + + +class BranchName(BaseModel): + branch_name: str = Field( + ..., description="The name of the branch, e.g. `my_branch`." + ) + + +class SearchCode(BaseModel): + search_query: str = Field( + ..., + description=( + "A keyword-focused natural language search" + "query for code, e.g. `MyFunctionName()`." + ), + ) + + +class CreateReviewRequest(BaseModel): + username: str = Field( + ..., + description="GitHub username of the user being requested, e.g. `my_username`.", + ) + + +class SearchIssuesAndPRs(BaseModel): + search_query: str = Field( + ..., + description="Natural language search query, e.g. `My issue title or topic`.", + ) + + class GitHubToolkit(BaseToolkit): """GitHub Toolkit. @@ -41,41 +147,127 @@ def from_github_api_wrapper( "mode": "get_issues", "name": "Get Issues", "description": GET_ISSUES_PROMPT, + "args_schema": NoInput, }, { "mode": "get_issue", "name": "Get Issue", "description": GET_ISSUE_PROMPT, + "args_schema": GetIssue, }, { "mode": "comment_on_issue", "name": "Comment on Issue", "description": COMMENT_ON_ISSUE_PROMPT, + "args_schema": CommentOnIssue, + }, + { + "mode": "list_open_pull_requests", + "name": "List open pull requests (PRs)", + "description": LIST_PRS_PROMPT, + "args_schema": NoInput, + }, + { + "mode": "get_pull_request", + "name": "Get Pull Request", + "description": GET_PR_PROMPT, + "args_schema": GetPR, + }, + { + "mode": "list_pull_request_files", + "name": "Overview of files included in PR", + "description": LIST_PULL_REQUEST_FILES, + "args_schema": GetPR, }, { "mode": "create_pull_request", "name": "Create Pull Request", "description": CREATE_PULL_REQUEST_PROMPT, + "args_schema": CreatePR, + }, + { + "mode": "list_pull_request_files", + "name": "List Pull Requests' Files", + "description": LIST_PULL_REQUEST_FILES, + "args_schema": GetPR, }, { "mode": "create_file", "name": "Create File", "description": CREATE_FILE_PROMPT, + "args_schema": CreateFile, }, { "mode": "read_file", "name": "Read File", "description": READ_FILE_PROMPT, + "args_schema": ReadFile, }, { "mode": "update_file", "name": "Update File", "description": UPDATE_FILE_PROMPT, + "args_schema": UpdateFile, }, { "mode": "delete_file", "name": "Delete File", "description": DELETE_FILE_PROMPT, + "args_schema": DeleteFile, + }, + { + "mode": "list_files_in_main_branch", + "name": "Overview of existing files in Main branch", + "description": OVERVIEW_EXISTING_FILES_IN_MAIN, + "args_schema": NoInput, + }, + { + "mode": "list_files_in_bot_branch", + "name": "Overview of files in current working branch", + "description": OVERVIEW_EXISTING_FILES_BOT_BRANCH, + "args_schema": NoInput, + }, + { + "mode": "list_branches_in_repo", + "name": "List branches in this repository", + "description": LIST_BRANCHES_IN_REPO_PROMPT, + "args_schema": NoInput, + }, + { + "mode": "set_active_branch", + "name": "Set active branch", + "description": SET_ACTIVE_BRANCH_PROMPT, + "args_schema": BranchName, + }, + { + "mode": "create_branch", + "name": "Create a new branch", + "description": CREATE_BRANCH_PROMPT, + "args_schema": BranchName, + }, + { + "mode": "get_files_from_directory", + "name": "Get files from a directory", + "description": GET_FILES_FROM_DIRECTORY_PROMPT, + "args_schema": DirectoryPath, + }, + { + "mode": "search_issues_and_prs", + "name": "Search issues and pull requests", + "description": SEARCH_ISSUES_AND_PRS_PROMPT, + "args_schema": SearchIssuesAndPRs, + }, + { + "mode": "search_code", + "name": "Search code", + "description": SEARCH_CODE_PROMPT, + "args_schema": SearchCode, + }, + { + "mode": "create_review_request", + "name": "Create review request", + "description": CREATE_REVIEW_REQUEST_PROMPT, + "args_schema": CreateReviewRequest, }, ] tools = [ @@ -84,6 +276,7 @@ def from_github_api_wrapper( description=action["description"], mode=action["mode"], api_wrapper=github_api_wrapper, + args_schema=action.get("args_schema", None), ) for action in operations ] diff --git a/libs/community/langchain_community/agent_toolkits/nasa/__init__.py b/libs/community/langchain_community/agent_toolkits/nasa/__init__.py new file mode 100644 index 0000000000000..a13c3ec706c6d --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/nasa/__init__.py @@ -0,0 +1 @@ +"""NASA Toolkit""" diff --git a/libs/community/langchain_community/agent_toolkits/nasa/toolkit.py b/libs/community/langchain_community/agent_toolkits/nasa/toolkit.py new file mode 100644 index 0000000000000..46edd98af3fb0 --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/nasa/toolkit.py @@ -0,0 +1,57 @@ +from typing import Dict, List + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.nasa.prompt import ( + NASA_CAPTIONS_PROMPT, + NASA_MANIFEST_PROMPT, + NASA_METADATA_PROMPT, + NASA_SEARCH_PROMPT, +) +from langchain_community.tools.nasa.tool import NasaAction +from langchain_community.utilities.nasa import NasaAPIWrapper + + +class NasaToolkit(BaseToolkit): + """Nasa Toolkit.""" + + tools: List[BaseTool] = [] + + @classmethod + def from_nasa_api_wrapper(cls, nasa_api_wrapper: NasaAPIWrapper) -> "NasaToolkit": + operations: List[Dict] = [ + { + "mode": "search_media", + "name": "Search NASA Image and Video Library media", + "description": NASA_SEARCH_PROMPT, + }, + { + "mode": "get_media_metadata_manifest", + "name": "Get NASA Image and Video Library media metadata manifest", + "description": NASA_MANIFEST_PROMPT, + }, + { + "mode": "get_media_metadata_location", + "name": "Get NASA Image and Video Library media metadata location", + "description": NASA_METADATA_PROMPT, + }, + { + "mode": "get_video_captions_location", + "name": "Get NASA Image and Video Library video captions location", + "description": NASA_CAPTIONS_PROMPT, + }, + ] + tools = [ + NasaAction( + name=action["name"], + description=action["description"], + mode=action["mode"], + api_wrapper=nasa_api_wrapper, + ) + for action in operations + ] + return cls(tools=tools) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return self.tools diff --git a/libs/community/langchain_community/agent_toolkits/openapi/planner.py b/libs/community/langchain_community/agent_toolkits/openapi/planner.py index 0e4f323306376..cfe28a1787de3 100644 --- a/libs/community/langchain_community/agent_toolkits/openapi/planner.py +++ b/libs/community/langchain_community/agent_toolkits/openapi/planner.py @@ -36,6 +36,7 @@ from langchain_community.agents.mrkl.base import ZeroShotAgent from langchain_community.chains.llm import LLMChain from langchain_community.memory import ReadOnlySharedMemory +from langchain_community.output_parsers.json import parse_json_markdown from langchain_community.tools.requests.tool import BaseRequestsTool from langchain_community.utilities.requests import RequestsWrapper @@ -79,7 +80,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool): def _run(self, text: str) -> str: try: - data = json.loads(text) + data = parse_json_markdown(text) except json.JSONDecodeError as e: raise e data_params = data.get("params") @@ -109,7 +110,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool): def _run(self, text: str) -> str: try: - data = json.loads(text) + data = parse_json_markdown(text) except json.JSONDecodeError as e: raise e response = self.requests_wrapper.post(data["url"], data["data"]) @@ -138,7 +139,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool): def _run(self, text: str) -> str: try: - data = json.loads(text) + data = parse_json_markdown(text) except json.JSONDecodeError as e: raise e response = self.requests_wrapper.patch(data["url"], data["data"]) @@ -167,7 +168,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool): def _run(self, text: str) -> str: try: - data = json.loads(text) + data = parse_json_markdown(text) except json.JSONDecodeError as e: raise e response = self.requests_wrapper.put(data["url"], data["data"]) @@ -197,7 +198,7 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool): def _run(self, text: str) -> str: try: - data = json.loads(text) + data = parse_json_markdown(text) except json.JSONDecodeError as e: raise e response = self.requests_wrapper.delete(data["url"]) diff --git a/libs/community/langchain_community/agent_toolkits/steam/__init__.py b/libs/community/langchain_community/agent_toolkits/steam/__init__.py new file mode 100644 index 0000000000000..f99981082424e --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/steam/__init__.py @@ -0,0 +1 @@ +"""Steam Toolkit.""" diff --git a/libs/community/langchain_community/agent_toolkits/steam/toolkit.py b/libs/community/langchain_community/agent_toolkits/steam/toolkit.py new file mode 100644 index 0000000000000..1fe57b032bf7f --- /dev/null +++ b/libs/community/langchain_community/agent_toolkits/steam/toolkit.py @@ -0,0 +1,48 @@ +"""Steam Toolkit.""" +from typing import List + +from langchain_community.agent_toolkits.base import BaseToolkit +from langchain_community.tools import BaseTool +from langchain_community.tools.steam.prompt import ( + STEAM_GET_GAMES_DETAILS, + STEAM_GET_RECOMMENDED_GAMES, +) +from langchain_community.tools.steam.tool import SteamWebAPIQueryRun +from langchain_community.utilities.steam import SteamWebAPIWrapper + + +class SteamToolkit(BaseToolkit): + """Steam Toolkit.""" + + tools: List[BaseTool] = [] + + @classmethod + def from_steam_api_wrapper( + cls, steam_api_wrapper: SteamWebAPIWrapper + ) -> "SteamToolkit": + operations: List[dict] = [ + { + "mode": "get_games_details", + "name": "Get Games Details", + "description": STEAM_GET_GAMES_DETAILS, + }, + { + "mode": "get_recommended_games", + "name": "Get Recommended Games", + "description": STEAM_GET_RECOMMENDED_GAMES, + }, + ] + tools = [ + SteamWebAPIQueryRun( + name=action["name"], + description=action["description"], + mode=action["mode"], + api_wrapper=steam_api_wrapper, + ) + for action in operations + ] + return cls(tools=tools) + + def get_tools(self) -> List[BaseTool]: + """Get the tools in the toolkit.""" + return self.tools diff --git a/libs/community/langchain_community/callbacks/tracers/comet.py b/libs/community/langchain_community/callbacks/tracers/comet.py new file mode 100644 index 0000000000000..45545736c4c17 --- /dev/null +++ b/libs/community/langchain_community/callbacks/tracers/comet.py @@ -0,0 +1,137 @@ +from types import ModuleType, SimpleNamespace +from typing import TYPE_CHECKING, Any, Callable, Dict + +from langchain_core.callbacks.tracers.base import BaseTracer + +if TYPE_CHECKING: + from uuid import UUID + + from comet_llm import Span + from comet_llm.chains.chain import Chain + from langchain_core.callbacks.tracers.schemas import Run + + +def _get_run_type(run: "Run") -> str: + if isinstance(run.run_type, str): + return run.run_type + elif hasattr(run.run_type, "value"): + return run.run_type.value + else: + return str(run.run_type) + + +def import_comet_llm_api() -> SimpleNamespace: + """Import comet_llm api and raise an error if it is not installed.""" + try: + from comet_llm import ( + experiment_info, # noqa: F401 + flush, # noqa: F401 + ) + from comet_llm.chains import api as chain_api # noqa: F401 + from comet_llm.chains import ( + chain, # noqa: F401 + span, # noqa: F401 + ) + + except ImportError: + raise ImportError( + "To use the CometTracer you need to have the " + "`comet_llm>=2.0.0` python package installed. Please install it with" + " `pip install -U comet_llm`" + ) + return SimpleNamespace( + chain=chain, + span=span, + chain_api=chain_api, + experiment_info=experiment_info, + flush=flush, + ) + + +class CometTracer(BaseTracer): + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._span_map: Dict["UUID", "Span"] = {} + self._chains_map: Dict["UUID", "Chain"] = {} + self._initialize_comet_modules() + + def _initialize_comet_modules(self) -> None: + comet_llm_api = import_comet_llm_api() + self._chain: ModuleType = comet_llm_api.chain + self._span: ModuleType = comet_llm_api.span + self._chain_api: ModuleType = comet_llm_api.chain_api + self._experiment_info: ModuleType = comet_llm_api.experiment_info + self._flush: Callable[[], None] = comet_llm_api.flush + + def _persist_run(self, run: "Run") -> None: + chain_ = self._chains_map[run.id] + chain_.set_outputs(outputs=run.outputs) + self._chain_api.log_chain(chain_) + + def _process_start_trace(self, run: "Run") -> None: + if not run.parent_run_id: + # This is the first run, which maps to a chain + chain_: "Chain" = self._chain.Chain( + inputs=run.inputs, + metadata=None, + experiment_info=self._experiment_info.get(), + ) + self._chains_map[run.id] = chain_ + else: + span: "Span" = self._span.Span( + inputs=run.inputs, + category=_get_run_type(run), + metadata=run.extra, + name=run.name, + ) + span.__api__start__(self._chains_map[run.parent_run_id]) + self._chains_map[run.id] = self._chains_map[run.parent_run_id] + self._span_map[run.id] = span + + def _process_end_trace(self, run: "Run") -> None: + if not run.parent_run_id: + pass + # Langchain will call _persist_run for us + else: + span = self._span_map[run.id] + span.set_outputs(outputs=run.outputs) + span.__api__end__() + + def flush(self) -> None: + self._flush() + + def _on_llm_start(self, run: "Run") -> None: + """Process the LLM Run upon start.""" + self._process_start_trace(run) + + def _on_llm_end(self, run: "Run") -> None: + """Process the LLM Run.""" + self._process_end_trace(run) + + def _on_llm_error(self, run: "Run") -> None: + """Process the LLM Run upon error.""" + self._process_end_trace(run) + + def _on_chain_start(self, run: "Run") -> None: + """Process the Chain Run upon start.""" + self._process_start_trace(run) + + def _on_chain_end(self, run: "Run") -> None: + """Process the Chain Run.""" + self._process_end_trace(run) + + def _on_chain_error(self, run: "Run") -> None: + """Process the Chain Run upon error.""" + self._process_end_trace(run) + + def _on_tool_start(self, run: "Run") -> None: + """Process the Tool Run upon start.""" + self._process_start_trace(run) + + def _on_tool_end(self, run: "Run") -> None: + """Process the Tool Run.""" + self._process_end_trace(run) + + def _on_tool_error(self, run: "Run") -> None: + """Process the Tool Run upon error.""" + self._process_end_trace(run) diff --git a/libs/community/langchain_community/document_loaders/__init__.py b/libs/community/langchain_community/document_loaders/__init__.py index a046880e15286..48143a30b664e 100644 --- a/libs/community/langchain_community/document_loaders/__init__.py +++ b/libs/community/langchain_community/document_loaders/__init__.py @@ -66,6 +66,7 @@ from langchain_community.document_loaders.concurrent import ConcurrentLoader from langchain_community.document_loaders.confluence import ConfluenceLoader from langchain_community.document_loaders.conllu import CoNLLULoader +from langchain_community.document_loaders.couchbase import CouchbaseLoader from langchain_community.document_loaders.csv_loader import ( CSVLoader, UnstructuredCSVLoader, @@ -84,7 +85,6 @@ OutlookMessageLoader, UnstructuredEmailLoader, ) -from langchain_community.document_loaders.embaas import EmbaasBlobLoader, EmbaasLoader from langchain_community.document_loaders.epub import UnstructuredEPubLoader from langchain_community.document_loaders.etherscan import EtherscanLoader from langchain_community.document_loaders.evernote import EverNoteLoader @@ -265,6 +265,7 @@ "CollegeConfidentialLoader", "ConcurrentLoader", "ConfluenceLoader", + "CouchbaseLoader", "CubeSemanticLoader", "DataFrameLoader", "DatadogLogsLoader", @@ -276,8 +277,6 @@ "Docx2txtLoader", "DropboxLoader", "DuckDBLoader", - "EmbaasBlobLoader", - "EmbaasLoader", "EtherscanLoader", "EverNoteLoader", "FacebookChatLoader", diff --git a/libs/community/langchain_community/document_loaders/blob_loaders/schema.py b/libs/community/langchain_community/document_loaders/blob_loaders/schema.py index 9d1e737e3745a..c2f88a14015ab 100644 --- a/libs/community/langchain_community/document_loaders/blob_loaders/schema.py +++ b/libs/community/langchain_community/document_loaders/blob_loaders/schema.py @@ -11,9 +11,9 @@ from abc import ABC, abstractmethod from io import BufferedReader, BytesIO from pathlib import PurePath -from typing import Any, Generator, Iterable, Mapping, Optional, Union +from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Union, cast -from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator PathLike = Union[str, PurePath] @@ -28,14 +28,20 @@ class Blob(BaseModel): Inspired by: https://developer.mozilla.org/en-US/docs/Web/API/Blob """ - data: Union[bytes, str, None] # Raw data - mimetype: Optional[str] = None # Not to be confused with a file extension - encoding: str = "utf-8" # Use utf-8 as default encoding, if decoding to string - # Location where the original content was found - # Represent location on the local file system - # Useful for situations where downstream code assumes it must work with file paths - # rather than in-memory content. + data: Union[bytes, str, None] + """Raw data associated with the blob.""" + mimetype: Optional[str] = None + """MimeType not to be confused with a file extension.""" + encoding: str = "utf-8" + """Encoding to use if decoding the bytes into a string. + + Use utf-8 as default encoding, if decoding to string. + """ path: Optional[PathLike] = None + """Location where the original content was found.""" + + metadata: Dict[str, Any] = Field(default_factory=dict) + """Metadata about the blob (e.g., source)""" class Config: arbitrary_types_allowed = True @@ -43,7 +49,15 @@ class Config: @property def source(self) -> Optional[str]: - """The source location of the blob as string if known otherwise none.""" + """The source location of the blob as string if known otherwise none. + + If a path is associated with the blob, it will default to the path location. + + Unless explicitly set via a metadata field called "source", in which + case that value will be used instead. + """ + if self.metadata and "source" in self.metadata: + return cast(Optional[str], self.metadata["source"]) return str(self.path) if self.path else None @root_validator(pre=True) @@ -96,6 +110,7 @@ def from_path( encoding: str = "utf-8", mime_type: Optional[str] = None, guess_type: bool = True, + metadata: Optional[dict] = None, ) -> Blob: """Load the blob from a path like object. @@ -105,6 +120,7 @@ def from_path( mime_type: if provided, will be set as the mime-type of the data guess_type: If True, the mimetype will be guessed from the file extension, if a mime-type was not provided + metadata: Metadata to associate with the blob Returns: Blob instance @@ -115,7 +131,13 @@ def from_path( _mimetype = mime_type # We do not load the data immediately, instead we treat the blob as a # reference to the underlying data. - return cls(data=None, mimetype=_mimetype, encoding=encoding, path=path) + return cls( + data=None, + mimetype=_mimetype, + encoding=encoding, + path=path, + metadata=metadata if metadata is not None else {}, + ) @classmethod def from_data( @@ -125,6 +147,7 @@ def from_data( encoding: str = "utf-8", mime_type: Optional[str] = None, path: Optional[str] = None, + metadata: Optional[dict] = None, ) -> Blob: """Initialize the blob from in-memory data. @@ -133,11 +156,18 @@ def from_data( encoding: Encoding to use if decoding the bytes into a string mime_type: if provided, will be set as the mime-type of the data path: if provided, will be set as the source from which the data came + metadata: Metadata to associate with the blob Returns: Blob instance """ - return cls(data=data, mimetype=mime_type, encoding=encoding, path=path) + return cls( + data=data, + mimetype=mime_type, + encoding=encoding, + path=path, + metadata=metadata if metadata is not None else {}, + ) def __repr__(self) -> str: """Define the blob representation.""" diff --git a/libs/community/langchain_community/document_loaders/couchbase.py b/libs/community/langchain_community/document_loaders/couchbase.py new file mode 100644 index 0000000000000..fabc0a7398744 --- /dev/null +++ b/libs/community/langchain_community/document_loaders/couchbase.py @@ -0,0 +1,100 @@ +import logging +from typing import Iterator, List, Optional + +from langchain_core.documents import Document + +from langchain_community.document_loaders.base import BaseLoader + +logger = logging.getLogger(__name__) + + +class CouchbaseLoader(BaseLoader): + """Load documents from `Couchbase`. + + Each document represents one row of the result. The `page_content_fields` are + written into the `page_content`of the document. The `metadata_fields` are written + into the `metadata` of the document. By default, all columns are written into + the `page_content` and none into the `metadata`. + """ + + def __init__( + self, + connection_string: str, + db_username: str, + db_password: str, + query: str, + *, + page_content_fields: Optional[List[str]] = None, + metadata_fields: Optional[List[str]] = None, + ) -> None: + """Initialize Couchbase document loader. + + Args: + connection_string (str): The connection string to the Couchbase cluster. + db_username (str): The username to connect to the Couchbase cluster. + db_password (str): The password to connect to the Couchbase cluster. + query (str): The SQL++ query to execute. + page_content_fields (Optional[List[str]]): The columns to write into the + `page_content` field of the document. By default, all columns are + written. + metadata_fields (Optional[List[str]]): The columns to write into the + `metadata` field of the document. By default, no columns are written. + """ + try: + from couchbase.auth import PasswordAuthenticator + from couchbase.cluster import Cluster + from couchbase.options import ClusterOptions + except ImportError as e: + raise ImportError( + "Could not import couchbase package." + "Please install couchbase SDK with `pip install couchbase`." + ) from e + if not connection_string: + raise ValueError("connection_string must be provided.") + + if not db_username: + raise ValueError("db_username must be provided.") + + if not db_password: + raise ValueError("db_password must be provided.") + + auth = PasswordAuthenticator( + db_username, + db_password, + ) + + self.cluster: Cluster = Cluster(connection_string, ClusterOptions(auth)) + self.query = query + self.page_content_fields = page_content_fields + self.metadata_fields = metadata_fields + + def load(self) -> List[Document]: + """Load Couchbase data into Document objects.""" + return list(self.lazy_load()) + + def lazy_load(self) -> Iterator[Document]: + """Load Couchbase data into Document objects lazily.""" + from datetime import timedelta + + # Ensure connection to Couchbase cluster + self.cluster.wait_until_ready(timedelta(seconds=5)) + + # Run SQL++ Query + result = self.cluster.query(self.query) + for row in result: + metadata_fields = self.metadata_fields + page_content_fields = self.page_content_fields + + if not page_content_fields: + page_content_fields = list(row.keys()) + + if not metadata_fields: + metadata_fields = [] + + metadata = {field: row[field] for field in metadata_fields} + + document = "\n".join( + f"{k}: {v}" for k, v in row.items() if k in page_content_fields + ) + + yield (Document(page_content=document, metadata=metadata)) diff --git a/libs/community/langchain_community/document_loaders/embaas.py b/libs/community/langchain_community/document_loaders/embaas.py index a98eb21390f25..4c2fe9d9deb5b 100644 --- a/libs/community/langchain_community/document_loaders/embaas.py +++ b/libs/community/langchain_community/document_loaders/embaas.py @@ -1,244 +1,17 @@ -import base64 -import warnings -from typing import Any, Dict, Iterator, List, Optional - -import requests -from langchain_core.documents import Document -from langchain_core.pydantic_v1 import BaseModel, root_validator, validator -from langchain_core.utils import get_from_dict_or_env -from typing_extensions import NotRequired, TypedDict - -from langchain_community.document_loaders.base import BaseBlobParser, BaseLoader -from langchain_community.document_loaders.blob_loaders import Blob -from langchain_community.text_splitter import TextSplitter - -EMBAAS_DOC_API_URL = "https://api.embaas.io/v1/document/extract-text/bytes/" - - -class EmbaasDocumentExtractionParameters(TypedDict): - """Parameters for the embaas document extraction API.""" - - mime_type: NotRequired[str] - """The mime type of the document.""" - file_extension: NotRequired[str] - """The file extension of the document.""" - file_name: NotRequired[str] - """The file name of the document.""" - - should_chunk: NotRequired[bool] - """Whether to chunk the document into pages.""" - chunk_size: NotRequired[int] - """The maximum size of the text chunks.""" - chunk_overlap: NotRequired[int] - """The maximum overlap allowed between chunks.""" - chunk_splitter: NotRequired[str] - """The text splitter class name for creating chunks.""" - separators: NotRequired[List[str]] - """The separators for chunks.""" - - should_embed: NotRequired[bool] - """Whether to create embeddings for the document in the response.""" - model: NotRequired[str] - """The model to pass to the Embaas document extraction API.""" - instruction: NotRequired[str] - """The instruction to pass to the Embaas document extraction API.""" - - -class EmbaasDocumentExtractionPayload(EmbaasDocumentExtractionParameters): - """Payload for the Embaas document extraction API.""" - - bytes: str - """The base64 encoded bytes of the document to extract text from.""" - - -class BaseEmbaasLoader(BaseModel): - """Base loader for `Embaas` document extraction API.""" - - embaas_api_key: Optional[str] = None - """The API key for the Embaas document extraction API.""" - api_url: str = EMBAAS_DOC_API_URL - """The URL of the Embaas document extraction API.""" - params: EmbaasDocumentExtractionParameters = EmbaasDocumentExtractionParameters() - """Additional parameters to pass to the Embaas document extraction API.""" - - @root_validator(pre=True) - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - embaas_api_key = get_from_dict_or_env( - values, "embaas_api_key", "EMBAAS_API_KEY" - ) - values["embaas_api_key"] = embaas_api_key - return values - - -class EmbaasBlobLoader(BaseEmbaasLoader, BaseBlobParser): - """Load `Embaas` blob. - - To use, you should have the - environment variable ``EMBAAS_API_KEY`` set with your API key, or pass - it as a named parameter to the constructor. - - Example: - .. code-block:: python - - # Default parsing - from langchain_community.document_loaders.embaas import EmbaasBlobLoader - loader = EmbaasBlobLoader() - blob = Blob.from_path(path="example.mp3") - documents = loader.parse(blob=blob) - - # Custom api parameters (create embeddings automatically) - from langchain_community.document_loaders.embaas import EmbaasBlobLoader - loader = EmbaasBlobLoader( - params={ - "should_embed": True, - "model": "e5-large-v2", - "chunk_size": 256, - "chunk_splitter": "CharacterTextSplitter" - } - ) - blob = Blob.from_path(path="example.pdf") - documents = loader.parse(blob=blob) - """ - - def lazy_parse(self, blob: Blob) -> Iterator[Document]: - """Parses the blob lazily. - - Args: - blob: The blob to parse. - """ - yield from self._get_documents(blob=blob) - - @staticmethod - def _api_response_to_documents(chunks: List[Dict[str, Any]]) -> List[Document]: - """Convert the API response to a list of documents.""" - docs = [] - for chunk in chunks: - metadata = chunk["metadata"] - if chunk.get("embedding", None) is not None: - metadata["embedding"] = chunk["embedding"] - doc = Document(page_content=chunk["text"], metadata=metadata) - docs.append(doc) - - return docs - - def _generate_payload(self, blob: Blob) -> EmbaasDocumentExtractionPayload: - """Generates payload for the API request.""" - base64_byte_str = base64.b64encode(blob.as_bytes()).decode() - payload: EmbaasDocumentExtractionPayload = EmbaasDocumentExtractionPayload( - bytes=base64_byte_str, - # Workaround for mypy issue: https://github.com/python/mypy/issues/9408 - # type: ignore - **self.params, - ) - - if blob.mimetype is not None and payload.get("mime_type", None) is None: - payload["mime_type"] = blob.mimetype - - return payload - - def _handle_request( - self, payload: EmbaasDocumentExtractionPayload - ) -> List[Document]: - """Sends a request to the embaas API and handles the response.""" - headers = { - "Authorization": f"Bearer {self.embaas_api_key}", - "Content-Type": "application/json", - } - - response = requests.post(self.api_url, headers=headers, json=payload) - response.raise_for_status() - - parsed_response = response.json() - return EmbaasBlobLoader._api_response_to_documents( - chunks=parsed_response["data"]["chunks"] - ) - - def _get_documents(self, blob: Blob) -> Iterator[Document]: - """Get the documents from the blob.""" - payload = self._generate_payload(blob=blob) - - try: - documents = self._handle_request(payload=payload) - except requests.exceptions.RequestException as e: - if e.response is None or not e.response.text: - raise ValueError( - f"Error raised by Embaas document text extraction API: {e}" - ) - - parsed_response = e.response.json() - if "message" in parsed_response: - raise ValueError( - f"Validation Error raised by Embaas document text extraction API:" - f" {parsed_response['message']}" - ) - raise - - yield from documents - - -class EmbaasLoader(BaseEmbaasLoader, BaseLoader): - """Load from `Embaas`. - - To use, you should have the - environment variable ``EMBAAS_API_KEY`` set with your API key, or pass - it as a named parameter to the constructor. - - Example: - .. code-block:: python - - # Default parsing - from langchain_community.document_loaders.embaas import EmbaasLoader - loader = EmbaasLoader(file_path="example.mp3") - documents = loader.load() - - # Custom api parameters (create embeddings automatically) - from langchain_community.document_loaders.embaas import EmbaasBlobLoader - loader = EmbaasBlobLoader( - file_path="example.pdf", - params={ - "should_embed": True, - "model": "e5-large-v2", - "chunk_size": 256, - "chunk_splitter": "CharacterTextSplitter" - } - ) - documents = loader.load() - """ - - file_path: str - """The path to the file to load.""" - blob_loader: Optional[EmbaasBlobLoader] - """The blob loader to use. If not provided, a default one will be created.""" - - @validator("blob_loader", always=True) - def validate_blob_loader( - cls, v: EmbaasBlobLoader, values: Dict - ) -> EmbaasBlobLoader: - return v or EmbaasBlobLoader( - embaas_api_key=values["embaas_api_key"], - api_url=values["api_url"], - params=values["params"], - ) - - def lazy_load(self) -> Iterator[Document]: - """Load the documents from the file path lazily.""" - blob = Blob.from_path(path=self.file_path) - - assert self.blob_loader is not None - # Should never be None, but mypy doesn't know that. - yield from self.blob_loader.lazy_parse(blob=blob) - - def load(self) -> List[Document]: - return list(self.lazy_load()) - - def load_and_split( - self, text_splitter: Optional[TextSplitter] = None - ) -> List[Document]: - if self.params.get("should_embed", False): - warnings.warn( - "Embeddings are not supported with load_and_split." - " Use the API splitter to properly generate embeddings." - " For more information see embaas.io docs." - ) - return super().load_and_split(text_splitter=text_splitter) +from langchain_community.document_loaders.embaas import ( + EMBAAS_DOC_API_URL, + BaseEmbaasLoader, + EmbaasBlobLoader, + EmbaasDocumentExtractionParameters, + EmbaasDocumentExtractionPayload, + EmbaasLoader, +) + +__all__ = [ + "EMBAAS_DOC_API_URL", + "EmbaasDocumentExtractionParameters", + "EmbaasDocumentExtractionPayload", + "BaseEmbaasLoader", + "EmbaasBlobLoader", + "EmbaasLoader", +] diff --git a/libs/community/langchain_community/document_loaders/geodataframe.py b/libs/community/langchain_community/document_loaders/geodataframe.py index 3782b7bbc2a13..09a4c5ae9f97e 100644 --- a/libs/community/langchain_community/document_loaders/geodataframe.py +++ b/libs/community/langchain_community/document_loaders/geodataframe.py @@ -35,7 +35,7 @@ def __init__(self, data_frame: Any, page_content_column: str = "geometry"): f"Expected data_frame to have a column named {page_content_column}" ) - if not isinstance(data_frame[page_content_column].iloc[0], gpd.GeoSeries): + if not isinstance(data_frame[page_content_column], gpd.GeoSeries): raise ValueError( f"Expected data_frame[{page_content_column}] to be a GeoSeries" ) diff --git a/libs/community/langchain_community/document_loaders/obsidian.py b/libs/community/langchain_community/document_loaders/obsidian.py index 40f6747069fce..e4b69341b679d 100644 --- a/libs/community/langchain_community/document_loaders/obsidian.py +++ b/libs/community/langchain_community/document_loaders/obsidian.py @@ -1,7 +1,8 @@ +import functools import logging import re from pathlib import Path -from typing import List +from typing import Any, Dict, List import yaml from langchain_core.documents import Document @@ -15,6 +16,7 @@ class ObsidianLoader(BaseLoader): """Load `Obsidian` files from directory.""" FRONT_MATTER_REGEX = re.compile(r"^---\n(.*?)\n---\n", re.DOTALL) + TEMPLATE_VARIABLE_REGEX = re.compile(r"{{(.*?)}}", re.DOTALL) TAG_REGEX = re.compile(r"[^\S\/]#([a-zA-Z_]+[-_/\w]*)") DATAVIEW_LINE_REGEX = re.compile(r"^\s*(\w+)::\s*(.*)$", re.MULTILINE) DATAVIEW_INLINE_BRACKET_REGEX = re.compile(r"\[(\w+)::\s*(.*)\]", re.MULTILINE) @@ -35,6 +37,27 @@ def __init__( self.encoding = encoding self.collect_metadata = collect_metadata + def _replace_template_var( + self, placeholders: Dict[str, str], match: re.Match + ) -> str: + """Replace a template variable with a placeholder.""" + placeholder = f"__TEMPLATE_VAR_{len(placeholders)}__" + placeholders[placeholder] = match.group(1) + return placeholder + + def _restore_template_vars(self, obj: Any, placeholders: Dict[str, str]) -> Any: + """Restore template variables replaced with placeholders to original values.""" + if isinstance(obj, str): + for placeholder, value in placeholders.items(): + obj = obj.replace(placeholder, f"{{{{{value}}}}}") + elif isinstance(obj, dict): + for key, value in obj.items(): + obj[key] = self._restore_template_vars(value, placeholders) + elif isinstance(obj, list): + for i, item in enumerate(obj): + obj[i] = self._restore_template_vars(item, placeholders) + return obj + def _parse_front_matter(self, content: str) -> dict: """Parse front matter metadata from the content and return it as a dict.""" if not self.collect_metadata: @@ -44,8 +67,17 @@ def _parse_front_matter(self, content: str) -> dict: if not match: return {} + placeholders: Dict[str, str] = {} + replace_template_var = functools.partial( + self._replace_template_var, placeholders + ) + front_matter_text = self.TEMPLATE_VARIABLE_REGEX.sub( + replace_template_var, match.group(1) + ) + try: - front_matter = yaml.safe_load(match.group(1)) + front_matter = yaml.safe_load(front_matter_text) + front_matter = self._restore_template_vars(front_matter, placeholders) # If tags are a string, split them into a list if "tags" in front_matter and isinstance(front_matter["tags"], str): diff --git a/libs/community/langchain_community/document_loaders/parsers/audio.py b/libs/community/langchain_community/document_loaders/parsers/audio.py index 266d204cce6c3..77ec5688a6539 100644 --- a/libs/community/langchain_community/document_loaders/parsers/audio.py +++ b/libs/community/langchain_community/document_loaders/parsers/audio.py @@ -3,6 +3,7 @@ from typing import Dict, Iterator, Optional, Tuple from langchain_core.documents import Document +from langchain_openai.utils import is_openai_v1 from langchain_community.document_loaders.base import BaseBlobParser from langchain_community.document_loaders.blob_loaders import Blob @@ -36,9 +37,13 @@ def lazy_parse(self, blob: Blob) -> Iterator[Document]: "pydub package not found, please install it with " "`pip install pydub`" ) - # Set the API key if provided - if self.api_key: - openai.api_key = self.api_key + if is_openai_v1(): + # api_key optional, defaults to `os.environ['OPENAI_API_KEY']` + client = openai.OpenAI(api_key=self.api_key) + else: + # Set the API key if provided + if self.api_key: + openai.api_key = self.api_key # Audio file from disk audio = AudioSegment.from_file(blob.path) @@ -63,7 +68,12 @@ def lazy_parse(self, blob: Blob) -> Iterator[Document]: attempts = 0 while attempts < 3: try: - transcript = openai.Audio.transcribe("whisper-1", file_obj) + if is_openai_v1(): + transcript = client.audio.transcriptions.create( + model="whisper-1", file=file_obj + ) + else: + transcript = openai.Audio.transcribe("whisper-1", file_obj) break except Exception as e: attempts += 1 diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index fdb39f4a10439..b5ceee27a59a5 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -26,6 +26,7 @@ QianfanEmbeddingsEndpoint, ) from langchain_community.embeddings.bedrock import BedrockEmbeddings +from langchain_community.embeddings.bookend import BookendEmbeddings from langchain_community.embeddings.cache import CacheBackedEmbeddings from langchain_community.embeddings.clarifai import ClarifaiEmbeddings from langchain_community.embeddings.cohere import CohereEmbeddings @@ -137,6 +138,7 @@ "QianfanEmbeddingsEndpoint", "JohnSnowLabsEmbeddings", "VoyageEmbeddings", + "BookendEmbeddings", ] diff --git a/libs/community/langchain_community/embeddings/bookend.py b/libs/community/langchain_community/embeddings/bookend.py new file mode 100644 index 0000000000000..fdfcea3eb10c5 --- /dev/null +++ b/libs/community/langchain_community/embeddings/bookend.py @@ -0,0 +1,91 @@ +"""Wrapper around Bookend AI embedding models.""" + +import json +from typing import Any, List + +import requests +from langchain_core.pydantic_v1 import BaseModel, Field + +from langchain_community.schema.embeddings import Embeddings + +API_URL = "https://api.bookend.ai/" +DEFAULT_TASK = "embeddings" +PATH = "/models/predict" + + +class BookendEmbeddings(BaseModel, Embeddings): + """Bookend AI sentence_transformers embedding models. + + Example: + .. code-block:: python + + from langchain_community.embeddings import BookendEmbeddings + + bookend = BookendEmbeddings( + domain={domain} + api_token={api_token} + model_id={model_id} + ) + bookend.embed_documents([ + "Please put on these earmuffs because I can't you hear.", + "Baby wipes are made of chocolate stardust.", + ]) + bookend.embed_query( + "She only paints with bold colors; she does not like pastels." + ) + """ + + domain: str + """Request for a domain at https://bookend.ai/ to use this embeddings module.""" + api_token: str + """Request for an API token at https://bookend.ai/ to use this embeddings module.""" + model_id: str + """Embeddings model ID to use.""" + auth_header: dict = Field(default_factory=dict) + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.auth_header = {"Authorization": "Basic {}".format(self.api_token)} + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed documents using a Bookend deployed embeddings model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + result = [] + headers = self.auth_header + headers["Content-Type"] = "application/json; charset=utf-8" + params = { + "model_id": self.model_id, + "task": DEFAULT_TASK, + } + + for text in texts: + data = json.dumps( + {"text": text, "question": None, "context": None, "instruction": None} + ) + r = requests.request( + "POST", + API_URL + self.domain + PATH, + headers=headers, + params=params, + data=data, + ) + result.append(r.json()[0]["data"]) + + return result + + def embed_query(self, text: str) -> List[float]: + """Embed a query using a Bookend deployed embeddings model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents([text])[0] diff --git a/libs/community/langchain_community/embeddings/clarifai.py b/libs/community/langchain_community/embeddings/clarifai.py index ef19690296030..e52fdf5dc2109 100644 --- a/libs/community/langchain_community/embeddings/clarifai.py +++ b/libs/community/langchain_community/embeddings/clarifai.py @@ -63,8 +63,8 @@ def validate_environment(cls, values: Dict) -> Dict: raise ValueError("Please provide a model_id.") try: - from clarifai.auth.helper import ClarifaiAuthHelper from clarifai.client import create_stub + from clarifai.client.auth.helper import ClarifaiAuthHelper except ImportError: raise ImportError( "Could not import clarifai python package. " diff --git a/libs/community/langchain_community/embeddings/cloudflare_workersai.py b/libs/community/langchain_community/embeddings/cloudflare_workersai.py new file mode 100644 index 0000000000000..81f6f83d4a52a --- /dev/null +++ b/libs/community/langchain_community/embeddings/cloudflare_workersai.py @@ -0,0 +1,94 @@ +from typing import Any, Dict, List + +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Extra + +DEFAULT_MODEL_NAME = "@cf/baai/bge-base-en-v1.5" + + +class CloudflareWorkersAIEmbeddings(BaseModel, Embeddings): + """Cloudflare Workers AI embedding model. + + To use, you need to provide an API token and + account ID to access Cloudflare Workers AI. + + Example: + .. code-block:: python + + from langchain_community.embeddings import CloudflareWorkersAIEmbeddings + + account_id = "my_account_id" + api_token = "my_secret_api_token" + model_name = "@cf/baai/bge-small-en-v1.5" + + cf = CloudflareWorkersAIEmbeddings( + account_id=account_id, + api_token=api_token, + model_name=model_name + ) + """ + + api_base_url: str = "https://api.cloudflare.com/client/v4/accounts" + account_id: str + api_token: str + model_name: str = DEFAULT_MODEL_NAME + batch_size: int = 50 + strip_new_lines: bool = True + headers: Dict[str, str] = {"Authorization": "Bearer "} + + def __init__(self, **kwargs: Any): + """Initialize the Cloudflare Workers AI client.""" + super().__init__(**kwargs) + + self.headers = {"Authorization": f"Bearer {self.api_token}"} + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using Cloudflare Workers AI. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + if self.strip_new_lines: + texts = [text.replace("\n", " ") for text in texts] + + batches = [ + texts[i : i + self.batch_size] + for i in range(0, len(texts), self.batch_size) + ] + embeddings = [] + + for batch in batches: + response = requests.post( + f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}", + headers=self.headers, + json={"text": batch}, + ) + embeddings.extend(response.json()["result"]["data"]) + + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using Cloudflare Workers AI. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + text = text.replace("\n", " ") if self.strip_new_lines else text + response = requests.post( + f"{self.api_base_url}/{self.account_id}/ai/run/{self.model_name}", + headers=self.headers, + json={"text": [text]}, + ) + return response.json()["result"]["data"][0] diff --git a/libs/community/langchain_community/embeddings/cohere.py b/libs/community/langchain_community/embeddings/cohere.py index ecb20a9013700..11617be5e3d10 100644 --- a/libs/community/langchain_community/embeddings/cohere.py +++ b/libs/community/langchain_community/embeddings/cohere.py @@ -1,9 +1,10 @@ from typing import Any, Dict, List, Optional -from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.utils import get_from_dict_or_env +from langchain_community.schema.embeddings import Embeddings + class CohereEmbeddings(BaseModel, Embeddings): """Cohere embedding models. @@ -17,7 +18,8 @@ class CohereEmbeddings(BaseModel, Embeddings): from langchain_community.embeddings import CohereEmbeddings cohere = CohereEmbeddings( - model="embed-english-light-v3.0", cohere_api_key="my-api-key" + model="embed-english-light-v3.0", + cohere_api_key="my-api-key" ) """ @@ -77,8 +79,30 @@ def validate_environment(cls, values: Dict) -> Dict: ) return values + def embed( + self, texts: List[str], *, input_type: Optional[str] = None + ) -> List[List[float]]: + embeddings = self.client.embed( + model=self.model, + texts=texts, + input_type=input_type, + truncate=self.truncate, + ).embeddings + return [list(map(float, e)) for e in embeddings] + + async def aembed( + self, texts: List[str], *, input_type: Optional[str] = None + ) -> List[List[float]]: + embeddings = await self.async_client.embed( + model=self.model, + texts=texts, + input_type=input_type, + truncate=self.truncate, + ).embeddings + return [list(map(float, e)) for e in embeddings] + def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Call out to Cohere's embedding endpoint. + """Embed a list of document texts. Args: texts: The list of texts to embed. @@ -86,13 +110,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: Returns: List of embeddings, one for each text. """ - embeddings = self.client.embed( - model=self.model, - texts=texts, - input_type="search_document", - truncate=self.truncate, - ).embeddings - return [list(map(float, e)) for e in embeddings] + return self.embed(texts, input_type="search_document") async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Async call out to Cohere's embedding endpoint. @@ -103,13 +121,7 @@ async def aembed_documents(self, texts: List[str]) -> List[List[float]]: Returns: List of embeddings, one for each text. """ - embeddings = await self.async_client.embed( - model=self.model, - texts=texts, - input_type="search_document", - truncate=self.truncate, - ) - return [list(map(float, e)) for e in embeddings.embeddings] + return await self.aembed(texts, input_type="search_document") def embed_query(self, text: str) -> List[float]: """Call out to Cohere's embedding endpoint. @@ -120,13 +132,7 @@ def embed_query(self, text: str) -> List[float]: Returns: Embeddings for the text. """ - embeddings = self.client.embed( - model=self.model, - texts=[text], - input_type="search_query", - truncate=self.truncate, - ).embeddings - return [list(map(float, e)) for e in embeddings][0] + return self.embed([text], input_type="search_query")[0] async def aembed_query(self, text: str) -> List[float]: """Async call out to Cohere's embedding endpoint. @@ -137,10 +143,4 @@ async def aembed_query(self, text: str) -> List[float]: Returns: Embeddings for the text. """ - embeddings = await self.async_client.embed( - model=self.model, - texts=[text], - input_type="search_query", - truncate=self.truncate, - ) - return [list(map(float, e)) for e in embeddings.embeddings][0] + return (await self.aembed([text], input_type="search_query"))[0] diff --git a/libs/community/langchain_community/embeddings/embaas.py b/libs/community/langchain_community/embeddings/embaas.py index 02153a8173843..e9e5d5f3ace46 100644 --- a/libs/community/langchain_community/embeddings/embaas.py +++ b/libs/community/langchain_community/embeddings/embaas.py @@ -1,11 +1,13 @@ from typing import Any, Dict, List, Mapping, Optional import requests -from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.utils import get_from_dict_or_env +from requests.adapters import HTTPAdapter, Retry from typing_extensions import NotRequired, TypedDict +from langchain_community.schema.embeddings import Embeddings + # Currently supported maximum batch size for embedding requests MAX_BATCH_SIZE = 256 EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/" @@ -50,6 +52,10 @@ class EmbaasEmbeddings(BaseModel, Embeddings): api_url: str = EMBAAS_API_URL """The URL for the embaas embeddings API.""" embaas_api_key: Optional[str] = None + """max number of retries for requests""" + max_retries: Optional[int] = 3 + """request timeout in seconds""" + timeout: Optional[int] = 30 class Config: """Configuration for this pydantic object.""" @@ -84,8 +90,22 @@ def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]] "Content-Type": "application/json", } - response = requests.post(self.api_url, headers=headers, json=payload) - response.raise_for_status() + session = requests.Session() + retries = Retry( + total=self.max_retries, + backoff_factor=0.5, + allowed_methods=["POST"], + raise_on_status=True, + ) + + session.mount("http://", HTTPAdapter(max_retries=retries)) + session.mount("https://", HTTPAdapter(max_retries=retries)) + response = session.post( + self.api_url, + headers=headers, + json=payload, + timeout=self.timeout, + ) parsed_response = response.json() embeddings = [item["embedding"] for item in parsed_response["data"]] diff --git a/libs/community/langchain_community/embeddings/huggingface.py b/libs/community/langchain_community/embeddings/huggingface.py index 6102092f0a968..6ca23b6bed5db 100644 --- a/libs/community/langchain_community/embeddings/huggingface.py +++ b/libs/community/langchain_community/embeddings/huggingface.py @@ -279,9 +279,15 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings): """Your API key for the HuggingFace Inference API.""" model_name: str = "sentence-transformers/all-MiniLM-L6-v2" """The name of the model to use for text embeddings.""" + api_url: Optional[str] = None + """Custom inference endpoint url. None for using default public url.""" @property def _api_url(self) -> str: + return self.api_url or self._default_api_url + + @property + def _default_api_url(self) -> str: return ( "https://api-inference.huggingface.co" "/pipeline" diff --git a/libs/community/langchain_community/embeddings/jina.py b/libs/community/langchain_community/embeddings/jina.py index da84d7d47865f..783615e59432c 100644 --- a/libs/community/langchain_community/embeddings/jina.py +++ b/libs/community/langchain_community/embeddings/jina.py @@ -1,4 +1,3 @@ -import os from typing import Any, Dict, List, Optional import requests @@ -6,69 +5,54 @@ from langchain_core.pydantic_v1 import BaseModel, root_validator from langchain_core.utils import get_from_dict_or_env +JINA_API_URL: str = "https://api.jina.ai/v1/embeddings" + class JinaEmbeddings(BaseModel, Embeddings): """Jina embedding models.""" - client: Any #: :meta private: - - model_name: str = "ViT-B-32::openai" - """Model name to use.""" - - jina_auth_token: Optional[str] = None - jina_api_url: str = "https://api.clip.jina.ai/api/v1/models/" - request_headers: Optional[dict] = None + session: Any #: :meta private: + model_name: str = "jina-embeddings-v2-base-en" + jina_api_key: Optional[str] = None @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that auth token exists in environment.""" - # Set Auth - jina_auth_token = get_from_dict_or_env( - values, "jina_auth_token", "JINA_AUTH_TOKEN" - ) - values["jina_auth_token"] = jina_auth_token - values["request_headers"] = (("authorization", jina_auth_token),) - - # Test that package is installed try: - import jina - except ImportError: - raise ImportError( - "Could not import `jina` python package. " - "Please install it with `pip install jina`." - ) + jina_api_key = get_from_dict_or_env(values, "jina_api_key", "JINA_API_KEY") + except ValueError as original_exc: + try: + jina_api_key = get_from_dict_or_env( + values, "jina_auth_token", "JINA_AUTH_TOKEN" + ) + except ValueError: + raise original_exc + session = requests.Session() + session.headers.update( + { + "Authorization": f"Bearer {jina_api_key}", + "Accept-Encoding": "identity", + "Content-type": "application/json", + } + ) + values["session"] = session + return values - # Setup client - jina_api_url = os.environ.get("JINA_API_URL", values["jina_api_url"]) - model_name = values["model_name"] - try: - resp = requests.get( - jina_api_url + f"?model_name={model_name}", - headers={"Authorization": jina_auth_token}, - ) + def _embed(self, texts: List[str]) -> List[List[float]]: + # Call Jina AI Embedding API + resp = self.session.post( # type: ignore + JINA_API_URL, json={"input": texts, "model": self.model_name} + ).json() + if "data" not in resp: + raise RuntimeError(resp["detail"]) - if resp.status_code == 401: - raise ValueError( - "The given Jina auth token is invalid. " - "Please check your Jina auth token." - ) - elif resp.status_code == 404: - raise ValueError( - f"The given model name `{model_name}` is not valid. " - f"Please go to https://cloud.jina.ai/user/inference " - f"and create a model with the given model name." - ) - resp.raise_for_status() + embeddings = resp["data"] - endpoint = resp.json()["endpoints"]["grpc"] - values["client"] = jina.Client(host=endpoint) - except requests.exceptions.HTTPError as err: - raise ValueError(f"Error: {err!r}") - return values + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore - def _post(self, docs: List[Any], **kwargs: Any) -> Any: - payload = dict(inputs=docs, metadata=self.request_headers, **kwargs) - return self.client.post(on="/encode", **payload) + # Return just the embeddings + return [result["embedding"] for result in sorted_embeddings] def embed_documents(self, texts: List[str]) -> List[List[float]]: """Call out to Jina's embedding endpoint. @@ -77,12 +61,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]: Returns: List of embeddings, one for each text. """ - from docarray import Document, DocumentArray - - embeddings = self._post( - docs=DocumentArray([Document(text=t) for t in texts]) - ).embeddings - return [list(map(float, e)) for e in embeddings] + return self._embed(texts) def embed_query(self, text: str) -> List[float]: """Call out to Jina's embedding endpoint. @@ -91,7 +70,4 @@ def embed_query(self, text: str) -> List[float]: Returns: Embeddings for the text. """ - from docarray import Document, DocumentArray - - embedding = self._post(docs=DocumentArray([Document(text=text)])).embeddings[0] - return list(map(float, embedding)) + return self._embed([text])[0] diff --git a/libs/community/langchain_community/llms/arcee.py b/libs/community/langchain_community/llms/arcee.py index 6e956a15d5233..1f8538c53d3a2 100644 --- a/libs/community/langchain_community/llms/arcee.py +++ b/libs/community/langchain_community/llms/arcee.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, Union, cast from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.pydantic_v1 import Extra, SecretStr, root_validator @@ -30,7 +30,7 @@ class Arcee(LLM): _client: Optional[ArceeWrapper] = None #: :meta private: """Arcee _client.""" - arcee_api_key: Optional[SecretStr] = None + arcee_api_key: Union[SecretStr, str, None] = None """Arcee API Key""" model: str @@ -66,15 +66,16 @@ def __init__(self, **data: Any) -> None: """Initializes private fields.""" super().__init__(**data) + api_key = cast(SecretStr, self.arcee_api_key) self._client = ArceeWrapper( - arcee_api_key=cast(SecretStr, self.arcee_api_key), + arcee_api_key=api_key, arcee_api_url=self.arcee_api_url, arcee_api_version=self.arcee_api_version, model_kwargs=self.model_kwargs, model_name=self.model, ) - @root_validator() + @root_validator(pre=False) def validate_environments(cls, values: Dict) -> Dict: """Validate Arcee environment variables.""" @@ -106,7 +107,7 @@ def validate_environments(cls, values: Dict) -> Dict: ) # validate model kwargs - if values["model_kwargs"]: + if values.get("model_kwargs"): kw = values["model_kwargs"] # validate size @@ -120,7 +121,6 @@ def validate_environments(cls, values: Dict) -> Dict: raise ValueError("`filters` must be a list") for f in kw.get("filters"): DALMFilter(**f) - return values def _call( diff --git a/libs/community/langchain_community/llms/clarifai.py b/libs/community/langchain_community/llms/clarifai.py index e7fe74af68f5f..c8053ad972f44 100644 --- a/libs/community/langchain_community/llms/clarifai.py +++ b/libs/community/langchain_community/llms/clarifai.py @@ -71,8 +71,8 @@ def validate_environment(cls, values: Dict) -> Dict: raise ValueError("Please provide a model_id.") try: - from clarifai.auth.helper import ClarifaiAuthHelper from clarifai.client import create_stub + from clarifai.client.auth.helper import ClarifaiAuthHelper except ImportError: raise ImportError( "Could not import clarifai python package. " diff --git a/libs/community/langchain_community/storage/upstash_redis.py b/libs/community/langchain_community/storage/upstash_redis.py index 193b8c7c4b5c3..7dc436ce33e2c 100644 --- a/libs/community/langchain_community/storage/upstash_redis.py +++ b/libs/community/langchain_community/storage/upstash_redis.py @@ -1,9 +1,10 @@ from typing import Any, Iterator, List, Optional, Sequence, Tuple, cast +from langchain_core._api.deprecation import deprecated from langchain_core.stores import BaseStore -class UpstashRedisStore(BaseStore[str, str]): +class _UpstashRedisStore(BaseStore[str, str]): """BaseStore implementation using Upstash Redis as the underlying store.""" def __init__( @@ -117,3 +118,57 @@ def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: yield relative_key else: yield key + + +@deprecated("0.0.335", alternative="UpstashRedisByteStore") +class UpstashRedisStore(_UpstashRedisStore): + """ + BaseStore implementation using Upstash Redis + as the underlying store to store strings. + + Deprecated in favor of the more generic UpstashRedisByteStore. + """ + + +class UpstashRedisByteStore(BaseStore[str, bytes]): + """ + BaseStore implementation using Upstash Redis + as the underlying store to store raw bytes. + """ + + def __init__( + self, + *, + client: Any = None, + url: Optional[str] = None, + token: Optional[str] = None, + ttl: Optional[int] = None, + namespace: Optional[str] = None, + ) -> None: + self.underlying_store = _UpstashRedisStore( + client=client, url=url, token=token, ttl=ttl, namespace=namespace + ) + + def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: + """Get the values associated with the given keys.""" + return [ + value.encode("utf-8") if value is not None else None + for value in self.underlying_store.mget(keys) + ] + + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + """Set the given key-value pairs.""" + self.underlying_store.mset( + [ + (k, v.decode("utf-8")) if v is not None else None + for k, v in key_value_pairs + ] + ) + + def mdelete(self, keys: Sequence[str]) -> None: + """Delete the given keys.""" + self.underlying_store.mdelete(keys) + + def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: + """Yield keys in the store.""" + yield from self.underlying_store.yield_keys(prefix=prefix) diff --git a/libs/community/langchain_community/tools/__init__.py b/libs/community/langchain_community/tools/__init__.py index 386ef3f48701f..a92353c477b2e 100644 --- a/libs/community/langchain_community/tools/__init__.py +++ b/libs/community/langchain_community/tools/__init__.py @@ -296,6 +296,18 @@ def _import_google_serper_tool_GoogleSerperRun() -> Any: return GoogleSerperRun +def _import_searchapi_tool_SearchAPIResults() -> Any: + from langchain_community.tools.searchapi.tool import SearchAPIResults + + return SearchAPIResults + + +def _import_searchapi_tool_SearchAPIRun() -> Any: + from langchain_community.tools.searchapi.tool import SearchAPIRun + + return SearchAPIRun + + def _import_graphql_tool() -> Any: from langchain_community.tools.graphql.tool import BaseGraphQLTool @@ -350,6 +362,12 @@ def _import_metaphor_search() -> Any: return MetaphorSearchResults +def _import_nasa_tool() -> Any: + from langchain_community.tools.nasa.tool import NasaAction + + return NasaAction + + def _import_office365_create_draft_message() -> Any: from langchain_community.tools.office365.create_draft_message import ( O365CreateDraftMessage, @@ -548,6 +566,12 @@ def _import_requests_tool_RequestsPutTool() -> Any: return RequestsPutTool +def _import_steam_webapi_tool() -> Any: + from langchain_community.tools.steam.tool import SteamWebAPIQueryRun + + return SteamWebAPIQueryRun + + def _import_scenexplain_tool() -> Any: from langchain_community.tools.scenexplain.tool import SceneXplainTool @@ -823,6 +847,10 @@ def __getattr__(name: str) -> Any: return _import_google_serper_tool_GoogleSerperResults() elif name == "GoogleSerperRun": return _import_google_serper_tool_GoogleSerperRun() + elif name == "SearchAPIResults": + return _import_searchapi_tool_SearchAPIResults() + elif name == "SearchAPIRun": + return _import_searchapi_tool_SearchAPIRun() elif name == "BaseGraphQLTool": return _import_graphql_tool() elif name == "HumanInputRun": @@ -841,6 +869,8 @@ def __getattr__(name: str) -> Any: return _import_merriam_webster_tool() elif name == "MetaphorSearchResults": return _import_metaphor_search() + elif name == "NasaAction": + return _import_nasa_tool() elif name == "O365CreateDraftMessage": return _import_office365_create_draft_message() elif name == "O365SearchEvents": @@ -903,6 +933,8 @@ def __getattr__(name: str) -> Any: return _import_requests_tool_RequestsPostTool() elif name == "RequestsPutTool": return _import_requests_tool_RequestsPutTool() + elif name == "SteamWebAPIQueryRun": + return _import_steam_webapi_tool() elif name == "SceneXplainTool": return _import_scenexplain_tool() elif name == "SearxSearchResults": @@ -1023,6 +1055,8 @@ def __getattr__(name: str) -> Any: "GoogleSearchRun", "GoogleSerperResults", "GoogleSerperRun", + "SearchAPIResults", + "SearchAPIRun", "HumanInputRun", "IFTTTWebhook", "InfoPowerBITool", @@ -1038,6 +1072,7 @@ def __getattr__(name: str) -> Any: "MerriamWebsterQueryRun", "MetaphorSearchResults", "MoveFileTool", + "NasaAction", "NavigateBackTool", "NavigateTool", "O365CreateDraftMessage", @@ -1060,6 +1095,7 @@ def __getattr__(name: str) -> Any: "RequestsPatchTool", "RequestsPostTool", "RequestsPutTool", + "SteamWebAPIQueryRun", "SceneXplainTool", "SearxSearchResults", "SearxSearchRun", diff --git a/libs/community/langchain_community/tools/ddg_search/tool.py b/libs/community/langchain_community/tools/ddg_search/tool.py index 786f828fdb653..2af4b6e5edf0d 100644 --- a/libs/community/langchain_community/tools/ddg_search/tool.py +++ b/libs/community/langchain_community/tools/ddg_search/tool.py @@ -46,11 +46,11 @@ class DuckDuckGoSearchResults(BaseTool): "Useful for when you need to answer questions about current events. " "Input should be a search query. Output is a JSON array of the query results" ) - num_results: int = 4 + max_results: int = Field(alias="num_results", default=4) api_wrapper: DuckDuckGoSearchAPIWrapper = Field( default_factory=DuckDuckGoSearchAPIWrapper ) - backend: str = "api" + backend: str = "text" args_schema: Type[BaseModel] = DDGInput def _run( @@ -59,7 +59,7 @@ def _run( run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Use the tool.""" - res = self.api_wrapper.results(query, self.num_results, backend=self.backend) + res = self.api_wrapper.results(query, self.max_results, source=self.backend) res_strs = [", ".join([f"{k}: {v}" for k, v in d.items()]) for d in res] return ", ".join([f"[{rs}]" for rs in res_strs]) diff --git a/libs/community/langchain_community/tools/github/prompt.py b/libs/community/langchain_community/tools/github/prompt.py index e0e72a808b802..3d66713e02b7e 100644 --- a/libs/community/langchain_community/tools/github/prompt.py +++ b/libs/community/langchain_community/tools/github/prompt.py @@ -1,19 +1,17 @@ # flake8: noqa GET_ISSUES_PROMPT = """ -This tool will fetch a list of the repository's issues. It will return the title, and issue number of 5 issues. It takes no input. -""" +This tool will fetch a list of the repository's issues. It will return the title, and issue number of 5 issues. It takes no input.""" GET_ISSUE_PROMPT = """ -This tool will fetch the title, body, and comment thread of a specific issue. **VERY IMPORTANT**: You must specify the issue number as an integer. -""" +This tool will fetch the title, body, and comment thread of a specific issue. **VERY IMPORTANT**: You must specify the issue number as an integer.""" COMMENT_ON_ISSUE_PROMPT = """ This tool is useful when you need to comment on a GitHub issue. Simply pass in the issue number and the comment you would like to make. Please use this sparingly as we don't want to clutter the comment threads. **VERY IMPORTANT**: Your input to this tool MUST strictly follow these rules: - First you must specify the issue number as an integer - Then you must place two newlines -- Then you must specify your comment -""" +- Then you must specify your comment""" + CREATE_PULL_REQUEST_PROMPT = """ This tool is useful when you need to create a new pull request in a GitHub repository. **VERY IMPORTANT**: Your input to this tool MUST strictly follow these rules: @@ -21,13 +19,13 @@ - Then you must place two newlines - Then you must write the body or description of the pull request -To reference an issue in the body, put its issue number directly after a #. -For example, if you would like to create a pull request called "README updates" with contents "added contributors' names, closes issue #3", you would pass in the following string: +When appropriate, always reference relevant issues in the body by using the syntax `closes #>>> OLD NEW <<<< new contents ->>>> NEW -""" +>>>> NEW""" DELETE_FILE_PROMPT = """ -This tool is a wrapper for the GitHub API, useful when you need to delete a file in a GitHub repository. Simply pass in the full file path of the file you would like to delete. **IMPORTANT**: the path must not start with a slash -""" +This tool is a wrapper for the GitHub API, useful when you need to delete a file in a GitHub repository. Simply pass in the full file path of the file you would like to delete. **IMPORTANT**: the path must not start with a slash""" + +GET_PR_PROMPT = """ +This tool will fetch the title, body, comment thread and commit history of a specific Pull Request (by PR number). **VERY IMPORTANT**: You must specify the PR number as an integer.""" + +LIST_PRS_PROMPT = """ +This tool will fetch a list of the repository's Pull Requests (PRs). It will return the title, and PR number of 5 PRs. It takes no input.""" + +LIST_PULL_REQUEST_FILES = """ +This tool will fetch the full text of all files in a pull request (PR) given the PR number as an input. This is useful for understanding the code changes in a PR or contributing to it. **VERY IMPORTANT**: You must specify the PR number as an integer input parameter.""" + +OVERVIEW_EXISTING_FILES_IN_MAIN = """ +This tool will provide an overview of all existing files in the main branch of the repository. It will list the file names, their respective paths, and a brief summary of their contents. This can be useful for understanding the structure and content of the repository, especially when navigating through large codebases. No input parameters are required.""" + +OVERVIEW_EXISTING_FILES_BOT_BRANCH = """ +This tool will provide an overview of all files in your current working branch where you should implement changes. This is great for getting a high level overview of the structure of your code. No input parameters are required.""" + +SEARCH_ISSUES_AND_PRS_PROMPT = """ +This tool will search for issues and pull requests in the repository. **VERY IMPORTANT**: You must specify the search query as a string input parameter.""" + +SEARCH_CODE_PROMPT = """ +This tool will search for code in the repository. **VERY IMPORTANT**: You must specify the search query as a string input parameter.""" + +CREATE_REVIEW_REQUEST_PROMPT = """ +This tool will create a review request on the open pull request that matches the current active branch. **VERY IMPORTANT**: You must specify the username of the person who is being requested as a string input parameter.""" + +LIST_BRANCHES_IN_REPO_PROMPT = """ +This tool will fetch a list of all branches in the repository. It will return the name of each branch. No input parameters are required.""" + +SET_ACTIVE_BRANCH_PROMPT = """ +This tool will set the active branch in the repository, similar to `git checkout ` and `git switch -c `. **VERY IMPORTANT**: You must specify the name of the branch as a string input parameter.""" + +CREATE_BRANCH_PROMPT = """ +This tool will create a new branch in the repository. **VERY IMPORTANT**: You must specify the name of the new branch as a string input parameter.""" + +GET_FILES_FROM_DIRECTORY_PROMPT = """ +This tool will fetch a list of all files in a specified directory. **VERY IMPORTANT**: You must specify the path of the directory as a string input parameter.""" diff --git a/libs/community/langchain_community/tools/github/tool.py b/libs/community/langchain_community/tools/github/tool.py index 3d3643f9754e6..1ad9f101d932d 100644 --- a/libs/community/langchain_community/tools/github/tool.py +++ b/libs/community/langchain_community/tools/github/tool.py @@ -7,10 +7,10 @@ GITHUB_REPOSITORY -> format: {owner}/{repo} """ -from typing import Optional +from typing import Optional, Type from langchain_core.callbacks.manager import CallbackManagerForToolRun -from langchain_core.pydantic_v1 import Field +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool from langchain_community.utilities.github import GitHubAPIWrapper @@ -23,11 +23,15 @@ class GitHubAction(BaseTool): mode: str name: str = "" description: str = "" + args_schema: Optional[Type[BaseModel]] = None def _run( self, - instructions: str, + instructions: Optional[str] = "", run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: """Use the GitHub API to run an operation.""" + if not instructions or instructions == "{}": + # Catch other forms of empty input that GPT-4 likes to send. + instructions = "" return self.api_wrapper.run(self.mode, instructions) diff --git a/libs/community/langchain_community/tools/nasa/__init__.py b/libs/community/langchain_community/tools/nasa/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/community/langchain_community/tools/nasa/prompt.py b/libs/community/langchain_community/tools/nasa/prompt.py new file mode 100644 index 0000000000000..4c7a3846a7e83 --- /dev/null +++ b/libs/community/langchain_community/tools/nasa/prompt.py @@ -0,0 +1,82 @@ +# flake8: noqa +NASA_SEARCH_PROMPT = """ + This tool is a wrapper around NASA's search API, useful when you need to search through NASA's Image and Video Library. + The input to this tool is a query specified by the user, and will be passed into NASA's `search` function. + + At least one parameter must be provided. + + There are optional parameters that can be passed by the user based on their query + specifications. Each item in this list contains pound sign (#) separated values, the first value is the parameter name, + the second value is the datatype and the third value is the description: {{ + + - q#string#Free text search terms to compare to all indexed metadata. + - center#string#NASA center which published the media. + - description#string#Terms to search for in “Description” fields. + - description_508#string#Terms to search for in “508 Description” fields. + - keywords #string#Terms to search for in “Keywords” fields. Separate multiple values with commas. + - location #string#Terms to search for in “Location” fields. + - media_type#string#Media types to restrict the search to. Available types: [“image”,“video”, “audio”]. Separate multiple values with commas. + - nasa_id #string#The media asset’s NASA ID. + - page#integer#Page number, starting at 1, of results to get.- + - page_size#integer#Number of results per page. Default: 100. + - photographer#string#The primary photographer’s name. + - secondary_creator#string#A secondary photographer/videographer’s name. + - title #string#Terms to search for in “Title” fields. + - year_start#string#The start year for results. Format: YYYY. + - year_end #string#The end year for results. Format: YYYY. + + }} + + Below are several task descriptions along with their respective input examples. + Task: get the 2nd page of image and video content starting from the year 2002 to 2010 + Example Input: {{"year_start": "2002", "year_end": "2010", "page": 2}} + + Task: get the image and video content of saturn photographed by John Appleseed + Example Input: {{"q": "saturn", "photographer": "John Appleseed"}} + + Task: search for Meteor Showers with description "Search Description" with media type image + Example Input: {{"q": "Meteor Shower", "description": "Search Description", "media_type": "image"}} + + Task: get the image and video content from year 2008 to 2010 from Kennedy Center + Example Input: {{"year_start": "2002", "year_end": "2010", "location": "Kennedy Center}} + """ + + +NASA_MANIFEST_PROMPT = """ + This tool is a wrapper around NASA's media asset manifest API, useful when you need to retrieve a media + asset's manifest. The input to this tool should include a string representing a NASA ID for a media asset that the user is trying to get the media asset manifest data for. The NASA ID will be passed as a string into NASA's `get_media_metadata_manifest` function. + + The following list are some examples of NASA IDs for a media asset that you can use to better extract the NASA ID from the input string to the tool. + - GSFC_20171102_Archive_e000579 + - Launch-Sound_Delta-PAM-Random-Commentary + - iss066m260341519_Expedition_66_Education_Inflight_with_Random_Lake_School_District_220203 + - 6973610 + - GRC-2020-CM-0167.4 + - Expedition_55_Inflight_Japan_VIP_Event_May_31_2018_659970 + - NASA 60th_SEAL_SLIVER_150DPI +""" + +NASA_METADATA_PROMPT = """ + This tool is a wrapper around NASA's media asset metadata location API, useful when you need to retrieve the media asset's metadata. The input to this tool should include a string representing a NASA ID for a media asset that the user is trying to get the media asset metadata location for. The NASA ID will be passed as a string into NASA's `get_media_metadata_manifest` function. + + The following list are some examples of NASA IDs for a media asset that you can use to better extract the NASA ID from the input string to the tool. + - GSFC_20171102_Archive_e000579 + - Launch-Sound_Delta-PAM-Random-Commentary + - iss066m260341519_Expedition_66_Education_Inflight_with_Random_Lake_School_District_220203 + - 6973610 + - GRC-2020-CM-0167.4 + - Expedition_55_Inflight_Japan_VIP_Event_May_31_2018_659970 + - NASA 60th_SEAL_SLIVER_150DPI +""" + +NASA_CAPTIONS_PROMPT = """ + This tool is a wrapper around NASA's video assests caption location API, useful when you need + to retrieve the location of the captions of a specific video. The input to this tool should include a string representing a NASA ID for a video media asset that the user is trying to get the get the location of the captions for. The NASA ID will be passed as a string into NASA's `get_media_metadata_manifest` function. + + The following list are some examples of NASA IDs for a video asset that you can use to better extract the NASA ID from the input string to the tool. + - 2017-08-09 - Video File RS-25 Engine Test + - 20180415-TESS_Social_Briefing + - 201_TakingWildOutOfWildfire + - 2022-H1_V_EuropaClipper-4 + - 2022_0429_Recientemente +""" diff --git a/libs/community/langchain_community/tools/nasa/tool.py b/libs/community/langchain_community/tools/nasa/tool.py new file mode 100644 index 0000000000000..23ca207779f2d --- /dev/null +++ b/libs/community/langchain_community/tools/nasa/tool.py @@ -0,0 +1,29 @@ +""" +This tool allows agents to interact with the NASA API, specifically +the the NASA Image & Video Library and Exoplanet +""" + +from typing import Optional + +from langchain_core.callbacks.manager import CallbackManagerForToolRun +from langchain_core.pydantic_v1 import Field +from langchain_core.tools import BaseTool + +from langchain_community.utilities.nasa import NasaAPIWrapper + + +class NasaAction(BaseTool): + """Tool that queries the Atlassian Jira API.""" + + api_wrapper: NasaAPIWrapper = Field(default_factory=NasaAPIWrapper) + mode: str + name: str = "" + description: str = "" + + def _run( + self, + instructions: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the NASA API to run an operation.""" + return self.api_wrapper.run(self.mode, instructions) diff --git a/libs/community/langchain_community/tools/playwright/utils.py b/libs/community/langchain_community/tools/playwright/utils.py index eb874f2eb4d65..692288fdde318 100644 --- a/libs/community/langchain_community/tools/playwright/utils.py +++ b/libs/community/langchain_community/tools/playwright/utils.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Any, Coroutine, TypeVar +from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, TypeVar if TYPE_CHECKING: from playwright.async_api import Browser as AsyncBrowser @@ -50,12 +50,15 @@ def get_current_page(browser: SyncBrowser) -> SyncPage: return context.pages[-1] -def create_async_playwright_browser(headless: bool = True) -> AsyncBrowser: +def create_async_playwright_browser( + headless: bool = True, args: Optional[List[str]] = None +) -> AsyncBrowser: """ Create an async playwright browser. Args: headless: Whether to run the browser in headless mode. Defaults to True. + args: arguments to pass to browser.chromium.launch Returns: AsyncBrowser: The playwright browser. @@ -63,15 +66,18 @@ def create_async_playwright_browser(headless: bool = True) -> AsyncBrowser: from playwright.async_api import async_playwright browser = run_async(async_playwright().start()) - return run_async(browser.chromium.launch(headless=headless)) + return run_async(browser.chromium.launch(headless=headless, args=args)) -def create_sync_playwright_browser(headless: bool = True) -> SyncBrowser: +def create_sync_playwright_browser( + headless: bool = True, args: Optional[List[str]] = None +) -> SyncBrowser: """ Create a playwright browser. Args: headless: Whether to run the browser in headless mode. Defaults to True. + args: arguments to pass to browser.chromium.launch Returns: SyncBrowser: The playwright browser. @@ -79,7 +85,7 @@ def create_sync_playwright_browser(headless: bool = True) -> SyncBrowser: from playwright.sync_api import sync_playwright browser = sync_playwright().start() - return browser.chromium.launch(headless=headless) + return browser.chromium.launch(headless=headless, args=args) T = TypeVar("T") diff --git a/libs/community/langchain_community/tools/steam/__init__.py b/libs/community/langchain_community/tools/steam/__init__.py new file mode 100644 index 0000000000000..9367fd95b3089 --- /dev/null +++ b/libs/community/langchain_community/tools/steam/__init__.py @@ -0,0 +1 @@ +"""Steam API toolkit""" diff --git a/libs/community/langchain_community/tools/steam/prompt.py b/libs/community/langchain_community/tools/steam/prompt.py new file mode 100644 index 0000000000000..6f82e2ff4f2f1 --- /dev/null +++ b/libs/community/langchain_community/tools/steam/prompt.py @@ -0,0 +1,26 @@ +STEAM_GET_GAMES_DETAILS = """ + This tool is a wrapper around python-steam-api's steam.apps.search_games API and + steam.apps.get_app_details API, useful when you need to search for a game. + The input to this tool is a string specifying the name of the game you want to + search for. For example, to search for a game called "Counter-Strike: Global + Offensive", you would input "Counter-Strike: Global Offensive" as the game name. + This input will be passed into steam.apps.search_games to find the game id, link + and price, and then the game id will be passed into steam.apps.get_app_details to + get the detailed description and supported languages of the game. Finally the + results are combined and returned as a string. +""" + +STEAM_GET_RECOMMENDED_GAMES = """ + This tool is a wrapper around python-steam-api's steam.users.get_owned_games API + and steamspypi's steamspypi.download API, useful when you need to get a list of + recommended games. The input to this tool is a string specifying the steam id of + the user you want to get recommended games for. For example, to get recommended + games for a user with steam id 76561197960435530, you would input + "76561197960435530" as the steam id. This steamid is then utilized to form a + data_request sent to steamspypi's steamspypi.download to retrieve genres of user's + owned games. Then, calculates the frequency of each genre, identifying the most + popular one, and stored it in a dictionary. Subsequently, use steamspypi.download + to returns all games in this genre and return 5 most-played games that is not owned + by the user. + +""" diff --git a/libs/community/langchain_community/tools/steam/tool.py b/libs/community/langchain_community/tools/steam/tool.py new file mode 100644 index 0000000000000..cd0cfdbc8b90a --- /dev/null +++ b/libs/community/langchain_community/tools/steam/tool.py @@ -0,0 +1,30 @@ +"""Tool for Steam Web API""" + +from typing import Optional + +from langchain_core.callbacks.manager import CallbackManagerForToolRun +from langchain_core.tools import BaseTool + +from langchain_community.utilities.steam import SteamWebAPIWrapper + + +class SteamWebAPIQueryRun(BaseTool): + """Tool that searches the Steam Web API.""" + + mode: str + name: str = "Steam" + description: str = ( + "A wrapper around Steam Web API." + "Steam Tool is useful for fetching User profiles and stats, Game data and more!" + "Input should be the User or Game you want to query." + ) + + api_wrapper: SteamWebAPIWrapper + + def _run( + self, + query: str, + run_manager: Optional[CallbackManagerForToolRun] = None, + ) -> str: + """Use the Steam-WebAPI tool.""" + return self.api_wrapper.run(self.mode, query) diff --git a/libs/community/langchain_community/utilities/__init__.py b/libs/community/langchain_community/utilities/__init__.py index efaf877c4f74c..ac91e964d2de9 100644 --- a/libs/community/langchain_community/utilities/__init__.py +++ b/libs/community/langchain_community/utilities/__init__.py @@ -224,6 +224,12 @@ def _import_sql_database() -> Any: return SQLDatabase +def _import_steam_webapi() -> Any: + from langchain_community.utilities.steam import SteamWebAPIWrapper + + return SteamWebAPIWrapper + + def _import_stackexchange() -> Any: from langchain_community.utilities.stackexchange import StackExchangeAPIWrapper @@ -260,6 +266,12 @@ def _import_zapier() -> Any: return ZapierNLAWrapper +def _import_nasa() -> Any: + from langchain_community.utilities.nasa import NasaAPIWrapper + + return NasaAPIWrapper + + def __getattr__(name: str) -> Any: if name == "AlphaVantageAPIWrapper": return _import_alpha_vantage() @@ -307,6 +319,8 @@ def __getattr__(name: str) -> Any: return _import_merriam_webster() elif name == "MetaphorSearchAPIWrapper": return _import_metaphor_search() + elif name == "NasaAPIWrapper": + return _import_nasa() elif name == "OpenWeatherMapAPIWrapper": return _import_openweathermap() elif name == "OutlineAPIWrapper": @@ -333,6 +347,8 @@ def __getattr__(name: str) -> Any: return _import_stackexchange() elif name == "SQLDatabase": return _import_sql_database() + elif name == "SteamWebAPIWrapper": + return _import_steam_webapi() elif name == "TensorflowDatasets": return _import_tensorflow_datasets() elif name == "TwilioAPIWrapper": @@ -371,6 +387,7 @@ def __getattr__(name: str) -> Any: "MaxComputeAPIWrapper", "MerriamWebsterAPIWrapper", "MetaphorSearchAPIWrapper", + "NasaAPIWrapper", "OpenWeatherMapAPIWrapper", "OutlineAPIWrapper", "Portkey", @@ -379,6 +396,7 @@ def __getattr__(name: str) -> Any: "PythonREPL", "Requests", "RequestsWrapper", + "SteamWebAPIWrapper", "SQLDatabase", "SceneXplainAPIWrapper", "SearchApiAPIWrapper", diff --git a/libs/community/langchain_community/utilities/arcee.py b/libs/community/langchain_community/utilities/arcee.py index 743930b93e8d4..7217034858310 100644 --- a/libs/community/langchain_community/utilities/arcee.py +++ b/libs/community/langchain_community/utilities/arcee.py @@ -96,11 +96,14 @@ def adapt(cls, arcee_document: ArceeDocument) -> Document: class ArceeWrapper: - """Wrapper for Arcee API.""" + """Wrapper for Arcee API. + + For more details, see: https://www.arcee.ai/ + """ def __init__( self, - arcee_api_key: SecretStr, + arcee_api_key: Union[str, SecretStr], arcee_api_url: str, arcee_api_version: str, model_kwargs: Optional[Dict[str, Any]], @@ -114,9 +117,12 @@ def __init__( arcee_api_version: Version of Arcee API. model_kwargs: Keyword arguments for Arcee API. model_name: Name of an Arcee model. - """ - self.arcee_api_key = arcee_api_key + if isinstance(arcee_api_key, str): + arcee_api_key_ = SecretStr(arcee_api_key) + else: + arcee_api_key_ = arcee_api_key + self.arcee_api_key: SecretStr = arcee_api_key_ self.model_kwargs = model_kwargs self.arcee_api_url = arcee_api_url self.arcee_api_version = arcee_api_version @@ -166,8 +172,13 @@ def _make_request( def _make_request_headers(self, headers: Optional[Dict] = None) -> Dict: headers = headers or {} + if not isinstance(self.arcee_api_key, SecretStr): + raise TypeError( + f"arcee_api_key must be a SecretStr. Got {type(self.arcee_api_key)}" + ) + api_key = self.arcee_api_key.get_secret_value() internal_headers = { - "X-Token": self.arcee_api_key.get_secret_value(), + "X-Token": api_key, "Content-Type": "application/json", } headers.update(internal_headers) diff --git a/libs/community/langchain_community/utilities/duckduckgo_search.py b/libs/community/langchain_community/utilities/duckduckgo_search.py index 67e89875ceca3..d258726896ab3 100644 --- a/libs/community/langchain_community/utilities/duckduckgo_search.py +++ b/libs/community/langchain_community/utilities/duckduckgo_search.py @@ -18,6 +18,8 @@ class DuckDuckGoSearchAPIWrapper(BaseModel): safesearch: str = "moderate" time: Optional[str] = "y" max_results: int = 5 + backend: str = "api" # which backend to use in DDGS.text() (api, html, lite) + source: str = "text" # which function to use in DDGS (DDGS.text() or DDGS.news()) class Config: """Configuration for this pydantic object.""" @@ -32,43 +34,69 @@ def validate_environment(cls, values: Dict) -> Dict: except ImportError: raise ImportError( "Could not import duckduckgo-search python package. " - "Please install it with `pip install duckduckgo-search`." + "Please install it with `pip install -U duckduckgo-search`." ) return values - def get_snippets(self, query: str) -> List[str]: - """Run query through DuckDuckGo and return concatenated results.""" + def _ddgs_text( + self, query: str, max_results: Optional[int] = None + ) -> List[Dict[str, str]]: + """Run query through DuckDuckGo text search and return results.""" + from duckduckgo_search import DDGS + + with DDGS() as ddgs: + ddgs_gen = ddgs.text( + query, + region=self.region, + safesearch=self.safesearch, + timelimit=self.time, + max_results=max_results or self.max_results, + backend=self.backend, + ) + if ddgs_gen: + return [r for r in ddgs_gen] + return [] + + def _ddgs_news( + self, query: str, max_results: Optional[int] = None + ) -> List[Dict[str, str]]: + """Run query through DuckDuckGo news search and return results.""" from duckduckgo_search import DDGS with DDGS() as ddgs: - results = ddgs.text( + ddgs_gen = ddgs.news( query, region=self.region, safesearch=self.safesearch, timelimit=self.time, + max_results=max_results or self.max_results, ) - if results is None: - return ["No good DuckDuckGo Search Result was found"] - snippets = [] - for i, res in enumerate(results, 1): - if res is not None: - snippets.append(res["body"]) - if len(snippets) == self.max_results: - break - return snippets + if ddgs_gen: + return [r for r in ddgs_gen] + return [] def run(self, query: str) -> str: - snippets = self.get_snippets(query) - return " ".join(snippets) + """Run query through DuckDuckGo and return concatenated results.""" + if self.source == "text": + results = self._ddgs_text(query) + elif self.source == "news": + results = self._ddgs_news(query) + else: + results = [] + + if not results: + return "No good DuckDuckGo Search Result was found" + return " ".join(r["body"] for r in results) def results( - self, query: str, num_results: int, backend: str = "api" + self, query: str, max_results: int, source: Optional[str] = None ) -> List[Dict[str, str]]: """Run query through DuckDuckGo and return metadata. Args: query: The query to search for. - num_results: The number of results to return. + max_results: The number of results to return. + source: The source to look from. Returns: A list of dictionaries with the following keys: @@ -76,38 +104,27 @@ def results( title - The title of the result. link - The link to the result. """ - from duckduckgo_search import DDGS - - with DDGS() as ddgs: - results = ddgs.text( - query, - region=self.region, - safesearch=self.safesearch, - timelimit=self.time, - backend=backend, - ) - if results is None: - return [{"Result": "No good DuckDuckGo Search Result was found"}] - - def to_metadata(result: Dict) -> Dict[str, str]: - if backend == "news": - return { - "date": result["date"], - "title": result["title"], - "snippet": result["body"], - "source": result["source"], - "link": result["url"], - } - return { - "snippet": result["body"], - "title": result["title"], - "link": result["href"], + source = source or self.source + if source == "text": + results = [ + {"snippet": r["body"], "title": r["title"], "link": r["href"]} + for r in self._ddgs_text(query, max_results=max_results) + ] + elif source == "news": + results = [ + { + "snippet": r["body"], + "title": r["title"], + "link": r["url"], + "date": r["date"], + "source": r["source"], } + for r in self._ddgs_news(query, max_results=max_results) + ] + else: + results = [] + + if results is None: + results = [{"Result": "No good DuckDuckGo Search Result was found"}] - formatted_results = [] - for i, res in enumerate(results, 1): - if res is not None: - formatted_results.append(to_metadata(res)) - if len(formatted_results) == num_results: - break - return formatted_results + return results diff --git a/libs/community/langchain_community/utilities/github.py b/libs/community/langchain_community/utilities/github.py index 6e83a6107d8f0..f4a7646975742 100644 --- a/libs/community/langchain_community/utilities/github.py +++ b/libs/community/langchain_community/utilities/github.py @@ -4,11 +4,14 @@ import json from typing import TYPE_CHECKING, Any, Dict, List, Optional +import requests +import tiktoken from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator from langchain_core.utils import get_from_dict_or_env if TYPE_CHECKING: from github.Issue import Issue + from github.PullRequest import PullRequest class GitHubAPIWrapper(BaseModel): @@ -19,7 +22,7 @@ class GitHubAPIWrapper(BaseModel): github_repository: Optional[str] = None github_app_id: Optional[str] = None github_app_private_key: Optional[str] = None - github_branch: Optional[str] = None + active_branch: Optional[str] = None github_base_branch: Optional[str] = None class Config: @@ -40,13 +43,6 @@ def validate_environment(cls, values: Dict) -> Dict: values, "github_app_private_key", "GITHUB_APP_PRIVATE_KEY" ) - github_branch = get_from_dict_or_env( - values, "github_branch", "GITHUB_BRANCH", default="master" - ) - github_base_branch = get_from_dict_or_env( - values, "github_base_branch", "GITHUB_BASE_BRANCH", default="master" - ) - try: from github import Auth, GithubIntegration @@ -56,8 +52,13 @@ def validate_environment(cls, values: Dict) -> Dict: "Please install it with `pip install PyGithub`" ) - with open(github_app_private_key, "r") as f: - private_key = f.read() + try: + # interpret the key as a file path + # fallback to interpreting as the key itself + with open(github_app_private_key, "r") as f: + private_key = f.read() + except Exception: + private_key = github_app_private_key auth = Auth.AppAuth( github_app_id, @@ -68,13 +69,28 @@ def validate_environment(cls, values: Dict) -> Dict: # create a GitHub instance: g = installation.get_github_for_installation() + repo = g.get_repo(github_repository) + + github_base_branch = get_from_dict_or_env( + values, + "github_base_branch", + "GITHUB_BASE_BRANCH", + default=repo.default_branch, + ) + + active_branch = get_from_dict_or_env( + values, + "active_branch", + "ACTIVE_BRANCH", + default=repo.default_branch, + ) values["github"] = g - values["github_repo_instance"] = g.get_repo(github_repository) + values["github_repo_instance"] = repo values["github_repository"] = github_repository values["github_app_id"] = github_app_id values["github_app_private_key"] = github_app_private_key - values["github_branch"] = github_branch + values["active_branch"] = active_branch values["github_base_branch"] = github_base_branch return values @@ -91,19 +107,45 @@ def parse_issues(self, issues: List[Issue]) -> List[dict]: for issue in issues: title = issue.title number = issue.number - parsed.append({"title": title, "number": number}) + opened_by = issue.user.login if issue.user else None + issue_dict = {"title": title, "number": number} + if opened_by is not None: + issue_dict["opened_by"] = opened_by + parsed.append(issue_dict) + return parsed + + def parse_pull_requests(self, pull_requests: List[PullRequest]) -> List[dict]: + """ + Extracts title and number from each Issue and puts them in a dictionary + Parameters: + issues(List[Issue]): A list of Github Issue objects + Returns: + List[dict]: A dictionary of issue titles and numbers + """ + parsed = [] + for pr in pull_requests: + parsed.append( + { + "title": pr.title, + "number": pr.number, + "commits": str(pr.commits), + "comments": str(pr.comments), + } + ) return parsed def get_issues(self) -> str: """ - Fetches all open issues from the repo + Fetches all open issues from the repo excluding pull requests Returns: str: A plaintext report containing the number of issues and each issue's title and number. """ issues = self.github_repo_instance.get_issues(state="open") - if issues.totalCount > 0: + # Filter out pull requests (part of GH issues object) + issues = [issue for issue in issues if not issue.pull_request] + if issues: parsed_issues = self.parse_issues(issues) parsed_issues_str = ( "Found " + str(len(parsed_issues)) + " issues:\n" + str(parsed_issues) @@ -112,14 +154,201 @@ def get_issues(self) -> str: else: return "No open issues available" + def list_open_pull_requests(self) -> str: + """ + Fetches all open PRs from the repo + + Returns: + str: A plaintext report containing the number of PRs + and each PR's title and number. + """ + # issues = self.github_repo_instance.get_issues(state="open") + pull_requests = self.github_repo_instance.get_pulls(state="open") + if pull_requests.totalCount > 0: + parsed_prs = self.parse_pull_requests(pull_requests) + parsed_prs_str = ( + "Found " + str(len(parsed_prs)) + " pull requests:\n" + str(parsed_prs) + ) + return parsed_prs_str + else: + return "No open pull requests available" + + def list_files_in_main_branch(self) -> str: + """ + Fetches all files in the main branch of the repo. + + Returns: + str: A plaintext report containing the paths and names of the files. + """ + files: List[str] = [] + try: + contents = self.github_repo_instance.get_contents( + "", ref=self.github_base_branch + ) + for content in contents: + if content.type == "dir": + files.extend(self.get_files_from_directory(content.path)) + else: + files.append(content.path) + + if files: + files_str = "\n".join(files) + return f"Found {len(files)} files in the main branch:\n{files_str}" + else: + return "No files found in the main branch" + except Exception as e: + return str(e) + + def set_active_branch(self, branch_name: str) -> str: + """Equivalent to `git checkout branch_name` for this Agent. + Clones formatting from Github. + + Returns an Error (as a string) if branch doesn't exist. + """ + curr_branches = [ + branch.name for branch in self.github_repo_instance.get_branches() + ] + if branch_name in curr_branches: + self.active_branch = branch_name + return f"Switched to branch `{branch_name}`" + else: + return ( + f"Error {branch_name} does not exist," + f"in repo with current branches: {str(curr_branches)}" + ) + + def list_branches_in_repo(self) -> str: + """ + Fetches a list of all branches in the repository. + + Returns: + str: A plaintext report containing the names of the branches. + """ + try: + branches = [ + branch.name for branch in self.github_repo_instance.get_branches() + ] + if branches: + branches_str = "\n".join(branches) + return ( + f"Found {len(branches)} branches in the repository:" + f"\n{branches_str}" + ) + else: + return "No branches found in the repository" + except Exception as e: + return str(e) + + def create_branch(self, proposed_branch_name: str) -> str: + """ + Create a new branch, and set it as the active bot branch. + Equivalent to `git switch -c proposed_branch_name` + If the proposed branch already exists, we append _v1 then _v2... + until a unique name is found. + + Returns: + str: A plaintext success message. + """ + from github import GithubException + + i = 0 + new_branch_name = proposed_branch_name + base_branch = self.github_repo_instance.get_branch( + self.github_repo_instance.default_branch + ) + for i in range(1000): + try: + self.github_repo_instance.create_git_ref( + ref=f"refs/heads/{new_branch_name}", sha=base_branch.commit.sha + ) + self.active_branch = new_branch_name + return ( + f"Branch '{new_branch_name}' " + "created successfully, and set as current active branch." + ) + except GithubException as e: + if e.status == 422 and "Reference already exists" in e.data["message"]: + i += 1 + new_branch_name = f"{proposed_branch_name}_v{i}" + else: + # Handle any other exceptions + print(f"Failed to create branch. Error: {e}") + raise Exception( + "Unable to create branch name from proposed_branch_name: " + f"{proposed_branch_name}" + ) + return ( + "Unable to create branch. " + "At least 1000 branches exist with named derived from " + f"proposed_branch_name: `{proposed_branch_name}`" + ) + + def list_files_in_bot_branch(self) -> str: + """ + Fetches all files in the active branch of the repo, + the branch the bot uses to make changes. + + Returns: + str: A plaintext list containing the the filepaths in the branch. + """ + files: List[str] = [] + try: + contents = self.github_repo_instance.get_contents( + "", ref=self.active_branch + ) + for content in contents: + if content.type == "dir": + files.extend(self.get_files_from_directory(content.path)) + else: + files.append(content.path) + + if files: + files_str = "\n".join(files) + return ( + f"Found {len(files)} files in branch `{self.active_branch}`:\n" + f"{files_str}" + ) + else: + return f"No files found in branch: `{self.active_branch}`" + except Exception as e: + return f"Error: {e}" + + def get_files_from_directory(self, directory_path: str) -> str: + """ + Recursively fetches files from a directory in the repo. + + Parameters: + directory_path (str): Path to the directory + + Returns: + str: List of file paths, or an error message. + """ + from github import GithubException + + files: List[str] = [] + try: + contents = self.github_repo_instance.get_contents( + directory_path, ref=self.active_branch + ) + except GithubException as e: + return f"Error: status code {e.status}, {e.message}" + + for content in contents: + if content.type == "dir": + files.extend(self.get_files_from_directory(content.path)) + else: + files.append(content.path) + return str(files) + def get_issue(self, issue_number: int) -> Dict[str, Any]: """ Fetches a specific issue and its first 10 comments Parameters: issue_number(int): The number for the github issue Returns: - dict: A doctionary containing the issue's title, - body, and comments as a string + dict: A dictionary containing the issue's title, + body, comments as a string, and the username of the user + who opened the issue """ issue = self.github_repo_instance.get_issue(number=issue_number) page = 0 @@ -132,12 +361,142 @@ def get_issue(self, issue_number: int) -> Dict[str, Any]: comments.append({"body": comment.body, "user": comment.user.login}) page += 1 + opened_by = None + if issue.user and issue.user.login: + opened_by = issue.user.login + return { + "number": issue_number, "title": issue.title, "body": issue.body, "comments": str(comments), + "opened_by": str(opened_by), } + def list_pull_request_files(self, pr_number: int) -> List[Dict[str, Any]]: + """Fetches the full text of all files in a PR. Truncates after first 3k tokens. + # TODO: Enhancement to summarize files with ctags if they're getting long. + + Args: + pr_number(int): The number of the pull request on Github + + Returns: + dict: A dictionary containing the issue's title, + body, and comments as a string + """ + MAX_TOKENS_FOR_FILES = 3_000 + pr_files = [] + pr = self.github_repo_instance.get_pull(number=int(pr_number)) + total_tokens = 0 + page = 0 + while True: # or while (total_tokens + tiktoken()) < MAX_TOKENS_FOR_FILES: + files_page = pr.get_files().get_page(page) + if len(files_page) == 0: + break + for file in files_page: + try: + file_metadata_response = requests.get(file.contents_url) + if file_metadata_response.status_code == 200: + download_url = json.loads(file_metadata_response.text)[ + "download_url" + ] + else: + print(f"Failed to download file: {file.contents_url}, skipping") + continue + + file_content_response = requests.get(download_url) + if file_content_response.status_code == 200: + # Save the content as a UTF-8 string + file_content = file_content_response.text + else: + print( + "Failed downloading file content " + f"(Error {file_content_response.status_code}). Skipping" + ) + continue + + file_tokens = len( + tiktoken.get_encoding("cl100k_base").encode( + file_content + file.filename + "file_name file_contents" + ) + ) + if (total_tokens + file_tokens) < MAX_TOKENS_FOR_FILES: + pr_files.append( + { + "filename": file.filename, + "contents": file_content, + "additions": file.additions, + "deletions": file.deletions, + } + ) + total_tokens += file_tokens + except Exception as e: + print(f"Error when reading files from a PR on github. {e}") + page += 1 + return pr_files + + def get_pull_request(self, pr_number: int) -> Dict[str, Any]: + """ + Fetches a specific pull request and its first 10 comments, + limited by max_tokens. + + Parameters: + pr_number(int): The number for the Github pull + max_tokens(int): The maximum number of tokens in the response + Returns: + dict: A dictionary containing the pull's title, body, + and comments as a string + """ + max_tokens = 2_000 + pull = self.github_repo_instance.get_pull(number=pr_number) + total_tokens = 0 + + def get_tokens(text: str) -> int: + return len(tiktoken.get_encoding("cl100k_base").encode(text)) + + def add_to_dict(data_dict: Dict[str, Any], key: str, value: str) -> None: + nonlocal total_tokens # Declare total_tokens as nonlocal + tokens = get_tokens(value) + if total_tokens + tokens <= max_tokens: + data_dict[key] = value + total_tokens += tokens # Now this will modify the outer variable + + response_dict: Dict[str, str] = {} + add_to_dict(response_dict, "title", pull.title) + add_to_dict(response_dict, "number", str(pr_number)) + add_to_dict(response_dict, "body", pull.body) + + comments: List[str] = [] + page = 0 + while len(comments) <= 10: + comments_page = pull.get_issue_comments().get_page(page) + if len(comments_page) == 0: + break + for comment in comments_page: + comment_str = str({"body": comment.body, "user": comment.user.login}) + if total_tokens + get_tokens(comment_str) > max_tokens: + break + comments.append(comment_str) + total_tokens += get_tokens(comment_str) + page += 1 + add_to_dict(response_dict, "comments", str(comments)) + + commits: List[str] = [] + page = 0 + while len(commits) <= 10: + commits_page = pull.get_commits().get_page(page) + if len(commits_page) == 0: + break + for commit in commits_page: + commit_str = str({"message": commit.commit.message}) + if total_tokens + get_tokens(commit_str) > max_tokens: + break + commits.append(commit_str) + total_tokens += get_tokens(commit_str) + page += 1 + add_to_dict(response_dict, "commits", str(commits)) + return response_dict + def create_pull_request(self, pr_query: str) -> str: """ Makes a pull request from the bot's branch to the base branch @@ -149,9 +508,9 @@ def create_pull_request(self, pr_query: str) -> str: Returns: str: A success or failure message """ - if self.github_base_branch == self.github_branch: + if self.github_base_branch == self.active_branch: return """Cannot make a pull request because - commits are already in the master branch""" + commits are already in the main or master branch.""" else: try: title = pr_query.split("\n")[0] @@ -159,7 +518,7 @@ def create_pull_request(self, pr_query: str) -> str: pr = self.github_repo_instance.create_pull( title=title, body=body, - head=self.github_branch, + head=self.active_branch, base=self.github_base_branch, ) return f"Successfully created PR number {pr.number}" @@ -197,33 +556,60 @@ def create_file(self, file_query: str) -> str: Returns: str: A success or failure message """ + if self.active_branch == self.github_base_branch: + return ( + "You're attempting to commit to the directly to the" + f"{self.github_base_branch} branch, which is protected. " + "Please create a new branch and try again." + ) + file_path = file_query.split("\n")[0] file_contents = file_query[len(file_path) + 2 :] + try: - exists = self.github_repo_instance.get_contents(file_path) - if exists is None: - self.github_repo_instance.create_file( - path=file_path, - message="Create " + file_path, - content=file_contents, - branch=self.github_branch, + try: + file = self.github_repo_instance.get_contents( + file_path, ref=self.active_branch ) - return "Created file " + file_path - else: - return f"File already exists at {file_path}. Use update_file instead" + if file: + return ( + f"File already exists at `{file_path}` " + f"on branch `{self.active_branch}`. You must use " + "`update_file` to modify it." + ) + except Exception: + # expected behavior, file shouldn't exist yet + pass + + self.github_repo_instance.create_file( + path=file_path, + message="Create " + file_path, + content=file_contents, + branch=self.active_branch, + ) + return "Created file " + file_path except Exception as e: return "Unable to make file due to error:\n" + str(e) def read_file(self, file_path: str) -> str: """ - Reads a file from the github repo + Read a file from this agent's branch, defined by self.active_branch, + which supports PR branches. Parameters: file_path(str): the file path Returns: - str: The file decoded as a string + str: The file decoded as a string, or an error message if not found """ - file = self.github_repo_instance.get_contents(file_path) - return file.decoded_content.decode("utf-8") + try: + file = self.github_repo_instance.get_contents( + file_path, ref=self.active_branch + ) + return file.decoded_content.decode("utf-8") + except Exception as e: + return ( + f"File not found `{file_path}` on branch" + f"`{self.active_branch}`. Error: {str(e)}" + ) def update_file(self, file_query: str) -> str: """ @@ -243,8 +629,14 @@ def update_file(self, file_query: str) -> str: Returns: A success or failure message """ + if self.active_branch == self.github_base_branch: + return ( + "You're attempting to commit to the directly" + f"to the {self.github_base_branch} branch, which is protected. " + "Please create a new branch and try again." + ) try: - file_path = file_query.split("\n")[0] + file_path: str = file_query.split("\n")[0] old_file_contents = ( file_query.split("OLD <<<<")[1].split(">>>> OLD")[0].strip() ) @@ -266,12 +658,14 @@ def update_file(self, file_query: str) -> str: self.github_repo_instance.update_file( path=file_path, - message="Update " + file_path, + message="Update " + str(file_path), content=updated_file_content, - branch=self.github_branch, - sha=self.github_repo_instance.get_contents(file_path).sha, + branch=self.active_branch, + sha=self.github_repo_instance.get_contents( + file_path, ref=self.active_branch + ).sha, ) - return "Updated file " + file_path + return "Updated file " + str(file_path) except Exception as e: return "Unable to update file due to error:\n" + str(e) @@ -283,23 +677,119 @@ def delete_file(self, file_path: str) -> str: Returns: str: Success or failure message """ + if self.active_branch == self.github_base_branch: + return ( + "You're attempting to commit to the directly" + f"to the {self.github_base_branch} branch, which is protected. " + "Please create a new branch and try again." + ) try: - file = self.github_repo_instance.get_contents(file_path) self.github_repo_instance.delete_file( path=file_path, message="Delete " + file_path, - branch=self.github_branch, - sha=file.sha, + branch=self.active_branch, + sha=self.github_repo_instance.get_contents( + file_path, ref=self.active_branch + ).sha, ) return "Deleted file " + file_path except Exception as e: return "Unable to delete file due to error:\n" + str(e) + def search_issues_and_prs(self, query: str) -> str: + """ + Searches issues and pull requests in the repository. + + Parameters: + query(str): The search query + + Returns: + str: A string containing the first 5 issues and pull requests + """ + search_result = self.github.search_issues(query, repo=self.github_repository) + max_items = min(5, len(search_result)) + results = [f"Top {max_items} results:"] + for issue in search_result[:max_items]: + results.append( + f"Title: {issue.title}, Number: {issue.number}, State: {issue.state}" + ) + return "\n".join(results) + + def search_code(self, query: str) -> str: + """ + Searches code in the repository. + # Todo: limit total tokens returned... + + Parameters: + query(str): The search query + + Returns: + str: A string containing, at most, the top 5 search results + """ + search_result = self.github.search_code( + query=query, repo=self.github_repository + ) + if search_result.totalCount == 0: + return "0 results found." + max_results = min(5, search_result.totalCount) + results = [f"Showing top {max_results} of {search_result.totalCount} results:"] + count = 0 + for code in search_result: + if count >= max_results: + break + # Get the file content using the PyGithub get_contents method + file_content = self.github_repo_instance.get_contents( + code.path, ref=self.active_branch + ).decoded_content.decode() + results.append( + f"Filepath: `{code.path}`\nFile contents: " + f"{file_content}\n" + ) + count += 1 + return "\n".join(results) + + def create_review_request(self, reviewer_username: str) -> str: + """ + Creates a review request on *THE* open pull request + that matches the current active_branch. + + Parameters: + reviewer_username(str): The username of the person who is being requested + + Returns: + str: A message confirming the creation of the review request + """ + pull_requests = self.github_repo_instance.get_pulls( + state="open", sort="created" + ) + # find PR against active_branch + pr = next( + (pr for pr in pull_requests if pr.head.ref == self.active_branch), None + ) + if pr is None: + return ( + "No open pull request found for the " + f"current branch `{self.active_branch}`" + ) + + try: + pr.create_review_request(reviewers=[reviewer_username]) + return ( + f"Review request created for user {reviewer_username} " + f"on PR #{pr.number}" + ) + except Exception as e: + return f"Failed to create a review request with error {e}" + def run(self, mode: str, query: str) -> str: - if mode == "get_issues": - return self.get_issues() - elif mode == "get_issue": + if mode == "get_issue": return json.dumps(self.get_issue(int(query))) + elif mode == "get_pull_request": + return json.dumps(self.get_pull_request(int(query))) + elif mode == "list_pull_request_files": + return json.dumps(self.list_pull_request_files(int(query))) + elif mode == "get_issues": + return self.get_issues() elif mode == "comment_on_issue": return self.comment_on_issue(query) elif mode == "create_file": @@ -312,5 +802,25 @@ def run(self, mode: str, query: str) -> str: return self.update_file(query) elif mode == "delete_file": return self.delete_file(query) + elif mode == "list_open_pull_requests": + return self.list_open_pull_requests() + elif mode == "list_files_in_main_branch": + return self.list_files_in_main_branch() + elif mode == "list_files_in_bot_branch": + return self.list_files_in_bot_branch() + elif mode == "list_branches_in_repo": + return self.list_branches_in_repo() + elif mode == "set_active_branch": + return self.set_active_branch(query) + elif mode == "create_branch": + return self.create_branch(query) + elif mode == "get_files_from_directory": + return self.get_files_from_directory(query) + elif mode == "search_issues_and_prs": + return self.search_issues_and_prs(query) + elif mode == "search_code": + return self.search_code(query) + elif mode == "create_review_request": + return self.create_review_request(query) else: raise ValueError("Invalid mode" + mode) diff --git a/libs/community/langchain_community/utilities/nasa.py b/libs/community/langchain_community/utilities/nasa.py new file mode 100644 index 0000000000000..b58889ca0de31 --- /dev/null +++ b/libs/community/langchain_community/utilities/nasa.py @@ -0,0 +1,51 @@ +"""Util that calls several NASA APIs.""" +import json + +import requests +from langchain_core.pydantic_v1 import BaseModel + +IMAGE_AND_VIDEO_LIBRARY_URL = "https://images-api.nasa.gov" + + +class NasaAPIWrapper(BaseModel): + def get_media(self, query: str) -> str: + params = json.loads(query) + if params.get("q"): + queryText = params["q"] + params.pop("q") + else: + queryText = "" + response = requests.get( + IMAGE_AND_VIDEO_LIBRARY_URL + "/search?q=" + queryText, params=params + ) + data = response.json() + return data + + def get_media_metadata_manifest(self, query: str) -> str: + response = requests.get(IMAGE_AND_VIDEO_LIBRARY_URL + "/asset/" + query) + return response.json() + + def get_media_metadata_location(self, query: str) -> str: + response = requests.get(IMAGE_AND_VIDEO_LIBRARY_URL + "/metadata/" + query) + return response.json() + + def get_video_captions_location(self, query: str) -> str: + response = requests.get(IMAGE_AND_VIDEO_LIBRARY_URL + "/captions/" + query) + return response.json() + + def run(self, mode: str, query: str) -> str: + if mode == "search_media": + output = self.get_media(query) + elif mode == "get_media_metadata_manifest": + output = self.get_media_metadata_manifest(query) + elif mode == "get_media_metadata_location": + output = self.get_media_metadata_location(query) + elif mode == "get_video_captions_location": + output = self.get_video_captions_location(query) + else: + output = f"ModeError: Got unexpected mode {mode}." + + try: + return json.dumps(output) + except Exception: + return str(output) diff --git a/libs/community/langchain_community/utilities/steam.py b/libs/community/langchain_community/utilities/steam.py new file mode 100644 index 0000000000000..778c3c6870e48 --- /dev/null +++ b/libs/community/langchain_community/utilities/steam.py @@ -0,0 +1,164 @@ +"""Util that calls Steam-WebAPI.""" + +from typing import Any, List + +from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator + + +class SteamWebAPIWrapper(BaseModel): + """Wrapper for Steam API.""" + + steam: Any # for python-steam-api + + from langchain_community.tools.steam.prompt import ( + STEAM_GET_GAMES_DETAILS, + STEAM_GET_RECOMMENDED_GAMES, + ) + + # operations: a list of dictionaries, each representing a specific operation that + # can be performed with the API + operations: List[dict] = [ + { + "mode": "get_game_details", + "name": "Get Game Details", + "description": STEAM_GET_GAMES_DETAILS, + }, + { + "mode": "get_recommended_games", + "name": "Get Recommended Games", + "description": STEAM_GET_RECOMMENDED_GAMES, + }, + ] + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + def get_operations(self) -> List[dict]: + """Return a list of operations.""" + return self.operations + + @root_validator + def validate_environment(cls, values: dict) -> dict: + """Validate api key and python package has been configured.""" + + # check if the python package is installed + try: + from steam import Steam + except ImportError: + raise ImportError("python-steam-api library is not installed. ") + + try: + from decouple import config + except ImportError: + raise ImportError("decouple library is not installed. ") + + # initialize the steam attribute for python-steam-api usage + KEY = config("STEAM_KEY") + steam = Steam(KEY) + values["steam"] = steam + return values + + def parse_to_str(self, details: dict) -> str: # For later parsing + """Parse the details result.""" + result = "" + for key, value in details.items(): + result += "The " + str(key) + " is: " + str(value) + "\n" + return result + + def get_id_link_price(self, games: dict) -> dict: + """The response may contain more than one game, so we need to choose the right + one and return the id.""" + + game_info = {} + for app in games["apps"]: + game_info["id"] = app["id"] + game_info["link"] = app["link"] + game_info["price"] = app["price"] + break + return game_info + + def remove_html_tags(self, html_string: str) -> str: + from bs4 import BeautifulSoup + + soup = BeautifulSoup(html_string, "html.parser") + return soup.get_text() + + def details_of_games(self, name: str) -> str: + games = self.steam.apps.search_games(name) + info_partOne_dict = self.get_id_link_price(games) + info_partOne = self.parse_to_str(info_partOne_dict) + id = str(info_partOne_dict.get("id")) + info_dict = self.steam.apps.get_app_details(id) + data = info_dict.get(id).get("data") + detailed_description = data.get("detailed_description") + + # detailed_description contains

  • some other html tags, so we need to + # remove them + detailed_description = self.remove_html_tags(detailed_description) + supported_languages = info_dict.get(id).get("data").get("supported_languages") + info_partTwo = ( + "The summary of the game is: " + + detailed_description + + "\n" + + "The supported languages of the game are: " + + supported_languages + + "\n" + ) + info = info_partOne + info_partTwo + return info + + def get_steam_id(self, name: str) -> str: + user = self.steam.users.search_user(name) + steam_id = user["player"]["steamid"] + return steam_id + + def get_users_games(self, steam_id: str) -> List[str]: + return self.steam.users.get_owned_games(steam_id, False, False) + + def recommended_games(self, steam_id: str) -> str: + try: + import steamspypi + except ImportError: + raise ImportError("steamspypi library is not installed.") + users_games = self.get_users_games(steam_id) + result = {} # type: ignore + most_popular_genre = "" + most_popular_genre_count = 0 + for game in users_games["games"]: # type: ignore + appid = game["appid"] + data_request = {"request": "appdetails", "appid": appid} + genreStore = steamspypi.download(data_request) + genreList = genreStore.get("genre", "").split(", ") + + for genre in genreList: + if genre in result: + result[genre] += 1 + else: + result[genre] = 1 + if result[genre] > most_popular_genre_count: + most_popular_genre_count = result[genre] + most_popular_genre = genre + + data_request = dict() + data_request["request"] = "genre" + data_request["genre"] = most_popular_genre + data = steamspypi.download(data_request) + sorted_data = sorted( + data.values(), key=lambda x: x.get("average_forever", 0), reverse=True + ) + owned_games = [game["appid"] for game in users_games["games"]] # type: ignore + remaining_games = [ + game for game in sorted_data if game["appid"] not in owned_games + ] + top_5_popular_not_owned = [game["name"] for game in remaining_games[:5]] + return str(top_5_popular_not_owned) + + def run(self, mode: str, game: str) -> str: + if mode == "get_games_details": + return self.details_of_games(game) + elif mode == "get_recommended_games": + return self.recommended_games(game) + else: + raise ValueError(f"Invalid mode {mode} for Steam API.") diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 48071995d4266..b4be68b3bc19b 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -14,6 +14,7 @@ Optional, Tuple, Type, + Union, ) import numpy as np @@ -35,10 +36,13 @@ from azure.search.documents.indexes.models import ( ScoringProfile, SearchField, - SemanticSettings, VectorSearch, ) + try: + from azure.search.documents.indexes.models import SemanticSearch + except ImportError: + from azure.search.documents.indexes.models import SemanticSettings # <11.4.0 # Allow overriding field names for Azure Search FIELDS_ID = get_from_env( @@ -68,7 +72,7 @@ def _get_search_client( semantic_configuration_name: Optional[str] = None, fields: Optional[List[SearchField]] = None, vector_search: Optional[VectorSearch] = None, - semantic_settings: Optional[SemanticSettings] = None, + semantic_settings: Optional[Union[SemanticSearch, SemanticSettings]] = None, scoring_profiles: Optional[List[ScoringProfile]] = None, default_scoring_profile: Optional[str] = None, default_fields: Optional[List[SearchField]] = None, @@ -80,15 +84,30 @@ def _get_search_client( from azure.search.documents import SearchClient from azure.search.documents.indexes import SearchIndexClient from azure.search.documents.indexes.models import ( - HnswVectorSearchAlgorithmConfiguration, - PrioritizedFields, SearchIndex, SemanticConfiguration, SemanticField, - SemanticSettings, VectorSearch, ) + # class names changed for versions >= 11.4.0 + try: + from azure.search.documents.indexes.models import ( + HnswAlgorithmConfiguration, # HnswVectorSearchAlgorithmConfiguration is old + SemanticPrioritizedFields, # PrioritizedFields outdated + SemanticSearch, # SemanticSettings outdated + ) + + NEW_VERSION = True + except ImportError: + from azure.search.documents.indexes.models import ( + HnswVectorSearchAlgorithmConfiguration, + PrioritizedFields, + SemanticSettings, + ) + + NEW_VERSION = False + default_fields = default_fields or [] if key is None: credential = DefaultAzureCredential() @@ -134,34 +153,71 @@ def fmt_err(x: str) -> str: fields = default_fields # Vector search configuration if vector_search is None: - vector_search = VectorSearch( - algorithm_configurations=[ - HnswVectorSearchAlgorithmConfiguration( - name="default", - kind="hnsw", - parameters={ # type: ignore - "m": 4, - "efConstruction": 400, - "efSearch": 500, - "metric": "cosine", - }, - ) - ] - ) + if NEW_VERSION: + # >= 11.4.0: + # VectorSearch(algorithm_configuration) --> VectorSearch(algorithms) + # HnswVectorSearchAlgorithmConfiguration --> HnswAlgorithmConfiguration + vector_search = VectorSearch( + algorithms=[ + HnswAlgorithmConfiguration( + name="default", + kind="hnsw", + parameters={ # type: ignore + "m": 4, + "efConstruction": 400, + "efSearch": 500, + "metric": "cosine", + }, + ) + ] + ) + else: # < 11.4.0 + vector_search = VectorSearch( + algorithm_configurations=[ + HnswVectorSearchAlgorithmConfiguration( + name="default", + kind="hnsw", + parameters={ # type: ignore + "m": 4, + "efConstruction": 400, + "efSearch": 500, + "metric": "cosine", + }, + ) + ] + ) + # Create the semantic settings with the configuration if semantic_settings is None and semantic_configuration_name is not None: - semantic_settings = SemanticSettings( - configurations=[ - SemanticConfiguration( - name=semantic_configuration_name, - prioritized_fields=PrioritizedFields( - prioritized_content_fields=[ - SemanticField(field_name=FIELDS_CONTENT) - ], - ), - ) - ] - ) + if NEW_VERSION: + # <=11.4.0: SemanticSettings --> SemanticSearch + # PrioritizedFields(prioritized_content_fields) + # --> SemanticPrioritizedFields(content_fields) + semantic_settings = SemanticSearch( + configurations=[ + SemanticConfiguration( + name=semantic_configuration_name, + prioritized_fields=SemanticPrioritizedFields( + content_fields=[ + SemanticField(field_name=FIELDS_CONTENT) + ], + ), + ) + ] + ) + else: # < 11.4.0 + semantic_settings = SemanticSettings( + configurations=[ + SemanticConfiguration( + name=semantic_configuration_name, + prioritized_fields=PrioritizedFields( + prioritized_content_fields=[ + SemanticField(field_name=FIELDS_CONTENT) + ], + ), + ) + ] + ) # Create the search index with the semantic settings and vector search index = SearchIndex( name=index_name, @@ -195,7 +251,7 @@ def __init__( semantic_query_language: str = "en-us", fields: Optional[List[SearchField]] = None, vector_search: Optional[VectorSearch] = None, - semantic_settings: Optional[SemanticSettings] = None, + semantic_settings: Optional[Union[SemanticSearch, SemanticSettings]] = None, scoring_profiles: Optional[List[ScoringProfile]] = None, default_scoring_profile: Optional[str] = None, **kwargs: Any, @@ -390,10 +446,21 @@ def vector_search_with_score( ( Document( page_content=result.pop(FIELDS_CONTENT), - metadata=json.loads(result[FIELDS_METADATA]) - if FIELDS_METADATA in result - else { - k: v for k, v in result.items() if k != FIELDS_CONTENT_VECTOR + metadata={ + **( + {FIELDS_ID: result.pop(FIELDS_ID)} + if FIELDS_ID in result + else {} + ), + **( + json.loads(result[FIELDS_METADATA]) + if FIELDS_METADATA in result + else { + k: v + for k, v in result.items() + if k != FIELDS_CONTENT_VECTOR + } + ), }, ), float(result["@search.score"]), @@ -451,10 +518,21 @@ def hybrid_search_with_score( ( Document( page_content=result.pop(FIELDS_CONTENT), - metadata=json.loads(result[FIELDS_METADATA]) - if FIELDS_METADATA in result - else { - k: v for k, v in result.items() if k != FIELDS_CONTENT_VECTOR + metadata={ + **( + {FIELDS_ID: result.pop(FIELDS_ID)} + if FIELDS_ID in result + else {} + ), + **( + json.loads(result[FIELDS_METADATA]) + if FIELDS_METADATA in result + else { + k: v + for k, v in result.items() + if k != FIELDS_CONTENT_VECTOR + } + ), }, ), float(result["@search.score"]), @@ -546,6 +624,11 @@ def semantic_hybrid_search_with_score_and_rerank( Document( page_content=result.pop(FIELDS_CONTENT), metadata={ + **( + {FIELDS_ID: result.pop(FIELDS_ID)} + if FIELDS_ID in result + else {} + ), **( json.loads(result[FIELDS_METADATA]) if FIELDS_METADATA in result diff --git a/libs/community/langchain_community/vectorstores/chroma.py b/libs/community/langchain_community/vectorstores/chroma.py index 2f81afd27c5b9..bad8612a82ba7 100644 --- a/libs/community/langchain_community/vectorstores/chroma.py +++ b/libs/community/langchain_community/vectorstores/chroma.py @@ -176,7 +176,7 @@ def add_images( """Run more images through the embeddings and add to the vectorstore. Args: - images (List[List[float]]): Images to add to the vectorstore. + uris List[str]: File path to the image. metadatas (Optional[List[dict]], optional): Optional list of metadatas. ids (Optional[List[str]], optional): Optional list of IDs. diff --git a/libs/community/langchain_community/vectorstores/momento_vector_index.py b/libs/community/langchain_community/vectorstores/momento_vector_index.py index 3c72256214097..b8f5b3e55122e 100644 --- a/libs/community/langchain_community/vectorstores/momento_vector_index.py +++ b/libs/community/langchain_community/vectorstores/momento_vector_index.py @@ -1,3 +1,4 @@ +import logging from typing import ( TYPE_CHECKING, Any, @@ -11,15 +12,20 @@ ) from uuid import uuid4 +import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils import get_from_env from langchain_core.vectorstores import VectorStore -from langchain_community.vectorstores.utils import DistanceStrategy +from langchain_community.vectorstores.utils import ( + DistanceStrategy, + maximal_marginal_relevance, +) VST = TypeVar("VST", bound="VectorStore") +logger = logging.getLogger(__name__) if TYPE_CHECKING: from momento import PreviewVectorIndexClient @@ -75,9 +81,8 @@ def __init__( index_name (str, optional): The name of the index to store the documents in. Defaults to "default". distance_strategy (DistanceStrategy, optional): The distance strategy to - use. Defaults to DistanceStrategy.COSINE. If you select - DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared - Euclidean distance. + use. If you select DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses + the squared Euclidean distance. Defaults to DistanceStrategy.COSINE. text_field (str, optional): The name of the metadata field to store the original text in. Defaults to "text". ensure_index_exists (bool, optional): Whether to ensure that the index @@ -125,6 +130,7 @@ def _create_index_if_not_exists(self, num_dimensions: int) -> bool: elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE: similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY else: + logger.error(f"Distance strategy {self.distance_strategy} not implemented.") raise ValueError( f"Distance strategy {self.distance_strategy} not implemented." ) @@ -137,8 +143,10 @@ def _create_index_if_not_exists(self, num_dimensions: int) -> bool: elif isinstance(response, CreateIndex.IndexAlreadyExists): return False elif isinstance(response, CreateIndex.Error): + logger.error(f"Error creating index: {response.inner_exception}") raise response.inner_exception else: + logger.error(f"Unexpected response: {response}") raise Exception(f"Unexpected response: {response}") def add_texts( @@ -331,6 +339,87 @@ def similarity_search_by_vector( ) return [doc for doc, _ in results] + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + from momento.requests.vector_index import ALL_METADATA + from momento.responses.vector_index import SearchAndFetchVectors + + response = self._client.search_and_fetch_vectors( + self.index_name, embedding, top_k=fetch_k, metadata_fields=ALL_METADATA + ) + + if isinstance(response, SearchAndFetchVectors.Success): + pass + elif isinstance(response, SearchAndFetchVectors.Error): + logger.error(f"Error searching and fetching vectors: {response}") + return [] + else: + logger.error(f"Unexpected response: {response}") + raise Exception(f"Unexpected response: {response}") + + mmr_selected = maximal_marginal_relevance( + query_embedding=np.array([embedding], dtype=np.float32), + embedding_list=[hit.vector for hit in response.hits], + lambda_mult=lambda_mult, + k=k, + ) + selected = [response.hits[i].metadata for i in mmr_selected] + return [ + Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore # noqa: E501 + for metadata in selected + ] + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + embedding = self._embedding.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding, k, fetch_k, lambda_mult, **kwargs + ) + @classmethod def from_texts( cls: Type[VST], diff --git a/libs/community/tests/integration_tests/document_loaders/test_couchbase.py b/libs/community/tests/integration_tests/document_loaders/test_couchbase.py new file mode 100644 index 0000000000000..f400867962616 --- /dev/null +++ b/libs/community/tests/integration_tests/document_loaders/test_couchbase.py @@ -0,0 +1,44 @@ +import unittest + +from langchain_community.document_loaders.couchbase import CouchbaseLoader + +try: + import couchbase # noqa: F401 + + couchbase_installed = True +except ImportError: + couchbase_installed = False + + +@unittest.skipIf(not couchbase_installed, "couchbase not installed") +class TestCouchbaseLoader(unittest.TestCase): + def setUp(self) -> None: + self.conn_string = "" + self.database_user = "" + self.database_password = "" + self.valid_query = "select h.* from `travel-sample`.inventory.hotel h limit 10" + self.valid_page_content_fields = ["country", "name", "description"] + self.valid_metadata_fields = ["id"] + + def test_couchbase_loader(self) -> None: + """Test Couchbase loader.""" + loader = CouchbaseLoader( + connection_string=self.conn_string, + db_username=self.database_user, + db_password=self.database_password, + query=self.valid_query, + page_content_fields=self.valid_page_content_fields, + metadata_fields=self.valid_metadata_fields, + ) + docs = loader.load() + print(docs) + + assert len(docs) > 0 # assuming the query returns at least one document + for doc in docs: + print(doc) + assert ( + doc.page_content != "" + ) # assuming that every document has page_content + assert ( + "id" in doc.metadata and doc.metadata["id"] != "" + ) # assuming that every document has 'id' diff --git a/libs/community/tests/integration_tests/embeddings/test_bookend.py b/libs/community/tests/integration_tests/embeddings/test_bookend.py new file mode 100644 index 0000000000000..c15036f14bd6b --- /dev/null +++ b/libs/community/tests/integration_tests/embeddings/test_bookend.py @@ -0,0 +1,27 @@ +"""Test Bookend AI embeddings.""" +from langchain_community.embeddings.bookend import BookendEmbeddings + + +def test_bookend_embedding_documents() -> None: + """Test Bookend AI embeddings for documents.""" + documents = ["foo bar", "bar foo"] + embedding = BookendEmbeddings( + domain="", + api_token="", + model_id="", + ) + output = embedding.embed_documents(documents) + assert len(output) == 2 + assert len(output[0]) == 768 + + +def test_bookend_embedding_query() -> None: + """Test Bookend AI embeddings for query.""" + document = "foo bar" + embedding = BookendEmbeddings( + domain="", + api_token="", + model_id="", + ) + output = embedding.embed_query(document) + assert len(output) == 768 diff --git a/libs/community/tests/integration_tests/embeddings/test_cloudflare_workersai.py b/libs/community/tests/integration_tests/embeddings/test_cloudflare_workersai.py new file mode 100644 index 0000000000000..55261f51725ba --- /dev/null +++ b/libs/community/tests/integration_tests/embeddings/test_cloudflare_workersai.py @@ -0,0 +1,55 @@ +"""Test Cloudflare Workers AI embeddings.""" + +import responses + +from langchain_community.embeddings.cloudflare_workersai import ( + CloudflareWorkersAIEmbeddings, +) + + +@responses.activate +def test_cloudflare_workers_ai_embedding_documents() -> None: + """Test Cloudflare Workers AI embeddings.""" + documents = ["foo bar", "foo bar", "foo bar"] + + responses.add( + responses.POST, + "https://api.cloudflare.com/client/v4/accounts/123/ai/run/@cf/baai/bge-base-en-v1.5", + json={ + "result": { + "shape": [3, 768], + "data": [[0.0] * 768, [0.0] * 768, [0.0] * 768], + }, + "success": "true", + "errors": [], + "messages": [], + }, + ) + + embeddings = CloudflareWorkersAIEmbeddings(account_id="123", api_token="abc") + output = embeddings.embed_documents(documents) + + assert len(output) == 3 + assert len(output[0]) == 768 + + +@responses.activate +def test_cloudflare_workers_ai_embedding_query() -> None: + """Test Cloudflare Workers AI embeddings.""" + + responses.add( + responses.POST, + "https://api.cloudflare.com/client/v4/accounts/123/ai/run/@cf/baai/bge-base-en-v1.5", + json={ + "result": {"shape": [1, 768], "data": [[0.0] * 768]}, + "success": "true", + "errors": [], + "messages": [], + }, + ) + + document = "foo bar" + embeddings = CloudflareWorkersAIEmbeddings(account_id="123", api_token="abc") + output = embeddings.embed_query(document) + + assert len(output) == 768 diff --git a/libs/community/tests/integration_tests/llms/test_arcee.py b/libs/community/tests/integration_tests/llms/test_arcee.py index 59193d7a6831c..95797988f01ae 100644 --- a/libs/community/tests/integration_tests/llms/test_arcee.py +++ b/libs/community/tests/integration_tests/llms/test_arcee.py @@ -1,34 +1,70 @@ -"""Test Arcee llm""" +from unittest.mock import MagicMock, patch + from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain_community.llms.arcee import Arcee -def test_api_key_is_secret_string() -> None: - llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") - assert isinstance(llm.arcee_api_key, SecretStr) +@patch("langchain.utilities.arcee.requests.get") +def test_arcee_api_key_is_secret_string(mock_get: MagicMock) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } + arcee_without_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="secret_api_key", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + assert isinstance(arcee_without_env_var.arcee_api_key, SecretStr) -def test_api_key_masked_when_passed_from_env( - monkeypatch: MonkeyPatch, capsys: CaptureFixture -) -> None: - """Test initialization with an API key provided via an env variable""" - monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key") - llm = Arcee(model="DALM-PubMed") +@patch("langchain.utilities.arcee.requests.get") +def test_api_key_masked_when_passed_via_constructor( + mock_get: MagicMock, capsys: CaptureFixture +) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } - print(llm.arcee_api_key, end="") + arcee_without_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="secret_api_key", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + print(arcee_without_env_var.arcee_api_key, end="") captured = capsys.readouterr() - assert captured.out == "**********" + assert "**********" == captured.out -def test_api_key_masked_when_passed_via_constructor( - capsys: CaptureFixture, + +@patch("langchain.utilities.arcee.requests.get") +def test_api_key_masked_when_passed_from_env( + mock_get: MagicMock, capsys: CaptureFixture, monkeypatch: MonkeyPatch ) -> None: - """Test initialization with an API key provided via the initializer""" - llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } - print(llm.arcee_api_key, end="") + monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key") + arcee_with_env_var = Arcee( + model="DALM-PubMed", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + print(arcee_with_env_var.arcee_api_key, end="") captured = capsys.readouterr() - assert captured.out == "**********" + + assert "**********" == captured.out diff --git a/libs/community/tests/integration_tests/storage/test_upstash_redis.py b/libs/community/tests/integration_tests/storage/test_upstash_redis.py index d7b824f8602bd..853de4c234fa8 100644 --- a/libs/community/tests/integration_tests/storage/test_upstash_redis.py +++ b/libs/community/tests/integration_tests/storage/test_upstash_redis.py @@ -5,7 +5,7 @@ import pytest -from langchain_community.storage.upstash_redis import UpstashRedisStore +from langchain_community.storage.upstash_redis import UpstashRedisByteStore if TYPE_CHECKING: from upstash_redis import Redis @@ -34,16 +34,16 @@ def redis_client() -> Redis: def test_mget(redis_client: Redis) -> None: - store = UpstashRedisStore(client=redis_client, ttl=None) + store = UpstashRedisByteStore(client=redis_client, ttl=None) keys = ["key1", "key2"] redis_client.mset({"key1": "value1", "key2": "value2"}) result = store.mget(keys) - assert result == ["value1", "value2"] + assert result == [b"value1", b"value2"] def test_mset(redis_client: Redis) -> None: - store = UpstashRedisStore(client=redis_client, ttl=None) - key_value_pairs = [("key1", "value1"), ("key2", "value2")] + store = UpstashRedisByteStore(client=redis_client, ttl=None) + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] store.mset(key_value_pairs) result = redis_client.mget("key1", "key2") assert result == ["value1", "value2"] @@ -51,7 +51,7 @@ def test_mset(redis_client: Redis) -> None: def test_mdelete(redis_client: Redis) -> None: """Test that deletion works as expected.""" - store = UpstashRedisStore(client=redis_client, ttl=None) + store = UpstashRedisByteStore(client=redis_client, ttl=None) keys = ["key1", "key2"] redis_client.mset({"key1": "value1", "key2": "value2"}) store.mdelete(keys) @@ -60,7 +60,7 @@ def test_mdelete(redis_client: Redis) -> None: def test_yield_keys(redis_client: Redis) -> None: - store = UpstashRedisStore(client=redis_client, ttl=None) + store = UpstashRedisByteStore(client=redis_client, ttl=None) redis_client.mset({"key1": "value2", "key2": "value2"}) assert sorted(store.yield_keys()) == ["key1", "key2"] assert sorted(store.yield_keys(prefix="key*")) == ["key1", "key2"] @@ -68,8 +68,8 @@ def test_yield_keys(redis_client: Redis) -> None: def test_namespace(redis_client: Redis) -> None: - store = UpstashRedisStore(client=redis_client, ttl=None, namespace="meow") - key_value_pairs = [("key1", "value1"), ("key2", "value2")] + store = UpstashRedisByteStore(client=redis_client, ttl=None, namespace="meow") + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] store.mset(key_value_pairs) cursor, all_keys = redis_client.scan(0) diff --git a/libs/community/tests/integration_tests/utilities/test_duckduckdgo_search_api.py b/libs/community/tests/integration_tests/utilities/test_duckduckdgo_search_api.py index 82debbb7694ea..d8b1a6e165bd1 100644 --- a/libs/community/tests/integration_tests/utilities/test_duckduckdgo_search_api.py +++ b/libs/community/tests/integration_tests/utilities/test_duckduckdgo_search_api.py @@ -1,11 +1,14 @@ import pytest -from langchain_community.tools.ddg_search.tool import DuckDuckGoSearchRun +from langchain_community.tools.ddg_search.tool import ( + DuckDuckGoSearchResults, + DuckDuckGoSearchRun, +) def ddg_installed() -> bool: try: - from duckduckgo_search import ddg # noqa: F401 + from duckduckgo_search import DDGS # noqa: F401 return True except Exception as e: @@ -20,3 +23,12 @@ def test_ddg_search_tool() -> None: result = tool(keywords) print(result) assert len(result.split()) > 20 + + +@pytest.mark.skipif(not ddg_installed(), reason="requires duckduckgo-search package") +def test_ddg_search_news_tool() -> None: + keywords = "Tesla" + tool = DuckDuckGoSearchResults(source="news") + result = tool(keywords) + print(result) + assert len(result.split()) > 20 diff --git a/libs/community/tests/integration_tests/utilities/test_nasa.py b/libs/community/tests/integration_tests/utilities/test_nasa.py new file mode 100644 index 0000000000000..621b76552b8ea --- /dev/null +++ b/libs/community/tests/integration_tests/utilities/test_nasa.py @@ -0,0 +1,32 @@ +"""Integration test for NASA API Wrapper.""" +from langchain_community.utilities.nasa import NasaAPIWrapper + + +def test_media_search() -> None: + """Test for NASA Image and Video Library media search""" + nasa = NasaAPIWrapper() + query = '{"q": "saturn", + "year_start": "2002", "year_end": "2010", "page": 2}' + output = nasa.run("search_media", query) + assert output is not None + assert "collection" in output + + +def test_get_media_metadata_manifest() -> None: + """Test for retrieving media metadata manifest from NASA Image and Video Library""" + nasa = NasaAPIWrapper() + output = nasa.run("get_media_metadata_manifest", "2022_0707_Recientemente") + assert output is not None + + +def test_get_media_metadata_location() -> None: + """Test for retrieving media metadata location from NASA Image and Video Library""" + nasa = NasaAPIWrapper() + output = nasa.run("get_media_metadata_location", "as11-40-5874") + assert output is not None + + +def test_get_video_captions_location() -> None: + """Test for retrieving video captions location from NASA Image and Video Library""" + nasa = NasaAPIWrapper() + output = nasa.run("get_video_captions_location", "172_ISS-Slosh.sr") + assert output is not None diff --git a/libs/community/tests/integration_tests/utilities/test_steam_api.py b/libs/community/tests/integration_tests/utilities/test_steam_api.py new file mode 100644 index 0000000000000..be61f5f1ab202 --- /dev/null +++ b/libs/community/tests/integration_tests/utilities/test_steam_api.py @@ -0,0 +1,22 @@ +import ast + +from langchain_community.utilities.steam import SteamWebAPIWrapper + + +def test_get_game_details() -> None: + """Test for getting game details on Steam""" + steam = SteamWebAPIWrapper() + output = steam.run("get_game_details", "Terraria") + assert "id" in output + assert "link" in output + assert "detailed description" in output + assert "supported languages" in output + assert "price" in output + + +def test_get_recommended_games() -> None: + """Test for getting recommended games on Steam""" + steam = SteamWebAPIWrapper() + output = steam.run("get_recommended_games", "76561198362745711") + output = ast.literal_eval(output) + assert len(output) == 5 diff --git a/libs/community/tests/integration_tests/vectorstores/fake_embeddings.py b/libs/community/tests/integration_tests/vectorstores/fake_embeddings.py index 7b99c696444af..5de74832de05b 100644 --- a/libs/community/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/libs/community/tests/integration_tests/vectorstores/fake_embeddings.py @@ -53,11 +53,6 @@ def embed_query(self, text: str) -> List[float]: """Return consistent embeddings for the text, if seen before, or a constant one if the text is unknown.""" return self.embed_documents([text])[0] - if text not in self.known_texts: - return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] - return [float(1.0)] * (self.dimensionality - 1) + [ - float(self.known_texts.index(text)) - ] class AngularTwoDimensionalEmbeddings(Embeddings): diff --git a/libs/community/tests/integration_tests/vectorstores/test_momento_vector_index.py b/libs/community/tests/integration_tests/vectorstores/test_momento_vector_index.py index 712b847650a83..0e8e178fd7d67 100644 --- a/libs/community/tests/integration_tests/vectorstores/test_momento_vector_index.py +++ b/libs/community/tests/integration_tests/vectorstores/test_momento_vector_index.py @@ -125,7 +125,7 @@ def test_from_texts_with_metadatas( def test_from_texts_with_scores(vector_store: MomentoVectorIndex) -> None: - # """Test end to end construction and search with scores and IDs.""" + """Test end to end construction and search with scores and IDs.""" texts = ["apple", "orange", "hammer"] metadatas = [{"page": f"{i}"} for i in range(len(texts))] @@ -162,3 +162,25 @@ def test_add_documents_with_ids(vector_store: MomentoVectorIndex) -> None: ) assert isinstance(response, Search.Success) assert [hit.id for hit in response.hits] == ids + + +def test_max_marginal_relevance_search(vector_store: MomentoVectorIndex) -> None: + """Test max marginal relevance search.""" + pepperoni_pizza = "pepperoni pizza" + cheese_pizza = "cheese pizza" + hot_dog = "hot dog" + + vector_store.add_texts([pepperoni_pizza, cheese_pizza, hot_dog]) + wait() + search_results = vector_store.similarity_search("pizza", k=2) + + assert search_results == [ + Document(page_content=pepperoni_pizza, metadata={}), + Document(page_content=cheese_pizza, metadata={}), + ] + + search_results = vector_store.max_marginal_relevance_search(query="pizza", k=2) + assert search_results == [ + Document(page_content=pepperoni_pizza, metadata={}), + Document(page_content=hot_dog, metadata={}), + ] diff --git a/libs/community/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/libs/community/tests/unit_tests/callbacks/tracers/test_base_tracer.py index 342e595182654..92ea3a1a4d00e 100644 --- a/libs/community/tests/unit_tests/callbacks/tracers/test_base_tracer.py +++ b/libs/community/tests/unit_tests/callbacks/tracers/test_base_tracer.py @@ -331,6 +331,42 @@ def test_tracer_llm_run_on_error() -> None: assert tracer.runs == [compare_run] +@freeze_time("2023-01-01") +def test_tracer_llm_run_on_error_callback() -> None: + """Test tracer on an LLM run with an error and a callback.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = Run( + id=str(uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + events=[ + {"name": "start", "time": datetime.utcnow()}, + {"name": "error", "time": datetime.utcnow()}, + ], + extra={}, + execution_order=1, + child_execution_order=1, + serialized=SERIALIZED, + inputs=dict(prompts=[]), + outputs=None, + error=repr(exception), + run_type="llm", + ) + + class FakeTracerWithLlmErrorCallback(FakeTracer): + error_run = None + + def _on_llm_error(self, run: Run) -> None: + self.error_run = run + + tracer = FakeTracerWithLlmErrorCallback() + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) + tracer.on_llm_error(exception, run_id=uuid) + assert tracer.error_run == compare_run + + @freeze_time("2023-01-01") def test_tracer_chain_run_on_error() -> None: """Test tracer on a Chain run with an error.""" diff --git a/libs/community/tests/unit_tests/callbacks/tracers/test_comet.py b/libs/community/tests/unit_tests/callbacks/tracers/test_comet.py new file mode 100644 index 0000000000000..7aa2c41c2ee43 --- /dev/null +++ b/libs/community/tests/unit_tests/callbacks/tracers/test_comet.py @@ -0,0 +1,97 @@ +import uuid +from types import SimpleNamespace +from unittest import mock + +from langchain_core.callbacks.tracers import comet +from langchain_core.outputs import LLMResult + + +def test_comet_tracer__trace_chain_with_single_span__happyflow() -> None: + # Setup mocks + chain_module_mock = mock.Mock() + chain_instance_mock = mock.Mock() + chain_module_mock.Chain.return_value = chain_instance_mock + + span_module_mock = mock.Mock() + span_instance_mock = mock.MagicMock() + span_instance_mock.__api__start__ = mock.Mock() + span_instance_mock.__api__end__ = mock.Mock() + + span_module_mock.Span.return_value = span_instance_mock + + experiment_info_module_mock = mock.Mock() + experiment_info_module_mock.get.return_value = "the-experiment-info" + + chain_api_module_mock = mock.Mock() + + comet_ml_api_mock = SimpleNamespace( + chain=chain_module_mock, + span=span_module_mock, + experiment_info=experiment_info_module_mock, + chain_api=chain_api_module_mock, + flush="not-used-in-this-test", + ) + + # Create tracer + with mock.patch.object( + comet, "import_comet_llm_api", return_value=comet_ml_api_mock + ): + tracer = comet.CometTracer() + + run_id_1 = uuid.UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a") + run_id_2 = uuid.UUID("4f31216e-7c26-4027-a5fd-0bbf9ace17dc") + + # Parent run + tracer.on_chain_start( + {"name": "chain-input"}, + {"input": "chain-input-prompt"}, + parent_run_id=None, + run_id=run_id_1, + ) + + # Check that chain was created + chain_module_mock.Chain.assert_called_once_with( + inputs={"input": "chain-input-prompt"}, + metadata=None, + experiment_info="the-experiment-info", + ) + + # Child run + tracer.on_llm_start( + {"name": "span-input"}, + ["span-input-prompt"], + parent_run_id=run_id_1, + run_id=run_id_2, + ) + + # Check that Span was created and attached to chain + span_module_mock.Span.assert_called_once_with( + inputs={"prompts": ["span-input-prompt"]}, + category=mock.ANY, + metadata=mock.ANY, + name=mock.ANY, + ) + span_instance_mock.__api__start__(chain_instance_mock) + + # Child run end + tracer.on_llm_end( + LLMResult(generations=[], llm_output={"span-output-key": "span-output-value"}), + run_id=run_id_2, + ) + # Check that Span outputs are set and span is ended + span_instance_mock.set_outputs.assert_called_once() + actual_span_outputs = span_instance_mock.set_outputs.call_args[1]["outputs"] + assert { + "llm_output": {"span-output-key": "span-output-value"}, + "generations": [], + }.items() <= actual_span_outputs.items() + span_instance_mock.__api__end__() + + # Parent run end + tracer.on_chain_end({"chain-output-key": "chain-output-value"}, run_id=run_id_1) + + # Check that chain outputs are set and chain is logged + chain_instance_mock.set_outputs.assert_called_once() + actual_chain_outputs = chain_instance_mock.set_outputs.call_args[1]["outputs"] + assert ("chain-output-key", "chain-output-value") in actual_chain_outputs.items() + chain_api_module_mock.log_chain.assert_called_once_with(chain_instance_mock) diff --git a/libs/community/tests/unit_tests/document_loaders/blob_loaders/test_schema.py b/libs/community/tests/unit_tests/document_loaders/blob_loaders/test_schema.py index 8f89f43835706..15ed29af43a54 100644 --- a/libs/community/tests/unit_tests/document_loaders/blob_loaders/test_schema.py +++ b/libs/community/tests/unit_tests/document_loaders/blob_loaders/test_schema.py @@ -135,3 +135,20 @@ def yield_blobs(self) -> Iterable[Blob]: yield Blob(data=b"Hello, World!") assert list(TestLoader().yield_blobs()) == [Blob(data=b"Hello, World!")] + + +def test_metadata_and_source() -> None: + """Test metadata and source""" + blob = Blob(path="some_file", data="b") + assert blob.source == "some_file" + assert blob.metadata == {} + blob = Blob(data=b"", metadata={"source": "hello"}) + assert blob.source == "hello" + assert blob.metadata == {"source": "hello"} + + blob = Blob.from_data("data", metadata={"source": "somewhere"}) + assert blob.source == "somewhere" + + with get_temp_file(b"hello") as path: + blob = Blob.from_path(path, metadata={"source": "somewhere"}) + assert blob.source == "somewhere" diff --git a/libs/community/tests/unit_tests/document_loaders/sample_documents/obsidian/template_var_frontmatter.md b/libs/community/tests/unit_tests/document_loaders/sample_documents/obsidian/template_var_frontmatter.md new file mode 100644 index 0000000000000..7bab90737c31f --- /dev/null +++ b/libs/community/tests/unit_tests/document_loaders/sample_documents/obsidian/template_var_frontmatter.md @@ -0,0 +1,12 @@ +--- +aString: {{var}} +anArray: +- element +- {{varElement}} +aDict: + dictId1: 'val' + dictId2: '{{varVal}}' +tags: [ 'tag', '{{varTag}}' ] +--- + +Frontmatter contains template variables. diff --git a/libs/community/tests/unit_tests/document_loaders/test_couchbase.py b/libs/community/tests/unit_tests/document_loaders/test_couchbase.py new file mode 100644 index 0000000000000..038000d9e70d7 --- /dev/null +++ b/libs/community/tests/unit_tests/document_loaders/test_couchbase.py @@ -0,0 +1,6 @@ +"""Test importing the Couchbase document loader.""" + + +def test_couchbase_import() -> None: + """Test that the Couchbase document loader can be imported.""" + from langchain_community.document_loaders import CouchbaseLoader # noqa: F401 diff --git a/libs/community/tests/unit_tests/document_loaders/test_imports.py b/libs/community/tests/unit_tests/document_loaders/test_imports.py index 43db9787e4ae3..69f20546d04e6 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/community/tests/unit_tests/document_loaders/test_imports.py @@ -41,6 +41,7 @@ "CollegeConfidentialLoader", "ConcurrentLoader", "ConfluenceLoader", + "CouchbaseLoader", "CubeSemanticLoader", "DataFrameLoader", "DatadogLogsLoader", @@ -52,8 +53,6 @@ "Docx2txtLoader", "DropboxLoader", "DuckDBLoader", - "EmbaasBlobLoader", - "EmbaasLoader", "EtherscanLoader", "EverNoteLoader", "FacebookChatLoader", diff --git a/libs/community/tests/unit_tests/document_loaders/test_obsidian.py b/libs/community/tests/unit_tests/document_loaders/test_obsidian.py index b820d3335f308..d73b53cfcb276 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_obsidian.py +++ b/libs/community/tests/unit_tests/document_loaders/test_obsidian.py @@ -17,7 +17,7 @@ def test_page_content_loaded() -> None: """Verify that all docs have page_content""" - assert len(docs) == 5 + assert len(docs) == 6 assert all(doc.page_content for doc in docs) @@ -27,7 +27,7 @@ def test_disable_collect_metadata() -> None: str(OBSIDIAN_EXAMPLE_PATH), collect_metadata=False ) docs_wo = loader_without_metadata.load() - assert len(docs_wo) == 5 + assert len(docs_wo) == 6 assert all(doc.page_content for doc in docs_wo) assert all(set(doc.metadata) == STANDARD_METADATA_FIELDS for doc in docs_wo) @@ -45,6 +45,24 @@ def test_metadata_with_frontmatter() -> None: assert set(doc.metadata["tags"].split(",")) == {"journal/entry", "obsidian"} +def test_metadata_with_template_vars_in_frontmatter() -> None: + """Verify frontmatter fields with template variables are loaded.""" + doc = next( + doc for doc in docs if doc.metadata["source"] == "template_var_frontmatter.md" + ) + FRONTMATTER_FIELDS = { + "aString", + "anArray", + "aDict", + "tags", + } + assert set(doc.metadata) == FRONTMATTER_FIELDS | STANDARD_METADATA_FIELDS + assert doc.metadata["aString"] == "{{var}}" + assert doc.metadata["anArray"] == "['element', '{{varElement}}']" + assert doc.metadata["aDict"] == "{'dictId1': 'val', 'dictId2': '{{varVal}}'}" + assert set(doc.metadata["tags"].split(",")) == {"tag", "{{varTag}}"} + + def test_metadata_with_bad_frontmatter() -> None: """Verify a doc with non-yaml frontmatter.""" doc = next(doc for doc in docs if doc.metadata["source"] == "bad_frontmatter.md") diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index dd447ab6fc8c8..8af0e0125288a 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -53,6 +53,7 @@ "QianfanEmbeddingsEndpoint", "JohnSnowLabsEmbeddings", "VoyageEmbeddings", + "BookendEmbeddings", ] diff --git a/libs/community/tests/unit_tests/tools/test_imports.py b/libs/community/tests/unit_tests/tools/test_imports.py index 8ea7913424774..424540de098dd 100644 --- a/libs/community/tests/unit_tests/tools/test_imports.py +++ b/libs/community/tests/unit_tests/tools/test_imports.py @@ -68,6 +68,7 @@ "ListSparkSQLTool", "MetaphorSearchResults", "MoveFileTool", + "NasaAction", "NavigateBackTool", "NavigateTool", "O365CreateDraftMessage", @@ -91,6 +92,8 @@ "RequestsPostTool", "RequestsPutTool", "SceneXplainTool", + "SearchAPIRun", + "SearchAPIResults", "SearxSearchResults", "SearxSearchRun", "ShellTool", @@ -101,6 +104,7 @@ "SleepTool", "StackExchangeTool", "StdInInquireTool", + "SteamWebAPIQueryRun", "SteamshipImageGenerationTool", "StructuredTool", "Tool", diff --git a/libs/community/tests/unit_tests/tools/test_public_api.py b/libs/community/tests/unit_tests/tools/test_public_api.py index 81a5926fdd167..624262f3d1967 100644 --- a/libs/community/tests/unit_tests/tools/test_public_api.py +++ b/libs/community/tests/unit_tests/tools/test_public_api.py @@ -70,6 +70,7 @@ "MerriamWebsterQueryRun", "MetaphorSearchResults", "MoveFileTool", + "NasaAction", "NavigateBackTool", "NavigateTool", "O365CreateDraftMessage", @@ -93,6 +94,8 @@ "RequestsPostTool", "RequestsPutTool", "SceneXplainTool", + "SearchAPIResults", + "SearchAPIRun", "SearxSearchResults", "SearxSearchRun", "ShellTool", @@ -105,6 +108,7 @@ "StackExchangeTool", "SteamshipImageGenerationTool", "StructuredTool", + "SteamWebAPIQueryRun", "Tool", "VectorStoreQATool", "VectorStoreQAWithSourcesTool", diff --git a/libs/community/tests/unit_tests/utilities/test_imports.py b/libs/community/tests/unit_tests/utilities/test_imports.py index 8caae85e3cad3..6499d1b857b2e 100644 --- a/libs/community/tests/unit_tests/utilities/test_imports.py +++ b/libs/community/tests/unit_tests/utilities/test_imports.py @@ -23,6 +23,7 @@ "LambdaWrapper", "MaxComputeAPIWrapper", "MetaphorSearchAPIWrapper", + "NasaAPIWrapper", "OpenWeatherMapAPIWrapper", "OutlineAPIWrapper", "Portkey", @@ -38,6 +39,7 @@ "SerpAPIWrapper", "SparkSQL", "StackExchangeAPIWrapper", + "SteamWebAPIWrapper", "TensorflowDatasets", "TextRequestsWrapper", "TwilioAPIWrapper", diff --git a/libs/core/tests/unit_tests/utils/test_json_schema.py b/libs/core/tests/unit_tests/utils/test_json_schema.py index ffbf0c1353b58..3b0c7b4fe5165 100644 --- a/libs/core/tests/unit_tests/utils/test_json_schema.py +++ b/libs/core/tests/unit_tests/utils/test_json_schema.py @@ -1,5 +1,6 @@ import pytest -from langchain.utils.json_schema import dereference_refs + +from langchain_core.utils.json_schema import dereference_refs def test_dereference_refs_no_refs() -> None: diff --git a/libs/langchain/langchain/adapters/openai.py b/libs/langchain/langchain/adapters/openai.py index bfb2aeea51354..b061bfae696e8 100644 --- a/libs/langchain/langchain/adapters/openai.py +++ b/libs/langchain/langchain/adapters/openai.py @@ -1,7 +1,16 @@ from langchain_community.adapters.openai import ( + Chat, ChatCompletion, + ChatCompletionChunk, + ChatCompletions, + Choice, + ChoiceChunk, + Completions, + IndexableBaseModel, + _convert_message_chunk, _convert_message_chunk_to_delta, _has_assistant_message, + chat, convert_dict_to_message, convert_message_to_dict, convert_messages_for_finetuning, @@ -9,11 +18,20 @@ ) __all__ = [ + "IndexableBaseModel", + "Choice", + "ChatCompletions", + "ChoiceChunk", + "ChatCompletionChunk", "convert_dict_to_message", "convert_message_to_dict", "convert_openai_messages", + "_convert_message_chunk", "_convert_message_chunk_to_delta", "ChatCompletion", "_has_assistant_message", "convert_messages_for_finetuning", + "Completions", + "Chat", + "chat", ] diff --git a/libs/langchain/langchain/agents/agent_toolkits/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/__init__.py index 2266b2413bbf8..062ab1f317633 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/__init__.py +++ b/libs/langchain/langchain/agents/agent_toolkits/__init__.py @@ -34,6 +34,7 @@ from langchain.agents.agent_toolkits.json.base import create_json_agent from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit from langchain.agents.agent_toolkits.multion.toolkit import MultionToolkit +from langchain.agents.agent_toolkits.nasa.toolkit import NasaToolkit from langchain.agents.agent_toolkits.nla.toolkit import NLAToolkit from langchain.agents.agent_toolkits.office365.toolkit import O365Toolkit from langchain.agents.agent_toolkits.openapi.base import create_openapi_agent @@ -47,6 +48,7 @@ from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit from langchain.agents.agent_toolkits.sql.base import create_sql_agent from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain.agents.agent_toolkits.steam.toolkit import SteamToolkit from langchain.agents.agent_toolkits.vectorstore.base import ( create_vectorstore_agent, create_vectorstore_router_agent, @@ -92,12 +94,14 @@ def __getattr__(name: str) -> Any: "JiraToolkit", "JsonToolkit", "MultionToolkit", + "NasaToolkit", "NLAToolkit", "O365Toolkit", "OpenAPIToolkit", "PlayWrightBrowserToolkit", "PowerBIToolkit", "SlackToolkit", + "SteamToolkit", "SQLDatabaseToolkit", "SparkSQLToolkit", "VectorStoreInfo", diff --git a/libs/langchain/langchain/agents/agent_toolkits/github/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/github/toolkit.py index 0242ddd503d09..2c4b9e213b049 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/github/toolkit.py +++ b/libs/langchain/langchain/agents/agent_toolkits/github/toolkit.py @@ -1,3 +1,35 @@ -from langchain_community.agent_toolkits.github.toolkit import GitHubToolkit +from langchain_community.agent_toolkits.github.toolkit import ( + BranchName, + CommentOnIssue, + CreateFile, + CreatePR, + CreateReviewRequest, + DeleteFile, + DirectoryPath, + GetIssue, + GetPR, + GitHubToolkit, + NoInput, + ReadFile, + SearchCode, + SearchIssuesAndPRs, + UpdateFile, +) -__all__ = ["GitHubToolkit"] +__all__ = [ + "NoInput", + "GetIssue", + "CommentOnIssue", + "GetPR", + "CreatePR", + "CreateFile", + "ReadFile", + "UpdateFile", + "DeleteFile", + "DirectoryPath", + "BranchName", + "SearchCode", + "CreateReviewRequest", + "SearchIssuesAndPRs", + "GitHubToolkit", +] diff --git a/libs/langchain/langchain/agents/agent_toolkits/nasa/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/nasa/__init__.py new file mode 100644 index 0000000000000..a13c3ec706c6d --- /dev/null +++ b/libs/langchain/langchain/agents/agent_toolkits/nasa/__init__.py @@ -0,0 +1 @@ +"""NASA Toolkit""" diff --git a/libs/langchain/langchain/agents/agent_toolkits/nasa/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/nasa/toolkit.py new file mode 100644 index 0000000000000..312233762c9a6 --- /dev/null +++ b/libs/langchain/langchain/agents/agent_toolkits/nasa/toolkit.py @@ -0,0 +1,3 @@ +from langchain_community.agent_toolkits.nasa.toolkit import NasaToolkit + +__all__ = ["NasaToolkit"] diff --git a/libs/langchain/langchain/agents/agent_toolkits/steam/__init__.py b/libs/langchain/langchain/agents/agent_toolkits/steam/__init__.py new file mode 100644 index 0000000000000..f99981082424e --- /dev/null +++ b/libs/langchain/langchain/agents/agent_toolkits/steam/__init__.py @@ -0,0 +1 @@ +"""Steam Toolkit.""" diff --git a/libs/langchain/langchain/agents/agent_toolkits/steam/toolkit.py b/libs/langchain/langchain/agents/agent_toolkits/steam/toolkit.py new file mode 100644 index 0000000000000..4fde82fea2413 --- /dev/null +++ b/libs/langchain/langchain/agents/agent_toolkits/steam/toolkit.py @@ -0,0 +1,3 @@ +from langchain_community.agent_toolkits.steam.toolkit import SteamToolkit + +__all__ = ["SteamToolkit"] diff --git a/libs/langchain/langchain/agents/openai_functions_agent/base.py b/libs/langchain/langchain/agents/openai_functions_agent/base.py index 472fec05b41a5..b70f16494d124 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/base.py @@ -26,7 +26,6 @@ ) from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks -from langchain.chat_models.openai import ChatOpenAI from langchain.tools.render import format_tool_to_openai_function @@ -50,12 +49,6 @@ def get_allowed_tools(self) -> List[str]: """Get allowed tools.""" return [t.name for t in self.tools] - @root_validator - def validate_llm(cls, values: dict) -> dict: - if not isinstance(values["llm"], ChatOpenAI): - raise ValueError("Only supported with ChatOpenAI models.") - return values - @root_validator def validate_prompt(cls, values: dict) -> dict: prompt: BasePromptTemplate = values["prompt"] @@ -162,7 +155,7 @@ def return_stopped_response( agent_decision = self.plan( intermediate_steps, with_functions=False, **kwargs ) - if type(agent_decision) == AgentFinish: + if isinstance(agent_decision, AgentFinish): return agent_decision else: raise ValueError( @@ -222,8 +215,6 @@ def from_llm_and_tools( **kwargs: Any, ) -> BaseSingleActionAgent: """Construct an agent from an LLM and tools.""" - if not isinstance(llm, ChatOpenAI): - raise ValueError("Only supported with ChatOpenAI models.") prompt = cls.create_prompt( extra_prompt_messages=extra_prompt_messages, system_message=system_message, diff --git a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py index 8b8d1da9f095b..d25944863757d 100644 --- a/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py +++ b/libs/langchain/langchain/agents/openai_functions_multi_agent/base.py @@ -26,7 +26,6 @@ ) from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks -from langchain.chat_models.openai import ChatOpenAI from langchain.tools import BaseTool # For backwards compatibility @@ -109,12 +108,6 @@ def get_allowed_tools(self) -> List[str]: """Get allowed tools.""" return [t.name for t in self.tools] - @root_validator - def validate_llm(cls, values: dict) -> dict: - if not isinstance(values["llm"], ChatOpenAI): - raise ValueError("Only supported with ChatOpenAI models.") - return values - @root_validator def validate_prompt(cls, values: dict) -> dict: prompt: BasePromptTemplate = values["prompt"] diff --git a/libs/langchain/langchain/agents/self_ask_with_search/base.py b/libs/langchain/langchain/agents/self_ask_with_search/base.py index df3bdea919acb..86a10722017b0 100644 --- a/libs/langchain/langchain/agents/self_ask_with_search/base.py +++ b/libs/langchain/langchain/agents/self_ask_with_search/base.py @@ -13,6 +13,7 @@ from langchain.agents.tools import Tool from langchain.agents.utils import validate_tools_single_input from langchain.utilities.google_serper import GoogleSerperAPIWrapper +from langchain.utilities.searchapi import SearchApiAPIWrapper from langchain.utilities.serpapi import SerpAPIWrapper @@ -64,7 +65,9 @@ class SelfAskWithSearchChain(AgentExecutor): def __init__( self, llm: BaseLanguageModel, - search_chain: Union[GoogleSerperAPIWrapper, SerpAPIWrapper], + search_chain: Union[ + GoogleSerperAPIWrapper, SearchApiAPIWrapper, SerpAPIWrapper + ], **kwargs: Any, ): """Initialize only with an LLM and a search chain.""" diff --git a/libs/langchain/langchain/callbacks/tracers/comet.py b/libs/langchain/langchain/callbacks/tracers/comet.py new file mode 100644 index 0000000000000..bfe7bb44342ce --- /dev/null +++ b/libs/langchain/langchain/callbacks/tracers/comet.py @@ -0,0 +1,138 @@ +from types import ModuleType, SimpleNamespace +from typing import TYPE_CHECKING, Any, Callable, Dict + +from langchain.callbacks.tracers.base import BaseTracer + +if TYPE_CHECKING: + from uuid import UUID + + from comet_llm import Span + from comet_llm.chains.chain import Chain + + from langchain.callbacks.tracers.schemas import Run + + +def _get_run_type(run: "Run") -> str: + if isinstance(run.run_type, str): + return run.run_type + elif hasattr(run.run_type, "value"): + return run.run_type.value + else: + return str(run.run_type) + + +def import_comet_llm_api() -> SimpleNamespace: + """Import comet_llm api and raise an error if it is not installed.""" + try: + from comet_llm import ( + experiment_info, # noqa: F401 + flush, # noqa: F401 + ) + from comet_llm.chains import api as chain_api # noqa: F401 + from comet_llm.chains import ( + chain, # noqa: F401 + span, # noqa: F401 + ) + + except ImportError: + raise ImportError( + "To use the CometTracer you need to have the " + "`comet_llm>=2.0.0` python package installed. Please install it with" + " `pip install -U comet_llm`" + ) + return SimpleNamespace( + chain=chain, + span=span, + chain_api=chain_api, + experiment_info=experiment_info, + flush=flush, + ) + + +class CometTracer(BaseTracer): + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._span_map: Dict["UUID", "Span"] = {} + self._chains_map: Dict["UUID", "Chain"] = {} + self._initialize_comet_modules() + + def _initialize_comet_modules(self) -> None: + comet_llm_api = import_comet_llm_api() + self._chain: ModuleType = comet_llm_api.chain + self._span: ModuleType = comet_llm_api.span + self._chain_api: ModuleType = comet_llm_api.chain_api + self._experiment_info: ModuleType = comet_llm_api.experiment_info + self._flush: Callable[[], None] = comet_llm_api.flush + + def _persist_run(self, run: "Run") -> None: + chain_ = self._chains_map[run.id] + chain_.set_outputs(outputs=run.outputs) + self._chain_api.log_chain(chain_) + + def _process_start_trace(self, run: "Run") -> None: + if not run.parent_run_id: + # This is the first run, which maps to a chain + chain_: "Chain" = self._chain.Chain( + inputs=run.inputs, + metadata=None, + experiment_info=self._experiment_info.get(), + ) + self._chains_map[run.id] = chain_ + else: + span: "Span" = self._span.Span( + inputs=run.inputs, + category=_get_run_type(run), + metadata=run.extra, + name=run.name, + ) + span.__api__start__(self._chains_map[run.parent_run_id]) + self._chains_map[run.id] = self._chains_map[run.parent_run_id] + self._span_map[run.id] = span + + def _process_end_trace(self, run: "Run") -> None: + if not run.parent_run_id: + pass + # Langchain will call _persist_run for us + else: + span = self._span_map[run.id] + span.set_outputs(outputs=run.outputs) + span.__api__end__() + + def flush(self) -> None: + self._flush() + + def _on_llm_start(self, run: "Run") -> None: + """Process the LLM Run upon start.""" + self._process_start_trace(run) + + def _on_llm_end(self, run: "Run") -> None: + """Process the LLM Run.""" + self._process_end_trace(run) + + def _on_llm_error(self, run: "Run") -> None: + """Process the LLM Run upon error.""" + self._process_end_trace(run) + + def _on_chain_start(self, run: "Run") -> None: + """Process the Chain Run upon start.""" + self._process_start_trace(run) + + def _on_chain_end(self, run: "Run") -> None: + """Process the Chain Run.""" + self._process_end_trace(run) + + def _on_chain_error(self, run: "Run") -> None: + """Process the Chain Run upon error.""" + self._process_end_trace(run) + + def _on_tool_start(self, run: "Run") -> None: + """Process the Tool Run upon start.""" + self._process_start_trace(run) + + def _on_tool_end(self, run: "Run") -> None: + """Process the Tool Run.""" + self._process_end_trace(run) + + def _on_tool_error(self, run: "Run") -> None: + """Process the Tool Run upon error.""" + self._process_end_trace(run) diff --git a/libs/langchain/langchain/document_loaders/__init__.py b/libs/langchain/langchain/document_loaders/__init__.py index 119496f9c66e5..ba3867ffcb43c 100644 --- a/libs/langchain/langchain/document_loaders/__init__.py +++ b/libs/langchain/langchain/document_loaders/__init__.py @@ -62,6 +62,7 @@ from langchain.document_loaders.concurrent import ConcurrentLoader from langchain.document_loaders.confluence import ConfluenceLoader from langchain.document_loaders.conllu import CoNLLULoader +from langchain.document_loaders.couchbase import CouchbaseLoader from langchain.document_loaders.csv_loader import CSVLoader, UnstructuredCSVLoader from langchain.document_loaders.cube_semantic import CubeSemanticLoader from langchain.document_loaders.datadog_logs import DatadogLogsLoader @@ -77,7 +78,6 @@ OutlookMessageLoader, UnstructuredEmailLoader, ) -from langchain.document_loaders.embaas import EmbaasBlobLoader, EmbaasLoader from langchain.document_loaders.epub import UnstructuredEPubLoader from langchain.document_loaders.etherscan import EtherscanLoader from langchain.document_loaders.evernote import EverNoteLoader @@ -248,6 +248,7 @@ "CollegeConfidentialLoader", "ConcurrentLoader", "ConfluenceLoader", + "CouchbaseLoader", "CubeSemanticLoader", "DataFrameLoader", "DatadogLogsLoader", @@ -259,8 +260,6 @@ "Docx2txtLoader", "DropboxLoader", "DuckDBLoader", - "EmbaasBlobLoader", - "EmbaasLoader", "EtherscanLoader", "EverNoteLoader", "FacebookChatLoader", diff --git a/libs/langchain/langchain/document_loaders/couchbase.py b/libs/langchain/langchain/document_loaders/couchbase.py new file mode 100644 index 0000000000000..3c7f26e95b47c --- /dev/null +++ b/libs/langchain/langchain/document_loaders/couchbase.py @@ -0,0 +1,3 @@ +from langchain_community.document_loaders.couchbase import CouchbaseLoader, logger + +__all__ = ["logger", "CouchbaseLoader"] diff --git a/libs/langchain/langchain/document_loaders/embaas.py b/libs/langchain/langchain/document_loaders/embaas.py deleted file mode 100644 index 4c2fe9d9deb5b..0000000000000 --- a/libs/langchain/langchain/document_loaders/embaas.py +++ /dev/null @@ -1,17 +0,0 @@ -from langchain_community.document_loaders.embaas import ( - EMBAAS_DOC_API_URL, - BaseEmbaasLoader, - EmbaasBlobLoader, - EmbaasDocumentExtractionParameters, - EmbaasDocumentExtractionPayload, - EmbaasLoader, -) - -__all__ = [ - "EMBAAS_DOC_API_URL", - "EmbaasDocumentExtractionParameters", - "EmbaasDocumentExtractionPayload", - "BaseEmbaasLoader", - "EmbaasBlobLoader", - "EmbaasLoader", -] diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index 8f2887942562e..3710a6e1969fa 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -22,6 +22,7 @@ from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint from langchain.embeddings.bedrock import BedrockEmbeddings +from langchain.embeddings.bookend import BookendEmbeddings from langchain.embeddings.cache import CacheBackedEmbeddings from langchain.embeddings.clarifai import ClarifaiEmbeddings from langchain.embeddings.cohere import CohereEmbeddings @@ -127,6 +128,7 @@ "QianfanEmbeddingsEndpoint", "JohnSnowLabsEmbeddings", "VoyageEmbeddings", + "BookendEmbeddings", ] diff --git a/libs/langchain/langchain/embeddings/bookend.py b/libs/langchain/langchain/embeddings/bookend.py new file mode 100644 index 0000000000000..eb192d19b21b5 --- /dev/null +++ b/libs/langchain/langchain/embeddings/bookend.py @@ -0,0 +1,8 @@ +from langchain_community.embeddings.bookend import ( + API_URL, + DEFAULT_TASK, + PATH, + BookendEmbeddings, +) + +__all__ = ["API_URL", "DEFAULT_TASK", "PATH", "BookendEmbeddings"] diff --git a/libs/langchain/langchain/embeddings/cloudflare_workersai.py b/libs/langchain/langchain/embeddings/cloudflare_workersai.py new file mode 100644 index 0000000000000..4757a6c4e9d42 --- /dev/null +++ b/libs/langchain/langchain/embeddings/cloudflare_workersai.py @@ -0,0 +1,6 @@ +from langchain_community.embeddings.cloudflare_workersai import ( + DEFAULT_MODEL_NAME, + CloudflareWorkersAIEmbeddings, +) + +__all__ = ["DEFAULT_MODEL_NAME", "CloudflareWorkersAIEmbeddings"] diff --git a/libs/langchain/langchain/retrievers/arcee.py b/libs/langchain/langchain/retrievers/arcee.py index 7d3e7b822f5b2..e360f62a03f3b 100644 --- a/libs/langchain/langchain/retrievers/arcee.py +++ b/libs/langchain/langchain/retrievers/arcee.py @@ -61,7 +61,7 @@ def __init__(self, **data: Any) -> None: super().__init__(**data) self._client = ArceeWrapper( - arcee_api_key=self.arcee_api_key, + arcee_api_key=self.arcee_api_key.get_secret_value(), arcee_api_url=self.arcee_api_url, arcee_api_version=self.arcee_api_version, model_kwargs=self.model_kwargs, diff --git a/libs/langchain/langchain/retrievers/azure_cognitive_search.py b/libs/langchain/langchain/retrievers/azure_cognitive_search.py index 27cdd91815715..8824d986cc872 100644 --- a/libs/langchain/langchain/retrievers/azure_cognitive_search.py +++ b/libs/langchain/langchain/retrievers/azure_cognitive_search.py @@ -13,7 +13,10 @@ AsyncCallbackManagerForRetrieverRun, CallbackManagerForRetrieverRun, ) -from langchain.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env, get_from_env + +DEFAULT_URL_SUFFIX = "search.windows.net" +"""Default URL Suffix for endpoint connection - commercial cloud""" class AzureCognitiveSearchRetriever(BaseRetriever): @@ -54,7 +57,10 @@ def validate_environment(cls, values: Dict) -> Dict: return values def _build_search_url(self, query: str) -> str: - base_url = f"https://{self.service_name}.search.windows.net/" + url_suffix = get_from_env( + "", "AZURE_COGNITIVE_SEARCH_URL_SUFFIX", DEFAULT_URL_SUFFIX + ) + base_url = f"https://{self.service_name}.{url_suffix}/" endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}" top_param = f"&$top={self.top_k}" if self.top_k else "" return base_url + endpoint_path + f"&search={query}" + top_param diff --git a/libs/langchain/langchain/retrievers/multi_vector.py b/libs/langchain/langchain/retrievers/multi_vector.py index f05e0859eb15b..dcc81b554c363 100644 --- a/libs/langchain/langchain/retrievers/multi_vector.py +++ b/libs/langchain/langchain/retrievers/multi_vector.py @@ -1,13 +1,13 @@ from enum import Enum -from typing import List +from typing import List, Optional from langchain_core.documents import Document -from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever from langchain_core.stores import BaseStore from langchain_core.vectorstores import VectorStore from langchain.callbacks.manager import CallbackManagerForRetrieverRun +from langchain.storage._lc_store import create_kv_docstore class SearchType(str, Enum): @@ -27,12 +27,35 @@ class MultiVectorRetriever(BaseRetriever): and their embedding vectors""" docstore: BaseStore[str, Document] """The storage layer for the parent documents""" - id_key: str = "doc_id" - search_kwargs: dict = Field(default_factory=dict) + id_key: str + search_kwargs: dict """Keyword arguments to pass to the search function.""" - search_type: SearchType = SearchType.similarity + search_type: SearchType """Type of search to perform (similarity / mmr)""" + def __init__( + self, + *, + vectorstore: VectorStore, + docstore: Optional[BaseStore[str, Document]] = None, + base_store: Optional[BaseStore[str, bytes]] = None, + id_key: str = "doc_id", + search_kwargs: Optional[dict] = None, + search_type: SearchType = SearchType.similarity, + ): + if base_store is not None: + docstore = create_kv_docstore(base_store) + elif docstore is None: + raise Exception("You must pass a `base_store` parameter.") + + super().__init__( + vectorstore=vectorstore, + docstore=docstore, + id_key=id_key, + search_kwargs=search_kwargs if search_kwargs is not None else {}, + search_type=search_type, + ) + def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: diff --git a/libs/langchain/langchain/storage/upstash_redis.py b/libs/langchain/langchain/storage/upstash_redis.py index 922252ebc8aa8..4b5311786c04f 100644 --- a/libs/langchain/langchain/storage/upstash_redis.py +++ b/libs/langchain/langchain/storage/upstash_redis.py @@ -1,3 +1,7 @@ -from langchain_community.storage.upstash_redis import UpstashRedisStore +from langchain_community.storage.upstash_redis import ( + UpstashRedisByteStore, + UpstashRedisStore, + _UpstashRedisStore, +) -__all__ = ["UpstashRedisStore"] +__all__ = ["_UpstashRedisStore", "UpstashRedisStore", "UpstashRedisByteStore"] diff --git a/libs/langchain/langchain/tools/__init__.py b/libs/langchain/langchain/tools/__init__.py index f73341a8010c9..a5478f36435bf 100644 --- a/libs/langchain/langchain/tools/__init__.py +++ b/libs/langchain/langchain/tools/__init__.py @@ -284,6 +284,18 @@ def _import_google_serper_tool_GoogleSerperRun() -> Any: return GoogleSerperRun +def _import_searchapi_tool_SearchAPIResults() -> Any: + from langchain.tools.searchapi.tool import SearchAPIResults + + return SearchAPIResults + + +def _import_searchapi_tool_SearchAPIRun() -> Any: + from langchain.tools.searchapi.tool import SearchAPIRun + + return SearchAPIRun + + def _import_graphql_tool() -> Any: from langchain.tools.graphql.tool import BaseGraphQLTool @@ -338,6 +350,12 @@ def _import_metaphor_search() -> Any: return MetaphorSearchResults +def _import_nasa_tool() -> Any: + from langchain.tools.nasa.tool import NasaAction + + return NasaAction + + def _import_office365_create_draft_message() -> Any: from langchain.tools.office365.create_draft_message import O365CreateDraftMessage @@ -534,6 +552,12 @@ def _import_requests_tool_RequestsPutTool() -> Any: return RequestsPutTool +def _import_steam_webapi_tool() -> Any: + from langchain.tools.steam.tool import SteamWebAPIQueryRun + + return SteamWebAPIQueryRun + + def _import_scenexplain_tool() -> Any: from langchain.tools.scenexplain.tool import SceneXplainTool @@ -807,6 +831,10 @@ def __getattr__(name: str) -> Any: return _import_google_serper_tool_GoogleSerperResults() elif name == "GoogleSerperRun": return _import_google_serper_tool_GoogleSerperRun() + elif name == "SearchAPIResults": + return _import_searchapi_tool_SearchAPIResults() + elif name == "SearchAPIRun": + return _import_searchapi_tool_SearchAPIRun() elif name == "BaseGraphQLTool": return _import_graphql_tool() elif name == "HumanInputRun": @@ -825,6 +853,8 @@ def __getattr__(name: str) -> Any: return _import_merriam_webster_tool() elif name == "MetaphorSearchResults": return _import_metaphor_search() + elif name == "NasaAction": + return _import_nasa_tool() elif name == "O365CreateDraftMessage": return _import_office365_create_draft_message() elif name == "O365SearchEvents": @@ -887,6 +917,8 @@ def __getattr__(name: str) -> Any: return _import_requests_tool_RequestsPostTool() elif name == "RequestsPutTool": return _import_requests_tool_RequestsPutTool() + elif name == "SteamWebAPIQueryRun": + return _import_steam_webapi_tool() elif name == "SceneXplainTool": return _import_scenexplain_tool() elif name == "SearxSearchResults": @@ -1007,6 +1039,8 @@ def __getattr__(name: str) -> Any: "GoogleSearchRun", "GoogleSerperResults", "GoogleSerperRun", + "SearchAPIResults", + "SearchAPIRun", "HumanInputRun", "IFTTTWebhook", "InfoPowerBITool", @@ -1022,6 +1056,7 @@ def __getattr__(name: str) -> Any: "MerriamWebsterQueryRun", "MetaphorSearchResults", "MoveFileTool", + "NasaAction", "NavigateBackTool", "NavigateTool", "O365CreateDraftMessage", @@ -1044,6 +1079,7 @@ def __getattr__(name: str) -> Any: "RequestsPatchTool", "RequestsPostTool", "RequestsPutTool", + "SteamWebAPIQueryRun", "SceneXplainTool", "SearxSearchResults", "SearxSearchRun", diff --git a/libs/langchain/langchain/tools/github/prompt.py b/libs/langchain/langchain/tools/github/prompt.py index a651b015928b1..4a21b8991136d 100644 --- a/libs/langchain/langchain/tools/github/prompt.py +++ b/libs/langchain/langchain/tools/github/prompt.py @@ -1,11 +1,23 @@ from langchain_community.tools.github.prompt import ( COMMENT_ON_ISSUE_PROMPT, + CREATE_BRANCH_PROMPT, CREATE_FILE_PROMPT, CREATE_PULL_REQUEST_PROMPT, + CREATE_REVIEW_REQUEST_PROMPT, DELETE_FILE_PROMPT, + GET_FILES_FROM_DIRECTORY_PROMPT, GET_ISSUE_PROMPT, GET_ISSUES_PROMPT, + GET_PR_PROMPT, + LIST_BRANCHES_IN_REPO_PROMPT, + LIST_PRS_PROMPT, + LIST_PULL_REQUEST_FILES, + OVERVIEW_EXISTING_FILES_BOT_BRANCH, + OVERVIEW_EXISTING_FILES_IN_MAIN, READ_FILE_PROMPT, + SEARCH_CODE_PROMPT, + SEARCH_ISSUES_AND_PRS_PROMPT, + SET_ACTIVE_BRANCH_PROMPT, UPDATE_FILE_PROMPT, ) @@ -18,4 +30,16 @@ "READ_FILE_PROMPT", "UPDATE_FILE_PROMPT", "DELETE_FILE_PROMPT", + "GET_PR_PROMPT", + "LIST_PRS_PROMPT", + "LIST_PULL_REQUEST_FILES", + "OVERVIEW_EXISTING_FILES_IN_MAIN", + "OVERVIEW_EXISTING_FILES_BOT_BRANCH", + "SEARCH_ISSUES_AND_PRS_PROMPT", + "SEARCH_CODE_PROMPT", + "CREATE_REVIEW_REQUEST_PROMPT", + "LIST_BRANCHES_IN_REPO_PROMPT", + "SET_ACTIVE_BRANCH_PROMPT", + "CREATE_BRANCH_PROMPT", + "GET_FILES_FROM_DIRECTORY_PROMPT", ] diff --git a/libs/langchain/langchain/tools/nasa/__init__.py b/libs/langchain/langchain/tools/nasa/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/langchain/langchain/tools/nasa/prompt.py b/libs/langchain/langchain/tools/nasa/prompt.py new file mode 100644 index 0000000000000..a862e0623c1e6 --- /dev/null +++ b/libs/langchain/langchain/tools/nasa/prompt.py @@ -0,0 +1,13 @@ +from langchain_community.tools.nasa.prompt import ( + NASA_CAPTIONS_PROMPT, + NASA_MANIFEST_PROMPT, + NASA_METADATA_PROMPT, + NASA_SEARCH_PROMPT, +) + +__all__ = [ + "NASA_SEARCH_PROMPT", + "NASA_MANIFEST_PROMPT", + "NASA_METADATA_PROMPT", + "NASA_CAPTIONS_PROMPT", +] diff --git a/libs/langchain/langchain/tools/nasa/tool.py b/libs/langchain/langchain/tools/nasa/tool.py new file mode 100644 index 0000000000000..5b8ea1405debd --- /dev/null +++ b/libs/langchain/langchain/tools/nasa/tool.py @@ -0,0 +1,3 @@ +from langchain_community.tools.nasa.tool import NasaAction + +__all__ = ["NasaAction"] diff --git a/libs/langchain/langchain/tools/steam/__init__.py b/libs/langchain/langchain/tools/steam/__init__.py new file mode 100644 index 0000000000000..9367fd95b3089 --- /dev/null +++ b/libs/langchain/langchain/tools/steam/__init__.py @@ -0,0 +1 @@ +"""Steam API toolkit""" diff --git a/libs/langchain/langchain/tools/steam/prompt.py b/libs/langchain/langchain/tools/steam/prompt.py new file mode 100644 index 0000000000000..172c3054461fa --- /dev/null +++ b/libs/langchain/langchain/tools/steam/prompt.py @@ -0,0 +1,6 @@ +from langchain_community.tools.steam.prompt import ( + STEAM_GET_GAMES_DETAILS, + STEAM_GET_RECOMMENDED_GAMES, +) + +__all__ = ["STEAM_GET_GAMES_DETAILS", "STEAM_GET_RECOMMENDED_GAMES"] diff --git a/libs/langchain/langchain/tools/steam/tool.py b/libs/langchain/langchain/tools/steam/tool.py new file mode 100644 index 0000000000000..534b757e76e55 --- /dev/null +++ b/libs/langchain/langchain/tools/steam/tool.py @@ -0,0 +1,3 @@ +from langchain_community.tools.steam.tool import SteamWebAPIQueryRun + +__all__ = ["SteamWebAPIQueryRun"] diff --git a/libs/langchain/langchain/utilities/__init__.py b/libs/langchain/langchain/utilities/__init__.py index 817a5475ab3ff..986f99de295bf 100644 --- a/libs/langchain/langchain/utilities/__init__.py +++ b/libs/langchain/langchain/utilities/__init__.py @@ -218,6 +218,12 @@ def _import_sql_database() -> Any: return SQLDatabase +def _import_steam_webapi() -> Any: + from langchain.utilities.steam import SteamWebAPIWrapper + + return SteamWebAPIWrapper + + def _import_stackexchange() -> Any: from langchain.utilities.stackexchange import StackExchangeAPIWrapper @@ -254,6 +260,12 @@ def _import_zapier() -> Any: return ZapierNLAWrapper +def _import_nasa() -> Any: + from langchain.utilities.nasa import NasaAPIWrapper + + return NasaAPIWrapper + + def __getattr__(name: str) -> Any: if name == "AlphaVantageAPIWrapper": return _import_alpha_vantage() @@ -301,6 +313,8 @@ def __getattr__(name: str) -> Any: return _import_merriam_webster() elif name == "MetaphorSearchAPIWrapper": return _import_metaphor_search() + elif name == "NasaAPIWrapper": + return _import_nasa() elif name == "OpenWeatherMapAPIWrapper": return _import_openweathermap() elif name == "OutlineAPIWrapper": @@ -327,6 +341,8 @@ def __getattr__(name: str) -> Any: return _import_stackexchange() elif name == "SQLDatabase": return _import_sql_database() + elif name == "SteamWebAPIWrapper": + return _import_steam_webapi() elif name == "TensorflowDatasets": return _import_tensorflow_datasets() elif name == "TwilioAPIWrapper": @@ -365,6 +381,7 @@ def __getattr__(name: str) -> Any: "MaxComputeAPIWrapper", "MerriamWebsterAPIWrapper", "MetaphorSearchAPIWrapper", + "NasaAPIWrapper", "OpenWeatherMapAPIWrapper", "OutlineAPIWrapper", "Portkey", @@ -373,6 +390,7 @@ def __getattr__(name: str) -> Any: "PythonREPL", "Requests", "RequestsWrapper", + "SteamWebAPIWrapper", "SQLDatabase", "SceneXplainAPIWrapper", "SearchApiAPIWrapper", diff --git a/libs/langchain/langchain/utilities/nasa.py b/libs/langchain/langchain/utilities/nasa.py new file mode 100644 index 0000000000000..94f356f39ee57 --- /dev/null +++ b/libs/langchain/langchain/utilities/nasa.py @@ -0,0 +1,6 @@ +from langchain_community.utilities.nasa import ( + IMAGE_AND_VIDEO_LIBRARY_URL, + NasaAPIWrapper, +) + +__all__ = ["IMAGE_AND_VIDEO_LIBRARY_URL", "NasaAPIWrapper"] diff --git a/libs/langchain/langchain/utilities/steam.py b/libs/langchain/langchain/utilities/steam.py new file mode 100644 index 0000000000000..37ff4c9690767 --- /dev/null +++ b/libs/langchain/langchain/utilities/steam.py @@ -0,0 +1,3 @@ +from langchain_community.utilities.steam import SteamWebAPIWrapper + +__all__ = ["SteamWebAPIWrapper"] diff --git a/libs/langchain/langchain/vectorstores/momento_vector_index.py b/libs/langchain/langchain/vectorstores/momento_vector_index.py index 301cb150d8c9d..8a9ac81da7f6c 100644 --- a/libs/langchain/langchain/vectorstores/momento_vector_index.py +++ b/libs/langchain/langchain/vectorstores/momento_vector_index.py @@ -1,6 +1,7 @@ from langchain_community.vectorstores.momento_vector_index import ( VST, MomentoVectorIndex, + logger, ) -__all__ = ["VST", "MomentoVectorIndex"] +__all__ = ["VST", "logger", "MomentoVectorIndex"] diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 717a022829d74..2cea6b5ddd801 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1662,6 +1662,40 @@ lint = ["black (>=22.6.0)", "mdformat (>0.7)", "mdformat-gfm (>=0.3.5)", "ruff ( test = ["pytest"] typing = ["mypy (>=0.990)"] +[[package]] +name = "couchbase" +version = "4.1.9" +description = "Python Client for Couchbase" +optional = true +python-versions = ">=3.7" +files = [ + {file = "couchbase-4.1.9-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:f36f65d5ea66ebebe8f9055feb44c72b60b64b8c466ee177c7eaf6d97b71b41a"}, + {file = "couchbase-4.1.9-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b750cb641a44084137444e86ba2cf596e713dceaaa8dcd4a09c370ddd5e3bca2"}, + {file = "couchbase-4.1.9-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:371f4c5e00965d6579e98cd6e49eb8543e3aeabb64d9ac41dae5b85c831faed4"}, + {file = "couchbase-4.1.9-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cfe53bfa29d72d5fa921554408ff7fada301e4641b652f2551060ebd3d1cc096"}, + {file = "couchbase-4.1.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d675d0d862eae34ebedd09e4f40e26ac0092ea0dca93520616cd68d195a1fb3a"}, + {file = "couchbase-4.1.9-cp310-cp310-win_amd64.whl", hash = "sha256:c8adc08a70cbe5e1b1e0e45ebbb4ea5879b3f1aba64d09770d6e35a760201609"}, + {file = "couchbase-4.1.9-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:8f3e16fedb2dd79dba81df5eb1fb6e493ee720ef12be5a2699ac540955775647"}, + {file = "couchbase-4.1.9-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8bb93e17304499fb9b6972efe8a75ea156a097eed983b4802a478ad6cef500b3"}, + {file = "couchbase-4.1.9-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:85da68da9efd5ed35d031a5725744ee36653f940ad16c252d9927f481581366c"}, + {file = "couchbase-4.1.9-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e04f014a8990c89195689af4d332028a6769b45221d861778c079e9f67184e6e"}, + {file = "couchbase-4.1.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:50db238605620ca1a2f4ed36f7820a2d61323a8a425986fd3caf1d9be4eb7f46"}, + {file = "couchbase-4.1.9-cp311-cp311-win_amd64.whl", hash = "sha256:ba9312755c88d39d86cae7ba11c15a6255d8afe5c552bbc1e2f6b66c880bd08e"}, + {file = "couchbase-4.1.9-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:17bdf8db3721e4f7c54b7e50db16fa6c65733d45cfd6c3bf50cd80a7f1672ea8"}, + {file = "couchbase-4.1.9-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a2fb14336b78843691a9f007fbbd0c33959ea4ae4e323112614673601772fb84"}, + {file = "couchbase-4.1.9-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3af36a4b25f948a4dd1a349ba5ddfa87a228cbdfbb8228a5045e187849392857"}, + {file = "couchbase-4.1.9-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f1a6d03fb4fc76aedeede7a55f957936863256b654ce38f05a508925cbd1c713"}, + {file = "couchbase-4.1.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:454c46c9fb6e485f1aba53f64a4b794e2146db480ccd32eaa80b2bba0f53895e"}, + {file = "couchbase-4.1.9-cp38-cp38-win_amd64.whl", hash = "sha256:4c35c2ef600677121b95540c8e78bb43ce5d18cafd49036ea256643ed00ac042"}, + {file = "couchbase-4.1.9-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:40bebe809042efceae95fba8d2a1f0bfecd144c090cf638d8283e038ffea6f19"}, + {file = "couchbase-4.1.9-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f9e956b6580baf4365c4a1b4e22622dc0948447f5ce106d24ed59532302b164f"}, + {file = "couchbase-4.1.9-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:150916388ee2813d242de014fb3ad5e259103e5cd0f1ce600280cc1c11732980"}, + {file = "couchbase-4.1.9-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bf2d1fc8fe22f6e3e4b5e41c7fc367a3a4537dd272a26859f01796724d2ae977"}, + {file = "couchbase-4.1.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9d9ffbb6897a3e68193a8611032230e5d520ae07ae74923305acf8670eb5281b"}, + {file = "couchbase-4.1.9-cp39-cp39-win_amd64.whl", hash = "sha256:b11ff93f4b5da9437fdfb384943dfbf0dac054394d30d21b5e50852dc1d27d2a"}, + {file = "couchbase-4.1.9.tar.gz", hash = "sha256:ee476c5e5b420610e5f4ce778b8c6c7a513f9f4dd4b57fe25000e94ad6eefb9e"}, +] + [[package]] name = "coverage" version = "7.3.2" @@ -4252,7 +4286,7 @@ tests = ["pandas (>=1.4)", "pytest", "pytest-asyncio", "pytest-mock"] [[package]] name = "langchain-core" -version = "0.0.9" +version = "0.0.10" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -4924,29 +4958,29 @@ files = [ [[package]] name = "momento" -version = "1.13.0" +version = "1.14.1" description = "SDK for Momento" optional = true python-versions = ">=3.7,<4.0" files = [ - {file = "momento-1.13.0-py3-none-any.whl", hash = "sha256:dd5ace5b8d679e882afcefaa16bc413973c270b0a7a1c6c45f3eb60b0b9526de"}, - {file = "momento-1.13.0.tar.gz", hash = "sha256:39419627542b8f5997a777ff91aa3aaf6406b7d76fb83cd84284a0f7d1f9e356"}, + {file = "momento-1.14.1-py3-none-any.whl", hash = "sha256:241e46669e39c19627396f2b2b027a912861f1b8097fc9f97b05b76b3d90d199"}, + {file = "momento-1.14.1.tar.gz", hash = "sha256:d200a5e7463f7746a8a611474af1c245183d7ddf9346d9592760b78b6e801560"}, ] [package.dependencies] grpcio = ">=1.46.0,<2.0.0" -momento-wire-types = ">=0.91.1,<0.92.0" +momento-wire-types = ">=0.96.0,<0.97.0" pyjwt = ">=2.4.0,<3.0.0" [[package]] name = "momento-wire-types" -version = "0.91.4" +version = "0.96.0" description = "Momento Client Proto Generated Files" optional = true python-versions = ">=3.7,<4.0" files = [ - {file = "momento_wire_types-0.91.4-py3-none-any.whl", hash = "sha256:f296249693de2f6c383a397e7616b84dd83dfd466743d34b035b90865000a2a8"}, - {file = "momento_wire_types-0.91.4.tar.gz", hash = "sha256:de8cd14a12835d95997eb9b753ea47e1a5d2916658ec9320e416da8bd835fdff"}, + {file = "momento_wire_types-0.96.0-py3-none-any.whl", hash = "sha256:93dc0e3c31bbe1f664ce33974f235bc20e63b5e35ea8e118f0c5e5ed3cda7709"}, + {file = "momento_wire_types-0.96.0.tar.gz", hash = "sha256:9c6c839c698741c54b9fc3a4fe0f82094ea5102418b02bb271ed6e64ea6d7d9e"}, ] [package.dependencies] @@ -11475,7 +11509,7 @@ cli = ["typer"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "couchbase", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] javascript = ["esprima"] llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -11485,4 +11519,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "0cd9769243ade0dc1df941e902aa66c18a57333ae50309f004b4f60e6e27b5cf" +content-hash = "ffccc36a82a8a31fb7b1e3a4d9a024093dfaf25b6115a7c8a1fcbce9d1bb726b" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 31e0a60f89a6a..fd27a8f4e8c1b 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain" -version = "0.0.345" +version = "0.0.346" description = "Building applications with LLMs through composability" authors = [] license = "MIT" @@ -12,7 +12,7 @@ langchain-server = "langchain.server:main" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -langchain-core = ">=0.0.9,<0.1" +langchain-core = ">=0.0.10,<0.1" pydantic = ">=1,<3" SQLAlchemy = ">=1.4,<3" requests = "^2" @@ -147,6 +147,7 @@ hologres-vector = {version = "^0.0.6", optional = true} praw = {version = "^7.7.1", optional = true} msal = {version = "^1.25.0", optional = true} databricks-vectorsearch = {version = "^0.21", optional = true} +couchbase = {version = "^4.1.9", optional = true} dgml-utils = {version = "^0.3.0", optional = true} datasets = {version = "^2.15.0", optional = true} @@ -391,6 +392,7 @@ extended_testing = [ "hologres-vector", "praw", "databricks-vectorsearch", + "couchbase", "dgml-utils", "cohere", ] diff --git a/libs/langchain/tests/integration_tests/chains/test_self_ask_with_search.py b/libs/langchain/tests/integration_tests/chains/test_self_ask_with_search.py index 61ef78d9228d2..f288f23dadf72 100644 --- a/libs/langchain/tests/integration_tests/chains/test_self_ask_with_search.py +++ b/libs/langchain/tests/integration_tests/chains/test_self_ask_with_search.py @@ -1,7 +1,7 @@ """Integration test for self ask with search.""" from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain from langchain.llms.openai import OpenAI -from langchain.utilities.google_serper import GoogleSerperAPIWrapper +from langchain.utilities.searchapi import SearchApiAPIWrapper def test_self_ask_with_search() -> None: @@ -9,10 +9,10 @@ def test_self_ask_with_search() -> None: question = "What is the hometown of the reigning men's U.S. Open champion?" chain = SelfAskWithSearchChain( llm=OpenAI(temperature=0), - search_chain=GoogleSerperAPIWrapper(), + search_chain=SearchApiAPIWrapper(), input_key="q", output_key="a", ) answer = chain.run(question) final_answer = answer.split("\n")[-1] - assert final_answer == "El Palmar, Spain" + assert final_answer == "Belgrade, Serbia" diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_couchbase.py b/libs/langchain/tests/integration_tests/document_loaders/test_couchbase.py new file mode 100644 index 0000000000000..d4585d0796b54 --- /dev/null +++ b/libs/langchain/tests/integration_tests/document_loaders/test_couchbase.py @@ -0,0 +1,44 @@ +import unittest + +from langchain.document_loaders.couchbase import CouchbaseLoader + +try: + import couchbase # noqa: F401 + + couchbase_installed = True +except ImportError: + couchbase_installed = False + + +@unittest.skipIf(not couchbase_installed, "couchbase not installed") +class TestCouchbaseLoader(unittest.TestCase): + def setUp(self) -> None: + self.conn_string = "" + self.database_user = "" + self.database_password = "" + self.valid_query = "select h.* from `travel-sample`.inventory.hotel h limit 10" + self.valid_page_content_fields = ["country", "name", "description"] + self.valid_metadata_fields = ["id"] + + def test_couchbase_loader(self) -> None: + """Test Couchbase loader.""" + loader = CouchbaseLoader( + connection_string=self.conn_string, + db_username=self.database_user, + db_password=self.database_password, + query=self.valid_query, + page_content_fields=self.valid_page_content_fields, + metadata_fields=self.valid_metadata_fields, + ) + docs = loader.load() + print(docs) + + assert len(docs) > 0 # assuming the query returns at least one document + for doc in docs: + print(doc) + assert ( + doc.page_content != "" + ) # assuming that every document has page_content + assert ( + "id" in doc.metadata and doc.metadata["id"] != "" + ) # assuming that every document has 'id' diff --git a/libs/langchain/tests/integration_tests/document_loaders/test_embaas.py b/libs/langchain/tests/integration_tests/document_loaders/test_embaas.py deleted file mode 100644 index 2170a143c66ac..0000000000000 --- a/libs/langchain/tests/integration_tests/document_loaders/test_embaas.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import Any -from unittest.mock import MagicMock, patch - -import responses - -from langchain.document_loaders import EmbaasBlobLoader, EmbaasLoader -from langchain.document_loaders.blob_loaders import Blob -from langchain.document_loaders.embaas import EMBAAS_DOC_API_URL - - -@responses.activate -def test_handle_request() -> None: - responses.add( - responses.POST, - EMBAAS_DOC_API_URL, - json={ - "data": { - "chunks": [ - { - "text": "Hello", - "metadata": {"start_page": 1, "end_page": 2}, - "embeddings": [0.0], - } - ] - } - }, - status=200, - ) - - loader = EmbaasBlobLoader(embaas_api_key="api_key", params={"should_embed": True}) - documents = loader.parse(blob=Blob.from_data(data="Hello")) - assert len(documents) == 1 - assert documents[0].page_content == "Hello" - assert documents[0].metadata["start_page"] == 1 - assert documents[0].metadata["end_page"] == 2 - assert documents[0].metadata["embeddings"] == [0.0] - - -@responses.activate -def test_handle_request_exception() -> None: - responses.add( - responses.POST, - EMBAAS_DOC_API_URL, - json={"message": "Invalid request"}, - status=400, - ) - loader = EmbaasBlobLoader(embaas_api_key="api_key") - try: - loader.parse(blob=Blob.from_data(data="Hello")) - except Exception as e: - assert "Invalid request" in str(e) - - -@patch.object(EmbaasBlobLoader, "_handle_request") -def test_load(mock_handle_request: Any) -> None: - mock_handle_request.return_value = [MagicMock()] - loader = EmbaasLoader(file_path="test_embaas.py", embaas_api_key="api_key") - documents = loader.load() - assert len(documents) == 1 diff --git a/libs/langchain/tests/integration_tests/embeddings/test_bookend.py b/libs/langchain/tests/integration_tests/embeddings/test_bookend.py new file mode 100644 index 0000000000000..940f67063802c --- /dev/null +++ b/libs/langchain/tests/integration_tests/embeddings/test_bookend.py @@ -0,0 +1,27 @@ +"""Test Bookend AI embeddings.""" +from langchain.embeddings.bookend import BookendEmbeddings + + +def test_bookend_embedding_documents() -> None: + """Test Bookend AI embeddings for documents.""" + documents = ["foo bar", "bar foo"] + embedding = BookendEmbeddings( + domain="", + api_token="", + model_id="", + ) + output = embedding.embed_documents(documents) + assert len(output) == 2 + assert len(output[0]) == 768 + + +def test_bookend_embedding_query() -> None: + """Test Bookend AI embeddings for query.""" + document = "foo bar" + embedding = BookendEmbeddings( + domain="", + api_token="", + model_id="", + ) + output = embedding.embed_query(document) + assert len(output) == 768 diff --git a/libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py b/libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py new file mode 100644 index 0000000000000..24ac031371728 --- /dev/null +++ b/libs/langchain/tests/integration_tests/embeddings/test_cloudflare_workersai.py @@ -0,0 +1,53 @@ +"""Test Cloudflare Workers AI embeddings.""" + +import responses + +from langchain.embeddings.cloudflare_workersai import CloudflareWorkersAIEmbeddings + + +@responses.activate +def test_cloudflare_workers_ai_embedding_documents() -> None: + """Test Cloudflare Workers AI embeddings.""" + documents = ["foo bar", "foo bar", "foo bar"] + + responses.add( + responses.POST, + "https://api.cloudflare.com/client/v4/accounts/123/ai/run/@cf/baai/bge-base-en-v1.5", + json={ + "result": { + "shape": [3, 768], + "data": [[0.0] * 768, [0.0] * 768, [0.0] * 768], + }, + "success": "true", + "errors": [], + "messages": [], + }, + ) + + embeddings = CloudflareWorkersAIEmbeddings(account_id="123", api_token="abc") + output = embeddings.embed_documents(documents) + + assert len(output) == 3 + assert len(output[0]) == 768 + + +@responses.activate +def test_cloudflare_workers_ai_embedding_query() -> None: + """Test Cloudflare Workers AI embeddings.""" + + responses.add( + responses.POST, + "https://api.cloudflare.com/client/v4/accounts/123/ai/run/@cf/baai/bge-base-en-v1.5", + json={ + "result": {"shape": [1, 768], "data": [[0.0] * 768]}, + "success": "true", + "errors": [], + "messages": [], + }, + ) + + document = "foo bar" + embeddings = CloudflareWorkersAIEmbeddings(account_id="123", api_token="abc") + output = embeddings.embed_query(document) + + assert len(output) == 768 diff --git a/libs/langchain/tests/integration_tests/llms/test_arcee.py b/libs/langchain/tests/integration_tests/llms/test_arcee.py index a795c1909ca3d..40daec3682fb9 100644 --- a/libs/langchain/tests/integration_tests/llms/test_arcee.py +++ b/libs/langchain/tests/integration_tests/llms/test_arcee.py @@ -1,34 +1,70 @@ -"""Test Arcee llm""" +from unittest.mock import MagicMock, patch + from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture, MonkeyPatch from langchain.llms.arcee import Arcee -def test_api_key_is_secret_string() -> None: - llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") - assert isinstance(llm.arcee_api_key, SecretStr) +@patch("langchain.utilities.arcee.requests.get") +def test_arcee_api_key_is_secret_string(mock_get: MagicMock) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } + arcee_without_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="secret_api_key", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + assert isinstance(arcee_without_env_var.arcee_api_key, SecretStr) -def test_api_key_masked_when_passed_from_env( - monkeypatch: MonkeyPatch, capsys: CaptureFixture -) -> None: - """Test initialization with an API key provided via an env variable""" - monkeypatch.setenv("ARCEE_API_KEY", "test-arcee-api-key") - llm = Arcee(model="DALM-PubMed") +@patch("langchain.utilities.arcee.requests.get") +def test_api_key_masked_when_passed_via_constructor( + mock_get: MagicMock, capsys: CaptureFixture +) -> None: + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } - print(llm.arcee_api_key, end="") + arcee_without_env_var = Arcee( + model="DALM-PubMed", + arcee_api_key="secret_api_key", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + print(arcee_without_env_var.arcee_api_key, end="") captured = capsys.readouterr() - assert captured.out == "**********" + assert "**********" == captured.out -def test_api_key_masked_when_passed_via_constructor( - capsys: CaptureFixture, + +@patch("langchain.utilities.arcee.requests.get") +def test_api_key_masked_when_passed_from_env( + mock_get: MagicMock, capsys: CaptureFixture, monkeypatch: MonkeyPatch ) -> None: - """Test initialization with an API key provided via the initializer""" - llm = Arcee(model="DALM-PubMed", arcee_api_key="test-arcee-api-key") + mock_response = mock_get.return_value + mock_response.status_code = 200 + mock_response.json.return_value = { + "model_id": "", + "status": "training_complete", + } - print(llm.arcee_api_key, end="") + monkeypatch.setenv("ARCEE_API_KEY", "secret_api_key") + arcee_with_env_var = Arcee( + model="DALM-PubMed", + arcee_api_url="https://localhost", + arcee_api_version="version", + ) + print(arcee_with_env_var.arcee_api_key, end="") captured = capsys.readouterr() - assert captured.out == "**********" + + assert "**********" == captured.out diff --git a/libs/langchain/tests/integration_tests/storage/test_upstash_redis.py b/libs/langchain/tests/integration_tests/storage/test_upstash_redis.py index 183e09515b6a4..01ab5831fe3ed 100644 --- a/libs/langchain/tests/integration_tests/storage/test_upstash_redis.py +++ b/libs/langchain/tests/integration_tests/storage/test_upstash_redis.py @@ -5,7 +5,7 @@ import pytest -from langchain.storage.upstash_redis import UpstashRedisStore +from langchain.storage.upstash_redis import UpstashRedisByteStore if TYPE_CHECKING: from upstash_redis import Redis @@ -34,16 +34,16 @@ def redis_client() -> Redis: def test_mget(redis_client: Redis) -> None: - store = UpstashRedisStore(client=redis_client, ttl=None) + store = UpstashRedisByteStore(client=redis_client, ttl=None) keys = ["key1", "key2"] redis_client.mset({"key1": "value1", "key2": "value2"}) result = store.mget(keys) - assert result == ["value1", "value2"] + assert result == [b"value1", b"value2"] def test_mset(redis_client: Redis) -> None: - store = UpstashRedisStore(client=redis_client, ttl=None) - key_value_pairs = [("key1", "value1"), ("key2", "value2")] + store = UpstashRedisByteStore(client=redis_client, ttl=None) + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] store.mset(key_value_pairs) result = redis_client.mget("key1", "key2") assert result == ["value1", "value2"] @@ -51,7 +51,7 @@ def test_mset(redis_client: Redis) -> None: def test_mdelete(redis_client: Redis) -> None: """Test that deletion works as expected.""" - store = UpstashRedisStore(client=redis_client, ttl=None) + store = UpstashRedisByteStore(client=redis_client, ttl=None) keys = ["key1", "key2"] redis_client.mset({"key1": "value1", "key2": "value2"}) store.mdelete(keys) @@ -60,7 +60,7 @@ def test_mdelete(redis_client: Redis) -> None: def test_yield_keys(redis_client: Redis) -> None: - store = UpstashRedisStore(client=redis_client, ttl=None) + store = UpstashRedisByteStore(client=redis_client, ttl=None) redis_client.mset({"key1": "value2", "key2": "value2"}) assert sorted(store.yield_keys()) == ["key1", "key2"] assert sorted(store.yield_keys(prefix="key*")) == ["key1", "key2"] @@ -68,8 +68,8 @@ def test_yield_keys(redis_client: Redis) -> None: def test_namespace(redis_client: Redis) -> None: - store = UpstashRedisStore(client=redis_client, ttl=None, namespace="meow") - key_value_pairs = [("key1", "value1"), ("key2", "value2")] + store = UpstashRedisByteStore(client=redis_client, ttl=None, namespace="meow") + key_value_pairs = [("key1", b"value1"), ("key2", b"value2")] store.mset(key_value_pairs) cursor, all_keys = redis_client.scan(0) diff --git a/libs/langchain/tests/integration_tests/utilities/test_duckduckdgo_search_api.py b/libs/langchain/tests/integration_tests/utilities/test_duckduckdgo_search_api.py index 8d228e573d6b3..74f0f25fa370c 100644 --- a/libs/langchain/tests/integration_tests/utilities/test_duckduckdgo_search_api.py +++ b/libs/langchain/tests/integration_tests/utilities/test_duckduckdgo_search_api.py @@ -1,11 +1,11 @@ import pytest -from langchain.tools.ddg_search.tool import DuckDuckGoSearchRun +from langchain.tools.ddg_search.tool import DuckDuckGoSearchResults, DuckDuckGoSearchRun def ddg_installed() -> bool: try: - from duckduckgo_search import ddg # noqa: F401 + from duckduckgo_search import DDGS # noqa: F401 return True except Exception as e: @@ -20,3 +20,12 @@ def test_ddg_search_tool() -> None: result = tool(keywords) print(result) assert len(result.split()) > 20 + + +@pytest.mark.skipif(not ddg_installed(), reason="requires duckduckgo-search package") +def test_ddg_search_news_tool() -> None: + keywords = "Tesla" + tool = DuckDuckGoSearchResults(source="news") + result = tool(keywords) + print(result) + assert len(result.split()) > 20 diff --git a/libs/langchain/tests/integration_tests/utilities/test_nasa.py b/libs/langchain/tests/integration_tests/utilities/test_nasa.py new file mode 100644 index 0000000000000..c605626afd865 --- /dev/null +++ b/libs/langchain/tests/integration_tests/utilities/test_nasa.py @@ -0,0 +1,32 @@ +"""Integration test for NASA API Wrapper.""" +from langchain.utilities.nasa import NasaAPIWrapper + + +def test_media_search() -> None: + """Test for NASA Image and Video Library media search""" + nasa = NasaAPIWrapper() + query = '{"q": "saturn", + "year_start": "2002", "year_end": "2010", "page": 2}' + output = nasa.run("search_media", query) + assert output is not None + assert "collection" in output + + +def test_get_media_metadata_manifest() -> None: + """Test for retrieving media metadata manifest from NASA Image and Video Library""" + nasa = NasaAPIWrapper() + output = nasa.run("get_media_metadata_manifest", "2022_0707_Recientemente") + assert output is not None + + +def test_get_media_metadata_location() -> None: + """Test for retrieving media metadata location from NASA Image and Video Library""" + nasa = NasaAPIWrapper() + output = nasa.run("get_media_metadata_location", "as11-40-5874") + assert output is not None + + +def test_get_video_captions_location() -> None: + """Test for retrieving video captions location from NASA Image and Video Library""" + nasa = NasaAPIWrapper() + output = nasa.run("get_video_captions_location", "172_ISS-Slosh.sr") + assert output is not None diff --git a/libs/langchain/tests/integration_tests/utilities/test_steam_api.py b/libs/langchain/tests/integration_tests/utilities/test_steam_api.py new file mode 100644 index 0000000000000..24664b3943852 --- /dev/null +++ b/libs/langchain/tests/integration_tests/utilities/test_steam_api.py @@ -0,0 +1,22 @@ +import ast + +from langchain.utilities.steam import SteamWebAPIWrapper + + +def test_get_game_details() -> None: + """Test for getting game details on Steam""" + steam = SteamWebAPIWrapper() + output = steam.run("get_game_details", "Terraria") + assert "id" in output + assert "link" in output + assert "detailed description" in output + assert "supported languages" in output + assert "price" in output + + +def test_get_recommended_games() -> None: + """Test for getting recommended games on Steam""" + steam = SteamWebAPIWrapper() + output = steam.run("get_recommended_games", "76561198362745711") + output = ast.literal_eval(output) + assert len(output) == 5 diff --git a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py index 7b99c696444af..5de74832de05b 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/libs/langchain/tests/integration_tests/vectorstores/fake_embeddings.py @@ -53,11 +53,6 @@ def embed_query(self, text: str) -> List[float]: """Return consistent embeddings for the text, if seen before, or a constant one if the text is unknown.""" return self.embed_documents([text])[0] - if text not in self.known_texts: - return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)] - return [float(1.0)] * (self.dimensionality - 1) + [ - float(self.known_texts.index(text)) - ] class AngularTwoDimensionalEmbeddings(Embeddings): diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_momento_vector_index.py b/libs/langchain/tests/integration_tests/vectorstores/test_momento_vector_index.py index 7689088ac5196..c4f20cf2e117a 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_momento_vector_index.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_momento_vector_index.py @@ -125,7 +125,7 @@ def test_from_texts_with_metadatas( def test_from_texts_with_scores(vector_store: MomentoVectorIndex) -> None: - # """Test end to end construction and search with scores and IDs.""" + """Test end to end construction and search with scores and IDs.""" texts = ["apple", "orange", "hammer"] metadatas = [{"page": f"{i}"} for i in range(len(texts))] @@ -162,3 +162,25 @@ def test_add_documents_with_ids(vector_store: MomentoVectorIndex) -> None: ) assert isinstance(response, Search.Success) assert [hit.id for hit in response.hits] == ids + + +def test_max_marginal_relevance_search(vector_store: MomentoVectorIndex) -> None: + """Test max marginal relevance search.""" + pepperoni_pizza = "pepperoni pizza" + cheese_pizza = "cheese pizza" + hot_dog = "hot dog" + + vector_store.add_texts([pepperoni_pizza, cheese_pizza, hot_dog]) + wait() + search_results = vector_store.similarity_search("pizza", k=2) + + assert search_results == [ + Document(page_content=pepperoni_pizza, metadata={}), + Document(page_content=cheese_pizza, metadata={}), + ] + + search_results = vector_store.max_marginal_relevance_search(query="pizza", k=2) + assert search_results == [ + Document(page_content=pepperoni_pizza, metadata={}), + Document(page_content=hot_dog, metadata={}), + ] diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py index f658abe260854..94be5295c2f3f 100644 --- a/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_base_tracer.py @@ -332,6 +332,42 @@ def test_tracer_llm_run_on_error() -> None: assert tracer.runs == [compare_run] +@freeze_time("2023-01-01") +def test_tracer_llm_run_on_error_callback() -> None: + """Test tracer on an LLM run with an error and a callback.""" + exception = Exception("test") + uuid = uuid4() + + compare_run = Run( + id=str(uuid), + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + events=[ + {"name": "start", "time": datetime.utcnow()}, + {"name": "error", "time": datetime.utcnow()}, + ], + extra={}, + execution_order=1, + child_execution_order=1, + serialized=SERIALIZED, + inputs=dict(prompts=[]), + outputs=None, + error=repr(exception), + run_type="llm", + ) + + class FakeTracerWithLlmErrorCallback(FakeTracer): + error_run = None + + def _on_llm_error(self, run: Run) -> None: + self.error_run = run + + tracer = FakeTracerWithLlmErrorCallback() + tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid) + tracer.on_llm_error(exception, run_id=uuid) + assert tracer.error_run == compare_run + + @freeze_time("2023-01-01") def test_tracer_chain_run_on_error() -> None: """Test tracer on a Chain run with an error.""" diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_comet.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_comet.py new file mode 100644 index 0000000000000..3dff6520d477e --- /dev/null +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_comet.py @@ -0,0 +1,97 @@ +import uuid +from types import SimpleNamespace +from unittest import mock + +from langchain.callbacks.tracers import comet +from langchain.schema.output import LLMResult + + +def test_comet_tracer__trace_chain_with_single_span__happyflow() -> None: + # Setup mocks + chain_module_mock = mock.Mock() + chain_instance_mock = mock.Mock() + chain_module_mock.Chain.return_value = chain_instance_mock + + span_module_mock = mock.Mock() + span_instance_mock = mock.MagicMock() + span_instance_mock.__api__start__ = mock.Mock() + span_instance_mock.__api__end__ = mock.Mock() + + span_module_mock.Span.return_value = span_instance_mock + + experiment_info_module_mock = mock.Mock() + experiment_info_module_mock.get.return_value = "the-experiment-info" + + chain_api_module_mock = mock.Mock() + + comet_ml_api_mock = SimpleNamespace( + chain=chain_module_mock, + span=span_module_mock, + experiment_info=experiment_info_module_mock, + chain_api=chain_api_module_mock, + flush="not-used-in-this-test", + ) + + # Create tracer + with mock.patch.object( + comet, "import_comet_llm_api", return_value=comet_ml_api_mock + ): + tracer = comet.CometTracer() + + run_id_1 = uuid.UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a") + run_id_2 = uuid.UUID("4f31216e-7c26-4027-a5fd-0bbf9ace17dc") + + # Parent run + tracer.on_chain_start( + {"name": "chain-input"}, + {"input": "chain-input-prompt"}, + parent_run_id=None, + run_id=run_id_1, + ) + + # Check that chain was created + chain_module_mock.Chain.assert_called_once_with( + inputs={"input": "chain-input-prompt"}, + metadata=None, + experiment_info="the-experiment-info", + ) + + # Child run + tracer.on_llm_start( + {"name": "span-input"}, + ["span-input-prompt"], + parent_run_id=run_id_1, + run_id=run_id_2, + ) + + # Check that Span was created and attached to chain + span_module_mock.Span.assert_called_once_with( + inputs={"prompts": ["span-input-prompt"]}, + category=mock.ANY, + metadata=mock.ANY, + name=mock.ANY, + ) + span_instance_mock.__api__start__(chain_instance_mock) + + # Child run end + tracer.on_llm_end( + LLMResult(generations=[], llm_output={"span-output-key": "span-output-value"}), + run_id=run_id_2, + ) + # Check that Span outputs are set and span is ended + span_instance_mock.set_outputs.assert_called_once() + actual_span_outputs = span_instance_mock.set_outputs.call_args[1]["outputs"] + assert { + "llm_output": {"span-output-key": "span-output-value"}, + "generations": [], + }.items() <= actual_span_outputs.items() + span_instance_mock.__api__end__() + + # Parent run end + tracer.on_chain_end({"chain-output-key": "chain-output-value"}, run_id=run_id_1) + + # Check that chain outputs are set and chain is logged + chain_instance_mock.set_outputs.assert_called_once() + actual_chain_outputs = chain_instance_mock.set_outputs.call_args[1]["outputs"] + assert ("chain-output-key", "chain-output-value") in actual_chain_outputs.items() + chain_api_module_mock.log_chain.assert_called_once_with(chain_instance_mock) diff --git a/libs/langchain/tests/unit_tests/document_loaders/blob_loaders/test_schema.py b/libs/langchain/tests/unit_tests/document_loaders/blob_loaders/test_schema.py index 2f18549b2a062..811af3dabd13d 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/blob_loaders/test_schema.py +++ b/libs/langchain/tests/unit_tests/document_loaders/blob_loaders/test_schema.py @@ -131,3 +131,20 @@ def yield_blobs(self) -> Iterable[Blob]: yield Blob(data=b"Hello, World!") assert list(TestLoader().yield_blobs()) == [Blob(data=b"Hello, World!")] + + +def test_metadata_and_source() -> None: + """Test metadata and source""" + blob = Blob(path="some_file", data="b") + assert blob.source == "some_file" + assert blob.metadata == {} + blob = Blob(data=b"", metadata={"source": "hello"}) + assert blob.source == "hello" + assert blob.metadata == {"source": "hello"} + + blob = Blob.from_data("data", metadata={"source": "somewhere"}) + assert blob.source == "somewhere" + + with get_temp_file(b"hello") as path: + blob = Blob.from_path(path, metadata={"source": "somewhere"}) + assert blob.source == "somewhere" diff --git a/libs/langchain/tests/unit_tests/document_loaders/sample_documents/obsidian/template_var_frontmatter.md b/libs/langchain/tests/unit_tests/document_loaders/sample_documents/obsidian/template_var_frontmatter.md new file mode 100644 index 0000000000000..7bab90737c31f --- /dev/null +++ b/libs/langchain/tests/unit_tests/document_loaders/sample_documents/obsidian/template_var_frontmatter.md @@ -0,0 +1,12 @@ +--- +aString: {{var}} +anArray: +- element +- {{varElement}} +aDict: + dictId1: 'val' + dictId2: '{{varVal}}' +tags: [ 'tag', '{{varTag}}' ] +--- + +Frontmatter contains template variables. diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_couchbase.py b/libs/langchain/tests/unit_tests/document_loaders/test_couchbase.py new file mode 100644 index 0000000000000..ec05691cbb0ec --- /dev/null +++ b/libs/langchain/tests/unit_tests/document_loaders/test_couchbase.py @@ -0,0 +1,6 @@ +"""Test importing the Couchbase document loader.""" + + +def test_couchbase_import() -> None: + """Test that the Couchbase document loader can be imported.""" + from langchain.document_loaders import CouchbaseLoader # noqa: F401 diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_imports.py b/libs/langchain/tests/unit_tests/document_loaders/test_imports.py index db754275234ba..18f6f22a5f00d 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_imports.py @@ -41,6 +41,7 @@ "CollegeConfidentialLoader", "ConcurrentLoader", "ConfluenceLoader", + "CouchbaseLoader", "CubeSemanticLoader", "DataFrameLoader", "DatadogLogsLoader", @@ -52,8 +53,6 @@ "Docx2txtLoader", "DropboxLoader", "DuckDBLoader", - "EmbaasBlobLoader", - "EmbaasLoader", "EtherscanLoader", "EverNoteLoader", "FacebookChatLoader", diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py b/libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py index 50f29d849e17b..e25bf80199d82 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_obsidian.py @@ -17,7 +17,7 @@ def test_page_content_loaded() -> None: """Verify that all docs have page_content""" - assert len(docs) == 5 + assert len(docs) == 6 assert all(doc.page_content for doc in docs) @@ -27,7 +27,7 @@ def test_disable_collect_metadata() -> None: str(OBSIDIAN_EXAMPLE_PATH), collect_metadata=False ) docs_wo = loader_without_metadata.load() - assert len(docs_wo) == 5 + assert len(docs_wo) == 6 assert all(doc.page_content for doc in docs_wo) assert all(set(doc.metadata) == STANDARD_METADATA_FIELDS for doc in docs_wo) @@ -45,6 +45,24 @@ def test_metadata_with_frontmatter() -> None: assert set(doc.metadata["tags"].split(",")) == {"journal/entry", "obsidian"} +def test_metadata_with_template_vars_in_frontmatter() -> None: + """Verify frontmatter fields with template variables are loaded.""" + doc = next( + doc for doc in docs if doc.metadata["source"] == "template_var_frontmatter.md" + ) + FRONTMATTER_FIELDS = { + "aString", + "anArray", + "aDict", + "tags", + } + assert set(doc.metadata) == FRONTMATTER_FIELDS | STANDARD_METADATA_FIELDS + assert doc.metadata["aString"] == "{{var}}" + assert doc.metadata["anArray"] == "['element', '{{varElement}}']" + assert doc.metadata["aDict"] == "{'dictId1': 'val', 'dictId2': '{{varVal}}'}" + assert set(doc.metadata["tags"].split(",")) == {"tag", "{{varTag}}"} + + def test_metadata_with_bad_frontmatter() -> None: """Verify a doc with non-yaml frontmatter.""" doc = next(doc for doc in docs if doc.metadata["source"] == "bad_frontmatter.md") diff --git a/libs/langchain/tests/unit_tests/embeddings/test_imports.py b/libs/langchain/tests/unit_tests/embeddings/test_imports.py index 9de69602dc6a7..8fe5df0994a50 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_imports.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_imports.py @@ -53,6 +53,7 @@ "QianfanEmbeddingsEndpoint", "JohnSnowLabsEmbeddings", "VoyageEmbeddings", + "BookendEmbeddings", ] diff --git a/libs/langchain/tests/unit_tests/test_dependencies.py b/libs/langchain/tests/unit_tests/test_dependencies.py index 2d326dcbbc082..a5297dd569f5a 100644 --- a/libs/langchain/tests/unit_tests/test_dependencies.py +++ b/libs/langchain/tests/unit_tests/test_dependencies.py @@ -104,5 +104,8 @@ def test_imports() -> None: from langchain.llms import OpenAI # noqa: F401 from langchain.retrievers import VespaRetriever # noqa: F401 from langchain.tools import DuckDuckGoSearchResults # noqa: F401 - from langchain.utilities import SerpAPIWrapper # noqa: F401 + from langchain.utilities import ( + SearchApiAPIWrapper, # noqa: F401 + SerpAPIWrapper, # noqa: F401 + ) from langchain.vectorstores import FAISS # noqa: F401 diff --git a/libs/langchain/tests/unit_tests/tools/test_imports.py b/libs/langchain/tests/unit_tests/tools/test_imports.py index 58bd210e9bb00..a0960a7d2f551 100644 --- a/libs/langchain/tests/unit_tests/tools/test_imports.py +++ b/libs/langchain/tests/unit_tests/tools/test_imports.py @@ -68,6 +68,7 @@ "ListSparkSQLTool", "MetaphorSearchResults", "MoveFileTool", + "NasaAction", "NavigateBackTool", "NavigateTool", "O365CreateDraftMessage", @@ -91,6 +92,8 @@ "RequestsPostTool", "RequestsPutTool", "SceneXplainTool", + "SearchAPIRun", + "SearchAPIResults", "SearxSearchResults", "SearxSearchRun", "ShellTool", @@ -101,6 +104,7 @@ "SleepTool", "StackExchangeTool", "StdInInquireTool", + "SteamWebAPIQueryRun", "SteamshipImageGenerationTool", "StructuredTool", "Tool", diff --git a/libs/langchain/tests/unit_tests/tools/test_public_api.py b/libs/langchain/tests/unit_tests/tools/test_public_api.py index 4db38fd13e13d..87ebceaae32f3 100644 --- a/libs/langchain/tests/unit_tests/tools/test_public_api.py +++ b/libs/langchain/tests/unit_tests/tools/test_public_api.py @@ -70,6 +70,7 @@ "MerriamWebsterQueryRun", "MetaphorSearchResults", "MoveFileTool", + "NasaAction", "NavigateBackTool", "NavigateTool", "O365CreateDraftMessage", @@ -93,6 +94,8 @@ "RequestsPostTool", "RequestsPutTool", "SceneXplainTool", + "SearchAPIResults", + "SearchAPIRun", "SearxSearchResults", "SearxSearchRun", "ShellTool", @@ -105,6 +108,7 @@ "StackExchangeTool", "SteamshipImageGenerationTool", "StructuredTool", + "SteamWebAPIQueryRun", "Tool", "VectorStoreQATool", "VectorStoreQAWithSourcesTool", diff --git a/libs/langchain/tests/unit_tests/utilities/test_imports.py b/libs/langchain/tests/unit_tests/utilities/test_imports.py index f1e6a27eda96f..49fddcff65522 100644 --- a/libs/langchain/tests/unit_tests/utilities/test_imports.py +++ b/libs/langchain/tests/unit_tests/utilities/test_imports.py @@ -23,6 +23,7 @@ "LambdaWrapper", "MaxComputeAPIWrapper", "MetaphorSearchAPIWrapper", + "NasaAPIWrapper", "OpenWeatherMapAPIWrapper", "OutlineAPIWrapper", "Portkey", @@ -38,6 +39,7 @@ "SerpAPIWrapper", "SparkSQL", "StackExchangeAPIWrapper", + "SteamWebAPIWrapper", "TensorflowDatasets", "TextRequestsWrapper", "TwilioAPIWrapper",