Skip to content

Commit

Permalink
Propagate context vars in all classes/methods (#15329)
Browse files Browse the repository at this point in the history
- Any direct usage of ThreadPoolExecutor or asyncio.run_in_executor
needs manual handling of context vars

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
nfcampos authored Dec 29, 2023
2 parents 7eec8f2 + 4e4b119 commit 99000c6
Show file tree
Hide file tree
Showing 39 changed files with 394 additions and 376 deletions.
15 changes: 0 additions & 15 deletions libs/community/langchain_community/chat_models/human.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
"""ChatModel wrapper which returns user input as the response.."""
import asyncio
from functools import partial
from io import StringIO
from typing import Any, Callable, Dict, List, Mapping, Optional

import yaml
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
Expand Down Expand Up @@ -111,15 +108,3 @@ def _generate(
self.message_func(messages, **self.message_kwargs)
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
return ChatResult(generations=[ChatGeneration(message=user_input)])

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)
15 changes: 0 additions & 15 deletions libs/community/langchain_community/chat_models/mlflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import asyncio
import logging
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from urllib.parse import urlparse

from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
Expand Down Expand Up @@ -125,18 +122,6 @@ def _generate(
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
return ChatMlflow._create_chat_result(resp)

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import asyncio
import logging
import warnings
from functools import partial
from typing import Any, Dict, List, Mapping, Optional

from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
Expand Down Expand Up @@ -116,18 +113,6 @@ def _generate(
resp = mlflow.gateway.query(self.route, data=data)
return ChatMLflowAIGateway._create_chat_result(resp)

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params
Expand Down
24 changes: 0 additions & 24 deletions libs/community/langchain_community/chat_models/pai_eas_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import asyncio
import json
import logging
from functools import partial
from typing import Any, AsyncIterator, Dict, List, Optional, cast

import requests
Expand Down Expand Up @@ -300,25 +298,3 @@ async def _astream(
# break if stop sequence found
if stop_seq_found:
break

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
if stream if stream is not None else self.streaming:
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
generation = chunk
assert generation is not None
return ChatResult(generations=[generation])

func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)
6 changes: 2 additions & 4 deletions libs/community/langchain_community/embeddings/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import json
import os
from functools import partial
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.runnables.config import run_in_executor


class BedrockEmbeddings(BaseModel, Embeddings):
Expand Down Expand Up @@ -181,9 +181,7 @@ async def aembed_query(self, text: str) -> List[float]:
Embeddings for the text.
"""

return await asyncio.get_running_loop().run_in_executor(
None, partial(self.embed_query, text)
)
return await run_in_executor(None, self.embed_query, text)

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous compute doc embeddings using a Bedrock model.
Expand Down
6 changes: 2 additions & 4 deletions libs/community/langchain_community/embeddings/ernie.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import asyncio
import logging
import threading
from functools import partial
from typing import Dict, List, Optional

import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import run_in_executor
from langchain_core.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -134,9 +134,7 @@ async def aembed_query(self, text: str) -> List[float]:
List[float]: Embeddings for the text.
"""

return await asyncio.get_running_loop().run_in_executor(
None, partial(self.embed_query, text)
)
return await run_in_executor(None, self.embed_query, text)

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Expand Down
10 changes: 0 additions & 10 deletions libs/community/langchain_community/tools/multion/close_session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Optional, Type

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field
Expand Down Expand Up @@ -57,11 +55,3 @@ def _run(
print(f"{e}, retrying...")
except Exception as e:
raise Exception(f"An error occurred: {e}")

async def _arun(
self,
sessionId: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> None:
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._run, sessionId)
13 changes: 0 additions & 13 deletions libs/community/langchain_community/tools/multion/create_session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Optional, Type

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field
Expand Down Expand Up @@ -67,14 +65,3 @@ def _run(
}
except Exception as e:
raise Exception(f"An error occurred: {e}")

async def _arun(
self,
query: str,
url: Optional[str] = "https://www.google.com/",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> dict:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self._run, query, url)

return result
14 changes: 0 additions & 14 deletions libs/community/langchain_community/tools/multion/update_session.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
from typing import TYPE_CHECKING, Optional, Type

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field
Expand Down Expand Up @@ -74,15 +72,3 @@ def _run(
return {"error": f"{e}", "Response": "retrying..."}
except Exception as e:
raise Exception(f"An error occurred: {e}")

async def _arun(
self,
sessionId: str,
query: str,
url: Optional[str] = "https://www.google.com/",
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> dict:
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, self._run, sessionId, query, url)

return result
12 changes: 0 additions & 12 deletions libs/community/langchain_community/tools/shell/tool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import asyncio
import platform
import warnings
from typing import Any, List, Optional, Type, Union

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
Expand Down Expand Up @@ -77,13 +75,3 @@ def _run(
) -> str:
"""Run commands and return final output."""
return self.process.run(commands)

async def _arun(
self,
commands: Union[str, List[str]],
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
"""Run commands asynchronously and return final output."""
return await asyncio.get_event_loop().run_in_executor(
None, self.process.run, commands
)
11 changes: 5 additions & 6 deletions libs/community/langchain_community/vectorstores/faiss.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

import asyncio
import logging
import operator
import os
import pickle
import uuid
import warnings
from functools import partial
from pathlib import Path
from typing import (
Any,
Expand All @@ -24,6 +22,7 @@
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore

from langchain_community.docstore.base import AddableMixin, Docstore
Expand Down Expand Up @@ -359,15 +358,15 @@ async def asimilarity_search_with_score_by_vector(
"""

# This is a temporary workaround to make the similarity search asynchronous.
func = partial(
return await run_in_executor(
None,
self.similarity_search_with_score_by_vector,
embedding,
k=k,
filter=filter,
fetch_k=fetch_k,
**kwargs,
)
return await asyncio.get_event_loop().run_in_executor(None, func)

def similarity_search_with_score(
self,
Expand Down Expand Up @@ -640,15 +639,15 @@ async def amax_marginal_relevance_search_with_score_by_vector(
relevance and score for each.
"""
# This is a temporary workaround to make the similarity search asynchronous.
func = partial(
return await run_in_executor(
None,
self.max_marginal_relevance_search_with_score_by_vector,
embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
filter=filter,
)
return await asyncio.get_event_loop().run_in_executor(None, func)

def max_marginal_relevance_search_by_vector(
self,
Expand Down
7 changes: 3 additions & 4 deletions libs/community/langchain_community/vectorstores/pgvector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import asyncio
import contextlib
import enum
import logging
import uuid
from functools import partial
from typing import (
Any,
Callable,
Expand All @@ -31,6 +29,7 @@

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore

Expand Down Expand Up @@ -941,7 +940,8 @@ async def amax_marginal_relevance_search_by_vector(
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(
return await run_in_executor(
None,
self.max_marginal_relevance_search_by_vector,
embedding,
k=k,
Expand All @@ -950,4 +950,3 @@ async def amax_marginal_relevance_search_by_vector(
filter=filter,
**kwargs,
)
return await asyncio.get_event_loop().run_in_executor(None, func)
7 changes: 3 additions & 4 deletions libs/community/langchain_community/vectorstores/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import functools
import uuid
import warnings
Expand All @@ -25,6 +24,7 @@
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
from langchain_core.vectorstores import VectorStore

from langchain_community.vectorstores.utils import maximal_marginal_relevance
Expand Down Expand Up @@ -58,10 +58,9 @@ async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
# by removing the first letter from the method name. For example,
# if the async method is called ``aaad_texts``, the synchronous method
# will be called ``aad_texts``.
sync_method = functools.partial(
getattr(self, method.__name__[1:]), *args, **kwargs
return await run_in_executor(
None, getattr(self, method.__name__[1:]), *args, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, sync_method)

return wrapper

Expand Down
Loading

0 comments on commit 99000c6

Please sign in to comment.