Skip to content

Commit

Permalink
feat: support tools in text generation
Browse files Browse the repository at this point in the history
  • Loading branch information
ikesnowy committed Mar 13, 2024
1 parent a15a621 commit 14a853d
Show file tree
Hide file tree
Showing 17 changed files with 263 additions and 7 deletions.
5 changes: 4 additions & 1 deletion src/Cnblogs.DashScope.Sdk/ChatMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ namespace Cnblogs.DashScope.Sdk;
/// <summary>
/// Represents a chat message between the user and the model.
/// </summary>
public record ChatMessage(string Role, string Content) : IMessage<string>;
/// <param name="Role">The role of this message.</param>
/// <param name="Content">The content of this message.</param>
/// <param name="ToolCalls">Calls to the function.</param>
public record ChatMessage(string Role, string Content, List<ToolCall>? ToolCalls = null) : IMessage<string>;
3 changes: 3 additions & 0 deletions src/Cnblogs.DashScope.Sdk/Cnblogs.DashScope.Sdk.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageTags>Cnblogs;Dashscope;AI;Sdk;Embedding;</PackageTags>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="JsonSchema.Net.Generation" Version="4.1.1" />
</ItemGroup>
</Project>
8 changes: 8 additions & 0 deletions src/Cnblogs.DashScope.Sdk/FunctionCall.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace Cnblogs.DashScope.Sdk;

/// <summary>
/// Represents a call to function.
/// </summary>
/// <param name="Name">Name of the function to call.</param>
/// <param name="Arguments">Arguments of this call, usually a json string.</param>
public record FunctionCall(string Name, string? Arguments);
11 changes: 11 additions & 0 deletions src/Cnblogs.DashScope.Sdk/FunctionDefinition.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using Json.Schema;

namespace Cnblogs.DashScope.Sdk;

/// <summary>
/// Definition of function that can be called by model.
/// </summary>
/// <param name="Name">The name of the function.</param>
/// <param name="Description">Descriptions about this function that help model to decide when to call this function.</param>
/// <param name="Parameters">The parameters JSON schema.</param>
public record FunctionDefinition(string Name, string Description, JsonSchema? Parameters);
5 changes: 5 additions & 0 deletions src/Cnblogs.DashScope.Sdk/ITextGenerationParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,9 @@ public interface ITextGenerationParameters : IIncrementalOutputParameter, ISeedP
/// Enable internet search when generation. Defaults to false.
/// </summary>
public bool? EnableSearch { get; }

/// <summary>
/// Available tools for model to call.
/// </summary>
public List<ToolDefinition>? Tools { get; }
}
5 changes: 5 additions & 0 deletions src/Cnblogs.DashScope.Sdk/TextGenerationInput.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,9 @@ public class TextGenerationInput
/// The collection of context messages associated with this chat completions request.
/// </summary>
public IEnumerable<ChatMessage>? Messages { get; set; }

/// <summary>
/// Available tools for model to use.
/// </summary>
public IEnumerable<ToolDefinition>? Tools { get; set; }
}
3 changes: 3 additions & 0 deletions src/Cnblogs.DashScope.Sdk/TextGenerationParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public class TextGenerationParameters : ITextGenerationParameters
/// <inheritdoc />
public bool? EnableSearch { get; set; }

/// <inheritdoc />
public List<ToolDefinition>? Tools { get; set; }

/// <inheritdoc />
public bool? IncrementalOutput { get; set; }
}
9 changes: 9 additions & 0 deletions src/Cnblogs.DashScope.Sdk/ToolCall.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace Cnblogs.DashScope.Sdk;

/// <summary>
/// Represents a call to tool.
/// </summary>
/// <param name="Id">Id of this tool call.</param>
/// <param name="Type">Type of the tool.</param>
/// <param name="Function">Not null if type is function.</param>
public record ToolCall(string? Id, string Type, FunctionCall? Function);
8 changes: 8 additions & 0 deletions src/Cnblogs.DashScope.Sdk/ToolDefinition.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace Cnblogs.DashScope.Sdk;

/// <summary>
/// Definition of a tool that model can call during generation.
/// </summary>
/// <param name="Type">The type of this tool. Use <see cref="ToolTypes"/> to get all available options.</param>
/// <param name="Function">Not null when <paramref name="Type"/> is tool.</param>
public record ToolDefinition(string Type, FunctionDefinition? Function);
12 changes: 12 additions & 0 deletions src/Cnblogs.DashScope.Sdk/ToolTypes.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace Cnblogs.DashScope.Sdk;

/// <summary>
/// Available tool types for <see cref="ToolDefinition"/>.
/// </summary>
public static class ToolTypes
{
/// <summary>
/// Function type.
/// </summary>
public const string Function = "function";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"model": "qwen-max",
"input": {
"messages": [
{
"role": "user",
"content": "杭州现在的天气如何?"
}
]
},
"parameters": {
"result_format": "message",
"seed": 1234,
"max_tokens": 1500,
"top_p": 0.8,
"top_k": 100,
"repetition_penalty": 1.1,
"temperature": 0.85,
"stop": [[37763, 367]],
"enable_search": false,
"incremental_output": false,
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "获取现在的天气",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "要获取天气的省市名称,例如浙江省杭州市"
},
"unit": {
"description": "温度单位",
"enum": [
"Celsius",
"Fahrenheit"
]
}
},
"required": [
"location"
]
}
}
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"output":{"choices":[{"finish_reason":"tool_calls","message":{"role":"assistant","tool_calls":[{"function":{"name":"get_current_weather","arguments":"{\"location\": \"浙江省杭州市\", \"unit\": \"Celsius\"}"},"id":"","type":"function"}],"content":""}}]},"usage":{"total_tokens":36,"output_tokens":31,"input_tokens":5},"request_id":"40b4361e-e936-91b5-879d-355a45d670f8"}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
HTTP/1.1 200 OK
eagleeye-traceid: 7328e5207abf69133abfe3a68446fc2d
content-type: application/json
x-dashscope-call-gateway: true
x-dashscope-experiments: 33e6d810-qwen-max-base-default-imbalance-fix-lua
req-cost-time: 3898
req-arrive-time: 1710324737299
resp-start-time: 1710324741198
x-envoy-upstream-service-time: 3893
content-encoding: gzip
vary: Accept-Encoding
date: Wed, 13 Mar 2024 10:12:21 GMT
server: istio-envoy
transfer-encoding: chunked
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,27 @@ public void Configuration_CustomSectionName_Inject()
httpClient.BaseAddress.Should().BeEquivalentTo(new Uri(ProxyApi));
}

[Fact]
public void Configuration_AddMultipleTime_Replace()
{
// Arrange
var services = new ServiceCollection();

// Act
services.AddDashScopeClient(ApiKey, ProxyApi);
services.AddDashScopeClient(ApiKey, ProxyApi);
var provider = services.BuildServiceProvider();
var httpClient = provider.GetRequiredService<IHttpClientFactory>().CreateClient(nameof(IDashScopeClient));

// Assert
provider.GetRequiredService<IDashScopeClient>().Should().NotBeNull().And
.BeOfType<DashScopeClientCore>();
httpClient.Should().NotBeNull();
httpClient.DefaultRequestHeaders.Authorization.Should()
.BeEquivalentTo(new AuthenticationHeaderValue("Bearer", ApiKey));
httpClient.BaseAddress.Should().BeEquivalentTo(new Uri(ProxyApi));
}

[Fact]
public void Configuration_NoApiKey_Throw()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@ public async Task SingleCompletion_TextFormatSse_SuccessAsync()
message.ToString().Should().Be(testCase.ResponseModel.Output.Text);
}

[Fact]
public async Task SingleCompletion_MessageFormatNoSse_SuccessAsync()
[Theory]
[MemberData(nameof(SingleGenerationMessageFormatData))]
public async Task SingleCompletion_MessageFormatNoSse_SuccessAsync(
RequestSnapshot<ModelRequest<TextGenerationInput, ITextGenerationParameters>,
ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>> testCase)
{
// Arrange
const bool sse = false;
var testCase = Snapshots.TextGeneration.MessageFormat.SingleMessage;
var (client, handler) = await Sut.GetTestClientAsync(sse, testCase);

// Act
Expand Down Expand Up @@ -83,7 +85,9 @@ public async Task SingleCompletion_MessageFormatSse_SuccessAsync()
Arg.Is<HttpRequestMessage>(m => Checkers.IsJsonEquivalent(m.Content!, testCase.GetRequestJson(sse))),
Arg.Any<CancellationToken>());
outputs.SkipLast(1).Should().AllSatisfy(x => x.Output.Choices![0].FinishReason.Should().Be("null"));
outputs.Last().Should().BeEquivalentTo(testCase.ResponseModel, o => o.Excluding(y => y.Output.Choices![0].Message.Content));
outputs.Last().Should().BeEquivalentTo(
testCase.ResponseModel,
o => o.Excluding(y => y.Output.Choices![0].Message.Content));
message.ToString().Should().Be(testCase.ResponseModel.Output.Choices![0].Message.Content);
}

Expand All @@ -105,7 +109,14 @@ public async Task ConversationCompletion_MessageFormatSse_SuccessAsync()
Arg.Is<HttpRequestMessage>(m => Checkers.IsJsonEquivalent(m.Content!, testCase.GetRequestJson(sse))),
Arg.Any<CancellationToken>());
outputs.SkipLast(1).Should().AllSatisfy(x => x.Output.Choices![0].FinishReason.Should().Be("null"));
outputs.Last().Should().BeEquivalentTo(testCase.ResponseModel, o => o.Excluding(y => y.Output.Choices![0].Message.Content));
outputs.Last().Should().BeEquivalentTo(
testCase.ResponseModel,
o => o.Excluding(y => y.Output.Choices![0].Message.Content));
message.ToString().Should().Be(testCase.ResponseModel.Output.Choices![0].Message.Content);
}

public static readonly TheoryData<RequestSnapshot<ModelRequest<TextGenerationInput, ITextGenerationParameters>,
ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>>> SingleGenerationMessageFormatData = new(
Snapshots.TextGeneration.MessageFormat.SingleMessage,
Snapshots.TextGeneration.MessageFormat.SingleMessageWithTools);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System.Text.Json.Serialization;
using Json.More;
using Json.Schema.Generation;

namespace Cnblogs.DashScope.Sdk.UnitTests.Utils;

public record GetCurrentWeatherParameters(
[property: Required]
[property: Description("要获取天气的省市名称,例如浙江省杭州市")]
string Location,
[property: JsonConverter(typeof(EnumStringConverter<TemperatureUnit>))]
[property: Description("温度单位")]
TemperatureUnit Unit = TemperatureUnit.Celsius);

public enum TemperatureUnit
{
Celsius,
Fahrenheit
}
74 changes: 73 additions & 1 deletion test/Cnblogs.DashScope.Sdk.UnitTests/Utils/Snapshots.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
namespace Cnblogs.DashScope.Sdk.UnitTests.Utils;
using Json.Schema;
using Json.Schema.Generation;

namespace Cnblogs.DashScope.Sdk.UnitTests.Utils;

public static class Snapshots
{
Expand Down Expand Up @@ -267,6 +270,75 @@ public static class MessageFormat
}
});

public static readonly
RequestSnapshot<ModelRequest<TextGenerationInput, ITextGenerationParameters>,
ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>> SingleMessageWithTools =
new(
"single-generation-message-with-tools",
new()
{
Model = "qwen-max",
Input = new() { Messages = [new("user", "杭州现在的天气如何?")] },
Parameters = new TextGenerationParameters()
{
ResultFormat = "message",
Seed = 1234,
MaxTokens = 1500,
TopP = 0.8f,
TopK = 100,
RepetitionPenalty = 1.1f,
Temperature = 0.85f,
Stop = new([[37763, 367]]),
EnableSearch = false,
IncrementalOutput = false,
Tools =
[
new ToolDefinition(
"function",
new FunctionDefinition(
"get_current_weather",
"获取现在的天气",
new JsonSchemaBuilder().FromType<GetCurrentWeatherParameters>(
new()
{
PropertyNameResolver = PropertyNameResolvers.LowerSnakeCase
})
.Build()))
]
}
},
new()
{
Output = new()
{
Choices =
[
new()
{
FinishReason = "tool_calls",
Message = new(
"assistant",
string.Empty,
[
new(
string.Empty,
ToolTypes.Function,
new(
"get_current_weather",
"""{"location": "浙江省杭州市", "unit": "Celsius"}"""))
])
}
]
},
RequestId = "40b4361e-e936-91b5-879d-355a45d670f8",
Usage = new()
{
InputTokens = 5,
OutputTokens = 31,
TotalTokens = 36
}
});

public static readonly RequestSnapshot<ModelRequest<TextGenerationInput, ITextGenerationParameters>,
ModelResponse<TextGenerationOutput, TextGenerationTokenUsage>>
ConversationMessageIncremental = new(
Expand Down

0 comments on commit 14a853d

Please sign in to comment.