Skip to content

Commit

Permalink
[backend] Use join map to avoid duplicate join
Browse files Browse the repository at this point in the history
  • Loading branch information
RomuDeuxfois authored and savacano28 committed Oct 11, 2024
1 parent 9dfb5ce commit a4f97d7
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 57 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.openbas.rest.scenario;

import io.openbas.aop.LogExecutionTime;
import io.openbas.database.model.*;
import io.openbas.database.raw.RawPaginationScenario;
import io.openbas.database.repository.*;
Expand Down Expand Up @@ -82,6 +83,7 @@ public List<ScenarioSimple> scenarios() {
return this.scenarioService.scenarios();
}

@LogExecutionTime
@PostMapping(SCENARIO_URI + "/search")
public Page<RawPaginationScenario> scenarios(@RequestBody @Valid final SearchPaginationInput searchPaginationInput) {
return this.scenarioService.scenarios(searchPaginationInput);
Expand Down Expand Up @@ -297,3 +299,19 @@ public Exercise createRunningExerciseFromScenario(@PathVariable @NotBlank final
}

}

// -- After --

// Without filters
// Time: back-end 0.06 front-end 0.08

// With filters
// Time: back-end 0.47 front-end 0.5

// -- Before --

// Without filters
// Time: back-end 0.02 front-end 0.04

// With filters
// Time: back-end 0.02 front-end 0.03
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import org.springframework.data.jpa.domain.Specification;

import java.util.Optional;
import java.util.function.Function;
import java.util.function.UnaryOperator;

import static io.openbas.utils.CustomFilterUtils.computeMode;
Expand All @@ -23,14 +22,9 @@ private ScenarioUtils() {
private static final String SCENARIO_RECURRENCE_FILTER = "scenario_recurrence";

/**
* Manage filters that are not directly managed by the generic mechanics -> scenario_kill_chain_phases
* Manage filters that are not directly managed by the generic mechanics
*/
public static Function<Specification<Scenario>, Specification<Scenario>> handleDeepFilter(
@NotNull final SearchPaginationInput searchPaginationInput) {
return handleCustomFilter(searchPaginationInput);
}

private static UnaryOperator<Specification<Scenario>> handleCustomFilter(
public static UnaryOperator<Specification<Scenario>> handleCustomFilter(
@NotNull final SearchPaginationInput searchPaginationInput) {
// Existence of the filter
Optional<Filters.Filter> scenarioRecurrenceFilterOpt = ofNullable(searchPaginationInput.getFilterGroup())
Expand Down
73 changes: 43 additions & 30 deletions openbas-api/src/main/java/io/openbas/service/ScenarioService.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
import java.io.InputStream;
import java.time.Instant;
import java.util.*;
import java.util.function.Function;
import java.util.function.BiFunction;
import java.util.function.UnaryOperator;
import java.util.logging.Level;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
Expand All @@ -56,7 +57,7 @@
import static io.openbas.database.criteria.GenericCriteria.countQuery;
import static io.openbas.database.specification.ScenarioSpecification.findGrantedFor;
import static io.openbas.helper.StreamHelper.fromIterable;
import static io.openbas.rest.scenario.utils.ScenarioUtils.handleDeepFilter;
import static io.openbas.rest.scenario.utils.ScenarioUtils.handleCustomFilter;
import static io.openbas.service.ImportService.EXPORT_ENTRY_ATTACHMENT;
import static io.openbas.service.ImportService.EXPORT_ENTRY_SCENARIO;
import static io.openbas.utils.Constants.ARTICLES;
Expand Down Expand Up @@ -122,39 +123,54 @@ public List<ScenarioSimple> scenarios() {
}

public Page<RawPaginationScenario> scenarios(@NotNull final SearchPaginationInput searchPaginationInput) {
Function<Specification<Scenario>, Specification<Scenario>> finalSpecification = handleDeepFilter(searchPaginationInput);
Map<String, Join<Base, Base>> joinMap = new HashMap<>();

// Compute custom filter
UnaryOperator<Specification<Scenario>> deepFilterSpecification = handleCustomFilter(
searchPaginationInput
);

// Compute find all method
BiFunction<Specification<Scenario>, Pageable, Page<RawPaginationScenario>> findAll = getFindAllFunction(
deepFilterSpecification, joinMap
);

// Compute pagination from find all
return buildPaginationCriteriaBuilder(findAll, searchPaginationInput, Scenario.class, joinMap);
}

private BiFunction<Specification<Scenario>, Pageable, Page<RawPaginationScenario>> getFindAllFunction(
UnaryOperator<Specification<Scenario>> deepFilterSpecification,
Map<String, Join<Base, Base>> joinMap) {

if (currentUser().isAdmin()) {
return buildPaginationCriteriaBuilder(
(Specification<Scenario> specification, Pageable pageable) -> this.findAllWithCriteriaBuilder(
finalSpecification.apply(specification),
pageable
),
searchPaginationInput,
Scenario.class
return (specification, pageable) -> this.findAllWithCriteriaBuilder(
deepFilterSpecification.apply(specification),
pageable,
joinMap
);
} else {
return buildPaginationCriteriaBuilder(
(Specification<Scenario> specification, Pageable pageable) -> this.findAllWithCriteriaBuilder(
findGrantedFor(currentUser().getId()).and(finalSpecification.apply(specification)),
pageable
),
searchPaginationInput,
Scenario.class
return (specification, pageable) -> this.findAllWithCriteriaBuilder(
findGrantedFor(currentUser().getId()).and(deepFilterSpecification.apply(specification)),
pageable,
joinMap
);
}
}

private Page<RawPaginationScenario> findAllWithCriteriaBuilder(
Specification<Scenario> specification,
Pageable pageable) {
Pageable pageable,
Map<String, Join<Base, Base>> joinMap) {
CriteriaBuilder cb = entityManager.getCriteriaBuilder();

// -- Create Query --
CriteriaQuery<Tuple> cq = cb.createTupleQuery();
// FROM
Root<Scenario> scenarioRoot = cq.from(Scenario.class);
// Join on TAG
Join<Scenario, Tag> scenarioTagsJoin = scenarioRoot.join("tags", JoinType.LEFT);
Join<Base, Base> scenarioTagsJoin = scenarioRoot.join("tags", JoinType.LEFT);
joinMap.put("tags", scenarioTagsJoin);
Expression<String[]> tagIdsExpression =
cb.function(
"array_remove",
Expand All @@ -163,8 +179,10 @@ private Page<RawPaginationScenario> findAllWithCriteriaBuilder(
cb.nullLiteral(String.class)
);
// Join on INJECT and INJECTOR CONTRACT
Join<Scenario, Inject> injectsJoin = scenarioRoot.join("injects", JoinType.LEFT);
Join<Inject, InjectorContract> injectorsContractsJoin = injectsJoin.join("injectorContract", JoinType.LEFT);
Join<Base, Base> injectsJoin = scenarioRoot.join("injects", JoinType.LEFT);
joinMap.put("injects", injectsJoin);
Join<Base, Base> injectorsContractsJoin = injectsJoin.join("injectorContract", JoinType.LEFT);
joinMap.put("injects.injectorContract", injectorsContractsJoin);
Expression<String[]> platformExpression =
cb.function(
"array_union_agg",
Expand Down Expand Up @@ -226,9 +244,7 @@ private Page<RawPaginationScenario> findAllWithCriteriaBuilder(
}

/**
* Scenario is recurring
* AND start date is before now
* AND end date is after now
* Scenario is recurring AND start date is before now AND end date is after now
*/
public List<Scenario> recurringScenarios(@NotNull final Instant instant) {
return this.scenarioRepository.findAll(
Expand All @@ -239,10 +255,7 @@ public List<Scenario> recurringScenarios(@NotNull final Instant instant) {
}

/**
* Scenario is recurring
* AND
* start date is before now
* OR stop date is before now
* Scenario is recurring AND start date is before now OR stop date is before now
*/
public List<Scenario> potentialOutdatedRecurringScenario(@NotNull final Instant instant) {
return this.scenarioRepository.findAll(
Expand Down Expand Up @@ -611,7 +624,7 @@ private void getListOfVariables(Scenario scenario, Scenario scenarioOrigin) {
variableService.createVariables(variableList);
}

private void getLessonsCategories(Scenario duplicatedScenario, Scenario originalScenario){
private void getLessonsCategories(Scenario duplicatedScenario, Scenario originalScenario) {
List<LessonsCategory> duplicatedCategories = new ArrayList<>();
for (LessonsCategory originalCategory : originalScenario.getLessonsCategories()) {
LessonsCategory duplicatedCategory = new LessonsCategory();
Expand Down Expand Up @@ -649,7 +662,7 @@ private void getLessonsCategories(Scenario duplicatedScenario, Scenario original
duplicatedScenario.setLessonsCategories(duplicatedCategories);
}

private void getObjectives(Scenario scenario, Scenario scenarioOrigin){
private void getObjectives(Scenario scenario, Scenario scenarioOrigin) {
List<Objective> duplicatedObjectives = new ArrayList<>();
for (Objective originalObjective : scenarioOrigin.getObjectives()) {
Objective duplicatedObjective = new Objective();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package io.openbas.utils.pagination;

import io.openbas.database.model.Base;
import jakarta.persistence.criteria.Join;
import jakarta.validation.constraints.NotNull;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageRequest;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.domain.Specification;

import java.util.HashMap;
import java.util.Map;
import java.util.function.BiFunction;

import static io.openbas.utils.FilterUtilsJpa.computeFilterGroupJpa;
Expand Down Expand Up @@ -39,9 +43,10 @@ public static <T> Page<T> buildPaginationJPA(
public static <T, U> Page<U> buildPaginationCriteriaBuilder(
@NotNull final BiFunction<Specification<T>, Pageable, Page<U>> findAll,
@NotNull final SearchPaginationInput input,
@NotNull final Class<T> clazz) {
@NotNull final Class<T> clazz,
Map<String, Join<Base, Base>> joinMap) {
// Specification
Specification<T> filterSpecifications = computeFilterGroupJpa(input.getFilterGroup());
Specification<T> filterSpecifications = computeFilterGroupJpa(input.getFilterGroup(), joinMap);
Specification<T> searchSpecifications = computeSearchJpa(input.getTextSearch());

// Pageable
Expand All @@ -50,6 +55,18 @@ public static <T, U> Page<U> buildPaginationCriteriaBuilder(
return findAll.apply(filterSpecifications.and(searchSpecifications), pageable);
}

public static <T, U> Page<U> buildPaginationCriteriaBuilder(
@NotNull final BiFunction<Specification<T>, Pageable, Page<U>> findAll,
@NotNull final SearchPaginationInput input,
@NotNull final Class<T> clazz) {
return buildPaginationCriteriaBuilder(
findAll,
input,
clazz,
new HashMap<>()
);
}

/**
* Build PaginationJPA with a specified search specifications that replace the default ones
* @param findAll the find all method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.springframework.data.jpa.domain.Specification;

import javax.annotation.Nullable;
import java.util.HashMap;
import java.util.List;

import static io.openbas.utils.JpaUtils.toPath;
Expand All @@ -35,7 +36,7 @@ public static <T> Specification<T> computeSearchJpa(@Nullable final String searc
List<PropertySchema> searchableProperties = getSearchableProperties(propertySchemas);
List<Predicate> predicates = searchableProperties.stream()
.map(propertySchema -> {
Expression<String> paths = toPath(propertySchema, root);
Expression<String> paths = toPath(propertySchema, root, new HashMap<>());
return toPredicate(paths, search, cb, propertySchema.getType());
})
.toList();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.openbas.utils;

import io.openbas.database.model.Base;
import io.openbas.database.model.Filters.Filter;
import io.openbas.database.model.Filters.FilterGroup;
import io.openbas.database.model.Filters.FilterMode;
Expand All @@ -8,15 +9,14 @@
import io.openbas.utils.schema.SchemaUtils;
import jakarta.persistence.criteria.CriteriaBuilder;
import jakarta.persistence.criteria.Expression;
import jakarta.persistence.criteria.Join;
import jakarta.persistence.criteria.Predicate;
import jakarta.validation.constraints.NotNull;
import org.jetbrains.annotations.Nullable;
import org.springframework.data.jpa.domain.Specification;

import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;

Expand All @@ -39,8 +39,17 @@ public record Option(String id, String label) {

private static final Specification<?> EMPTY_SPECIFICATION = (root, query, cb) -> cb.conjunction();


@SuppressWarnings("unchecked")
public static <T> Specification<T> computeFilterGroupJpa(
@Nullable final FilterGroup filterGroup) {
return computeFilterGroupJpa(filterGroup, new HashMap<>());
}

@SuppressWarnings("unchecked")
public static <T> Specification<T> computeFilterGroupJpa(@Nullable final FilterGroup filterGroup) {
public static <T> Specification<T> computeFilterGroupJpa(
@Nullable final FilterGroup filterGroup,
Map<String, Join<Base, Base>> joinMap) {
if (filterGroup == null) {
return (Specification<T>) EMPTY_SPECIFICATION;
}
Expand All @@ -50,7 +59,7 @@ public static <T> Specification<T> computeFilterGroupJpa(@Nullable final FilterG
if (!filters.isEmpty()) {
List<Specification<T>> list = filters
.stream()
.map((Function<? super Filter, Specification<T>>) FilterUtilsJpa::computeFilter)
.map((Function<? super Filter, Specification<T>>) f -> FilterUtilsJpa.computeFilter(f, joinMap))
.toList();
Specification<T> result = null;
for (Specification<T> el : list) {
Expand All @@ -71,7 +80,9 @@ public static <T> Specification<T> computeFilterGroupJpa(@Nullable final FilterG
}

@SuppressWarnings("unchecked")
private static <T, U> Specification<T> computeFilter(@Nullable final Filter filter) {
private static <T, U> Specification<T> computeFilter(
@Nullable final Filter filter,
Map<String, Join<Base, Base>> joinMap) {
if (filter == null) {
return (Specification<T>) EMPTY_SPECIFICATION;
}
Expand All @@ -81,7 +92,7 @@ private static <T, U> Specification<T> computeFilter(@Nullable final Filter filt
List<PropertySchema> propertySchemas = SchemaUtils.schema(root.getJavaType());
List<PropertySchema> filterableProperties = getFilterableProperties(propertySchemas);
PropertySchema filterableProperty = retrieveProperty(filterableProperties, filterKey);
Expression<U> paths = toPath(filterableProperty, root);
Expression<U> paths = toPath(filterableProperty, root, joinMap);
// In case of join table, we will use ID so type is String
return toPredicate(
paths, filter, cb, filterableProperty.getJoinTable() != null ? String.class : filterableProperty.getType()
Expand Down
Loading

0 comments on commit a4f97d7

Please sign in to comment.