Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Upgrading Mistral AI Connector to Version 1.0 #9542

Merged
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ milvus = [
"milvus >= 2.3,<2.3.8; platform_system != 'Windows'"
]
mistralai = [
"mistralai >= 0.4,< 1.0"
"mistralai >= 1.2,< 2.0"
]
ollama = [
"ollama ~= 0.2"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import sys
from typing import Any, Literal

from mistralai import utils

if sys.version_info >= (3, 11):
pass # pragma: no cover
else:
pass # pragma: no cover

from pydantic import Field
from pydantic import Field, field_validator

from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings

Expand All @@ -27,13 +29,19 @@ class MistralAIChatPromptExecutionSettings(MistralAIPromptExecutionSettings):

response_format: dict[Literal["type"], Literal["text", "json_object"]] | None = None
messages: list[dict[str, Any]] | None = None
safe_mode: bool = False
safe_mode: bool = Field(False, exclude=True)
safe_prompt: bool = False
max_tokens: int | None = Field(None, gt=0)
seed: int | None = None
temperature: float | None = Field(None, ge=0.0, le=2.0)
top_p: float | None = Field(None, ge=0.0, le=1.0)
random_seed: int | None = None
presence_penalty: float | None = Field(None, gt=0)
frequency_penalty: float | None = Field(None, gt=0)
n: int | None = Field(None, gt=1)
retries: utils.RetryConfig | None = None
server_url: str | None = None
timeout_ms: int | None = None
tools: list[dict[str, Any]] | None = Field(
None,
max_length=64,
Expand All @@ -43,3 +51,12 @@ class MistralAIChatPromptExecutionSettings(MistralAIPromptExecutionSettings):
None,
description="Do not set this manually. It is set by the service based on the function choice configuration.",
)

@field_validator("safe_mode")
@classmethod
def check_safe_mode(cls, v: bool) -> bool:
"""The safe_mode setting is no longer supported."""
logger.warning(
"The 'safe_mode' setting is no longer supported and is being ignored, it will be removed in the Future."
)
return v
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC
from typing import ClassVar

from mistralai.async_client import MistralAsyncClient
from mistralai import Mistral

from semantic_kernel.kernel_pydantic import KernelBaseModel

Expand All @@ -13,4 +13,4 @@ class MistralAIBase(KernelBaseModel, ABC):

MODEL_PROVIDER_NAME: ClassVar[str] = "mistralai"

async_client: MistralAsyncClient
async_client: Mistral
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
else:
from typing_extensions import override # pragma: no cover

from mistralai.async_client import MistralAsyncClient
from mistralai.models.chat_completion import (
from mistralai import Mistral
from mistralai.models import (
AssistantMessage,
ChatCompletionChoice,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionChunk,
CompletionResponseStreamChoice,
DeltaMessage,
ToolCall,
)
from pydantic import ValidationError

Expand Down Expand Up @@ -63,22 +64,22 @@ def __init__(
ai_model_id: str | None = None,
service_id: str | None = None,
api_key: str | None = None,
async_client: MistralAsyncClient | None = None,
async_client: Mistral | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize an MistralAIChatCompletion service.

Args:
ai_model_id (str): MistralAI model name, see
ai_model_id : MistralAI model name, see
https://docs.mistral.ai/getting-started/models/
service_id (str | None): Service ID tied to the execution settings.
api_key (str | None): The optional API key to use. If provided will override,
service_id : Service ID tied to the execution settings.
api_key : The optional API key to use. If provided will override,
the env vars or .env file value.
async_client (MistralAsyncClient | None) : An existing client to use.
env_file_path (str | None): Use the environment settings file as a fallback
async_client : An existing client to use.
env_file_path : Use the environment settings file as a fallback
to environment variables.
env_file_encoding (str | None): The encoding of the environment settings file.
env_file_encoding : The encoding of the environment settings file.
"""
try:
mistralai_settings = MistralAISettings.create(
Expand All @@ -94,7 +95,7 @@ def __init__(
raise ServiceInitializationError("The MistralAI chat model ID is required.")

if not async_client:
async_client = MistralAsyncClient(
async_client = Mistral(
api_key=mistralai_settings.api_key.get_secret_value(),
)

Expand Down Expand Up @@ -135,15 +136,22 @@ async def _inner_get_chat_message_contents(
settings.messages = self._prepare_chat_history_for_request(chat_history)

try:
response = await self.async_client.chat(**settings.prepare_settings_dict())
response = await self.async_client.chat.complete_async(**settings.prepare_settings_dict())
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to complete the prompt",
ex,
) from ex

response_metadata = self._get_metadata_from_response(response)
return [self._create_chat_message_content(response, choice, response_metadata) for choice in response.choices]
if isinstance(response, ChatCompletionResponse):
response_metadata = self._get_metadata_from_response(response)
# If there are no choices, return an empty list
if isinstance(response.choices, list) and len(response.choices) > 0:
return [
self._create_chat_message_content(response, choice, response_metadata)
for choice in response.choices
]
return []

@override
@trace_streaming_chat_completion(MistralAIBase.MODEL_PROVIDER_NAME)
Expand All @@ -160,26 +168,30 @@ async def _inner_get_streaming_chat_message_contents(
settings.messages = self._prepare_chat_history_for_request(chat_history)

try:
response = self.async_client.chat_stream(**settings.prepare_settings_dict())
response = await self.async_client.chat.stream_async(**settings.prepare_settings_dict())
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to complete the prompt",
ex,
) from ex
async for chunk in response:
if len(chunk.choices) == 0:
continue
chunk_metadata = self._get_metadata_from_response(chunk)
yield [
self._create_streaming_chat_message_content(chunk, choice, chunk_metadata) for choice in chunk.choices
]

# If there is no response end the generator
if isinstance(response, AsyncGenerator):
async for chunk in response:
if len(chunk.data.choices) == 0:
continue
chunk_metadata = self._get_metadata_from_response(chunk.data)
yield [
self._create_streaming_chat_message_content(chunk.data, choice, chunk_metadata)
for choice in chunk.data.choices
]

# endregion

# region content conversion to SK

def _create_chat_message_content(
self, response: ChatCompletionResponse, choice: ChatCompletionResponseChoice, response_metadata: dict[str, Any]
self, response: ChatCompletionResponse, choice: ChatCompletionChoice, response_metadata: dict[str, Any]
) -> "ChatMessageContent":
"""Create a chat message content object from a choice."""
metadata = self._get_metadata_from_chat_choice(choice)
Expand All @@ -201,8 +213,8 @@ def _create_chat_message_content(

def _create_streaming_chat_message_content(
self,
chunk: ChatCompletionStreamResponse,
choice: ChatCompletionResponseStreamChoice,
chunk: CompletionChunk,
choice: CompletionResponseStreamChoice,
chunk_metadata: dict[str, Any],
) -> StreamingChatMessageContent:
"""Create a streaming chat message content object from a choice."""
Expand All @@ -224,9 +236,7 @@ def _create_streaming_chat_message_content(
items=items,
)

def _get_metadata_from_response(
self, response: ChatCompletionResponse | ChatCompletionStreamResponse
) -> dict[str, Any]:
def _get_metadata_from_response(self, response: ChatCompletionResponse | CompletionChunk) -> dict[str, Any]:
"""Get metadata from a chat response."""
metadata: dict[str, Any] = {
"id": response.id,
Expand All @@ -244,19 +254,19 @@ def _get_metadata_from_response(
return metadata

def _get_metadata_from_chat_choice(
self, choice: ChatCompletionResponseChoice | ChatCompletionResponseStreamChoice
self, choice: ChatCompletionChoice | CompletionResponseStreamChoice
) -> dict[str, Any]:
"""Get metadata from a chat choice."""
return {
"logprobs": getattr(choice, "logprobs", None),
}

def _get_tool_calls_from_chat_choice(
self, choice: ChatCompletionResponseChoice | ChatCompletionResponseStreamChoice
self, choice: ChatCompletionChoice | CompletionResponseStreamChoice
) -> list[FunctionCallContent]:
"""Get tool calls from a chat choice."""
content: ChatMessage | DeltaMessage
content = choice.message if isinstance(choice, ChatCompletionResponseChoice) else choice.delta
content: AssistantMessage | DeltaMessage
content = choice.message if isinstance(choice, ChatCompletionChoice) else choice.delta
if content.tool_calls is None:
return []

Expand All @@ -268,6 +278,7 @@ def _get_tool_calls_from_chat_choice(
arguments=tool.function.arguments,
)
for tool in content.tool_calls
if isinstance(tool, ToolCall)
]

# endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import logging

from mistralai.async_client import MistralAsyncClient
from mistralai.models.embeddings import EmbeddingResponse
from mistralai import Mistral
from mistralai.models import EmbeddingResponse
from numpy import array, ndarray
from pydantic import ValidationError

Expand All @@ -33,7 +33,7 @@ def __init__(
ai_model_id: str | None = None,
api_key: str | None = None,
service_id: str | None = None,
async_client: MistralAsyncClient | None = None,
async_client: Mistral | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
Expand All @@ -45,12 +45,12 @@ def __init__(
- MISTRALAI_EMBEDDING_MODEL_ID

Args:
ai_model_id: (str | None): A string that is used to identify the model such as the model name.
api_key (str | None): The API key for the Mistral AI service deployment.
service_id (str | None): Service ID for the embedding completion service.
async_client (MistralAsyncClient | None): The Mistral AI client to use.
env_file_path (str | None): The path to the environment file.
env_file_encoding (str | None): The encoding of the environment file.
ai_model_id: : A string that is used to identify the model such as the model name.
api_key : The API key for the Mistral AI service deployment.
service_id : Service ID for the embedding completion service.
async_client : The Mistral AI client to use.
env_file_path : The path to the environment file.
env_file_encoding : The encoding of the environment file.

Raises:
ServiceInitializationError: If an error occurs during initialization.
Expand All @@ -69,8 +69,9 @@ def __init__(
raise ServiceInitializationError("The MistralAI embedding model ID is required.")

if not async_client:
async_client = MistralAsyncClient(api_key=mistralai_settings.api_key.get_secret_value())

async_client = Mistral(
api_key=mistralai_settings.api_key.get_secret_value(),
)
super().__init__(
service_id=service_id or mistralai_settings.embedding_model_id,
ai_model_id=ai_model_id or mistralai_settings.embedding_model_id,
Expand All @@ -96,13 +97,12 @@ async def generate_raw_embeddings(
) -> Any:
"""Generate embeddings from the Mistral AI service."""
try:
embedding_response: EmbeddingResponse = await self.async_client.embeddings(
model=self.ai_model_id, input=texts
)
embedding_response = await self.async_client.embeddings.create_async(model=self.ai_model_id, inputs=texts)
except Exception as ex:
raise ServiceResponseException(
f"{type(self)} service failed to complete the embedding request.",
ex,
) from ex

return [item.embedding for item in embedding_response.data]
if isinstance(embedding_response, EmbeddingResponse):
return [item.embedding for item in embedding_response.data]
return []
Loading
Loading