Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use null propagation to optimize away IS NOT NULL checks #34127

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,100 @@ protected virtual SqlExpression VisitCase(CaseExpression caseExpression, bool al
return elseResult ?? _sqlExpressionFactory.Constant(null, caseExpression.Type, caseExpression.TypeMapping);
}

if (IsNull(elseResult))
{
elseResult = null;
}

// optimize expressions such as expr != null ? expr : null and expr == null ? null : expr
if (testIsCondition && whenClauses is [var clause] && (elseResult is null || IsNull(clause.Result)))
{
HashSet<SqlExpression> nullPropagatedOperands = [];

var (test, expr) = elseResult is null
? (clause.Test, clause.Result)
: (_sqlExpressionFactory.Not(clause.Test), elseResult);

DetectNullPropagatingNodes(expr, nullPropagatedOperands);
test = DropNotNullChecks(test, nullPropagatedOperands);

if (IsTrue(test))
{
return expr;
}

if (elseResult != null)
{
test = _sqlExpressionFactory.Not(test);
}

whenClauses = [new(test, clause.Result)];
}

return _sqlExpressionFactory.Case(operand, whenClauses, elseResult, caseExpression);

SqlExpression DropNotNullChecks(SqlExpression expression, HashSet<SqlExpression> nullPropagatedOperands)
=> expression switch
{
SqlUnaryExpression { OperatorType: ExpressionType.NotEqual } isNotNull
when nullPropagatedOperands.Contains(isNotNull.Operand)
=> _sqlExpressionFactory.Constant(true, expression.Type, expression.TypeMapping),

SqlBinaryExpression { OperatorType: ExpressionType.AndAlso } binary
=> _sqlExpressionFactory.MakeBinary(
ExpressionType.AndAlso,
DropNotNullChecks(binary.Left, nullPropagatedOperands),
DropNotNullChecks(binary.Right, nullPropagatedOperands),
expression.TypeMapping,
expression)!,

_ => expression,
};

// FIXME: unify nullability computations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this definitely feels like something more general than as a very local utility to case, plus there should be extensibility for provider-specific expression types. We could even compute this information and bubble it up as part of the normal visitation of SqlNullabilityProcessor - assuming it's useful enough for other parts of the visitor.

But of course this can all be done later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I was considering doing in #33889 (comment) (instead of simple boolean nullability, propagate the "nullability expression") 😉
I will most likely do that in a later experiment, at the very least to evaluate whether the nullability expressions are simple to create/consume.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, makes sense!

static void DetectNullPropagatingNodes(SqlExpression expression, HashSet<SqlExpression> operands)
{
operands.Add(expression);

switch (expression)
{
case AtTimeZoneExpression atTimeZone:
DetectNullPropagatingNodes(atTimeZone.Operand, operands);
DetectNullPropagatingNodes(atTimeZone.TimeZone, operands);
break;

case CollateExpression collate:
DetectNullPropagatingNodes(collate.Operand, operands);
break;

case SqlUnaryExpression { OperatorType: not (ExpressionType.Equal or ExpressionType.NotEqual) } unary:
DetectNullPropagatingNodes(unary.Operand, operands);
break;

case SqlBinaryExpression { OperatorType: not (ExpressionType.AndAlso or ExpressionType.OrElse) } binary:
DetectNullPropagatingNodes(binary.Left, operands);
DetectNullPropagatingNodes(binary.Right, operands);
break;

case SqlFunctionExpression { IsNullable: true } func:
if (func.InstancePropagatesNullability is true)
{
DetectNullPropagatingNodes(func.Instance!, operands);
}

if (!func.IsNiladic)
{
for (var i = 0; i < func.ArgumentsPropagateNullability.Count; i++)
{
if (func.ArgumentsPropagateNullability[i])
{
DetectNullPropagatingNodes(func.Arguments[i], operands);
}
}
}
break;
}
}
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,52 @@ public virtual Task Is_null_on_column_followed_by_OrElse_optimizes_nullability_c
? x.NullableBoolA != x.NullableBoolC
: x.NullableBoolC != x.NullableBoolA)));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_not_null_optimizes_unary_op(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Select(
x => x.NullableIntA != null ? ~x.NullableIntA : null));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_not_null_optimizes_binary_op(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Select(
x => x.NullableIntA != null && x.NullableIntB != null
? x.NullableIntA + x.NullableIntB
: null));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_not_null_optimizes_binary_op_with_partial_checks(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Select(
x => x.NullableStringA != null && x.NullableStringB != null
? x.NullableStringA + x.NullableStringB + x.NullableStringC
: null));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_not_null_optimizes_binary_op_with_nested_checks(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Select(
x => x.NullableStringA != null
? x.NullableStringB != null ? x.NullableStringA + x.NullableStringB : null
: null));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Is_not_null_optimizes_binary_op_with_mixed_checks(bool async)
=> AssertQuery(
async,
ss => ss.Set<NullSemanticsEntity1>().Select(
x => x.NullableStringA != null && x.BoolA ? x.NullableStringA + x.NullableStringB : null));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Sum_function_is_always_considered_non_nullable(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public override async Task Normal_entity_owning_a_split_reference_with_main_frag
AssertSql(
"""
SELECT [e].[Id], CASE
WHEN [e].[OwnedReference_Id] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue1] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue2] IS NOT NULL AND [o0].[OwnedIntValue3] IS NOT NULL AND [o].[OwnedIntValue4] IS NOT NULL THEN [o].[OwnedIntValue4]
WHEN [e].[OwnedReference_Id] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue1] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue2] IS NOT NULL AND [o0].[OwnedIntValue3] IS NOT NULL THEN [o].[OwnedIntValue4]
END AS [OwnedIntValue4], CASE
WHEN [e].[OwnedReference_Id] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue1] IS NOT NULL AND [e].[OwnedReference_OwnedIntValue2] IS NOT NULL AND [o0].[OwnedIntValue3] IS NOT NULL AND [o].[OwnedIntValue4] IS NOT NULL THEN [o].[OwnedStringValue4]
END AS [OwnedStringValue4]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -632,10 +632,7 @@ public override async Task Null_propagation_optimization4(bool async)
"""
SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank]
FROM [Gears] AS [g]
WHERE CASE
WHEN [g].[LeaderNickname] IS NULL THEN NULL
ELSE CAST(LEN([g].[LeaderNickname]) AS int)
END = 5
WHERE CAST(LEN([g].[LeaderNickname]) AS int) = 5
""");
}

Expand All @@ -648,9 +645,7 @@ public override async Task Null_propagation_optimization5(bool async)
"""
SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank]
FROM [Gears] AS [g]
WHERE CASE
WHEN [g].[LeaderNickname] IS NOT NULL THEN CAST(LEN([g].[LeaderNickname]) AS int)
END = 5
WHERE CAST(LEN([g].[LeaderNickname]) AS int) = 5
""");
}

Expand All @@ -663,9 +658,7 @@ public override async Task Null_propagation_optimization6(bool async)
"""
SELECT [g].[Nickname], [g].[SquadId], [g].[AssignedCityName], [g].[CityOfBirthName], [g].[Discriminator], [g].[FullName], [g].[HasSoulPatch], [g].[LeaderNickname], [g].[LeaderSquadId], [g].[Rank]
FROM [Gears] AS [g]
WHERE CASE
WHEN [g].[LeaderNickname] IS NOT NULL THEN CAST(LEN([g].[LeaderNickname]) AS int)
END = 5
WHERE CAST(LEN([g].[LeaderNickname]) AS int) = 5
""");
}

Expand All @@ -676,9 +669,7 @@ public override async Task Select_null_propagation_optimization7(bool async)
// issue #16050
AssertSql(
"""
SELECT CASE
WHEN [g].[LeaderNickname] IS NOT NULL THEN [g].[LeaderNickname] + [g].[LeaderNickname]
END
SELECT [g].[LeaderNickname] + [g].[LeaderNickname]
FROM [Gears] AS [g]
""");
}
Expand Down Expand Up @@ -855,9 +846,7 @@ public override async Task Select_null_propagation_works_for_multiple_navigation

AssertSql(
"""
SELECT CASE
WHEN [c].[Name] IS NOT NULL THEN [c].[Name]
END
SELECT [c].[Name]
FROM [Tags] AS [t]
LEFT JOIN [Gears] AS [g] ON [t].[GearNickName] = [g].[Nickname] AND [t].[GearSquadId] = [g].[SquadId]
LEFT JOIN [Tags] AS [t0] ON ([g].[Nickname] = [t0].[GearNickName] OR ([g].[Nickname] IS NULL AND [t0].[GearNickName] IS NULL)) AND ([g].[SquadId] = [t0].[GearSquadId] OR ([g].[SquadId] IS NULL AND [t0].[GearSquadId] IS NULL))
Expand Down Expand Up @@ -3057,9 +3046,7 @@ public override async Task Select_null_conditional_with_inheritance(bool async)

AssertSql(
"""
SELECT CASE
WHEN [f].[CommanderName] IS NOT NULL THEN [f].[CommanderName]
END
SELECT [f].[CommanderName]
FROM [Factions] AS [f]
""");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2499,9 +2499,7 @@ public override async Task GroupBy_group_Where_Select_Distinct_aggregate(bool as

AssertSql(
"""
SELECT [o].[CustomerID] AS [Key], MAX(CASE
WHEN [o].[OrderDate] IS NOT NULL THEN [o].[OrderDate]
END) AS [Max]
SELECT [o].[CustomerID] AS [Key], MAX([o].[OrderDate]) AS [Max]
FROM [Orders] AS [o]
GROUP BY [o].[CustomerID]
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4417,6 +4417,63 @@ ELSE CAST(0 AS bit)
""");
}

public override async Task Is_not_null_optimizes_unary_op(bool async)
{
await base.Is_not_null_optimizes_unary_op(async);

AssertSql(
"""
SELECT ~[e].[NullableIntA]
FROM [Entities1] AS [e]
""");
}

public override async Task Is_not_null_optimizes_binary_op(bool async)
{
await base.Is_not_null_optimizes_binary_op(async);

AssertSql(
"""
SELECT [e].[NullableIntA] + [e].[NullableIntB]
FROM [Entities1] AS [e]
""");
}

public override async Task Is_not_null_optimizes_binary_op_with_partial_checks(bool async)
{
await base.Is_not_null_optimizes_binary_op_with_partial_checks(async);

AssertSql(
"""
SELECT [e].[NullableStringA] + [e].[NullableStringB] + COALESCE([e].[NullableStringC], N'')
FROM [Entities1] AS [e]
""");
}

public override async Task Is_not_null_optimizes_binary_op_with_nested_checks(bool async)
{
await base.Is_not_null_optimizes_binary_op_with_nested_checks(async);

AssertSql(
"""
SELECT [e].[NullableStringA] + [e].[NullableStringB]
FROM [Entities1] AS [e]
""");
}

public override async Task Is_not_null_optimizes_binary_op_with_mixed_checks(bool async)
{
await base.Is_not_null_optimizes_binary_op_with_mixed_checks(async);

AssertSql(
"""
SELECT CASE
WHEN [e].[BoolA] = CAST(1 AS bit) THEN [e].[NullableStringA] + COALESCE([e].[NullableStringB], N'')
END
FROM [Entities1] AS [e]
""");
}

public override async Task Sum_function_is_always_considered_non_nullable(bool async)
{
await base.Sum_function_is_always_considered_non_nullable(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ public override async Task AsBinary_with_null_check(bool async)

AssertSql(
"""
SELECT [p].[Id], CASE
WHEN [p].[Point] IS NULL THEN NULL
ELSE [p].[Point].STAsBinary()
END AS [Binary]
SELECT [p].[Id], [p].[Point].STAsBinary() AS [Binary]
FROM [PointEntity] AS [p]
""");
}
Expand Down Expand Up @@ -285,10 +282,7 @@ public override async Task Disjoint_with_null_check(bool async)
"""
@point='0xE6100000010C000000000000F03F000000000000F03F' (Size = 22) (DbType = Object)

SELECT [p].[Id], CASE
WHEN [p].[Polygon] IS NULL THEN NULL
ELSE [p].[Polygon].STDisjoint(@point)
END AS [Disjoint]
SELECT [p].[Id], [p].[Polygon].STDisjoint(@point) AS [Disjoint]
FROM [PolygonEntity] AS [p]
""");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ public override async Task AsBinary_with_null_check(bool async)

AssertSql(
"""
SELECT [p].[Id], CASE
WHEN [p].[Point] IS NULL THEN NULL
ELSE [p].[Point].STAsBinary()
END AS [Binary]
SELECT [p].[Id], [p].[Point].STAsBinary() AS [Binary]
FROM [PointEntity] AS [p]
""");
}
Expand Down Expand Up @@ -404,10 +401,7 @@ public override async Task Disjoint_with_null_check(bool async)
"""
@point='0x00000000010C000000000000F03F000000000000F03F' (Size = 22) (DbType = Object)

SELECT [p].[Id], CASE
WHEN [p].[Polygon] IS NULL THEN NULL
ELSE [p].[Polygon].STDisjoint(@point)
END AS [Disjoint]
SELECT [p].[Id], [p].[Polygon].STDisjoint(@point) AS [Disjoint]
FROM [PolygonEntity] AS [p]
""");
}
Expand Down
Loading