diff --git a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs index 37e00ec56154..7d10f74d6fb2 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI.UnitTests/Client/MistralClientTests.cs @@ -50,7 +50,7 @@ public async Task ValidateChatMessageRequestAsync() // Act var executionSettings = new MistralAIPromptExecutionSettings { MaxTokens = 1024, Temperature = 0.9 }; - await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings); + await client.GetChatMessageContentsAsync(chatHistory, executionSettings); // Assert var request = this.DelegatingHandler!.RequestContent; @@ -76,7 +76,7 @@ public async Task ValidateGetChatMessageContentsAsync() { new ChatMessageContent(AuthorRole.User, "What is the best French cheese?") }; - var response = await client.GetChatMessageContentsAsync(chatHistory, default); + var response = await client.GetChatMessageContentsAsync(chatHistory); // Assert Assert.NotNull(response); @@ -96,7 +96,7 @@ public async Task ValidateGenerateEmbeddingsAsync() // Act List data = ["Hello", "world"]; - var response = await client.GenerateEmbeddingsAsync(data, default); + var response = await client.GenerateEmbeddingsAsync(data); // Assert Assert.NotNull(response); @@ -117,7 +117,7 @@ public async Task ValidateGetStreamingChatMessageContentsAsync() }; // Act - var response = client.GetStreamingChatMessageContentsAsync(chatHistory, default); + var response = client.GetStreamingChatMessageContentsAsync(chatHistory); var chunks = new List(); await foreach (var chunk in response) { @@ -150,7 +150,7 @@ public async Task ValidateChatHistoryFirstSystemOrUserMessageAsync() }; // Act & Assert - await Assert.ThrowsAsync(async () => await client.GetChatMessageContentsAsync(chatHistory, default)); + await Assert.ThrowsAsync(async () => await client.GetChatMessageContentsAsync(chatHistory)); } [Fact] @@ -161,7 +161,7 @@ public async Task ValidateEmptyChatHistoryAsync() var chatHistory = new ChatHistory(); // Act & Assert - await Assert.ThrowsAsync(async () => await client.GetChatMessageContentsAsync(chatHistory, default)); + await Assert.ThrowsAsync(async () => await client.GetChatMessageContentsAsync(chatHistory)); } [Fact] @@ -181,7 +181,7 @@ public async Task ValidateChatMessageRequestWithToolsAsync() kernel.Plugins.AddFromType(); // Act - await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel); + await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); // Assert var request = this.DelegatingHandler!.RequestContent; @@ -212,7 +212,7 @@ public async Task ValidateGetStreamingChatMessageContentsWithToolsAsync() // Act var executionSettings = new MistralAIPromptExecutionSettings { ToolCallBehavior = MistralAIToolCallBehavior.AutoInvokeKernelFunctions }; - var response = client.GetStreamingChatMessageContentsAsync(chatHistory, default, executionSettings, kernel); + var response = client.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel); var chunks = new List(); await foreach (var chunk in response) { @@ -253,7 +253,7 @@ public async Task ValidateGetChatMessageContentsWithFunctionCallAsync() { new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") }; - var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel); + var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); // Assert Assert.NotNull(response); @@ -279,7 +279,7 @@ public async Task ValidateGetChatMessageContentsWithFunctionCallNoneAsync() { new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") }; - var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel); + var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); // Assert Assert.NotNull(response); @@ -307,7 +307,7 @@ public async Task ValidateGetChatMessageContentsWithFunctionCallRequiredAsync() { new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") }; - var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel); + var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); // Assert Assert.NotNull(response); @@ -345,7 +345,7 @@ public async Task ValidateGetChatMessageContentsWithFunctionInvocationFilterAsyn { new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") }; - var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel); + var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); // Assert Assert.NotNull(response); @@ -389,11 +389,11 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming) if (isStreaming) { - await client.GetStreamingChatMessageContentsAsync(chatHistory, default, executionSettings, kernel).ToListAsync(); + await client.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel).ToListAsync(); } else { - await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel); + await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); } // Assert @@ -428,7 +428,7 @@ public async Task ValidateGetChatMessageContentsWithAutoFunctionInvocationFilter { new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?") }; - var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel); + var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel); // Assert Assert.NotNull(response); @@ -465,7 +465,7 @@ public async Task ValidateGetStreamingChatMessageContentWithAutoFunctionInvocati List streamingContent = []; // Act - await foreach (var item in client.GetStreamingChatMessageContentsAsync(chatHistory, default, executionSettings, kernel)) + await foreach (var item in client.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel)) { streamingContent.Add(item); } @@ -497,7 +497,7 @@ public void ValidateToMistralChatMessages(string roleLabel, string content) }; // Act - var messages = client.ToMistralChatMessages(chatMessage, default); + var messages = client.ToMistralChatMessages(chatMessage); // Assert Assert.NotNull(messages); @@ -517,7 +517,7 @@ public void ValidateToMistralChatMessagesWithFunctionCallContent() }; // Act - var messages = client.ToMistralChatMessages(content, default); + var messages = client.ToMistralChatMessages(content); // Assert Assert.NotNull(messages); @@ -537,7 +537,7 @@ public void ValidateToMistralChatMessagesWithFunctionResultContent() }; // Act - var messages = client.ToMistralChatMessages(content, default); + var messages = client.ToMistralChatMessages(content); // Assert Assert.NotNull(messages); diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs index 9157073b244c..d43b9af5eab9 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs @@ -16,6 +16,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.FunctionCalling; using Microsoft.SemanticKernel.Diagnostics; using Microsoft.SemanticKernel.Http; using Microsoft.SemanticKernel.Text; @@ -27,6 +28,14 @@ namespace Microsoft.SemanticKernel.Connectors.MistralAI.Client; /// internal sealed class MistralClient { + /// + /// Create an instance of + /// + /// The model id + /// The HTTP Client + /// The API key + /// An optional endpoint URI + /// An optional logger internal MistralClient( string modelId, HttpClient httpClient, @@ -44,9 +53,17 @@ internal MistralClient( this._httpClient = httpClient; this._logger = logger ?? NullLogger.Instance; this._streamJsonParser = new StreamJsonParser(); + this._functionCallsProcessor = new FunctionCallsProcessor(this._logger); } - internal async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, CancellationToken cancellationToken, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null) + /// + /// Generate new chat message responses. + /// + /// Chat history + /// Prompt execution settings + /// Kernel instance + /// Cancellation token + internal async Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { this.ValidateChatHistory(chatHistory); @@ -264,7 +281,14 @@ internal async Task> GetChatMessageContentsAsy } } - internal async IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, [EnumeratorCancellation] CancellationToken cancellationToken, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null) + /// + /// Generate new streaming chat message responses. + /// + /// Chat history + /// Prompt execution settings + /// Kernel instance + /// Cancellation token + internal async IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { this.ValidateChatHistory(chatHistory); @@ -498,75 +522,93 @@ internal async IAsyncEnumerable GetStreamingChatMes } } - private async IAsyncEnumerable StreamChatMessageContentsAsync(ChatHistory chatHistory, MistralAIPromptExecutionSettings executionSettings, ChatCompletionRequest chatRequest, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken) + /// + /// Generate embeddings for the provided data. + /// + /// Data to create embeddings for + /// Prompt execution settings + /// Kernel instance + /// Cancellation token + internal async Task>> GenerateEmbeddingsAsync(IList data, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) { - this.ValidateChatHistory(chatHistory); + var request = new TextEmbeddingRequest(this._modelId, data); + var mistralExecutionSettings = MistralAIPromptExecutionSettings.FromExecutionSettings(executionSettings); + var endpoint = this.GetEndpoint(mistralExecutionSettings, path: "embeddings"); + using var httpRequestMessage = this.CreatePost(request, endpoint, this._apiKey, false); - var endpoint = this.GetEndpoint(executionSettings, path: "chat/completions"); - using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: true); - using var response = await this.SendStreamingRequestAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); - using var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false); - await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) - { - yield return streamingChatContent; - } + var response = await this.SendRequestAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + + return response.Data!.Select(item => new ReadOnlyMemory([.. item.Embedding])).ToList(); } - private async IAsyncEnumerable ProcessChatResponseStreamAsync(Stream stream, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken) + /// + /// Convert to + /// + /// Chat message content + /// Tool call behavior + internal List ToMistralChatMessages(ChatMessageContent content, MistralAIToolCallBehavior? toolCallBehavior = null) { - IAsyncEnumerator? responseEnumerator = null; - - try + if (content.Role == AuthorRole.Assistant) { - var responseEnumerable = this.ParseChatResponseStreamAsync(stream, cancellationToken); - responseEnumerator = responseEnumerable.GetAsyncEnumerator(cancellationToken); - - string? currentRole = null; - while (await responseEnumerator.MoveNextAsync().ConfigureAwait(false)) + // Handling function calls supplied via ChatMessageContent.Items collection elements of the FunctionCallContent type. + var message = new MistralChatMessage(content.Role.ToString(), content.Content ?? string.Empty); + Dictionary toolCalls = []; + foreach (var item in content.Items) { - var chunk = responseEnumerator.Current!; - - for (int i = 0; i < chunk.GetChoiceCount(); i++) + if (item is not FunctionCallContent callRequest) { - currentRole ??= chunk.GetRole(i); + continue; + } - yield return new(role: new AuthorRole(currentRole ?? "assistant"), - content: chunk.GetContent(i), - choiceIndex: i, - modelId: modelId, - encoding: chunk.GetEncoding(), - innerContent: chunk, - metadata: chunk.GetMetadata()); + if (callRequest.Id is null || toolCalls.ContainsKey(callRequest.Id)) + { + continue; } + + var arguments = JsonSerializer.Serialize(callRequest.Arguments); + var toolCall = new MistralToolCall() + { + Id = callRequest.Id, + Function = new MistralFunction( + callRequest.FunctionName, + callRequest.PluginName) + { + Arguments = arguments + } + }; + toolCalls.Add(callRequest.Id, toolCall); } - } - finally - { - if (responseEnumerator != null) + if (toolCalls.Count > 0) { - await responseEnumerator.DisposeAsync().ConfigureAwait(false); + message.ToolCalls = [.. toolCalls.Values]; } + return [message]; } - } - private async IAsyncEnumerable ParseChatResponseStreamAsync(Stream responseStream, [EnumeratorCancellation] CancellationToken cancellationToken) - { - await foreach (var json in this._streamJsonParser.ParseAsync(responseStream, cancellationToken: cancellationToken).ConfigureAwait(false)) + if (content.Role == AuthorRole.Tool) { - yield return DeserializeResponse(json); - } - } + List? messages = null; + foreach (var item in content.Items) + { + if (item is not FunctionResultContent resultContent) + { + continue; + } - internal async Task>> GenerateEmbeddingsAsync(IList data, CancellationToken cancellationToken, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null) - { - var request = new TextEmbeddingRequest(this._modelId, data); - var mistralExecutionSettings = MistralAIPromptExecutionSettings.FromExecutionSettings(executionSettings); - var endpoint = this.GetEndpoint(mistralExecutionSettings, path: "embeddings"); - using var httpRequestMessage = this.CreatePost(request, endpoint, this._apiKey, false); + messages ??= []; - var response = await this.SendRequestAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + var stringResult = ProcessFunctionResult(resultContent.Result ?? string.Empty, toolCallBehavior); + messages.Add(new MistralChatMessage(content.Role.ToString(), stringResult)); + } + if (messages is not null) + { + return messages; + } - return response.Data!.Select(item => new ReadOnlyMemory([.. item.Embedding])).ToList(); + throw new NotSupportedException("No function result provided in the tool message."); + } + + return [new MistralChatMessage(content.Role.ToString(), content.Content ?? string.Empty)]; } #region private @@ -576,6 +618,7 @@ internal async Task>> GenerateEmbeddingsAsync(IList< private readonly HttpClient _httpClient; private readonly ILogger _logger; private readonly StreamJsonParser _streamJsonParser; + private readonly FunctionCallsProcessor _functionCallsProcessor; /// Provider name used for diagnostics. private const string ModelProvider = "mistralai"; @@ -701,69 +744,63 @@ private ChatCompletionRequest CreateChatCompletionRequest(string modelId, bool s return request; } - internal List ToMistralChatMessages(ChatMessageContent content, MistralAIToolCallBehavior? toolCallBehavior) + private async IAsyncEnumerable StreamChatMessageContentsAsync(ChatHistory chatHistory, MistralAIPromptExecutionSettings executionSettings, ChatCompletionRequest chatRequest, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken) { - if (content.Role == AuthorRole.Assistant) - { - // Handling function calls supplied via ChatMessageContent.Items collection elements of the FunctionCallContent type. - var message = new MistralChatMessage(content.Role.ToString(), content.Content ?? string.Empty); - Dictionary toolCalls = []; - foreach (var item in content.Items) - { - if (item is not FunctionCallContent callRequest) - { - continue; - } - - if (callRequest.Id is null || toolCalls.ContainsKey(callRequest.Id)) - { - continue; - } + this.ValidateChatHistory(chatHistory); - var arguments = JsonSerializer.Serialize(callRequest.Arguments); - var toolCall = new MistralToolCall() - { - Id = callRequest.Id, - Function = new MistralFunction( - callRequest.FunctionName, - callRequest.PluginName) - { - Arguments = arguments - } - }; - toolCalls.Add(callRequest.Id, toolCall); - } - if (toolCalls.Count > 0) - { - message.ToolCalls = [.. toolCalls.Values]; - } - return [message]; + var endpoint = this.GetEndpoint(executionSettings, path: "chat/completions"); + using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: true); + using var response = await this.SendStreamingRequestAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); + using var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false); + await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false)) + { + yield return streamingChatContent; } + } - if (content.Role == AuthorRole.Tool) + private async IAsyncEnumerable ProcessChatResponseStreamAsync(Stream stream, string modelId, [EnumeratorCancellation] CancellationToken cancellationToken) + { + IAsyncEnumerator? responseEnumerator = null; + + try { - List? messages = null; - foreach (var item in content.Items) + var responseEnumerable = this.ParseChatResponseStreamAsync(stream, cancellationToken); + responseEnumerator = responseEnumerable.GetAsyncEnumerator(cancellationToken); + + string? currentRole = null; + while (await responseEnumerator.MoveNextAsync().ConfigureAwait(false)) { - if (item is not FunctionResultContent resultContent) - { - continue; - } + var chunk = responseEnumerator.Current!; - messages ??= []; + for (int i = 0; i < chunk.GetChoiceCount(); i++) + { + currentRole ??= chunk.GetRole(i); - var stringResult = ProcessFunctionResult(resultContent.Result ?? string.Empty, toolCallBehavior); - messages.Add(new MistralChatMessage(content.Role.ToString(), stringResult)); + yield return new(role: new AuthorRole(currentRole ?? "assistant"), + content: chunk.GetContent(i), + choiceIndex: i, + modelId: modelId, + encoding: chunk.GetEncoding(), + innerContent: chunk, + metadata: chunk.GetMetadata()); + } } - if (messages is not null) + } + finally + { + if (responseEnumerator != null) { - return messages; + await responseEnumerator.DisposeAsync().ConfigureAwait(false); } - - throw new NotSupportedException("No function result provided in the tool message."); } + } - return [new MistralChatMessage(content.Role.ToString(), content.Content ?? string.Empty)]; + private async IAsyncEnumerable ParseChatResponseStreamAsync(Stream responseStream, [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (var json in this._streamJsonParser.ParseAsync(responseStream, cancellationToken: cancellationToken).ConfigureAwait(false)) + { + yield return DeserializeResponse(json); + } } private HttpRequestMessage CreatePost(object requestData, Uri endpoint, string apiKey, bool stream) diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Connectors.MistralAI.csproj b/dotnet/src/Connectors/Connectors.MistralAI/Connectors.MistralAI.csproj index 8edcf0ed416e..02ebdc127f87 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Connectors.MistralAI.csproj +++ b/dotnet/src/Connectors/Connectors.MistralAI/Connectors.MistralAI.csproj @@ -12,6 +12,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Services/MistralAIChatCompletionService.cs b/dotnet/src/Connectors/Connectors.MistralAI/Services/MistralAIChatCompletionService.cs index bbaa136ea07d..1af3238baf34 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Services/MistralAIChatCompletionService.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Services/MistralAIChatCompletionService.cs @@ -45,11 +45,11 @@ public MistralAIChatCompletionService(string modelId, string apiKey, Uri? endpoi /// public Task> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) - => this.Client.GetChatMessageContentsAsync(chatHistory, cancellationToken, executionSettings, kernel); + => this.Client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); /// public IAsyncEnumerable GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default) - => this.Client.GetStreamingChatMessageContentsAsync(chatHistory, cancellationToken, executionSettings, kernel); + => this.Client.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel, cancellationToken); #region private private Dictionary AttributesInternal { get; } = new(); diff --git a/dotnet/src/Connectors/Connectors.MistralAI/Services/MistralAITextEmbeddingGenerationService.cs b/dotnet/src/Connectors/Connectors.MistralAI/Services/MistralAITextEmbeddingGenerationService.cs index 018418f79184..8a6cca7ecfa0 100644 --- a/dotnet/src/Connectors/Connectors.MistralAI/Services/MistralAITextEmbeddingGenerationService.cs +++ b/dotnet/src/Connectors/Connectors.MistralAI/Services/MistralAITextEmbeddingGenerationService.cs @@ -45,7 +45,7 @@ public MistralAITextEmbeddingGenerationService(string modelId, string apiKey, Ur /// public Task>> GenerateEmbeddingsAsync(IList data, Kernel? kernel = null, CancellationToken cancellationToken = default) - => this.Client.GenerateEmbeddingsAsync(data, cancellationToken, executionSettings: null, kernel); + => this.Client.GenerateEmbeddingsAsync(data, executionSettings: null, kernel, cancellationToken); #region private private Dictionary AttributesInternal { get; } = [];