diff --git a/src/FastExpressionCompiler/FastExpressionCompiler.cs b/src/FastExpressionCompiler/FastExpressionCompiler.cs index 680a55fe..de10819b 100644 --- a/src/FastExpressionCompiler/FastExpressionCompiler.cs +++ b/src/FastExpressionCompiler/FastExpressionCompiler.cs @@ -503,6 +503,9 @@ internal static object TryCompileBoundToFirstClosureParam(Type delegateType, Exp EmittingVisitor.EmitLoadConstantsAndNestedLambdasIntoVars(il, ref closureInfo); var parent = returnType == typeof(void) ? ParentFlags.IgnoreResult : ParentFlags.LambdaCall; + if (returnType.IsByRef) + parent |= ParentFlags.ReturnByRef; + if (!EmittingVisitor.TryEmit(bodyExpr, paramExprs, il, ref closureInfo, flags, parent)) return null; il.Demit(OpCodes.Ret); @@ -1691,6 +1694,9 @@ private static bool TryCompileNestedLambda(ref ClosureInfo nestedClosureInfo, Ne EmittingVisitor.EmitLoadConstantsAndNestedLambdasIntoVars(il, ref nestedClosureInfo); var parent = nestedReturnType == typeof(void) ? ParentFlags.IgnoreResult : ParentFlags.LambdaCall; + if (nestedReturnType.IsByRef) + parent |= ParentFlags.ReturnByRef; + if (!EmittingVisitor.TryEmit(nestedLambdaBody, nestedLambdaParamExprs, il, ref nestedClosureInfo, setup, parent)) return false; il.Demit(OpCodes.Ret); @@ -1859,8 +1865,10 @@ public enum ParentFlags AssignmentByRef = 1 << 14, /// Indicates the root lambda call LambdaCall = 1 << 15, + /// ReturnByRef + ReturnByRef = 1 << 16, /// The block result - BlockResult = 1 << 16, + BlockResult = 1 << 17, } [MethodImpl((MethodImplOptions)256)] @@ -2261,7 +2269,7 @@ private static bool TryEmitNew(Expression expr, IReadOnlyList paramExprs, IL else if (newExpr.Type.IsValueType) { ctor = newExpr.Type.GetConstructor( - BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, + BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, default, CallingConventions.Any, Tools.Empty(), default); if (ctor != null) il.Demit(OpCodes.Newobj, ctor); @@ -2670,10 +2678,10 @@ public static bool TryEmitParameter(ParameterExpression paramExpr, // `a[i].Foo()` -> false, see #281 // `a[i].Bar` -> false, see #265 // `a[i]` -> true, see #413 - (parent & ParentFlags.IndexAccess) == 0 | + (parent & ParentFlags.IndexAccess) == 0 | (parent & (ParentFlags.Call | ParentFlags.MemberAccess)) == 0 ); - + closure.LastEmitIsAddress = !isParamOrVarByRef & (isPassedRef | valueTypeMemberButNotIndexAccess); if (closure.LastEmitIsAddress) @@ -2689,7 +2697,7 @@ public static bool TryEmitParameter(ParameterExpression paramExpr, EmitLoadIndirectlyByRef(il, paramType); return true; } - + var byAddress = isParamOrVarByRef & isPassedRef & isValueType; if (byAddress) EmitStoreAndLoadLocalVariableAddress(il, varIndex); @@ -2739,7 +2747,7 @@ public static bool TryEmitParameter(ParameterExpression paramExpr, // `a[i].Foo()` -> false, see #281 // `a[i].Bar` -> false, see #265 // `a[i]` -> true, see #413 - (parent & ParentFlags.IndexAccess) == 0 | + (parent & ParentFlags.IndexAccess) == 0 | (parent & (ParentFlags.Call | ParentFlags.MemberAccess)) == 0 ); @@ -2757,15 +2765,18 @@ public static bool TryEmitParameter(ParameterExpression paramExpr, { // #248 - skip the cases with `ref param.Field` were we are actually want to load the `Field` address not the `param` // this means the parameter is the argument to the method call and not the instance in the method call or member access - if (!isPassedRef & (parent & (ParentFlags.Call | ParentFlags.LambdaCall)) != 0 & - (parent & ParentFlags.InstanceAccess) == 0 || + if (!isPassedRef & ( + ((parent & ParentFlags.Call) != 0 & (parent & ParentFlags.InstanceAccess) == 0) | + ((parent & ParentFlags.LambdaCall) != 0 & (parent & ParentFlags.ReturnByRef) == 0) + ) || (parent & (ParentFlags.Arithmetic | ParentFlags.AssignmentRightValue)) != 0 & (parent & (ParentFlags.MemberAccess | ParentFlags.InstanceAccess | ParentFlags.AssignmentLeftValue)) == 0) EmitLoadIndirectlyByRef(il, paramType); } else { - if (!isPassedRef & ((parent & ParentFlags.Call) != 0) || + if (!isPassedRef & ( + (parent & ParentFlags.Call) != 0) || (parent & (ParentFlags.Coalesce | ParentFlags.MemberAccess | ParentFlags.IndexAccess | ParentFlags.AssignmentRightValue)) != 0) il.Demit(OpCodes.Ldind_Ref); } @@ -3709,7 +3720,10 @@ private static bool TryEmitListInit(ListInitExpression expr, IReadOnlyList p var ok = true; // see the TryEmitMethodCall for the reason of the callFlags - var callFlags = parent & ~ParentFlags.IgnoreResult & ~ParentFlags.MemberAccess & ~ParentFlags.InstanceAccess | ParentFlags.Call; + var callFlags = (parent + & ~(ParentFlags.IgnoreResult | ParentFlags.MemberAccess | ParentFlags.InstanceAccess | + ParentFlags.LambdaCall | ParentFlags.ReturnByRef)) + | ParentFlags.Call; for (var i = 0; i < initCount; ++i) { if (valueVarIndex != -1) // load local value address, to set its members @@ -3849,9 +3863,12 @@ private static bool TryEmitArithmeticAndOrAssign( // Remove the InstanceCall because we need to operate on the (nullable) field value and not on `ref` to return the value. // We may avoid it in case of not returning the value or PreIncrement/PreDecrement, but let's do less checks and branching. - var baseFlags = parent & ~ParentFlags.IgnoreResult & ~ParentFlags.InstanceCall; + var baseFlags = parent & + ~(ParentFlags.IgnoreResult | ParentFlags.InstanceCall | + ParentFlags.LambdaCall | ParentFlags.ReturnByRef); var rightOnlyFlags = baseFlags | ParentFlags.AssignmentRightValue; + var memberOrIndexFlags = leftMemberExpr != null ? ParentFlags.MemberAccess : ParentFlags.IndexAccess; var leftArLeastFlags = baseFlags | ParentFlags.AssignmentLeftValue | memberOrIndexFlags; @@ -4622,8 +4639,9 @@ public static bool TryEmitMemberGet(MemberExpression expr, if (objExpr != null) { var p = (parent | ParentFlags.InstanceCall) - & ~ParentFlags.MemberAccess // removing ParentFlags.MemberAccess here because we are calling the method instead of accessing the field - & ~ParentFlags.IgnoreResult & ~ParentFlags.DupIt; + // removing ParentFlags.MemberAccess here because we are calling the method instead of accessing the field + & ~(ParentFlags.IgnoreResult | ParentFlags.MemberAccess | ParentFlags.DupIt | + ParentFlags.LambdaCall | ParentFlags.ReturnByRef); if (!TryEmit(objExpr, paramExprs, il, ref closure, setup, p)) return false; @@ -5296,7 +5314,11 @@ private static bool TryEmitArithmetic(Expression left, Expression right, Express #endif ILGenerator il, ref ClosureInfo closure, CompilerFlags setup, ParentFlags parent) { - var flags = (parent & ~ParentFlags.IgnoreResult & ~ParentFlags.InstanceCall) | ParentFlags.Arithmetic; + var flags = (parent + & ~(ParentFlags.IgnoreResult | ParentFlags.InstanceCall | + ParentFlags.LambdaCall | ParentFlags.ReturnByRef)) + | ParentFlags.Arithmetic; + var leftNoValueLabel = default(Label); var leftType = left.Type; var leftIsNullable = leftType.IsNullable(); @@ -7339,7 +7361,7 @@ internal static StringBuilder ToCSharpString(this Expression e, StringBuilder sb var methodReturnType = mc.Method.ReturnType; if (methodReturnType.IsByRef) sb.Append("ref "); - + // output convert only if it is required, e.g. it may happen for custom expressions designed by users var diffTypes = mc.Type != methodReturnType; if (diffTypes) sb.Append("((").Append(mc.Type.ToCode(stripNamespace, printType)).Append(')'); @@ -7491,7 +7513,7 @@ internal static StringBuilder ToCSharpString(this Expression e, StringBuilder sb { var newLineIdent = lineIdent + identSpaces; body.ToCSharpString(sb.NewLineIdent(newLineIdent), - EnclosedIn.LambdaBody, newLineIdent, stripNamespace, printType, identSpaces, + EnclosedIn.LambdaBody, newLineIdent, stripNamespace, printType, identSpaces, notRecognizedToCode, lambdaMethod.ReturnType.IsByRef); } else diff --git a/test/FastExpressionCompiler.IssueTests/Issue414_Incorrect_il_when_passing_by_ref_value.cs b/test/FastExpressionCompiler.IssueTests/Issue414_Incorrect_il_when_passing_by_ref_value.cs index ae21d861..cb17064c 100644 --- a/test/FastExpressionCompiler.IssueTests/Issue414_Incorrect_il_when_passing_by_ref_value.cs +++ b/test/FastExpressionCompiler.IssueTests/Issue414_Incorrect_il_when_passing_by_ref_value.cs @@ -16,18 +16,18 @@ public int Run() { Issue413_ParameterStructIndexer(); Issue413_VariableStructIndexer(); - + Issue414_ReturnRefParameter(); Issue414_PassByRefParameter(); - + #if LIGHT_EXPRESSION Issue414_PassByRefVariable(); - // Issue415_ReturnRefParameterByRef(); - // Issue415_ReturnRefParameterByRef_ReturnRefCall(); - return 3; + Issue415_ReturnRefParameterByRef(); + Issue415_ReturnRefParameterByRef_ReturnRefCall(); + return 7; #else - return 2; + return 4; #endif } @@ -135,12 +135,12 @@ public void Issue413_ParameterStructIndexer() } delegate int MyDelegateNoArgs(); - + [Test] public void Issue413_VariableStructIndexer() { var p = Parameter(typeof(MyStruct)); - + var expr = Lambda( Block( new[] { p }, @@ -192,10 +192,10 @@ public void Issue415_ReturnRefParameterByRef() var ff = expr.CompileFast(true, CompilerFlags.ThrowOnNotSupportedExpression); ff.PrintIL(); - // ff.AssertOpCodes( - // OpCodes.Ldarg_1, - // OpCodes.Ret - // ); + ff.AssertOpCodes( + OpCodes.Ldarg_1, + OpCodes.Ret + ); var x = 17; ++ff(ref x);