Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat add volcano embedding #14693

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions docs/docs/integrations/text_embedding/volcengine.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": false
},
"source": [
"# Volc Engine\n",
"\n",
"This notebook provides you with a guide on how to load the Volcano Embedding class.\n",
"\n",
"\n",
"## API Initialization\n",
"\n",
"To use the LLM services based on [VolcEngine](https://www.volcengine.com/docs/82379/1099455), you have to initialize these parameters:\n",
"\n",
"You could either choose to init the AK,SK in environment variables or init params:\n",
"\n",
"```base\n",
"export VOLC_ACCESSKEY=XXX\n",
"export VOLC_SECRETKEY=XXX\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"start_time": "2023-12-14T03:05:29.857798Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed_documents result:\n",
" [0.02929673343896866, -0.009310632012784481, -0.060323506593704224, 0.0031018739100545645, -0.002218986628577113, -0.0023125179577618837, -0.04864659160375595, -2.062115163425915e-05]\n",
" [0.01987231895327568, -0.026041055098176003, -0.08395249396562576, 0.020043574273586273, -0.028862033039331436, 0.004629664588719606, -0.023107370361685753, -0.0342753604054451]\n"
]
}
],
"source": [
"\"\"\"For basic init and call\"\"\"\n",
"import os\n",
"\n",
"from langchain_community.embeddings import VolcanoEmbeddings\n",
"\n",
"os.environ[\"VOLC_ACCESSKEY\"] = \"\"\n",
"os.environ[\"VOLC_SECRETKEY\"] = \"\"\n",
"\n",
"embed = VolcanoEmbeddings(volcano_ak=\"\", volcano_sk=\"\")\n",
"print(\"embed_documents result:\")\n",
"res1 = embed.embed_documents([\"foo\", \"bar\"])\n",
"for r in res1:\n",
" print(\"\", r[:8])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"start_time": "2023-12-14T03:05:29.859276Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"embed_query result:\n",
" [0.01987231895327568, -0.026041055098176003, -0.08395249396562576, 0.020043574273586273, -0.028862033039331436, 0.004629664588719606, -0.023107370361685753, -0.0342753604054451]\n"
]
}
],
"source": [
"print(\"embed_query result:\")\n",
"res2 = embed.embed_query(\"foo\")\n",
"print(\"\", r[:8])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"ExecuteTime": {
"start_time": "2023-12-14T03:05:29.860282Z"
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
},
"vscode": {
"interpreter": {
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
2 changes: 2 additions & 0 deletions libs/community/langchain_community/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from langchain_community.embeddings.spacy_embeddings import SpacyEmbeddings
from langchain_community.embeddings.tensorflow_hub import TensorflowHubEmbeddings
from langchain_community.embeddings.vertexai import VertexAIEmbeddings
from langchain_community.embeddings.volcengine import VolcanoEmbeddings
from langchain_community.embeddings.voyageai import VoyageEmbeddings
from langchain_community.embeddings.xinference import XinferenceEmbeddings

Expand Down Expand Up @@ -136,6 +137,7 @@
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
]


Expand Down
128 changes: 128 additions & 0 deletions libs/community/langchain_community/embeddings/volcengine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)


class VolcanoEmbeddings(BaseModel, Embeddings):
"""`Volcengine Embeddings` embedding models."""

volcano_ak: Optional[str] = None
"""volcano access key
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""

volcano_sk: Optional[str] = None
"""volcano secret key
learn more from: https://www.volcengine.com/docs/6459/76491#ak-sk"""

host: str = "maas-api.ml-platform-cn-beijing.volces.com"
"""host
learn more from https://www.volcengine.com/docs/82379/1174746"""
region: str = "cn-beijing"
"""region
learn more from https://www.volcengine.com/docs/82379/1174746"""

model: str = "bge-large-zh"
"""Model name
you could get from https://www.volcengine.com/docs/82379/1174746
for now, we support bge_large_zh
"""

version: str = "1.0"
""" model version """

chunk_size: int = 100
"""Chunk size when multiple texts are input"""

client: Any
"""volcano client"""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
Validate whether volcano_ak and volcano_sk in the environment variables or
configuration file are available or not.

init volcano embedding client with `ak`, `sk`, `host`, `region`

Args:

values: a dictionary containing configuration information, must include the
fields of volcano_ak and volcano_sk
Returns:

a dictionary containing configuration information. If volcano_ak and
volcano_sk are not provided in the environment variables or configuration
file,the original values will be returned; otherwise, values containing
volcano_ak and volcano_sk will be returned.
Raises:

ValueError: volcengine package not found, please install it with
`pip install volcengine`
"""
values["volcano_ak"] = get_from_dict_or_env(
values,
"volcano_ak",
"VOLC_ACCESSKEY",
)
values["volcano_sk"] = get_from_dict_or_env(
values,
"volcano_sk",
"VOLC_SECRETKEY",
)

try:
from volcengine.maas import MaasService

client = MaasService(values["host"], values["region"])
client.set_ak(values["volcano_ak"])
client.set_sk(values["volcano_sk"])
values["client"] = client
except ImportError:
raise ImportError(
"volcengine package not found, please install it with "
"`pip install volcengine`"
)
return values

def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.

Args:
texts (List[str]): A list of text documents to embed.

Returns:
List[List[float]]: A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i : i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
req = {
"model": {
"name": self.model,
"version": self.version,
},
"input": chunk,
}
try:
from volcengine.maas import MaasException

resp = self.client.embeddings(req)
lst.extend([res["embedding"] for res in resp["data"]])
except MaasException as e:
raise ValueError(f"embed by volcengine Error: {e}")
return lst
19 changes: 19 additions & 0 deletions libs/community/tests/integration_tests/embeddings/test_volcano.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""Test Bytedance Volcano Embedding."""
from langchain_community.embeddings import VolcanoEmbeddings


def test_embedding_documents() -> None:
"""Test embeddings for documents."""
documents = ["foo", "bar"]
embedding = VolcanoEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 1024


def test_embedding_query() -> None:
"""Test embeddings for query."""
document = "foo bar"
embedding = VolcanoEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1024
1 change: 1 addition & 0 deletions libs/community/tests/unit_tests/embeddings/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
]


Expand Down
Loading