Skip to content

Commit

Permalink
feat: Integrate mistral (#754)
Browse files Browse the repository at this point in the history
Co-authored-by: Ruslan Serebriakov <[email protected]>
Co-authored-by: 李国豪 <[email protected]>
  • Loading branch information
3 people authored Jul 29, 2024
1 parent a0ea1aa commit e82a63f
Show file tree
Hide file tree
Showing 21 changed files with 1,290 additions and 559 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/build_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,8 @@ jobs:
AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }}"
AZURE_DEPLOYMENT_NAME: ${{ secrets.AZURE_DEPLOYMENT_NAME }}"
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }}"
MISTRAL_API_KEY: "${{ secrets.MISTRAL_API_KEY }}"
NEO4J_URI: "${{ secrets.NEO4J_URI }}"
NEO4J_USERNAME: "${{ secrets.NEO4J_USERNAME }}"
NEO4J_PASSWORD: "${{ secrets.NEO4J_PASSWORD }}"
run: pytest --fast-test-mode ./test
12 changes: 12 additions & 0 deletions .github/workflows/pytest_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ jobs:
AZURE_API_VERSION: "to-be-filled"
AZURE_DEPLOYMENT_NAME: "to-be-filled"
AZURE_OPENAI_ENDPOINT: "https://camel.openai.azure.com/"
MISTRAL_API_KEY: "${{ secrets.MISTRAL_API_KEY }}"
NEO4J_URI: "${{ secrets.NEO4J_URI }}"
NEO4J_USERNAME: "${{ secrets.NEO4J_USERNAME }}"
NEO4J_PASSWORD: "${{ secrets.NEO4J_PASSWORD }}"
run: poetry run pytest --fast-test-mode test/

pytest_package_llm_test:
Expand Down Expand Up @@ -67,6 +71,10 @@ jobs:
AZURE_API_VERSION: "to-be-filled"
AZURE_DEPLOYMENT_NAME: "to-be-filled"
AZURE_OPENAI_ENDPOINT: "https://camel.openai.azure.com/"
MISTRAL_API_KEY: "${{ secrets.MISTRAL_API_KEY }}"
NEO4J_URI: "${{ secrets.NEO4J_URI }}"
NEO4J_USERNAME: "${{ secrets.NEO4J_USERNAME }}"
NEO4J_PASSWORD: "${{ secrets.NEO4J_PASSWORD }}"
run: poetry run pytest --llm-test-only test/

pytest_package_very_slow_test:
Expand Down Expand Up @@ -95,4 +103,8 @@ jobs:
AZURE_API_VERSION: ${{ secrets.AZURE_API_VERSION }}"
AZURE_DEPLOYMENT_NAME: ${{ secrets.AZURE_DEPLOYMENT_NAME }}"
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }}"
MISTRAL_API_KEY: "${{ secrets.MISTRAL_API_KEY }}"
NEO4J_URI: "${{ secrets.NEO4J_URI }}"
NEO4J_USERNAME: "${{ secrets.NEO4J_USERNAME }}"
NEO4J_PASSWORD: "${{ secrets.NEO4J_PASSWORD }}"
run: poetry run pytest --very-slow-test-only test/
3 changes: 3 additions & 0 deletions camel/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from .groq_config import GROQ_API_PARAMS, GroqConfig
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
from .openai_config import (
OPENAI_API_PARAMS,
Expand Down Expand Up @@ -47,4 +48,6 @@
'Gemini_API_PARAMS',
'VLLMConfig',
'VLLM_API_PARAMS',
'MistralConfig',
'MISTRAL_API_PARAMS',
]
81 changes: 81 additions & 0 deletions camel/configs/mistral_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from __future__ import annotations

from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Dict, Optional, Union

from camel.configs.base_config import BaseConfig

if TYPE_CHECKING:
from mistralai.models.chat_completion import (
ResponseFormat,
)

from camel.toolkits import OpenAIFunction


@dataclass(frozen=True)
class MistralConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using the
Mistral API.
reference: https://github.com/mistralai/client-python/blob/9d238f88c41689821d7b08570f13b43426f97fd6/src/mistralai/client.py#L195
Args:
temperature (Optional[float], optional): temperature the temperature
to use for sampling, e.g. 0.5.
max_tokens (Optional[int], optional): the maximum number of tokens to
generate, e.g. 100. Defaults to None.
top_p (Optional[float], optional): the cumulative probability of
tokens to generate, e.g. 0.9. Defaults to None.
random_seed (Optional[int], optional): the random seed to use for
sampling, e.g. 42. Defaults to None.
safe_mode (bool, optional): deprecated, use safe_prompt instead.
Defaults to False.
safe_prompt (bool, optional): whether to use safe prompt, e.g. true.
Defaults to False.
response_format (Union[Dict[str, str], ResponseFormat): format of the
response.
tools (Optional[list[OpenAIFunction]], optional): a list of tools to
use.
tool_choice (str, optional): Controls which (if
any) tool is called by the model. :obj:`"none"` means the model
will not call any tool and instead generates a message.
:obj:`"auto"` means the model can pick between generating a
message or calling one or more tools. :obj:`"any"` means the
model must call one or more tools. :obj:`"auto"` is the default
value.
"""

temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_p: Optional[float] = None
random_seed: Optional[int] = None
safe_mode: bool = False
safe_prompt: bool = False
response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None
tools: Optional[list[OpenAIFunction]] = None
tool_choice: Optional[str] = "auto"

def __post_init__(self):
if self.tools is not None:
object.__setattr__(
self,
'tools',
[tool.get_openai_tool_schema() for tool in self.tools],
)


MISTRAL_API_PARAMS = {param for param in asdict(MistralConfig()).keys()}
2 changes: 2 additions & 0 deletions camel/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from .base import BaseEmbedding
from .mistral_embedding import MistralEmbedding
from .openai_embedding import OpenAIEmbedding
from .sentence_transformers_embeddings import SentenceTransformerEncoder
from .vlm_embedding import VisionLanguageEmbedding
Expand All @@ -21,4 +22,5 @@
"OpenAIEmbedding",
"SentenceTransformerEncoder",
"VisionLanguageEmbedding",
"MistralEmbedding",
]
89 changes: 89 additions & 0 deletions camel/embeddings/mistral_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from __future__ import annotations

import os
from typing import Any

from camel.embeddings.base import BaseEmbedding
from camel.types import EmbeddingModelType
from camel.utils import api_keys_required


class MistralEmbedding(BaseEmbedding[str]):
r"""Provides text embedding functionalities using Mistral's models.
Args:
model_type (EmbeddingModelType, optional): The model type to be
used for text embeddings.
(default: :obj:`MISTRAL_EMBED`)
api_key (str, optional): The API key for authenticating with the
Mistral service. (default: :obj:`None`)
dimensions (int, optional): The text embedding output dimensions.
(default: :obj:`None`)
Raises:
RuntimeError: If an unsupported model type is specified.
"""

def __init__(
self,
model_type: EmbeddingModelType = (EmbeddingModelType.MISTRAL_EMBED),
api_key: str | None = None,
dimensions: int | None = None,
) -> None:
from mistralai.client import MistralClient

if not model_type.is_mistral:
raise ValueError("Invalid Mistral embedding model type.")
self.model_type = model_type
if dimensions is None:
self.output_dim = model_type.output_dim
else:
assert isinstance(dimensions, int)
self.output_dim = dimensions
self._api_key = api_key or os.environ.get("MISTRAL_API_KEY")
self._client = MistralClient(api_key=self._api_key)

@api_keys_required("MISTRAL_API_KEY")
def embed_list(
self,
objs: list[str],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts.
Args:
objs (list[str]): The texts for which to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
"""
# TODO: count tokens
response = self._client.embeddings(
input=objs,
model=self.model_type.value,
**kwargs,
)
return [data.embedding for data in response.data]

def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
return self.output_dim
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .gemini_model import GeminiModel
from .groq_model import GroqModel
from .litellm_model import LiteLLMModel
from .mistral_model import MistralModel
from .model_factory import ModelFactory
from .nemotron_model import NemotronModel
from .ollama_model import OllamaModel
Expand All @@ -32,6 +33,7 @@
'OpenAIModel',
'AzureOpenAIModel',
'AnthropicModel',
'MistralModel',
'GroqModel',
'StubModel',
'ZhipuAIModel',
Expand Down
Loading

0 comments on commit e82a63f

Please sign in to comment.