diff --git a/openbas-api/src/main/java/io/openbas/rest/scenario/ScenarioApi.java b/openbas-api/src/main/java/io/openbas/rest/scenario/ScenarioApi.java index 553697a029..1a2d666ce0 100644 --- a/openbas-api/src/main/java/io/openbas/rest/scenario/ScenarioApi.java +++ b/openbas-api/src/main/java/io/openbas/rest/scenario/ScenarioApi.java @@ -1,5 +1,6 @@ package io.openbas.rest.scenario; +import io.openbas.aop.LogExecutionTime; import io.openbas.database.model.*; import io.openbas.database.raw.RawPaginationScenario; import io.openbas.database.repository.*; @@ -82,6 +83,7 @@ public List scenarios() { return this.scenarioService.scenarios(); } + @LogExecutionTime @PostMapping(SCENARIO_URI + "/search") public Page scenarios(@RequestBody @Valid final SearchPaginationInput searchPaginationInput) { return this.scenarioService.scenarios(searchPaginationInput); @@ -297,3 +299,19 @@ public Exercise createRunningExerciseFromScenario(@PathVariable @NotBlank final } } + +// -- After -- + +// Without filters +// Time: back-end 0.06 front-end 0.08 + +// With filters +// Time: back-end 0.47 front-end 0.5 + +// -- Before -- + +// Without filters +// Time: back-end 0.02 front-end 0.04 + +// With filters +// Time: back-end 0.02 front-end 0.03 diff --git a/openbas-api/src/main/java/io/openbas/rest/scenario/utils/ScenarioUtils.java b/openbas-api/src/main/java/io/openbas/rest/scenario/utils/ScenarioUtils.java index b4bfd00f62..e038cfef3f 100644 --- a/openbas-api/src/main/java/io/openbas/rest/scenario/utils/ScenarioUtils.java +++ b/openbas-api/src/main/java/io/openbas/rest/scenario/utils/ScenarioUtils.java @@ -8,7 +8,6 @@ import org.springframework.data.jpa.domain.Specification; import java.util.Optional; -import java.util.function.Function; import java.util.function.UnaryOperator; import static io.openbas.utils.CustomFilterUtils.computeMode; @@ -23,14 +22,9 @@ private ScenarioUtils() { private static final String SCENARIO_RECURRENCE_FILTER = "scenario_recurrence"; /** - * Manage filters that are not directly managed by the generic mechanics -> scenario_kill_chain_phases + * Manage filters that are not directly managed by the generic mechanics */ - public static Function, Specification> handleDeepFilter( - @NotNull final SearchPaginationInput searchPaginationInput) { - return handleCustomFilter(searchPaginationInput); - } - - private static UnaryOperator> handleCustomFilter( + public static UnaryOperator> handleCustomFilter( @NotNull final SearchPaginationInput searchPaginationInput) { // Existence of the filter Optional scenarioRecurrenceFilterOpt = ofNullable(searchPaginationInput.getFilterGroup()) diff --git a/openbas-api/src/main/java/io/openbas/service/ScenarioService.java b/openbas-api/src/main/java/io/openbas/service/ScenarioService.java index 21bbf2be17..6072304f31 100644 --- a/openbas-api/src/main/java/io/openbas/service/ScenarioService.java +++ b/openbas-api/src/main/java/io/openbas/service/ScenarioService.java @@ -46,7 +46,8 @@ import java.io.InputStream; import java.time.Instant; import java.util.*; -import java.util.function.Function; +import java.util.function.BiFunction; +import java.util.function.UnaryOperator; import java.util.logging.Level; import java.util.stream.Collectors; import java.util.zip.ZipEntry; @@ -56,7 +57,7 @@ import static io.openbas.database.criteria.GenericCriteria.countQuery; import static io.openbas.database.specification.ScenarioSpecification.findGrantedFor; import static io.openbas.helper.StreamHelper.fromIterable; -import static io.openbas.rest.scenario.utils.ScenarioUtils.handleDeepFilter; +import static io.openbas.rest.scenario.utils.ScenarioUtils.handleCustomFilter; import static io.openbas.service.ImportService.EXPORT_ENTRY_ATTACHMENT; import static io.openbas.service.ImportService.EXPORT_ENTRY_SCENARIO; import static io.openbas.utils.Constants.ARTICLES; @@ -122,31 +123,45 @@ public List scenarios() { } public Page scenarios(@NotNull final SearchPaginationInput searchPaginationInput) { - Function, Specification> finalSpecification = handleDeepFilter(searchPaginationInput); + Map> joinMap = new HashMap<>(); + + // Compute custom filter + UnaryOperator> deepFilterSpecification = handleCustomFilter( + searchPaginationInput + ); + + // Compute find all method + BiFunction, Pageable, Page> findAll = getFindAllFunction( + deepFilterSpecification, joinMap + ); + + // Compute pagination from find all + return buildPaginationCriteriaBuilder(findAll, searchPaginationInput, Scenario.class, joinMap); + } + + private BiFunction, Pageable, Page> getFindAllFunction( + UnaryOperator> deepFilterSpecification, + Map> joinMap) { + if (currentUser().isAdmin()) { - return buildPaginationCriteriaBuilder( - (Specification specification, Pageable pageable) -> this.findAllWithCriteriaBuilder( - finalSpecification.apply(specification), - pageable - ), - searchPaginationInput, - Scenario.class + return (specification, pageable) -> this.findAllWithCriteriaBuilder( + deepFilterSpecification.apply(specification), + pageable, + joinMap ); } else { - return buildPaginationCriteriaBuilder( - (Specification specification, Pageable pageable) -> this.findAllWithCriteriaBuilder( - findGrantedFor(currentUser().getId()).and(finalSpecification.apply(specification)), - pageable - ), - searchPaginationInput, - Scenario.class + return (specification, pageable) -> this.findAllWithCriteriaBuilder( + findGrantedFor(currentUser().getId()).and(deepFilterSpecification.apply(specification)), + pageable, + joinMap ); } } private Page findAllWithCriteriaBuilder( Specification specification, - Pageable pageable) { + Pageable pageable, + Map> joinMap) { CriteriaBuilder cb = entityManager.getCriteriaBuilder(); // -- Create Query -- @@ -154,7 +169,8 @@ private Page findAllWithCriteriaBuilder( // FROM Root scenarioRoot = cq.from(Scenario.class); // Join on TAG - Join scenarioTagsJoin = scenarioRoot.join("tags", JoinType.LEFT); + Join scenarioTagsJoin = scenarioRoot.join("tags", JoinType.LEFT); + joinMap.put("tags", scenarioTagsJoin); Expression tagIdsExpression = cb.function( "array_remove", @@ -163,8 +179,10 @@ private Page findAllWithCriteriaBuilder( cb.nullLiteral(String.class) ); // Join on INJECT and INJECTOR CONTRACT - Join injectsJoin = scenarioRoot.join("injects", JoinType.LEFT); - Join injectorsContractsJoin = injectsJoin.join("injectorContract", JoinType.LEFT); + Join injectsJoin = scenarioRoot.join("injects", JoinType.LEFT); + joinMap.put("injects", injectsJoin); + Join injectorsContractsJoin = injectsJoin.join("injectorContract", JoinType.LEFT); + joinMap.put("injects.injectorContract", injectorsContractsJoin); Expression platformExpression = cb.function( "array_union_agg", @@ -226,9 +244,7 @@ private Page findAllWithCriteriaBuilder( } /** - * Scenario is recurring - * AND start date is before now - * AND end date is after now + * Scenario is recurring AND start date is before now AND end date is after now */ public List recurringScenarios(@NotNull final Instant instant) { return this.scenarioRepository.findAll( @@ -239,10 +255,7 @@ public List recurringScenarios(@NotNull final Instant instant) { } /** - * Scenario is recurring - * AND - * start date is before now - * OR stop date is before now + * Scenario is recurring AND start date is before now OR stop date is before now */ public List potentialOutdatedRecurringScenario(@NotNull final Instant instant) { return this.scenarioRepository.findAll( @@ -611,7 +624,7 @@ private void getListOfVariables(Scenario scenario, Scenario scenarioOrigin) { variableService.createVariables(variableList); } - private void getLessonsCategories(Scenario duplicatedScenario, Scenario originalScenario){ + private void getLessonsCategories(Scenario duplicatedScenario, Scenario originalScenario) { List duplicatedCategories = new ArrayList<>(); for (LessonsCategory originalCategory : originalScenario.getLessonsCategories()) { LessonsCategory duplicatedCategory = new LessonsCategory(); @@ -649,7 +662,7 @@ private void getLessonsCategories(Scenario duplicatedScenario, Scenario original duplicatedScenario.setLessonsCategories(duplicatedCategories); } - private void getObjectives(Scenario scenario, Scenario scenarioOrigin){ + private void getObjectives(Scenario scenario, Scenario scenarioOrigin) { List duplicatedObjectives = new ArrayList<>(); for (Objective originalObjective : scenarioOrigin.getObjectives()) { Objective duplicatedObjective = new Objective(); diff --git a/openbas-api/src/main/java/io/openbas/utils/pagination/PaginationUtils.java b/openbas-api/src/main/java/io/openbas/utils/pagination/PaginationUtils.java index ca669fc4aa..b3666a6b47 100644 --- a/openbas-api/src/main/java/io/openbas/utils/pagination/PaginationUtils.java +++ b/openbas-api/src/main/java/io/openbas/utils/pagination/PaginationUtils.java @@ -1,11 +1,15 @@ package io.openbas.utils.pagination; +import io.openbas.database.model.Base; +import jakarta.persistence.criteria.Join; import jakarta.validation.constraints.NotNull; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.jpa.domain.Specification; +import java.util.HashMap; +import java.util.Map; import java.util.function.BiFunction; import static io.openbas.utils.FilterUtilsJpa.computeFilterGroupJpa; @@ -39,9 +43,10 @@ public static Page buildPaginationJPA( public static Page buildPaginationCriteriaBuilder( @NotNull final BiFunction, Pageable, Page> findAll, @NotNull final SearchPaginationInput input, - @NotNull final Class clazz) { + @NotNull final Class clazz, + Map> joinMap) { // Specification - Specification filterSpecifications = computeFilterGroupJpa(input.getFilterGroup()); + Specification filterSpecifications = computeFilterGroupJpa(input.getFilterGroup(), joinMap); Specification searchSpecifications = computeSearchJpa(input.getTextSearch()); // Pageable @@ -50,6 +55,18 @@ public static Page buildPaginationCriteriaBuilder( return findAll.apply(filterSpecifications.and(searchSpecifications), pageable); } + public static Page buildPaginationCriteriaBuilder( + @NotNull final BiFunction, Pageable, Page> findAll, + @NotNull final SearchPaginationInput input, + @NotNull final Class clazz) { + return buildPaginationCriteriaBuilder( + findAll, + input, + clazz, + new HashMap<>() + ); + } + /** * Build PaginationJPA with a specified search specifications that replace the default ones * @param findAll the find all method diff --git a/openbas-api/src/main/java/io/openbas/utils/pagination/SearchUtilsJpa.java b/openbas-api/src/main/java/io/openbas/utils/pagination/SearchUtilsJpa.java index b414f73ddf..4c6ea0be8c 100644 --- a/openbas-api/src/main/java/io/openbas/utils/pagination/SearchUtilsJpa.java +++ b/openbas-api/src/main/java/io/openbas/utils/pagination/SearchUtilsJpa.java @@ -10,6 +10,7 @@ import org.springframework.data.jpa.domain.Specification; import javax.annotation.Nullable; +import java.util.HashMap; import java.util.List; import static io.openbas.utils.JpaUtils.toPath; @@ -35,7 +36,7 @@ public static Specification computeSearchJpa(@Nullable final String searc List searchableProperties = getSearchableProperties(propertySchemas); List predicates = searchableProperties.stream() .map(propertySchema -> { - Expression paths = toPath(propertySchema, root); + Expression paths = toPath(propertySchema, root, new HashMap<>()); return toPredicate(paths, search, cb, propertySchema.getType()); }) .toList(); diff --git a/openbas-framework/src/main/java/io/openbas/utils/FilterUtilsJpa.java b/openbas-framework/src/main/java/io/openbas/utils/FilterUtilsJpa.java index 03a4075d98..4a72a282e1 100644 --- a/openbas-framework/src/main/java/io/openbas/utils/FilterUtilsJpa.java +++ b/openbas-framework/src/main/java/io/openbas/utils/FilterUtilsJpa.java @@ -1,5 +1,6 @@ package io.openbas.utils; +import io.openbas.database.model.Base; import io.openbas.database.model.Filters.Filter; import io.openbas.database.model.Filters.FilterGroup; import io.openbas.database.model.Filters.FilterMode; @@ -8,15 +9,14 @@ import io.openbas.utils.schema.SchemaUtils; import jakarta.persistence.criteria.CriteriaBuilder; import jakarta.persistence.criteria.Expression; +import jakarta.persistence.criteria.Join; import jakarta.persistence.criteria.Predicate; import jakarta.validation.constraints.NotNull; import org.jetbrains.annotations.Nullable; import org.springframework.data.jpa.domain.Specification; import java.time.Instant; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.function.BiFunction; import java.util.function.Function; @@ -39,8 +39,17 @@ public record Option(String id, String label) { private static final Specification EMPTY_SPECIFICATION = (root, query, cb) -> cb.conjunction(); + + @SuppressWarnings("unchecked") + public static Specification computeFilterGroupJpa( + @Nullable final FilterGroup filterGroup) { + return computeFilterGroupJpa(filterGroup, new HashMap<>()); + } + @SuppressWarnings("unchecked") - public static Specification computeFilterGroupJpa(@Nullable final FilterGroup filterGroup) { + public static Specification computeFilterGroupJpa( + @Nullable final FilterGroup filterGroup, + Map> joinMap) { if (filterGroup == null) { return (Specification) EMPTY_SPECIFICATION; } @@ -50,7 +59,7 @@ public static Specification computeFilterGroupJpa(@Nullable final FilterG if (!filters.isEmpty()) { List> list = filters .stream() - .map((Function>) FilterUtilsJpa::computeFilter) + .map((Function>) f -> FilterUtilsJpa.computeFilter(f, joinMap)) .toList(); Specification result = null; for (Specification el : list) { @@ -71,7 +80,9 @@ public static Specification computeFilterGroupJpa(@Nullable final FilterG } @SuppressWarnings("unchecked") - private static Specification computeFilter(@Nullable final Filter filter) { + private static Specification computeFilter( + @Nullable final Filter filter, + Map> joinMap) { if (filter == null) { return (Specification) EMPTY_SPECIFICATION; } @@ -81,7 +92,7 @@ private static Specification computeFilter(@Nullable final Filter filt List propertySchemas = SchemaUtils.schema(root.getJavaType()); List filterableProperties = getFilterableProperties(propertySchemas); PropertySchema filterableProperty = retrieveProperty(filterableProperties, filterKey); - Expression paths = toPath(filterableProperty, root); + Expression paths = toPath(filterableProperty, root, joinMap); // In case of join table, we will use ID so type is String return toPredicate( paths, filter, cb, filterableProperty.getJoinTable() != null ? String.class : filterableProperty.getType() diff --git a/openbas-framework/src/main/java/io/openbas/utils/JpaUtils.java b/openbas-framework/src/main/java/io/openbas/utils/JpaUtils.java index a3d539bd5e..d29de14259 100644 --- a/openbas-framework/src/main/java/io/openbas/utils/JpaUtils.java +++ b/openbas-framework/src/main/java/io/openbas/utils/JpaUtils.java @@ -1,9 +1,12 @@ package io.openbas.utils; +import io.openbas.database.model.Base; import io.openbas.utils.schema.PropertySchema; import jakarta.persistence.criteria.*; import jakarta.validation.constraints.NotNull; +import java.util.Map; + import static org.springframework.util.StringUtils.hasText; public class JpaUtils { @@ -12,19 +15,71 @@ private JpaUtils() { } + private static Path computePath( + @NotNull final From from, + @NotNull final String key) { + String[] jsonPaths = key.split("\\."); + + // Deep path -> use join + if (jsonPaths.length > 1) { + From currentFrom = from; + for (int i = 0; i < jsonPaths.length - 1; i++) { + currentFrom = currentFrom.join(jsonPaths[i], JoinType.LEFT); + } + // Last path part -> use get + return currentFrom.get(jsonPaths[jsonPaths.length - 1]); + } + + // Simple path -> use get + else if (jsonPaths.length == 1) { + return from.get(jsonPaths[0]); + } + + return null; + } + public static Expression toPath( @NotNull final PropertySchema propertySchema, - @NotNull final Root root) { + @NotNull final Root root, + @NotNull final Map> joinMap) { // Path if (hasText(propertySchema.getPath())) { - String[] jsonPaths = propertySchema.getPath().split("\\."); - if (jsonPaths.length > 0) { - Join paths = root.join(jsonPaths[0], JoinType.LEFT); - for (int i = 1; i < jsonPaths.length - 1; i++) { - paths = paths.join(jsonPaths[i], JoinType.LEFT); + if (joinMap.isEmpty()) { + return computePath(root, propertySchema.getPath()); + } + + String existingPath = propertySchema.getPath(); + Join existingJoin = null; + String existingKey = null; + + // Compute existing join + while (hasText(existingPath)) { + if (joinMap.containsKey(existingPath)) { + existingJoin = joinMap.get(existingPath); + existingKey = existingPath; + break; + } + // Nothing found -> exit + int lastDotIndex = existingPath.lastIndexOf("."); + if (lastDotIndex == -1) { + break; } - return paths.get(jsonPaths[jsonPaths.length - 1]); + existingPath = existingPath.substring(0, existingPath.lastIndexOf(".")); } + + // If existing join in joinMap + if (existingJoin != null) { + // If equals to key -> return it + if (existingKey.equals(propertySchema.getPath())) { + return (Expression) existingJoin; + // If not, compute the remaining path and return it + } else { + String remainingPath = propertySchema.getPath().substring(existingKey.length() + 1); + return computePath(existingJoin, remainingPath); + } + } + + return computePath(root, propertySchema.getPath()); } // Join if (propertySchema.getJoinTable() != null) { @@ -50,11 +105,16 @@ public static Expression arrayAggOnId( // -- JOIN -- - public static Join createLeftJoin(Root root, String attributeName) { + public static Join createLeftJoin( + Root root, + String attributeName) { return root.join(attributeName, JoinType.LEFT); } - public static Expression createJoinArrayAggOnId(CriteriaBuilder cb, Root root, String attributeName) { + public static Expression createJoinArrayAggOnId( + CriteriaBuilder cb, + Root root, + String attributeName) { Join join = createLeftJoin(root, attributeName); return arrayAggOnId(cb, join); }