diff --git a/Delete-BIN-OBJ-Folders.bat b/Delete-BIN-OBJ-Folders.bat new file mode 100644 index 0000000..ae50a95 --- /dev/null +++ b/Delete-BIN-OBJ-Folders.bat @@ -0,0 +1,20 @@ +@ECHO off +cls + +ECHO Deleting all BIN and OBJ folders... +ECHO. + +FOR /d /r . %%d in (bin,obj,node_modules,Logs) DO ( + IF EXIST "%%d" ( + ECHO %%d | FIND /I "\node_modules\" > Nul && ( + ECHO.Skipping: %%d + ) || ( + ECHO.Deleting: %%d + rd /s/q "%%d" + ) + ) +) + +ECHO. +ECHO.BIN and OBJ folders have been successfully deleted. Press any key to exit. +pause > nul \ No newline at end of file diff --git a/Google_GenerativeAI.sln b/Google_GenerativeAI.sln new file mode 100644 index 0000000..c2a427e --- /dev/null +++ b/Google_GenerativeAI.sln @@ -0,0 +1,53 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio Version 17 +VisualStudioVersion = 17.7.34202.233 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{AC161F1D-EC76-48D2-86A3-B52584618D49}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "GenerativeAI", "src\GenerativeAI\GenerativeAI.csproj", "{3B56ADE1-2A08-4774-A7C9-7CDB5EE5735B}" +EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tests", "tests", "{FCCDE15A-B121-4D6C-BD56-D1B043A26F18}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "GenerativeAI.Tests", "tests\GenerativeAI.Tests\GenerativeAI.Tests.csproj", "{734EA3CA-DB49-4FD3-ABB5-E185231B1818}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "GenerativeAI.Generators", "src\GenerativeAI.Generators\GenerativeAI.Generators.csproj", "{6B247114-6727-4CE5-8C9F-0C361776589E}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GenerativeAI.IntegrationTests", "tests\GenerativeAI.IntegrationTests\GenerativeAI.IntegrationTests.csproj", "{3C094E6B-A3FB-43BA-920B-0E27DB3F1B3F}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {3B56ADE1-2A08-4774-A7C9-7CDB5EE5735B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3B56ADE1-2A08-4774-A7C9-7CDB5EE5735B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3B56ADE1-2A08-4774-A7C9-7CDB5EE5735B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3B56ADE1-2A08-4774-A7C9-7CDB5EE5735B}.Release|Any CPU.Build.0 = Release|Any CPU + {734EA3CA-DB49-4FD3-ABB5-E185231B1818}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {734EA3CA-DB49-4FD3-ABB5-E185231B1818}.Debug|Any CPU.Build.0 = Debug|Any CPU + {734EA3CA-DB49-4FD3-ABB5-E185231B1818}.Release|Any CPU.ActiveCfg = Release|Any CPU + {734EA3CA-DB49-4FD3-ABB5-E185231B1818}.Release|Any CPU.Build.0 = Release|Any CPU + {6B247114-6727-4CE5-8C9F-0C361776589E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6B247114-6727-4CE5-8C9F-0C361776589E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6B247114-6727-4CE5-8C9F-0C361776589E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6B247114-6727-4CE5-8C9F-0C361776589E}.Release|Any CPU.Build.0 = Release|Any CPU + {3C094E6B-A3FB-43BA-920B-0E27DB3F1B3F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {3C094E6B-A3FB-43BA-920B-0E27DB3F1B3F}.Debug|Any CPU.Build.0 = Debug|Any CPU + {3C094E6B-A3FB-43BA-920B-0E27DB3F1B3F}.Release|Any CPU.ActiveCfg = Release|Any CPU + {3C094E6B-A3FB-43BA-920B-0E27DB3F1B3F}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {3B56ADE1-2A08-4774-A7C9-7CDB5EE5735B} = {AC161F1D-EC76-48D2-86A3-B52584618D49} + {734EA3CA-DB49-4FD3-ABB5-E185231B1818} = {FCCDE15A-B121-4D6C-BD56-D1B043A26F18} + {6B247114-6727-4CE5-8C9F-0C361776589E} = {AC161F1D-EC76-48D2-86A3-B52584618D49} + {3C094E6B-A3FB-43BA-920B-0E27DB3F1B3F} = {FCCDE15A-B121-4D6C-BD56-D1B043A26F18} + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {FFF3E8BB-BACD-4376-8E33-55D6E8A30BE0} + EndGlobalSection +EndGlobal diff --git a/README.md b/README.md index bab78af..9007a0e 100644 --- a/README.md +++ b/README.md @@ -1 +1,91 @@ -# Google_GenerativeAI \ No newline at end of file +# Google GenerativeAI (Gemini) + + + + +- [Google GenerativeAI (Gemini)](#google-generativeai-gemini) + - [Usage](#usage) + - [Quick Start](#quick-start) + - [Function Calling](#function-calling) + + + + +Unofficial C# SDK based on Google GenerativeAI (Gemini Pro) REST APIs. + +This package includes C# Source Generator which allows you to define functions natively through a C# interface, +and also provides extensions that make it easier to call this interface later. +In addition to easy function implementation and readability, +it generates Args classes, extension methods to easily pass a functions to API, +and extension methods to simply call a function via json and return json. +Currently only System.Text.Json is supported. + +### Usage + +### Quick Start + +1) [Obtain an API](https://makersuite.google.com/app/apikey) key to use with the Google AI SDKs. + +```csharp + var apiKey = 'Your API Key'; + + var model = new GenerativeModel(apiKey); + + var res = await model.GenerateContentAsync("How are you doing?"); + +``` + +### Function Calling + +```csharp +using GenerativeAI; + +public enum Unit +{ + Celsius, + Fahrenheit, +} + +public class Weather +{ + public string Location { get; set; } = string.Empty; + public double Temperature { get; set; } + public Unit Unit { get; set; } + public string Description { get; set; } = string.Empty; +} + +[GenerativeAIFunctions] +public interface IWeatherFunctions +{ + [Description("Get the current weather in a given location")] + public Task GetCurrentWeatherAsync( + [Description("The city and state, e.g. San Francisco, CA")] string location, + Unit unit = Unit.Celsius, + CancellationToken cancellationToken = default); +} + +public class WeatherService : IWeatherFunctions +{ + public Task GetCurrentWeatherAsync(string location, Unit unit = Unit.Celsius, CancellationToken cancellationToken = default) + { + return Task.FromResult(new Weather + { + Location = location, + Temperature = 22.0, + Unit = unit, + Description = "Sunny", + }); + } +} + + WeatherService service = new WeatherService(); + + var apiKey = Environment.GetEnvironmentVariable("Gemini_API_Key", EnvironmentVariableTarget.User); + + var model = new GenerativeModel(apiKey, functions:service.AsFunctions(),calls:service.AsCalls()); + + var result = await model.GenerateContentAsync("What is the weather in San Francisco today?"); + + Console.WriteLine(result); +``` + diff --git a/src/GenerativeAI.Generators/GenerativeAI.Generators.csproj b/src/GenerativeAI.Generators/GenerativeAI.Generators.csproj new file mode 100644 index 0000000..13bee3f --- /dev/null +++ b/src/GenerativeAI.Generators/GenerativeAI.Generators.csproj @@ -0,0 +1,23 @@ + + + + netstandard2.0 + false + false + true + true + 10 + $(NoWarn);CA1014;CA1031;CA1308 + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + \ No newline at end of file diff --git a/src/GenerativeAI.Generators/Generators/GenerativeAiFunctionsGenerator.cs b/src/GenerativeAI.Generators/Generators/GenerativeAiFunctionsGenerator.cs new file mode 100644 index 0000000..319bc82 --- /dev/null +++ b/src/GenerativeAI.Generators/Generators/GenerativeAiFunctionsGenerator.cs @@ -0,0 +1,235 @@ +using System; +using System.ComponentModel; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace H.Generators; + +[Generator] +public class GenerativeAiFunctionsGenerator : IIncrementalGenerator +{ + #region Constants + + public const string Name = nameof(GenerativeAiFunctionsGenerator); + public const string Id = "OAFG"; + + #endregion + + #region Methods + + public void Initialize(IncrementalGeneratorInitializationContext context) + { + context.SyntaxProvider + .ForAttributeWithMetadataName("GenerativeAI.GenerativeAIFunctionsAttribute") + .SelectManyAllAttributesOfCurrentInterfaceSyntax() + .SelectAndReportExceptions(PrepareData, context, Id) + .SelectAndReportExceptions(GetClientSourceCode, context, Id) + .AddSource(context); + } + + private static string GetDescription(ISymbol symbol) + { + return symbol.GetAttributes() + .FirstOrDefault(static x => x.AttributeClass?.Name == nameof(DescriptionAttribute))? + .ConstructorArguments.First().Value?.ToString() ?? string.Empty; + } + + private static InterfaceData PrepareData( + (SemanticModel SemanticModel, AttributeData AttributeData, InterfaceDeclarationSyntax InterfaceSyntax, INamedTypeSymbol InterfaceSymbol) tuple) + { + var (_, _, _, interfaceSymbol) = tuple; + + var methods = interfaceSymbol + .GetMembers() + .OfType() + .Where(static x => x.MethodKind == MethodKind.Ordinary) + .Select(static x => new MethodData( + Name: x.Name, + Description: GetDescription(x), + IsAsync: x.IsAsync || x.ReturnType.Name == "Task", + IsVoid: x.ReturnsVoid, + Parameters: x.Parameters + .Where(static x => x.Type.MetadataName != "CancellationToken") + .Select(static x => ToParameterData( + typeSymbol: x.Type, + name: x.Name, + description: GetDescription(x), + isRequired: !x.IsOptional)) + .ToArray())) + .ToArray(); + + return new InterfaceData( + Namespace: interfaceSymbol.ContainingNamespace.ToDisplayString(), + Name: interfaceSymbol.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat), + Methods: methods); + } + + private static ParameterData ToParameterData(ITypeSymbol typeSymbol, string? name = null, string? description = null, bool isRequired = true) + { + string schemaType; + string? format = null; + var properties = Array.Empty(); + ParameterData? arrayItem = null; + switch (typeSymbol.TypeKind) + { + case TypeKind.Enum: + schemaType = "string"; + break; + + case TypeKind.Structure: + switch (typeSymbol.SpecialType) + { + case SpecialType.System_Int32: + schemaType = "integer"; + format = "int32"; + break; + + case SpecialType.System_Int64: + schemaType = "integer"; + format = "int64"; + break; + + case SpecialType.System_Single: + schemaType = "number"; + format = "double"; + break; + + case SpecialType.System_Double: + schemaType = "number"; + format = "float"; + break; + + case SpecialType.System_DateTime: + schemaType = "string"; + format = "date-time"; + break; + + case SpecialType.System_Boolean: + schemaType = "boolean"; + break; + + case SpecialType.None: + switch (typeSymbol.Name) + { + case "DateOnly": + schemaType = "string"; + format = "date"; + break; + + default: + throw new NotImplementedException($"{typeSymbol.Name} is not implemented."); + } + break; + + default: + throw new NotImplementedException($"{typeSymbol.SpecialType} is not implemented."); + } + break; + + case TypeKind.Class: + switch (typeSymbol.SpecialType) + { + case SpecialType.System_String: + schemaType = "string"; + break; + + + case SpecialType.None: + schemaType = "object"; + properties = typeSymbol.GetMembers() + .OfType() + .Select(static y => ToParameterData( + typeSymbol: y.Type, + name: y.Name, + description: GetDescription(y), + isRequired: true)) + .ToArray(); + break; + + default: + throw new NotImplementedException($"{typeSymbol.SpecialType} is not implemented."); + } + break; + + case TypeKind.Interface when typeSymbol.MetadataName == "IReadOnlyCollection`1": + schemaType = "array"; + arrayItem = (typeSymbol as INamedTypeSymbol)?.TypeArguments + .Select(static y => ToParameterData(y)) + .FirstOrDefault(); + break; + + case TypeKind.Array: + schemaType = "array"; + arrayItem = ToParameterData((typeSymbol as IArrayTypeSymbol)?.ElementType!); + break; + + default: + throw new NotImplementedException($"{typeSymbol.TypeKind} is not implemented."); + } + + return new ParameterData( + Name: !string.IsNullOrWhiteSpace(name) + ? name! + : typeSymbol.Name, + Description: !string.IsNullOrWhiteSpace(description) + ? description! + : GetDescription(typeSymbol), + Type: typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + DefaultValue: GetDefaultValue(typeSymbol), + SchemaType: schemaType, + Format: format, + Properties: properties, + ArrayItem: arrayItem != null + ? new []{ arrayItem.Value } + : Array.Empty(), + EnumValues: typeSymbol.TypeKind == TypeKind.Enum + ? typeSymbol + .GetMembers() + .OfType() + .Select(static x => x.Name.ToLowerInvariant()) + .ToArray() + : Array.Empty(), + IsNullable: IsNullable(typeSymbol), + IsRequired: isRequired); + } + + private static bool IsNullable(ITypeSymbol typeSymbol) + { + if (typeSymbol.TypeKind == TypeKind.Enum) + { + return false; + } + if (typeSymbol.TypeKind == TypeKind.Structure) + { + return false; + } + + return typeSymbol.SpecialType switch + { + SpecialType.System_String => false, + _ => true, + }; + } + + private static string GetDefaultValue(ITypeSymbol typeSymbol) + { + switch (typeSymbol.SpecialType) + { + case SpecialType.System_String: + return "string.Empty"; + + default: + return string.Empty; + } + } + + private static FileWithName GetClientSourceCode(InterfaceData @interface) + { + return new FileWithName( + Name: $"{@interface.Name}.Functions.generated.cs", + Text: SourceGenerationHelper.GenerateClientImplementation(@interface)); + } + + #endregion +} \ No newline at end of file diff --git a/src/GenerativeAI.Generators/Generators/Steps.cs b/src/GenerativeAI.Generators/Generators/Steps.cs new file mode 100644 index 0000000..752097c --- /dev/null +++ b/src/GenerativeAI.Generators/Generators/Steps.cs @@ -0,0 +1,34 @@ +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace H.Generators; + +public static class CommonSteps +{ + public static IncrementalValuesProvider + ForAttributeWithMetadataName( + this SyntaxValueProvider source, + string fullyQualifiedMetadataName) + { + return source + .ForAttributeWithMetadataName( + fullyQualifiedMetadataName: fullyQualifiedMetadataName, + predicate: static (node, _) => + node is InterfaceDeclarationSyntax { AttributeLists.Count: > 0 }, + transform: static (context, _) => context); + } + + public static IncrementalValuesProvider<(SemanticModel SemanticModel, AttributeData AttributeData, InterfaceDeclarationSyntax InterfaceSyntax, INamedTypeSymbol InterfaceSymbol)> + SelectManyAllAttributesOfCurrentInterfaceSyntax( + this IncrementalValuesProvider source) + { + return source + .SelectMany(static (context, _) => context.Attributes + .Select(x => ( + context.SemanticModel, + AttributeData: x, + ClassSyntax: (InterfaceDeclarationSyntax)context.TargetNode, + ClassSymbol: (INamedTypeSymbol)context.TargetSymbol))); + } +} diff --git a/src/GenerativeAI.Generators/Models/InterfaceData.cs b/src/GenerativeAI.Generators/Models/InterfaceData.cs new file mode 100644 index 0000000..5ccf7c5 --- /dev/null +++ b/src/GenerativeAI.Generators/Models/InterfaceData.cs @@ -0,0 +1,8 @@ +using System.Collections.Generic; + +namespace H.Generators; + +public readonly record struct InterfaceData( + string Namespace, + string Name, + IReadOnlyCollection Methods); \ No newline at end of file diff --git a/src/GenerativeAI.Generators/Models/MethodData.cs b/src/GenerativeAI.Generators/Models/MethodData.cs new file mode 100644 index 0000000..03a5cf9 --- /dev/null +++ b/src/GenerativeAI.Generators/Models/MethodData.cs @@ -0,0 +1,10 @@ +using System.Collections.Generic; + +namespace H.Generators; + +public readonly record struct MethodData( + string Name, + string Description, + bool IsAsync, + bool IsVoid, + IReadOnlyCollection Parameters); \ No newline at end of file diff --git a/src/GenerativeAI.Generators/Models/ParameterData.cs b/src/GenerativeAI.Generators/Models/ParameterData.cs new file mode 100644 index 0000000..267d663 --- /dev/null +++ b/src/GenerativeAI.Generators/Models/ParameterData.cs @@ -0,0 +1,16 @@ +using System.Collections.Generic; + +namespace H.Generators; + +public readonly record struct ParameterData( + string Name, + string Description, + string Type, + string SchemaType, + string? Format, + IReadOnlyCollection EnumValues, + IReadOnlyCollection Properties, + IReadOnlyCollection ArrayItem, + bool IsRequired, + bool IsNullable, + string DefaultValue); \ No newline at end of file diff --git a/src/GenerativeAI.Generators/SourceGenerationHelper.cs b/src/GenerativeAI.Generators/SourceGenerationHelper.cs new file mode 100644 index 0000000..9e4689b --- /dev/null +++ b/src/GenerativeAI.Generators/SourceGenerationHelper.cs @@ -0,0 +1,276 @@ +using System.Linq; +using H.Generators.Extensions; + +namespace H.Generators; + +internal static class SourceGenerationHelper +{ + /// + /// https://swagger.io/docs/specification/data-models/data-types/ + /// + /// + /// + /// + public static string GenerateOpenApiSchema(ParameterData parameter, int depth = 0) + { + var indent = new string(' ', depth * 4); + if (parameter.ArrayItem.Count != 0) + { + return $@"new +{indent} {{ +{indent} type = ""{parameter.SchemaType}"", +{indent} description = ""{parameter.Description}"", +{indent} items = {GenerateOpenApiSchema(parameter.ArrayItem.First(), depth: depth + 1)}, +{indent} }}"; + } + if (parameter.Properties.Count != 0) + { + return $@"new +{indent} {{ +{indent} type = ""{parameter.SchemaType}"", +{indent} description = ""{parameter.Description}"", +{indent} properties = new Dictionary +{indent} {{ +{indent} {string.Join(",\n " + indent, parameter.Properties.Select(x => $@"[""{x.Name}""] = " + GenerateOpenApiSchema(x, depth: depth + 2)))} +{indent} }}, +{indent} required = new string[] {{ {string.Join(", ", parameter.Properties + .Where(static x => x.IsRequired) + .Select(static x => $"\"{x.Name}\""))} }}, +{indent} }}"; + } + + if (parameter.EnumValues.Count != 0) + { + return $@"new +{indent} {{ +{indent} type = ""{parameter.SchemaType}"", +{indent} description = ""{parameter.Description}"", +{indent} @enum = new string[] {{ {string.Join(", ", parameter.EnumValues.Select(static x => $"\"{x}\""))} }}, +{indent} }}"; + } + + return $@"new +{indent} {{ +{indent} type = ""{parameter.SchemaType}"",{(parameter.Format != null ? $@" +{indent} format = ""{parameter.Format}""," : "")} +{indent} description = ""{parameter.Description}"", +{indent} }}"; + } + + public static string GenerateClientImplementation(InterfaceData @interface) + { + var extensionsClassName = @interface.Name.Substring(startIndex: 1) + "Extensions"; + + return @$" +using System.Collections.Generic; + +#nullable enable + +namespace {@interface.Namespace} +{{ + public static class {extensionsClassName} + {{ +{@interface.Methods.Select(static method => $@" + public class {method.Name}Args + {{ + {string.Join("\n ", method.Parameters.Select(static x => $@"public {x.Type}{(x.IsNullable ? "?" : "")} {x.Name.ToPropertyName()} {{ get; set; }}{(!string.IsNullOrEmpty(x.DefaultValue) ? $" = {x.DefaultValue};" : "")}"))} + }} +").Inject()} + +{@interface.Methods.Select(method => $@" + public static (string Name, string Description, object Obj) {method.Name}AsParametersObject(this {@interface.Name} functions) + {{ + return (""{method.Name}"", ""{method.Description}"", new + {{ + type = ""object"", + properties = new Dictionary + {{ + {string.Join(",\n ", method.Parameters.Select(static parameter => $@"[""{parameter.Name}""] = " + GenerateOpenApiSchema(parameter)))} + }}, + required = new string[] {{ {string.Join(", ", method.Parameters + .Where(static x => x.IsRequired) + .Select(static x => $"\"{x.Name}\""))} }}, + }}); + }} +").Inject()} + +{@interface.Methods.Select(method => $@" + public static (string Name, string Description, Dictionary Dictionary) {method.Name}AsDictionary(this {@interface.Name} functions) + {{ + return (""{method.Name}"", ""{method.Description}"", new Dictionary + {{ + [""type""] = ""object"", + [""properties""] = new Dictionary + {{ + {string.Join(",\n ", method.Parameters.Select(static parameter => $@"[""{parameter.Name}""] = " + GenerateOpenApiSchema(parameter)))} + }}, + [""required""] = new string[] {{ {string.Join(", ", method.Parameters + .Where(static x => x.IsRequired) + .Select(static x => $"\"{x.Name}\""))} }}, + }}); + }} +").Inject()} + + public static global::System.Collections.Generic.ICollection AsFunctions(this {@interface.Name} functions) + {{ +{@interface.Methods.Select((method, i) => $@" + var function{i} = functions.{method.Name}AsDictionary();").Inject()} + + return new global::System.Collections.Generic.List + {{ +{@interface.Methods.Select((_, i) => $@" + new global::GenerativeAI.Tools.ChatCompletionFunction() + {{ + Name = function{i}.Name, + Description = function{i}.Description, + Parameters = new global::GenerativeAI.Tools.ChatCompletionFunctionParameters + {{ + AdditionalProperties = function{i}.Dictionary, + }}, + }}, +").Inject()} + }}; + }} + + public static global::System.Collections.Generic.IReadOnlyDictionary>> AsCalls(this {@interface.Name} service) + {{ + return new global::System.Collections.Generic.Dictionary>> + {{ +{@interface.Methods.Where(static x => x is { IsAsync: false, IsVoid: false }).Select(method => $@" + [""{method.Name}""] = (json, _) => + {{ + return global::System.Threading.Tasks.Task.FromResult(service.Call{method.Name}(json)); + }}, +").Inject()} +{@interface.Methods.Where(static x => x is { IsAsync: false, IsVoid: true }).Select(method => $@" + [""{method.Name}""] = (json, _) => + {{ + service.Call{method.Name}(json); + + return global::System.Threading.Tasks.Task.FromResult(string.Empty); + }}, +").Inject()} +{@interface.Methods.Where(static x => x is { IsAsync: true, IsVoid: false }).Select(method => $@" + [""{method.Name}""] = async (json, cancellationToken) => + {{ + return await service.Call{method.Name}(json, cancellationToken); + }}, +").Inject()} +{@interface.Methods.Where(static x => x is { IsAsync: true, IsVoid: true }).Select(method => $@" + [""{method.Name}""] = async (json, cancellationToken) => + {{ + await service.Call{method.Name}(json, cancellationToken); + + return string.Empty; + }}, +").Inject()} + }}; + }} + +{@interface.Methods.Select(method => $@" + public static object {method.Name}AsFunctionObject(this {@interface.Name} functions) + {{ + var (name, description, parameters) = functions.{method.Name}AsParametersObject(); + + return new + {{ + name = name, + description = description, + parameters = parameters, + }}; + }} +").Inject()} + +{@interface.Methods.Select(method => $@" + public static (string Name, string Description, global::System.Text.Json.Nodes.JsonNode Node) {method.Name}AsParametersJsonNode(this {@interface.Name} functions) + {{ + var (name, description, parameters) = functions.{method.Name}AsParametersObject(); + var node = + global::System.Text.Json.JsonSerializer.SerializeToNode(parameters) ?? + throw new global::System.InvalidOperationException(""Could not serialize parameters.""); + + return (name, description, node); + }} +").Inject()} + +{@interface.Methods.Select(method => $@" + public static {extensionsClassName}.{method.Name}Args As{method.Name}Args( + this {@interface.Name} functions, + string json) + {{ + return + global::System.Text.Json.JsonSerializer.Deserialize<{extensionsClassName}.{method.Name}Args>(json, new global::System.Text.Json.JsonSerializerOptions + {{ + PropertyNamingPolicy = global::System.Text.Json.JsonNamingPolicy.CamelCase, + }}) ?? + throw new global::System.InvalidOperationException(""Could not deserialize JSON.""); + }} +").Inject()} + +{@interface.Methods.Where(static x => x is { IsAsync: false, IsVoid: false }).Select(method => $@" + public static string Call{method.Name}(this {@interface.Name} functions, string json) + {{ + var args = functions.As{method.Name}Args(json); + var jsonResult = functions.{method.Name}({string.Join(", ", method.Parameters.Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}); + + return global::System.Text.Json.JsonSerializer.Serialize(jsonResult, new global::System.Text.Json.JsonSerializerOptions + {{ + PropertyNamingPolicy = global::System.Text.Json.JsonNamingPolicy.CamelCase, + Converters = {{ new global::System.Text.Json.Serialization.JsonStringEnumConverter(global::System.Text.Json.JsonNamingPolicy.CamelCase) }}, + }}); + }} +").Inject()} + +{@interface.Methods.Where(static x => x is { IsAsync: false, IsVoid: true }).Select(method => $@" + public static void Call{method.Name}(this {@interface.Name} functions, string json) + {{ + var args = functions.As{method.Name}Args(json); + functions.{method.Name}({string.Join(", ", method.Parameters.Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}); + }} +").Inject()} + +{@interface.Methods.Where(static x => x is { IsAsync: true, IsVoid: false }).Select(method => $@" + public static async global::System.Threading.Tasks.Task Call{method.Name}( + this {@interface.Name} functions, + string json, + global::System.Threading.CancellationToken cancellationToken = default) + {{ + var args = functions.As{method.Name}Args(json); + var jsonResult = await functions.{method.Name}({string.Join(", ", method.Parameters.Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}, cancellationToken); + + return global::System.Text.Json.JsonSerializer.Serialize(jsonResult, new global::System.Text.Json.JsonSerializerOptions + {{ + PropertyNamingPolicy = global::System.Text.Json.JsonNamingPolicy.CamelCase, + Converters = {{ new global::System.Text.Json.Serialization.JsonStringEnumConverter(global::System.Text.Json.JsonNamingPolicy.CamelCase) }}, + }}); + }} +").Inject()} + +{@interface.Methods.Where(static x => x is { IsAsync: true, IsVoid: true }).Select(method => $@" + public static async global::System.Threading.Tasks.Task Call{method.Name}( + this {@interface.Name} functions, + string json, + global::System.Threading.CancellationToken cancellationToken = default) + {{ + var args = functions.As{method.Name}Args(json); + await functions.{method.Name}({string.Join(", ", method.Parameters.Select(static parameter => $@"args.{parameter.Name.ToPropertyName()}"))}, cancellationToken); + + return string.Empty; + }} +").Inject()} + + public static async global::System.Threading.Tasks.Task CallAsync( + this {@interface.Name} service, + string functionName, + string argumentsAsJson, + global::System.Threading.CancellationToken cancellationToken = default) + {{ + var calls = service.AsCalls(); + var func = calls[functionName]; + + return await func(argumentsAsJson, cancellationToken); + }} + }} +}}"; + } +} diff --git a/src/GenerativeAI/Client/ChatSession.cs b/src/GenerativeAI/Client/ChatSession.cs new file mode 100644 index 0000000..e19fb8d --- /dev/null +++ b/src/GenerativeAI/Client/ChatSession.cs @@ -0,0 +1,105 @@ +using GenerativeAI.Helpers; +using GenerativeAI.Models; +using GenerativeAI.Requests; +using GenerativeAI.Types; + +namespace GenerativeAI.Methods +{ + public class ChatSession + { + #region Properties + public List History { get; private set; } + public GenerativeModel Model { get; private set; } + #endregion + + #region Constructor + + public ChatSession(GenerativeModel model, StartChatParams @params) + { + this.Model = model; + if (@params.History != null) + { + this.History = @params.History.Select(s => RequestExtensions.FormatGenerateContentInput(s.Parts)).ToList(); + } + else this.History = new List(); + } + #endregion + + #region public methods + + public async Task SendMessageAsync(string message, CancellationToken cancellationToken = default) + { + var content = RequestExtensions.FormatGenerateContentInput(message); + + var contents = new List(); + contents.AddRange(this.History); + contents.Add(content); + + var request = new GenerateContentRequest() + { + Contents = contents.ToArray() + }; + + var response = await Model.GenerateContentAsync(request,cancellationToken); + + if (response.Candidates is { Length: > 0 }) + { + this.History.Add(content); + var responseContent = response.Candidates[0].Content; + responseContent.Role = Roles.Model; + + this.History.Add(responseContent); + } + else + { + var blockErrorMessage = ResponseHelper.FormatBlockErrorMessage(response); + if (!string.IsNullOrEmpty(blockErrorMessage)) + { + throw new Exception(blockErrorMessage); + } + } + + return response; + } + + public async Task SendMessage(GenerateContentRequest request) + { + var contents = new List(); + contents.AddRange(this.History); + if (request.Contents != null) + { + contents.AddRange(request.Contents); + + } + var request2 = new GenerateContentRequest() + { + Contents = contents.ToArray() + }; + var response = await Model.GenerateContentAsync(request2); + + if (response.Candidates is { Length: > 0 }) + { + if (request.Contents != null) + this.History.AddRange(request.Contents); + var responseContent = response.Candidates[0].Content; + if (responseContent != null) + { + responseContent.Role = Roles.Model; + + this.History.Add(responseContent); + } + } + else + { + var blockErrorMessage = ResponseHelper.FormatBlockErrorMessage(response); + if (!string.IsNullOrEmpty(blockErrorMessage)) + { + throw new Exception(blockErrorMessage); + } + } + + return response; + } + #endregion + } +} diff --git a/src/GenerativeAI/Common/GenerativeAiSerializeOptions.cs b/src/GenerativeAI/Common/GenerativeAiSerializeOptions.cs new file mode 100644 index 0000000..04a9142 --- /dev/null +++ b/src/GenerativeAI/Common/GenerativeAiSerializeOptions.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace GenerativeAI.Common +{ + +} diff --git a/src/GenerativeAI/Extensions/RequestExtensions.cs b/src/GenerativeAI/Extensions/RequestExtensions.cs new file mode 100644 index 0000000..f26b041 --- /dev/null +++ b/src/GenerativeAI/Extensions/RequestExtensions.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using GenerativeAI.Types; + +namespace GenerativeAI.Helpers +{ + public class RequestExtensions + { + public static Content FormatGenerateContentInput(string @params) + { + var parts = new[]{new Part(){Text = @params}}; + return new Content(parts, Roles.User); + } + + public static Content FormatGenerateContentInput( IEnumerable request) + { + var parts = request.Select(part => new Part() { Text = part }).ToArray(); + + return new Content(parts, Roles.User); + } + + public static Content FormatGenerateContentInput(IEnumerable request) + { + return new Content(request.ToArray(), Roles.User); + } + } +} diff --git a/src/GenerativeAI/GenerativeAI.csproj b/src/GenerativeAI/GenerativeAI.csproj new file mode 100644 index 0000000..be20e10 --- /dev/null +++ b/src/GenerativeAI/GenerativeAI.csproj @@ -0,0 +1,36 @@ + + + + net4.6.2;netstandard2.0;net6.0;net7.0;net8.0 + enable + enable + 10.0 + True + Google_$(AssemblyName) + Unofficial Google GenerativeAI (Gemini) SDK + Gunpal Jain + + Unofficial C# SDK for Google Generative AI + https://github.com/gunpal5/Google_GenerativeAI + README.md + https://github.com/gunpal5/Google_GenerativeAI + Gemini,Google,GenerativeAI + 0.1.0 + 0.1.0 + 0.1.0 + + + + + True + \ + + + + + + + + + + diff --git a/src/GenerativeAI/GenerativeAIFunctionsAttribute.cs b/src/GenerativeAI/GenerativeAIFunctionsAttribute.cs new file mode 100644 index 0000000..01b815c --- /dev/null +++ b/src/GenerativeAI/GenerativeAIFunctionsAttribute.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace GenerativeAI +{ + public class GenerativeAIFunctionsAttribute : Attribute + { + } +} diff --git a/src/GenerativeAI/Models/GenerativeModel.cs b/src/GenerativeAI/Models/GenerativeModel.cs new file mode 100644 index 0000000..9d8333c --- /dev/null +++ b/src/GenerativeAI/Models/GenerativeModel.cs @@ -0,0 +1,154 @@ +using GenerativeAI.Helpers; +using GenerativeAI.Methods; +using GenerativeAI.Types; +using System.Net.Http; +using System.Text.Json; +using System.Text.Json.Nodes; +using GenerativeAI.Tools; +using System.Threading; + +namespace GenerativeAI.Models +{ + public class GenerativeModel : ModelBase + { + #region Properties + public string Model { get; set; } + public GenerationConfig Config { get; set; } + public SafetySetting[] SafetySettings { get; set; } + private string ApiKey { get; set; } + public bool AutoCallFunction { get; set; } = true; + public bool AutoReplyFunction { get; set; } = true; + public List? Functions { get; set; } + + public IReadOnlyDictionary>> Calls { get; set; } + #endregion + + #region Contructors + public GenerativeModel(string apiKey, ModelParams modelParams, HttpClient? client = null, ICollection functions = null, IReadOnlyDictionary>> calls = null) + { + if (modelParams.Model != null && modelParams.Model.StartsWith("models/")) + { + this.Model = modelParams.Model.Split(new[] { "model/" }, StringSplitOptions.RemoveEmptyEntries)[1]; + } + else + { + this.Model = modelParams.Model ?? "gemini-pro"; + } + + this.Config = modelParams.GenerationConfig ?? new GenerationConfig(); + this.SafetySettings = modelParams.SafetySettings ?? new List().ToArray(); + this.ApiKey = apiKey; + this.Functions = functions.ToList(); + this.Calls = calls; + InitClient(client); + } + + private void InitClient(HttpClient? client) + { + if (client == null) + { + this.Client = new HttpClient() { Timeout = new TimeSpan(0, 10, 0) }; + } + else + Client = client; + } + + public GenerativeModel(string apiKey, string model = "gemini-pro", HttpClient? client = null, ICollection functions = null, IReadOnlyDictionary>> calls = null) + { + this.ApiKey = apiKey; + this.Model = model; + this.Config = new GenerationConfig(); + this.SafetySettings = new List().ToArray(); + this.Functions = functions.ToList(); + this.Calls = calls; + + InitClient(client); + } + #endregion + + #region public Methods + public async Task GenerateContentAsync(GenerateContentRequest request, CancellationToken cancellationToken = default) + { + request.GenerationConfig = this.Config; + request.SafetySettings = this.SafetySettings; + var res = await GenerateContent(this.ApiKey, this.Model, request); + return await CallFunction(request,res, cancellationToken); + } + + public async Task GenerateContentAsync(string message, CancellationToken cancellationToken = default) + { + var content = RequestExtensions.FormatGenerateContentInput(message); + var req = new GenerateContentRequest() + { + Contents = new[] { content }, + GenerationConfig = this.Config, + SafetySettings = this.SafetySettings + }; + + if (this.Functions != null) + { + req.Tools = new List(new[]{new GenerativeAITool() + { + FunctionDeclaration = this.Functions + }}); + } + + var res = await GenerateContent(this.ApiKey, this.Model, req); + + res = await CallFunction(req,res, cancellationToken); + return res.Text(); + } + + private async Task CallFunction(GenerateContentRequest req, EnhancedGenerateContentResponse res, CancellationToken cancellationToken = default) + { + if (AutoCallFunction && res.GetFunction() != null) + { + var function = res.GetFunction(); + var name = function.Name ?? string.Empty; + var func = Calls[name]; + var args = function.Arguments != null ? JsonSerializer.Serialize(function.Arguments,SerializerOptions) : string.Empty; + var jsonResult = await func(args, cancellationToken).ConfigureAwait(false); + + if (AutoReplyFunction) + { + var responseContent = JsonSerializer.Deserialize(jsonResult,SerializerOptions); + + var content = new Content(){Role = Roles.Function}; + content.Parts = new[] + { + new Part() + { + FunctionResponse = new ChatFunctionResponse() + { + Name = name, + Response = new FunctionResponse(){Name = name, Content = responseContent} + } + } + }; + + var contents = new List(); + if(req.Contents!=null) + contents.AddRange(req.Contents); + + contents.Add(new Content(res.Candidates[0].Content.Parts, res.Candidates[0].Content.Role)); + contents.Add(content); + res = await GenerateContentAsync(new GenerateContentRequest() { Contents = contents.ToArray()} , + cancellationToken); + } + } + + return res; + } + + public ChatSession StartChat(StartChatParams startSessionParams) + { + return new ChatSession(this, startSessionParams); + } + + public async Task CountTokens(CountTokensRequest request) + { + return await CountTokens(this.ApiKey, this.Model, request); + } + #endregion + } +} diff --git a/src/GenerativeAI/Models/ModelBase.cs b/src/GenerativeAI/Models/ModelBase.cs new file mode 100644 index 0000000..7a280eb --- /dev/null +++ b/src/GenerativeAI/Models/ModelBase.cs @@ -0,0 +1,79 @@ +using GenerativeAI.Requests; +using GenerativeAI.Types; +using System.Net.Http; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace GenerativeAI.Models +{ + public abstract class ModelBase + { + public string BaseUrl { get; set; } = "https://generativelanguage.googleapis.com"; + public string Version { get; set; } = "v1beta"; + + public HttpClient Client { get; protected set; } + public TimeSpan Timeout + { + get => Client.Timeout; + set => Client.Timeout = value; + } + + protected JsonSerializerOptions SerializerOptions => new JsonSerializerOptions() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, Converters = { new JsonStringEnumConverter() }, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }; + + protected virtual async Task GenerateContent(string apiKey, string model, GenerateContentRequest request) + { + var url = new RequestUrl(model, Tasks.GenerateContent, apiKey, false, BaseUrl, Version); + + var json = JsonSerializer.Serialize(request, SerializerOptions); + + var stringContent = new StringContent(json, Encoding.UTF8, "application/json"); + + var response = await Client.PostAsync(url, stringContent); + if (response.IsSuccessStatusCode) + { + var result = await JsonSerializer.DeserializeAsync(await response.Content.ReadAsStreamAsync(), SerializerOptions); + + if (!(result.Candidates is { Length: > 0 })) + { + var blockErrorMessage = ResponseHelper.FormatBlockErrorMessage(result); + if (!string.IsNullOrEmpty(blockErrorMessage)) + { + throw new Exception(blockErrorMessage); + } + } + + return result; + } + else + throw new Exception($"Error while requesting {url.ToString("__API_Key__")}: " + + await response.Content.ReadAsStreamAsync()); + } + + protected virtual async Task CountTokens(string apiKey, string model, + CountTokensRequest request) + { + var url = new RequestUrl(model, Tasks.CountTokens, apiKey, false, this.BaseUrl, this.Version); + + var json = JsonSerializer.Serialize(request, GoogleSerializerContext.Default.CountTokensRequest); + + var stringContent = new StringContent(json, Encoding.UTF8, "application/json"); + + var response = await Client.PostAsync(url, stringContent); + if (response.IsSuccessStatusCode) + { + var str = await response.Content.ReadAsStringAsync(); + var result = await JsonSerializer.DeserializeAsync(await response.Content.ReadAsStreamAsync(), + GoogleSerializerContext.Default.CountTokensResponse); + return result; + } + else + throw new Exception($"Error while requesting {url.ToString("__API_Key__")}: " + + await response.Content.ReadAsStreamAsync()); + } + } +} diff --git a/src/GenerativeAI/Requests/Request.cs b/src/GenerativeAI/Requests/Request.cs new file mode 100644 index 0000000..c14f55b --- /dev/null +++ b/src/GenerativeAI/Requests/Request.cs @@ -0,0 +1,49 @@ +namespace GenerativeAI.Requests +{ + public class RequestUrl + { + public string Model { get; set; } + public string Task { get; set; } + public string ApiKey { get; set; } + public bool Stream { get; set; } + + public string BaseUrl { get; set; } + public string Version { get; set; } + public RequestUrl(string model, string task, string apiKey, bool stream, string baseUrl, string version = "v1") + { + Model = model; + Task = task; + ApiKey = apiKey; + Stream = stream; + BaseUrl = baseUrl; + Version = version; + } + + public override string ToString() + { + return ToString("__API_Key__"); + } + + public string ToString(string apiKey) + { + var url = $"{BaseUrl}/{Version}/models/{this.Model}:{this.Task}?key={ApiKey}"; + if (this.Stream) + { + url += "&alt=sse"; + } + + return url; + } + + public static implicit operator string(RequestUrl d) => d.ToString(d.ApiKey); + } + + public class Tasks + { + public const string GenerateContent = "generateContent"; + public const string StreamGenerateContent = "streamGenerateContent"; + public const string CountTokens = "countTokens"; + public const string EmbedContent = "embedContent"; + public const string BatchEmbedContents = "batchEmbedContents"; + } +} diff --git a/src/GenerativeAI/Requests/ResponseHelper.cs b/src/GenerativeAI/Requests/ResponseHelper.cs new file mode 100644 index 0000000..dee2c2d --- /dev/null +++ b/src/GenerativeAI/Requests/ResponseHelper.cs @@ -0,0 +1,43 @@ +using GenerativeAI.Types; + +namespace GenerativeAI.Requests +{ + public class ResponseHelper + { + public static string FormatBlockErrorMessage(GenerateContentResponse response) + { + var message = ""; + if (response.Candidates == null || response.Candidates.Length == 0 && response.PromptFeedback!=null) + { + message += "Response was blocked"; + if (response.PromptFeedback?.BlockReason >0) + { + message += $" due to {response.PromptFeedback.BlockReason}"; + } + + if (!string.IsNullOrEmpty(response.PromptFeedback?.BlockReasonMessage)) + { + message = $" :{response.PromptFeedback?.BlockReasonMessage}"; + } + } + else if (response.Candidates?[0] != null) + { + var firstCandidate = response.Candidates[0]; + if (hadBadFinishReason(firstCandidate)) + { + message += $": {firstCandidate.FinishMessage}"; + } + } + return message; + } + + public static bool hadBadFinishReason(GenerateContentCandidate candidate) + { + if (candidate.FinishReason == FinishReason.RECITATION || candidate.FinishReason == FinishReason.SAFETY) + { + return false; + } + else return true; + } + } +} diff --git a/src/GenerativeAI/Tools/Function.cs b/src/GenerativeAI/Tools/Function.cs new file mode 100644 index 0000000..9e597b5 --- /dev/null +++ b/src/GenerativeAI/Tools/Function.cs @@ -0,0 +1,65 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using System.Text; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Threading.Tasks; + +namespace GenerativeAI.Tools +{ + public class ChatCompletionFunctionParameters + { + + private IDictionary? _additionalProperties; + + [JsonExtensionData] + public IDictionary AdditionalProperties + { + get { return _additionalProperties ??= new System.Collections.Generic.Dictionary(); } + set + { + _additionalProperties = value; + } + } + + } + + public class ChatCompletionFunction + { + /// + /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + /// + + [JsonPropertyName("name")] + + [JsonIgnore(Condition = JsonIgnoreCondition.Never)] + [Required(AllowEmptyStrings = true)] + public string Name { get; set; } = default!; + + /// + /// A description of what the function does, used by the model to choose when and how to call the function. + /// + + [JsonPropertyName("description")] + + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)] + public string? Description { get; set; } = default!; + + [JsonPropertyName("parameters")] + + [JsonIgnore(Condition = JsonIgnoreCondition.Never)] + [Required] + public ChatCompletionFunctionParameters Parameters { get; set; } = new ChatCompletionFunctionParameters(); + + private IDictionary? _additionalProperties; + + [JsonExtensionData] + public IDictionary AdditionalProperties + { + get { return _additionalProperties ??= new Dictionary(); } + set { _additionalProperties = value; } + } + } +} diff --git a/src/GenerativeAI/Tools/FunctionCall.cs b/src/GenerativeAI/Tools/FunctionCall.cs new file mode 100644 index 0000000..c9f989d --- /dev/null +++ b/src/GenerativeAI/Tools/FunctionCall.cs @@ -0,0 +1,42 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using System.Text; +using System.Text.Json.Serialization; +using System.Threading.Tasks; + +namespace GenerativeAI.Tools +{ + public class ChatFunctionCall + { + /// + /// The name of the function to call. + /// + + [JsonPropertyName("name")] + + [JsonIgnore(Condition = System.Text.Json.Serialization.JsonIgnoreCondition.Never)] + [Required(AllowEmptyStrings = true)] + public string Name { get; set; } = default!; + + /// + /// The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + /// + + [JsonPropertyName("args")] + + [JsonIgnore(Condition = System.Text.Json.Serialization.JsonIgnoreCondition.Never)] + [Required(AllowEmptyStrings = true)] + public IDictionary? Arguments { get; set; } = default!; + + private IDictionary? _additionalProperties; + + [JsonExtensionData] + public IDictionary AdditionalProperties + { + get { return _additionalProperties ?? (_additionalProperties = new System.Collections.Generic.Dictionary()); } + set { _additionalProperties = value; } + } + } +} diff --git a/src/GenerativeAI/Tools/GenerativeAITools.cs b/src/GenerativeAI/Tools/GenerativeAITools.cs new file mode 100644 index 0000000..3d6e98d --- /dev/null +++ b/src/GenerativeAI/Tools/GenerativeAITools.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json.Serialization; +using System.Threading.Tasks; + +namespace GenerativeAI.Tools +{ + /// + /// + /// + public class GenerativeAITool + { + [JsonPropertyName("function_declarations")] + public List? FunctionDeclaration { get; set; } + } + + public class GenerativeAITools : List + { + + } +} diff --git a/src/GenerativeAI/Types/Content.cs b/src/GenerativeAI/Types/Content.cs new file mode 100644 index 0000000..6fc54d4 --- /dev/null +++ b/src/GenerativeAI/Types/Content.cs @@ -0,0 +1,68 @@ +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using GenerativeAI.Tools; + +namespace GenerativeAI.Types +{ + /// + /// Content type for both prompts and response candidates. + /// + public class Content + { + public Content() { } + public Content(Part[]? parts, string? role) + { + Parts = parts; + Role = role; + } + + public Part[]? Parts { get; set; } + public string? Role { get; set; } + } + + /// + /// Content that can be provided as history input to startChat(). + /// + public class InputContent + { + public string? Parts { get; set; } + public string? Role { get; set; } + } + /// + /// Content part - includes text or image part types. + /// + public class Part + { + public string? Text { get; set; } + public string? InlineData { get; set; } + + public ChatFunctionCall? FunctionCall { get; set; } + public ChatFunctionResponse? FunctionResponse { get; set; } + } + + public class ChatFunctionResponse + { + public string Name { get; set; } + public FunctionResponse Response { get; set; } + } + + public class FunctionResponse + { + public string Name { get; set; } + public JsonNode Content { get; set; } + } + /// + /// Interface for sending an image. + /// + public class GenerativeContentBlob + { + /// + /// MimeType of Image + /// + public string? MimeType { get; set; } + /// + /// Image as a base64 string. + /// + public string? Data { get; set; } + } +} diff --git a/src/GenerativeAI/Types/Enums.cs b/src/GenerativeAI/Types/Enums.cs new file mode 100644 index 0000000..1ff3b41 --- /dev/null +++ b/src/GenerativeAI/Types/Enums.cs @@ -0,0 +1,128 @@ +namespace GenerativeAI.Types +{ + /// + /// Harm categories that would cause prompts or candidates to be blocked. + /// + public enum HarmCategory + { + HARM_CATEGORY_UNSPECIFIED, + HARM_CATEGORY_HATE_SPEECH, + HARM_CATEGORY_SEXUALLY_EXPLICIT, + HARM_CATEGORY_HARASSMENT, + HARM_CATEGORY_DANGEROUS_CONTENT, + } + + /// + /// Reason that a prompt was blocked. + /// + public enum BlockReason + { + /// + /// A blocked reason was not specified. + /// + BLOCKED_REASON_UNSPECIFIED, + /// + /// Content was blocked by safety settings. + /// + SAFETY, + /// + /// Content was blocked, but the reason is uncategorized. + /// + OTHER, + } + /// + /// Threshhold above which a prompt or candidate will be blocked. + /// + public enum HarmBlockThreshold + { + /// + /// Threshold is unspecified. + /// + HARM_BLOCK_THRESHOLD_UNSPECIFIED, + /// + /// Content with NEGLIGIBLE will be allowed. + /// + BLOCK_LOW_AND_ABOVE, + /// + /// Content with NEGLIGIBLE and LOW will be allowed. + /// + BLOCK_MEDIUM_AND_ABOVE, + /// + /// Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed. + /// + BLOCK_ONLY_HIGH, + /// + /// All content will be allowed. + /// + BLOCK_NONE, + } + + /// + /// Probability that a prompt or candidate matches a harm category. + /// + public enum HarmProbability + { + /// + /// Probability is unspecified. + /// + HARM_PROBABILITY_UNSPECIFIED, + // Content has a negligible chance of being unsafe. + NEGLIGIBLE, + /// + /// Content has a low chance of being unsafe. + /// + LOW, + /// + /// Content has a medium chance of being unsafe. + /// + MEDIUM, + /// + /// Content has a high chance of being unsafe. + /// + HIGH + } + + /// + /// Reason that a candidate finished. + /// + public enum FinishReason + { + /// + /// Default value. This value is unused. + /// + FINISH_REASON_UNSPECIFIED, + /// + /// Natural stop point of the model or provided stop sequence. + /// + STOP, + /// + /// The maximum number of tokens as specified in the request was reached. + /// + MAX_TOKENS, + /// + /// The candidate content was flagged for safety reasons. + /// + SAFETY, + /// + /// The candidate content was flagged for recitation reasons. + /// + RECITATION, + /// + /// Unknown reason. + /// + OTHER, + } + + /// + /// Task type for embedding content. + /// + public enum TaskType + { + TASK_TYPE_UNSPECIFIED, + RETRIEVAL_QUERY, + RETRIEVAL_DOCUMENT, + SEMANTIC_SIMILARITY, + CLASSIFICATION, + CLUSTERING, + } +} diff --git a/src/GenerativeAI/Types/Requests.cs b/src/GenerativeAI/Types/Requests.cs new file mode 100644 index 0000000..61dcf25 --- /dev/null +++ b/src/GenerativeAI/Types/Requests.cs @@ -0,0 +1,87 @@ +using GenerativeAI.Tools; + +namespace GenerativeAI.Types +{ + /// + /// Base parameters for a number of methods. + /// + public class BaseParams + { + public SafetySetting[]? SafetySettings { get; set; } + public GenerationConfig? GenerationConfig { get; set; } + } + + /// + /// Params passed to {@link GoogleGenerativeAI.getGenerativeModel}. + /// + public class ModelParams : BaseParams + { + public string? Model { get; set; } + } + + /// + /// Request sent to `generateContent` endpoint. + /// + public class GenerateContentRequest : BaseParams + { + public Content[]? Contents { get; set; } + public List Tools { get; set; } + } + + /// + /// Params for {@link GenerativeModel.startChat}. + /// + public class StartChatParams : BaseParams + { + public InputContent[]? History { get; set; } + } + + /// + /// Params for calling {@link GenerativeModel.countTokens} + /// + public class CountTokensRequest + { + public Content[]? Contents + { + get; + set; + } + } + + /// + /// Params for calling {@link GenerativeModel.embedContent} + /// + public class EmbedContentRequest + { + public Content? Content { get; set; } + public TaskType TaskType { get; set; } + public string? Title { get; set; } + } + + public class BatchEmbedContentsRequest + { + public EmbedContentRequest[]? Requests { get; set; } + } + + /// + /// Config options for content-related requests + /// + public class GenerationConfig + { + public int? CandidateCount { get; set; } + public string[]? StopSequences { get; set; } + public int? MaxOutputTokens { get; set; } + public double? Temperature { get; set; } + public double? TopP { get; set; } + public double? TopK { get; set; } + } + + /// + /// Safety setting that can be sent as part of request parameters. + /// + public class SafetySetting + { + public HarmCategory Category { get; set; } + public HarmBlockThreshold Threshold { get; set; } + } +} diff --git a/src/GenerativeAI/Types/Responses.cs b/src/GenerativeAI/Types/Responses.cs new file mode 100644 index 0000000..a6c13d5 --- /dev/null +++ b/src/GenerativeAI/Types/Responses.cs @@ -0,0 +1,122 @@ +using GenerativeAI.Tools; + +namespace GenerativeAI.Types +{ + /// + /// Result object returned from generateContent() call. + /// + public class GenerateContentResult + { + public EnhancedGenerateContentResponse? Response { get; set; } + } + + /// + /// Response from calling {@link GenerativeModel.countTokens}. + /// + public class CountTokensResponse + { + public int TotalTokens { get; set; } + } + + /// + /// Response from calling {@link GenerativeModel.batchEmbedContents}. + /// + public class BatchEmbedContentsResponse + { + public ContentEmbedding[] Embeddings { get; set; } + } + + /// + /// Response from calling {@link GenerativeModel.embedContent}. + /// + public class EmbedContentResponse + { + public ContentEmbedding Embedding { get; set; } + } + + /// + /// A single content embedding. + /// + public class ContentEmbedding + { + public int[] Values { get; set; } + } + + /// + /// Response object wrapped with helper methods. + /// + public class EnhancedGenerateContentResponse: GenerateContentResponse + { + public virtual string? Text() + { + return this.Candidates?[0].Content?.Parts?[0].Text; + } + + internal ChatFunctionCall? GetFunction() + { + return Candidates?[0].Content?.Parts?[0].FunctionCall; + } + } + + /// + /// Individual response from {@link GenerativeModel.generateContent} and + /// {@link GenerativeModel.generateContentStream}. + /// `generateContentStream()` will return one in each chunk until + /// the stream is done. + /// + public class GenerateContentResponse + { + public GenerateContentCandidate[]? Candidates { get; set; } + public PromptFeedback? PromptFeedback { get; set; } + } + /// + /// A candidate returned as part of a {@link GenerateContentResponse}. + /// + public class GenerateContentCandidate + { + public int? Index { get; set; } + public Content? Content { get; set; } + public FinishReason FinishReason { get; set; } + public string FinishMessage { get; set; } + public SafetyRating[]? SafetyRatings { get; set; } + public CitationMetadata? CitationMetadata { get; set; } + } + /// + /// If the prompt was blocked, this will be populated with `blockReason` and + /// the relevant `safetyRatings`. + /// + public class PromptFeedback + { + public BlockReason BlockReason { get; set; } + public SafetyRating[] SafetyRatings { get; set; } + public string? BlockReasonMessage { get; set; } + } + + /// + /// A safety rating associated with a {@link GenerateContentCandidate} + /// + public class SafetyRating + { + public HarmCategory Category { get; set; } + public HarmProbability Probability { get; set; } + } + + /// + /// Citation metadata that may be found on a {@link GenerateContentCandidate}. + /// + public class CitationMetadata + { + public CitationSource[]? CitationSources { get; set; } + } + + /// + /// A single citation source. + /// + public class CitationSource + { + public int? StartIndex { get; set; } + public int? EndIndex { get; set;} + public string? Uri { get; set; } + public string? License { get; set; } + } +} diff --git a/src/GenerativeAI/Types/Roles.cs b/src/GenerativeAI/Types/Roles.cs new file mode 100644 index 0000000..20a5719 --- /dev/null +++ b/src/GenerativeAI/Types/Roles.cs @@ -0,0 +1,9 @@ +namespace GenerativeAI.Types +{ + public class Roles + { + public const string User = "user"; + public const string Model = "model"; + public const string Function = "function"; + } +} diff --git a/src/GenerativeAI/Types/SerializerContext.cs b/src/GenerativeAI/Types/SerializerContext.cs new file mode 100644 index 0000000..515f44d --- /dev/null +++ b/src/GenerativeAI/Types/SerializerContext.cs @@ -0,0 +1,21 @@ +using System.Text.Json.Serialization; + +namespace GenerativeAI.Types +{ + [JsonSerializable(typeof(BatchEmbedContentsResponse))] + [JsonSerializable(typeof(EnhancedGenerateContentResponse))] + [JsonSerializable(typeof(GenerateContentResult))] + [JsonSerializable(typeof(GenerateContentRequest))] + [JsonSerializable(typeof(StartChatParams))] + [JsonSerializable(typeof(GenerateContentResponse))] + [JsonSerializable(typeof(CountTokensRequest))] + [JsonSerializable(typeof(CountTokensResponse))] + [JsonSourceGenerationOptions( + PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull, + UseStringEnumConverter = true + )] + internal partial class GoogleSerializerContext : JsonSerializerContext + { + } +} diff --git a/tests/GenerativeAI.IntegrationTests/GenerativeAI.IntegrationTests.csproj b/tests/GenerativeAI.IntegrationTests/GenerativeAI.IntegrationTests.csproj new file mode 100644 index 0000000..0123b33 --- /dev/null +++ b/tests/GenerativeAI.IntegrationTests/GenerativeAI.IntegrationTests.csproj @@ -0,0 +1,30 @@ + + + + net7.0 + enable + enable + + false + true + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + diff --git a/tests/GenerativeAI.IntegrationTests/GlobalUsings.cs b/tests/GenerativeAI.IntegrationTests/GlobalUsings.cs new file mode 100644 index 0000000..8c927eb --- /dev/null +++ b/tests/GenerativeAI.IntegrationTests/GlobalUsings.cs @@ -0,0 +1 @@ +global using Xunit; \ No newline at end of file diff --git a/tests/GenerativeAI.IntegrationTests/WeatherService.cs b/tests/GenerativeAI.IntegrationTests/WeatherService.cs new file mode 100644 index 0000000..c166ef5 --- /dev/null +++ b/tests/GenerativeAI.IntegrationTests/WeatherService.cs @@ -0,0 +1,62 @@ +using DescriptionAttribute = System.ComponentModel.DescriptionAttribute; + +namespace GenerativeAI.IntegrationTests +{ + public enum Unit + { + Celsius, + Fahrenheit, + Imperial + } + + public class Weather + { + public string Location { get; set; } = string.Empty; + public double Temperature { get; set; } + public Unit Unit { get; set; } + public string Description { get; set; } = string.Empty; + } + + [GenerativeAIFunctions] + public interface IWeatherFunctions + { + [Description("Get the current weather in a given location")] + public Weather GetCurrentWeather( + [Description("The city and state, e.g. San Francisco, CA")] + string location, + Unit unit = Unit.Celsius); + + [Description("Get the current weather in a given location")] + public Task GetCurrentWeatherAsync( + [Description("The city and state, e.g. San Francisco, CA")] + string location, + Unit unit = Unit.Celsius, + CancellationToken cancellationToken = default); + } + + public class WeatherService : IWeatherFunctions + { + public Weather GetCurrentWeather(string location, Unit unit = Unit.Celsius) + { + return new Weather + { + Location = location, + Temperature = 30.0, + Unit = unit, + Description = "Sunny", + }; + } + + public Task GetCurrentWeatherAsync(string location, Unit unit = Unit.Celsius, + CancellationToken cancellationToken = default) + { + return Task.FromResult(new Weather + { + Location = location, + Temperature = 22.0, + Unit = unit, + Description = "Sunny", + }); + } + } +} \ No newline at end of file diff --git a/tests/GenerativeAI.IntegrationTests/WeatherServiceTests.cs b/tests/GenerativeAI.IntegrationTests/WeatherServiceTests.cs new file mode 100644 index 0000000..a25bcbe --- /dev/null +++ b/tests/GenerativeAI.IntegrationTests/WeatherServiceTests.cs @@ -0,0 +1,29 @@ +using GenerativeAI.Models; +using Xunit.Abstractions; + +namespace GenerativeAI.IntegrationTests +{ + public class WeatherServiceTests + { + private ITestOutputHelper Console; + public WeatherServiceTests(ITestOutputHelper helper) + { + this.Console = helper; + } + [Fact] + public async Task ShouldInvokeWetherService() + { + WeatherService service = new WeatherService(); + var functions = service.AsFunctions(); + var calls = service.AsCalls(); + + var apiKey = Environment.GetEnvironmentVariable("Gemini_API_Key", EnvironmentVariableTarget.User); + + var model = new GenerativeModel(apiKey, functions:functions,calls:calls); + + var result = await model.GenerateContentAsync("What is the weather in san fransisco today?"); + + Console.WriteLine(result); + } + } +} \ No newline at end of file diff --git a/tests/GenerativeAI.Tests/GenerativeAI.Tests.csproj b/tests/GenerativeAI.Tests/GenerativeAI.Tests.csproj new file mode 100644 index 0000000..46c41db --- /dev/null +++ b/tests/GenerativeAI.Tests/GenerativeAI.Tests.csproj @@ -0,0 +1,34 @@ + + + + net8.0 + enable + enable + + false + true + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + diff --git a/tests/GenerativeAI.Tests/GlobalUsings.cs b/tests/GenerativeAI.Tests/GlobalUsings.cs new file mode 100644 index 0000000..8c927eb --- /dev/null +++ b/tests/GenerativeAI.Tests/GlobalUsings.cs @@ -0,0 +1 @@ +global using Xunit; \ No newline at end of file diff --git a/tests/GenerativeAI.Tests/Model/ChatSession_Tests.cs b/tests/GenerativeAI.Tests/Model/ChatSession_Tests.cs new file mode 100644 index 0000000..b9f3f5b --- /dev/null +++ b/tests/GenerativeAI.Tests/Model/ChatSession_Tests.cs @@ -0,0 +1,32 @@ +using GenerativeAI.Models; +using GenerativeAI.Types; +using Xunit.Abstractions; + +namespace GenerativeAI.Tests.Model +{ + public class ChatSession_Tests + { + ITestOutputHelper Console; + public ChatSession_Tests(ITestOutputHelper helper) + { + this.Console = helper; + } + + [Fact] + public async Task ChatSession_Run() + { + var apiKey = Environment.GetEnvironmentVariable("Gemini_API_Key", EnvironmentVariableTarget.User); + + var model = new GenerativeModel(apiKey); + + var chat = model.StartChat(new StartChatParams()); + var result = await chat.SendMessage("Write a poem"); + Console.WriteLine("Initial Poem\r\n"); + Console.WriteLine(result.Text()); + + var result2 = await chat.SendMessage("Make it longer"); + Console.WriteLine("\r\nLong Poem\r\n"); + Console.WriteLine(result2.Text()); + } + } +} diff --git a/tests/GenerativeAI.Tests/Model/GeminiPro_Tests.cs b/tests/GenerativeAI.Tests/Model/GeminiPro_Tests.cs new file mode 100644 index 0000000..2748151 --- /dev/null +++ b/tests/GenerativeAI.Tests/Model/GeminiPro_Tests.cs @@ -0,0 +1,48 @@ +using GenerativeAI.Helpers; +using GenerativeAI.Models; +using GenerativeAI.Types; +using Shouldly; +using Xunit.Abstractions; + +namespace GenerativeAI.Tests.Model +{ + public class GeminiPro_Tests + { + private ITestOutputHelper Console; + public GeminiPro_Tests(ITestOutputHelper helper) + { + this.Console = helper; + } + [Fact] + public async Task ShouldGenerateResult() + { + var apiKey = Environment.GetEnvironmentVariable("Gemini_API_Key", EnvironmentVariableTarget.User); + + var model = new GenerativeModel(apiKey); + + var res = await model.GenerateContentAsync("How are you doing?"); + + res.ShouldNotBeNullOrEmpty(); + + Console.WriteLine(res); + } + + [Fact] + public async Task ShouldCountTokens() + { + var apiKey = Environment.GetEnvironmentVariable("Gemini_API_Key", EnvironmentVariableTarget.User); + + var text = + "In the realm of dreams where magic thrives,\nWhere whispers of fantasy come alive,\nA tale unfolds, a journey to behold,\nWhere dreams take flight, like stories untold.\n\nAmidst the stars, a shimmering light,\nA celestial vision, enchanting the night,\nA dreamer awakens, with eyes wide open,\nEmbracing the wonders, that lie unspoken.\n\nThrough landscapes vast, where colors dance,\nGuided by hope, a heart's keen glance,\nWith every step, the dreamer learns,\nThat courage and kindness the spirit earns.\n\nIn shadows deep, where secrets reside,\nMysteries unravel, side by side,\nEach challenge faced, a lesson to embrace,\nThe dreamer's resolve, forever in grace.\n\nThrough trials and triumphs, the path unwinds,\nA tapestry of dreams, where destiny binds,\nFor in the realms of dreams, where magic dwells,\nThe power of belief forever excels.\n\nSo let us wander, with hearts set free,\nIn this realm of dreams, where possibilities decree,\nFor in these enchanted lands, we find,\nThe magic that lies within our own mind.\r\n"; + var content = RequestExtensions.FormatGenerateContentInput(text); + var model = new GenerativeModel(apiKey); + + + var res = await model.CountTokens(new CountTokensRequest(){Contents = new[]{content}}); + res.TotalTokens.ShouldBeGreaterThan(0); + + + Console.WriteLine($"Tokens count = {res.TotalTokens}"); + } + } +}