diff --git a/libs/community/langchain_community/chat_models/human.py b/libs/community/langchain_community/chat_models/human.py index 0ac1a407c92a6..e0294746934f8 100644 --- a/libs/community/langchain_community/chat_models/human.py +++ b/libs/community/langchain_community/chat_models/human.py @@ -1,12 +1,9 @@ """ChatModel wrapper which returns user input as the response..""" -import asyncio -from functools import partial from io import StringIO from typing import Any, Callable, Dict, List, Mapping, Optional import yaml from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import BaseChatModel @@ -111,15 +108,3 @@ def _generate( self.message_func(messages, **self.message_kwargs) user_input = self.input_func(messages, stop=stop, **self.input_kwargs) return ChatResult(generations=[ChatGeneration(message=user_input)]) - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - func = partial( - self._generate, messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await asyncio.get_event_loop().run_in_executor(None, func) diff --git a/libs/community/langchain_community/chat_models/mlflow.py b/libs/community/langchain_community/chat_models/mlflow.py index ee289527bb005..7068644439ded 100644 --- a/libs/community/langchain_community/chat_models/mlflow.py +++ b/libs/community/langchain_community/chat_models/mlflow.py @@ -1,11 +1,8 @@ -import asyncio import logging -from functools import partial from typing import Any, Dict, List, Mapping, Optional from urllib.parse import urlparse from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models import BaseChatModel @@ -125,18 +122,6 @@ def _generate( resp = self._client.predict(endpoint=self.endpoint, inputs=data) return ChatMlflow._create_chat_result(resp) - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - func = partial( - self._generate, messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await asyncio.get_event_loop().run_in_executor(None, func) - @property def _identifying_params(self) -> Dict[str, Any]: return self._default_params diff --git a/libs/community/langchain_community/chat_models/mlflow_ai_gateway.py b/libs/community/langchain_community/chat_models/mlflow_ai_gateway.py index 5674f69fc2c0a..39ad6b550b176 100644 --- a/libs/community/langchain_community/chat_models/mlflow_ai_gateway.py +++ b/libs/community/langchain_community/chat_models/mlflow_ai_gateway.py @@ -1,11 +1,8 @@ -import asyncio import logging import warnings -from functools import partial from typing import Any, Dict, List, Mapping, Optional from langchain_core.callbacks import ( - AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import BaseChatModel @@ -116,18 +113,6 @@ def _generate( resp = mlflow.gateway.query(self.route, data=data) return ChatMLflowAIGateway._create_chat_result(resp) - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - func = partial( - self._generate, messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await asyncio.get_event_loop().run_in_executor(None, func) - @property def _identifying_params(self) -> Dict[str, Any]: return self._default_params diff --git a/libs/community/langchain_community/chat_models/pai_eas_endpoint.py b/libs/community/langchain_community/chat_models/pai_eas_endpoint.py index 85f13246817d9..e9f231514d012 100644 --- a/libs/community/langchain_community/chat_models/pai_eas_endpoint.py +++ b/libs/community/langchain_community/chat_models/pai_eas_endpoint.py @@ -1,7 +1,5 @@ -import asyncio import json import logging -from functools import partial from typing import Any, AsyncIterator, Dict, List, Optional, cast import requests @@ -300,25 +298,3 @@ async def _astream( # break if stop sequence found if stop_seq_found: break - - async def _agenerate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, - stream: Optional[bool] = None, - **kwargs: Any, - ) -> ChatResult: - if stream if stream is not None else self.streaming: - generation: Optional[ChatGenerationChunk] = None - async for chunk in self._astream( - messages=messages, stop=stop, run_manager=run_manager, **kwargs - ): - generation = chunk - assert generation is not None - return ChatResult(generations=[generation]) - - func = partial( - self._generate, messages, stop=stop, run_manager=run_manager, **kwargs - ) - return await asyncio.get_event_loop().run_in_executor(None, func) diff --git a/libs/community/langchain_community/embeddings/bedrock.py b/libs/community/langchain_community/embeddings/bedrock.py index 8e98bfe285819..529809fb91163 100644 --- a/libs/community/langchain_community/embeddings/bedrock.py +++ b/libs/community/langchain_community/embeddings/bedrock.py @@ -1,11 +1,11 @@ import asyncio import json import os -from functools import partial from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.runnables.config import run_in_executor class BedrockEmbeddings(BaseModel, Embeddings): @@ -181,9 +181,7 @@ async def aembed_query(self, text: str) -> List[float]: Embeddings for the text. """ - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.embed_query, text) - ) + return await run_in_executor(None, self.embed_query, text) async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Asynchronous compute doc embeddings using a Bedrock model. diff --git a/libs/community/langchain_community/embeddings/ernie.py b/libs/community/langchain_community/embeddings/ernie.py index 0e2d19f6b5d24..5467e4c027814 100644 --- a/libs/community/langchain_community/embeddings/ernie.py +++ b/libs/community/langchain_community/embeddings/ernie.py @@ -1,12 +1,12 @@ import asyncio import logging import threading -from functools import partial from typing import Dict, List, Optional import requests from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.runnables.config import run_in_executor from langchain_core.utils import get_from_dict_or_env logger = logging.getLogger(__name__) @@ -134,9 +134,7 @@ async def aembed_query(self, text: str) -> List[float]: List[float]: Embeddings for the text. """ - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.embed_query, text) - ) + return await run_in_executor(None, self.embed_query, text) async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Asynchronous Embed search docs. diff --git a/libs/community/langchain_community/tools/multion/close_session.py b/libs/community/langchain_community/tools/multion/close_session.py index 7aaead7fa0c44..8232d861e2d56 100644 --- a/libs/community/langchain_community/tools/multion/close_session.py +++ b/libs/community/langchain_community/tools/multion/close_session.py @@ -1,8 +1,6 @@ -import asyncio from typing import TYPE_CHECKING, Optional, Type from langchain_core.callbacks import ( - AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.pydantic_v1 import BaseModel, Field @@ -57,11 +55,3 @@ def _run( print(f"{e}, retrying...") except Exception as e: raise Exception(f"An error occurred: {e}") - - async def _arun( - self, - sessionId: str, - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, - ) -> None: - loop = asyncio.get_running_loop() - await loop.run_in_executor(None, self._run, sessionId) diff --git a/libs/community/langchain_community/tools/multion/create_session.py b/libs/community/langchain_community/tools/multion/create_session.py index 9f93676ee1810..de6983cb4bc2d 100644 --- a/libs/community/langchain_community/tools/multion/create_session.py +++ b/libs/community/langchain_community/tools/multion/create_session.py @@ -1,8 +1,6 @@ -import asyncio from typing import TYPE_CHECKING, Optional, Type from langchain_core.callbacks import ( - AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.pydantic_v1 import BaseModel, Field @@ -67,14 +65,3 @@ def _run( } except Exception as e: raise Exception(f"An error occurred: {e}") - - async def _arun( - self, - query: str, - url: Optional[str] = "https://www.google.com/", - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, - ) -> dict: - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, self._run, query, url) - - return result diff --git a/libs/community/langchain_community/tools/multion/update_session.py b/libs/community/langchain_community/tools/multion/update_session.py index 97a8f1ff4a36c..fe92c36dd76c0 100644 --- a/libs/community/langchain_community/tools/multion/update_session.py +++ b/libs/community/langchain_community/tools/multion/update_session.py @@ -1,8 +1,6 @@ -import asyncio from typing import TYPE_CHECKING, Optional, Type from langchain_core.callbacks import ( - AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.pydantic_v1 import BaseModel, Field @@ -74,15 +72,3 @@ def _run( return {"error": f"{e}", "Response": "retrying..."} except Exception as e: raise Exception(f"An error occurred: {e}") - - async def _arun( - self, - sessionId: str, - query: str, - url: Optional[str] = "https://www.google.com/", - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, - ) -> dict: - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, self._run, sessionId, query, url) - - return result diff --git a/libs/community/langchain_community/tools/shell/tool.py b/libs/community/langchain_community/tools/shell/tool.py index e92d51445aaf4..5f61631059d92 100644 --- a/libs/community/langchain_community/tools/shell/tool.py +++ b/libs/community/langchain_community/tools/shell/tool.py @@ -1,10 +1,8 @@ -import asyncio import platform import warnings from typing import Any, List, Optional, Type, Union from langchain_core.callbacks import ( - AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.pydantic_v1 import BaseModel, Field, root_validator @@ -77,13 +75,3 @@ def _run( ) -> str: """Run commands and return final output.""" return self.process.run(commands) - - async def _arun( - self, - commands: Union[str, List[str]], - run_manager: Optional[AsyncCallbackManagerForToolRun] = None, - ) -> str: - """Run commands asynchronously and return final output.""" - return await asyncio.get_event_loop().run_in_executor( - None, self.process.run, commands - ) diff --git a/libs/community/langchain_community/vectorstores/faiss.py b/libs/community/langchain_community/vectorstores/faiss.py index 7473af4a7b555..4e7e5619e192c 100644 --- a/libs/community/langchain_community/vectorstores/faiss.py +++ b/libs/community/langchain_community/vectorstores/faiss.py @@ -1,13 +1,11 @@ from __future__ import annotations -import asyncio import logging import operator import os import pickle import uuid import warnings -from functools import partial from pathlib import Path from typing import ( Any, @@ -24,6 +22,7 @@ import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings +from langchain_core.runnables.config import run_in_executor from langchain_core.vectorstores import VectorStore from langchain_community.docstore.base import AddableMixin, Docstore @@ -359,7 +358,8 @@ async def asimilarity_search_with_score_by_vector( """ # This is a temporary workaround to make the similarity search asynchronous. - func = partial( + return await run_in_executor( + None, self.similarity_search_with_score_by_vector, embedding, k=k, @@ -367,7 +367,6 @@ async def asimilarity_search_with_score_by_vector( fetch_k=fetch_k, **kwargs, ) - return await asyncio.get_event_loop().run_in_executor(None, func) def similarity_search_with_score( self, @@ -640,7 +639,8 @@ async def amax_marginal_relevance_search_with_score_by_vector( relevance and score for each. """ # This is a temporary workaround to make the similarity search asynchronous. - func = partial( + return await run_in_executor( + None, self.max_marginal_relevance_search_with_score_by_vector, embedding, k=k, @@ -648,7 +648,6 @@ async def amax_marginal_relevance_search_with_score_by_vector( lambda_mult=lambda_mult, filter=filter, ) - return await asyncio.get_event_loop().run_in_executor(None, func) def max_marginal_relevance_search_by_vector( self, diff --git a/libs/community/langchain_community/vectorstores/pgvector.py b/libs/community/langchain_community/vectorstores/pgvector.py index b7b3e529c6de8..8a75297ba057e 100644 --- a/libs/community/langchain_community/vectorstores/pgvector.py +++ b/libs/community/langchain_community/vectorstores/pgvector.py @@ -1,11 +1,9 @@ from __future__ import annotations -import asyncio import contextlib import enum import logging import uuid -from functools import partial from typing import ( Any, Callable, @@ -31,6 +29,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings +from langchain_core.runnables.config import run_in_executor from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore @@ -941,7 +940,8 @@ async def amax_marginal_relevance_search_by_vector( # This is a temporary workaround to make the similarity search # asynchronous. The proper solution is to make the similarity search # asynchronous in the vector store implementations. - func = partial( + return await run_in_executor( + None, self.max_marginal_relevance_search_by_vector, embedding, k=k, @@ -950,4 +950,3 @@ async def amax_marginal_relevance_search_by_vector( filter=filter, **kwargs, ) - return await asyncio.get_event_loop().run_in_executor(None, func) diff --git a/libs/community/langchain_community/vectorstores/qdrant.py b/libs/community/langchain_community/vectorstores/qdrant.py index ea881c4cbd6df..7c18da82ec0d9 100644 --- a/libs/community/langchain_community/vectorstores/qdrant.py +++ b/libs/community/langchain_community/vectorstores/qdrant.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import functools import uuid import warnings @@ -25,6 +24,7 @@ import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings +from langchain_core.runnables.config import run_in_executor from langchain_core.vectorstores import VectorStore from langchain_community.vectorstores.utils import maximal_marginal_relevance @@ -58,10 +58,9 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: # by removing the first letter from the method name. For example, # if the async method is called ``aaad_texts``, the synchronous method # will be called ``aad_texts``. - sync_method = functools.partial( - getattr(self, method.__name__[1:]), *args, **kwargs + return await run_in_executor( + None, getattr(self, method.__name__[1:]), *args, **kwargs ) - return await asyncio.get_event_loop().run_in_executor(None, sync_method) return wrapper diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index f13af30b2c50a..db7e3b1708f0f 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -23,7 +23,7 @@ RunnableSerializable, coerce_to_runnable, ) -from langchain_core.runnables.config import RunnableConfig, patch_config +from langchain_core.runnables.config import RunnableConfig, ensure_config, patch_config from langchain_core.runnables.utils import ConfigurableFieldSpec, Input, Output T = TypeVar("T") @@ -186,7 +186,7 @@ def config_specs(self) -> List[ConfigurableFieldSpec]: ] def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: - config = config or {} + config = ensure_config(config) configurable = config.get("configurable", {}) if isinstance(self.key, list): return {key: configurable[id_]() for key, id_ in zip(self.key, self.ids)} @@ -196,7 +196,7 @@ def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: async def ainvoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: - config = config or {} + config = ensure_config(config) configurable = config.get("configurable", {}) if isinstance(self.key, list): values = await asyncio.gather(*(configurable[id_]() for id_ in self.ids)) @@ -281,7 +281,7 @@ def config_specs(self) -> List[ConfigurableFieldSpec]: ] def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: - config = config or {} + config = ensure_config(config) configurable = config.get("configurable", {}) for id_, mapper in zip(self.ids, self.keys.values()): if mapper is not None: @@ -293,7 +293,7 @@ def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: async def ainvoke( self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Any: - config = config or {} + config = ensure_config(config) configurable = config.get("configurable", {}) for id_, mapper in zip(self.ids, self.keys.values()): if mapper is not None: diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index fba38e57ab534..771edb72de2c8 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -4,13 +4,15 @@ import functools import logging import uuid +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager -from contextvars import Context, copy_context +from contextvars import copy_context from typing import ( TYPE_CHECKING, Any, AsyncGenerator, + Callable, Coroutine, Dict, Generator, @@ -272,25 +274,14 @@ def handle_event( # we end up in a deadlock, as we'd have gotten here from a # running coroutine, which we cannot interrupt to run this one. # The solution is to create a new loop in a new thread. - with _executor_w_context(1) as executor: - executor.submit(_run_coros, coros).result() + with ThreadPoolExecutor(1) as executor: + executor.submit( + cast(Callable, copy_context().run), _run_coros, coros + ).result() else: _run_coros(coros) -def _set_context(context: Context) -> None: - for var, value in context.items(): - var.set(value) - - -def _executor_w_context(max_workers: Optional[int] = None) -> ThreadPoolExecutor: - return ThreadPoolExecutor( - max_workers=max_workers, - initializer=_set_context, - initargs=(copy_context(),), - ) - - def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: if hasattr(asyncio, "Runner"): # Python 3.11+ @@ -315,7 +306,6 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: async def _ahandle_event_for_handler( - executor: ThreadPoolExecutor, handler: BaseCallbackHandler, event_name: str, ignore_condition_name: Optional[str], @@ -332,13 +322,18 @@ async def _ahandle_event_for_handler( event(*args, **kwargs) else: await asyncio.get_event_loop().run_in_executor( - executor, functools.partial(event, *args, **kwargs) + None, + cast( + Callable, + functools.partial( + copy_context().run, event, *args, **kwargs + ), + ), ) except NotImplementedError as e: if event_name == "on_chat_model_start": message_strings = [get_buffer_string(m) for m in args[1]] await _ahandle_event_for_handler( - executor, handler, "on_llm_start", "ignore_llm", @@ -380,25 +375,23 @@ async def ahandle_event( *args: The arguments to pass to the event handler **kwargs: The keyword arguments to pass to the event handler """ - with _executor_w_context() as executor: - for handler in [h for h in handlers if h.run_inline]: - await _ahandle_event_for_handler( - executor, handler, event_name, ignore_condition_name, *args, **kwargs - ) - await asyncio.gather( - *( - _ahandle_event_for_handler( - executor, - handler, - event_name, - ignore_condition_name, - *args, - **kwargs, - ) - for handler in handlers - if not handler.run_inline + for handler in [h for h in handlers if h.run_inline]: + await _ahandle_event_for_handler( + handler, event_name, ignore_condition_name, *args, **kwargs + ) + await asyncio.gather( + *( + _ahandle_event_for_handler( + handler, + event_name, + ignore_condition_name, + *args, + **kwargs, ) + for handler in handlers + if not handler.run_inline ) + ) BRM = TypeVar("BRM", bound="BaseRunManager") @@ -526,9 +519,17 @@ def get_child(self, tag: Optional[str] = None) -> CallbackManager: return manager -class AsyncRunManager(BaseRunManager): +class AsyncRunManager(BaseRunManager, ABC): """Async Run Manager.""" + @abstractmethod + def get_sync(self) -> RunManager: + """Get the equivalent sync RunManager. + + Returns: + RunManager: The sync RunManager. + """ + async def on_text( self, text: str, @@ -664,6 +665,23 @@ def on_llm_error( class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): """Async callback manager for LLM run.""" + def get_sync(self) -> CallbackManagerForLLMRun: + """Get the equivalent sync RunManager. + + Returns: + CallbackManagerForLLMRun: The sync RunManager. + """ + return CallbackManagerForLLMRun( + run_id=self.run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + async def on_llm_new_token( self, token: str, @@ -818,6 +836,23 @@ def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): """Async callback manager for chain run.""" + def get_sync(self) -> CallbackManagerForChainRun: + """Get the equivalent sync RunManager. + + Returns: + CallbackManagerForChainRun: The sync RunManager. + """ + return CallbackManagerForChainRun( + run_id=self.run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + async def on_chain_end( self, outputs: Union[Dict[str, Any], Any], **kwargs: Any ) -> None: @@ -948,6 +983,23 @@ def on_tool_error( class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): """Async callback manager for tool run.""" + def get_sync(self) -> CallbackManagerForToolRun: + """Get the equivalent sync RunManager. + + Returns: + CallbackManagerForToolRun: The sync RunManager. + """ + return CallbackManagerForToolRun( + run_id=self.run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + async def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running. @@ -1031,6 +1083,23 @@ class AsyncCallbackManagerForRetrieverRun( ): """Async callback manager for retriever run.""" + def get_sync(self) -> CallbackManagerForRetrieverRun: + """Get the equivalent sync RunManager. + + Returns: + CallbackManagerForRetrieverRun: The sync RunManager. + """ + return CallbackManagerForRetrieverRun( + run_id=self.run_id, + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + async def on_retriever_end( self, documents: Sequence[Document], **kwargs: Any ) -> None: diff --git a/libs/core/langchain_core/documents/transformers.py b/libs/core/langchain_core/documents/transformers.py index 5d0418cbb356a..245e6a715c295 100644 --- a/libs/core/langchain_core/documents/transformers.py +++ b/libs/core/langchain_core/documents/transformers.py @@ -1,10 +1,10 @@ from __future__ import annotations -import asyncio from abc import ABC, abstractmethod -from functools import partial from typing import TYPE_CHECKING, Any, Sequence +from langchain_core.runnables.config import run_in_executor + if TYPE_CHECKING: from langchain_core.documents import Document @@ -69,6 +69,6 @@ async def atransform_documents( Returns: A list of transformed Documents. """ - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.transform_documents, **kwargs), documents + return await run_in_executor( + None, self.transform_documents, documents, **kwargs ) diff --git a/libs/core/langchain_core/embeddings.py b/libs/core/langchain_core/embeddings.py index c08a279750b8d..ffc963097b8ed 100644 --- a/libs/core/langchain_core/embeddings.py +++ b/libs/core/langchain_core/embeddings.py @@ -1,7 +1,8 @@ -import asyncio from abc import ABC, abstractmethod from typing import List +from langchain_core.runnables.config import run_in_executor + class Embeddings(ABC): """Interface for embedding models.""" @@ -16,12 +17,8 @@ def embed_query(self, text: str) -> List[float]: async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Asynchronous Embed search docs.""" - return await asyncio.get_running_loop().run_in_executor( - None, self.embed_documents, texts - ) + return await run_in_executor(None, self.embed_documents, texts) async def aembed_query(self, text: str) -> List[float]: """Asynchronous Embed query text.""" - return await asyncio.get_running_loop().run_in_executor( - None, self.embed_query, text - ) + return await run_in_executor(None, self.embed_query, text) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index dba21ba71c16e..047908f06e1f8 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -4,7 +4,6 @@ import inspect import warnings from abc import ABC, abstractmethod -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -45,6 +44,7 @@ ) from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.runnables.config import ensure_config, run_in_executor if TYPE_CHECKING: from langchain_core.runnables import RunnableConfig @@ -158,7 +158,7 @@ def invoke( stop: Optional[List[str]] = None, **kwargs: Any, ) -> BaseMessage: - config = config or {} + config = ensure_config(config) return cast( ChatGeneration, self.generate_prompt( @@ -180,7 +180,7 @@ async def ainvoke( stop: Optional[List[str]] = None, **kwargs: Any, ) -> BaseMessage: - config = config or {} + config = ensure_config(config) llm_result = await self.agenerate_prompt( [self._convert_input(input)], stop=stop, @@ -206,7 +206,7 @@ def stream( BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs) ) else: - config = config or {} + config = ensure_config(config) messages = self._convert_input(input).to_messages() params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop, **kwargs} @@ -264,7 +264,7 @@ async def astream( await self.ainvoke(input, config=config, stop=stop, **kwargs), ) else: - config = config or {} + config = ensure_config(config) messages = self._convert_input(input).to_messages() params = self._get_invocation_params(stop=stop, **kwargs) options = {"stop": stop, **kwargs} @@ -605,8 +605,13 @@ async def _agenerate( **kwargs: Any, ) -> ChatResult: """Top Level call""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._generate, **kwargs), messages, stop, run_manager + return await run_in_executor( + None, + self._generate, + messages, + stop, + run_manager.get_sync() if run_manager else None, + **kwargs, ) def _stream( @@ -766,7 +771,11 @@ async def _agenerate( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - func = partial( - self._generate, messages, stop=stop, run_manager=run_manager, **kwargs + return await run_in_executor( + None, + self._generate, + messages, + stop=stop, + run_manager=run_manager.get_sync() if run_manager else None, + **kwargs, ) - return await asyncio.get_event_loop().run_in_executor(None, func) diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index e0e830d10be7e..4ecfc93521fdd 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -8,7 +8,6 @@ import logging import warnings from abc import ABC, abstractmethod -from functools import partial from pathlib import Path from typing import ( Any, @@ -52,7 +51,8 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.pydantic_v1 import Field, root_validator, validator -from langchain_core.runnables import RunnableConfig, get_config_list +from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list +from langchain_core.runnables.config import run_in_executor logger = logging.getLogger(__name__) @@ -221,7 +221,7 @@ def invoke( stop: Optional[List[str]] = None, **kwargs: Any, ) -> str: - config = config or {} + config = ensure_config(config) return ( self.generate_prompt( [self._convert_input(input)], @@ -244,7 +244,7 @@ async def ainvoke( stop: Optional[List[str]] = None, **kwargs: Any, ) -> str: - config = config or {} + config = ensure_config(config) llm_result = await self.agenerate_prompt( [self._convert_input(input)], stop=stop, @@ -362,7 +362,7 @@ def stream( yield self.invoke(input, config=config, stop=stop, **kwargs) else: prompt = self._convert_input(input).to_string() - config = config or {} + config = ensure_config(config) params = self.dict() params["stop"] = stop params = {**params, **kwargs} @@ -419,7 +419,7 @@ async def astream( yield await self.ainvoke(input, config=config, stop=stop, **kwargs) else: prompt = self._convert_input(input).to_string() - config = config or {} + config = ensure_config(config) params = self.dict() params["stop"] = stop params = {**params, **kwargs} @@ -483,8 +483,13 @@ async def _agenerate( **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompts.""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._generate, **kwargs), prompts, stop, run_manager + return await run_in_executor( + None, + self._generate, + prompts, + stop, + run_manager.get_sync() if run_manager else None, + **kwargs, ) def _stream( @@ -1049,8 +1054,13 @@ async def _acall( **kwargs: Any, ) -> str: """Run the LLM on the given prompt and input.""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._call, **kwargs), prompt, stop, run_manager + return await run_in_executor( + None, + self._call, + prompt, + stop, + run_manager.get_sync() if run_manager else None, + **kwargs, ) def _generate( diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index c5c2e379e8c65..5972b2f3b2006 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -1,7 +1,5 @@ from __future__ import annotations -import asyncio -import functools from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, @@ -20,6 +18,7 @@ from langchain_core.messages import AnyMessage, BaseMessage from langchain_core.outputs import ChatGeneration, Generation from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.runnables.config import run_in_executor if TYPE_CHECKING: from langchain_core.prompt_values import PromptValue @@ -54,9 +53,7 @@ async def aparse_result( Returns: Structured output. """ - return await asyncio.get_running_loop().run_in_executor( - None, self.parse_result, result - ) + return await run_in_executor(None, self.parse_result, result) class BaseGenerationOutputParser( @@ -247,9 +244,7 @@ async def aparse_result( Returns: Structured output. """ - return await asyncio.get_running_loop().run_in_executor( - None, functools.partial(self.parse_result, partial=partial), result - ) + return await run_in_executor(None, self.parse_result, result, partial=partial) async def aparse(self, text: str) -> T: """Parse a single string model output into some structure. @@ -260,7 +255,7 @@ async def aparse(self, text: str) -> T: Returns: Structured output. """ - return await asyncio.get_running_loop().run_in_executor(None, self.parse, text) + return await run_in_executor(None, self.parse, text) # TODO: rename 'completion' -> 'text'. def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any: diff --git a/libs/core/langchain_core/retrievers.py b/libs/core/langchain_core/retrievers.py index 5fe912e032143..f215636695998 100644 --- a/libs/core/langchain_core/retrievers.py +++ b/libs/core/langchain_core/retrievers.py @@ -1,15 +1,19 @@ from __future__ import annotations -import asyncio import warnings from abc import ABC, abstractmethod -from functools import partial from inspect import signature from typing import TYPE_CHECKING, Any, Dict, List, Optional from langchain_core.documents import Document from langchain_core.load.dump import dumpd -from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable +from langchain_core.runnables import ( + Runnable, + RunnableConfig, + RunnableSerializable, + ensure_config, +) +from langchain_core.runnables.config import run_in_executor if TYPE_CHECKING: from langchain_core.callbacks.manager import ( @@ -113,7 +117,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: def invoke( self, input: str, config: Optional[RunnableConfig] = None ) -> List[Document]: - config = config or {} + config = ensure_config(config) return self.get_relevant_documents( input, callbacks=config.get("callbacks"), @@ -128,7 +132,7 @@ async def ainvoke( config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> List[Document]: - config = config or {} + config = ensure_config(config) return await self.aget_relevant_documents( input, callbacks=config.get("callbacks"), @@ -159,8 +163,11 @@ async def _aget_relevant_documents( Returns: List of relevant documents """ - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._get_relevant_documents, run_manager=run_manager), query + return await run_in_executor( + None, + self._get_relevant_documents, + query, + run_manager=run_manager.get_sync(), ) def get_relevant_documents( diff --git a/libs/core/langchain_core/runnables/__init__.py b/libs/core/langchain_core/runnables/__init__.py index e1c9a995cb3b0..2d23a78dc17a6 100644 --- a/libs/core/langchain_core/runnables/__init__.py +++ b/libs/core/langchain_core/runnables/__init__.py @@ -27,8 +27,10 @@ from langchain_core.runnables.branch import RunnableBranch from langchain_core.runnables.config import ( RunnableConfig, + ensure_config, get_config_list, patch_config, + run_in_executor, ) from langchain_core.runnables.fallbacks import RunnableWithFallbacks from langchain_core.runnables.passthrough import ( @@ -42,6 +44,7 @@ ConfigurableField, ConfigurableFieldMultiOption, ConfigurableFieldSingleOption, + ConfigurableFieldSpec, aadd, add, ) @@ -51,6 +54,9 @@ "ConfigurableField", "ConfigurableFieldSingleOption", "ConfigurableFieldMultiOption", + "ConfigurableFieldSpec", + "ensure_config", + "run_in_executor", "patch_config", "RouterInput", "RouterRunnable", diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index b42a17ec4ec58..27ee8409a7166 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, wait from copy import deepcopy -from functools import partial, wraps +from functools import wraps from itertools import groupby, tee from operator import itemgetter from typing import ( @@ -47,6 +47,7 @@ get_executor_for_config, merge_configs, patch_config, + run_in_executor, ) from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( @@ -472,10 +473,7 @@ async def ainvoke( Subclasses should override this method if they can run asynchronously. """ - with get_executor_for_config(config) as executor: - return await asyncio.get_running_loop().run_in_executor( - executor, partial(self.invoke, **kwargs), input, config - ) + return await run_in_executor(config, self.invoke, input, config, **kwargs) def batch( self, @@ -665,7 +663,7 @@ async def astream_log( ) # Assign the stream handler to the config - config = config or {} + config = ensure_config(config) callbacks = config.get("callbacks") if callbacks is None: config["callbacks"] = [stream] @@ -2883,10 +2881,7 @@ async def _ainvoke( @wraps(self.func) async def f(*args, **kwargs): # type: ignore[no-untyped-def] - with get_executor_for_config(config) as executor: - return await asyncio.get_running_loop().run_in_executor( - executor, partial(self.func, **kwargs), *args - ) + return await run_in_executor(config, self.func, *args, **kwargs) afunc = f @@ -2913,7 +2908,7 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def] def _config( self, config: Optional[RunnableConfig], callable: Callable[..., Any] ) -> RunnableConfig: - config = config or {} + config = ensure_config(config) if config.get("run_name") is None: try: @@ -3052,9 +3047,7 @@ async def _atransform( @wraps(self.func) async def f(*args, **kwargs): # type: ignore[no-untyped-def] - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.func, **kwargs), *args - ) + return await run_in_executor(config, self.func, *args, **kwargs) afunc = f diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index c803f52b29558..080dfa9cdbea8 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -1,8 +1,10 @@ from __future__ import annotations -from concurrent.futures import Executor, ThreadPoolExecutor +import asyncio +from concurrent.futures import Executor, Future, ThreadPoolExecutor from contextlib import contextmanager -from contextvars import Context, copy_context +from contextvars import ContextVar, copy_context +from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -10,13 +12,16 @@ Callable, Dict, Generator, + Iterable, + Iterator, List, Optional, + TypeVar, Union, cast, ) -from typing_extensions import TypedDict +from typing_extensions import ParamSpec, TypedDict from langchain_core.runnables.utils import ( Input, @@ -91,6 +96,11 @@ class RunnableConfig(TypedDict, total=False): """ +var_child_runnable_config = ContextVar( + "child_runnable_config", default=RunnableConfig() +) + + def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: """Ensure that a config is a dict with all keys present. @@ -107,6 +117,10 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: callbacks=None, recursion_limit=25, ) + if var_config := var_child_runnable_config.get(): + empty.update( + cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None}) + ) if config is not None: empty.update( cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}) @@ -388,9 +402,51 @@ def get_async_callback_manager_for_config( ) -def _set_context(context: Context) -> None: - for var, value in context.items(): - var.set(value) +P = ParamSpec("P") +T = TypeVar("T") + + +class ContextThreadPoolExecutor(ThreadPoolExecutor): + """ThreadPoolExecutor that copies the context to the child thread.""" + + def submit( # type: ignore[override] + self, + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, + ) -> Future[T]: + """Submit a function to the executor. + + Args: + func (Callable[..., T]): The function to submit. + *args (Any): The positional arguments to the function. + **kwargs (Any): The keyword arguments to the function. + + Returns: + Future[T]: The future for the function. + """ + return super().submit( + cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)) + ) + + def map( + self, + fn: Callable[..., T], + *iterables: Iterable[Any], + timeout: float | None = None, + chunksize: int = 1, + ) -> Iterator[T]: + contexts = [copy_context() for _ in range(len(iterables[0]))] # type: ignore[arg-type] + + def _wrapped_fn(*args: Any) -> T: + return contexts.pop().run(fn, *args) + + return super().map( + _wrapped_fn, + *iterables, + timeout=timeout, + chunksize=chunksize, + ) @contextmanager @@ -406,9 +462,36 @@ def get_executor_for_config( Generator[Executor, None, None]: The executor. """ config = config or {} - with ThreadPoolExecutor( - max_workers=config.get("max_concurrency"), - initializer=_set_context, - initargs=(copy_context(),), + with ContextThreadPoolExecutor( + max_workers=config.get("max_concurrency") ) as executor: yield executor + + +async def run_in_executor( + executor_or_config: Optional[Union[Executor, RunnableConfig]], + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: + """Run a function in an executor. + + Args: + executor (Executor): The executor. + func (Callable[P, Output]): The function. + *args (Any): The positional arguments to the function. + **kwargs (Any): The keyword arguments to the function. + + Returns: + Output: The output of the function. + """ + if executor_or_config is None or isinstance(executor_or_config, dict): + # Use default executor with context copied from current context + return await asyncio.get_running_loop().run_in_executor( + None, + cast(Callable[..., T], partial(copy_context().run, func, *args, **kwargs)), + ) + + return await asyncio.get_running_loop().run_in_executor( + executor_or_config, partial(func, **kwargs), *args + ) diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index f7ad523ca558f..81ab33d8fd445 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -23,6 +23,7 @@ from langchain_core.runnables.base import Runnable, RunnableSerializable from langchain_core.runnables.config import ( RunnableConfig, + ensure_config, get_config_list, get_executor_for_config, ) @@ -259,7 +260,7 @@ def configurable_fields( def _prepare( self, config: Optional[RunnableConfig] = None ) -> Tuple[Runnable[Input, Output], RunnableConfig]: - config = config or {} + config = ensure_config(config) specs_by_id = {spec.id: (key, spec) for key, spec in self.fields.items()} configurable_fields = { specs_by_id[k][0]: v @@ -392,7 +393,7 @@ def configurable_fields( def _prepare( self, config: Optional[RunnableConfig] = None ) -> Tuple[Runnable[Input, Output], RunnableConfig]: - config = config or {} + config = ensure_config(config) which = config.get("configurable", {}).get(self.which.id, self.default_key) # remap configurable keys for the chosen alternative if self.prefix_keys: diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index c1b6b7f94c2c4..99be9e7e13033 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import inspect from typing import ( TYPE_CHECKING, @@ -18,6 +17,7 @@ from langchain_core.load import load from langchain_core.pydantic_v1 import BaseModel, create_model from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda +from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.utils import ( ConfigurableFieldSpec, @@ -331,9 +331,7 @@ def _enter_history(self, input: Any, config: RunnableConfig) -> List[BaseMessage async def _aenter_history( self, input: Dict[str, Any], config: RunnableConfig ) -> List[BaseMessage]: - return await asyncio.get_running_loop().run_in_executor( - None, self._enter_history, input, config - ) + return await run_in_executor(config, self._enter_history, input, config) def _exit_history(self, run: Run, config: RunnableConfig) -> None: hist = config["configurable"]["message_history"] diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index 9ff41d8d21e4b..1f4367d9d5ead 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -31,6 +31,7 @@ RunnableConfig, acall_func_with_variable_args, call_func_with_variable_args, + ensure_config, get_executor_for_config, patch_config, ) @@ -206,7 +207,9 @@ def invoke( self, input: Other, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Other: if self.func is not None: - call_func_with_variable_args(self.func, input, config or {}, **kwargs) + call_func_with_variable_args( + self.func, input, ensure_config(config), **kwargs + ) return self._call_with_config(identity, input, config) async def ainvoke( @@ -217,10 +220,12 @@ async def ainvoke( ) -> Other: if self.afunc is not None: await acall_func_with_variable_args( - self.afunc, input, config or {}, **kwargs + self.afunc, input, ensure_config(config), **kwargs ) elif self.func is not None: - call_func_with_variable_args(self.func, input, config or {}, **kwargs) + call_func_with_variable_args( + self.func, input, ensure_config(config), **kwargs + ) return await self._acall_with_config(aidentity, input, config) def transform( @@ -243,7 +248,9 @@ def transform( final = final + chunk if final is not None: - call_func_with_variable_args(self.func, final, config or {}, **kwargs) + call_func_with_variable_args( + self.func, final, ensure_config(config), **kwargs + ) async def atransform( self, @@ -269,7 +276,7 @@ async def atransform( final = final + chunk if final is not None: - config = config or {} + config = ensure_config(config) if self.afunc is not None: await acall_func_with_variable_args( self.afunc, final, config, **kwargs @@ -458,7 +465,7 @@ def _transform( ) # get executor to start map output stream in background - with get_executor_for_config(config or {}) as executor: + with get_executor_for_config(config) as executor: # start map output stream first_map_chunk_future = executor.submit( next, diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 8b73580358a29..e13e641d7cd45 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -1,11 +1,9 @@ """Base implementation for tools or skills.""" from __future__ import annotations -import asyncio import inspect import warnings from abc import abstractmethod -from functools import partial from inspect import signature from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union @@ -26,7 +24,13 @@ root_validator, validate_arguments, ) -from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable +from langchain_core.runnables import ( + Runnable, + RunnableConfig, + RunnableSerializable, + ensure_config, +) +from langchain_core.runnables.config import run_in_executor class SchemaAnnotationError(TypeError): @@ -202,7 +206,7 @@ def invoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - config = config or {} + config = ensure_config(config) return self.run( input, callbacks=config.get("callbacks"), @@ -218,7 +222,7 @@ async def ainvoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Any: - config = config or {} + config = ensure_config(config) return await self.arun( input, callbacks=config.get("callbacks"), @@ -280,11 +284,7 @@ async def _arun( Add run_manager: Optional[AsyncCallbackManagerForToolRun] = None to child implementations to enable tracing, """ - return await asyncio.get_running_loop().run_in_executor( - None, - partial(self._run, **kwargs), - *args, - ) + return await run_in_executor(None, self._run, *args, **kwargs) def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: # For backwards compatibility, if run_input is a string, @@ -468,9 +468,7 @@ async def ainvoke( ) -> Any: if not self.coroutine: # If the tool does not implement async, fall back to default implementation - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, **kwargs) - ) + return await run_in_executor(config, self.invoke, input, config, **kwargs) return await super().ainvoke(input, config, **kwargs) @@ -538,8 +536,12 @@ async def _arun( else await self.coroutine(*args, **kwargs) ) else: - return await asyncio.get_running_loop().run_in_executor( - None, partial(self._run, run_manager=run_manager, **kwargs), *args + return await run_in_executor( + None, + self._run, + run_manager=run_manager.get_sync() if run_manager else None, + *args, + **kwargs, ) # TODO: this is for backwards compatibility, remove in future @@ -599,9 +601,7 @@ async def ainvoke( ) -> Any: if not self.coroutine: # If the tool does not implement async, fall back to default implementation - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.invoke, input, config, **kwargs) - ) + return await run_in_executor(config, self.invoke, input, config, **kwargs) return await super().ainvoke(input, config, **kwargs) @@ -652,10 +652,12 @@ async def _arun( if new_argument_supported else await self.coroutine(*args, **kwargs) ) - return await asyncio.get_running_loop().run_in_executor( + return await run_in_executor( None, - partial(self._run, run_manager=run_manager, **kwargs), + self._run, + run_manager=run_manager.get_sync() if run_manager else None, *args, + **kwargs, ) @classmethod diff --git a/libs/core/langchain_core/vectorstores.py b/libs/core/langchain_core/vectorstores.py index 8a23dcaec1940..2fb32f86b1acb 100644 --- a/libs/core/langchain_core/vectorstores.py +++ b/libs/core/langchain_core/vectorstores.py @@ -1,11 +1,9 @@ from __future__ import annotations -import asyncio import logging import math import warnings from abc import ABC, abstractmethod -from functools import partial from typing import ( TYPE_CHECKING, Any, @@ -24,6 +22,7 @@ from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.retrievers import BaseRetriever +from langchain_core.runnables.config import run_in_executor if TYPE_CHECKING: from langchain_core.callbacks.manager import ( @@ -103,9 +102,7 @@ async def aadd_texts( **kwargs: Any, ) -> List[str]: """Run more texts through the embeddings and add to the vectorstore.""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.add_texts, **kwargs), texts, metadatas - ) + return await run_in_executor(None, self.add_texts, texts, metadatas, **kwargs) def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]: """Run more documents through the embeddings and add to the vectorstore. @@ -224,8 +221,9 @@ async def asimilarity_search_with_score( # This is a temporary workaround to make the similarity search # asynchronous. The proper solution is to make the similarity search # asynchronous in the vector store implementations. - func = partial(self.similarity_search_with_score, *args, **kwargs) - return await asyncio.get_event_loop().run_in_executor(None, func) + return await run_in_executor( + None, self.similarity_search_with_score, *args, **kwargs + ) def _similarity_search_with_relevance_scores( self, @@ -383,8 +381,7 @@ async def asimilarity_search( # This is a temporary workaround to make the similarity search # asynchronous. The proper solution is to make the similarity search # asynchronous in the vector store implementations. - func = partial(self.similarity_search, query, k=k, **kwargs) - return await asyncio.get_event_loop().run_in_executor(None, func) + return await run_in_executor(None, self.similarity_search, query, k=k, **kwargs) def similarity_search_by_vector( self, embedding: List[float], k: int = 4, **kwargs: Any @@ -408,8 +405,9 @@ async def asimilarity_search_by_vector( # This is a temporary workaround to make the similarity search # asynchronous. The proper solution is to make the similarity search # asynchronous in the vector store implementations. - func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs) - return await asyncio.get_event_loop().run_in_executor(None, func) + return await run_in_executor( + None, self.similarity_search_by_vector, embedding, k=k, **kwargs + ) def max_marginal_relevance_search( self, @@ -450,7 +448,8 @@ async def amax_marginal_relevance_search( # This is a temporary workaround to make the similarity search # asynchronous. The proper solution is to make the similarity search # asynchronous in the vector store implementations. - func = partial( + return await run_in_executor( + None, self.max_marginal_relevance_search, query, k=k, @@ -458,7 +457,6 @@ async def amax_marginal_relevance_search( lambda_mult=lambda_mult, **kwargs, ) - return await asyncio.get_event_loop().run_in_executor(None, func) def max_marginal_relevance_search_by_vector( self, @@ -541,8 +539,8 @@ async def afrom_texts( **kwargs: Any, ) -> VST: """Return VectorStore initialized from texts and embeddings.""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas + return await run_in_executor( + None, cls.from_texts, texts, embedding, metadatas, **kwargs ) def _get_retriever_tags(self) -> List[str]: diff --git a/libs/core/tests/unit_tests/runnables/test_imports.py b/libs/core/tests/unit_tests/runnables/test_imports.py index c0bd73cd3ed82..8300292af1291 100644 --- a/libs/core/tests/unit_tests/runnables/test_imports.py +++ b/libs/core/tests/unit_tests/runnables/test_imports.py @@ -5,6 +5,9 @@ "ConfigurableField", "ConfigurableFieldSingleOption", "ConfigurableFieldMultiOption", + "ConfigurableFieldSpec", + "ensure_config", + "run_in_executor", "patch_config", "RouterInput", "RouterRunnable", diff --git a/libs/experimental/langchain_experimental/tools/python/tool.py b/libs/experimental/langchain_experimental/tools/python/tool.py index eb9b1c08a3637..6f1f6d72c01fd 100644 --- a/libs/experimental/langchain_experimental/tools/python/tool.py +++ b/libs/experimental/langchain_experimental/tools/python/tool.py @@ -1,7 +1,6 @@ """A tool for running python code in a REPL.""" import ast -import asyncio import re import sys from contextlib import redirect_stdout @@ -14,6 +13,7 @@ ) from langchain.pydantic_v1 import BaseModel, Field, root_validator from langchain.tools.base import BaseTool +from langchain_core.runnables.config import run_in_executor from langchain_experimental.utilities.python import PythonREPL @@ -72,10 +72,7 @@ async def _arun( if self.sanitize_input: query = sanitize_input(query) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, self.run, query) - - return result + return await run_in_executor(None, self.run, query) class PythonInputs(BaseModel): @@ -144,7 +141,4 @@ async def _arun( ) -> Any: """Use the tool asynchronously.""" - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, self._run, query) - - return result + return await run_in_executor(None, self._run, query) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index a28377625315c..d19af1778c5a9 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -30,7 +30,7 @@ from langchain_core.prompts.few_shot import FewShotPromptTemplate from langchain_core.prompts.prompt import PromptTemplate from langchain_core.pydantic_v1 import BaseModel, root_validator -from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables.utils import AddableDict from langchain_core.tools import BaseTool from langchain_core.utils.input import get_color_mapping @@ -1437,7 +1437,7 @@ def stream( **kwargs: Any, ) -> Iterator[AddableDict]: """Enables streaming over steps taken to reach final output.""" - config = config or {} + config = ensure_config(config) iterator = AgentExecutorIterator( self, input, @@ -1458,7 +1458,7 @@ async def astream( **kwargs: Any, ) -> AsyncIterator[AddableDict]: """Enables streaming over steps taken to reach final output.""" - config = config or {} + config = ensure_config(config) iterator = AgentExecutorIterator( self, input, diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index d1a45f9c763c8..3e98f5442a6c4 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -8,7 +8,7 @@ from langchain_core.agents import AgentAction, AgentFinish from langchain_core.load import dumpd from langchain_core.pydantic_v1 import Field -from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config from langchain_core.tools import BaseTool from langchain.callbacks.manager import CallbackManager @@ -222,7 +222,7 @@ def invoke( Union[List[ThreadMessage], List[RequiredActionFunctionToolCall]]. """ - config = config or {} + config = ensure_config(config) callback_manager = CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), inheritable_tags=config.get("tags"), diff --git a/libs/langchain/langchain/agents/output_parsers/openai_functions.py b/libs/langchain/langchain/agents/output_parsers/openai_functions.py index b0e0436bf7b49..04778d177bd42 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_functions.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_functions.py @@ -1,4 +1,3 @@ -import asyncio import json from json import JSONDecodeError from typing import List, Union @@ -85,12 +84,5 @@ def parse_result( message = result[0].message return self._parse_ai_message(message) - async def aparse_result( - self, result: List[Generation], *, partial: bool = False - ) -> Union[AgentAction, AgentFinish]: - return await asyncio.get_running_loop().run_in_executor( - None, self.parse_result, result - ) - def parse(self, text: str) -> Union[AgentAction, AgentFinish]: raise ValueError("Can only parse messages") diff --git a/libs/langchain/langchain/agents/output_parsers/openai_tools.py b/libs/langchain/langchain/agents/output_parsers/openai_tools.py index 4c4759d58ad7a..f4b2cdd9cebc6 100644 --- a/libs/langchain/langchain/agents/output_parsers/openai_tools.py +++ b/libs/langchain/langchain/agents/output_parsers/openai_tools.py @@ -1,4 +1,3 @@ -import asyncio import json from json import JSONDecodeError from typing import List, Union @@ -92,12 +91,5 @@ def parse_result( message = result[0].message return parse_ai_message_to_openai_tool_action(message) - async def aparse_result( - self, result: List[Generation], *, partial: bool = False - ) -> Union[List[AgentAction], AgentFinish]: - return await asyncio.get_running_loop().run_in_executor( - None, self.parse_result, result - ) - def parse(self, text: str) -> Union[List[AgentAction], AgentFinish]: raise ValueError("Can only parse messages") diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index f02a76eaa5ab0..64138770f04e6 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -1,5 +1,4 @@ """Base interface that all chains should implement.""" -import asyncio import inspect import json import logging @@ -19,7 +18,12 @@ root_validator, validator, ) -from langchain_core.runnables import RunnableConfig, RunnableSerializable +from langchain_core.runnables import ( + RunnableConfig, + RunnableSerializable, + ensure_config, + run_in_executor, +) from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import ( @@ -85,7 +89,7 @@ def invoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Dict[str, Any]: - config = config or {} + config = ensure_config(config) return self( input, callbacks=config.get("callbacks"), @@ -101,7 +105,7 @@ async def ainvoke( config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Dict[str, Any]: - config = config or {} + config = ensure_config(config) return await self.acall( input, callbacks=config.get("callbacks"), @@ -245,8 +249,8 @@ async def _acall( A dict of named outputs. Should contain all outputs specified in `Chain.output_keys`. """ - return await asyncio.get_running_loop().run_in_executor( - None, self._call, inputs, run_manager + return await run_in_executor( + None, self._call, inputs, run_manager.get_sync() if run_manager else None ) def __call__( diff --git a/libs/langchain/langchain/evaluation/schema.py b/libs/langchain/langchain/evaluation/schema.py index bb9d459344341..e4fa139ca6ecf 100644 --- a/libs/langchain/langchain/evaluation/schema.py +++ b/libs/langchain/langchain/evaluation/schema.py @@ -1,16 +1,15 @@ """Interfaces to be implemented by general evaluators.""" from __future__ import annotations -import asyncio import logging from abc import ABC, abstractmethod from enum import Enum -from functools import partial from typing import Any, Optional, Sequence, Tuple, Union from warnings import warn from langchain_core.agents import AgentAction from langchain_core.language_models import BaseLanguageModel +from langchain_core.runnables.config import run_in_executor from langchain.chains.base import Chain @@ -189,15 +188,13 @@ async def _aevaluate_strings( - value: the string value of the evaluation, if applicable. - reasoning: the reasoning for the evaluation, if applicable. """ # noqa: E501 - return await asyncio.get_running_loop().run_in_executor( + return await run_in_executor( None, - partial( - self._evaluate_strings, - prediction=prediction, - reference=reference, - input=input, - **kwargs, - ), + self._evaluate_strings, + prediction=prediction, + reference=reference, + input=input, + **kwargs, ) def evaluate_strings( @@ -292,16 +289,14 @@ async def _aevaluate_string_pairs( Returns: dict: A dictionary containing the preference, scores, and/or other information. """ # noqa: E501 - return await asyncio.get_running_loop().run_in_executor( + return await run_in_executor( None, - partial( - self._evaluate_string_pairs, - prediction=prediction, - prediction_b=prediction_b, - reference=reference, - input=input, - **kwargs, - ), + self._evaluate_string_pairs, + prediction=prediction, + prediction_b=prediction_b, + reference=reference, + input=input, + **kwargs, ) def evaluate_string_pairs( @@ -415,16 +410,14 @@ async def _aevaluate_agent_trajectory( Returns: dict: The evaluation result. """ - return await asyncio.get_running_loop().run_in_executor( + return await run_in_executor( None, - partial( - self._evaluate_agent_trajectory, - prediction=prediction, - agent_trajectory=agent_trajectory, - reference=reference, - input=input, - **kwargs, - ), + self._evaluate_agent_trajectory, + prediction=prediction, + agent_trajectory=agent_trajectory, + reference=reference, + input=input, + **kwargs, ) def evaluate_agent_trajectory( diff --git a/libs/langchain/langchain/retrievers/document_compressors/base.py b/libs/langchain/langchain/retrievers/document_compressors/base.py index 75e05e29006a5..0acb81e9e2606 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/base.py +++ b/libs/langchain/langchain/retrievers/document_compressors/base.py @@ -1,10 +1,10 @@ -import asyncio from abc import ABC, abstractmethod from inspect import signature from typing import List, Optional, Sequence, Union from langchain_core.documents import BaseDocumentTransformer, Document from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables.config import run_in_executor from langchain.callbacks.manager import Callbacks @@ -28,7 +28,7 @@ async def acompress_documents( callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: """Compress retrieved documents given the query context.""" - return await asyncio.get_running_loop().run_in_executor( + return await run_in_executor( None, self.compress_documents, documents, query, callbacks ) diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index efb55c1b984d3..be0cb5bdfa695 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -21,7 +21,6 @@ from __future__ import annotations -import asyncio import copy import logging import pathlib @@ -29,7 +28,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from functools import partial from io import BytesIO, StringIO from typing import ( AbstractSet, @@ -283,14 +281,6 @@ def transform_documents( """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) - async def atransform_documents( - self, documents: Sequence[Document], **kwargs: Any - ) -> Sequence[Document]: - """Asynchronously transform a sequence of documents by splitting them.""" - return await asyncio.get_running_loop().run_in_executor( - None, partial(self.transform_documents, **kwargs), documents - ) - class CharacterTextSplitter(TextSplitter): """Splitting text that looks at characters."""