Skip to content

Commit

Permalink
Add ServiceProviderAccessor to allow access to IServiceProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
tdg5 committed Nov 5, 2024
1 parent 46410a6 commit 886015d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/Temporalio.Extensions.Hosting/IServiceProviderAccessor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using System;

namespace Temporalio.Extensions.Hosting
{
/// <summary>
/// Provides access to the current, scoped <see cref="IServiceProvider"/> if
/// one is available.
/// </summary>
public interface IServiceProviderAccessor
{
/// <summary>
/// Gets or sets the current service provider.
/// </summary>
IServiceProvider? ServiceProvider { get; set; }
}
}
42 changes: 42 additions & 0 deletions src/Temporalio.Extensions.Hosting/ServiceProviderAccessor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using System;
using System.Threading;

namespace Temporalio.Extensions.Hosting
{
/// <summary>
/// Provides an implementation of <see cref="IServiceProvider" /> based on
/// the current execution context.
/// </summary>
public class ServiceProviderAccessor : IServiceProviderAccessor
{
private static readonly AsyncLocal<ServiceProviderHolder> ServiceProviderCurrent = new();

/// <inheritdoc/>
public IServiceProvider? ServiceProvider
{
get => ServiceProviderCurrent.Value?.ServiceProvider;

set
{
var holder = ServiceProviderCurrent.Value;
if (holder != null)
{
// Clear current IServiceProvider trapped in the AsyncLocals, as its done.
holder.ServiceProvider = null;
}

if (value != null)
{
// Use an object indirection to hold the IServiceProvider in the AsyncLocal,
// so it can be cleared in all ExecutionContexts when its cleared.
ServiceProviderCurrent.Value = new ServiceProviderHolder { ServiceProvider = value };
}
}
}

private sealed class ServiceProviderHolder
{
public IServiceProvider? ServiceProvider { get; set; }
}
}
}
12 changes: 12 additions & 0 deletions src/Temporalio.Extensions.Hosting/ServiceProviderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ public static ActivityDefinition CreateTemporalActivityDefinition(
#else
var scope = provider.CreateScope();
#endif
IServiceProviderAccessor? serviceProviderAccessor =
scope.ServiceProvider.GetService<IServiceProviderAccessor>();

if (serviceProviderAccessor is not null)
{
serviceProviderAccessor.ServiceProvider = scope.ServiceProvider;
}

try
{
object? result;
Expand Down Expand Up @@ -111,6 +119,10 @@ public static ActivityDefinition CreateTemporalActivityDefinition(
}
finally
{
if (serviceProviderAccessor is not null)
{
serviceProviderAccessor.ServiceProvider = null;
}
#if NET6_0_OR_GREATER
await scope.DisposeAsync().ConfigureAwait(false);
#else
Expand Down

0 comments on commit 886015d

Please sign in to comment.