diff --git a/src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs b/src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs index cf9e9ef..10224d0 100644 --- a/src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs +++ b/src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs @@ -5,6 +5,7 @@ using Foundatio.Parsers.LuceneQueries.Extensions; using Foundatio.Parsers.LuceneQueries.Nodes; using Foundatio.Parsers.SqlQueries.Visitors; +using Microsoft.Extensions.Primitives; namespace Foundatio.Parsers.SqlQueries.Extensions; @@ -70,13 +71,18 @@ public static string ToDynamicLinqString(this ExistsNode node, ISqlQueryVisitorC if (node.TryGetQuery(out string query)) return query; + var field = GetFieldInfo(context.Fields, node.Field); + var (fieldPrefix, fieldSuffix) = field.GetFieldPrefixAndSuffix(); + var builder = new StringBuilder(); - builder.Append(node.Field); + builder.Append(fieldPrefix); + builder.Append(field.Name); if (!node.IsNegated.HasValue || !node.IsNegated.Value) builder.Append(" != null"); else builder.Append(" == null"); + builder.Append(fieldSuffix); return builder.ToString(); } @@ -93,13 +99,17 @@ public static string ToDynamicLinqString(this MissingNode node, ISqlQueryVisitor if (node.TryGetQuery(out string query)) return query; - var builder = new StringBuilder(); + var field = GetFieldInfo(context.Fields, node.Field); + var (fieldPrefix, fieldSuffix) = field.GetFieldPrefixAndSuffix(); - builder.Append(node.Field); + var builder = new StringBuilder(); + builder.Append(fieldPrefix); + builder.Append(field.Name); if (!node.IsNegated.HasValue || !node.IsNegated.Value) builder.Append(" == null"); else builder.Append(" != null"); + builder.Append(fieldSuffix); return builder.ToString(); } @@ -143,80 +153,51 @@ public static string ToDynamicLinqString(this TermNode node, ISqlQueryVisitorCon fieldTerms.ForEach((kvp, x) => { + if (x.IsFirst && node.IsNegated.HasValue && node.IsNegated.Value) + builder.Append("!"); + builder.Append(x.IsFirst ? "(" : " OR "); + var searchTerm = kvp.Value; var tokens = kvp.Value.Tokens ?? [kvp.Value.Term]; + var (fieldPrefix, fieldSuffix) = kvp.Key.GetFieldPrefixAndSuffix(); - if (searchTerm.FieldInfo.IsCollection) + if (searchTerm.Operator == SqlSearchOperator.Equals) { - int dotIndex = searchTerm.FieldInfo.Field.LastIndexOf('.'); - string collectionField = searchTerm.FieldInfo.Field.Substring(0, dotIndex); - string fieldName = searchTerm.FieldInfo.Field.Substring(dotIndex + 1); - - if (searchTerm.Operator == SqlSearchOperator.Equals) - { - builder.Append(collectionField); - builder.Append(".Any("); - builder.Append(fieldName); - builder.Append(" in ("); - builder.Append(String.Join(',', tokens.Select(t => "\"" + t + "\""))); - builder.Append("))"); - } - else if (searchTerm.Operator == SqlSearchOperator.Contains) - { - tokens.ForEach((token, i) => { - builder.Append(i.IsFirst ? "(" : " OR "); - builder.Append(collectionField); - builder.Append(".Any("); - builder.Append(fieldName); - builder.Append(".Contains(\""); - builder.Append(token); - builder.Append("\"))"); - if (i.IsLast) - builder.Append(")"); - }); - } - else if (searchTerm.Operator == SqlSearchOperator.StartsWith) - { - tokens.ForEach((token, i) => { - builder.Append(i.IsFirst ? "(" : " OR "); - builder.Append(collectionField); - builder.Append(".Any("); - builder.Append(fieldName); - builder.Append(".StartsWith(\""); - builder.Append(token); - builder.Append("\"))"); - if (i.IsLast) - builder.Append(")"); - }); - } + builder.Append(fieldPrefix); + builder.Append(kvp.Key.Name); + builder.Append(" in ("); + builder.Append(String.Join(',', tokens.Select(t => "\"" + t + "\""))); + builder.Append(")"); + builder.Append(fieldSuffix); } - else + else if (searchTerm.Operator == SqlSearchOperator.Contains) { - if (searchTerm.Operator == SqlSearchOperator.Equals) - { - builder.Append(searchTerm.FieldInfo.Field).Append(" in ("); - builder.Append(String.Join(',', tokens.Select(t => "\"" + t + "\""))); - builder.Append(")"); - } - else if (searchTerm.Operator == SqlSearchOperator.Contains) - { - tokens.ForEach((token, i) => { - builder.Append(i.IsFirst ? "(" : " OR "); - builder.Append(searchTerm.FieldInfo.Field).Append(".Contains(\"").Append(token).Append("\")"); - if (i.IsLast) - builder.Append(")"); - }); - } - else if (searchTerm.Operator == SqlSearchOperator.StartsWith) - { - tokens.ForEach((token, i) => { - builder.Append(i.IsFirst ? "(" : " OR "); - builder.Append(searchTerm.FieldInfo.Field).Append(".StartsWith(\"").Append(token).Append("\")"); - if (i.IsLast) - builder.Append(")"); - }); - } + tokens.ForEach((token, i) => { + builder.Append(i.IsFirst ? "(" : " OR "); + builder.Append(fieldPrefix); + builder.Append(kvp.Key.Name); + builder.Append(".Contains(\""); + builder.Append(token); + builder.Append("\")"); + builder.Append(fieldSuffix); + if (i.IsLast) + builder.Append(")"); + }); + } + else if (searchTerm.Operator == SqlSearchOperator.StartsWith) + { + tokens.ForEach((token, i) => { + builder.Append(i.IsFirst ? "(" : " OR "); + builder.Append(fieldPrefix); + builder.Append(kvp.Key.Name); + builder.Append(".StartsWith(\""); + builder.Append(token); + builder.Append("\")"); + builder.Append(fieldSuffix); + if (i.IsLast) + builder.Append(")"); + }); } if (x.IsLast) @@ -227,38 +208,42 @@ public static string ToDynamicLinqString(this TermNode node, ISqlQueryVisitorCon } var field = GetFieldInfo(context.Fields, node.Field); + var (fieldPrefix, fieldSuffix) = field.GetFieldPrefixAndSuffix(); + var searchOperator = SqlSearchOperator.Equals; + if (node.Term.StartsWith("*") && node.Term.EndsWith("*")) + searchOperator = SqlSearchOperator.Contains; + else if (node.Term.EndsWith("*")) + searchOperator = SqlSearchOperator.StartsWith; if (node.IsNegated.HasValue && node.IsNegated.Value) builder.Append("!"); - if (field.IsCollection) + if (searchOperator == SqlSearchOperator.Equals) { - int index = node.Field.LastIndexOf('.'); - string collectionField = node.Field.Substring(0, index); - string fieldName = node.Field.Substring(index + 1); - - builder.Append(collectionField); - builder.Append(".Any("); - builder.Append(fieldName); - - if (node.IsNegated.HasValue && node.IsNegated.Value) - builder.Append(" != "); - else - builder.Append(" = "); - - AppendField(builder, field, node.Term); - - builder.Append(")"); + builder.Append(fieldPrefix); + builder.Append(field.Name); + builder.Append(" = \""); + builder.Append(node.Term); + builder.Append("\""); + builder.Append(fieldSuffix); + } + else if (searchOperator == SqlSearchOperator.Contains) + { + builder.Append(fieldPrefix); + builder.Append(field.Name); + builder.Append(".Contains(\""); + builder.Append(node.Term); + builder.Append("\")"); + builder.Append(fieldSuffix); } else { - builder.Append(node.Field); - if (node.IsNegated.HasValue && node.IsNegated.Value) - builder.Append(" != "); - else - builder.Append(" = "); - - AppendField(builder, field, node.Term); + builder.Append(fieldPrefix); + builder.Append(field.Name); + builder.Append(".StartsWith(\""); + builder.Append(node.Term); + builder.Append("\")"); + builder.Append(fieldSuffix); } return builder.ToString(); @@ -281,19 +266,23 @@ public static string ToDynamicLinqString(this TermRangeNode node, ISqlQueryVisit if (!field.IsNumber && !field.IsDate && !field.IsMoney) context.AddValidationError("Field must be a number, money or date for term range queries."); + var (fieldPrefix, fieldSuffix) = field.GetFieldPrefixAndSuffix(); + var builder = new StringBuilder(); if (node.IsNegated.HasValue && node.IsNegated.Value) - builder.Append("NOT "); + builder.Append("!"); if (node.Min != null && node.Max != null) builder.Append("("); if (node.Min != null) { - builder.Append(node.Field); + builder.Append(fieldPrefix); + builder.Append(field.Name); builder.Append(node.MinInclusive == true ? " >= " : " > "); AppendField(builder, field, node.Min); + builder.Append(fieldSuffix); } if (node.Min != null && node.Max != null) @@ -301,9 +290,11 @@ public static string ToDynamicLinqString(this TermRangeNode node, ISqlQueryVisit if (node.Max != null) { - builder.Append(node.Field); + builder.Append(fieldPrefix); + builder.Append(field.Name); builder.Append(node.MaxInclusive == true ? " <= " : " < "); AppendField(builder, field, node.Max); + builder.Append(fieldSuffix); } if (node.Min != null && node.Max != null) @@ -328,10 +319,10 @@ public static string ToDynamicLinqString(this IQueryNode node, ISqlQueryVisitorC public static EntityFieldInfo GetFieldInfo(List fields, string field) { if (fields == null) - return new EntityFieldInfo { Field = field }; + return new EntityFieldInfo { Name = field, FullName = field}; - return fields.FirstOrDefault(f => f.Field.Equals(field, StringComparison.OrdinalIgnoreCase)) ?? - new EntityFieldInfo { Field = field }; + return fields.FirstOrDefault(f => f.FullName.Equals(field, StringComparison.OrdinalIgnoreCase)) ?? + new EntityFieldInfo { Name = field, FullName = field}; } private static void AppendField(StringBuilder builder, EntityFieldInfo field, string term) diff --git a/src/Foundatio.Parsers.SqlQueries/SqlQueryParser.cs b/src/Foundatio.Parsers.SqlQueries/SqlQueryParser.cs index 65ed49d..93a76ef 100644 --- a/src/Foundatio.Parsers.SqlQueries/SqlQueryParser.cs +++ b/src/Foundatio.Parsers.SqlQueries/SqlQueryParser.cs @@ -82,7 +82,7 @@ public SqlQueryVisitorContext GetContext(IEntityType entityType) if (!_entityFieldCache.TryGetValue(entityType, out var fields)) { fields = new List(); - AddEntityFields(fields, entityType); + AddEntityFields(fields, null, entityType); _entityFieldCache.TryAdd(entityType, fields); } @@ -90,7 +90,7 @@ public SqlQueryVisitorContext GetContext(IEntityType entityType) fields = fields.ToList(); var validationOptions = new QueryValidationOptions(); - foreach (string field in fields.Select(f => f.Field)) + foreach (string field in fields.Where(f => !f.IsNavigation).Select(f => f.FullName)) validationOptions.AllowedFields.Add(field); Configuration.SetValidationOptions(validationOptions); @@ -101,7 +101,7 @@ public SqlQueryVisitorContext GetContext(IEntityType entityType) }; } - private void AddEntityFields(List fields, IEntityType entityType, Stack entityTypeStack = null, string prefix = null, bool isCollection = false, int depth = 0) + private void AddEntityFields(List fields, EntityFieldInfo parent, IEntityType entityType, Stack entityTypeStack = null, string prefix = null, int depth = 0) { entityTypeStack ??= new Stack(); @@ -123,11 +123,12 @@ private void AddEntityFields(List fields, IEntityType entityTyp string propertyPath = prefix + property.Name; fields.Add(new EntityFieldInfo { - Field = propertyPath, + Name = property.Name, + FullName = propertyPath, IsNumber = property.ClrType.UnwrapNullable().IsNumeric(), IsDate = property.ClrType.UnwrapNullable().IsDateTime(), IsBoolean = property.ClrType.UnwrapNullable().IsBoolean(), - IsCollection = isCollection + Parent = parent }); } @@ -139,7 +140,17 @@ private void AddEntityFields(List fields, IEntityType entityTyp string propertyPath = prefix + nav.Name; bool isNavCollection = nav is IReadOnlyNavigationBase { IsCollection: true }; - AddEntityFields(fields, nav.TargetEntityType, entityTypeStack, propertyPath + ".", isNavCollection, depth + 1); + var navFieldInfo = new EntityFieldInfo + { + IsCollection = isNavCollection, + IsNavigation = true, + Name = nav.Name, + FullName = propertyPath, + Parent = parent + }; + fields.Add(navFieldInfo); + + AddEntityFields(fields, navFieldInfo, nav.TargetEntityType, entityTypeStack, propertyPath + ".", depth + 1); } foreach (var skipNav in entityType.GetSkipNavigations()) @@ -149,7 +160,17 @@ private void AddEntityFields(List fields, IEntityType entityTyp string propertyPath = prefix + skipNav.Name; - AddEntityFields(fields, skipNav.TargetEntityType, entityTypeStack, propertyPath + ".", skipNav.IsCollection, depth + 1); + var navFieldInfo = new EntityFieldInfo + { + IsCollection = skipNav.IsCollection, + IsNavigation = true, + Name = skipNav.Name, + FullName = propertyPath, + Parent = parent + }; + fields.Add(navFieldInfo); + + AddEntityFields(fields, navFieldInfo, skipNav.TargetEntityType, entityTypeStack, propertyPath + ".", depth + 1); } entityTypeStack.Pop(); diff --git a/src/Foundatio.Parsers.SqlQueries/Visitors/SqlQueryVisitorContext.cs b/src/Foundatio.Parsers.SqlQueries/Visitors/SqlQueryVisitorContext.cs index f2a9be8..d8c8f88 100644 --- a/src/Foundatio.Parsers.SqlQueries/Visitors/SqlQueryVisitorContext.cs +++ b/src/Foundatio.Parsers.SqlQueries/Visitors/SqlQueryVisitorContext.cs @@ -1,8 +1,11 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Text; using Foundatio.Parsers.LuceneQueries.Visitors; +using Foundatio.Parsers.SqlQueries.Extensions; using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.Extensions.Primitives; namespace Foundatio.Parsers.SqlQueries.Visitors; @@ -13,18 +16,21 @@ public class SqlQueryVisitorContext : QueryVisitorContext, ISqlQueryVisitorConte public IEntityType EntityType { get; set; } } -[DebuggerDisplay("{Field} IsNumber: {IsNumber} IsMoney: {IsMoney} IsDate: {IsDate} IsBoolean: {IsBoolean} IsCollection: {IsCollection}")] +[DebuggerDisplay("{FullName} IsNumber: {IsNumber} IsMoney: {IsMoney} IsDate: {IsDate} IsBoolean: {IsBoolean} IsCollection: {IsCollection}")] public class EntityFieldInfo { - public string Field { get; init; } + public required string Name { get; init; } + public required string FullName { get; init; } public bool IsNumber { get; set; } public bool IsMoney { get; set; } public bool IsDate { get; set; } public bool IsBoolean { get; set; } public bool IsCollection { get; set; } + public bool IsNavigation { get; set; } + public EntityFieldInfo Parent { get; set; } public IDictionary Data { get; set; } = new Dictionary(); - protected bool Equals(EntityFieldInfo other) => Field == other.Field; + protected bool Equals(EntityFieldInfo other) => Name == other.Name; public override bool Equals(object obj) { @@ -46,7 +52,36 @@ public override bool Equals(object obj) return Equals((EntityFieldInfo)obj); } - public override int GetHashCode() => (Field != null ? Field.GetHashCode() : 0); + public override int GetHashCode() => (Name != null ? Name.GetHashCode() : 0); + + public (string fieldPrefix, string fieldSuffix) GetFieldPrefixAndSuffix() + { + var fieldTree = new List(); + EntityFieldInfo current = Parent; + while (current != null) + { + fieldTree.Add(current); + current = current.Parent; + } + + fieldTree.Reverse(); + + var prefix = new StringBuilder(); + var suffix = new StringBuilder(); + foreach (var field in fieldTree) { + if (field.IsCollection) + { + prefix.Append($"{field.Name}.Any("); + suffix.Append(")"); + } + else + { + prefix.Append(field.Name).Append("."); + } + }; + + return (prefix.ToString(), suffix.ToString()); + } } public class SearchTerm diff --git a/tests/Foundatio.Parsers.SqlQueries.Tests/SqlQueryParserTests.cs b/tests/Foundatio.Parsers.SqlQueries.Tests/SqlQueryParserTests.cs index 23586e5..7591f6f 100644 --- a/tests/Foundatio.Parsers.SqlQueries.Tests/SqlQueryParserTests.cs +++ b/tests/Foundatio.Parsers.SqlQueries.Tests/SqlQueryParserTests.cs @@ -93,7 +93,7 @@ public async Task CanSearchWithTokenizer() if (String.IsNullOrWhiteSpace(s.Term)) return; - if (s.FieldInfo.Field != "NationalPhoneNumber") + if (s.FieldInfo.FullName != "NationalPhoneNumber") return; s.Tokens = [phoneNumberUtil.Parse(s.Term, "US").NationalNumber.ToString()]; @@ -213,6 +213,25 @@ public async Task CanUseCollectionDefaultFields() Assert.Equal(sqlExpected, sqlActual); } + [Fact] + public async Task CanUseCollectionDefaultFieldsWithNestedDepth() + { + var sp = GetServiceProvider(); + await using var db = await GetSampleContextWithDataAsync(sp); + var parser = sp.GetRequiredService(); + parser.Configuration.SetDefaultFields(["Companies.DataDefinitions.Key"]); + + var context = parser.GetContext(db.Employees.EntityType); + + string sqlExpected = db.Employees.Where(e => e.Companies.Any(c => c.DataDefinitions.Any(e => e.Key.StartsWith("age")))).ToQueryString(); + string sqlActual = db.Employees.Where("""Companies.Any(DataDefinitions.Any(Key.StartsWith("age")))""").ToQueryString(); + Assert.Equal(sqlExpected, sqlActual); + string sql = await parser.ToDynamicLinqAsync("age", context); + _logger.LogInformation(sql); + sqlActual = db.Employees.Where(sql).ToQueryString(); + Assert.Equal(sqlExpected, sqlActual); + } + [Fact] public async Task CanUseNavigationFields() { @@ -269,7 +288,7 @@ public async Task CanGenerateSql() var parser = sp.GetRequiredService(); var context = parser.GetContext(db.Employees.EntityType); - context.Fields.Add(new EntityFieldInfo { Field = "age", IsNumber = true, Data = { { "DataDefinitionId", 1 } } }); + context.Fields.Add(new EntityFieldInfo { Name = "age", FullName = "age", IsNumber = true, Data = { { "DataDefinitionId", 1 } } }); context.ValidationOptions.AllowedFields.Add("age"); string sqlExpected = db.Employees.Where(e => e.Companies.Any(c => c.Name == "acme") && e.DataValues.Any(dv => dv.DataDefinitionId == 1 && dv.NumberValue == 30)).ToQueryString(); @@ -383,7 +402,7 @@ private async Task ParseAndValidateQuery(string query, string expected, bool isV { Fields = [ - new EntityFieldInfo { Field = "field", IsNumber = true } + new EntityFieldInfo { Name = "field", FullName = "field", IsNumber = true } ] }; string generatedQuery = await GenerateSqlVisitor.RunAsync(result, context);