Skip to content

Commit

Permalink
Init embeddings (#28370)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Nov 27, 2024
1 parent ffe7bd4 commit 585da22
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 2 deletions.
2 changes: 2 additions & 0 deletions libs/langchain/langchain/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any

from langchain._api import create_importer
from langchain.embeddings.base import init_embeddings
from langchain.embeddings.cache import CacheBackedEmbeddings

if TYPE_CHECKING:
Expand Down Expand Up @@ -221,4 +222,5 @@ def __getattr__(name: str) -> Any:
"VertexAIEmbeddings",
"VoyageEmbeddings",
"XinferenceEmbeddings",
"init_embeddings",
]
224 changes: 222 additions & 2 deletions libs/langchain/langchain/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,224 @@
import functools
from importlib import util
from typing import Any, List, Optional, Tuple, Union

from langchain_core._api import beta
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import Runnable

_SUPPORTED_PROVIDERS = {
"azure_openai": "langchain_openai",
"bedrock": "langchain_aws",
"cohere": "langchain_cohere",
"google_vertexai": "langchain_google_vertexai",
"huggingface": "langchain_huggingface",
"mistralai": "langchain_mistralai",
"openai": "langchain_openai",
}


def _get_provider_list() -> str:
"""Get formatted list of providers and their packages."""
return "\n".join(
f" - {p}: {pkg.replace('_', '-')}" for p, pkg in _SUPPORTED_PROVIDERS.items()
)


def _parse_model_string(model_name: str) -> Tuple[str, str]:
"""Parse a model string into provider and model name components.
The model string should be in the format 'provider:model-name', where provider
is one of the supported providers.
Args:
model_name: A model string in the format 'provider:model-name'
Returns:
A tuple of (provider, model_name)
.. code-block:: python
_parse_model_string("openai:text-embedding-3-small")
# Returns: ("openai", "text-embedding-3-small")
_parse_model_string("bedrock:amazon.titan-embed-text-v1")
# Returns: ("bedrock", "amazon.titan-embed-text-v1")
Raises:
ValueError: If the model string is not in the correct format or
the provider is unsupported
"""
if ":" not in model_name:
providers = _SUPPORTED_PROVIDERS
raise ValueError(
f"Invalid model format '{model_name}'.\n"
f"Model name must be in format 'provider:model-name'\n"
f"Example valid model strings:\n"
f" - openai:text-embedding-3-small\n"
f" - bedrock:amazon.titan-embed-text-v1\n"
f" - cohere:embed-english-v3.0\n"
f"Supported providers: {providers}"
)

provider, model = model_name.split(":", 1)
provider = provider.lower().strip()
model = model.strip()

if provider not in _SUPPORTED_PROVIDERS:
raise ValueError(
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
if not model:
raise ValueError("Model name cannot be empty")
return provider, model


def _infer_model_and_provider(
model: str, *, provider: Optional[str] = None
) -> Tuple[str, str]:
if not model.strip():
raise ValueError("Model name cannot be empty")
if provider is None and ":" in model:
provider, model_name = _parse_model_string(model)
else:
provider = provider
model_name = model

if not provider:
providers = _SUPPORTED_PROVIDERS
raise ValueError(
"Must specify either:\n"
"1. A model string in format 'provider:model-name'\n"
" Example: 'openai:text-embedding-3-small'\n"
"2. Or explicitly set provider from: "
f"{providers}"
)

if provider not in _SUPPORTED_PROVIDERS:
raise ValueError(
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)
return provider, model_name


@functools.lru_cache(maxsize=len(_SUPPORTED_PROVIDERS))
def _check_pkg(pkg: str) -> None:
"""Check if a package is installed."""
if not util.find_spec(pkg):
raise ImportError(
f"Could not import {pkg} python package. "
f"Please install it with `pip install {pkg}`"
)


@beta()
def init_embeddings(
model: str,
*,
provider: Optional[str] = None,
**kwargs: Any,
) -> Union[Embeddings, Runnable[Any, List[float]]]:
"""Initialize an embeddings model from a model name and optional provider.
**Note:** Must have the integration package corresponding to the model provider
installed.
Args:
model: Name of the model to use. Can be either:
- A model string like "openai:text-embedding-3-small"
- Just the model name if provider is specified
provider: Optional explicit provider name. If not specified,
will attempt to parse from the model string. Supported providers
and their required packages:
{_get_provider_list()}
**kwargs: Additional model-specific parameters passed to the embedding model.
These vary by provider, see the provider-specific documentation for details.
Returns:
An Embeddings instance that can generate embeddings for text.
Raises:
ValueError: If the model provider is not supported or cannot be determined
ImportError: If the required provider package is not installed
.. dropdown:: Example Usage
:open:
.. code-block:: python
# Using a model string
model = init_embeddings("openai:text-embedding-3-small")
model.embed_query("Hello, world!")
# Using explicit provider
model = init_embeddings(
model="text-embedding-3-small",
provider="openai"
)
model.embed_documents(["Hello, world!", "Goodbye, world!"])
# With additional parameters
model = init_embeddings(
"openai:text-embedding-3-small",
api_key="sk-..."
)
.. versionadded:: 0.3.9
"""
if not model:
providers = _SUPPORTED_PROVIDERS.keys()
raise ValueError(
"Must specify model name. "
f"Supported providers are: {', '.join(providers)}"
)

provider, model_name = _infer_model_and_provider(model, provider=provider)
pkg = _SUPPORTED_PROVIDERS[provider]
_check_pkg(pkg)

if provider == "openai":
from langchain_openai import OpenAIEmbeddings

return OpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings

return AzureOpenAIEmbeddings(model=model_name, **kwargs)
elif provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings

return VertexAIEmbeddings(model=model_name, **kwargs)
elif provider == "bedrock":
from langchain_aws import BedrockEmbeddings

return BedrockEmbeddings(model_id=model_name, **kwargs)
elif provider == "cohere":
from langchain_cohere import CohereEmbeddings

return CohereEmbeddings(model=model_name, **kwargs)
elif provider == "mistralai":
from langchain_mistralai import MistralAIEmbeddings

return MistralAIEmbeddings(model=model_name, **kwargs)
elif provider == "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings

return HuggingFaceEmbeddings(model_name=model_name, **kwargs)
else:
raise ValueError(
f"Provider '{provider}' is not supported.\n"
f"Supported providers and their required packages:\n"
f"{_get_provider_list()}"
)


# This is for backwards compatibility
__all__ = ["Embeddings"]
__all__ = [
"init_embeddings",
"Embeddings", # This one is for backwards compatibility
]
Empty file.
44 changes: 44 additions & 0 deletions libs/langchain/tests/integration_tests/embeddings/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Test embeddings base module."""

import importlib

import pytest
from langchain_core.embeddings import Embeddings

from langchain.embeddings.base import _SUPPORTED_PROVIDERS, init_embeddings


@pytest.mark.parametrize(
"provider, model",
[
("openai", "text-embedding-3-large"),
("google_vertexai", "text-embedding-gecko@003"),
("bedrock", "amazon.titan-embed-text-v1"),
("cohere", "embed-english-v2.0"),
],
)
async def test_init_embedding_model(provider: str, model: str) -> None:
package = _SUPPORTED_PROVIDERS[provider]
try:
importlib.import_module(package)
except ImportError:
pytest.skip(f"Package {package} is not installed")

model_colon = init_embeddings(f"{provider}:{model}")
assert isinstance(model_colon, Embeddings)

model_explicit = init_embeddings(
model=model,
provider=provider,
)
assert isinstance(model_explicit, Embeddings)

text = "Hello world"

embedding_colon = await model_colon.aembed_query(text)
assert isinstance(embedding_colon, list)
assert all(isinstance(x, float) for x in embedding_colon)

embedding_explicit = await model_explicit.aembed_query(text)
assert isinstance(embedding_explicit, list)
assert all(isinstance(x, float) for x in embedding_explicit)
111 changes: 111 additions & 0 deletions libs/langchain/tests/unit_tests/embeddings/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Test embeddings base module."""

import pytest

from langchain.embeddings.base import (
_SUPPORTED_PROVIDERS,
_infer_model_and_provider,
_parse_model_string,
)


def test_parse_model_string() -> None:
"""Test parsing model strings into provider and model components."""
assert _parse_model_string("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)
assert _parse_model_string("bedrock:amazon.titan-embed-text-v1") == (
"bedrock",
"amazon.titan-embed-text-v1",
)
assert _parse_model_string("huggingface:BAAI/bge-base-en:v1.5") == (
"huggingface",
"BAAI/bge-base-en:v1.5",
)


def test_parse_model_string_errors() -> None:
"""Test error cases for model string parsing."""
with pytest.raises(ValueError, match="Model name must be"):
_parse_model_string("just-a-model-name")

with pytest.raises(ValueError, match="Invalid model format "):
_parse_model_string("")

with pytest.raises(ValueError, match="is not supported"):
_parse_model_string(":model-name")

with pytest.raises(ValueError, match="Model name cannot be empty"):
_parse_model_string("openai:")

with pytest.raises(
ValueError, match="Provider 'invalid-provider' is not supported"
):
_parse_model_string("invalid-provider:model-name")

for provider in _SUPPORTED_PROVIDERS:
with pytest.raises(ValueError, match=f"{provider}"):
_parse_model_string("invalid-provider:model-name")


def test_infer_model_and_provider() -> None:
"""Test model and provider inference from different input formats."""
assert _infer_model_and_provider("openai:text-embedding-3-small") == (
"openai",
"text-embedding-3-small",
)

assert _infer_model_and_provider(
model="text-embedding-3-small", provider="openai"
) == ("openai", "text-embedding-3-small")

assert _infer_model_and_provider(
model="ft:text-embedding-3-small", provider="openai"
) == ("openai", "ft:text-embedding-3-small")

assert _infer_model_and_provider(model="openai:ft:text-embedding-3-small") == (
"openai",
"ft:text-embedding-3-small",
)


def test_infer_model_and_provider_errors() -> None:
"""Test error cases for model and provider inference."""
# Test missing provider
with pytest.raises(ValueError, match="Must specify either"):
_infer_model_and_provider("text-embedding-3-small")

# Test empty model
with pytest.raises(ValueError, match="Model name cannot be empty"):
_infer_model_and_provider("")

# Test empty provider with model
with pytest.raises(ValueError, match="Must specify either"):
_infer_model_and_provider("model", provider="")

# Test invalid provider
with pytest.raises(ValueError, match="is not supported"):
_infer_model_and_provider("model", provider="invalid")

# Test provider list is in error
with pytest.raises(ValueError) as exc:
_infer_model_and_provider("model", provider="invalid")
for provider in _SUPPORTED_PROVIDERS:
assert provider in str(exc.value)


@pytest.mark.parametrize(
"provider",
sorted(_SUPPORTED_PROVIDERS.keys()),
)
def test_supported_providers_package_names(provider: str) -> None:
"""Test that all supported providers have valid package names."""
package = _SUPPORTED_PROVIDERS[provider]
assert "-" not in package
assert package.startswith("langchain_")
assert package.islower()


def test_is_sorted() -> None:
assert list(_SUPPORTED_PROVIDERS) == sorted(_SUPPORTED_PROVIDERS.keys())
Loading

0 comments on commit 585da22

Please sign in to comment.