From 81985b31e6ff839ba073d1107722e1756d30e678 Mon Sep 17 00:00:00 2001 From: Liang Zhang Date: Tue, 5 Mar 2024 18:04:45 -0800 Subject: [PATCH] community[patch]: Databricks SerDe uses cloudpickle instead of pickle (#18607) - **Description:** Databricks SerDe uses cloudpickle instead of pickle when serializing a user-defined function transform_input_fn since pickle does not support functions defined in `__main__`, and cloudpickle supports this. - **Dependencies:** cloudpickle>=2.0.0 Added a unit test. --- .../langchain_community/llms/databricks.py | 15 ++++++++++++--- libs/community/poetry.lock | 8 ++++---- libs/community/pyproject.toml | 5 ++++- .../tests/unit_tests/llms/test_databricks.py | 16 +++++++++++++--- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/libs/community/langchain_community/llms/databricks.py b/libs/community/langchain_community/llms/databricks.py index c92f8ba221fdd..3b0bca7edb114 100644 --- a/libs/community/langchain_community/llms/databricks.py +++ b/libs/community/langchain_community/llms/databricks.py @@ -1,5 +1,4 @@ import os -import pickle import re import warnings from abc import ABC, abstractmethod @@ -225,7 +224,12 @@ def _is_hex_string(data: str) -> bool: def _load_pickled_fn_from_hex_string(data: str) -> Callable: """Loads a pickled function from a hexadecimal string.""" try: - return pickle.loads(bytes.fromhex(data)) + import cloudpickle + except Exception as e: + raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}") + + try: + return cloudpickle.loads(bytes.fromhex(data)) except Exception as e: raise ValueError( f"Failed to load the pickled function from a hexadecimal string. Error: {e}" @@ -235,7 +239,12 @@ def _load_pickled_fn_from_hex_string(data: str) -> Callable: def _pickle_fn_to_hex_string(fn: Callable) -> str: """Pickles a function and returns the hexadecimal string.""" try: - return pickle.dumps(fn).hex() + import cloudpickle + except Exception as e: + raise ValueError(f"Please install cloudpickle>=2.0.0. Error: {e}") + + try: + return cloudpickle.dumps(fn).hex() except Exception as e: raise ValueError(f"Failed to pickle the function: {e}") diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index c1f0c0877fbf8..41b3801366476 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -3650,7 +3650,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.28" +version = "0.1.29" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -3687,7 +3687,7 @@ develop = true langchain-core = "^0.1.28" [package.extras] -extended-testing = [] +extended-testing = ["lxml (>=5.1.0,<6.0.0)"] [package.source] type = "directory" @@ -9176,9 +9176,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] cli = ["typer"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "xata", "xmltodict", "zhipuai"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "d64381a1891a09e6215818c25ba7ca7b14a8708351695feab9ae53f4485f3b3e" +content-hash = "d110eaaa4ecba8f6ed7faa2577b058c1f7c74171a6dbc53bc880f3c8598fc34b" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 40f3163bf4133..b1ed9e7aeaafd 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -81,6 +81,7 @@ hologres-vector = {version = "^0.0.6", optional = true} praw = {version = "^7.7.1", optional = true} msal = {version = "^1.25.0", optional = true} databricks-vectorsearch = {version = "^0.21", optional = true} +cloudpickle = {version = ">=2.0.0", optional = true} dgml-utils = {version = "^0.3.0", optional = true} datasets = {version = "^2.15.0", optional = true} tree-sitter = {version = "^0.20.2", optional = true} @@ -249,6 +250,7 @@ extended_testing = [ "hologres-vector", "praw", "databricks-vectorsearch", + "cloudpickle", "dgml-utils", "cohere", "tree-sitter", @@ -260,7 +262,8 @@ extended_testing = [ "elasticsearch", "hdbcli", "oci", - "rdflib" + "rdflib", + "cloudpickle", ] [tool.ruff] diff --git a/libs/community/tests/unit_tests/llms/test_databricks.py b/libs/community/tests/unit_tests/llms/test_databricks.py index 7d3809e270cf9..6b1a3d988175c 100644 --- a/libs/community/tests/unit_tests/llms/test_databricks.py +++ b/libs/community/tests/unit_tests/llms/test_databricks.py @@ -1,10 +1,13 @@ """test Databricks LLM""" -import pickle from typing import Any, Dict +import pytest from pytest import MonkeyPatch -from langchain_community.llms.databricks import Databricks +from langchain_community.llms.databricks import ( + Databricks, + _load_pickled_fn_from_hex_string, +) class MockDatabricksServingEndpointClient: @@ -29,7 +32,10 @@ def transform_input(**request: Any) -> Dict[str, Any]: return request +@pytest.mark.requires("cloudpickle") def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None: + import cloudpickle + monkeypatch.setattr( "langchain_community.llms.databricks._DatabricksServingEndpointClient", MockDatabricksServingEndpointClient, @@ -42,5 +48,9 @@ def test_serde_transform_input_fn(monkeypatch: MonkeyPatch) -> None: transform_input_fn=transform_input, ) params = llm._default_params - pickled_string = pickle.dumps(transform_input).hex() + pickled_string = cloudpickle.dumps(transform_input).hex() assert params["transform_input_fn"] == pickled_string + + request = {"prompt": "What is the meaning of life?"} + fn = _load_pickled_fn_from_hex_string(params["transform_input_fn"]) + assert fn(**request) == transform_input(**request)