Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AzureMLEndpointEmbeddings for embedding support via Azure ML serverless API #27148

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,27 @@
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from langchain_community.embeddings.aleph_alpha import (
AlephAlphaAsymmetricSemanticEmbedding,
AlephAlphaSymmetricSemanticEmbedding,
)
from langchain_community.embeddings.anyscale import (
AnyscaleEmbeddings,
)
from langchain_community.embeddings.ascend import (
AscendEmbeddings,
)
from langchain_community.embeddings.awa import (
AwaEmbeddings,
)
from langchain_community.embeddings.azure_openai import (
AzureOpenAIEmbeddings,
)
from langchain_community.embeddings.azure_ml_endpoint import (
AzureMLEndpointEmbeddings,

Check failure on line 35 in libs/community/langchain_community/embeddings/__init__.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (F401)

langchain_community/embeddings/__init__.py:35:9: F401 `langchain_community.embeddings.azure_ml_endpoint.AzureMLEndpointEmbeddings` imported but unused; consider removing, adding to `__all__`, or using a redundant alias

Check failure on line 35 in libs/community/langchain_community/embeddings/__init__.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (F401)

langchain_community/embeddings/__init__.py:35:9: F401 `langchain_community.embeddings.azure_ml_endpoint.AzureMLEndpointEmbeddings` imported but unused; consider removing, adding to `__all__`, or using a redundant alias
)
from langchain_community.embeddings.baichuan import (
BaichuanTextEmbeddings,

Check failure on line 38 in libs/community/langchain_community/embeddings/__init__.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (F401)

langchain_community/embeddings/__init__.py:38:9: F401 `langchain_community.embeddings.baichuan.BaichuanTextEmbeddings` imported but unused; consider removing, adding to `__all__`, or using a redundant alias

Check failure on line 38 in libs/community/langchain_community/embeddings/__init__.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (F401)

langchain_community/embeddings/__init__.py:38:9: F401 `langchain_community.embeddings.baichuan.BaichuanTextEmbeddings` imported but unused; consider removing, adding to `__all__`, or using a redundant alias
)
from langchain_community.embeddings.baidu_qianfan_endpoint import (
QianfanEmbeddingsEndpoint,
Expand Down Expand Up @@ -251,6 +254,7 @@
"AscendEmbeddings",
"AwaEmbeddings",
"AzureOpenAIEmbeddings",
"AzureMLEndpointEmbeddings"
"BaichuanTextEmbeddings",
"BedrockEmbeddings",
"BookendEmbeddings",
Expand Down Expand Up @@ -335,6 +339,7 @@
"AnyscaleEmbeddings": "langchain_community.embeddings.anyscale",
"AwaEmbeddings": "langchain_community.embeddings.awa",
"AzureOpenAIEmbeddings": "langchain_community.embeddings.azure_openai",
"AzureMLEndpointEmbeddings": "langchain_community.embeddings.azure_ml_endpoint",
"BaichuanTextEmbeddings": "langchain_community.embeddings.baichuan",
"BedrockEmbeddings": "langchain_community.embeddings.bedrock",
"BookendEmbeddings": "langchain_community.embeddings.bookend",
Expand Down
138 changes: 138 additions & 0 deletions libs/community/langchain_community/embeddings/azure_ml_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import json
import os
from typing import Any, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.utils import from_env
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self
import requests
import asyncio

Check failure on line 10 in libs/community/langchain_community/embeddings/azure_ml_endpoint.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (F401)

langchain_community/embeddings/azure_ml_endpoint.py:10:8: F401 `asyncio` imported but unused

Check failure on line 10 in libs/community/langchain_community/embeddings/azure_ml_endpoint.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (F401)

langchain_community/embeddings/azure_ml_endpoint.py:10:8: F401 `asyncio` imported but unused
import aiohttp

DEFAULT_MODEL = "Cohere-embed-v3-multilingual"

Check failure on line 13 in libs/community/langchain_community/embeddings/azure_ml_endpoint.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (I001)

langchain_community/embeddings/azure_ml_endpoint.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 13 in libs/community/langchain_community/embeddings/azure_ml_endpoint.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (I001)

langchain_community/embeddings/azure_ml_endpoint.py:1:1: I001 Import block is un-sorted or un-formatted

class AzureMLEndpointEmbeddings(BaseModel, Embeddings):
"""Azure ML embedding endpoint for embeddings.

To use, set up Azure ML API endpoint and provide the endpoint URL and API key.

Example:
.. code-block:: python

from langchain_community.embeddings import AzureMLEndpointEmbeddings
azure_ml = AzureMLEndpointEmbeddings(
embed_url="Endpoint URL from Azure ML Serverless API",
api_key="your-api-key"
)
"""

client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
embed_url: Optional[str] = None
"""Azure ML endpoint URL to use for embedding."""
api_key: Optional[str] = Field(
default_factory=from_env("AZURE_ML_API_KEY", default=None)
)
"""API Key to use for authentication."""

model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the embedding API."""

model_config = ConfigDict(
extra="forbid",
protected_namespaces=(),
)

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that api key exists in environment."""
if not self.api_key:
self.api_key = os.getenv("AZURE_ML_API_KEY")

if not self.api_key:
raise ValueError("API Key must be provided or set in the environment.")

if not self.embed_url:
raise ValueError("Azure ML endpoint URL must be provided.")

return self

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call Azure ML embedding endpoint to embed documents.

Args:
texts: The list of texts to embed.

Returns:
List of embeddings, one for each text.
"""
texts = [text.replace("\n", " ") for text in texts]
data = {
"input": texts
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}

response = requests.post(self.embed_url, headers=headers, data=json.dumps(data))

if response.status_code == 200:
response_data = response.json()
embeddings = [item['embedding'] for item in response_data['data']]
return embeddings
else:
raise Exception(f"Error: {response.status_code}, {response.text}")

async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Async Call to Azure ML embedding endpoint to embed documents.

Args:
texts: The list of texts to embed.

Returns:
List of embeddings, one for each text.
"""
texts = [text.replace("\n", " ") for text in texts]
data = {
"input": texts
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}

async with aiohttp.ClientSession() as session:
async with session.post(self.embed_url, headers=headers, json=data) as response:

Check failure on line 107 in libs/community/langchain_community/embeddings/azure_ml_endpoint.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/embeddings/azure_ml_endpoint.py:107:89: E501 Line too long (92 > 88)

Check failure on line 107 in libs/community/langchain_community/embeddings/azure_ml_endpoint.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/embeddings/azure_ml_endpoint.py:107:89: E501 Line too long (92 > 88)
if response.status == 200:
response_data = await response.json()
embeddings = [item['embedding'] for item in response_data['data']]
return embeddings
else:
response_text = await response.text()
raise Exception(f"Error: {response.status}, {response_text}")

def embed_query(self, text: str) -> List[float]:
"""Call Azure ML embedding endpoint to embed a single query text.

Args:
text: The text to embed.

Returns:
Embeddings for the text.
"""
response = self.embed_documents([text])[0]
return response

async def aembed_query(self, text: str) -> List[float]:
"""Async Call to Azure ML embedding endpoint to embed a single query text.

Args:
text: The text to embed.

Returns:
Embeddings for the text.
"""
response = (await self.aembed_documents([text]))[0]
return response
Loading