Skip to content

Commit

Permalink
Merge branch 'main' into gai/opt-in-features
Browse files Browse the repository at this point in the history
  • Loading branch information
glen-84 committed Nov 14, 2024
2 parents 8e5e2d1 + a18706c commit 3d79b77
Show file tree
Hide file tree
Showing 202 changed files with 7,849 additions and 2,420 deletions.
4 changes: 2 additions & 2 deletions src/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.11.0" />
<PackageVersion Include="Microsoft.OpenApi" Version="1.6.14" />
<PackageVersion Include="Microsoft.OpenApi.Readers" Version="1.6.14" />
<PackageVersion Include="MongoDB.Driver" Version="2.29.0" />
<PackageVersion Include="MongoDB.Driver" Version="3.0.0" />
<PackageVersion Include="Moq" Version="4.20.70" />
<PackageVersion Include="NetTopologySuite" Version="2.0.0" />
<PackageVersion Include="Newtonsoft.Json" Version="13.0.2" />
Expand All @@ -47,7 +47,7 @@
<PackageVersion Include="Snapshooter.Xunit" Version="0.5.4" />
<PackageVersion Include="sqlite-net-pcl" Version="1.9.172" />
<PackageVersion Include="SQLitePCLRaw.bundle_green" Version="2.1.8" />
<PackageVersion Include="Squadron.Mongo" Version="0.21.0" />
<PackageVersion Include="Squadron.Mongo" Version="0.23.0" />
<PackageVersion Include="Squadron.Nats" Version="0.18.0" />
<PackageVersion Include="Squadron.PostgreSql" Version="0.18.0" />
<PackageVersion Include="Squadron.RabbitMQ" Version="0.18.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ public static IServiceCollection AddDataLoader<T>(
return services;
}

public static IServiceCollection AddDataLoader<TService, TImplementation>(
this IServiceCollection services,
Func<IServiceProvider, TImplementation> factory)
where TService : class, IDataLoader
where TImplementation : class, TService
{
services.TryAddDataLoaderCore();
services.AddSingleton(new DataLoaderRegistration(typeof(TService), typeof(TImplementation), sp => factory(sp)));
services.TryAddScoped<TImplementation>(sp => sp.GetDataLoader<TImplementation>());
services.TryAddScoped<TService>(sp => sp.GetDataLoader<TService>());
return services;
}

public static IServiceCollection TryAddDataLoaderCore(
this IServiceCollection services)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using GreenDonut;
using Xunit;
using static GreenDonut.TestHelpers;

namespace Microsoft.Extensions.DependencyInjection;

public class DataLoaderServiceCollectionExtensionsTests
{
[Fact]
public void ImplFactoryIsCalledWhenServiceIsResolved()
{
// arrange
var factoryCalled = false;
var fetch = CreateFetch<string, string>();
var services = new ServiceCollection()
.AddScoped<IBatchScheduler, ManualBatchScheduler>()
.AddDataLoader(sp =>
{
factoryCalled = true;
return new DataLoader<string, string>(fetch, sp.GetRequiredService<IBatchScheduler>());
});
var scope = services.BuildServiceProvider().CreateScope();

// act
var dataLoader = scope.ServiceProvider.GetRequiredService<DataLoader<string, string>>();

// assert
Assert.NotNull(dataLoader);
Assert.True(factoryCalled);
}

[Fact]
public void InterfaceImplFactoryIsCalledWhenServiceIsResolved()
{
// arrange
var factoryCalled = false;
var fetch = CreateFetch<string, string>();
var services = new ServiceCollection()
.AddScoped<IBatchScheduler, ManualBatchScheduler>()
.AddDataLoader<IDataLoader<string, string>, DataLoader<string, string>>(sp =>
{
factoryCalled = true;
return new DataLoader<string, string>(fetch, sp.GetRequiredService<IBatchScheduler>());
});
var scope = services.BuildServiceProvider().CreateScope();

// act
var dataLoader = scope.ServiceProvider.GetRequiredService<DataLoader<string, string>>();
var asInterface = scope.ServiceProvider.GetRequiredService<IDataLoader<string, string>>();

// assert
Assert.NotNull(dataLoader);
Assert.NotNull(asInterface);
Assert.True(factoryCalled);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,48 +4,21 @@

namespace HotChocolate.AspNetCore.Authorization;

internal sealed class AuthorizationPolicyCache(IAuthorizationPolicyProvider policyProvider)
internal sealed class AuthorizationPolicyCache
{
private readonly ConcurrentDictionary<string, Task<AuthorizationPolicy>> _cache = new();
private readonly ConcurrentDictionary<string, AuthorizationPolicy> _cache = new();

public Task<AuthorizationPolicy> GetOrCreatePolicyAsync(AuthorizeDirective directive)
public AuthorizationPolicy? LookupPolicy(AuthorizeDirective directive)
{
var cacheKey = directive.GetPolicyCacheKey();

return _cache.GetOrAdd(cacheKey, _ => BuildAuthorizationPolicy(directive.Policy, directive.Roles));
return _cache.GetValueOrDefault(cacheKey);
}

private async Task<AuthorizationPolicy> BuildAuthorizationPolicy(
string? policyName,
IReadOnlyList<string>? roles)
public void CachePolicy(AuthorizeDirective directive, AuthorizationPolicy policy)
{
var policyBuilder = new AuthorizationPolicyBuilder();

if (!string.IsNullOrWhiteSpace(policyName))
{
var policy = await policyProvider.GetPolicyAsync(policyName).ConfigureAwait(false);

if (policy is not null)
{
policyBuilder = policyBuilder.Combine(policy);
}
else
{
throw new MissingAuthorizationPolicyException(policyName);
}
}
else
{
var defaultPolicy = await policyProvider.GetDefaultPolicyAsync().ConfigureAwait(false);

policyBuilder = policyBuilder.Combine(defaultPolicy);
}

if (roles is not null)
{
policyBuilder = policyBuilder.RequireRole(roles);
}
var cacheKey = directive.GetPolicyCacheKey();

return policyBuilder.Build();
_cache.TryAdd(cacheKey, policy);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,39 @@ namespace HotChocolate.AspNetCore.Authorization;
internal sealed class DefaultAuthorizationHandler : IAuthorizationHandler
{
private readonly IAuthorizationService _authSvc;
private readonly AuthorizationPolicyCache _policyCache;
private readonly IAuthorizationPolicyProvider _authorizationPolicyProvider;
private readonly AuthorizationPolicyCache _authorizationPolicyCache;
private readonly bool _canCachePolicies;

/// <summary>
/// Initializes a new instance <see cref="DefaultAuthorizationHandler"/>.
/// </summary>
/// <param name="authorizationService">
/// The authorization service.
/// </param>
/// <param name="policyCache">
/// <param name="authorizationPolicyProvider">
/// The authorization policy provider.
/// </param>
/// <param name="authorizationPolicyCache">
/// The authorization policy cache.
/// </param>
/// <exception cref="ArgumentNullException">
/// <paramref name="authorizationService"/> is <c>null</c>.
/// <paramref name="policyCache"/> is <c>null</c>.
/// <paramref name="authorizationPolicyCache"/> is <c>null</c>.
/// </exception>
public DefaultAuthorizationHandler(
IAuthorizationService authorizationService,
AuthorizationPolicyCache policyCache)
IAuthorizationPolicyProvider authorizationPolicyProvider,
AuthorizationPolicyCache authorizationPolicyCache)
{
_authSvc = authorizationService ??
throw new ArgumentNullException(nameof(authorizationService));
_policyCache = policyCache ??
throw new ArgumentNullException(nameof(policyCache));
_authorizationPolicyProvider = authorizationPolicyProvider ??
throw new ArgumentNullException(nameof(authorizationPolicyProvider));
_authorizationPolicyCache = authorizationPolicyCache ??
throw new ArgumentNullException(nameof(authorizationPolicyCache));

_canCachePolicies = _authorizationPolicyProvider.AllowsCachingPolicies;
}

/// <summary>
Expand Down Expand Up @@ -123,9 +133,24 @@ private async ValueTask<AuthorizeResult> AuthorizeAsync(
{
try
{
var combinedPolicy = await _policyCache.GetOrCreatePolicyAsync(directive);
AuthorizationPolicy? authorizationPolicy = null;

if (_canCachePolicies)
{
authorizationPolicy = _authorizationPolicyCache.LookupPolicy(directive);
}

if (authorizationPolicy is null)
{
authorizationPolicy = await BuildAuthorizationPolicy(directive.Policy, directive.Roles);

var result = await _authSvc.AuthorizeAsync(user, context, combinedPolicy).ConfigureAwait(false);
if (_canCachePolicies)
{
_authorizationPolicyCache.CachePolicy(directive, authorizationPolicy);
}
}

var result = await _authSvc.AuthorizeAsync(user, context, authorizationPolicy).ConfigureAwait(false);

return result.Succeeded
? AuthorizeResult.Allowed
Expand All @@ -137,6 +162,40 @@ private async ValueTask<AuthorizeResult> AuthorizeAsync(
}
}

private async Task<AuthorizationPolicy> BuildAuthorizationPolicy(
string? policyName,
IReadOnlyList<string>? roles)
{
var policyBuilder = new AuthorizationPolicyBuilder();

if (!string.IsNullOrWhiteSpace(policyName))
{
var policy = await _authorizationPolicyProvider.GetPolicyAsync(policyName).ConfigureAwait(false);

if (policy is not null)
{
policyBuilder = policyBuilder.Combine(policy);
}
else
{
throw new MissingAuthorizationPolicyException(policyName);
}
}
else
{
var defaultPolicy = await _authorizationPolicyProvider.GetDefaultPolicyAsync().ConfigureAwait(false);

policyBuilder = policyBuilder.Combine(defaultPolicy);
}

if (roles is not null)
{
policyBuilder = policyBuilder.RequireRole(roles);
}

return policyBuilder.Build();
}

private static UserState GetUserState(IDictionary<string, object?> contextData)
{
if (contextData.TryGetValue(WellKnownContextData.UserState, out var value) &&
Expand Down
Loading

0 comments on commit 3d79b77

Please sign in to comment.