Skip to content

Commit

Permalink
add: test cases for initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Gunpal Jain committed Feb 16, 2025
1 parent c5fcc02 commit 9eea7e3
Show file tree
Hide file tree
Showing 15 changed files with 483 additions and 31 deletions.
16 changes: 11 additions & 5 deletions src/GenerativeAI.Microsoft/GenerativeAIChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,31 @@

namespace GenerativeAI.Microsoft;


/// <inheritdoc/>
public class GenerativeAIChatClient : IChatClient
{
public GenerativeModel model { get; }

/// <inheritdoc/>
public GenerativeAIChatClient(string apiKey,string modelName = GoogleAIModels.DefaultGeminiModel)
{
model = new GenerativeModel(apiKey, modelName);
}

/// <inheritdoc/>
public GenerativeAIChatClient(IPlatformAdapter adapter, string modelName = GoogleAIModels.DefaultGeminiModel)
{
model = new GenerativeModel(adapter, modelName);
}


/// <inheritdoc/>
public void Dispose()
{

}

/// <inheritdoc/>
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions? options = null,
CancellationToken cancellationToken = default)
{
Expand All @@ -32,7 +38,7 @@ public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages,
var response = await model.GenerateContentAsync(request, cancellationToken);
return response.ToChatCompletion() ?? throw new Exception("Failed to generate content");
}

/// <inheritdoc/>
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(IList<ChatMessage> chatMessages,
ChatOptions? options = null,
CancellationToken cancellationToken = new CancellationToken())
Expand All @@ -45,15 +51,15 @@ public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAs
yield return response.ToStreamingChatCompletionUpdate();
}
}

/// <inheritdoc/>
public object? GetService(Type serviceType, object? serviceKey = null)
{
if (serviceKey == null && (serviceType is GenerativeAIChatClient))
if (serviceKey == null && (bool)serviceType?.IsInstanceOfType(this))
{
return this;
}
else return null;
}

/// <inheritdoc/>
public ChatClientMetadata Metadata { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public async IAsyncEnumerable<GenerateContentResponse> StreamContentAsync(
var request = new GenerateContentRequest();

request.AddContent(new Content() { Role = Roles.User });
var uri = new Uri(filePath);

await AppendFile(filePath, request, cancellationToken);


Expand Down
15 changes: 9 additions & 6 deletions src/GenerativeAI/Constants/DefaultSerializerOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ namespace GenerativeAI;

internal class DefaultSerializerOptions
{
internal static readonly JsonSerializerOptions Options = new JsonSerializerOptions
internal static JsonSerializerOptions Options
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
PropertyNameCaseInsensitive = true,
Converters = { new JsonStringEnumConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
};
get => new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
PropertyNameCaseInsensitive = true,
Converters = { new JsonStringEnumConverter() },
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
};
}
}
19 changes: 17 additions & 2 deletions src/GenerativeAI/Extensions/GenerateContentResponseExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using GenerativeAI.Core;
using System.Text;
using GenerativeAI.Core;
using GenerativeAI.Types;

namespace GenerativeAI;
Expand All @@ -15,7 +16,21 @@ public static class GenerateContentResponseExtensions
/// <returns>The text if found; otherwise null.</returns>
public static string? Text(this GenerateContentResponse response)
{
return response?.Candidates?[0].Content?.Parts?[0].Text;

StringBuilder sb = new StringBuilder();
foreach (var candidate in response?.Candidates)
{
if(candidate.Content==null)
continue;
foreach (var p in candidate.Content?.Parts)
{
sb.AppendLine(p.Text);
}
}
var text = sb.ToString();
if (string.IsNullOrEmpty(text))
return null;
else return text;
}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion src/GenerativeAI/Platforms/GoogleAICredentials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class GoogleAICredentials : ICredentials
/// The API Key provides an easy way to access public resources or perform
/// authorized operations without requiring OAuth2 tokens.
/// </summary>
public string ApiKey { get; }
public string ApiKey { get; set; }

/// <summary>
/// Gets the Access Token for authenticating with Google AI APIs.
Expand All @@ -33,6 +33,10 @@ public GoogleAICredentials(string apiKey,string? accessToken = null, DateTime? e
this.AuthToken = new AuthTokens(accessToken, expiryTime:expiry);
}

public GoogleAICredentials()
{

}
/// <summary>
/// Validates the API credentials for the GoogleAICredentials instance.
/// Ensures that either an API Key or an Access Token is provided.
Expand Down
4 changes: 3 additions & 1 deletion src/GenerativeAI/Platforms/GoogleAIPlatformAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public async Task AuthorizeAsync(CancellationToken cancellationToken = default)
public string GetBaseUrl(bool appendVesion = true)
{
if (appendVesion)
return $"{BaseUrl}/{ApiVersion}";
return $"{BaseUrl}/{GetApiVersion()}";
return BaseUrl;
}

Expand Down Expand Up @@ -245,6 +245,8 @@ public string CreateUrlForTunedModel(string modelId, string task)
/// <return>The API version string.</return>
public string GetApiVersion()
{
if(string.IsNullOrEmpty(ApiVersion))
ApiVersion = ApiVersions.v1Beta;
return ApiVersion;
}

Expand Down
13 changes: 7 additions & 6 deletions src/GenerativeAI/Platforms/VertextPlatformAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,16 @@ public VertextPlatformAdapter(string? projectId = null, string? region = null, b
if (string.IsNullOrEmpty(apiKey))
throw new Exception("API Key is required for Vertex AI Express.");
}
else if (string.IsNullOrEmpty(accessToken))

if (authenticator == null)
{
if (authenticator == null)
if(string.IsNullOrEmpty(accessToken))
this.Authenticator = new GoogleCloudAdcAuthenticator(credentialsFile, logger);
else this.Authenticator = authenticator;


}
else

else this.Authenticator = authenticator;

if (!string.IsNullOrEmpty(accessToken))
{
this.Credentials = new GoogleAICredentials(apiKey, accessToken);
}
Expand Down
1 change: 1 addition & 0 deletions tests/GenerativeAI.Tests/GenerativeAI.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.11.1" />
<PackageReference Include="Moq" Version="4.20.72" />
<PackageReference Include="Shouldly" Version="4.2.1" />
<PackageReference Include="xunit" Version="2.7.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.7">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ public async Task ShouldProcessAudioWithFilePath()
result.ShouldNotBeNull();
var text = result.Text();
text.ShouldNotBeNull();
text.ShouldContain("theological", Case.Insensitive);
// if(!text.Contains("theological",StringComparison.InvariantCultureIgnoreCase) && !text.Contains("Friedrich",StringComparison.InvariantCultureIgnoreCase))
// text.ShouldContain("theological", Case.Insensitive);
Console.WriteLine(result.Text());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
using System;
using System.Net.Http;
using GenerativeAI;
using GenerativeAI.Core;
using Microsoft.Extensions.Logging;
using Moq;
using Shouldly;
using Xunit;

public class GoogleAIPlatformAdapterTests
{
[Fact]
public void Constructor_WithNullApiKey_ShouldThrowException()
{
// Arrange / Act / Assert
var ex = Should.Throw<Exception>(() =>
{
// Passing null for the API key (and environment variable checks are skipped here)
_ = new GoogleAIPlatformAdapter(null!);
});

ex.Message.ShouldContain("API Key is required");
}

[Fact]
public void Constructor_WithValidApiKey_ShouldInitializeCredentialsProperly()
{
// Arrange
const string testApiKey = "TEST_API_KEY";

// Act
var adapter = new GoogleAIPlatformAdapter(testApiKey);

// Assert
adapter.Credentials.ShouldNotBeNull();
adapter.Credentials.ApiKey.ShouldBe(testApiKey);
adapter.Credentials.AuthToken.ShouldBeNull("No access token was passed, so AuthToken should be null.");
}

[Fact]
public void Constructor_WithApiKeyAndAccessToken_ShouldPopulateCredentialsAndToken()
{
// Arrange
const string testApiKey = "TEST_API_KEY";
const string testAccessToken = "TEST_ACCESS_TOKEN";

// Act
var adapter = new GoogleAIPlatformAdapter(testApiKey, accessToken: testAccessToken);

// Assert
adapter.Credentials.ShouldNotBeNull();
adapter.Credentials.ApiKey.ShouldBe(testApiKey);
adapter.Credentials.AuthToken.ShouldNotBeNull();
adapter.Credentials.AuthToken!.AccessToken.ShouldBe(testAccessToken);
}

[Fact]
public void Constructor_WithCustomApiVersion_ShouldSetApiVersionProperty()
{
// Arrange
const string customVersion = "v2Test";

// Act
var adapter = new GoogleAIPlatformAdapter("TEST_API_KEY", apiVersion: customVersion);

// Assert
adapter.ApiVersion.ShouldBe(customVersion);
}

[Fact]
public void Constructor_DefaultBaseUrl_ShouldBeSetToGoogleGenerativeAI()
{
// Arrange
const string testApiKey = "TEST_API_KEY";

// Act
var adapter = new GoogleAIPlatformAdapter(testApiKey);

// Assert
adapter.BaseUrl.ShouldNotBeNullOrEmpty();
// Use whichever default value you expect it to have:
// For example, "https://generativelanguage.googleapis.com"
// adapter.BaseUrl.ShouldBe("https://generativelanguage.googleapis.com");
}

[Fact]
public void Constructor_ValidateAccessToken_ShouldBeTrueByDefault()
{
// Arrange / Act
var adapter = new GoogleAIPlatformAdapter("TEST_API_KEY");

// Assert
// Using reflection or a test accessor to ensure the internal property
// is set to true (the code snippet shows "bool ValidateAccessToken = true;")
adapter.ShouldSatisfyAllConditions(
() => adapter.ShouldNotBeNull(),
() => adapter.GetType().GetProperty("ValidateAccessToken",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)
.ShouldNotBeNull(),
() => adapter.GetType().GetProperty("ValidateAccessToken",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance)!
.GetValue(adapter)
.ShouldBeOfType<bool>()
.ShouldBeTrue()
);
}

[Fact]
public void Constructor_WithLogger_ShouldInitializeLogger()
{
// Arrange
var loggerMock = new Mock<ILogger>();

// Act
var adapter = new GoogleAIPlatformAdapter("TEST_API_KEY", logger: loggerMock.Object);

// Assert
// Using reflection because Logger is not public:
var loggerProperty = adapter.GetType().GetProperty("Logger",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
loggerProperty.ShouldNotBeNull("Logger property should exist.");

var actualLogger = loggerProperty!.GetValue(adapter) as ILogger;
actualLogger.ShouldNotBeNull();
actualLogger.ShouldBe(loggerMock.Object);
}

[Fact]
public void Constructor_WithAuthenticator_ShouldSetAuthenticator()
{
// Arrange
var mockAuthenticator = new Mock<IGoogleAuthenticator>();

// Act
var adapter = new GoogleAIPlatformAdapter("TEST_API_KEY", authenticator: mockAuthenticator.Object);

// Assert
var authenticatorProperty = adapter.GetType().GetProperty("Authenticator",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);

authenticatorProperty.ShouldNotBeNull("Authenticator property should exist.");
var actualAuthenticator = authenticatorProperty!.GetValue(adapter) as IGoogleAuthenticator;
actualAuthenticator.ShouldBe(mockAuthenticator.Object);
}

[Fact]
public void Constructor_WithCustomValidateAccessToken_ShouldSetProperly()
{
// Arrange
const bool testFlag = false;

// Act
var adapter = new GoogleAIPlatformAdapter("TEST_API_KEY", validateAccessToken: testFlag);

// Assert
var validateAccessTokenProperty = adapter.GetType().GetProperty("ValidateAccessToken",
System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);

validateAccessTokenProperty.ShouldNotBeNull("ValidateAccessToken property should exist.");
var actualValue = (bool)validateAccessTokenProperty!.GetValue(adapter)!;
actualValue.ShouldBeFalse();
}
}
19 changes: 12 additions & 7 deletions tests/GenerativeAI.Tests/Platforms/GoogleAI_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@ public GoogleAI_Tests(ITestOutputHelper helper) : base(helper)
public async Task ShouldThrowException_WhenProjectIdsAreInvalid()
{
var apiKey = Environment.GetEnvironmentVariable("Gemini_Api_Key",EnvironmentVariableTarget.User);
var googleAi = new GoogleAi();
var model = googleAi.CreateGenerativeModel(GoogleAIModels.Gemini15Flash);
var response = await model.GenerateContentAsync("write a poem about the sun");

response.ShouldNotBeNull();
var text = response.Text();
text.ShouldNotBeNullOrWhiteSpace();
Console.WriteLine(text);
Should.Throw<Exception>(() =>
{
var googleAi = new GoogleAi();
});

// var model = googleAi.CreateGenerativeModel(GoogleAIModels.Gemini15Flash);
// var response = await model.GenerateContentAsync("write a poem about the sun");
//
// response.ShouldNotBeNull();
// var text = response.Text();
// text.ShouldNotBeNullOrWhiteSpace();
// Console.WriteLine(text);
}

// [Fact]
Expand Down
Loading

0 comments on commit 9eea7e3

Please sign in to comment.