Skip to content

Commit

Permalink
com.openai.unity 7.7.5 (#205)
Browse files Browse the repository at this point in the history
- Allow FunctionPropertyAttribute to be assignable to fields
- Updated Function schema generation
  - Fall back to complex types, and use $ref for discovered types
  - Fixed schema generation to properly assign unsigned integer types
  • Loading branch information
StephenHodgson authored Mar 3, 2024
1 parent 43212ee commit 052512c
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ namespace OpenAI
[AttributeUsage(AttributeTargets.Parameter)]
public sealed class FunctionParameterAttribute : Attribute
{
/// <summary>
/// Function parameter attribute to help describe the parameter for the function.
/// </summary>
/// <param name="description">The description of the parameter and its usage.</param>
public FunctionParameterAttribute(string description)
{
Description = description;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace OpenAI
{
[AttributeUsage(AttributeTargets.Property)]
[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field)]
public sealed class FunctionPropertyAttribute : Attribute
{
/// <summary>
Expand Down
177 changes: 128 additions & 49 deletions OpenAI/Packages/com.openai.unity/Runtime/Extensions/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public static JObject GenerateJsonSchema(this MethodInfo methodInfo)
requiredParameters.Add(parameter.Name);
}

schema["properties"]![parameter.Name] = GenerateJsonSchema(parameter.ParameterType);
schema["properties"]![parameter.Name] = GenerateJsonSchema(parameter.ParameterType, schema);

var functionParameterAttribute = parameter.GetCustomAttribute<FunctionParameterAttribute>();

Expand All @@ -62,12 +62,57 @@ public static JObject GenerateJsonSchema(this MethodInfo methodInfo)
return schema;
}

public static JObject GenerateJsonSchema(this Type type)
public static JObject GenerateJsonSchema(this Type type, JObject rootSchema)
{
var schema = new JObject();
var serializer = JsonSerializer.Create(OpenAIClient.JsonSerializationOptions);

if (type.IsEnum)
if (!type.IsPrimitive &&
type != typeof(Guid) &&
type != typeof(DateTime) &&
type != typeof(DateTimeOffset) &&
rootSchema["definitions"] != null &&
((JObject)rootSchema["definitions"]).ContainsKey(type.FullName))
{
return new JObject { ["$ref"] = $"#/definitions/{type.FullName}" };
}

if (type == typeof(string))
{
schema["type"] = "string";
}
else if (type == typeof(int) ||
type == typeof(long) ||
type == typeof(uint) ||
type == typeof(byte) ||
type == typeof(sbyte) ||
type == typeof(ulong) ||
type == typeof(short) ||
type == typeof(ushort))
{
schema["type"] = "integer";
}
else if (type == typeof(float) ||
type == typeof(double) ||
type == typeof(decimal))
{
schema["type"] = "number";
}
else if (type == typeof(bool))
{
schema["type"] = "boolean";
}
else if (type == typeof(DateTime) || type == typeof(DateTimeOffset))
{
schema["type"] = "string";
schema["format"] = "date-time";
}
else if (type == typeof(Guid))
{
schema["type"] = "string";
schema["format"] = "uuid";
}
else if (type.IsEnum)
{
schema["type"] = "string";
schema["enum"] = new JArray();
Expand All @@ -80,21 +125,54 @@ public static JObject GenerateJsonSchema(this Type type)
else if (type.IsArray || (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(List<>)))
{
schema["type"] = "array";
schema["items"] = GenerateJsonSchema(type.GetElementType() ?? type.GetGenericArguments()[0]);
var elementType = type.GetElementType() ?? type.GetGenericArguments()[0];

if (rootSchema["definitions"] != null &&
((JObject)rootSchema["definitions"]).ContainsKey(elementType.FullName))
{
schema["items"] = new JObject { ["$ref"] = $"#/definitions/{elementType.FullName}" };
}
else
{
schema["items"] = GenerateJsonSchema(elementType, rootSchema);
}
}
else if (type.IsClass && type != typeof(string))
else
{
schema["type"] = "object";
var properties = type.GetProperties();
var propertiesInfo = new JObject();
var requiredProperties = new JArray();
rootSchema["definitions"] ??= new JObject();
rootSchema["definitions"][type.FullName] = new JObject();

var properties = type.GetProperties(BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly);
var fields = type.GetFields(BindingFlags.Public | BindingFlags.Instance | BindingFlags.DeclaredOnly);
var members = new List<MemberInfo>(properties.Length + fields.Length);
members.AddRange(properties);
members.AddRange(fields);

var memberInfo = new JObject();
var requiredMembers = new JArray();

foreach (var property in properties)
foreach (var member in members)
{
var propertyInfo = GenerateJsonSchema(property.PropertyType);
var functionPropertyAttribute = property.GetCustomAttribute<FunctionPropertyAttribute>();
var jsonPropertyAttribute = property.GetCustomAttribute<JsonPropertyAttribute>();
var propertyName = jsonPropertyAttribute?.PropertyName ?? property.Name;
var memberType = GetMemberType(member);
var functionPropertyAttribute = member.GetCustomAttribute<FunctionPropertyAttribute>();
var jsonPropertyAttribute = member.GetCustomAttribute<JsonPropertyAttribute>();
var propertyName = jsonPropertyAttribute?.PropertyName ?? member.Name;

// skip unity engine property for Items
if (memberType == typeof(float) && propertyName.Equals("Item")) { continue; }

JObject propertyInfo;

if (rootSchema["definitions"] != null &&
((JObject)rootSchema["definitions"]).ContainsKey(memberType.FullName))
{
propertyInfo = new JObject { ["$ref"] = $"#/definitions/{memberType.FullName}" };
}
else
{
propertyInfo = GenerateJsonSchema(memberType, rootSchema);
}

// override properties with values from function property attribute
if (functionPropertyAttribute != null)
Expand All @@ -103,7 +181,7 @@ public static JObject GenerateJsonSchema(this Type type)

if (functionPropertyAttribute.Required)
{
requiredProperties.Add(propertyName);
requiredMembers.Add(propertyName);
}

JToken defaultValue = null;
Expand Down Expand Up @@ -143,52 +221,53 @@ public static JObject GenerateJsonSchema(this Type type)
propertyInfo["enum"] = enums;
}
}
else if (Nullable.GetUnderlyingType(property.PropertyType) == null)
else if (jsonPropertyAttribute != null)
{
requiredProperties.Add(propertyName);
switch (jsonPropertyAttribute.Required)
{
case Required.Always:
case Required.AllowNull:
requiredMembers.Add(propertyName);
break;
case Required.Default:
case Required.DisallowNull:
default:
requiredMembers.Remove(propertyName);
break;
}
}
else if (Nullable.GetUnderlyingType(memberType) == null)
{
if (member is FieldInfo)
{
requiredMembers.Add(propertyName);
}
}

propertiesInfo[propertyName] = propertyInfo;
memberInfo[propertyName] = propertyInfo;
}

schema["properties"] = propertiesInfo;
schema["properties"] = memberInfo;

if (requiredProperties.Count > 0)
{
schema["required"] = requiredProperties;
}
}
else
{
if (type == typeof(int) || type == typeof(long) || type == typeof(short) || type == typeof(byte))
if (requiredMembers.Count > 0)
{
schema["type"] = "integer";
}
else if (type == typeof(float) || type == typeof(double) || type == typeof(decimal))
{
schema["type"] = "number";
}
else if (type == typeof(bool))
{
schema["type"] = "boolean";
}
else if (type == typeof(DateTime) || type == typeof(DateTimeOffset))
{
schema["type"] = "string";
schema["format"] = "date-time";
}
else if (type == typeof(Guid))
{
schema["type"] = "string";
schema["format"] = "uuid";
}
else
{
schema["type"] = type.Name.ToLower();
schema["required"] = requiredMembers;
}

rootSchema["definitions"] ??= new JObject();
rootSchema["definitions"][type.FullName] = schema;
return new JObject { ["$ref"] = $"#/definitions/{type.FullName}" };
}

return schema;
}

private static Type GetMemberType(MemberInfo member)
=> member switch
{
FieldInfo fieldInfo => fieldInfo.FieldType,
PropertyInfo propertyInfo => propertyInfo.PropertyType,
_ => throw new ArgumentException($"{nameof(MemberInfo)} must be of type {nameof(FieldInfo)}, {nameof(PropertyInfo)}", nameof(member))
};
}
}
1 change: 1 addition & 0 deletions OpenAI/Packages/com.openai.unity/Runtime/OpenAIClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ protected override void ValidateAuthentication()
{
NullValueHandling = NullValueHandling.Ignore,
DefaultValueHandling = DefaultValueHandling.Ignore,
ReferenceLoopHandling = ReferenceLoopHandling.Ignore,
Converters = new List<JsonConverter>
{
new StringEnumConverter(new SnakeCaseNamingStrategy())
Expand Down
19 changes: 10 additions & 9 deletions OpenAI/Packages/com.openai.unity/Tests/TestFixture_00_02_Tools.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using NUnit.Framework;
using OpenAI.Images;
using OpenAI.Tests.Weather;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
Expand All @@ -21,6 +22,7 @@ public void Test_01_GetTools()
Assert.IsNotNull(tools);
Assert.IsNotEmpty(tools);
tools.Add(Tool.GetOrCreateTool(OpenAIClient.ImagesEndPoint, nameof(ImagesEndpoint.GenerateImageAsync)));
tools.Add(Tool.FromFunc<GameObject, Vector2, Vector3, Quaternion, string>("complex_objects", (gameObject, vector2, vector3, quaternion) => "success"));
var json = JsonConvert.SerializeObject(tools, Formatting.Indented, OpenAIClient.JsonSerializationOptions);
Debug.Log(json);
}
Expand All @@ -31,7 +33,7 @@ public async Task Test_02_Tool_Funcs()
var tools = new List<Tool>
{
Tool.FromFunc("test_func", Function),
Tool.FromFunc<string, string, string>("test_func_with_args", FunctionWithArgs),
Tool.FromFunc<DateTime, Vector3, string>("test_func_with_args", FunctionWithArgs),
Tool.FromFunc("test_func_weather", () => WeatherService.GetCurrentWeatherAsync("my location", WeatherService.WeatherUnit.Celsius))
};

Expand All @@ -44,13 +46,12 @@ public async Task Test_02_Tool_Funcs()
Assert.AreEqual("success", result);
var toolWithArgs = tools[1];
Assert.IsNotNull(toolWithArgs);
toolWithArgs.Function.Arguments = new JObject
{
["arg1"] = "arg1",
["arg2"] = "arg2"
};
var testValue = new { arg1 = DateTime.UtcNow, arg2 = Vector3.one };
toolWithArgs.Function.Arguments = JToken.FromObject(testValue, JsonSerializer.Create(OpenAIClient.JsonSerializationOptions));
var resultWithArgs = toolWithArgs.InvokeFunction<string>();
Assert.AreEqual("arg1 arg2", resultWithArgs);
Debug.Log(resultWithArgs);
var testResult = JsonConvert.DeserializeObject(resultWithArgs, testValue.GetType(), OpenAIClient.JsonSerializationOptions);
Assert.AreEqual(testResult, testValue);

var toolWeather = tools[2];
Assert.IsNotNull(toolWeather);
Expand All @@ -64,9 +65,9 @@ private string Function()
return "success";
}

private string FunctionWithArgs(string arg1, string arg2)
private string FunctionWithArgs(DateTime arg1, Vector3 arg2)
{
return $"{arg1} {arg2}";
return JsonConvert.SerializeObject(new { arg1, arg2 }, OpenAIClient.JsonSerializationOptions);
}
}
}
4 changes: 2 additions & 2 deletions OpenAI/Packages/com.openai.unity/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"displayName": "OpenAI",
"description": "A OpenAI package for the Unity Game Engine to use GPT-4, GPT-3.5, GPT-3 and Dall-E though their RESTful API (currently in beta).\n\nIndependently developed, this is not an official library and I am not affiliated with OpenAI.\n\nAn OpenAI API account is required.",
"keywords": [],
"version": "7.7.4",
"version": "7.7.5",
"unity": "2021.3",
"documentationUrl": "https://github.com/RageAgainstThePixel/com.openai.unity#documentation",
"changelogUrl": "https://github.com/RageAgainstThePixel/com.openai.unity/releases",
Expand All @@ -17,7 +17,7 @@
"url": "https://github.com/StephenHodgson"
},
"dependencies": {
"com.utilities.rest": "2.5.3",
"com.utilities.rest": "2.5.4",
"com.utilities.encoder.wav": "1.1.5"
},
"samples": [
Expand Down
4 changes: 2 additions & 2 deletions OpenAI/Packages/manifest.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"dependencies": {
"com.unity.ide.rider": "3.0.27",
"com.unity.ide.rider": "3.0.28",
"com.unity.ide.visualstudio": "2.0.22",
"com.unity.textmeshpro": "3.0.8",
"com.utilities.buildpipeline": "1.2.2"
"com.utilities.buildpipeline": "1.2.3"
},
"scopedRegistries": [
{
Expand Down

0 comments on commit 052512c

Please sign in to comment.