Skip to content

Commit

Permalink
.Net: fix: forwards cancellation token to outbound connections to fre…
Browse files Browse the repository at this point in the history
…e up resources upon cancellation (#9738)

Signed-off-by: Vincent Biret <[email protected]>

---------

Signed-off-by: Vincent Biret <[email protected]>
  • Loading branch information
baywet authored Nov 19, 2024
1 parent 5197f0b commit 7d5b50c
Show file tree
Hide file tree
Showing 19 changed files with 51 additions and 38 deletions.
2 changes: 1 addition & 1 deletion dotnet/src/Connectors/Connectors.Google/Core/ClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ protected async Task<string> SendRequestAndGetStringBodyAsync(
{
using var response = await this.HttpClient.SendWithSuccessCheckAsync(httpRequestMessage, cancellationToken)
.ConfigureAwait(false);
var body = await response.Content.ReadAsStringWithExceptionMappingAsync()
var body = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken)
.ConfigureAwait(false);
return body;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ public async IAsyncEnumerable<StreamingChatMessageContent> StreamGenerateChatMes
{
using var httpRequestMessage = await this.CreateHttpRequestAsync(state.GeminiRequest, this._chatStreamingEndpoint).ConfigureAwait(false);
httpResponseMessage = await this.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false);
responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ internal async Task<string> SendRequestAndGetStringBodyAsync(
using var response = await this._httpClient.SendWithSuccessCheckAsync(httpRequestMessage, cancellationToken)
.ConfigureAwait(false);

var body = await response.Content.ReadAsStringWithExceptionMappingAsync()
var body = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken)
.ConfigureAwait(false);

return body;
Expand Down Expand Up @@ -185,7 +185,7 @@ public async IAsyncEnumerable<StreamingTextContent> StreamGenerateTextAsync(
{
using var httpRequestMessage = this.CreatePost(request, endpoint, this.ApiKey);
httpResponseMessage = await this.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false);
responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> StreamCompleteChatM
{
using var httpRequestMessage = this._clientCore.CreatePost(request, endpoint, this._clientCore.ApiKey);
httpResponseMessage = await this._clientCore.SendRequestAndGetResponseImmediatelyAfterHeadersReadAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false);
responseStream = await httpResponseMessage.Content.ReadAsStreamAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ public async Task<ChromaQueryResultModel> QueryEmbeddingsAsync(string collection
{
response = await this._httpClient.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false);

responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);
}
catch (HttpOperationException e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ private string GetIndexOperationsApiBasePath()

using HttpResponseMessage response = await this._httpClient.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false);

string responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
string responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);

return (response, responseContent);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ private static Uri SanitizeEndpoint(string endpoint, int? port = null)

HttpResponseMessage response = await this._httpClient.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false);

string responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
string responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);

return (response, responseContent);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ private static string ToWeaviateFriendlyClassName(string collectionName)
{
HttpResponseMessage response = await this._httpClient.SendWithSuccessCheckAsync(request, cancel).ConfigureAwait(false);

string? responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
string? responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync(cancel).ConfigureAwait(false);

this._logger.LogDebug("Weaviate responded with {StatusCode}", response.StatusCode);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public async IAsyncEnumerable<string> ListCollectionNamesAsync([EnumeratorCancel
using var request = new WeaviateGetCollectionsRequest().Build();

var response = await this._httpClient.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false);
var responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
var responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);
var collectionResponse = JsonSerializer.Deserialize<WeaviateGetCollectionsResponse>(responseContent);

if (collectionResponse?.Collections is not null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ private Task<HttpResponseMessage> ExecuteRequestAsync(HttpRequestMessage request
{
var response = await this.ExecuteRequestAsync(request, cancellationToken).ConfigureAwait(false);

var responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
var responseContent = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);

var responseModel = JsonSerializer.Deserialize<TResponse>(responseContent, s_jsonSerializerOptions);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ private async IAsyncEnumerable<StreamingChatMessageContent> StreamChatMessageCon
var endpoint = this.GetEndpoint(executionSettings, path: "chat/completions");
using var httpRequestMessage = this.CreatePost(chatRequest, endpoint, this._apiKey, stream: true);
using var response = await this.SendStreamingRequestAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);
using var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false);
using var responseStream = await response.Content.ReadAsStreamAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false);
await foreach (var streamingChatContent in this.ProcessChatResponseStreamAsync(responseStream, modelId, cancellationToken).ConfigureAwait(false))
{
yield return streamingChatContent;
Expand Down Expand Up @@ -787,7 +787,7 @@ private async Task<T> SendRequestAsync<T>(HttpRequestMessage httpRequestMessage,
{
using var response = await this._httpClient.SendWithSuccessCheckAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);

var body = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
var body = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);

return DeserializeResponse<T>(body);
}
Expand Down Expand Up @@ -934,7 +934,7 @@ private void AddResponseMessage(ChatHistory chat, MistralToolCall toolCall, stri
// Add the tool response message to the chat history
var message = new ChatMessageContent(AuthorRole.Tool, result, metadata: new Dictionary<string, object?> { { nameof(MistralToolCall.Function), toolCall.Function } });

// Add an item of type FunctionResultContent to the ChatMessageContent.Items collection in addition to the function result stored as a string in the ChatMessageContent.Content property.
// Add an item of type FunctionResultContent to the ChatMessageContent.Items collection in addition to the function result stored as a string in the ChatMessageContent.Content property.
// This will enable migration to the new function calling model and facilitate the deprecation of the current one in the future.
if (toolCall.Function is not null)
{
Expand Down Expand Up @@ -989,16 +989,16 @@ private void AddResponseMessage(ChatHistory chat, MistralToolCall toolCall, stri
return stringResult;
}

// This is an optimization to use ChatMessageContent content directly
// without unnecessary serialization of the whole message content class.
// This is an optimization to use ChatMessageContent content directly
// without unnecessary serialization of the whole message content class.
if (functionResult is ChatMessageContent chatMessageContent)
{
return chatMessageContent.ToString();
}

// For polymorphic serialization of unknown in advance child classes of the KernelContent class,
// a corresponding JsonTypeInfoResolver should be provided via the JsonSerializerOptions.TypeInfoResolver property.
// For more details about the polymorphic serialization, see the article at:
// For polymorphic serialization of unknown in advance child classes of the KernelContent class,
// a corresponding JsonTypeInfoResolver should be provided via the JsonSerializerOptions.TypeInfoResolver property.
// For more details about the polymorphic serialization, see the article at:
// https://learn.microsoft.com/en-us/dotnet/standard/serialization/system-text-json/polymorphism?pivots=dotnet-8-0
return JsonSerializer.Serialize(functionResult, toolCallBehavior?.ToolCallResultSerializerOptions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ private async Task<TModel> ExecuteGetRequestAsync<TModel>(string url, Cancellati
this.AddRequestHeaders(request);
using var response = await this._httpClient.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false);

var body = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
var body = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);

var model = JsonSerializer.Deserialize<TModel>(body);

Expand All @@ -230,7 +230,7 @@ private async Task<TModel> ExecuteGetRequestAsync<TModel>(string url, Cancellati
{
return
(new HttpResponseStream(
await response.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false),
await response.Content.ReadAsStreamAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false),
response),
response.Content.Headers.ContentType?.MediaType);
}
Expand All @@ -247,7 +247,7 @@ private async Task<TModel> ExecutePostRequestAsync<TModel>(string url, HttpConte
this.AddRequestHeaders(request);
using var response = await this._httpClient.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false);

var body = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
var body = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);

var model = JsonSerializer.Deserialize<TModel>(body);

Expand Down
4 changes: 2 additions & 2 deletions dotnet/src/Functions/Functions.OpenApi/DocumentLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ internal static async Task<string> LoadDocumentFromUriAsync(
CancellationToken cancellationToken)
{
using var response = await LoadDocumentResponseFromUriAsync(uri, logger, httpClient, authCallback, userAgent, cancellationToken).ConfigureAwait(false);
return await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
return await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);
}

internal static async Task<Stream> LoadDocumentFromUriAsStreamAsync(
Expand All @@ -35,7 +35,7 @@ internal static async Task<Stream> LoadDocumentFromUriAsStreamAsync(
{
//disposing the response disposes the stream
var response = await LoadDocumentResponseFromUriAsync(uri, logger, httpClient, authCallback, userAgent, cancellationToken).ConfigureAwait(false);
var stream = await response.Content.ReadAsStreamAndTranslateExceptionAsync().ConfigureAwait(false);
var stream = await response.Content.ReadAsStreamAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false);
return new HttpResponseStream(stream, response);
}

Expand Down
11 changes: 4 additions & 7 deletions dotnet/src/Functions/Functions.OpenApi/RestApiOperationRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,12 @@ internal sealed class RestApiOperationRunner
/// <summary>
/// A dictionary containing the content type as the key and the corresponding content reader as the value.
/// </summary>
/// <remarks>
/// TODO: Pass cancelation tokes to the content readers.
/// </remarks>
private static readonly Dictionary<string, HttpResponseContentReader> s_contentReaderByContentType = new()
{
{ "image", async (context, _) => await context.Response.Content.ReadAsByteArrayAndTranslateExceptionAsync().ConfigureAwait(false) },
{ "text", async (context, _) => await context.Response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false) },
{ "application/json", async (context, _) => await context.Response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false)},
{ "application/xml", async (context, _) => await context.Response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false)}
{ "image", async (context, cancellationToken) => await context.Response.Content.ReadAsByteArrayAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false) },
{ "text", async (context, cancellationToken) => await context.Response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false) },
{ "application/json", async (context, cancellationToken) => await context.Response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false)},
{ "application/xml", async (context, cancellationToken) => await context.Response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false)}
};

/// <summary>
Expand Down
22 changes: 19 additions & 3 deletions dotnet/src/InternalUtilities/src/Http/HttpContentExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.SemanticKernel.Http;
Expand All @@ -17,12 +18,17 @@ internal static class HttpContentExtensions
/// Reads the content of the HTTP response as a string and translates any HttpRequestException into an HttpOperationException.
/// </summary>
/// <param name="httpContent">The HTTP content to read.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A string representation of the HTTP content.</returns>
public static async Task<string> ReadAsStringWithExceptionMappingAsync(this HttpContent httpContent)
public static async Task<string> ReadAsStringWithExceptionMappingAsync(this HttpContent httpContent, CancellationToken cancellationToken = default)
{
try
{
#if NET5_0_OR_GREATER
return await httpContent.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
#else
return await httpContent.ReadAsStringAsync().ConfigureAwait(false);
#endif
}
catch (HttpRequestException ex)
{
Expand All @@ -34,12 +40,17 @@ public static async Task<string> ReadAsStringWithExceptionMappingAsync(this Http
/// Reads the content of the HTTP response as a stream and translates any HttpRequestException into an HttpOperationException.
/// </summary>
/// <param name="httpContent">The HTTP content to read.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A stream representing the HTTP content.</returns>
public static async Task<Stream> ReadAsStreamAndTranslateExceptionAsync(this HttpContent httpContent)
public static async Task<Stream> ReadAsStreamAndTranslateExceptionAsync(this HttpContent httpContent, CancellationToken cancellationToken = default)
{
try
{
#if NET5_0_OR_GREATER
return await httpContent.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
#else
return await httpContent.ReadAsStreamAsync().ConfigureAwait(false);
#endif
}
catch (HttpRequestException ex)
{
Expand All @@ -51,12 +62,17 @@ public static async Task<Stream> ReadAsStreamAndTranslateExceptionAsync(this Htt
/// Reads the content of the HTTP response as a byte array and translates any HttpRequestException into an HttpOperationException.
/// </summary>
/// <param name="httpContent">The HTTP content to read.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>A byte array representing the HTTP content.</returns>
public static async Task<byte[]> ReadAsByteArrayAndTranslateExceptionAsync(this HttpContent httpContent)
public static async Task<byte[]> ReadAsByteArrayAndTranslateExceptionAsync(this HttpContent httpContent, CancellationToken cancellationToken = default)
{
try
{
#if NET5_0_OR_GREATER
return await httpContent.ReadAsByteArrayAsync(cancellationToken).ConfigureAwait(false);
#else
return await httpContent.ReadAsByteArrayAsync().ConfigureAwait(false);
#endif
}
catch (HttpRequestException ex)
{
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Plugins/Plugins.Core/HttpPlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,6 @@ private async Task<string> SendRequestAsync(string uri, HttpMethod method, HttpC
request.Headers.Add("User-Agent", HttpHeaderConstant.Values.UserAgent);
request.Headers.Add(HttpHeaderConstant.Names.SemanticKernelVersion, HttpHeaderConstant.Values.GetAssemblyVersion(typeof(HttpPlugin)));
using var response = await this._client.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false);
return await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
return await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);
}
}
2 changes: 1 addition & 1 deletion dotnet/src/Plugins/Plugins.Web/Bing/BingConnector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public async Task<IEnumerable<T>> SearchAsync<T>(string query, int count = 1, in

this._logger.LogDebug("Response received: {StatusCode}", response.StatusCode);

string json = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
string json = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);

// Sensitive data, logging as trace, disabled by default
this._logger.LogTrace("Response content received: {Data}", json);
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Plugins/Plugins.Web/Bing/BingTextSearch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ public async Task<KernelSearchResults<object>> GetSearchResultsAsync(string quer

this._logger.LogDebug("Response received: {StatusCode}", response.StatusCode);

string json = await response.Content.ReadAsStringWithExceptionMappingAsync().ConfigureAwait(false);
string json = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);

// Sensitive data, logging as trace, disabled by default
this._logger.LogTrace("Response content received: {Data}", json);
Expand Down
Loading

0 comments on commit 7d5b50c

Please sign in to comment.