diff --git a/docker-compose.yml b/docker-compose.yml index fbb92280..c9db4c86 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -35,8 +35,11 @@ services: - "1433:1433" # login with sa:P@ssword1 environment: - "ACCEPT_EULA=Y" - - "SA_PASSWORD=P@ssword1" + - "MSSQL_SA_PASSWORD=P@ssword1" - "MSSQL_PID=Developer" + user: root + networks: + - foundatio healthcheck: test: [ @@ -49,8 +52,6 @@ services: ] interval: 1s retries: 20 - networks: - - foundatio ready: image: andrewlock/wait-for-dependencies diff --git a/src/Foundatio.Parsers.SqlQueries/Extensions/EnumerableExtensions.cs b/src/Foundatio.Parsers.SqlQueries/Extensions/EnumerableExtensions.cs new file mode 100644 index 00000000..2b61ebeb --- /dev/null +++ b/src/Foundatio.Parsers.SqlQueries/Extensions/EnumerableExtensions.cs @@ -0,0 +1,38 @@ +using System.Collections.Generic; + +namespace Foundatio.Parsers.SqlQueries.Extensions; + +internal static class EnumerableExtensions { + public delegate void ElementAction(T element, ElementInfo info); + + public static void ForEach(this IEnumerable elements, ElementAction action) + { + using IEnumerator enumerator = elements.GetEnumerator(); + bool isFirst = true; + bool hasNext = enumerator.MoveNext(); + int index = 0; + + while (hasNext) + { + T current = enumerator.Current; + hasNext = enumerator.MoveNext(); + action(current, new ElementInfo(index, isFirst, !hasNext)); + isFirst = false; + index++; + } + } + + public struct ElementInfo { + public ElementInfo(int index, bool isFirst, bool isLast) + : this() { + Index = index; + IsFirst = isFirst; + IsLast = isLast; + } + + public int Index { get; private set; } + public bool IsFirst { get; private set; } + public bool IsLast { get; private set; } + } +} + diff --git a/src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs b/src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs index c1c89c9f..cf9e9efb 100644 --- a/src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs +++ b/src/Foundatio.Parsers.SqlQueries/Extensions/SqlNodeExtensions.cs @@ -104,20 +104,15 @@ public static string ToDynamicLinqString(this MissingNode node, ISqlQueryVisitor return builder.ToString(); } - public static EntityFieldInfo GetFieldInfo(List fields, string field) - { - if (fields == null) - return new EntityFieldInfo { Field = field }; - - return fields.FirstOrDefault(f => f.Field.Equals(field, StringComparison.OrdinalIgnoreCase)) ?? - new EntityFieldInfo { Field = field }; - } - public static string ToDynamicLinqString(this TermNode node, ISqlQueryVisitorContext context) { if (!String.IsNullOrEmpty(node.Prefix)) context.AddValidationError("Prefix is not supported for term range queries."); + // support overriding the generated query + if (node.TryGetQuery(out string query)) + return query; + var builder = new StringBuilder(); if (String.IsNullOrEmpty(node.Field)) @@ -128,39 +123,109 @@ public static string ToDynamicLinqString(this TermNode node, ISqlQueryVisitorCon return String.Empty; } - for (int index = 0; index < context.DefaultFields.Length; index++) + var fieldTerms = new Dictionary(); + foreach (string df in context.DefaultFields) + { + var fieldInfo = GetFieldInfo(context.Fields, df); + if (!fieldTerms.TryGetValue(fieldInfo, out var searchTerm)) + { + searchTerm = new SearchTerm + { + FieldInfo = fieldInfo, + Term = node.Term, + Operator = SqlSearchOperator.StartsWith + }; + fieldTerms[fieldInfo] = searchTerm; + } + + context.SearchTokenizer.Invoke(searchTerm); + } + + fieldTerms.ForEach((kvp, x) => { - builder.Append(index == 0 ? "(" : " OR "); + builder.Append(x.IsFirst ? "(" : " OR "); + var searchTerm = kvp.Value; + var tokens = kvp.Value.Tokens ?? [kvp.Value.Term]; - var defaultField = GetFieldInfo(context.Fields, context.DefaultFields[index]); - if (defaultField.IsCollection) + if (searchTerm.FieldInfo.IsCollection) { - int dotIndex = defaultField.Field.LastIndexOf('.'); - string collectionField = defaultField.Field.Substring(0, dotIndex); - string fieldName = defaultField.Field.Substring(dotIndex + 1); - - builder.Append(collectionField); - builder.Append(".Any("); - builder.Append(fieldName); - builder.Append(".Contains(\"").Append(node.Term).Append("\")"); - builder.Append(")"); + int dotIndex = searchTerm.FieldInfo.Field.LastIndexOf('.'); + string collectionField = searchTerm.FieldInfo.Field.Substring(0, dotIndex); + string fieldName = searchTerm.FieldInfo.Field.Substring(dotIndex + 1); + + if (searchTerm.Operator == SqlSearchOperator.Equals) + { + builder.Append(collectionField); + builder.Append(".Any("); + builder.Append(fieldName); + builder.Append(" in ("); + builder.Append(String.Join(',', tokens.Select(t => "\"" + t + "\""))); + builder.Append("))"); + } + else if (searchTerm.Operator == SqlSearchOperator.Contains) + { + tokens.ForEach((token, i) => { + builder.Append(i.IsFirst ? "(" : " OR "); + builder.Append(collectionField); + builder.Append(".Any("); + builder.Append(fieldName); + builder.Append(".Contains(\""); + builder.Append(token); + builder.Append("\"))"); + if (i.IsLast) + builder.Append(")"); + }); + } + else if (searchTerm.Operator == SqlSearchOperator.StartsWith) + { + tokens.ForEach((token, i) => { + builder.Append(i.IsFirst ? "(" : " OR "); + builder.Append(collectionField); + builder.Append(".Any("); + builder.Append(fieldName); + builder.Append(".StartsWith(\""); + builder.Append(token); + builder.Append("\"))"); + if (i.IsLast) + builder.Append(")"); + }); + } } else { - builder.Append(defaultField.Field).Append(".Contains(\"").Append(node.Term).Append("\")"); + if (searchTerm.Operator == SqlSearchOperator.Equals) + { + builder.Append(searchTerm.FieldInfo.Field).Append(" in ("); + builder.Append(String.Join(',', tokens.Select(t => "\"" + t + "\""))); + builder.Append(")"); + } + else if (searchTerm.Operator == SqlSearchOperator.Contains) + { + tokens.ForEach((token, i) => { + builder.Append(i.IsFirst ? "(" : " OR "); + builder.Append(searchTerm.FieldInfo.Field).Append(".Contains(\"").Append(token).Append("\")"); + if (i.IsLast) + builder.Append(")"); + }); + } + else if (searchTerm.Operator == SqlSearchOperator.StartsWith) + { + tokens.ForEach((token, i) => { + builder.Append(i.IsFirst ? "(" : " OR "); + builder.Append(searchTerm.FieldInfo.Field).Append(".StartsWith(\"").Append(token).Append("\")"); + if (i.IsLast) + builder.Append(")"); + }); + } } - if (index == context.DefaultFields.Length - 1) + if (x.IsLast) builder.Append(")"); - } + }); return builder.ToString(); } - // support overriding the generated query - if (node.TryGetQuery(out string query)) - return query; - var field = GetFieldInfo(context.Fields, node.Field); if (node.IsNegated.HasValue && node.IsNegated.Value) @@ -199,52 +264,6 @@ public static string ToDynamicLinqString(this TermNode node, ISqlQueryVisitorCon return builder.ToString(); } - private static void AppendField(StringBuilder builder, EntityFieldInfo field, string term) - { - if (field == null) - return; - - if (field.IsNumber || field.IsBoolean || field.IsMoney) - { - builder.Append(term); - } - else if (field is { IsDate: true }) - { - term = term.Trim(); - if (term.StartsWith("now", StringComparison.OrdinalIgnoreCase)) - { - builder.Append("DateTime.UtcNow"); - - if (term.Length == 3) - return; - - builder.Append("."); - - string method = term[^1..] switch - { - "y" => "AddYears", - "M" => "AddMonths", - "d" => "AddDays", - "h" => "AddHours", - "H" => "AddHours", - "m" => "AddMinutes", - "s" => "AddSeconds", - _ => throw new NotSupportedException("Invalid date operation.") - }; - - bool subtract = term.Substring(3, 1) == "-"; - - builder.Append(method).Append("(").Append(subtract ? "-" : "").Append(term.Substring(4, term.Length - 5)).Append(")"); - } - else - { - builder.Append("DateTime.Parse(\"" + term + "\")"); - } - } - else - builder.Append("\"" + term + "\""); - } - public static string ToDynamicLinqString(this TermRangeNode node, ISqlQueryVisitorContext context) { if (String.IsNullOrEmpty(node.Field)) @@ -306,6 +325,61 @@ public static string ToDynamicLinqString(this IQueryNode node, ISqlQueryVisitorC }; } + public static EntityFieldInfo GetFieldInfo(List fields, string field) + { + if (fields == null) + return new EntityFieldInfo { Field = field }; + + return fields.FirstOrDefault(f => f.Field.Equals(field, StringComparison.OrdinalIgnoreCase)) ?? + new EntityFieldInfo { Field = field }; + } + + private static void AppendField(StringBuilder builder, EntityFieldInfo field, string term) + { + if (field == null) + return; + + if (field.IsNumber || field.IsBoolean || field.IsMoney) + { + builder.Append(term); + } + else if (field is { IsDate: true }) + { + term = term.Trim(); + if (term.StartsWith("now", StringComparison.OrdinalIgnoreCase)) + { + builder.Append("DateTime.UtcNow"); + + if (term.Length == 3) + return; + + builder.Append("."); + + string method = term[^1..] switch + { + "y" => "AddYears", + "M" => "AddMonths", + "d" => "AddDays", + "h" => "AddHours", + "H" => "AddHours", + "m" => "AddMinutes", + "s" => "AddSeconds", + _ => throw new NotSupportedException("Invalid date operation.") + }; + + bool subtract = term.Substring(3, 1) == "-"; + + builder.Append(method).Append("(").Append(subtract ? "-" : "").Append(term.Substring(4, term.Length - 5)).Append(")"); + } + else + { + builder.Append("DateTime.Parse(\"" + term + "\")"); + } + } + else + builder.Append("\"" + term + "\""); + } + private const string QueryKey = "Query"; public static void SetQuery(this IQueryNode node, string query) { diff --git a/src/Foundatio.Parsers.SqlQueries/SqlQueryParser.cs b/src/Foundatio.Parsers.SqlQueries/SqlQueryParser.cs index 5a029f7b..65ed49d0 100644 --- a/src/Foundatio.Parsers.SqlQueries/SqlQueryParser.cs +++ b/src/Foundatio.Parsers.SqlQueries/SqlQueryParser.cs @@ -189,5 +189,10 @@ private void SetupQueryVisitorContextDefaults(IQueryVisitorContext context) if (Configuration.IncludeResolver != null && context.GetIncludeResolver() == null) context.SetIncludeResolver(Configuration.IncludeResolver); } + + if (context is ISqlQueryVisitorContext sqlContext) + { + sqlContext.SearchTokenizer = Configuration.SearchTokenizer; + } } } diff --git a/src/Foundatio.Parsers.SqlQueries/SqlQueryParserConfiguration.cs b/src/Foundatio.Parsers.SqlQueries/SqlQueryParserConfiguration.cs index 2474aebf..6a295d65 100644 --- a/src/Foundatio.Parsers.SqlQueries/SqlQueryParserConfiguration.cs +++ b/src/Foundatio.Parsers.SqlQueries/SqlQueryParserConfiguration.cs @@ -3,6 +3,7 @@ using System.Threading.Tasks; using Foundatio.Parsers.LuceneQueries; using Foundatio.Parsers.LuceneQueries.Visitors; +using Foundatio.Parsers.SqlQueries.Visitors; using Microsoft.EntityFrameworkCore.Metadata; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -25,6 +26,7 @@ public SqlQueryParserConfiguration() public int MaxFieldDepth { get; private set; } = 10; public QueryFieldResolver FieldResolver { get; private set; } + public Action SearchTokenizer { get; set; } = static _ => { }; public EntityTypePropertyFilter EntityTypePropertyFilter { get; private set; } = static _ => true; public EntityTypeNavigationFilter EntityTypeNavigationFilter { get; private set; } = static _ => true; public EntityTypeSkipNavigationFilter EntityTypeSkipNavigationFilter { get; private set; } = static _ => true; @@ -48,6 +50,12 @@ public SqlQueryParserConfiguration SetDefaultFields(string[] fields) return this; } + public SqlQueryParserConfiguration SetSearchTokenizer(Action tokenizer) + { + SearchTokenizer = tokenizer; + return this; + } + public SqlQueryParserConfiguration SetFieldDepth(int maxFieldDepth) { MaxFieldDepth = maxFieldDepth; diff --git a/src/Foundatio.Parsers.SqlQueries/Visitors/ISqlQueryVisitorContext.cs b/src/Foundatio.Parsers.SqlQueries/Visitors/ISqlQueryVisitorContext.cs index 2e68df18..c0d549a8 100644 --- a/src/Foundatio.Parsers.SqlQueries/Visitors/ISqlQueryVisitorContext.cs +++ b/src/Foundatio.Parsers.SqlQueries/Visitors/ISqlQueryVisitorContext.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using Foundatio.Parsers.LuceneQueries.Visitors; namespace Foundatio.Parsers.SqlQueries.Visitors; @@ -6,4 +7,5 @@ namespace Foundatio.Parsers.SqlQueries.Visitors; public interface ISqlQueryVisitorContext : IQueryVisitorContext { List Fields { get; set; } + Action SearchTokenizer { get; set; } } diff --git a/src/Foundatio.Parsers.SqlQueries/Visitors/SqlQueryVisitorContext.cs b/src/Foundatio.Parsers.SqlQueries/Visitors/SqlQueryVisitorContext.cs index f1ec3e9d..f2a9be84 100644 --- a/src/Foundatio.Parsers.SqlQueries/Visitors/SqlQueryVisitorContext.cs +++ b/src/Foundatio.Parsers.SqlQueries/Visitors/SqlQueryVisitorContext.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; using System.Diagnostics; using Foundatio.Parsers.LuceneQueries.Visitors; using Microsoft.EntityFrameworkCore.Metadata; @@ -8,17 +9,52 @@ namespace Foundatio.Parsers.SqlQueries.Visitors; public class SqlQueryVisitorContext : QueryVisitorContext, ISqlQueryVisitorContext { public List Fields { get; set; } + public Action SearchTokenizer { get; set; } = static _ => { }; public IEntityType EntityType { get; set; } } [DebuggerDisplay("{Field} IsNumber: {IsNumber} IsMoney: {IsMoney} IsDate: {IsDate} IsBoolean: {IsBoolean} IsCollection: {IsCollection}")] public class EntityFieldInfo { - public string Field { get; set; } + public string Field { get; init; } public bool IsNumber { get; set; } public bool IsMoney { get; set; } public bool IsDate { get; set; } public bool IsBoolean { get; set; } public bool IsCollection { get; set; } public IDictionary Data { get; set; } = new Dictionary(); + + protected bool Equals(EntityFieldInfo other) => Field == other.Field; + + public override bool Equals(object obj) + { + if (obj is null) + { + return false; + } + + if (ReferenceEquals(this, obj)) + { + return true; + } + + if (obj.GetType() != GetType()) + { + return false; + } + + return Equals((EntityFieldInfo)obj); + } + + public override int GetHashCode() => (Field != null ? Field.GetHashCode() : 0); } + +public class SearchTerm +{ + public EntityFieldInfo FieldInfo { get; set; } + public string Term { get; set; } + public List Tokens { get; set; } + public SqlSearchOperator Operator { get; set; } = SqlSearchOperator.Contains; +} + +public enum SqlSearchOperator { Equals, Contains, StartsWith } diff --git a/tests/Foundatio.Parsers.SqlQueries.Tests/Foundatio.Parsers.SqlQueries.Tests.csproj b/tests/Foundatio.Parsers.SqlQueries.Tests/Foundatio.Parsers.SqlQueries.Tests.csproj index 56be4a1e..9458c114 100644 --- a/tests/Foundatio.Parsers.SqlQueries.Tests/Foundatio.Parsers.SqlQueries.Tests.csproj +++ b/tests/Foundatio.Parsers.SqlQueries.Tests/Foundatio.Parsers.SqlQueries.Tests.csproj @@ -9,6 +9,7 @@ + diff --git a/tests/Foundatio.Parsers.SqlQueries.Tests/SampleContext.cs b/tests/Foundatio.Parsers.SqlQueries.Tests/SampleContext.cs index 38b372cc..2d106b03 100644 --- a/tests/Foundatio.Parsers.SqlQueries.Tests/SampleContext.cs +++ b/tests/Foundatio.Parsers.SqlQueries.Tests/SampleContext.cs @@ -41,10 +41,13 @@ public class Employee { public int Id { get; set; } public string FullName { get; set; } + public string PhoneNumber { get; set; } + public string NationalPhoneNumber { get; set; } public string Title { get; set; } public int Salary { get; set; } public List Companies { get; set; } public List DataValues { get; set; } + public DateTime Created { get; set; } = DateTime.Now; } diff --git a/tests/Foundatio.Parsers.SqlQueries.Tests/SqlQueryParserTests.cs b/tests/Foundatio.Parsers.SqlQueries.Tests/SqlQueryParserTests.cs index 03968063..23586e5a 100644 --- a/tests/Foundatio.Parsers.SqlQueries.Tests/SqlQueryParserTests.cs +++ b/tests/Foundatio.Parsers.SqlQueries.Tests/SqlQueryParserTests.cs @@ -12,6 +12,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Pegasus.Common.Tracing; +using PhoneNumbers; using Xunit; using Xunit.Abstractions; @@ -69,14 +70,63 @@ public async Task CanSearchDefaultFields() var context = parser.GetContext(db.Employees.EntityType); - string sqlExpected = db.Employees.Where(e => e.FullName.Contains("John") || e.Title.Contains("John")).ToQueryString(); - string sqlActual = db.Employees.Where("""FullName.Contains("John") || Title.Contains("John")""").ToQueryString(); + string sqlExpected = db.Employees.Where(e => e.FullName.StartsWith("John") || e.Title.StartsWith("John")).ToQueryString(); + string sqlActual = db.Employees.Where("""FullName.StartsWith("John") || Title.StartsWith("John") """).ToQueryString(); Assert.Equal(sqlExpected, sqlActual); string sql = await parser.ToDynamicLinqAsync("John", context); sqlActual = db.Employees.Where(sql).ToQueryString(); + var results = await db.Employees.Where(sql).ToListAsync(); + Assert.Single(results); Assert.Equal(sqlExpected, sqlActual); } + [Fact] + public async Task CanSearchWithTokenizer() + { + var sp = GetServiceProvider(); + await using var db = await GetSampleContextWithDataAsync(sp); + var phoneNumberUtil = PhoneNumberUtil.GetInstance(); + var parser = sp.GetRequiredService(); + parser.Configuration.SetDefaultFields(["NationalPhoneNumber"]); + parser.Configuration.SetSearchTokenizer(s => + { + if (String.IsNullOrWhiteSpace(s.Term)) + return; + + if (s.FieldInfo.Field != "NationalPhoneNumber") + return; + + s.Tokens = [phoneNumberUtil.Parse(s.Term, "US").NationalNumber.ToString()]; + s.Operator = SqlSearchOperator.StartsWith; + }); + + var context = parser.GetContext(db.Employees.EntityType); + + string sqlExpected = db.Employees.Where(e => e.NationalPhoneNumber.StartsWith("2142222222")).ToQueryString(); + string sqlActual = db.Employees.Where("NationalPhoneNumber.StartsWith(\"2142222222\")").ToQueryString(); + Assert.Equal(sqlExpected, sqlActual); + + string sql = await parser.ToDynamicLinqAsync("214-222-2222", context); + _logger.LogInformation(sql); + sqlActual = db.Employees.Where(sql).ToQueryString(); + var results = await db.Employees.Where(sql).ToListAsync(); + Assert.Single(results); + Assert.Equal(sqlExpected, sqlActual); + + sql = await parser.ToDynamicLinqAsync("2142222222", context); + _logger.LogInformation(sql); + sqlActual = db.Employees.Where(sql).ToQueryString(); + results = await db.Employees.Where(sql).ToListAsync(); + Assert.Single(results); + Assert.Equal(sqlExpected, sqlActual); + + sql = await parser.ToDynamicLinqAsync("21422", context); + _logger.LogInformation(sql); + sqlActual = db.Employees.Where(sql).ToQueryString(); + results = await db.Employees.Where(sql).ToListAsync(); + Assert.Single(results); + } + [Fact] public async Task CanUseDateFilter() { @@ -155,8 +205,8 @@ public async Task CanUseCollectionDefaultFields() var context = parser.GetContext(db.Employees.EntityType); - string sqlExpected = db.Employees.Where(e => e.Companies.Any(c => c.Name.Contains("acme"))).ToQueryString(); - string sqlActual = db.Employees.Where("""Companies.Any(Name.Contains("acme"))""").ToQueryString(); + string sqlExpected = db.Employees.Where(e => e.Companies.Any(c => c.Name.StartsWith("acme"))).ToQueryString(); + string sqlActual = db.Employees.Where("""Companies.Any(Name.StartsWith("acme"))""").ToQueryString(); Assert.Equal(sqlExpected, sqlActual); string sql = await parser.ToDynamicLinqAsync("acme", context); sqlActual = db.Employees.Where(sql).ToQueryString(); @@ -263,6 +313,8 @@ public async Task GetSampleContextWithDataAsync(IServiceProvider var db = sp.GetRequiredService(); var parser = sp.GetRequiredService(); + var phoneNumberUtil = PhoneNumberUtil.GetInstance(); + var dbParser = db.GetService(); Assert.Same(parser, dbParser); var dbSetParser = db.Employees.GetService(); @@ -281,6 +333,8 @@ public async Task GetSampleContextWithDataAsync(IServiceProvider { FullName = "John Doe", Title = "Software Developer", + PhoneNumber = "(214) 222-2222", + NationalPhoneNumber = phoneNumberUtil.Parse("(214) 222-2222", "US").NationalNumber.ToString(), Salary = 80_000, DataValues = [new() { Definition = company.DataDefinitions[0], NumberValue = 30 }], Companies = [company] @@ -289,6 +343,8 @@ public async Task GetSampleContextWithDataAsync(IServiceProvider { FullName = "Jane Doe", Title = "Software Developer", + PhoneNumber = "+52 55 1234 5678", // Mexico + NationalPhoneNumber = phoneNumberUtil.Parse("+52 55 1234 5678", "US").NationalNumber.ToString(), Salary = 90_000, DataValues = [new() { Definition = company.DataDefinitions[0], NumberValue = 23 }], Companies = [company]