Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add JsonElement String to Primitive Implicit Conversion Support (SLM Function Calling) #9784

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
14 changes: 8 additions & 6 deletions dotnet/samples/Demos/OllamaFunctionCalling/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
var chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();
var settings = new OllamaPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.Auto() };

Console.WriteLine("Ask questions or give instructions to the copilot such as:\n" +
"- Change the alarm to 8\n" +
"- What is the current alarm set?\n" +
"- Is the light on?\n" +
"- Turn the light off please.\n" +
"- Set an alarm for 6:00 am.\n");
Console.WriteLine("""
Ask questions or give instructions to the copilot such as:
- Change the alarm to 8
- What is the current alarm set?
- Is the light on?
- Turn the light off please.
- Set an alarm for 6:00 am.
""");

Console.Write("> ");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<IsTestProject>true</IsTestProject>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<NoWarn>$(NoWarn);SKEXP0001;SKEXP0070;CS1591;IDE1006;RCS1261;CA1031;CA1308;CA1861;CA2007;CA2234;VSTHRD111</NoWarn>
<NoWarn>$(NoWarn);SYSLIB1222;SKEXP0001;SKEXP0070;CS1591;IDE1006;RCS1261;CA1031;CA1308;CA1861;CA2007;CA2234;VSTHRD111</NoWarn>
</PropertyGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<RootNamespace>$(AssemblyName)</RootNamespace>
<TargetFrameworks>net8.0;netstandard2.0</TargetFrameworks>
<VersionSuffix>alpha</VersionSuffix>
<NoWarn>$(NoWarn);SYSLIB1222</NoWarn>
<IsAotCompatible Condition="$([MSBuild]::IsTargetFrameworkCompatible('$(TargetFramework)', 'net7.0'))">true</IsAotCompatible>
</PropertyGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,22 @@ private static (Func<KernelFunction, Kernel, KernelArguments, CancellationToken,

var converter = GetConverter(type);

var jsonStringParsers = new Dictionary<Type, Func<string, object>>(12)
{
{ typeof(bool), s => bool.Parse(s) },
{ typeof(int), s => int.Parse(s) },
{ typeof(uint), s => uint.Parse(s) },
{ typeof(long), s => long.Parse(s) },
{ typeof(ulong), s => ulong.Parse(s) },
{ typeof(float), s => float.Parse(s) },
{ typeof(double), s => double.Parse(s) },
{ typeof(decimal), s => decimal.Parse(s) },
{ typeof(short), s => short.Parse(s) },
{ typeof(ushort), s => ushort.Parse(s) },
{ typeof(byte), s => byte.Parse(s) },
{ typeof(sbyte), s => sbyte.Parse(s) }
};

object? parameterFunc(KernelFunction _, Kernel kernel, KernelArguments arguments, CancellationToken __)
{
// 1. Use the value of the variable if it exists.
Expand All @@ -710,26 +726,34 @@ private static (Func<KernelFunction, Kernel, KernelArguments, CancellationToken,

object? Process(object? value)
{
if (!type.IsAssignableFrom(value?.GetType()))
if (type.IsAssignableFrom(value?.GetType()))
{
if (converter is not null)
return value;
}

if (converter is not null && value is not JsonElement or JsonDocument or JsonNode)
{
try
{
try
{
return converter(value, kernel.Culture);
}
catch (Exception e) when (!e.IsCriticalException())
{
throw new ArgumentOutOfRangeException(name, value, e.Message);
}
return converter(value, kernel.Culture);
}

if (value is not null && TryToDeserializeValue(value, type, jsonSerializerOptions, out var deserializedValue))
catch (Exception e) when (!e.IsCriticalException())
{
return deserializedValue;
throw new ArgumentOutOfRangeException(name, value, e.Message);
}
}

if (value is JsonElement element && element.ValueKind == JsonValueKind.String
&& jsonStringParsers.TryGetValue(type, out var jsonStringParser))
{
return jsonStringParser(element.GetString()!);
}

if (value is not null && TryToDeserializeValue(value, type, jsonSerializerOptions, out var deserializedValue))
{
return deserializedValue;
}

return value;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,66 @@ public async Task ItSupportsArgumentsImplicitConversionAsync()
await function.InvokeAsync(this._kernel, arguments);
}

[Fact]
public async Task ItSupportsJsonElementArgumentsImplicitConversionAsync()
{
//Arrange
var arguments = new KernelArguments()
{
["l"] = JsonSerializer.Deserialize<JsonElement>((long)1), //Passed to parameter of type long
["i"] = JsonSerializer.Deserialize<JsonElement>((byte)1), //Passed to parameter of type int
["d"] = JsonSerializer.Deserialize<JsonElement>((float)1.0), //Passed to parameter of type double
["f"] = JsonSerializer.Deserialize<JsonElement>((uint)1.0), //Passed to parameter of type float
["g"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize(new Guid("35626209-b0ab-458c-bfc4-43e6c7bd13dc"))), //Passed to parameter of type string
["dof"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize(DayOfWeek.Thursday)), //Passed to parameter of type int
["b"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("true")), //Passed to parameter of type bool
};

var function = KernelFunctionFactory.CreateFromMethod((long l, int i, double d, float f, string g, int dof, bool b) =>
{
Assert.Equal(1, l);
Assert.Equal(1, i);
Assert.Equal(1.0, d);
Assert.Equal("35626209-b0ab-458c-bfc4-43e6c7bd13dc", g);
Assert.Equal(4, dof);
Assert.True(b);
},
functionName: "Test");

await function.InvokeAsync(this._kernel, arguments);
await function.AsAIFunction().InvokeAsync(arguments);
}

[Fact]
public async Task ItSupportsStringJsonElementArgumentsImplicitConversionAsync()
{
//Arrange
var arguments = new KernelArguments()
{
["l"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("1")), //Passed to parameter of type long
["i"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("1")), //Passed to parameter of type int
["d"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("1.0")), //Passed to parameter of type double
["f"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("1.0")), //Passed to parameter of type float
["g"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("35626209-b0ab-458c-bfc4-43e6c7bd13dc")), //Passed to parameter of type Guid
["dof"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("4")), //Passed to parameter of type int
["b"] = JsonSerializer.Deserialize<JsonElement>(JsonSerializer.Serialize("false")), //Passed to parameter of type bool
};

var function = KernelFunctionFactory.CreateFromMethod((long l, int i, double d, float f, Guid g, int dof, bool b) =>
{
Assert.Equal(1, l);
Assert.Equal(1, i);
Assert.Equal(1.0, d);
Assert.Equal(new Guid("35626209-b0ab-458c-bfc4-43e6c7bd13dc"), g);
Assert.Equal(4, dof);
Assert.False(b);
},
functionName: "Test");

await function.InvokeAsync(this._kernel, arguments);
await function.AsAIFunction().InvokeAsync(arguments);
}

[Fact]
public async Task ItSupportsParametersWithDefaultValuesAsync()
{
Expand Down
Loading