Skip to content

Commit

Permalink
plugins system dependencies forced to be platform version
Browse files Browse the repository at this point in the history
  • Loading branch information
caunt committed Sep 25, 2024
1 parent fe3c34a commit 0c63d15
Show file tree
Hide file tree
Showing 11 changed files with 289 additions and 54 deletions.
8 changes: 8 additions & 0 deletions src/API/Plugins/IPluginDependencyService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
using System.Reflection;

namespace Void.Proxy.API.Plugins;

public interface IPluginDependencyService
{
public string? ResolveAssemblyPath(AssemblyName assemblyName);
}
1 change: 1 addition & 0 deletions src/Platform/EntryPoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
builder.Services.AddSingleton<ISettings, Settings>();
builder.Services.AddSingleton<ICryptoService, RsaCryptoService>();
builder.Services.AddSingleton<IEventService, EventService>();
builder.Services.AddSingleton<IPluginDependencyService, PluginDependencyService>();
builder.Services.AddSingleton<IPluginService, PluginService>();
builder.Services.AddSingleton<IPlayerService, PlayerService>();
builder.Services.AddSingleton<IServerService, ServerService>();
Expand Down
26 changes: 17 additions & 9 deletions src/Platform/Events/EventService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Void.Proxy.Events;

public class EventService : IEventService
public class EventService(ILogger<EventService> logger) : IEventService
{
private readonly List<IEventListener> _listeners = [];
private readonly List<MethodInfo> _methods = [];
Expand All @@ -30,6 +30,7 @@ public class EventService : IEventService
public async ValueTask ThrowAsync<T>(T @event, CancellationToken cancellationToken = default) where T : IEvent
{
var eventType = @event.GetType();

var simpleParameters = (object[]) [@event];
var cancellableParameters = (object[]) [@event, cancellationToken];

Expand All @@ -56,15 +57,22 @@ public async ValueTask ThrowAsync<T>(T @event, CancellationToken cancellationTok

await Task.Yield();

var value = method.Invoke(listener, parameters.Length == 1 ? simpleParameters : cancellableParameters);
var handle = value switch
try
{
Task task => new ValueTask(task),
ValueTask task => task,
_ => ValueTask.CompletedTask
};

await handle;
var value = method.Invoke(listener, parameters.Length == 1 ? simpleParameters : cancellableParameters);
var handle = value switch
{
Task task => new ValueTask(task),
ValueTask task => task,
_ => ValueTask.CompletedTask
};

await handle;
}
catch (TargetInvocationException exception)
{
logger.LogError(exception.InnerException, "{EventName} cannot be invoked on {ListenerName}", eventType.Name, listener.GetType().FullName);
}
}
}
}
Expand Down
180 changes: 180 additions & 0 deletions src/Platform/Plugins/PluginDependencyService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
using System.Reflection;
using System.Runtime.Versioning;
using NuGet.Common;
using NuGet.Configuration;
using NuGet.Frameworks;
using NuGet.PackageManagement;
using NuGet.Packaging;
using NuGet.Packaging.Core;
using NuGet.Protocol;
using NuGet.Protocol.Core.Types;
using Void.Proxy.API.Plugins;

namespace Void.Proxy.Plugins;

public class PluginDependencyService(ILogger<PluginDependencyService> logger) : IPluginDependencyService
{
private static readonly SourceRepository NuGetRepository = Repository.Factory.GetCoreV3(new PackageSource("https://api.nuget.org/v3/index.json").Source);
private static readonly SourceCacheContext NuGetCache = new();
private static readonly string NuGetPackagesPath = Path.Combine(Directory.GetCurrentDirectory(), SettingsUtility.DefaultGlobalPackagesFolderPath);

public string? ResolveAssemblyPath(AssemblyName assemblyName)
{
logger.LogInformation("Resolving {AssemblyName} dependency", assemblyName.Name);

if (assemblyName.FullName.StartsWith(nameof(Void) + '.'))
{
logger.LogCritical("Void packages shouldn't be searched in NuGet");
return null;
}

var assemblyPath = ResolveAssemblyFromNuGetAsync(assemblyName, CancellationToken.None).GetAwaiter().GetResult();

return assemblyPath;
}

private async ValueTask<string?> ResolveAssemblyFromNuGetAsync(AssemblyName assemblyName, CancellationToken cancellationToken)
{
var identity = await TryResolveNuGetIdentityAsync(assemblyName, cancellationToken);

if (identity is null)
{
logger.LogError("Dependency {DependencyName} not found in NuGet at all", assemblyName.Name);
return null;
}

var packagePath = Path.Combine(NuGetPackagesPath, identity.Id.ToLower(), identity.Version.ToString());

if (!Directory.Exists(packagePath))
await TryDownloadNuGetPackageAsync(identity, cancellationToken);

if (!Directory.Exists(packagePath))
{
logger.LogError("Dependency {DependencyName} cannot be downloaded from NuGet", assemblyName.Name);
return null;
}

var targetFrameworkName = Assembly.GetExecutingAssembly().GetCustomAttribute<TargetFrameworkAttribute>()?.FrameworkName;

if (targetFrameworkName == null)
throw new InvalidOperationException("Cannot determine the target framework.");

var packageReader = new PackageFolderReader(packagePath);
var frameworks = await packageReader.GetLibItemsAsync(cancellationToken);
var targetFramework = NuGetFramework.ParseFrameworkName(targetFrameworkName, new DefaultFrameworkNameProvider());

foreach (var framework in frameworks)
{
if (!DefaultCompatibilityProvider.Instance.IsCompatible(targetFramework, framework.TargetFramework))
continue;

var assembly = framework.Items.FirstOrDefault(fileName => Path.GetFileName(fileName).Equals(assemblyName.Name + ".dll", StringComparison.InvariantCultureIgnoreCase)) ?? framework.Items.FirstOrDefault();

if (assembly is null)
throw new FileNotFoundException($"Dependency {identity.Id} was downloaded but file cannot be located");

return Path.Combine(packagePath, assembly);
}

return null;
}

private async ValueTask TryDownloadNuGetPackageAsync(PackageIdentity identity, CancellationToken cancellationToken)
{
try
{
using var result = await PackageDownloader.GetDownloadResourceResultAsync(NuGetRepository, identity, new PackageDownloadContext(NuGetCache), NuGetPackagesPath, NullLogger.Instance, cancellationToken);
logger.LogInformation("Downloaded {PackageId} {PackageVersion}", identity.Id, identity.Version);
}
catch (FatalProtocolException exception)
{
logger.LogCritical("Dependency {PackageId} cannot be resolved: {Reason}", identity.Id, exception.Message);
}
catch (RetriableProtocolException exception)
{
logger.LogError("Dependency {PackageId} loading was cancelled: {Message}", identity.Id, exception.Message);
}
}

private async ValueTask<PackageIdentity?> TryResolveNuGetIdentityAsync(AssemblyName assemblyName, CancellationToken cancellationToken)
{
logger.LogInformation("Looking for dependency {DependencyName} as Identity in NuGet", assemblyName.Name);
var identity = await TryResolveNuGetPackageIdAsync(assemblyName, cancellationToken);

if (identity is not null)
return identity;

logger.LogInformation("Looking for dependency {DependencyName} with Search in NuGet", assemblyName.Name);
identity = await TryResolveNuGetPackageSearchAsync(assemblyName, cancellationToken);

if (identity is not null)
return identity;

logger.LogWarning("Dependency {DependencyName} not found in NuGet", assemblyName.Name);
return null;
}

private async ValueTask<PackageIdentity?> TryResolveNuGetPackageIdAsync(AssemblyName assemblyName, CancellationToken cancellationToken)
{
if (string.IsNullOrWhiteSpace(assemblyName.Name))
return null;

var packages = await GetNuGetPackageVersionAsync(assemblyName.Name, cancellationToken);
var best = SelectBestNuGetPackageVersion(packages, assemblyName.Version);

return best;
}

private async ValueTask<PackageIdentity?> TryResolveNuGetPackageSearchAsync(AssemblyName assemblyName, CancellationToken cancellationToken)
{
var packageSearchResource = await NuGetRepository.GetResourceAsync<PackageSearchResource>(cancellationToken);
var packageSearchResults = await packageSearchResource.SearchAsync(assemblyName.Name, new SearchFilter(true), 0, 1, NullLogger.Instance, cancellationToken);

// actually always 1
foreach (var packageSearchResult in packageSearchResults)
{
var packages = await GetNuGetPackageVersionAsync(packageSearchResult.Identity.Id, cancellationToken);
var best = SelectBestNuGetPackageVersion(packages, assemblyName.Version);

return best;
}

return null;
}

private PackageIdentity? SelectBestNuGetPackageVersion(IEnumerable<IPackageSearchMetadata> packages, Version? assemblyVersion)
{
IPackageSearchMetadata? result = null;

foreach (var package in packages)
{
if (result is null)
{
result = package;
continue;
}

if (!package.Identity.HasVersion)
continue;

if (package.Identity.Version.CompareTo(result.Identity.Version) < 0)
continue;

if (assemblyVersion is null)
result = package;
else if (assemblyVersion.Major == package.Identity.Version.Major) result = package;
}

if (result is null)
return null;

logger.LogInformation("Dependency {DependencyName} resolved with version {DependencyVersion}", result.Identity.Id, result.Identity.Version);
return result.Identity;
}

private static async ValueTask<IEnumerable<IPackageSearchMetadata>> GetNuGetPackageVersionAsync(string packageId, CancellationToken cancellationToken)
{
var packageMetadataResource = await NuGetRepository.GetResourceAsync<PackageMetadataResource>(cancellationToken);
return await packageMetadataResource.GetMetadataAsync(packageId, true, false, NuGetCache, NullLogger.Instance, cancellationToken);
}
}
11 changes: 5 additions & 6 deletions src/Platform/Plugins/PluginService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Void.Proxy.Plugins;

public class PluginService(ILogger<PluginService> logger, IEventService events, IServiceProvider services) : IPluginService
public class PluginService(ILogger<PluginService> logger, IEventService events, IServiceProvider services, IPluginDependencyService dependencies) : IPluginService
{
private readonly TimeSpan _gcRate = TimeSpan.FromMilliseconds(500);
private readonly List<IPlugin> _plugins = [];
Expand All @@ -32,20 +32,19 @@ public async ValueTask LoadAsync(string path = "plugins", CancellationToken canc

foreach (var pluginPath in pluginPaths)
{
var context = new PluginLoadContext(pluginPath);
var context = new PluginLoadContext(dependencies, pluginPath);

logger.LogInformation("Loading {PluginName} plugin", context.Name);

var assembly = context.LoadFromAssemblyName(new AssemblyName(Path.GetFileNameWithoutExtension(pluginPath)));
var plugins = RegisterPlugins(context.Name, assembly);
var plugins = RegisterPlugins(context.Name, context.PluginAssembly);

if (plugins.Length == 0)
{
logger.LogWarning("Plugin {PluginName} has no IPlugin implementations", context.Name);
continue;
}

var listeners = assembly.GetTypes().Where(typeof(IEventListener).IsAssignableFrom).Select(CreateListenerInstance).Cast<IEventListener?>().WhereNotNull().ToArray();
var listeners = context.PluginAssembly.GetTypes().Where(typeof(IEventListener).IsAssignableFrom).Select(CreateListenerInstance).Cast<IEventListener?>().WhereNotNull().ToArray();

if (listeners.Length == 0)
logger.LogWarning("Plugin {PluginName} has no event listeners", context.Name);
Expand Down Expand Up @@ -116,7 +115,7 @@ public IPlugin[] RegisterPlugins(string? name, Assembly assembly)
{
logger.LogError("Assembly {AssemblyName} cannot be loaded:", name);

var noStackTrace = exception.LoaderExceptions.WhereNotNull().Where(loaderException => string.IsNullOrWhiteSpace(loaderException?.StackTrace)).ToArray();
var noStackTrace = exception.LoaderExceptions.WhereNotNull().Where(loaderException => string.IsNullOrWhiteSpace(loaderException.StackTrace)).ToArray();

if (noStackTrace.Length == exception.LoaderExceptions.Length)
logger.LogError("{Exceptions}", string.Join(", ", noStackTrace.Select(loaderException => loaderException.Message)));
Expand Down
67 changes: 53 additions & 14 deletions src/Platform/Reflection/PluginLoadContext.cs
Original file line number Diff line number Diff line change
@@ -1,35 +1,74 @@
using System.Reflection;
using System.Runtime.Loader;
using Void.Proxy.API.Plugins;

namespace Void.Proxy.Reflection;

public class PluginLoadContext(string pluginPath) : AssemblyLoadContext(Path.GetFileName(pluginPath), true)
public class PluginLoadContext : AssemblyLoadContext
{
private static readonly string[] SharedDependencies = [nameof(Void), nameof(Microsoft), nameof(System)];
private readonly AssemblyDependencyResolver _resolver = new(pluginPath);
private static readonly string[] VersionedDependencies = [nameof(Void)];
private static readonly string[] SharedDependencies = [nameof(Microsoft)];
private static readonly string[] SystemDependencies = [nameof(System), "netstandard"];

protected override Assembly? Load(AssemblyName assemblyName)
private readonly IPluginDependencyService _dependencies;
private readonly AssemblyDependencyResolver _localResolver;

public PluginLoadContext(IPluginDependencyService dependencies, string pluginPath) : base(Path.GetFileName(pluginPath), true)
{
_dependencies = dependencies;
_localResolver = new AssemblyDependencyResolver(pluginPath);
PluginAssembly = LoadFromAssemblyPath(pluginPath);
}

public Assembly PluginAssembly { get; }

protected override Assembly Load(AssemblyName assemblyName)
{
if (SharedDependencies.Any(prefix => !string.IsNullOrWhiteSpace(assemblyName.Name) && assemblyName.Name.StartsWith(prefix)))
if (VersionedDependencies.Any(assemblyName.FullName.StartsWith))
{
var sharedAssembly = Default.Assemblies.FirstOrDefault(loadedAssembly => loadedAssembly.FullName == assemblyName.FullName);
var loadedAssembly = Default.Assemblies.FirstOrDefault(loadedAssembly => loadedAssembly.FullName == assemblyName.FullName);

if (loadedAssembly is not null)
return loadedAssembly;

if (sharedAssembly is not null)
return sharedAssembly;
// version mismatch here
}

var assemblyPath = _resolver.ResolveAssemblyToPath(assemblyName);
if (SharedDependencies.Any(assemblyName.FullName.StartsWith) || SystemDependencies.Any(assemblyName.FullName.StartsWith))
{
var loadedAssembly = Default.Assemblies.FirstOrDefault(loadedAssembly => loadedAssembly.GetName().Name == assemblyName.Name);

if (loadedAssembly is not null)
return loadedAssembly;
}

if (SystemDependencies.Any(assemblyName.FullName.StartsWith))
{
var loadedAssembly = Default.Assemblies.FirstOrDefault(loadedAssembly => loadedAssembly.GetName().Name == assemblyName.Name);

// if System dependency still not loaded, load it manually
return loadedAssembly ?? Default.LoadFromAssemblyName(assemblyName);
}

// fallback to local folder and NuGet
var assembly = _localResolver.ResolveAssemblyToPath(assemblyName) switch
{
{ } assemblyPath => LoadFromAssemblyPath(assemblyPath),
_ when _dependencies.ResolveAssemblyPath(assemblyName) is { } assemblyPath => LoadFromAssemblyPath(assemblyPath),
_ => null
};

if (assemblyPath is not null)
return LoadFromAssemblyPath(assemblyPath);
// sorry, but where am I supposed to find your dependency?
// throw is mandatory to prevent search in Default context
if (assembly is null)
throw new FileNotFoundException("Unable to resolve requested dependency");

// TODO: implement NuGet resolver here
return null;
return assembly;
}

protected override IntPtr LoadUnmanagedDll(string unmanagedDllName)
{
var libraryPath = _resolver.ResolveUnmanagedDllToPath(unmanagedDllName);
var libraryPath = _localResolver.ResolveUnmanagedDllToPath(unmanagedDllName);
return libraryPath != null ? LoadUnmanagedDllFromPath(libraryPath) : IntPtr.Zero;
}
}
1 change: 1 addition & 0 deletions src/Platform/Void.Proxy.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
<ItemGroup>
<PackageReference Include="ini-parser-netstandard" Version="2.5.2" />
<PackageReference Include="NET.Minecraft.Component" Version="1.0.5" />
<PackageReference Include="NuGet.PackageManagement" Version="6.11.0" />
<PackageReference Include="Serilog.Extensions.Hosting" Version="8.0.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="6.0.0" />
</ItemGroup>
Expand Down
Loading

0 comments on commit 0c63d15

Please sign in to comment.