Skip to content

Commit

Permalink
.Net: First step in moving to use function calling processor (#9562)
Browse files Browse the repository at this point in the history
### Motivation and Context

First step in updating the MistralAI connector to use the function
calling processor

### Description

1. Some code clean up
2. Add the FunctionCallsProcessor

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
markwallace-microsoft authored Nov 6, 2024
1 parent 3332079 commit 3a88dcf
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public async Task ValidateChatMessageRequestAsync()

// Act
var executionSettings = new MistralAIPromptExecutionSettings { MaxTokens = 1024, Temperature = 0.9 };
await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings);
await client.GetChatMessageContentsAsync(chatHistory, executionSettings);

// Assert
var request = this.DelegatingHandler!.RequestContent;
Expand All @@ -76,7 +76,7 @@ public async Task ValidateGetChatMessageContentsAsync()
{
new ChatMessageContent(AuthorRole.User, "What is the best French cheese?")
};
var response = await client.GetChatMessageContentsAsync(chatHistory, default);
var response = await client.GetChatMessageContentsAsync(chatHistory);

// Assert
Assert.NotNull(response);
Expand All @@ -96,7 +96,7 @@ public async Task ValidateGenerateEmbeddingsAsync()

// Act
List<string> data = ["Hello", "world"];
var response = await client.GenerateEmbeddingsAsync(data, default);
var response = await client.GenerateEmbeddingsAsync(data);

// Assert
Assert.NotNull(response);
Expand All @@ -117,7 +117,7 @@ public async Task ValidateGetStreamingChatMessageContentsAsync()
};

// Act
var response = client.GetStreamingChatMessageContentsAsync(chatHistory, default);
var response = client.GetStreamingChatMessageContentsAsync(chatHistory);
var chunks = new List<StreamingChatMessageContent>();
await foreach (var chunk in response)
{
Expand Down Expand Up @@ -150,7 +150,7 @@ public async Task ValidateChatHistoryFirstSystemOrUserMessageAsync()
};

// Act & Assert
await Assert.ThrowsAsync<ArgumentException>(async () => await client.GetChatMessageContentsAsync(chatHistory, default));
await Assert.ThrowsAsync<ArgumentException>(async () => await client.GetChatMessageContentsAsync(chatHistory));
}

[Fact]
Expand All @@ -161,7 +161,7 @@ public async Task ValidateEmptyChatHistoryAsync()
var chatHistory = new ChatHistory();

// Act & Assert
await Assert.ThrowsAsync<ArgumentException>(async () => await client.GetChatMessageContentsAsync(chatHistory, default));
await Assert.ThrowsAsync<ArgumentException>(async () => await client.GetChatMessageContentsAsync(chatHistory));
}

[Fact]
Expand All @@ -181,7 +181,7 @@ public async Task ValidateChatMessageRequestWithToolsAsync()
kernel.Plugins.AddFromType<WeatherPlugin>();

// Act
await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel);
await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel);

// Assert
var request = this.DelegatingHandler!.RequestContent;
Expand Down Expand Up @@ -212,7 +212,7 @@ public async Task ValidateGetStreamingChatMessageContentsWithToolsAsync()

// Act
var executionSettings = new MistralAIPromptExecutionSettings { ToolCallBehavior = MistralAIToolCallBehavior.AutoInvokeKernelFunctions };
var response = client.GetStreamingChatMessageContentsAsync(chatHistory, default, executionSettings, kernel);
var response = client.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel);
var chunks = new List<StreamingChatMessageContent>();
await foreach (var chunk in response)
{
Expand Down Expand Up @@ -253,7 +253,7 @@ public async Task ValidateGetChatMessageContentsWithFunctionCallAsync()
{
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};
var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel);
var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel);

// Assert
Assert.NotNull(response);
Expand All @@ -279,7 +279,7 @@ public async Task ValidateGetChatMessageContentsWithFunctionCallNoneAsync()
{
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};
var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel);
var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel);

// Assert
Assert.NotNull(response);
Expand Down Expand Up @@ -307,7 +307,7 @@ public async Task ValidateGetChatMessageContentsWithFunctionCallRequiredAsync()
{
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};
var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel);
var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel);

// Assert
Assert.NotNull(response);
Expand Down Expand Up @@ -345,7 +345,7 @@ public async Task ValidateGetChatMessageContentsWithFunctionInvocationFilterAsyn
{
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};
var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel);
var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel);

// Assert
Assert.NotNull(response);
Expand Down Expand Up @@ -389,11 +389,11 @@ public async Task FilterContextHasValidStreamingFlagAsync(bool isStreaming)

if (isStreaming)
{
await client.GetStreamingChatMessageContentsAsync(chatHistory, default, executionSettings, kernel).ToListAsync();
await client.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel).ToListAsync();
}
else
{
await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel);
await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel);
}

// Assert
Expand Down Expand Up @@ -428,7 +428,7 @@ public async Task ValidateGetChatMessageContentsWithAutoFunctionInvocationFilter
{
new ChatMessageContent(AuthorRole.User, "What is the weather like in Paris?")
};
var response = await client.GetChatMessageContentsAsync(chatHistory, default, executionSettings, kernel);
var response = await client.GetChatMessageContentsAsync(chatHistory, executionSettings, kernel);

// Assert
Assert.NotNull(response);
Expand Down Expand Up @@ -465,7 +465,7 @@ public async Task ValidateGetStreamingChatMessageContentWithAutoFunctionInvocati
List<StreamingKernelContent> streamingContent = [];

// Act
await foreach (var item in client.GetStreamingChatMessageContentsAsync(chatHistory, default, executionSettings, kernel))
await foreach (var item in client.GetStreamingChatMessageContentsAsync(chatHistory, executionSettings, kernel))
{
streamingContent.Add(item);
}
Expand Down Expand Up @@ -497,7 +497,7 @@ public void ValidateToMistralChatMessages(string roleLabel, string content)
};

// Act
var messages = client.ToMistralChatMessages(chatMessage, default);
var messages = client.ToMistralChatMessages(chatMessage);

// Assert
Assert.NotNull(messages);
Expand All @@ -517,7 +517,7 @@ public void ValidateToMistralChatMessagesWithFunctionCallContent()
};

// Act
var messages = client.ToMistralChatMessages(content, default);
var messages = client.ToMistralChatMessages(content);

// Assert
Assert.NotNull(messages);
Expand All @@ -537,7 +537,7 @@ public void ValidateToMistralChatMessagesWithFunctionResultContent()
};

// Act
var messages = client.ToMistralChatMessages(content, default);
var messages = client.ToMistralChatMessages(content);

// Assert
Assert.NotNull(messages);
Expand Down
Loading

0 comments on commit 3a88dcf

Please sign in to comment.