Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat add volcano embedding #14693

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions docs/docs/integrations/text_embedding/bytedance_volcano.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"cells": [
{
"cell_type": "raw",
"source": [
"# Bytedance Volcano\n",
"\n",
"This notebook provides you with a guide on how to load the Volcano Embedding class.\n",
"\n",
"\n",
"## API Initialization\n",
"\n",
"To use the LLM services based on [Bytedance Volcano](https://www.volcengine.com/docs/82379/1099455), you have to initialize these parameters:\n",
"\n",
"You could either choose to init the AK,SK in environment variables or init params:\n",
"\n",
"```base\n",
"export VOLC_ACCESSKEY=XXX\n",
"export VOLC_SECRETKEY=XXX\n",
"```"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2023-12-14T03:05:29.857798Z"
}
},
"outputs": [],
"source": [
"\"\"\"For basic init and call\"\"\"\n",
"import os\n",
"\n",
"from langchain.embeddings import VolcanoEmbeddings\n",
"\n",
"os.environ[\"VOLC_ACCESSKEY\"] = \"\"\n",
"os.environ[\"VOLC_SECRETKEY\"] = \"==\"\n",
"\n",
"embed = VolcanoEmbeddings(\n",
" # volcano_ak='xxx',\n",
" # volcano_sk='xxx'\n",
")\n",
"\n",
"print(\"embed_documents result:\")\n",
"res1 = embed.embed_documents([\"foo\", \"bar\"])\n",
"for r in res1:\n",
" print(\"\", r[:8])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2023-12-14T03:05:29.859276Z"
}
},
"outputs": [],
"source": [
"print(\"embed_query result:\")\n",
"res2 = embed.embed_query(\"foo\")\n",
"print(\"\", r[:8])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2023-12-14T03:05:29.860282Z"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
},
"vscode": {
"interpreter": {
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
2 changes: 2 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from langchain_community.embeddings.bedrock import BedrockEmbeddings
from langchain_community.embeddings.bookend import BookendEmbeddings
from langchain_community.embeddings.bytedance_volcano import VolcanoEmbeddings
from langchain_community.embeddings.clarifai import ClarifaiEmbeddings
from langchain_community.embeddings.cohere import CohereEmbeddings
from langchain_community.embeddings.dashscope import DashScopeEmbeddings
Expand Down Expand Up @@ -136,6 +137,7 @@
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
]


Expand Down
128 changes: 128 additions & 0 deletions libs/community/langchain_community/embeddings/bytedance_volcano.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)


class VolcanoEmbeddings(BaseModel, Embeddings):
"""`Bytedance Volcano Embeddings` embedding models."""

volcano_ak: Optional[str] = None
"""volcano access key
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""

volcano_sk: Optional[str] = None
"""volcano secret key
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""

host: str = "maas-api.ml-platform-cn-beijing.volces.com"
"""host
learn more from https://www.volcengine.com/docs/82379/1174746"""
region: str = "cn-beijing"
"""region
learn more from https://www.volcengine.com/docs/82379/1174746"""

model: str = "bge-large-zh"
"""Model name
you could get from https://www.volcengine.com/docs/82379/1174746
for now, we support bge_large_zh
"""

version: str = "1.0"
""" model version """

chunk_size: int = 100
"""Chunk size when multiple texts are input"""

client: Any
"""volcano client"""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
Validate whether volcano_ak and volcano_sk in the environment variables or
configuration file are available or not.

init volcano embedding client with `ak`, `sk`, `host`, `region`

Args:

values: a dictionary containing configuration information, must include the
fields of volcano_ak and volcano_sk
Returns:

a dictionary containing configuration information. If volcano_ak and
volcano_sk are not provided in the environment variables or configuration
file,the original values will be returned; otherwise, values containing
volcano_ak and volcano_sk will be returned.
Raises:

ValueError: volcengine package not found, please install it with
`pip install volcengine`
"""
values["volcano_ak"] = get_from_dict_or_env(
values,
"volcano_ak",
"VOLC_ACCESSKEY",
)
values["volcano_sk"] = get_from_dict_or_env(
values,
"volcano_sk",
"VOLC_SECRETKEY",
)

try:
from volcengine.maas import MaasService

client = MaasService(values["host"], values["region"])
client.set_ak(values["volcano_ak"])
client.set_sk(values["volcano_sk"])
values["client"] = client
except ImportError:
raise ImportError(
"volcengine package not found, please install it with "
"`pip install volcengine`"
)
return values

def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.

Args:
texts (List[str]): A list of text documents to embed.

Returns:
List[List[float]]: A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
req = {
"model": {
"name": self.model,
"version": self.version,
},
"input": chunk,
}
try:
from volcengine.maas import MaasException

resp = self.client.embeddings(req)
lst.extend([res["embedding"] for res in resp["data"]])
except MaasException as err:
raise ValueError(f"Error: {err!r}")
return lst
19 changes: 19 additions & 0 deletions libs/community/tests/integration_tests/embeddings/test_volcano.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Test Bytedance Vocalno Embedding."""
from langchain_community.embeddings import VolcanoEmbeddings


def test_modelscope_embedding_documents() -> None:
"""Test modelscope embeddings for documents."""
documents = ["foo", "bar"]
embedding = VolcanoEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 1024


def test_modelscope_embedding_query() -> None:
"""Test modelscope embeddings for query."""
document = "foo bar"
embedding = VolcanoEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1024
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
]


Expand Down
2 changes: 2 additions & 0 deletions libs/langchain/langchain/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
from langchain.embeddings.bedrock import BedrockEmbeddings
from langchain.embeddings.bookend import BookendEmbeddings
from langchain.embeddings.bytedance_volcano import VolcanoEmbeddings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to do anything in langchain. existing imports are kept for backwards compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

from langchain.embeddings.cache import CacheBackedEmbeddings
from langchain.embeddings.clarifai import ClarifaiEmbeddings
from langchain.embeddings.cohere import CohereEmbeddings
Expand Down Expand Up @@ -129,6 +130,7 @@
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
]


Expand Down
5 changes: 5 additions & 0 deletions libs/langchain/langchain/embeddings/bytedance_volcano.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from langchain_community.embeddings.bytedance_volcano import (
VolcanoEmbeddings,
)

__all__ = ["VolcanoEmbeddings"]
1 change: 1 addition & 0 deletions libs/langchain/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
]


Expand Down
Loading