From 167a8369dd61f6592e6c7321fab5a9f801329ea8 Mon Sep 17 00:00:00 2001 From: Wendong-Fan <133094783+Wendong-Fan@users.noreply.github.com> Date: Fri, 24 Jan 2025 23:24:51 +0800 Subject: [PATCH] feat: Integrate jina embedding (#1500) --- .env | 3 + .github/workflows/build_package.yml | 1 + .github/workflows/pytest_package.yml | 3 + camel/embeddings/__init__.py | 2 + camel/embeddings/jina_embedding.py | 156 ++++++++++++++++++ camel/types/enums.py | 30 +++- examples/embeddings/jina_embedding_example.py | 96 +++++++++++ test/embeddings/test_jina_embedding.py | 142 ++++++++++++++++ 8 files changed, 432 insertions(+), 1 deletion(-) create mode 100644 camel/embeddings/jina_embedding.py create mode 100644 examples/embeddings/jina_embedding_example.py create mode 100644 test/embeddings/test_jina_embedding.py diff --git a/.env b/.env index 6b84096d88..0600f02192 100644 --- a/.env +++ b/.env @@ -53,6 +53,9 @@ # InternLM API (https://internlm.intern-ai.org.cn/api/tokens) # INTERNLM_API_KEY="Fill your API key here" +# JINA API (https://jina.ai/) +# JINA_API_KEY="Fill your API key here" + #=========================================== # Tools & Services API #=========================================== diff --git a/.github/workflows/build_package.yml b/.github/workflows/build_package.yml index e062074a64..4c61f07ec5 100644 --- a/.github/workflows/build_package.yml +++ b/.github/workflows/build_package.yml @@ -79,6 +79,7 @@ jobs: DAPPIER_API_KEY: "${{ secrets.DAPPIER_API_KEY }}" DISCORD_BOT_TOKEN: "${{ secrets.DISCORD_BOT_TOKEN }}" INTERNLM_API_KEY: "${{ secrets.INTERNLM_API_KEY }}" + JINA_API_KEY: "${{ secrets.JINA_API_KEY }}" run: | source venv/bin/activate pytest --fast-test-mode ./test diff --git a/.github/workflows/pytest_package.yml b/.github/workflows/pytest_package.yml index 4dd092f659..70c0f01151 100644 --- a/.github/workflows/pytest_package.yml +++ b/.github/workflows/pytest_package.yml @@ -58,6 +58,7 @@ jobs: DAPPIER_API_KEY: "${{ secrets.DAPPIER_API_KEY }}" DISCORD_BOT_TOKEN: "${{ secrets.DISCORD_BOT_TOKEN }}" INTERNLM_API_KEY: "${{ secrets.INTERNLM_API_KEY }}" + JINA_API_KEY: "${{ secrets.JINA_API_KEY }}" run: poetry run pytest --fast-test-mode test/ pytest_package_llm_test: @@ -105,6 +106,7 @@ jobs: DAPPIER_API_KEY: "${{ secrets.DAPPIER_API_KEY }}" DISCORD_BOT_TOKEN: "${{ secrets.DISCORD_BOT_TOKEN }}" INTERNLM_API_KEY: "${{ secrets.INTERNLM_API_KEY }}" + JINA_API_KEY: "${{ secrets.JINA_API_KEY }}" run: poetry run pytest --llm-test-only test/ pytest_package_very_slow_test: @@ -152,4 +154,5 @@ jobs: DAPPIER_API_KEY: "${{ secrets.DAPPIER_API_KEY }}" DISCORD_BOT_TOKEN: "${{ secrets.DISCORD_BOT_TOKEN }}" INTERNLM_API_KEY: "${{ secrets.INTERNLM_API_KEY }}" + JINA_API_KEY: "${{ secrets.JINA_API_KEY }}" run: poetry run pytest --very-slow-test-only test/ diff --git a/camel/embeddings/__init__.py b/camel/embeddings/__init__.py index e61e2768a8..a40d260758 100644 --- a/camel/embeddings/__init__.py +++ b/camel/embeddings/__init__.py @@ -12,6 +12,7 @@ # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= from .base import BaseEmbedding +from .jina_embedding import JinaEmbedding from .mistral_embedding import MistralEmbedding from .openai_compatible_embedding import OpenAICompatibleEmbedding from .openai_embedding import OpenAIEmbedding @@ -25,4 +26,5 @@ "VisionLanguageEmbedding", "MistralEmbedding", "OpenAICompatibleEmbedding", + "JinaEmbedding", ] diff --git a/camel/embeddings/jina_embedding.py b/camel/embeddings/jina_embedding.py new file mode 100644 index 0000000000..eca4473dea --- /dev/null +++ b/camel/embeddings/jina_embedding.py @@ -0,0 +1,156 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= + +import base64 +import io +import os +from typing import Any, Optional, Union + +import requests +from PIL import Image + +from camel.embeddings import BaseEmbedding +from camel.types.enums import EmbeddingModelType +from camel.utils import api_keys_required + + +class JinaEmbedding(BaseEmbedding[Union[str, Image.Image]]): + r"""Provides text and image embedding functionalities using Jina AI's API. + + Args: + model_type (EmbeddingModelType, optional): The model to use for + embeddings. (default: :obj:`JINA_EMBEDDINGS_V3`) + api_key (Optional[str], optional): The API key for authenticating with + Jina AI. (default: :obj:`None`) + dimensions (Optional[int], optional): The dimension of the output + embeddings. (default: :obj:`None`) + task (Optional[str], optional): The type of task for text embeddings. + Options: retrieval.query, retrieval.passage, text-matching, + classification, separation. (default: :obj:`None`) + late_chunking (bool, optional): If true, concatenates all sentences in + input and treats as a single input. (default: :obj:`False`) + normalized (bool, optional): If true, embeddings are normalized to unit + L2 norm. (default: :obj:`False`) + """ + + @api_keys_required([("api_key", 'JINA_API_KEY')]) + def __init__( + self, + model_type: EmbeddingModelType = EmbeddingModelType.JINA_EMBEDDINGS_V3, + api_key: Optional[str] = None, + dimensions: Optional[int] = None, + embedding_type: Optional[str] = None, + task: Optional[str] = None, + late_chunking: bool = False, + normalized: bool = False, + ) -> None: + if not model_type.is_jina: + raise ValueError( + f"Model type {model_type} is not a Jina model. " + "Please use a valid Jina model type." + ) + self.model_type = model_type + if dimensions is None: + self.output_dim = model_type.output_dim + else: + self.output_dim = dimensions + self._api_key = api_key or os.environ.get("JINA_API_KEY") + + self.embedding_type = embedding_type + self.task = task + self.late_chunking = late_chunking + self.normalized = normalized + self.url = 'https://api.jina.ai/v1/embeddings' + self.headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + 'Authorization': f'Bearer {self._api_key}', + } + + def embed_list( + self, + objs: list[Union[str, Image.Image]], + **kwargs: Any, + ) -> list[list[float]]: + r"""Generates embeddings for the given texts or images. + + Args: + objs (list[Union[str, Image.Image]]): The texts or images for which + to generate the embeddings. + **kwargs (Any): Extra kwargs passed to the embedding API. Not used + in this implementation. + + Returns: + list[list[float]]: A list that represents the generated embedding + as a list of floating-point numbers. + + Raises: + ValueError: If the input type is not supported. + RuntimeError: If the API request fails. + """ + input_data = [] + for obj in objs: + if isinstance(obj, str): + if self.model_type == EmbeddingModelType.JINA_CLIP_V2: + input_data.append({"text": obj}) + else: + input_data.append(obj) # type: ignore[arg-type] + elif isinstance(obj, Image.Image): + if self.model_type != EmbeddingModelType.JINA_CLIP_V2: + raise ValueError( + f"Model {self.model_type} does not support " + "image input. Use JINA_CLIP_V2 for image embeddings." + ) + # Convert PIL Image to base64 string + buffered = io.BytesIO() + obj.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode() + input_data.append({"image": img_str}) + else: + raise ValueError( + f"Input type {type(obj)} is not supported. " + "Must be either str or PIL.Image" + ) + + data = { + "model": self.model_type.value, + "input": input_data, + "embedding_type": "float", + } + + if self.embedding_type is not None: + data["embedding_type"] = self.embedding_type + if self.task is not None: + data["task"] = self.task + if self.late_chunking: + data["late_chunking"] = self.late_chunking # type: ignore[assignment] + if self.normalized: + data["normalized"] = self.normalized # type: ignore[assignment] + try: + response = requests.post( + self.url, headers=self.headers, json=data, timeout=180 + ) + response.raise_for_status() + result = response.json() + return [data["embedding"] for data in result["data"]] + except requests.exceptions.RequestException as e: + raise RuntimeError(f"Failed to get embeddings from Jina AI: {e}") + + def get_output_dim(self) -> int: + r"""Returns the output dimension of the embeddings. + + Returns: + int: The dimensionality of the embedding for the current model. + """ + return self.output_dim diff --git a/camel/types/enums.py b/camel/types/enums.py index 5622dece99..2b32f77cf5 100644 --- a/camel/types/enums.py +++ b/camel/types/enums.py @@ -567,6 +567,11 @@ class EmbeddingModelType(Enum): TEXT_EMBEDDING_3_SMALL = "text-embedding-3-small" TEXT_EMBEDDING_3_LARGE = "text-embedding-3-large" + JINA_EMBEDDINGS_V3 = "jina-embeddings-v3" + JINA_CLIP_V2 = "jina-clip-v2" + JINA_COLBERT_V2 = "jina-colbert-v2" + JINA_EMBEDDINGS_V2_BASE_CODE = "jina-embeddings-v2-base-code" + MISTRAL_EMBED = "mistral-embed" @property @@ -578,6 +583,16 @@ def is_openai(self) -> bool: EmbeddingModelType.TEXT_EMBEDDING_3_LARGE, } + @property + def is_jina(self) -> bool: + r"""Returns whether this type of models is an Jina model.""" + return self in { + EmbeddingModelType.JINA_EMBEDDINGS_V3, + EmbeddingModelType.JINA_CLIP_V2, + EmbeddingModelType.JINA_COLBERT_V2, + EmbeddingModelType.JINA_EMBEDDINGS_V2_BASE_CODE, + } + @property def is_mistral(self) -> bool: r"""Returns whether this type of models is an Mistral-released @@ -589,7 +604,20 @@ def is_mistral(self) -> bool: @property def output_dim(self) -> int: - if self is EmbeddingModelType.TEXT_EMBEDDING_ADA_2: + if self in { + EmbeddingModelType.JINA_COLBERT_V2, + }: + return 128 + elif self in { + EmbeddingModelType.JINA_EMBEDDINGS_V2_BASE_CODE, + }: + return 768 + elif self in { + EmbeddingModelType.JINA_EMBEDDINGS_V3, + EmbeddingModelType.JINA_CLIP_V2, + }: + return 1024 + elif self is EmbeddingModelType.TEXT_EMBEDDING_ADA_2: return 1536 elif self is EmbeddingModelType.TEXT_EMBEDDING_3_SMALL: return 1536 diff --git a/examples/embeddings/jina_embedding_example.py b/examples/embeddings/jina_embedding_example.py new file mode 100644 index 0000000000..193d46add2 --- /dev/null +++ b/examples/embeddings/jina_embedding_example.py @@ -0,0 +1,96 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= + +import requests +from PIL import Image + +from camel.embeddings import JinaEmbedding +from camel.types import EmbeddingModelType + +# Set the text embedding instance +jina_text_embed = JinaEmbedding( + model_type=EmbeddingModelType.JINA_EMBEDDINGS_V3, +) + +# Embed the text +text_embeddings = jina_text_embed.embed_list( + ["What is the capital of France?"] +) + +print(len(text_embeddings[0])) +''' +=============================================================================== +1024 +=============================================================================== +''' + + +# Set the code embedding instance +jina_code_embed = JinaEmbedding( + model_type=EmbeddingModelType.JINA_EMBEDDINGS_V2_BASE_CODE, + normalized=True, +) + +# Embed the code +code_embeddings = jina_code_embed.embed_list( + [ + "Calculates the square of a number. Parameters: number (int or float)" + " - The number to square. Returns: int or float - The square of the" + " number.", + "This function calculates the square of a number you give it.", + "def square(number): return number ** 2", + "print(square(5))", + "Output: 25", + "Each text can be up to 8192 tokens long", + ] +) + +print(len(code_embeddings[0])) +''' +=============================================================================== +768 +=============================================================================== +''' + +# Set the clip embedding instance +jina_clip_embed = JinaEmbedding( + model_type=EmbeddingModelType.JINA_CLIP_V2, +) + +# Embed the text +text_embeddings = jina_clip_embed.embed_list( + ["What is the capital of France?"] +) + +# Set example image to embed +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image_example = Image.open(requests.get(url, stream=True).raw) + +# Embed the image +image_embeddings = jina_clip_embed.embed_list([image_example]) + +print(len(text_embeddings[0])) +''' +=============================================================================== +1024 +=============================================================================== +''' + +print(len(image_embeddings[0])) + +''' +=============================================================================== +1024 +=============================================================================== +''' diff --git a/test/embeddings/test_jina_embedding.py b/test/embeddings/test_jina_embedding.py new file mode 100644 index 0000000000..11debf3660 --- /dev/null +++ b/test/embeddings/test_jina_embedding.py @@ -0,0 +1,142 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= + +import os +from unittest.mock import MagicMock, patch + +import pytest +import requests +from PIL import Image + +from camel.embeddings import JinaEmbedding +from camel.types import EmbeddingModelType + + +@patch.dict(os.environ, {"JINA_API_KEY": "fake_api_key"}) +@patch('requests.post') +def test_text_embed_list(mock_post): + # Mock the API response + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{"embedding": [0.1, 0.2, 0.3]}] + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Initialize embedding instance + embedding = JinaEmbedding() + + # Test text embedding + result = embedding.embed_list(["test text"]) + + # Verify the API was called correctly + mock_post.assert_called_once() + assert isinstance(result, list) + assert len(result) == 1 + assert result[0] == [0.1, 0.2, 0.3] + + +@patch.dict(os.environ, {"JINA_API_KEY": "fake_api_key"}) +@patch('requests.post') +def test_image_embed_list(mock_post): + # Mock the API response + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{"embedding": [0.1, 0.2, 0.3]}] + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Create a dummy image + img = Image.new('RGB', (60, 30), color='red') + + # Initialize embedding instance with CLIP model + embedding = JinaEmbedding(model_type=EmbeddingModelType.JINA_CLIP_V2) + + # Test image embedding + result = embedding.embed_list([img]) + + # Verify the API was called correctly + mock_post.assert_called_once() + assert isinstance(result, list) + assert len(result) == 1 + assert result[0] == [0.1, 0.2, 0.3] + + +@patch.dict(os.environ, {"JINA_API_KEY": "fake_api_key"}) +def test_invalid_model_type(): + # Test initialization with invalid model type + with pytest.raises(ValueError, match="is not a Jina model"): + JinaEmbedding(model_type=EmbeddingModelType.TEXT_EMBEDDING_3_SMALL) + + +@patch.dict(os.environ, {"JINA_API_KEY": "fake_api_key"}) +@patch('requests.post') +def test_embed_list_with_options(mock_post): + # Mock the API response + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{"embedding": [0.1, 0.2, 0.3]}] + } + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + # Initialize embedding instance with options + embedding = JinaEmbedding( + dimensions=3, + task="text-matching", + late_chunking=True, + normalized=True, + ) + + # Test embedding with options + result = embedding.embed_list(["test text"]) + + # Verify the API was called with correct parameters + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + assert len(result[0]) == 3 + assert "json" in call_kwargs + request_data = call_kwargs["json"] + assert request_data["task"] == "text-matching" + assert request_data["late_chunking"] is True + assert request_data["normalized"] is True + + +@patch.dict(os.environ, {"JINA_API_KEY": "fake_api_key"}) +def test_get_output_dim(): + # Test with default model + embedding = JinaEmbedding() + assert embedding.get_output_dim() == embedding.output_dim + + # Test with custom dimensions + custom_dim = 512 + embedding_custom = JinaEmbedding(dimensions=custom_dim) + assert embedding_custom.get_output_dim() == custom_dim + + +@patch.dict(os.environ, {"JINA_API_KEY": "fake_api_key"}) +@patch('requests.post') +def test_api_error_handling(mock_post): + # Mock a failed API response + mock_post.side_effect = requests.exceptions.RequestException("API Error") + + # Initialize embedding instance + embedding = JinaEmbedding() + + # Test error handling + with pytest.raises( + RuntimeError, match="Failed to get embeddings from Jina AI" + ): + embedding.embed_list(["test text"])