Skip to content

Commit

Permalink
[IngestionClient] Use dependency injection for BatchClient, add User-…
Browse files Browse the repository at this point in the history
…Agent request header (#2592)
  • Loading branch information
HenryvanderVegte authored Sep 18, 2024
1 parent 75d8e47 commit 203bd46
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 39 deletions.
53 changes: 30 additions & 23 deletions samples/ingestion/ingestion-client/Connector/BatchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Connector
using Polly;
using Polly.Retry;

public static class BatchClient
public class BatchClient
{
private const string TranscriptionsBasePath = "speechtotext/v3.0/Transcriptions/";

Expand All @@ -29,36 +29,42 @@ public static class BatchClient

private static readonly TimeSpan GetFilesTimeout = TimeSpan.FromMinutes(5);

private static readonly HttpClient HttpClient = new HttpClient() { Timeout = Timeout.InfiniteTimeSpan };

private static readonly AsyncRetryPolicy RetryPolicy =
Policy
.Handle<Exception>(e => e is HttpStatusCodeException || e is HttpRequestException)
.WaitAndRetryAsync(MaxNumberOfRetries, retryAttempt => TimeSpan.FromSeconds(5));

public static Task<TranscriptionReportFile> GetTranscriptionReportFileFromSasAsync(string sasUri)
private readonly HttpClient httpClient;

public BatchClient(IHttpClientFactory httpClientFactory)
{
ArgumentNullException.ThrowIfNull(httpClientFactory, nameof(httpClientFactory));
this.httpClient = httpClientFactory.CreateClient(nameof(BatchClient));
}

public Task<TranscriptionReportFile> GetTranscriptionReportFileFromSasAsync(string sasUri)
{
return GetAsync<TranscriptionReportFile>(sasUri, null, DefaultTimeout);
return this.GetAsync<TranscriptionReportFile>(sasUri, null, DefaultTimeout);
}

public static Task<SpeechTranscript> GetSpeechTranscriptFromSasAsync(string sasUri)
public Task<SpeechTranscript> GetSpeechTranscriptFromSasAsync(string sasUri)
{
return GetAsync<SpeechTranscript>(sasUri, null, DefaultTimeout);
return this.GetAsync<SpeechTranscript>(sasUri, null, DefaultTimeout);
}

public static Task<Transcription> GetTranscriptionAsync(string transcriptionLocation, string subscriptionKey)
public Task<Transcription> GetTranscriptionAsync(string transcriptionLocation, string subscriptionKey)
{
return GetAsync<Transcription>(transcriptionLocation, subscriptionKey, DefaultTimeout);
return this.GetAsync<Transcription>(transcriptionLocation, subscriptionKey, DefaultTimeout);
}

public static async Task<TranscriptionFiles> GetTranscriptionFilesAsync(string transcriptionLocation, string subscriptionKey)
public async Task<TranscriptionFiles> GetTranscriptionFilesAsync(string transcriptionLocation, string subscriptionKey)
{
var path = $"{transcriptionLocation}/files";
var combinedTranscriptionFiles = new List<TranscriptionFile>();

do
{
var transcriptionFiles = await GetAsync<TranscriptionFiles>(path, subscriptionKey, GetFilesTimeout).ConfigureAwait(false);
var transcriptionFiles = await this.GetAsync<TranscriptionFiles>(path, subscriptionKey, GetFilesTimeout).ConfigureAwait(false);
combinedTranscriptionFiles.AddRange(transcriptionFiles.Values);
path = transcriptionFiles.NextLink;
}
Expand All @@ -67,39 +73,39 @@ public static async Task<TranscriptionFiles> GetTranscriptionFilesAsync(string t
return new TranscriptionFiles(combinedTranscriptionFiles, null);
}

public static Task DeleteTranscriptionAsync(string transcriptionLocation, string subscriptionKey)
public Task DeleteTranscriptionAsync(string transcriptionLocation, string subscriptionKey)
{
return DeleteAsync(transcriptionLocation, subscriptionKey, DefaultTimeout);
return this.DeleteAsync(transcriptionLocation, subscriptionKey, DefaultTimeout);
}

public static async Task<Uri> PostTranscriptionAsync(TranscriptionDefinition transcriptionDefinition, string hostName, string subscriptionKey)
public async Task<Uri> PostTranscriptionAsync(TranscriptionDefinition transcriptionDefinition, string hostName, string subscriptionKey)
{
var path = $"{hostName}{TranscriptionsBasePath}";
var payloadString = JsonConvert.SerializeObject(transcriptionDefinition);

return await PostAsync(path, subscriptionKey, payloadString, PostTimeout).ConfigureAwait(false);
return await this.PostAsync(path, subscriptionKey, payloadString, PostTimeout).ConfigureAwait(false);
}

private static async Task<Uri> PostAsync(string path, string subscriptionKey, string payloadString, TimeSpan timeout)
private async Task<Uri> PostAsync(string path, string subscriptionKey, string payloadString, TimeSpan timeout)
{
var responseMessage = await SendHttpRequestMessage(HttpMethod.Post, path, subscriptionKey, payloadString, timeout).ConfigureAwait(false);
var responseMessage = await this.SendHttpRequestMessage(HttpMethod.Post, path, subscriptionKey, payloadString, timeout).ConfigureAwait(false);
return responseMessage.Headers.Location;
}

private static async Task DeleteAsync(string path, string subscriptionKey, TimeSpan timeout)
private async Task DeleteAsync(string path, string subscriptionKey, TimeSpan timeout)
{
await SendHttpRequestMessage(HttpMethod.Delete, path, subscriptionKey, payload: null, timeout: timeout).ConfigureAwait(false);
await this.SendHttpRequestMessage(HttpMethod.Delete, path, subscriptionKey, payload: null, timeout: timeout).ConfigureAwait(false);
}

private static async Task<TResponse> GetAsync<TResponse>(string path, string subscriptionKey, TimeSpan timeout)
private async Task<TResponse> GetAsync<TResponse>(string path, string subscriptionKey, TimeSpan timeout)
{
var responseMessage = await SendHttpRequestMessage(HttpMethod.Get, path, subscriptionKey, payload: null, timeout: timeout).ConfigureAwait(false);
var responseMessage = await this.SendHttpRequestMessage(HttpMethod.Get, path, subscriptionKey, payload: null, timeout: timeout).ConfigureAwait(false);

var contentString = await responseMessage.Content.ReadAsStringAsync().ConfigureAwait(false);
return JsonConvert.DeserializeObject<TResponse>(contentString);
}

private static async Task<HttpResponseMessage> SendHttpRequestMessage(HttpMethod httpMethod, string path, string subscriptionKey, string payload, TimeSpan timeout)
private async Task<HttpResponseMessage> SendHttpRequestMessage(HttpMethod httpMethod, string path, string subscriptionKey, string payload, TimeSpan timeout)
{
try
{
Expand All @@ -110,6 +116,7 @@ private static async Task<HttpResponseMessage> SendHttpRequestMessage(HttpMethod
async (token) =>
{
using var httpRequestMessage = new HttpRequestMessage(httpMethod, path);

if (!string.IsNullOrEmpty(subscriptionKey))
{
httpRequestMessage.Headers.Add("Ocp-Apim-Subscription-Key", subscriptionKey);
Expand All @@ -120,7 +127,7 @@ private static async Task<HttpResponseMessage> SendHttpRequestMessage(HttpMethod
httpRequestMessage.Content = new StringContent(payload, Encoding.UTF8, "application/json");
}

var responseMessage = await HttpClient.SendAsync(httpRequestMessage, token).ConfigureAwait(false);
var responseMessage = await this.httpClient.SendAsync(httpRequestMessage, token).ConfigureAwait(false);

await responseMessage.EnsureSuccessStatusCodeAsync().ConfigureAwait(false);
return responseMessage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,7 @@ public int RetryLimit
public bool CreateAudioProcessedContainer { get; set; }

public string AudioProcessedContainer { get; set; }

public string Version { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class FetchTranscription
private readonly IStorageConnector storageConnector;
private readonly IAzureClientFactory<ServiceBusClient> serviceBusClientFactory;
private readonly ILogger<FetchTranscription> logger;
private readonly BatchClient batchClient;
private readonly AppConfig appConfig;

/// <summary>
Expand All @@ -36,18 +37,21 @@ public class FetchTranscription
/// <param name="logger">The FetchTranscription logger.</param>
/// <param name="storageConnector">Storage Connector dependency</param>
/// <param name="serviceBusClientFactory">Azure client factory for service bus clients</param>
/// <param name="batchClient">The client to call the Azure Speech-To-Text batch API</param>
/// <param name="appConfig">Environment configuration</param>
public FetchTranscription(
IServiceProvider serviceProvider,
ILogger<FetchTranscription> logger,
IStorageConnector storageConnector,
IAzureClientFactory<ServiceBusClient> serviceBusClientFactory,
BatchClient batchClient,
IOptions<AppConfig> appConfig)
{
this.serviceProvider = serviceProvider;
this.logger = logger;
this.storageConnector = storageConnector;
this.serviceBusClientFactory = serviceBusClientFactory;
this.batchClient = batchClient;
this.appConfig = appConfig?.Value;
}

Expand All @@ -72,7 +76,12 @@ public async Task Run([ServiceBusTrigger("fetch_transcription_queue", Connection

var databaseContext = this.appConfig.UseSqlDatabase ? this.serviceProvider.GetRequiredService<IngestionClientDbContext>() : null;

var transcriptionProcessor = new TranscriptionProcessor(this.storageConnector, this.serviceBusClientFactory, databaseContext, Options.Create(this.appConfig));
var transcriptionProcessor = new TranscriptionProcessor(
this.storageConnector,
this.serviceBusClientFactory,
databaseContext,
this.batchClient,
Options.Create(this.appConfig));

await transcriptionProcessor.ProcessTranscriptionJobAsync(serviceBusMessage, this.serviceProvider, this.logger).ConfigureAwait(false);
}
Expand Down
11 changes: 11 additions & 0 deletions samples/ingestion/ingestion-client/FetchTranscription/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
namespace FetchTranscription
{
using System.IO;
using System.Threading;

using Azure.Storage;
using Azure.Storage.Blobs;
Expand Down Expand Up @@ -70,6 +71,16 @@ public static void Main(string[] args)
.WithName(ServiceBusClientName.CompletedTranscriptionServiceBusClient.ToString());
}
});

services.AddHttpClient(nameof(BatchClient), httpClient =>
{
// timeouts are managed by BatchClient directly:
httpClient.Timeout = Timeout.InfiniteTimeSpan;
httpClient.DefaultRequestHeaders.UserAgent.ParseAdd($"Ingestion Client ({config.Version})");
});

services.AddSingleton<BatchClient>();

services.Configure<AppConfig>(configuration);
})
.Build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,20 @@ public class TranscriptionProcessor

private readonly IStorageConnector storageConnector;

private readonly BatchClient batchClient;

private readonly AppConfig appConfig;

public TranscriptionProcessor(
IStorageConnector storageConnector,
IAzureClientFactory<ServiceBusClient> serviceBusClientFactory,
IngestionClientDbContext databaseContext,
BatchClient batchClient,
IOptions<AppConfig> appConfig)
{
this.storageConnector = storageConnector;
this.databaseContext = databaseContext;
this.batchClient = batchClient;
this.appConfig = appConfig?.Value;

ArgumentNullException.ThrowIfNull(serviceBusClientFactory, nameof(serviceBusClientFactory));
Expand Down Expand Up @@ -86,7 +90,7 @@ public async Task ProcessTranscriptionJobAsync(TranscriptionStartedMessage servi

try
{
var transcription = await BatchClient.GetTranscriptionAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);
var transcription = await this.batchClient.GetTranscriptionAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);
log.LogInformation($"Polled {serviceBusMessage.PollingCounter} time(s) for results in total, delay job for {messageDelayTime.TotalMinutes} minutes if not completed.");
switch (transcription.Status)
{
Expand Down Expand Up @@ -189,13 +193,13 @@ private async Task ProcessFailedTranscriptionAsync(string transcriptionLocation,

log.LogInformation(logMessage);

var transcriptionFiles = await BatchClient.GetTranscriptionFilesAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);
var transcriptionFiles = await this.batchClient.GetTranscriptionFilesAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);

var errorReportOutput = logMessage;
var reportFile = transcriptionFiles.Values.Where(t => t.Kind == TranscriptionFileKind.TranscriptionReport).FirstOrDefault();
if (reportFile?.Links?.ContentUrl != null)
{
var reportFileContent = await BatchClient.GetTranscriptionReportFileFromSasAsync(reportFile.Links.ContentUrl).ConfigureAwait(false);
var reportFileContent = await this.batchClient.GetTranscriptionReportFileFromSasAsync(reportFile.Links.ContentUrl).ConfigureAwait(false);
errorReportOutput += $"\nReport file: \n {JsonConvert.SerializeObject(reportFileContent)}";
}

Expand Down Expand Up @@ -237,7 +241,7 @@ await this.storageConnector.MoveFileAsync(
}
}

await BatchClient.DeleteTranscriptionAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);
await this.batchClient.DeleteTranscriptionAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);
}

private async Task ProcessReportFileAsync(TranscriptionReportFile transcriptionReportFile, ILogger log)
Expand Down Expand Up @@ -290,7 +294,7 @@ private async Task RetryOrFailJobAsync(TranscriptionStartedMessage message, stri
else
{
await this.WriteFailedJobLogToStorageAsync(message, errorMessage, jobName, log).ConfigureAwait(false);
await BatchClient.DeleteTranscriptionAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);
await this.batchClient.DeleteTranscriptionAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);
}
}

Expand Down Expand Up @@ -347,7 +351,7 @@ private async Task ProcessSucceededTranscriptionAsync(string transcriptionLocati
return;
}

var transcriptionFiles = await BatchClient.GetTranscriptionFilesAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);
var transcriptionFiles = await this.batchClient.GetTranscriptionFilesAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false);
log.LogInformation($"Received transcription files.");
var resultFiles = transcriptionFiles.Values.Where(t => t.Kind == TranscriptionFileKind.Transcription);

Expand All @@ -360,7 +364,7 @@ private async Task ProcessSucceededTranscriptionAsync(string transcriptionLocati

try
{
var transcriptionResult = await BatchClient.GetSpeechTranscriptFromSasAsync(resultFile.Links.ContentUrl).ConfigureAwait(false);
var transcriptionResult = await this.batchClient.GetSpeechTranscriptFromSasAsync(resultFile.Links.ContentUrl).ConfigureAwait(false);

if (string.IsNullOrEmpty(transcriptionResult.Source))
{
Expand Down Expand Up @@ -522,10 +526,10 @@ await this.databaseContext.StoreTranscriptionAsync(
}

var reportFile = transcriptionFiles.Values.Where(t => t.Kind == TranscriptionFileKind.TranscriptionReport).FirstOrDefault();
var reportFileContent = await BatchClient.GetTranscriptionReportFileFromSasAsync(reportFile.Links.ContentUrl).ConfigureAwait(false);
var reportFileContent = await this.batchClient.GetTranscriptionReportFileFromSasAsync(reportFile.Links.ContentUrl).ConfigureAwait(false);
await this.ProcessReportFileAsync(reportFileContent, log).ConfigureAwait(false);

BatchClient.DeleteTranscriptionAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false).GetAwaiter().GetResult();
this.batchClient.DeleteTranscriptionAsync(transcriptionLocation, subscriptionKey).ConfigureAwait(false).GetAwaiter().GetResult();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,7 @@ public int InitialPollingDelayInMinutes
public string StartTranscriptionServiceBusConnectionString { get; set; }

public string StartTranscriptionFunctionTimeInterval { get; set; }

public string Version { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
namespace StartTranscriptionByTimer
{
using System.IO;
using System.Threading;

using Azure.Storage;
using Azure.Storage.Blobs;
Expand Down Expand Up @@ -56,6 +57,16 @@ public static void Main(string[] args)
clientBuilder.AddServiceBusClient(config.FetchTranscriptionServiceBusConnectionString)
.WithName(ServiceBusClientName.FetchTranscriptionServiceBusClient.ToString());
});

services.AddHttpClient(nameof(BatchClient), httpClient =>
{
// timeouts are managed by BatchClient directly:
httpClient.Timeout = Timeout.InfiniteTimeSpan;
httpClient.DefaultRequestHeaders.UserAgent.ParseAdd($"Ingestion Client ({config.Version})");
});

services.AddSingleton<BatchClient>();

services.Configure<AppConfig>(configuration);
})
.Build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,20 @@ public class StartTranscriptionHelper : IStartTranscriptionHelper

private readonly IStorageConnector storageConnector;

private readonly BatchClient batchClient;

private readonly AppConfig appConfig;

public StartTranscriptionHelper(
ILogger<StartTranscriptionHelper> logger,
IStorageConnector storageConnector,
IAzureClientFactory<ServiceBusClient> serviceBusClientFactory,
BatchClient batchClient,
IOptions<AppConfig> appConfig)
{
this.logger = logger;
this.storageConnector = storageConnector;
this.batchClient = batchClient;
this.appConfig = appConfig?.Value;
this.locale = this.appConfig.Locale.Split('|')[0].Trim();

Expand Down Expand Up @@ -229,7 +233,7 @@ private async Task StartBatchTranscriptionJobAsync(IEnumerable<ServiceBusReceive

var transcriptionDefinition = TranscriptionDefinition.Create(jobName, "StartByTimerTranscription", this.locale, audioUrls, properties, modelIdentity);

var transcriptionLocation = await BatchClient.PostTranscriptionAsync(
var transcriptionLocation = await this.batchClient.PostTranscriptionAsync(
transcriptionDefinition,
this.appConfig.AzureSpeechServicesEndpointUri,
this.appConfig.AzureSpeechServicesKey).ConfigureAwait(false);
Expand Down
Loading

0 comments on commit 203bd46

Please sign in to comment.