Skip to content

Commit

Permalink
#12 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
sguldmund committed Jan 17, 2024
1 parent 456721a commit e31f97b
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 81 deletions.
4 changes: 2 additions & 2 deletions src/Pose/Helpers/ShimHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ public static void ValidateReplacementMethodSignature(MethodBase original, Metho
.Skip(isStaticOrConstructor ? 0 : 1)
.ToArray();

if (validReturnType != shimReturnType)
throw new InvalidShimSignatureException($"Mismatched return types. Expected {validReturnType}. Got {shimReturnType}");
// if (validReturnType != shimReturnType)
// throw new InvalidShimSignatureException($"Mismatched return types. Expected {validReturnType}. Got {shimReturnType}");

if (!isStaticOrConstructor)
{
Expand Down
88 changes: 75 additions & 13 deletions src/Pose/PoseContext.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Pose.IL;

namespace System.Runtime.CompilerServices
namespace System.Runtime.CompilerServices1
{
// AsyncVoidMethodBuilder.cs in your project
public class AsyncTaskMethodBuilder
Expand All @@ -27,41 +30,98 @@ public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
{

}

public void SetStateMachine(IAsyncStateMachine stateMachine) {}

public void SetException(Exception exception) {}

public Task Task => null;

public AsyncTaskMethodBuilder()
=> Console.WriteLine(".ctor");

public static AsyncTaskMethodBuilder Create()
=> new AsyncTaskMethodBuilder();

public void SetResult() => Console.WriteLine("SetResult");

public void Start<TStateMachine>(ref TStateMachine stateMachine)
where TStateMachine : IAsyncStateMachine
{
Console.WriteLine("Start");
var methodInfos = stateMachine.GetType().GetMethods(BindingFlags.Instance | BindingFlags.NonPublic);
var methodRewriter = MethodRewriter.CreateRewriter(methodInfos[0], false);
var methodInfo = methodInfos[0];
var methodRewriter = MethodRewriter.CreateRewriter(methodInfo, false);
var methodBase = methodRewriter.Rewrite();
methodBase.Invoke(this, new object[] { stateMachine });
stateMachine.MoveNext();
}

// AwaitOnCompleted, AwaitUnsafeOnCompleted, SetException
// and SetStateMachine are empty
}
}

namespace Pose
{
/// <summary>
/// A helper class to run Async code from a synchronize methods
/// </summary>
/// <remarks>
/// Use this helper when your method isn't decorated with 'async', so you can't implement 'await' on the call to the async-method.
/// </remarks>
public static class AsyncHelper
{
private static readonly TaskFactory MyTaskFactory = new TaskFactory(CancellationToken.None, TaskCreationOptions.None, TaskContinuationOptions.None, TaskScheduler.Default);

/// <summary>
/// Call this method when you need the result back from the async-method you are calling.
/// </summary>
/// <example>
/// <code>
/// var result = AsyncHelper.RunASync&lt;bool&gt;(() => IsValueTrueAsync(true));
/// </code>
/// </example>
/// <typeparam name="TResult">The type of the result.</typeparam>
/// <param name="func">The function to run.</param>
/// <returns>The result from running <paramref name="func"/>.</returns>
/// <exception cref="ArgumentNullException">If <paramref name="func"/> is null.</exception>
public static TResult RunASync<TResult>(Func<Task<TResult>> func)
{
if (func == null) throw new ArgumentNullException(nameof(func));

return MyTaskFactory
.StartNew(func)
.Unwrap()
.GetAwaiter()
.GetResult();
}

/// <summary>
/// Call this method when you don't need any result back
/// </summary>
/// <example>
/// <code>
/// AsyncHelper.RunASync(() => Save(person));
/// </code>
/// </example>
/// <param name="func">The function to run.</param>
/// <exception cref="ArgumentNullException">If <paramref name="func"/> is null.</exception>
public static void RunASync(Func<Task> func)
{
if (func == null) throw new ArgumentNullException(nameof(func));

MyTaskFactory
.StartNew(func)
.Unwrap()
.GetAwaiter()
.GetResult();
}
}

public static class PoseContext
{
internal static Shim[] Shims { private set; get; }
public static Shim[] Shims { set; get; }
internal static Dictionary<MethodBase, DynamicMethod> StubCache { private set; get; }

public static void Isolate(Action entryPoint, params Shim[] shims)
Expand All @@ -72,7 +132,9 @@ public static void Isolate(Action entryPoint, params Shim[] shims)
return;
}

Shims = shims;
var enumerable = new Shim[]{Shim.Replace(() => System.Runtime.CompilerServices.AsyncTaskMethodBuilder.Create())
.With(() => System.Runtime.CompilerServices1.AsyncTaskMethodBuilder.Create())};
Shims = shims.Concat(enumerable).ToArray();
StubCache = new Dictionary<MethodBase, DynamicMethod>();

var delegateType = typeof(Action<>).MakeGenericType(entryPoint.Target.GetType());
Expand All @@ -84,11 +146,11 @@ public static void Isolate(Action entryPoint, params Shim[] shims)
methodInfo.CreateDelegate(delegateType).DynamicInvoke(entryPoint.Target);
}

public static async Task IsolateAsync(Func<Task> entryPoint, params Shim[] shims)
public static void IsolateAsync(Func<Task> entryPoint, params Shim[] shims)
{
if (shims == null || shims.Length == 0)
{
await entryPoint.Invoke();
AsyncHelper.RunASync(entryPoint.Invoke);
return;
}

Expand All @@ -102,7 +164,7 @@ public static async Task IsolateAsync(Func<Task> entryPoint, params Shim[] shims

Console.WriteLine("----------------------------- Invoking ----------------------------- ");
var @delegate = methodInfo.CreateDelegate(delegateType);
@delegate.DynamicInvoke();
AsyncHelper.RunASync(() => @delegate.DynamicInvoke() as Task);
}
}
}
140 changes: 74 additions & 66 deletions src/Sandbox/Program.cs
Original file line number Diff line number Diff line change
@@ -1,68 +1,70 @@
// See https://aka.ms/new-console-template for more information

using System;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using Pose.IL;
using System.Runtime.CompilerServices;

namespace Pose.Sandbox
{
namespace System.Runtime.CompilerServices
{
// AsyncVoidMethodBuilder.cs in your project
public class AsyncTaskMethodBuilder
{
public void AwaitOnCompleted<TAwaiter, TStateMachine>(
ref TAwaiter awaiter,
ref TStateMachine stateMachine
)
where TAwaiter : INotifyCompletion
where TStateMachine : IAsyncStateMachine
{

}

public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
ref TAwaiter awaiter, ref TStateMachine stateMachine)
where TAwaiter : ICriticalNotifyCompletion
where TStateMachine : IAsyncStateMachine
{

}

public void SetStateMachine(IAsyncStateMachine stateMachine) {}

public void SetException(Exception exception) {}

public Task Task => null;

public AsyncTaskMethodBuilder()
=> Console.WriteLine(".ctor");

public static AsyncTaskMethodBuilder Create()
=> new AsyncTaskMethodBuilder();

public void SetResult() => Console.WriteLine("SetResult");

public void Start<TStateMachine>(ref TStateMachine stateMachine)
where TStateMachine : IAsyncStateMachine
{
Console.WriteLine("Start");
var methodInfos = stateMachine.GetType().GetMethods(BindingFlags.Instance | BindingFlags.NonPublic);
var methodRewriter = MethodRewriter.CreateRewriter(methodInfos[0], false);
var methodBase = methodRewriter.Rewrite();
stateMachine.MoveNext();
}

// AwaitOnCompleted, AwaitUnsafeOnCompleted, SetException
// and SetStateMachine are empty
}
}
// namespace System.Runtime.CompilerServices
// {
// // AsyncVoidMethodBuilder.cs in your project
// public class AsyncTaskMethodBuilder
// {
// public void AwaitOnCompleted<TAwaiter, TStateMachine>(
// ref TAwaiter awaiter,
// ref TStateMachine stateMachine
// )
// where TAwaiter : INotifyCompletion
// where TStateMachine : IAsyncStateMachine
// {
//
// }
//
// public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
// ref TAwaiter awaiter, ref TStateMachine stateMachine)
// where TAwaiter : ICriticalNotifyCompletion
// where TStateMachine : IAsyncStateMachine
// {
//
// }
//
// public void SetStateMachine(IAsyncStateMachine stateMachine) {}
//
// public void SetException(Exception exception) {}
//
// public Task Task => null;
//
// public AsyncTaskMethodBuilder()
// => Console.WriteLine(".ctor");
//
// public static AsyncTaskMethodBuilder Create()
// => new AsyncTaskMethodBuilder();
//
// public void SetResult() => Console.WriteLine("SetResult");
//
// public void Start<TStateMachine>(ref TStateMachine stateMachine)
// where TStateMachine : IAsyncStateMachine
// {
// Console.WriteLine("Start");
// var methodInfos = stateMachine.GetType().GetMethods(BindingFlags.Instance | BindingFlags.NonPublic);
// var methodRewriter = MethodRewriter.CreateRewriter(methodInfos[0], false);
// var methodBase = methodRewriter.Rewrite();
// stateMachine.MoveNext();
// }
//
// // AwaitOnCompleted, AwaitUnsafeOnCompleted, SetException
// // and SetStateMachine are empty
// }
// }

public class Program
{
public static async Task<int> GetAsyncInt() => await Task.FromResult(1);
public static async Task<int> GetAsyncInt()
{
await Task.Delay(1000);
return await Task.FromResult(1);
}

public static async Task Lol()
{
Expand All @@ -72,18 +74,24 @@ public static async Task Lol()

public static void Main(string[] args)
{
Lol().GetAwaiter().GetResult();
//Lol().GetAwaiter().GetResult();

// var shim = Shim
// .Replace(() => Program.GetAsyncInt())
// .With(() => Task.FromResult(2));
//
// PoseContext.IsolateAsync(
// async () =>
// {
// var @int = await GetAsyncInt();
// Console.WriteLine(@int);
// }, shim).GetAwaiter().GetResult();
var shim = Shim
.Replace(() => Program.GetAsyncInt())
.With(() => Task.FromResult(2));

PoseContext.Shims = new Shim[] { shim };

// var shim1 = Shim
// .Replace(() => System.Runtime.CompilerServices.AsyncTaskMethodBuilder.Create())
// .With(() => System.Runtime.CompilerServices1.AsyncTaskMethodBuilder.Create());

PoseContext.IsolateAsync(
async () =>
{
var @int = await GetAsyncInt();
Console.WriteLine(@int);
}, shim);
/*
#if NET48
Console.WriteLine("4.8");
Expand Down

0 comments on commit e31f97b

Please sign in to comment.