diff --git a/src/FastExpressionCompiler/FastExpressionCompiler.cs b/src/FastExpressionCompiler/FastExpressionCompiler.cs
index 680a55fe..de10819b 100644
--- a/src/FastExpressionCompiler/FastExpressionCompiler.cs
+++ b/src/FastExpressionCompiler/FastExpressionCompiler.cs
@@ -503,6 +503,9 @@ internal static object TryCompileBoundToFirstClosureParam(Type delegateType, Exp
EmittingVisitor.EmitLoadConstantsAndNestedLambdasIntoVars(il, ref closureInfo);
var parent = returnType == typeof(void) ? ParentFlags.IgnoreResult : ParentFlags.LambdaCall;
+ if (returnType.IsByRef)
+ parent |= ParentFlags.ReturnByRef;
+
if (!EmittingVisitor.TryEmit(bodyExpr, paramExprs, il, ref closureInfo, flags, parent))
return null;
il.Demit(OpCodes.Ret);
@@ -1691,6 +1694,9 @@ private static bool TryCompileNestedLambda(ref ClosureInfo nestedClosureInfo, Ne
EmittingVisitor.EmitLoadConstantsAndNestedLambdasIntoVars(il, ref nestedClosureInfo);
var parent = nestedReturnType == typeof(void) ? ParentFlags.IgnoreResult : ParentFlags.LambdaCall;
+ if (nestedReturnType.IsByRef)
+ parent |= ParentFlags.ReturnByRef;
+
if (!EmittingVisitor.TryEmit(nestedLambdaBody, nestedLambdaParamExprs, il, ref nestedClosureInfo, setup, parent))
return false;
il.Demit(OpCodes.Ret);
@@ -1859,8 +1865,10 @@ public enum ParentFlags
AssignmentByRef = 1 << 14,
/// Indicates the root lambda call
LambdaCall = 1 << 15,
+ /// ReturnByRef
+ ReturnByRef = 1 << 16,
/// The block result
- BlockResult = 1 << 16,
+ BlockResult = 1 << 17,
}
[MethodImpl((MethodImplOptions)256)]
@@ -2261,7 +2269,7 @@ private static bool TryEmitNew(Expression expr, IReadOnlyList paramExprs, IL
else if (newExpr.Type.IsValueType)
{
ctor = newExpr.Type.GetConstructor(
- BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
+ BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic,
default, CallingConventions.Any, Tools.Empty(), default);
if (ctor != null)
il.Demit(OpCodes.Newobj, ctor);
@@ -2670,10 +2678,10 @@ public static bool TryEmitParameter(ParameterExpression paramExpr,
// `a[i].Foo()` -> false, see #281
// `a[i].Bar` -> false, see #265
// `a[i]` -> true, see #413
- (parent & ParentFlags.IndexAccess) == 0 |
+ (parent & ParentFlags.IndexAccess) == 0 |
(parent & (ParentFlags.Call | ParentFlags.MemberAccess)) == 0
);
-
+
closure.LastEmitIsAddress = !isParamOrVarByRef & (isPassedRef | valueTypeMemberButNotIndexAccess);
if (closure.LastEmitIsAddress)
@@ -2689,7 +2697,7 @@ public static bool TryEmitParameter(ParameterExpression paramExpr,
EmitLoadIndirectlyByRef(il, paramType);
return true;
}
-
+
var byAddress = isParamOrVarByRef & isPassedRef & isValueType;
if (byAddress)
EmitStoreAndLoadLocalVariableAddress(il, varIndex);
@@ -2739,7 +2747,7 @@ public static bool TryEmitParameter(ParameterExpression paramExpr,
// `a[i].Foo()` -> false, see #281
// `a[i].Bar` -> false, see #265
// `a[i]` -> true, see #413
- (parent & ParentFlags.IndexAccess) == 0 |
+ (parent & ParentFlags.IndexAccess) == 0 |
(parent & (ParentFlags.Call | ParentFlags.MemberAccess)) == 0
);
@@ -2757,15 +2765,18 @@ public static bool TryEmitParameter(ParameterExpression paramExpr,
{
// #248 - skip the cases with `ref param.Field` were we are actually want to load the `Field` address not the `param`
// this means the parameter is the argument to the method call and not the instance in the method call or member access
- if (!isPassedRef & (parent & (ParentFlags.Call | ParentFlags.LambdaCall)) != 0 &
- (parent & ParentFlags.InstanceAccess) == 0 ||
+ if (!isPassedRef & (
+ ((parent & ParentFlags.Call) != 0 & (parent & ParentFlags.InstanceAccess) == 0) |
+ ((parent & ParentFlags.LambdaCall) != 0 & (parent & ParentFlags.ReturnByRef) == 0)
+ ) ||
(parent & (ParentFlags.Arithmetic | ParentFlags.AssignmentRightValue)) != 0 &
(parent & (ParentFlags.MemberAccess | ParentFlags.InstanceAccess | ParentFlags.AssignmentLeftValue)) == 0)
EmitLoadIndirectlyByRef(il, paramType);
}
else
{
- if (!isPassedRef & ((parent & ParentFlags.Call) != 0) ||
+ if (!isPassedRef & (
+ (parent & ParentFlags.Call) != 0) ||
(parent & (ParentFlags.Coalesce | ParentFlags.MemberAccess | ParentFlags.IndexAccess | ParentFlags.AssignmentRightValue)) != 0)
il.Demit(OpCodes.Ldind_Ref);
}
@@ -3709,7 +3720,10 @@ private static bool TryEmitListInit(ListInitExpression expr, IReadOnlyList p
var ok = true;
// see the TryEmitMethodCall for the reason of the callFlags
- var callFlags = parent & ~ParentFlags.IgnoreResult & ~ParentFlags.MemberAccess & ~ParentFlags.InstanceAccess | ParentFlags.Call;
+ var callFlags = (parent
+ & ~(ParentFlags.IgnoreResult | ParentFlags.MemberAccess | ParentFlags.InstanceAccess |
+ ParentFlags.LambdaCall | ParentFlags.ReturnByRef))
+ | ParentFlags.Call;
for (var i = 0; i < initCount; ++i)
{
if (valueVarIndex != -1) // load local value address, to set its members
@@ -3849,9 +3863,12 @@ private static bool TryEmitArithmeticAndOrAssign(
// Remove the InstanceCall because we need to operate on the (nullable) field value and not on `ref` to return the value.
// We may avoid it in case of not returning the value or PreIncrement/PreDecrement, but let's do less checks and branching.
- var baseFlags = parent & ~ParentFlags.IgnoreResult & ~ParentFlags.InstanceCall;
+ var baseFlags = parent &
+ ~(ParentFlags.IgnoreResult | ParentFlags.InstanceCall |
+ ParentFlags.LambdaCall | ParentFlags.ReturnByRef);
var rightOnlyFlags = baseFlags | ParentFlags.AssignmentRightValue;
+
var memberOrIndexFlags = leftMemberExpr != null ? ParentFlags.MemberAccess : ParentFlags.IndexAccess;
var leftArLeastFlags = baseFlags | ParentFlags.AssignmentLeftValue | memberOrIndexFlags;
@@ -4622,8 +4639,9 @@ public static bool TryEmitMemberGet(MemberExpression expr,
if (objExpr != null)
{
var p = (parent | ParentFlags.InstanceCall)
- & ~ParentFlags.MemberAccess // removing ParentFlags.MemberAccess here because we are calling the method instead of accessing the field
- & ~ParentFlags.IgnoreResult & ~ParentFlags.DupIt;
+ // removing ParentFlags.MemberAccess here because we are calling the method instead of accessing the field
+ & ~(ParentFlags.IgnoreResult | ParentFlags.MemberAccess | ParentFlags.DupIt |
+ ParentFlags.LambdaCall | ParentFlags.ReturnByRef);
if (!TryEmit(objExpr, paramExprs, il, ref closure, setup, p))
return false;
@@ -5296,7 +5314,11 @@ private static bool TryEmitArithmetic(Expression left, Expression right, Express
#endif
ILGenerator il, ref ClosureInfo closure, CompilerFlags setup, ParentFlags parent)
{
- var flags = (parent & ~ParentFlags.IgnoreResult & ~ParentFlags.InstanceCall) | ParentFlags.Arithmetic;
+ var flags = (parent
+ & ~(ParentFlags.IgnoreResult | ParentFlags.InstanceCall |
+ ParentFlags.LambdaCall | ParentFlags.ReturnByRef))
+ | ParentFlags.Arithmetic;
+
var leftNoValueLabel = default(Label);
var leftType = left.Type;
var leftIsNullable = leftType.IsNullable();
@@ -7339,7 +7361,7 @@ internal static StringBuilder ToCSharpString(this Expression e, StringBuilder sb
var methodReturnType = mc.Method.ReturnType;
if (methodReturnType.IsByRef)
sb.Append("ref ");
-
+
// output convert only if it is required, e.g. it may happen for custom expressions designed by users
var diffTypes = mc.Type != methodReturnType;
if (diffTypes) sb.Append("((").Append(mc.Type.ToCode(stripNamespace, printType)).Append(')');
@@ -7491,7 +7513,7 @@ internal static StringBuilder ToCSharpString(this Expression e, StringBuilder sb
{
var newLineIdent = lineIdent + identSpaces;
body.ToCSharpString(sb.NewLineIdent(newLineIdent),
- EnclosedIn.LambdaBody, newLineIdent, stripNamespace, printType, identSpaces,
+ EnclosedIn.LambdaBody, newLineIdent, stripNamespace, printType, identSpaces,
notRecognizedToCode, lambdaMethod.ReturnType.IsByRef);
}
else
diff --git a/test/FastExpressionCompiler.IssueTests/Issue414_Incorrect_il_when_passing_by_ref_value.cs b/test/FastExpressionCompiler.IssueTests/Issue414_Incorrect_il_when_passing_by_ref_value.cs
index ae21d861..cb17064c 100644
--- a/test/FastExpressionCompiler.IssueTests/Issue414_Incorrect_il_when_passing_by_ref_value.cs
+++ b/test/FastExpressionCompiler.IssueTests/Issue414_Incorrect_il_when_passing_by_ref_value.cs
@@ -16,18 +16,18 @@ public int Run()
{
Issue413_ParameterStructIndexer();
Issue413_VariableStructIndexer();
-
+
Issue414_ReturnRefParameter();
Issue414_PassByRefParameter();
-
+
#if LIGHT_EXPRESSION
Issue414_PassByRefVariable();
- // Issue415_ReturnRefParameterByRef();
- // Issue415_ReturnRefParameterByRef_ReturnRefCall();
- return 3;
+ Issue415_ReturnRefParameterByRef();
+ Issue415_ReturnRefParameterByRef_ReturnRefCall();
+ return 7;
#else
- return 2;
+ return 4;
#endif
}
@@ -135,12 +135,12 @@ public void Issue413_ParameterStructIndexer()
}
delegate int MyDelegateNoArgs();
-
+
[Test]
public void Issue413_VariableStructIndexer()
{
var p = Parameter(typeof(MyStruct));
-
+
var expr = Lambda(
Block(
new[] { p },
@@ -192,10 +192,10 @@ public void Issue415_ReturnRefParameterByRef()
var ff = expr.CompileFast(true, CompilerFlags.ThrowOnNotSupportedExpression);
ff.PrintIL();
- // ff.AssertOpCodes(
- // OpCodes.Ldarg_1,
- // OpCodes.Ret
- // );
+ ff.AssertOpCodes(
+ OpCodes.Ldarg_1,
+ OpCodes.Ret
+ );
var x = 17;
++ff(ref x);