Skip to content

Commit

Permalink
Merge branch 'dev' into tefa/refactor-service-endpoint-codes
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan authored Nov 25, 2024
2 parents 418aec8 + 3c6b848 commit efa08f8
Show file tree
Hide file tree
Showing 33 changed files with 843 additions and 921 deletions.
15 changes: 11 additions & 4 deletions build/dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
<SystemIOPipelinesPackageVersion>8.0.0</SystemIOPipelinesPackageVersion>
<MicrosoftOwinPackageVersion>4.2.2</MicrosoftOwinPackageVersion>
<OwinPackageVersion>1.0.0</OwinPackageVersion>
<MicrosoftAspNetCoreConnectionsAbstractionsNet8PackageVersion>8.0.8</MicrosoftAspNetCoreConnectionsAbstractionsNet8PackageVersion>
<MicrosoftAspNetCoreConnectionsAbstractionsNet8PackageVersion>8.0.11</MicrosoftAspNetCoreConnectionsAbstractionsNet8PackageVersion>
<MicrosoftExtensionsDependencyInjectionNet8PackageVersion>8.0.0</MicrosoftExtensionsDependencyInjectionNet8PackageVersion>
<MicrosoftAspNetCoreHttpConnectionsCommonNet8PackageVersion>8.0.8</MicrosoftAspNetCoreHttpConnectionsCommonNet8PackageVersion>
<MicrosoftAspNetCoreHttpConnectionsCommonNet8PackageVersion>8.0.11</MicrosoftAspNetCoreHttpConnectionsCommonNet8PackageVersion>
<MicrosoftExtensionsLoggingAbstractionsNet8PackageVersion>8.0.1</MicrosoftExtensionsLoggingAbstractionsNet8PackageVersion>
<MicrosoftAspNetCoreSignalRCommonNet8PackageVersion>8.0.8</MicrosoftAspNetCoreSignalRCommonNet8PackageVersion>
<MicrosoftAspNetCoreSignalRCommonNet8PackageVersion>8.0.11</MicrosoftAspNetCoreSignalRCommonNet8PackageVersion>

<!-- SignalR Management -->
<AzureCorePackageVersion>1.39.0</AzureCorePackageVersion>
Expand All @@ -55,6 +55,13 @@

<!--Emulator, self-contained, always try the latest version -->
<MicrosoftExtensionsCommandLineUtilsPackageVersion>1.1.1</MicrosoftExtensionsCommandLineUtilsPackageVersion>
<EmulatorMicrosoftPackageVersion>8.0.8</EmulatorMicrosoftPackageVersion></PropertyGroup>
<EmulatorMicrosoftPackageVersion>8.0.11</EmulatorMicrosoftPackageVersion>

<!-- Vulnerability fix-->
<jQueryPackageVersion>3.7.1</jQueryPackageVersion>
<MessagePackPackageVersion>2.5.192</MessagePackPackageVersion>
<SystemNetHttpPackageVersion>4.3.4</SystemNetHttpPackageVersion>
<SystemTextRegularExpressionsPackageVersion>4.3.1</SystemTextRegularExpressionsPackageVersion>
</PropertyGroup>
<Import Project="$(DotNetPackageVersionPropsPath)" Condition=" '$(DotNetPackageVersionPropsPath)' != '' " />
</Project>
6 changes: 6 additions & 0 deletions samples/Directory.Build.props
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
<Project>
<Import Project="..\Directory.Build.props" />
<Import Project="..\build\dependencies.private.props" />
<PropertyGroup>
<WarningsNotAsErrors>NU1902</WarningsNotAsErrors>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="MessagePack" Version="$(MessagePackPackageVersion)" />
</ItemGroup>
</Project>
17 changes: 9 additions & 8 deletions src/Common/MemoryBufferWriter.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
#nullable enable

using System;
using System.Buffers;
Expand All @@ -14,7 +15,7 @@ namespace Microsoft.Azure.SignalR
internal sealed class MemoryBufferWriter : Stream, IBufferWriter<byte>
{
[ThreadStatic]
private static MemoryBufferWriter _cachedInstance;
private static MemoryBufferWriter? _cachedInstance;

#if DEBUG
private bool _inUse;
Expand All @@ -23,8 +24,8 @@ internal sealed class MemoryBufferWriter : Stream, IBufferWriter<byte>
private readonly int _minimumSegmentSize;
private int _bytesWritten;

private List<CompletedBuffer> _completedSegments;
private byte[] _currentSegment;
private List<CompletedBuffer>? _completedSegments;
private byte[]? _currentSegment;
private int _position;

public MemoryBufferWriter(int minimumSegmentSize = 4096)
Expand Down Expand Up @@ -107,14 +108,14 @@ public Memory<byte> GetMemory(int sizeHint = 0)
{
EnsureCapacity(sizeHint);

return _currentSegment.AsMemory(_position, _currentSegment.Length - _position);
return _currentSegment.AsMemory(_position, _currentSegment!.Length - _position);
}

public Span<byte> GetSpan(int sizeHint = 0)
{
EnsureCapacity(sizeHint);

return _currentSegment.AsSpan(_position, _currentSegment.Length - _position);
return _currentSegment.AsSpan(_position, _currentSegment!.Length - _position);
}

public void CopyTo(IBufferWriter<byte> destination)
Expand All @@ -137,7 +138,7 @@ public override Task CopyToAsync(Stream destination, int bufferSize, Cancellatio
if (_completedSegments == null)
{
// There is only one segment so write without awaiting.
return destination.WriteAsync(_currentSegment, 0, _position);
return destination.WriteAsync(_currentSegment!, 0, _position);
}

return CopyToSlowAsync(destination);
Expand Down Expand Up @@ -194,7 +195,7 @@ private async Task CopyToSlowAsync(Stream destination)
}
}

await destination.WriteAsync(_currentSegment, 0, _position);
await destination.WriteAsync(_currentSegment!, 0, _position);
}

public byte[] ToArray()
Expand Down Expand Up @@ -270,7 +271,7 @@ public override void WriteByte(byte value)
else
{
AddSegment();
_currentSegment[0] = value;
_currentSegment![0] = value;
}

_position++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
<LangVersion>11</LangVersion>
<RootNamespace>Microsoft.Azure.SignalR.AspNet</RootNamespace>
<TargetFramework>net462</TargetFramework>
<!-- jQuery is not used in this library, upgrading jQuery to 3.x might potentially break users using 1.x, so ignore this warning -->
<WarningsNotAsErrors>NU1902</WarningsNotAsErrors>
</PropertyGroup>

<ItemGroup>
<!-- Directly reference Microsoft.Owin 4.x for security fix -->
<PackageReference Include="Microsoft.Owin" Version="$(MicrosoftOwinPackageVersion)" />
<PackageReference Include="Microsoft.AspNet.SignalR" Version="$(MicrosoftAspNetSignalRPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.EventSource" Version="$(MicrosoftExtensionsLoggingEventSourcePackageVersion)" />

</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;

namespace Microsoft.Azure.SignalR;

internal sealed class AccessKeySynchronizer : IAccessKeySynchronizer, IDisposable
{
private readonly ConcurrentDictionary<ServiceEndpoint, object> _endpoints = new ConcurrentDictionary<ServiceEndpoint, object>(ReferenceEqualityComparer.Instance);
private readonly ConcurrentDictionary<MicrosoftEntraAccessKey, bool> _keyMap = new(ReferenceEqualityComparer.Instance);

private readonly ILoggerFactory _factory;
private readonly ILogger<AccessKeySynchronizer> _logger;

private readonly TimerAwaitable _timer = new TimerAwaitable(TimeSpan.Zero, TimeSpan.FromMinutes(1));

internal IEnumerable<MicrosoftEntraAccessKey> AccessKeyForMicrosoftEntraList => _endpoints.Select(e => e.Key.AccessKey).OfType<MicrosoftEntraAccessKey>();
internal IEnumerable<MicrosoftEntraAccessKey> InitializedKeyList => _keyMap.Where(x => x.Key.Initialized).Select(x => x.Key);

public AccessKeySynchronizer(ILoggerFactory loggerFactory) : this(loggerFactory, true)
{
Expand All @@ -32,65 +34,74 @@ internal AccessKeySynchronizer(ILoggerFactory loggerFactory, bool start)
{
if (start)
{
_ = UpdateAccessKeyAsync();
_ = UpdateAllAccessKeyAsync();
}
_factory = loggerFactory ?? throw new ArgumentNullException(nameof(loggerFactory));
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<AccessKeySynchronizer>();
}

public void AddServiceEndpoint(ServiceEndpoint endpoint)
{
if (endpoint.AccessKey is MicrosoftEntraAccessKey key)
{
_ = key.UpdateAccessKeyAsync();
_keyMap.TryAdd(key, true);
}
_endpoints.TryAdd(endpoint, null);
}

public void Dispose() => _timer.Stop();

public void UpdateServiceEndpoints(IEnumerable<ServiceEndpoint> endpoints)
{
_endpoints.Clear();
_keyMap.Clear();
foreach (var endpoint in endpoints)
{
AddServiceEndpoint(endpoint);
}
}

internal bool ContainsServiceEndpoint(ServiceEndpoint e) => _endpoints.ContainsKey(e);
/// <summary>
/// Test only
/// </summary>
/// <param name="e"></param>
/// <returns></returns>
internal bool ContainsKey(ServiceEndpoint e) => _keyMap.ContainsKey(e.AccessKey as MicrosoftEntraAccessKey);

internal int ServiceEndpointsCount() => _endpoints.Count;
/// <summary>
/// Test only
/// </summary>
/// <returns></returns>
internal int Count() => _keyMap.Count;

private async Task UpdateAccessKeyAsync()
private async Task UpdateAllAccessKeyAsync()
{
using (_timer)
{
_timer.Start();

while (await _timer)
{
foreach (var key in AccessKeyForMicrosoftEntraList)
foreach (var key in InitializedKeyList)
{
_ = key.UpdateAccessKeyAsync();
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = key.UpdateAccessKeyAsync(source.Token);
}
}
}
}

private sealed class ReferenceEqualityComparer : IEqualityComparer<ServiceEndpoint>
private sealed class ReferenceEqualityComparer : IEqualityComparer<MicrosoftEntraAccessKey>
{
internal static readonly ReferenceEqualityComparer Instance = new ReferenceEqualityComparer();

private ReferenceEqualityComparer()
{
}

public bool Equals(ServiceEndpoint x, ServiceEndpoint y)
public bool Equals(MicrosoftEntraAccessKey x, MicrosoftEntraAccessKey y)
{
return ReferenceEquals(x, y);
}

public int GetHashCode(ServiceEndpoint obj)
public int GetHashCode(MicrosoftEntraAccessKey obj)
{
return RuntimeHelpers.GetHashCode(obj);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ internal class MicrosoftEntraAccessKey : IAccessKey
{
internal static readonly TimeSpan GetAccessKeyTimeout = TimeSpan.FromSeconds(100);

private const int UpdateTaskIdle = 0;

private const int UpdateTaskRunning = 1;

private const int GetAccessKeyMaxRetryTimes = 3;

private const int GetMicrosoftEntraTokenMaxRetryTimes = 3;
Expand All @@ -36,10 +40,12 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private static readonly TimeSpan AccessKeyExpireTime = TimeSpan.FromMinutes(120);

private readonly TaskCompletionSource<object?> _initializedTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource<object?> _initializedTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

private readonly IHttpClientFactory _httpClientFactory;

private volatile int _updateState = 0;

private volatile bool _isAuthorized = false;

private DateTime _updateAt = DateTime.MinValue;
Expand All @@ -48,6 +54,8 @@ internal class MicrosoftEntraAccessKey : IAccessKey

private volatile byte[]? _keyBytes;

public bool Initialized => _initializedTcs.Task.IsCompleted;

public bool Available
{
get => _isAuthorized && DateTime.UtcNow - _updateAt < AccessKeyExpireTime;
Expand Down Expand Up @@ -116,6 +124,12 @@ public async Task<string> GenerateAccessTokenAsync(string audience,
AccessTokenAlgorithm algorithm,
CancellationToken ctoken = default)
{
if (!_initializedTcs.Task.IsCompleted)
{
var source = new CancellationTokenSource(Constants.Periods.DefaultUpdateAccessKeyTimeout);
_ = UpdateAccessKeyAsync(source.Token);
}

await _initializedTcs.Task.OrCancelAsync(ctoken, "The access key initialization timed out.");

return Available
Expand All @@ -142,26 +156,35 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
return;
}

if (Interlocked.CompareExchange(ref _updateState, UpdateTaskRunning, UpdateTaskIdle) != UpdateTaskIdle)
{
return;
}

for (var i = 0; i < GetAccessKeyMaxRetryTimes; i++)
{
if (ctoken.IsCancellationRequested)
{
break;
}

var source = new CancellationTokenSource(GetAccessKeyTimeout);
var linkedSource = CancellationTokenSource.CreateLinkedTokenSource(source.Token, ctoken);
try
{
await UpdateAccessKeyInternalAsync(linkedSource.Token);
await UpdateAccessKeyInternalAsync(source.Token).OrCancelAsync(ctoken);
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
return;
}
catch (OperationCanceledException e)
{
LastException = e;
break;
LastException = e; // retry immediately
}
catch (Exception e)
{
LastException = e;
try
{
await Task.Delay(GetAccessKeyRetryInterval, ctoken);
await Task.Delay(GetAccessKeyRetryInterval, ctoken); // retry after interval.
}
catch (OperationCanceledException)
{
Expand All @@ -175,6 +198,7 @@ internal async Task UpdateAccessKeyAsync(CancellationToken ctoken = default)
// Update the status only when it becomes "not available" due to expiration to refresh updateAt.
Available = false;
}
Interlocked.Exchange(ref _updateState, UpdateTaskIdle);
}

private static string GetExceptionMessage(Exception? exception)
Expand Down Expand Up @@ -232,11 +256,11 @@ private async Task UpdateAccessKeyInternalAsync(CancellationToken ctoken)
await ThrowExceptionOnResponseFailureAsync(request, response);
}

private async Task<bool> HandleHttpResponseAsync(HttpResponseMessage response)
private async Task HandleHttpResponseAsync(HttpResponseMessage response)
{
if (response.StatusCode != HttpStatusCode.OK)
{
return false;
return;
}

var content = await response.Content.ReadAsStringAsync();
Expand All @@ -250,8 +274,6 @@ private async Task<bool> HandleHttpResponseAsync(HttpResponseMessage response)
{
throw new AzureSignalRException("Missing required <AccessKey> field.");
}

UpdateAccessKey(obj.KeyId, obj.AccessKey);
return true;
}
}
5 changes: 3 additions & 2 deletions src/Microsoft.Azure.SignalR.Common/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ internal static class Constants

public const string AsrsDefaultScope = "https://signalr.azure.com/.default";


public const int DefaultCloseTimeoutMilliseconds = 30000;
public const int DefaultCloseTimeoutMilliseconds = 10000;

public static class Keys
{
Expand All @@ -45,6 +44,8 @@ public static class Periods

public const int MaxCustomHandshakeTimeout = 30;

public static readonly TimeSpan DefaultUpdateAccessKeyTimeout = TimeSpan.FromMinutes(2);

public static readonly TimeSpan DefaultAccessTokenLifetime = TimeSpan.FromHours(1);

public static readonly TimeSpan DefaultScaleTimeout = TimeSpan.FromMinutes(5);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static async Task OrCancelAsync(this Task task, CancellationToken token,
{
// make sure the task throws exception if any
await anyTask;
tcs.TrySetCanceled();
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
<PackageReference Include="Microsoft.AspNetCore.Mvc.NewtonsoftJson" Version="$(EmulatorMicrosoftPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Authentication.JwtBearer" Version="$(EmulatorMicrosoftPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.CommandLineUtils" Version="$(MicrosoftExtensionsCommandLineUtilsPackageVersion)" />
<PackageReference Include="MessagePack" Version="$(MessagePackPackageVersion)" />
<PackageReference Include="System.Net.Http" Version="$(SystemNetHttpPackageVersion)" />
<PackageReference Include="System.Text.RegularExpressions" Version="$(SystemTextRegularExpressionsPackageVersion)" />
</ItemGroup>
</Project>
Loading

0 comments on commit efa08f8

Please sign in to comment.